diff --git a/internal/session/compaction_cycle_test.go b/internal/session/compaction_cycle_test.go new file mode 100644 index 00000000..83cecefe --- /dev/null +++ b/internal/session/compaction_cycle_test.go @@ -0,0 +1,66 @@ +package session + +import ( + "testing" + + "github.com/mark3labs/kit/internal/message" +) + +// TestCompactionParentCycleRegression tests that after multiple compactions, +// newly appended messages always have a valid parent chain and BuildContext +// returns the correct messages. +func TestCompactionParentCycleRegression(t *testing.T) { + tm := InMemoryTreeSession("/test") + + // Simulate a long conversation with multiple compactions. + msg1, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg1"}}}) + msg2, _ := tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg2"}}}) + + // First compaction + comp1, _ := tm.AppendCompaction("Summary 1", msg1, 1000, 500, 1, []string{}, []string{}) + + msg3, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg3"}}}) + msg4, _ := tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg4"}}}) + + // Second compaction + comp2, _ := tm.AppendCompaction("Summary 2", msg3, 1000, 500, 1, []string{}, []string{}) + + msg5, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg5"}}}) + msg6, _ := tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg6"}}}) + + // Verify parent chain integrity + for _, id := range []string{msg1, msg2, comp1, msg3, msg4, comp2, msg5, msg6} { + entry := tm.GetEntry(id) + if entry == nil { + t.Fatalf("entry %s not found in index", id) + } + } + + // Walk parent chain from msg6 — must reach root without cycles + visited := make(map[string]bool) + current := msg6 + for current != "" { + if visited[current] { + t.Fatalf("cycle detected at entry %s", current) + } + visited[current] = true + entry := tm.GetEntry(current) + if entry == nil { + t.Fatalf("entry %s missing from index during parent walk", current) + } + parent := "" + switch e := entry.(type) { + case *MessageEntry: + parent = e.ParentID + case *CompactionEntry: + parent = e.ParentID + } + current = parent + } + + // BuildContext should return: Summary2 + msg6 + msg5 + msg3 + msg4 = 5 messages + msgs, _, _ := tm.BuildContext() + if len(msgs) != 5 { + t.Fatalf("expected 5 messages, got %d: %+v", len(msgs), msgs) + } +} diff --git a/internal/session/tree_cycle_test.go b/internal/session/tree_cycle_test.go new file mode 100644 index 00000000..78b2db6a --- /dev/null +++ b/internal/session/tree_cycle_test.go @@ -0,0 +1,109 @@ +package session + +import ( + "testing" + + "github.com/mark3labs/kit/internal/message" +) + +// TestDetectCycleWithCorruptedParentChain tests that cycle detection works +// when a corrupted session has circular parent references. +func TestDetectCycleWithCorruptedParentChain(t *testing.T) { + tm := InMemoryTreeSession("/test") + + // Create normal chain: msg1 -> msg2 -> msg3 + id1, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg1"}}}) + _, _ = tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg2"}}}) + id3, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg3"}}}) + + // Simulate corruption: manually set msg1's parent to msg3, creating cycle + // This simulates the condition seen in the user's session + for _, entry := range tm.entries { + if e, ok := entry.(*MessageEntry); ok && e.ID == id1 { + e.ParentID = id3 // Create cycle: msg1 -> msg3 -> ... -> msg1 + break + } + } + + // DetectCycle should find the cycle + // The cycle is: id1 -> id3 -> id2 -> id1 + // So detecting from id3 should find id1 as the repeat + cycle, entry := tm.DetectCycle(id3) + if !cycle { + t.Fatal("expected to detect cycle, but none found") + } + // The cycle entry could be id1 or id3 depending on where we start + if entry != id1 && entry != id3 { + t.Fatalf("expected cycle at %s or %s, got %s", id1, id3, entry) + } + + // BuildContext should still work (it has its own cycle detection) + // but will truncate at the cycle point + msgs, _, _ := tm.BuildContext() + if len(msgs) == 0 { + t.Fatal("BuildContext returned no messages") + } +} + +// TestAppendMessageRejectsInvalidParent tests that AppendMessage rejects +// appending when the current leaf has a broken parent chain. +func TestAppendMessageRejectsInvalidParent(t *testing.T) { + tm := InMemoryTreeSession("/test") + + // Create normal message + id1, err := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg1"}}}) + if err != nil { + t.Fatalf("failed to append msg1: %v", err) + } + + // Simulate corruption: set leafID to a non-existent ID + tm.leafID = "non-existent-id" + + // Next append should fail validation + _, err = tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg2"}}}) + if err == nil { + t.Fatal("expected error when appending with invalid leafID, got nil") + } + + // Restore valid leafID + tm.leafID = id1 + + // Append should succeed now + _, err = tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg3"}}}) + if err != nil { + t.Fatalf("failed to append msg3 after restoring leafID: %v", err) + } +} + +// TestBuildContextHandlesCycleGracefully tests that BuildContext handles +// cycles gracefully by truncating the branch. +func TestBuildContextHandlesCycleGracefully(t *testing.T) { + tm := InMemoryTreeSession("/test") + + // Create messages + id1, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg1"}}}) + _, _ = tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg2"}}}) + id3, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg3"}}}) + + // Verify normal case works + msgs, _, _ := tm.BuildContext() + if len(msgs) != 3 { + t.Fatalf("expected 3 messages, got %d", len(msgs)) + } + + // Simulate cycle: set msg1's parent to msg3 + for _, entry := range tm.entries { + if e, ok := entry.(*MessageEntry); ok && e.ID == id1 { + e.ParentID = id3 + break + } + } + + // BuildContext should handle cycle gracefully (getBranchLocked has cycle detection) + msgs, _, _ = tm.BuildContext() + // Should only include messages from the cycle: msg3, msg2, msg1 + // (msg3 is leaf, walks to msg2 -> msg1 -> msg3 (cycle detected, stops)) + if len(msgs) != 3 { + t.Fatalf("expected 3 messages in cycle case, got %d: %+v", len(msgs), msgs) + } +} diff --git a/internal/session/tree_manager.go b/internal/session/tree_manager.go index 4d90215c..7d2a8b9c 100644 --- a/internal/session/tree_manager.go +++ b/internal/session/tree_manager.go @@ -365,6 +365,9 @@ func OpenTreeSession(path string) (*TreeManager, error) { tm.leafID = tm.EntryID(tm.entries[len(tm.entries)-1]) } + // Validate tree integrity and log diagnostics + tm.LogTreeDiagnostics() + // Open file for appending. f, err := os.OpenFile(path, os.O_WRONLY|os.O_APPEND, 0644) if err != nil { @@ -410,6 +413,12 @@ func (tm *TreeManager) AppendMessage(msg message.Message) (string, error) { tm.mu.Lock() defer tm.mu.Unlock() + // Validate parent chain before appending to detect/prevent cycles + // that could be caused by external file corruption or race conditions. + if err := tm.validateParentChainLocked(tm.leafID, ""); err != nil { + return "", fmt.Errorf("parent chain validation failed: %w", err) + } + entry, err := NewMessageEntry(tm.leafID, msg) if err != nil { return "", err @@ -518,6 +527,13 @@ func (tm *TreeManager) AppendCompaction(summary, firstKeptEntryID string, tokens tm.mu.Lock() defer tm.mu.Unlock() + // Validate that firstKeptEntryID exists if provided + if firstKeptEntryID != "" { + if _, ok := tm.index[firstKeptEntryID]; !ok { + return "", fmt.Errorf("first kept entry %q does not exist", firstKeptEntryID) + } + } + // The compaction entry has no parent, making it a new "root" for the // post-compaction branch. This ensures old compacted messages are not // traversed when walking from the current leaf. @@ -1213,12 +1229,32 @@ func (tm *TreeManager) getBranchLocked(fromID string) []any { } // buildTreeNode recursively builds a TreeNode from an entry ID. +// It includes a depth limit to prevent infinite recursion in case of +// corrupted parent-child relationships. func (tm *TreeManager) buildTreeNode(id string) *TreeNode { + return tm.buildTreeNodeDepth(id, 0, make(map[string]bool)) +} + +// buildTreeNodeDepth is the internal implementation with depth tracking. +func (tm *TreeManager) buildTreeNodeDepth(id string, depth int, visited map[string]bool) *TreeNode { + const maxDepth = 1000 + if depth > maxDepth { + // Cycle or extremely deep tree detected, stop recursing + return nil + } + if visited[id] { + // Cycle detected, stop recursing + return nil + } + entry, ok := tm.index[id] if !ok { return nil } + visited[id] = true + defer delete(visited, id) + node := &TreeNode{ Entry: entry, ID: id, @@ -1226,7 +1262,7 @@ func (tm *TreeManager) buildTreeNode(id string) *TreeNode { } for _, childID := range tm.childIndex[id] { - child := tm.buildTreeNode(childID) + child := tm.buildTreeNodeDepth(childID, depth+1, visited) if child != nil { node.Children = append(node.Children, child) } diff --git a/internal/session/tree_validation.go b/internal/session/tree_validation.go new file mode 100644 index 00000000..898cdbd4 --- /dev/null +++ b/internal/session/tree_validation.go @@ -0,0 +1,143 @@ +package session + +import ( + "fmt" + "log" +) + +// ValidateParentChain checks that the parent ID points to an existing entry +// and that appending this entry would not create a cycle. This should be called +// before appending any entry to the tree. +// Returns an error if the parent is invalid or would create a cycle. +func (tm *TreeManager) ValidateParentChain(parentID string, newEntryID string) error { + if parentID == "" { + // Empty parent is valid (root entry) + return nil + } + + // Check that parent exists + if _, ok := tm.index[parentID]; !ok { + return fmt.Errorf("parent entry %q does not exist in index", parentID) + } + + // Check that we're not creating a cycle by walking up the parent chain + // from parentID and ensuring we don't hit newEntryID (or any node that + // has newEntryID as an ancestor, but since newEntryID is new, just check + // that parentID isn't newEntryID, which it can't be since we check existence) + visited := make(map[string]bool) + current := parentID + for current != "" { + if visited[current] { + return fmt.Errorf("existing cycle detected at entry %q", current) + } + visited[current] = true + + // Safety check: if somehow we reach the new entry ID, that's a cycle + if current == newEntryID { + return fmt.Errorf("would create cycle: entry %q cannot be its own ancestor", newEntryID) + } + + entry, ok := tm.index[current] + if !ok { + return fmt.Errorf("broken parent chain: entry %q not found", current) + } + current = tm.entryParentID(entry) + } + + return nil +} + +// DetectCycle walks the parent chain from the given entry ID and returns true +// if a cycle is detected. This is used for diagnostics. +func (tm *TreeManager) DetectCycle(fromID string) (cycleDetected bool, cycleEntry string) { + visited := make(map[string]bool) + current := fromID + for current != "" { + if visited[current] { + return true, current + } + visited[current] = true + entry, ok := tm.index[current] + if !ok { + return false, "" + } + current = tm.entryParentID(entry) + } + return false, "" +} + +// LogTreeDiagnostics logs information about the tree structure for debugging. +// Call this after OpenTreeSession or when anomalies are detected. +func (tm *TreeManager) LogTreeDiagnostics() { + tm.mu.RLock() + defer tm.mu.RUnlock() + + log.Printf("[TreeManager] Entry count: %d, Leaf ID: %s", len(tm.entries), tm.leafID) + + // Check for cycles from leaf + if tm.leafID != "" { + if cycle, entry := tm.detectCycleLocked(tm.leafID); cycle { + log.Printf("[TreeManager] WARNING: Cycle detected in tree at entry %s", entry) + } + } + + // Count entries by type + counts := make(map[EntryType]int) + for _, entry := range tm.entries { + var et EntryType + switch e := entry.(type) { + case *MessageEntry: + et = e.Type + case *ModelChangeEntry: + et = e.Type + case *BranchSummaryEntry: + et = e.Type + case *LabelEntry: + et = e.Type + case *SessionInfoEntry: + et = e.Type + case *ExtensionDataEntry: + et = e.Type + case *CompactionEntry: + et = e.Type + default: + et = "unknown" + } + counts[et]++ + } + log.Printf("[TreeManager] Entry types: %+v", counts) +} + +// detectCycleLocked is the internal version of DetectCycle (must hold read lock) +func (tm *TreeManager) detectCycleLocked(fromID string) (bool, string) { + visited := make(map[string]bool) + current := fromID + for current != "" { + if visited[current] { + return true, current + } + visited[current] = true + entry, ok := tm.index[current] + if !ok { + return false, "" + } + current = tm.entryParentID(entry) + } + return false, "" +} + +// validateParentChainLocked is the internal version used by append methods. +// Must be called with the write lock held. +func (tm *TreeManager) validateParentChainLocked(parentID string, newEntryID string) error { + if parentID == "" { + return nil + } + if _, ok := tm.index[parentID]; !ok { + return fmt.Errorf("parent entry %q does not exist", parentID) + } + // Check for existing cycles in the parent chain + if cycle, entry := tm.detectCycleLocked(parentID); cycle { + return fmt.Errorf("existing cycle detected at entry %q in parent chain", entry) + } + return nil +}