diff --git a/internal/session/compaction_test.go b/internal/session/compaction_test.go index 59baf320..0762c784 100644 --- a/internal/session/compaction_test.go +++ b/internal/session/compaction_test.go @@ -129,26 +129,35 @@ func TestCompactionWithNewMessagesAfterCompaction(t *testing.T) { msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - after compaction"}}} _, _ = tm.AppendMessage(msg4) - // BuildContext should return: [summary] + [M4 (new after compaction)] + [M3 (kept)] + // BuildContext should return: [summary] + [M3 (kept)] + [M4 (new after compaction)] + // Kept messages must appear BEFORE post-compaction messages so the LLM + // sees the conversation in chronological order. Otherwise the latest + // post-compaction user message would be followed by an older kept user + // message, breaking user/assistant alternation and causing the model to + // respond as if the post-compaction turn never happened. messages, _, _ := tm.BuildContext() if len(messages) != 3 { - t.Fatalf("expected 3 messages (summary + M4 + M3), got %d: %+v", len(messages), messages) + t.Fatalf("expected 3 messages (summary + M3 + M4), got %d: %+v", len(messages), messages) } - // Verify order: summary, M4 (new), M3 (kept) + // Verify order: summary, M3 (kept), M4 (new) if messages[0].Role != fantasy.MessageRoleSystem { t.Errorf("first message should be summary, got %s", messages[0].Role) } - if messages[1].Role != fantasy.MessageRoleAssistant { - t.Errorf("second message should be assistant (M4), got %s", messages[1].Role) + if messages[1].Role != fantasy.MessageRoleUser { + t.Errorf("second message should be user (M3 kept), got %s", messages[1].Role) } - m4Text := messages[1].Content[0].(fantasy.TextPart).Text + m3Text := messages[1].Content[0].(fantasy.TextPart).Text + if m3Text != "Message 3 - kept" { + t.Errorf("unexpected M3 text: %s", m3Text) + } + if messages[2].Role != fantasy.MessageRoleAssistant { + t.Errorf("third message should be assistant (M4 post-compact), got %s", messages[2].Role) + } + m4Text := messages[2].Content[0].(fantasy.TextPart).Text if m4Text != "Message 4 - after compaction" { t.Errorf("unexpected M4 text: %s", m4Text) } - if messages[2].Role != fantasy.MessageRoleUser { - t.Errorf("third message should be user (M3), got %s", messages[2].Role) - } // Verify that M1 is NOT in the context for i, msg := range messages { diff --git a/internal/session/tree_manager.go b/internal/session/tree_manager.go index b3959846..aa4b76cf 100644 --- a/internal/session/tree_manager.go +++ b/internal/session/tree_manager.go @@ -755,9 +755,17 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri } } - // If there is a compaction, inject the summary first and collect - // the kept messages starting from FirstKeptEntryID (since the - // compaction entry's parent chain doesn't include them). + // If there is a compaction, inject the summary first, then the + // preserved "kept" messages (chronologically before the compaction), + // then the post-compaction messages (chronologically after). + // + // Order matters: the kept messages must come BEFORE the post-compaction + // branch so the LLM sees the conversation in chronological order. If the + // kept messages were appended last, the latest user message in the + // current branch would be followed by an older kept user message, + // breaking the strict user/assistant alternation that providers expect + // and causing the model to respond as if the previous turn never + // happened. if lastCompaction != nil { messages = append(messages, fantasy.Message{ Role: fantasy.MessageRoleSystem, @@ -768,49 +776,10 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri }, }) - // Collect entries from the compaction entry itself (at compactionIndex) - // and any entries before it in the branch (newer messages). - for i := compactionIndex; i < len(branch); i++ { - entry := branch[i] - switch e := entry.(type) { - case *MessageEntry: - msg, err := e.ToMessage() - if err != nil { - continue // skip malformed entries - } - msgs := msg.ToLLMMessages() - messages = append(messages, msgs...) - - case *BranchSummaryEntry: - // Convert branch summary to a user message for context. - if e.Summary != "" { - messages = append(messages, fantasy.Message{ - Role: fantasy.MessageRoleUser, - Content: []fantasy.MessagePart{ - fantasy.TextPart{ - Text: fmt.Sprintf("[Branch context: %s]", e.Summary), - }, - }, - }) - } - - case *ModelChangeEntry: - provider = e.Provider - modelID = e.ModelID - - case *CompactionEntry: - // Already handled above (summary injected). - continue - } - } - - // Now collect the kept messages starting from FirstKeptEntryID. - // These are not in the current branch because the compaction entry - // is parented to the first kept entry's parent, not the first kept entry. - // We iterate through entries in order (not using getBranchLocked) to avoid - // walking back to old compacted messages. - // We stop when we reach the compaction entry to avoid double-counting - // messages that were added after the compaction. + // Step 1: collect the kept messages starting from FirstKeptEntryID. + // These are not on the current branch (the compaction entry is a + // new root with no parent), so we iterate tm.entries in append order + // and stop when we reach the compaction entry itself. if lastCompaction.FirstKeptEntryID != "" { found := false for _, entry := range tm.entries { @@ -825,13 +794,12 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri } } - // Stop when we reach the compaction entry itself. - // Messages after the compaction are collected from the branch walk above. + // Stop when we reach the compaction entry itself; messages + // after it are collected from the branch walk below. if entryID == lastCompaction.ID { break } - // Process this kept entry. switch e := entry.(type) { case *MessageEntry: msg, err := e.ToMessage() @@ -860,6 +828,42 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri } } + // Step 2: collect entries on the current branch after the compaction + // entry (these are post-compaction messages). The compaction entry + // itself is skipped — its summary was already injected above. + for i := compactionIndex; i < len(branch); i++ { + entry := branch[i] + switch e := entry.(type) { + case *MessageEntry: + msg, err := e.ToMessage() + if err != nil { + continue + } + msgs := msg.ToLLMMessages() + messages = append(messages, msgs...) + + case *BranchSummaryEntry: + if e.Summary != "" { + messages = append(messages, fantasy.Message{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{ + Text: fmt.Sprintf("[Branch context: %s]", e.Summary), + }, + }, + }) + } + + case *ModelChangeEntry: + provider = e.Provider + modelID = e.ModelID + + case *CompactionEntry: + // Summary already injected above. + continue + } + } + return messages, provider, modelID } @@ -1030,44 +1034,22 @@ func (tm *TreeManager) GetContextEntryIDs() []string { var ids []string - // If there's a compaction, we need to collect IDs from: - // 1. Entries after the compaction entry in the branch (newer messages) - // 2. Entries from FirstKeptEntryID onwards (kept messages) + // If there's a compaction, we collect IDs in the same order as + // BuildContext: [summary placeholder, kept messages, post-compaction + // messages]. This ordering must stay in sync with BuildContext so a + // cut-point index can be mapped back to the correct entry ID. if lastCompaction != nil { // Placeholder for the summary system message (no entry ID). ids = append(ids, "") - // Collect IDs from entries after the compaction entry (newer messages). - for i := compactionIndex + 1; i < len(branch); i++ { - entry := branch[i] - switch e := entry.(type) { - case *MessageEntry: - msg, err := e.ToMessage() - if err != nil { - continue - } - msgs := msg.ToLLMMessages() - for range msgs { - ids = append(ids, e.ID) - } - - case *BranchSummaryEntry: - if e.Summary != "" { - ids = append(ids, e.ID) - } - } - } - - // Collect IDs from the kept messages starting at FirstKeptEntryID. - // We iterate through entries in order (not using getBranchLocked) to avoid - // walking back to old compacted messages. - // We stop when we reach the compaction entry to avoid double-counting. + // Step 1: IDs of the kept messages starting at FirstKeptEntryID. + // Iterate tm.entries in append order and stop at the compaction + // entry to avoid double-counting post-compaction messages. if lastCompaction.FirstKeptEntryID != "" { found := false for _, entry := range tm.entries { entryID := tm.EntryID(entry) - // Skip entries until we reach the first kept entry. if !found { if entryID == lastCompaction.FirstKeptEntryID { found = true @@ -1076,7 +1058,6 @@ func (tm *TreeManager) GetContextEntryIDs() []string { } } - // Stop when we reach the compaction entry itself. if entryID == lastCompaction.ID { break } @@ -1100,6 +1081,28 @@ func (tm *TreeManager) GetContextEntryIDs() []string { } } + // Step 2: IDs of entries after the compaction entry on the current + // branch (post-compaction messages). + for i := compactionIndex + 1; i < len(branch); i++ { + entry := branch[i] + switch e := entry.(type) { + case *MessageEntry: + msg, err := e.ToMessage() + if err != nil { + continue + } + msgs := msg.ToLLMMessages() + for range msgs { + ids = append(ids, e.ID) + } + + case *BranchSummaryEntry: + if e.Summary != "" { + ids = append(ids, e.ID) + } + } + } + return ids }