mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Implement saved sessions (#82)
This commit is contained in:
+224
-18
@@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"github.com/mark3labs/mcphost/internal/auth"
|
||||
"github.com/mark3labs/mcphost/internal/config"
|
||||
"github.com/mark3labs/mcphost/internal/models"
|
||||
"github.com/mark3labs/mcphost/internal/session"
|
||||
"github.com/mark3labs/mcphost/internal/tokens"
|
||||
"github.com/mark3labs/mcphost/internal/ui"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -34,6 +36,10 @@ var (
|
||||
maxSteps int
|
||||
scriptMCPConfig *config.Config // Used to override config in script mode
|
||||
|
||||
// Session management
|
||||
saveSessionPath string
|
||||
loadSessionPath string
|
||||
|
||||
// Model generation parameters
|
||||
maxTokens int
|
||||
temperature float32
|
||||
@@ -69,6 +75,11 @@ Examples:
|
||||
mcphost -p "What is the weather like today?"
|
||||
mcphost -p "Calculate 15 * 23" --quiet
|
||||
|
||||
# Session management
|
||||
mcphost --save-session ./my-session.json -p "Hello"
|
||||
mcphost --load-session ./my-session.json -p "Continue our conversation"
|
||||
mcphost --load-session ./session.json --save-session ./session.json -p "Next message"
|
||||
|
||||
# Script mode
|
||||
mcphost script myscript.sh`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
@@ -155,6 +166,12 @@ func init() {
|
||||
rootCmd.PersistentFlags().
|
||||
IntVar(&maxSteps, "max-steps", 0, "maximum number of agent steps (0 for unlimited)")
|
||||
|
||||
// Session management flags
|
||||
rootCmd.PersistentFlags().
|
||||
StringVar(&saveSessionPath, "save-session", "", "save session to file after each message")
|
||||
rootCmd.PersistentFlags().
|
||||
StringVar(&loadSessionPath, "load-session", "", "load session from file at startup")
|
||||
|
||||
flags := rootCmd.PersistentFlags()
|
||||
flags.StringVar(&providerURL, "provider-url", "", "base URL for the provider API (applies to OpenAI, Anthropic, Ollama, and Google)")
|
||||
flags.StringVar(&providerAPIKey, "provider-api-key", "", "API key for the provider (applies to OpenAI, Anthropic, and Google)")
|
||||
@@ -380,10 +397,120 @@ func runNormalMode(ctx context.Context) error {
|
||||
|
||||
// Main interaction logic
|
||||
var messages []*schema.Message
|
||||
var sessionManager *session.Manager
|
||||
|
||||
// Load existing session if specified
|
||||
if loadSessionPath != "" {
|
||||
loadedSession, err := session.LoadFromFile(loadSessionPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load session: %v", err)
|
||||
}
|
||||
|
||||
// Convert session messages to schema messages
|
||||
for _, msg := range loadedSession.Messages {
|
||||
messages = append(messages, msg.ConvertToSchemaMessage())
|
||||
}
|
||||
|
||||
// If we're also saving, use the loaded session with the session manager
|
||||
if saveSessionPath != "" {
|
||||
sessionManager = session.NewManagerWithSession(loadedSession, saveSessionPath)
|
||||
}
|
||||
|
||||
if !quietFlag && cli != nil {
|
||||
// Create a map of tool call IDs to tool calls for quick lookup
|
||||
toolCallMap := make(map[string]session.ToolCall)
|
||||
for _, sessionMsg := range loadedSession.Messages {
|
||||
if sessionMsg.Role == "assistant" && len(sessionMsg.ToolCalls) > 0 {
|
||||
for _, tc := range sessionMsg.ToolCalls {
|
||||
toolCallMap[tc.ID] = tc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Display all previous messages as they would have appeared
|
||||
for _, sessionMsg := range loadedSession.Messages {
|
||||
if sessionMsg.Role == "user" {
|
||||
cli.DisplayUserMessage(sessionMsg.Content)
|
||||
} else if sessionMsg.Role == "assistant" {
|
||||
// Display tool calls if present
|
||||
if len(sessionMsg.ToolCalls) > 0 {
|
||||
for _, tc := range sessionMsg.ToolCalls {
|
||||
// Convert arguments to string
|
||||
var argsStr string
|
||||
if argBytes, err := json.Marshal(tc.Arguments); err == nil {
|
||||
argsStr = string(argBytes)
|
||||
}
|
||||
|
||||
// Display tool call
|
||||
cli.DisplayToolCallMessage(tc.Name, argsStr)
|
||||
}
|
||||
}
|
||||
|
||||
// Display assistant response (only if there's content)
|
||||
if sessionMsg.Content != "" {
|
||||
cli.DisplayAssistantMessage(sessionMsg.Content)
|
||||
}
|
||||
} else if sessionMsg.Role == "tool" {
|
||||
// Display tool result
|
||||
if sessionMsg.ToolCallID != "" {
|
||||
if toolCall, exists := toolCallMap[sessionMsg.ToolCallID]; exists {
|
||||
// Convert arguments to string
|
||||
var argsStr string
|
||||
if argBytes, err := json.Marshal(toolCall.Arguments); err == nil {
|
||||
argsStr = string(argBytes)
|
||||
}
|
||||
|
||||
// Parse tool result content - it might be JSON-encoded MCP content
|
||||
resultContent := sessionMsg.Content
|
||||
|
||||
// Try to parse as MCP content structure
|
||||
var mcpContent struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
}
|
||||
|
||||
// First try to unmarshal as-is
|
||||
if err := json.Unmarshal([]byte(sessionMsg.Content), &mcpContent); err == nil {
|
||||
// Extract text from MCP content structure
|
||||
if len(mcpContent.Content) > 0 && mcpContent.Content[0].Type == "text" {
|
||||
resultContent = mcpContent.Content[0].Text
|
||||
}
|
||||
} else {
|
||||
// If that fails, try unquoting first (in case it's double-encoded)
|
||||
var unquoted string
|
||||
if err := json.Unmarshal([]byte(sessionMsg.Content), &unquoted); err == nil {
|
||||
if err := json.Unmarshal([]byte(unquoted), &mcpContent); err == nil {
|
||||
if len(mcpContent.Content) > 0 && mcpContent.Content[0].Type == "text" {
|
||||
resultContent = mcpContent.Content[0].Text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Display tool result (assuming no error for saved results)
|
||||
cli.DisplayToolMessage(toolCall.Name, argsStr, resultContent, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if saveSessionPath != "" {
|
||||
// Only saving, create new session manager
|
||||
sessionManager = session.NewManager(saveSessionPath)
|
||||
|
||||
// Set metadata
|
||||
sessionManager.SetMetadata(session.Metadata{
|
||||
MCPHostVersion: "dev", // TODO: Get actual version
|
||||
Provider: parts[0],
|
||||
Model: modelName,
|
||||
})
|
||||
}
|
||||
|
||||
// Check if running in non-interactive mode
|
||||
if promptFlag != "" {
|
||||
return runNonInteractiveMode(ctx, mcpAgent, cli, promptFlag, modelName, messages, quietFlag, noExitFlag, mcpConfig)
|
||||
return runNonInteractiveMode(ctx, mcpAgent, cli, promptFlag, modelName, messages, quietFlag, noExitFlag, mcpConfig, sessionManager)
|
||||
}
|
||||
|
||||
// Quiet mode is not allowed in interactive mode
|
||||
@@ -391,7 +518,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
return fmt.Errorf("--quiet flag can only be used with --prompt/-p")
|
||||
}
|
||||
|
||||
return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages)
|
||||
return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager)
|
||||
}
|
||||
|
||||
// AgenticLoopConfig configures the behavior of the unified agentic loop
|
||||
@@ -405,10 +532,11 @@ type AgenticLoopConfig struct {
|
||||
Quiet bool // suppress all output except final response
|
||||
|
||||
// Context data
|
||||
ServerNames []string // for slash commands
|
||||
ToolNames []string // for slash commands
|
||||
ModelName string // for display
|
||||
MCPConfig *config.Config // for continuing to interactive mode
|
||||
ServerNames []string // for slash commands
|
||||
ToolNames []string // for slash commands
|
||||
ModelName string // for display
|
||||
MCPConfig *config.Config // for continuing to interactive mode
|
||||
SessionManager *session.Manager // for session persistence
|
||||
}
|
||||
|
||||
// runAgenticLoop handles all execution modes with a single unified loop
|
||||
@@ -424,7 +552,7 @@ func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
|
||||
tempMessages := append(messages, schema.UserMessage(config.InitialPrompt))
|
||||
|
||||
// Process the initial prompt with tool calls
|
||||
response, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config)
|
||||
response, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config)
|
||||
if err != nil {
|
||||
// Check if this was a user cancellation
|
||||
if err.Error() == "generation cancelled by user" && cli != nil {
|
||||
@@ -437,9 +565,25 @@ func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
|
||||
}
|
||||
} else {
|
||||
// Only add to history after successful completion
|
||||
messages = append(messages, schema.UserMessage(config.InitialPrompt))
|
||||
userMsg := schema.UserMessage(config.InitialPrompt)
|
||||
messages = append(messages, userMsg)
|
||||
messages = append(messages, response)
|
||||
|
||||
// Save to session if session manager is available
|
||||
if config.SessionManager != nil {
|
||||
// Simple approach: save the entire conversation history
|
||||
// This includes the user message + all generated messages
|
||||
allMessages := append([]*schema.Message{userMsg}, conversationMessages...)
|
||||
|
||||
// Clear the session and save the complete history
|
||||
if err := config.SessionManager.ReplaceAllMessages(allMessages); err != nil {
|
||||
// Log error but don't fail the operation
|
||||
if cli != nil && !config.Quiet {
|
||||
cli.DisplayError(fmt.Errorf("failed to save conversation to session: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If not continuing to interactive mode, exit here
|
||||
if !config.ContinueAfterRun {
|
||||
return nil
|
||||
@@ -459,7 +603,7 @@ func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
|
||||
}
|
||||
|
||||
// runAgenticStep processes a single step of the agentic loop (handles tool calls)
|
||||
func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig) (*schema.Message, error) {
|
||||
func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig) (*schema.Message, []*schema.Message, error) {
|
||||
var currentSpinner *ui.Spinner
|
||||
|
||||
// Start initial spinner (skip if quiet)
|
||||
@@ -468,7 +612,7 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
|
||||
currentSpinner.Start()
|
||||
}
|
||||
|
||||
response, err := mcpAgent.GenerateWithLoop(ctx, messages,
|
||||
result, err := mcpAgent.GenerateWithLoop(ctx, messages,
|
||||
// Tool call handler - called when a tool is about to be executed
|
||||
func(toolName, toolArgs string) {
|
||||
if !config.Quiet && cli != nil {
|
||||
@@ -499,7 +643,36 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
|
||||
// Tool result handler - called when a tool execution completes
|
||||
func(toolName, toolArgs, result string, isError bool) {
|
||||
if !config.Quiet && cli != nil {
|
||||
cli.DisplayToolMessage(toolName, toolArgs, result, isError)
|
||||
// Parse tool result content - it might be JSON-encoded MCP content
|
||||
resultContent := result
|
||||
|
||||
// Try to parse as MCP content structure
|
||||
var mcpContent struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
}
|
||||
|
||||
// First try to unmarshal as-is
|
||||
if err := json.Unmarshal([]byte(result), &mcpContent); err == nil {
|
||||
// Extract text from MCP content structure
|
||||
if len(mcpContent.Content) > 0 && mcpContent.Content[0].Type == "text" {
|
||||
resultContent = mcpContent.Content[0].Text
|
||||
}
|
||||
} else {
|
||||
// If that fails, try unquoting first (in case it's double-encoded)
|
||||
var unquoted string
|
||||
if err := json.Unmarshal([]byte(result), &unquoted); err == nil {
|
||||
if err := json.Unmarshal([]byte(unquoted), &mcpContent); err == nil {
|
||||
if len(mcpContent.Content) > 0 && mcpContent.Content[0].Type == "text" {
|
||||
resultContent = mcpContent.Content[0].Text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cli.DisplayToolMessage(toolName, toolArgs, resultContent, isError)
|
||||
// Start spinner again for next LLM call
|
||||
currentSpinner = ui.NewSpinner("Thinking...")
|
||||
currentSpinner.Start()
|
||||
@@ -540,14 +713,18 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
|
||||
if !config.Quiet && cli != nil {
|
||||
cli.DisplayError(fmt.Errorf("agent error: %v", err))
|
||||
}
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Get the final response and conversation messages
|
||||
response := result.FinalResponse
|
||||
conversationMessages := result.ConversationMessages
|
||||
|
||||
// Display assistant response with model name (skip if quiet)
|
||||
if !config.Quiet && cli != nil {
|
||||
if err := cli.DisplayAssistantMessageWithModel(response.Content, config.ModelName); err != nil {
|
||||
cli.DisplayError(fmt.Errorf("display error: %v", err))
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Update usage tracking with the last user message and response
|
||||
@@ -567,7 +744,8 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
|
||||
fmt.Print(response.Content)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
// Return the final response and all conversation messages
|
||||
return response, conversationMessages, nil
|
||||
}
|
||||
|
||||
// runInteractiveLoop handles the interactive portion of the agentic loop
|
||||
@@ -603,7 +781,7 @@ func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI,
|
||||
tempMessages := append(messages, schema.UserMessage(prompt))
|
||||
|
||||
// Process the user input with tool calls
|
||||
response, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config)
|
||||
response, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config)
|
||||
if err != nil {
|
||||
// Check if this was a user cancellation
|
||||
if err.Error() == "generation cancelled by user" {
|
||||
@@ -615,13 +793,39 @@ func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI,
|
||||
}
|
||||
|
||||
// Only add to history after successful completion
|
||||
messages = append(messages, schema.UserMessage(prompt))
|
||||
userMsg := schema.UserMessage(prompt)
|
||||
messages = append(messages, userMsg)
|
||||
messages = append(messages, response)
|
||||
|
||||
// Save to session if session manager is available
|
||||
if config.SessionManager != nil {
|
||||
if err := config.SessionManager.AddMessage(userMsg); err != nil {
|
||||
// Log error but don't fail the operation
|
||||
cli.DisplayError(fmt.Errorf("failed to save user message to session: %v", err))
|
||||
}
|
||||
|
||||
// Save all conversation messages (includes tool calls and results)
|
||||
// Find the messages that were generated during this conversation
|
||||
if len(conversationMessages) > len(tempMessages) {
|
||||
// Extract only the new messages generated during this step
|
||||
newMessages := conversationMessages[len(tempMessages):]
|
||||
if err := config.SessionManager.AddMessages(newMessages); err != nil {
|
||||
// Log error but don't fail the operation
|
||||
cli.DisplayError(fmt.Errorf("failed to save conversation messages to session: %v", err))
|
||||
}
|
||||
} else {
|
||||
// No tool calls, just save the final response
|
||||
if err := config.SessionManager.AddMessage(response); err != nil {
|
||||
// Log error but don't fail the operation
|
||||
cli.DisplayError(fmt.Errorf("failed to save assistant message to session: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runNonInteractiveMode handles the non-interactive mode execution
|
||||
func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, prompt, modelName string, messages []*schema.Message, quiet, noExit bool, mcpConfig *config.Config) error {
|
||||
func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, prompt, modelName string, messages []*schema.Message, quiet, noExit bool, mcpConfig *config.Config, sessionManager *session.Manager) error {
|
||||
// Prepare data for slash commands (needed if continuing to interactive mode)
|
||||
var serverNames []string
|
||||
for name := range mcpConfig.MCPServers {
|
||||
@@ -646,13 +850,14 @@ func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.C
|
||||
ToolNames: toolNames,
|
||||
ModelName: modelName,
|
||||
MCPConfig: mcpConfig,
|
||||
SessionManager: sessionManager,
|
||||
}
|
||||
|
||||
return runAgenticLoop(ctx, mcpAgent, cli, messages, config)
|
||||
}
|
||||
|
||||
// runInteractiveMode handles the interactive mode execution
|
||||
func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message) error {
|
||||
func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager) error {
|
||||
// Configure and run unified agentic loop
|
||||
config := AgenticLoopConfig{
|
||||
IsInteractive: true,
|
||||
@@ -663,6 +868,7 @@ func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI,
|
||||
ToolNames: toolNames,
|
||||
ModelName: modelName,
|
||||
MCPConfig: nil, // Not needed for pure interactive mode
|
||||
SessionManager: sessionManager,
|
||||
}
|
||||
|
||||
return runAgenticLoop(ctx, mcpAgent, cli, messages, config)
|
||||
|
||||
+17
-3
@@ -69,9 +69,15 @@ func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateWithLoopResult contains the result and conversation history
|
||||
type GenerateWithLoopResult struct {
|
||||
FinalResponse *schema.Message
|
||||
ConversationMessages []*schema.Message // All messages in the conversation (including tool calls and results)
|
||||
}
|
||||
|
||||
// GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time
|
||||
func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message,
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*schema.Message, error) {
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*GenerateWithLoopResult, error) {
|
||||
|
||||
// Create a copy of messages to avoid modifying the original
|
||||
workingMessages := make([]*schema.Message, len(messages))
|
||||
@@ -127,6 +133,7 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message
|
||||
|
||||
// Check if this is a tool call or final response
|
||||
if len(response.ToolCalls) > 0 {
|
||||
|
||||
// Display any content that accompanies the tool calls
|
||||
if response.Content != "" && onToolCallContent != nil {
|
||||
onToolCallContent(response.Content)
|
||||
@@ -193,12 +200,19 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message
|
||||
if onResponse != nil && response.Content != "" {
|
||||
onResponse(response.Content)
|
||||
}
|
||||
return response, nil
|
||||
return &GenerateWithLoopResult{
|
||||
FinalResponse: response,
|
||||
ConversationMessages: workingMessages,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If we reach here, we've exceeded max steps
|
||||
return schema.AssistantMessage("Maximum number of steps reached.", nil), nil
|
||||
finalResponse := schema.AssistantMessage("Maximum number of steps reached.", nil)
|
||||
return &GenerateWithLoopResult{
|
||||
FinalResponse: finalResponse,
|
||||
ConversationMessages: workingMessages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetTools returns the list of available tools
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// Manager manages session state and auto-saving
|
||||
type Manager struct {
|
||||
session *Session
|
||||
filePath string
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewManager creates a new session manager
|
||||
func NewManager(filePath string) *Manager {
|
||||
return &Manager{
|
||||
session: NewSession(),
|
||||
filePath: filePath,
|
||||
}
|
||||
}
|
||||
|
||||
// NewManagerWithSession creates a new session manager with an existing session
|
||||
func NewManagerWithSession(session *Session, filePath string) *Manager {
|
||||
return &Manager{
|
||||
session: session,
|
||||
filePath: filePath,
|
||||
}
|
||||
}
|
||||
|
||||
// AddMessage adds a message to the session and auto-saves
|
||||
func (m *Manager) AddMessage(msg *schema.Message) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
sessionMsg := ConvertFromSchemaMessage(msg)
|
||||
m.session.AddMessage(sessionMsg)
|
||||
|
||||
if m.filePath != "" {
|
||||
return m.session.SaveToFile(m.filePath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddMessages adds multiple messages to the session and auto-saves
|
||||
func (m *Manager) AddMessages(msgs []*schema.Message) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
for _, msg := range msgs {
|
||||
sessionMsg := ConvertFromSchemaMessage(msg)
|
||||
m.session.AddMessage(sessionMsg)
|
||||
}
|
||||
|
||||
if m.filePath != "" {
|
||||
return m.session.SaveToFile(m.filePath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReplaceAllMessages replaces all messages in the session with the provided messages
|
||||
func (m *Manager) ReplaceAllMessages(msgs []*schema.Message) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// Clear existing messages
|
||||
m.session.Messages = []Message{}
|
||||
|
||||
// Add all new messages
|
||||
for _, msg := range msgs {
|
||||
sessionMsg := ConvertFromSchemaMessage(msg)
|
||||
m.session.AddMessage(sessionMsg)
|
||||
}
|
||||
|
||||
if m.filePath != "" {
|
||||
return m.session.SaveToFile(m.filePath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
// SetMetadata sets the session metadata
|
||||
func (m *Manager) SetMetadata(metadata Metadata) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.session.SetMetadata(metadata)
|
||||
|
||||
if m.filePath != "" {
|
||||
return m.session.SaveToFile(m.filePath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMessages returns all messages as schema.Message slice
|
||||
func (m *Manager) GetMessages() []*schema.Message {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
messages := make([]*schema.Message, len(m.session.Messages))
|
||||
for i, msg := range m.session.Messages {
|
||||
messages[i] = msg.ConvertToSchemaMessage()
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// GetSession returns a copy of the current session
|
||||
func (m *Manager) GetSession() *Session {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
// Return a copy to prevent external modification
|
||||
sessionCopy := *m.session
|
||||
sessionCopy.Messages = make([]Message, len(m.session.Messages))
|
||||
copy(sessionCopy.Messages, m.session.Messages)
|
||||
|
||||
return &sessionCopy
|
||||
}
|
||||
|
||||
// Save manually saves the session to file
|
||||
func (m *Manager) Save() error {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
if m.filePath == "" {
|
||||
return fmt.Errorf("no file path specified for session manager")
|
||||
}
|
||||
|
||||
return m.session.SaveToFile(m.filePath)
|
||||
}
|
||||
|
||||
// GetFilePath returns the file path for this session
|
||||
func (m *Manager) GetFilePath() string {
|
||||
return m.filePath
|
||||
}
|
||||
|
||||
// MessageCount returns the number of messages in the session
|
||||
func (m *Manager) MessageCount() int {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
return len(m.session.Messages)
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// Session represents a complete conversation session with metadata
|
||||
type Session struct {
|
||||
Version string `json:"version"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Metadata Metadata `json:"metadata"`
|
||||
Messages []Message `json:"messages"`
|
||||
}
|
||||
|
||||
// Metadata contains session metadata
|
||||
type Metadata struct {
|
||||
MCPHostVersion string `json:"mcphost_version"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// Message represents a single message in the session
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"` // For tool result messages
|
||||
}
|
||||
|
||||
// ToolCall represents a tool call within a message
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Arguments any `json:"arguments"`
|
||||
}
|
||||
|
||||
|
||||
|
||||
// NewSession creates a new session with default values
|
||||
func NewSession() *Session {
|
||||
return &Session{
|
||||
Version: "1.0",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
Messages: []Message{},
|
||||
Metadata: Metadata{},
|
||||
}
|
||||
}
|
||||
|
||||
// AddMessage adds a message to the session
|
||||
func (s *Session) AddMessage(msg Message) {
|
||||
if msg.ID == "" {
|
||||
msg.ID = generateMessageID()
|
||||
}
|
||||
if msg.Timestamp.IsZero() {
|
||||
msg.Timestamp = time.Now()
|
||||
}
|
||||
|
||||
s.Messages = append(s.Messages, msg)
|
||||
s.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// SetMetadata sets the session metadata
|
||||
func (s *Session) SetMetadata(metadata Metadata) {
|
||||
s.Metadata = metadata
|
||||
s.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// SaveToFile saves the session to a JSON file
|
||||
func (s *Session) SaveToFile(filePath string) error {
|
||||
s.UpdatedAt = time.Now()
|
||||
|
||||
data, err := json.MarshalIndent(s, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal session: %v", err)
|
||||
}
|
||||
|
||||
return os.WriteFile(filePath, data, 0644)
|
||||
}
|
||||
|
||||
// LoadFromFile loads a session from a JSON file
|
||||
func LoadFromFile(filePath string) (*Session, error) {
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read session file: %v", err)
|
||||
}
|
||||
|
||||
var session Session
|
||||
if err := json.Unmarshal(data, &session); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal session: %v", err)
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// ConvertFromSchemaMessage converts a schema.Message to a session Message
|
||||
func ConvertFromSchemaMessage(msg *schema.Message) Message {
|
||||
sessionMsg := Message{
|
||||
Role: string(msg.Role),
|
||||
Content: msg.Content,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
// Convert tool calls if present (for assistant messages)
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
sessionMsg.ToolCalls = make([]ToolCall, len(msg.ToolCalls))
|
||||
for i, tc := range msg.ToolCalls {
|
||||
sessionMsg.ToolCalls[i] = ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool result messages - extract tool call ID from ToolCallID field
|
||||
if msg.Role == schema.Tool && msg.ToolCallID != "" {
|
||||
sessionMsg.ToolCallID = msg.ToolCallID
|
||||
}
|
||||
|
||||
return sessionMsg
|
||||
}
|
||||
|
||||
// ConvertToSchemaMessage converts a session Message to a schema.Message
|
||||
func (m *Message) ConvertToSchemaMessage() *schema.Message {
|
||||
msg := &schema.Message{
|
||||
Role: schema.RoleType(m.Role),
|
||||
Content: m.Content,
|
||||
}
|
||||
|
||||
// Convert tool calls if present (for assistant messages)
|
||||
if len(m.ToolCalls) > 0 {
|
||||
msg.ToolCalls = make([]schema.ToolCall, len(m.ToolCalls))
|
||||
for i, tc := range m.ToolCalls {
|
||||
// Arguments are already stored as a string, use them directly
|
||||
var argsStr string
|
||||
if str, ok := tc.Arguments.(string); ok {
|
||||
argsStr = str
|
||||
} else {
|
||||
// Fallback: marshal to JSON if not a string
|
||||
if argBytes, err := json.Marshal(tc.Arguments); err == nil {
|
||||
argsStr = string(argBytes)
|
||||
}
|
||||
}
|
||||
|
||||
msg.ToolCalls[i] = schema.ToolCall{
|
||||
ID: tc.ID,
|
||||
Function: schema.FunctionCall{
|
||||
Name: tc.Name,
|
||||
Arguments: argsStr,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool result messages - set the tool call ID
|
||||
if m.Role == "tool" && m.ToolCallID != "" {
|
||||
msg.ToolCallID = m.ToolCallID
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// generateMessageID generates a unique message ID
|
||||
func generateMessageID() string {
|
||||
bytes := make([]byte, 8)
|
||||
rand.Read(bytes)
|
||||
return "msg_" + hex.EncodeToString(bytes)
|
||||
}
|
||||
Reference in New Issue
Block a user