Implement saved sessions (#82)

This commit is contained in:
Ed Zynda
2025-06-25 20:25:14 +03:00
committed by GitHub
parent b56fb3c597
commit 904cbc6b37
4 changed files with 571 additions and 21 deletions
+224 -18
View File
@@ -2,6 +2,7 @@ package cmd
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
@@ -13,6 +14,7 @@ import (
"github.com/mark3labs/mcphost/internal/auth"
"github.com/mark3labs/mcphost/internal/config"
"github.com/mark3labs/mcphost/internal/models"
"github.com/mark3labs/mcphost/internal/session"
"github.com/mark3labs/mcphost/internal/tokens"
"github.com/mark3labs/mcphost/internal/ui"
"github.com/spf13/cobra"
@@ -34,6 +36,10 @@ var (
maxSteps int
scriptMCPConfig *config.Config // Used to override config in script mode
// Session management
saveSessionPath string
loadSessionPath string
// Model generation parameters
maxTokens int
temperature float32
@@ -69,6 +75,11 @@ Examples:
mcphost -p "What is the weather like today?"
mcphost -p "Calculate 15 * 23" --quiet
# Session management
mcphost --save-session ./my-session.json -p "Hello"
mcphost --load-session ./my-session.json -p "Continue our conversation"
mcphost --load-session ./session.json --save-session ./session.json -p "Next message"
# Script mode
mcphost script myscript.sh`,
RunE: func(cmd *cobra.Command, args []string) error {
@@ -155,6 +166,12 @@ func init() {
rootCmd.PersistentFlags().
IntVar(&maxSteps, "max-steps", 0, "maximum number of agent steps (0 for unlimited)")
// Session management flags
rootCmd.PersistentFlags().
StringVar(&saveSessionPath, "save-session", "", "save session to file after each message")
rootCmd.PersistentFlags().
StringVar(&loadSessionPath, "load-session", "", "load session from file at startup")
flags := rootCmd.PersistentFlags()
flags.StringVar(&providerURL, "provider-url", "", "base URL for the provider API (applies to OpenAI, Anthropic, Ollama, and Google)")
flags.StringVar(&providerAPIKey, "provider-api-key", "", "API key for the provider (applies to OpenAI, Anthropic, and Google)")
@@ -380,10 +397,120 @@ func runNormalMode(ctx context.Context) error {
// Main interaction logic
var messages []*schema.Message
var sessionManager *session.Manager
// Load existing session if specified
if loadSessionPath != "" {
loadedSession, err := session.LoadFromFile(loadSessionPath)
if err != nil {
return fmt.Errorf("failed to load session: %v", err)
}
// Convert session messages to schema messages
for _, msg := range loadedSession.Messages {
messages = append(messages, msg.ConvertToSchemaMessage())
}
// If we're also saving, use the loaded session with the session manager
if saveSessionPath != "" {
sessionManager = session.NewManagerWithSession(loadedSession, saveSessionPath)
}
if !quietFlag && cli != nil {
// Create a map of tool call IDs to tool calls for quick lookup
toolCallMap := make(map[string]session.ToolCall)
for _, sessionMsg := range loadedSession.Messages {
if sessionMsg.Role == "assistant" && len(sessionMsg.ToolCalls) > 0 {
for _, tc := range sessionMsg.ToolCalls {
toolCallMap[tc.ID] = tc
}
}
}
// Display all previous messages as they would have appeared
for _, sessionMsg := range loadedSession.Messages {
if sessionMsg.Role == "user" {
cli.DisplayUserMessage(sessionMsg.Content)
} else if sessionMsg.Role == "assistant" {
// Display tool calls if present
if len(sessionMsg.ToolCalls) > 0 {
for _, tc := range sessionMsg.ToolCalls {
// Convert arguments to string
var argsStr string
if argBytes, err := json.Marshal(tc.Arguments); err == nil {
argsStr = string(argBytes)
}
// Display tool call
cli.DisplayToolCallMessage(tc.Name, argsStr)
}
}
// Display assistant response (only if there's content)
if sessionMsg.Content != "" {
cli.DisplayAssistantMessage(sessionMsg.Content)
}
} else if sessionMsg.Role == "tool" {
// Display tool result
if sessionMsg.ToolCallID != "" {
if toolCall, exists := toolCallMap[sessionMsg.ToolCallID]; exists {
// Convert arguments to string
var argsStr string
if argBytes, err := json.Marshal(toolCall.Arguments); err == nil {
argsStr = string(argBytes)
}
// Parse tool result content - it might be JSON-encoded MCP content
resultContent := sessionMsg.Content
// Try to parse as MCP content structure
var mcpContent struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
}
// First try to unmarshal as-is
if err := json.Unmarshal([]byte(sessionMsg.Content), &mcpContent); err == nil {
// Extract text from MCP content structure
if len(mcpContent.Content) > 0 && mcpContent.Content[0].Type == "text" {
resultContent = mcpContent.Content[0].Text
}
} else {
// If that fails, try unquoting first (in case it's double-encoded)
var unquoted string
if err := json.Unmarshal([]byte(sessionMsg.Content), &unquoted); err == nil {
if err := json.Unmarshal([]byte(unquoted), &mcpContent); err == nil {
if len(mcpContent.Content) > 0 && mcpContent.Content[0].Type == "text" {
resultContent = mcpContent.Content[0].Text
}
}
}
}
// Display tool result (assuming no error for saved results)
cli.DisplayToolMessage(toolCall.Name, argsStr, resultContent, false)
}
}
}
}
}
} else if saveSessionPath != "" {
// Only saving, create new session manager
sessionManager = session.NewManager(saveSessionPath)
// Set metadata
sessionManager.SetMetadata(session.Metadata{
MCPHostVersion: "dev", // TODO: Get actual version
Provider: parts[0],
Model: modelName,
})
}
// Check if running in non-interactive mode
if promptFlag != "" {
return runNonInteractiveMode(ctx, mcpAgent, cli, promptFlag, modelName, messages, quietFlag, noExitFlag, mcpConfig)
return runNonInteractiveMode(ctx, mcpAgent, cli, promptFlag, modelName, messages, quietFlag, noExitFlag, mcpConfig, sessionManager)
}
// Quiet mode is not allowed in interactive mode
@@ -391,7 +518,7 @@ func runNormalMode(ctx context.Context) error {
return fmt.Errorf("--quiet flag can only be used with --prompt/-p")
}
return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages)
return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager)
}
// AgenticLoopConfig configures the behavior of the unified agentic loop
@@ -405,10 +532,11 @@ type AgenticLoopConfig struct {
Quiet bool // suppress all output except final response
// Context data
ServerNames []string // for slash commands
ToolNames []string // for slash commands
ModelName string // for display
MCPConfig *config.Config // for continuing to interactive mode
ServerNames []string // for slash commands
ToolNames []string // for slash commands
ModelName string // for display
MCPConfig *config.Config // for continuing to interactive mode
SessionManager *session.Manager // for session persistence
}
// runAgenticLoop handles all execution modes with a single unified loop
@@ -424,7 +552,7 @@ func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
tempMessages := append(messages, schema.UserMessage(config.InitialPrompt))
// Process the initial prompt with tool calls
response, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config)
response, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config)
if err != nil {
// Check if this was a user cancellation
if err.Error() == "generation cancelled by user" && cli != nil {
@@ -437,9 +565,25 @@ func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
}
} else {
// Only add to history after successful completion
messages = append(messages, schema.UserMessage(config.InitialPrompt))
userMsg := schema.UserMessage(config.InitialPrompt)
messages = append(messages, userMsg)
messages = append(messages, response)
// Save to session if session manager is available
if config.SessionManager != nil {
// Simple approach: save the entire conversation history
// This includes the user message + all generated messages
allMessages := append([]*schema.Message{userMsg}, conversationMessages...)
// Clear the session and save the complete history
if err := config.SessionManager.ReplaceAllMessages(allMessages); err != nil {
// Log error but don't fail the operation
if cli != nil && !config.Quiet {
cli.DisplayError(fmt.Errorf("failed to save conversation to session: %v", err))
}
}
}
// If not continuing to interactive mode, exit here
if !config.ContinueAfterRun {
return nil
@@ -459,7 +603,7 @@ func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
}
// runAgenticStep processes a single step of the agentic loop (handles tool calls)
func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig) (*schema.Message, error) {
func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig) (*schema.Message, []*schema.Message, error) {
var currentSpinner *ui.Spinner
// Start initial spinner (skip if quiet)
@@ -468,7 +612,7 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
currentSpinner.Start()
}
response, err := mcpAgent.GenerateWithLoop(ctx, messages,
result, err := mcpAgent.GenerateWithLoop(ctx, messages,
// Tool call handler - called when a tool is about to be executed
func(toolName, toolArgs string) {
if !config.Quiet && cli != nil {
@@ -499,7 +643,36 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
// Tool result handler - called when a tool execution completes
func(toolName, toolArgs, result string, isError bool) {
if !config.Quiet && cli != nil {
cli.DisplayToolMessage(toolName, toolArgs, result, isError)
// Parse tool result content - it might be JSON-encoded MCP content
resultContent := result
// Try to parse as MCP content structure
var mcpContent struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
}
// First try to unmarshal as-is
if err := json.Unmarshal([]byte(result), &mcpContent); err == nil {
// Extract text from MCP content structure
if len(mcpContent.Content) > 0 && mcpContent.Content[0].Type == "text" {
resultContent = mcpContent.Content[0].Text
}
} else {
// If that fails, try unquoting first (in case it's double-encoded)
var unquoted string
if err := json.Unmarshal([]byte(result), &unquoted); err == nil {
if err := json.Unmarshal([]byte(unquoted), &mcpContent); err == nil {
if len(mcpContent.Content) > 0 && mcpContent.Content[0].Type == "text" {
resultContent = mcpContent.Content[0].Text
}
}
}
}
cli.DisplayToolMessage(toolName, toolArgs, resultContent, isError)
// Start spinner again for next LLM call
currentSpinner = ui.NewSpinner("Thinking...")
currentSpinner.Start()
@@ -540,14 +713,18 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
if !config.Quiet && cli != nil {
cli.DisplayError(fmt.Errorf("agent error: %v", err))
}
return nil, err
return nil, nil, err
}
// Get the final response and conversation messages
response := result.FinalResponse
conversationMessages := result.ConversationMessages
// Display assistant response with model name (skip if quiet)
if !config.Quiet && cli != nil {
if err := cli.DisplayAssistantMessageWithModel(response.Content, config.ModelName); err != nil {
cli.DisplayError(fmt.Errorf("display error: %v", err))
return nil, err
return nil, nil, err
}
// Update usage tracking with the last user message and response
@@ -567,7 +744,8 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
fmt.Print(response.Content)
}
return response, nil
// Return the final response and all conversation messages
return response, conversationMessages, nil
}
// runInteractiveLoop handles the interactive portion of the agentic loop
@@ -603,7 +781,7 @@ func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI,
tempMessages := append(messages, schema.UserMessage(prompt))
// Process the user input with tool calls
response, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config)
response, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config)
if err != nil {
// Check if this was a user cancellation
if err.Error() == "generation cancelled by user" {
@@ -615,13 +793,39 @@ func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI,
}
// Only add to history after successful completion
messages = append(messages, schema.UserMessage(prompt))
userMsg := schema.UserMessage(prompt)
messages = append(messages, userMsg)
messages = append(messages, response)
// Save to session if session manager is available
if config.SessionManager != nil {
if err := config.SessionManager.AddMessage(userMsg); err != nil {
// Log error but don't fail the operation
cli.DisplayError(fmt.Errorf("failed to save user message to session: %v", err))
}
// Save all conversation messages (includes tool calls and results)
// Find the messages that were generated during this conversation
if len(conversationMessages) > len(tempMessages) {
// Extract only the new messages generated during this step
newMessages := conversationMessages[len(tempMessages):]
if err := config.SessionManager.AddMessages(newMessages); err != nil {
// Log error but don't fail the operation
cli.DisplayError(fmt.Errorf("failed to save conversation messages to session: %v", err))
}
} else {
// No tool calls, just save the final response
if err := config.SessionManager.AddMessage(response); err != nil {
// Log error but don't fail the operation
cli.DisplayError(fmt.Errorf("failed to save assistant message to session: %v", err))
}
}
}
}
}
// runNonInteractiveMode handles the non-interactive mode execution
func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, prompt, modelName string, messages []*schema.Message, quiet, noExit bool, mcpConfig *config.Config) error {
func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, prompt, modelName string, messages []*schema.Message, quiet, noExit bool, mcpConfig *config.Config, sessionManager *session.Manager) error {
// Prepare data for slash commands (needed if continuing to interactive mode)
var serverNames []string
for name := range mcpConfig.MCPServers {
@@ -646,13 +850,14 @@ func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.C
ToolNames: toolNames,
ModelName: modelName,
MCPConfig: mcpConfig,
SessionManager: sessionManager,
}
return runAgenticLoop(ctx, mcpAgent, cli, messages, config)
}
// runInteractiveMode handles the interactive mode execution
func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message) error {
func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager) error {
// Configure and run unified agentic loop
config := AgenticLoopConfig{
IsInteractive: true,
@@ -663,6 +868,7 @@ func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI,
ToolNames: toolNames,
ModelName: modelName,
MCPConfig: nil, // Not needed for pure interactive mode
SessionManager: sessionManager,
}
return runAgenticLoop(ctx, mcpAgent, cli, messages, config)
+17 -3
View File
@@ -69,9 +69,15 @@ func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) {
}, nil
}
// GenerateWithLoopResult contains the result and conversation history
type GenerateWithLoopResult struct {
FinalResponse *schema.Message
ConversationMessages []*schema.Message // All messages in the conversation (including tool calls and results)
}
// GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time
func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message,
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*schema.Message, error) {
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*GenerateWithLoopResult, error) {
// Create a copy of messages to avoid modifying the original
workingMessages := make([]*schema.Message, len(messages))
@@ -127,6 +133,7 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message
// Check if this is a tool call or final response
if len(response.ToolCalls) > 0 {
// Display any content that accompanies the tool calls
if response.Content != "" && onToolCallContent != nil {
onToolCallContent(response.Content)
@@ -193,12 +200,19 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message
if onResponse != nil && response.Content != "" {
onResponse(response.Content)
}
return response, nil
return &GenerateWithLoopResult{
FinalResponse: response,
ConversationMessages: workingMessages,
}, nil
}
}
// If we reach here, we've exceeded max steps
return schema.AssistantMessage("Maximum number of steps reached.", nil), nil
finalResponse := schema.AssistantMessage("Maximum number of steps reached.", nil)
return &GenerateWithLoopResult{
FinalResponse: finalResponse,
ConversationMessages: workingMessages,
}, nil
}
// GetTools returns the list of available tools
+151
View File
@@ -0,0 +1,151 @@
package session
import (
"fmt"
"sync"
"github.com/cloudwego/eino/schema"
)
// Manager manages session state and auto-saving
type Manager struct {
session *Session
filePath string
mutex sync.RWMutex
}
// NewManager creates a new session manager
func NewManager(filePath string) *Manager {
return &Manager{
session: NewSession(),
filePath: filePath,
}
}
// NewManagerWithSession creates a new session manager with an existing session
func NewManagerWithSession(session *Session, filePath string) *Manager {
return &Manager{
session: session,
filePath: filePath,
}
}
// AddMessage adds a message to the session and auto-saves
func (m *Manager) AddMessage(msg *schema.Message) error {
m.mutex.Lock()
defer m.mutex.Unlock()
sessionMsg := ConvertFromSchemaMessage(msg)
m.session.AddMessage(sessionMsg)
if m.filePath != "" {
return m.session.SaveToFile(m.filePath)
}
return nil
}
// AddMessages adds multiple messages to the session and auto-saves
func (m *Manager) AddMessages(msgs []*schema.Message) error {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, msg := range msgs {
sessionMsg := ConvertFromSchemaMessage(msg)
m.session.AddMessage(sessionMsg)
}
if m.filePath != "" {
return m.session.SaveToFile(m.filePath)
}
return nil
}
// ReplaceAllMessages replaces all messages in the session with the provided messages
func (m *Manager) ReplaceAllMessages(msgs []*schema.Message) error {
m.mutex.Lock()
defer m.mutex.Unlock()
// Clear existing messages
m.session.Messages = []Message{}
// Add all new messages
for _, msg := range msgs {
sessionMsg := ConvertFromSchemaMessage(msg)
m.session.AddMessage(sessionMsg)
}
if m.filePath != "" {
return m.session.SaveToFile(m.filePath)
}
return nil
}
// SetMetadata sets the session metadata
func (m *Manager) SetMetadata(metadata Metadata) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.session.SetMetadata(metadata)
if m.filePath != "" {
return m.session.SaveToFile(m.filePath)
}
return nil
}
// GetMessages returns all messages as schema.Message slice
func (m *Manager) GetMessages() []*schema.Message {
m.mutex.RLock()
defer m.mutex.RUnlock()
messages := make([]*schema.Message, len(m.session.Messages))
for i, msg := range m.session.Messages {
messages[i] = msg.ConvertToSchemaMessage()
}
return messages
}
// GetSession returns a copy of the current session
func (m *Manager) GetSession() *Session {
m.mutex.RLock()
defer m.mutex.RUnlock()
// Return a copy to prevent external modification
sessionCopy := *m.session
sessionCopy.Messages = make([]Message, len(m.session.Messages))
copy(sessionCopy.Messages, m.session.Messages)
return &sessionCopy
}
// Save manually saves the session to file
func (m *Manager) Save() error {
m.mutex.RLock()
defer m.mutex.RUnlock()
if m.filePath == "" {
return fmt.Errorf("no file path specified for session manager")
}
return m.session.SaveToFile(m.filePath)
}
// GetFilePath returns the file path for this session
func (m *Manager) GetFilePath() string {
return m.filePath
}
// MessageCount returns the number of messages in the session
func (m *Manager) MessageCount() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
return len(m.session.Messages)
}
+179
View File
@@ -0,0 +1,179 @@
package session
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"time"
"github.com/cloudwego/eino/schema"
)
// Session represents a complete conversation session with metadata
type Session struct {
Version string `json:"version"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Metadata Metadata `json:"metadata"`
Messages []Message `json:"messages"`
}
// Metadata contains session metadata
type Metadata struct {
MCPHostVersion string `json:"mcphost_version"`
Provider string `json:"provider"`
Model string `json:"model"`
}
// Message represents a single message in the session
type Message struct {
ID string `json:"id"`
Role string `json:"role"`
Content string `json:"content"`
Timestamp time.Time `json:"timestamp"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` // For tool result messages
}
// ToolCall represents a tool call within a message
type ToolCall struct {
ID string `json:"id"`
Name string `json:"name"`
Arguments any `json:"arguments"`
}
// NewSession creates a new session with default values
func NewSession() *Session {
return &Session{
Version: "1.0",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Messages: []Message{},
Metadata: Metadata{},
}
}
// AddMessage adds a message to the session
func (s *Session) AddMessage(msg Message) {
if msg.ID == "" {
msg.ID = generateMessageID()
}
if msg.Timestamp.IsZero() {
msg.Timestamp = time.Now()
}
s.Messages = append(s.Messages, msg)
s.UpdatedAt = time.Now()
}
// SetMetadata sets the session metadata
func (s *Session) SetMetadata(metadata Metadata) {
s.Metadata = metadata
s.UpdatedAt = time.Now()
}
// SaveToFile saves the session to a JSON file
func (s *Session) SaveToFile(filePath string) error {
s.UpdatedAt = time.Now()
data, err := json.MarshalIndent(s, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal session: %v", err)
}
return os.WriteFile(filePath, data, 0644)
}
// LoadFromFile loads a session from a JSON file
func LoadFromFile(filePath string) (*Session, error) {
data, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read session file: %v", err)
}
var session Session
if err := json.Unmarshal(data, &session); err != nil {
return nil, fmt.Errorf("failed to unmarshal session: %v", err)
}
return &session, nil
}
// ConvertFromSchemaMessage converts a schema.Message to a session Message
func ConvertFromSchemaMessage(msg *schema.Message) Message {
sessionMsg := Message{
Role: string(msg.Role),
Content: msg.Content,
Timestamp: time.Now(),
}
// Convert tool calls if present (for assistant messages)
if len(msg.ToolCalls) > 0 {
sessionMsg.ToolCalls = make([]ToolCall, len(msg.ToolCalls))
for i, tc := range msg.ToolCalls {
sessionMsg.ToolCalls[i] = ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
}
}
}
// Handle tool result messages - extract tool call ID from ToolCallID field
if msg.Role == schema.Tool && msg.ToolCallID != "" {
sessionMsg.ToolCallID = msg.ToolCallID
}
return sessionMsg
}
// ConvertToSchemaMessage converts a session Message to a schema.Message
func (m *Message) ConvertToSchemaMessage() *schema.Message {
msg := &schema.Message{
Role: schema.RoleType(m.Role),
Content: m.Content,
}
// Convert tool calls if present (for assistant messages)
if len(m.ToolCalls) > 0 {
msg.ToolCalls = make([]schema.ToolCall, len(m.ToolCalls))
for i, tc := range m.ToolCalls {
// Arguments are already stored as a string, use them directly
var argsStr string
if str, ok := tc.Arguments.(string); ok {
argsStr = str
} else {
// Fallback: marshal to JSON if not a string
if argBytes, err := json.Marshal(tc.Arguments); err == nil {
argsStr = string(argBytes)
}
}
msg.ToolCalls[i] = schema.ToolCall{
ID: tc.ID,
Function: schema.FunctionCall{
Name: tc.Name,
Arguments: argsStr,
},
}
}
}
// Handle tool result messages - set the tool call ID
if m.Role == "tool" && m.ToolCallID != "" {
msg.ToolCallID = m.ToolCallID
}
return msg
}
// generateMessageID generates a unique message ID
func generateMessageID() string {
bytes := make([]byte, 8)
rand.Read(bytes)
return "msg_" + hex.EncodeToString(bytes)
}