Add message window to prevent overloading context

This commit is contained in:
Ed Zynda
2024-12-10 22:19:35 +03:00
parent 57fe2c6795
commit b79b521911
4 changed files with 125 additions and 2 deletions
+1
View File
@@ -2,3 +2,4 @@
.env
aidocs/
.mcp.json
*.log
+13
View File
@@ -22,6 +22,7 @@ This architecture allows language models to:
- Tool calling capabilities for both model types
- Configurable MCP server locations and arguments
- Consistent command interface across model types
- Configurable message history window for context management
## Installation 📦
@@ -88,15 +89,27 @@ mcphost ollama --model mistral
mcphost --config /path/to/config.json
```
### Setting Message Window Size
Control how many previous messages are kept in context:
```bash
mcphost --message-window 15
```
The default window size is 10 messages.
## Available Commands 💻
While chatting, you can use these commands:
- `/help`: Show available commands
- `/tools`: List all available tools
- `/servers`: List configured MCP servers
- `/history`: Display conversation history
- `/quit`: Exit the application
- `Ctrl+C`: Exit at any time
### Global Flags
- `--config`: Specify custom config file location
- `--message-window`: Set number of messages to keep in context (default: 10)
## Requirements 📋
- Go 1.18 or later
+3
View File
@@ -252,6 +252,9 @@ When you do need to use a tool, explain what you're doing first.`,
continue
}
if len(messages) > 0 {
messages = pruneMessages(messages)
}
err = runOllamaPrompt(client, mcpClients, allTools, prompt, &messages)
if err != nil {
return err
+108 -2
View File
@@ -12,6 +12,7 @@ import (
"github.com/charmbracelet/huh"
"github.com/charmbracelet/huh/spinner"
"github.com/charmbracelet/log"
"github.com/ollama/ollama/api"
"github.com/charmbracelet/glamour"
mcpclient "github.com/mark3labs/mcp-go/client"
@@ -23,7 +24,8 @@ import (
var (
renderer *glamour.TermRenderer
configFile string
configFile string
messageWindow int
)
const (
@@ -51,6 +53,107 @@ func Execute() {
func init() {
rootCmd.PersistentFlags().
StringVar(&configFile, "config", "", "config file (default is $HOME/mcp.json)")
rootCmd.PersistentFlags().
IntVar(&messageWindow, "message-window", 10, "number of messages to keep in context")
}
func pruneMessages[T MessageParam | api.Message](messages []T) []T {
if len(messages) <= messageWindow {
return messages
}
// Keep only the most recent messages based on window size
messages = messages[len(messages)-messageWindow:]
switch any(messages[0]).(type) {
case MessageParam:
// Handle Anthropic messages
toolUseIds := make(map[string]bool)
toolResultIds := make(map[string]bool)
// First pass: collect all tool use and result IDs
for _, msg := range messages {
m := any(msg).(MessageParam)
for _, block := range m.Content {
if block.Type == "tool_use" {
toolUseIds[block.ID] = true
} else if block.Type == "tool_result" {
toolResultIds[block.ToolUseID] = true
}
}
}
// Second pass: filter out orphaned tool calls/results
var prunedMessages []T
for _, msg := range messages {
m := any(msg).(MessageParam)
var prunedBlocks []ContentBlock
for _, block := range m.Content {
keep := true
if block.Type == "tool_use" {
keep = toolResultIds[block.ID]
} else if block.Type == "tool_result" {
keep = toolUseIds[block.ToolUseID]
}
if keep {
prunedBlocks = append(prunedBlocks, block)
}
}
// Only include messages that have content or are not assistant messages
if (len(prunedBlocks) > 0 && m.Role == "assistant") || m.Role != "assistant" {
hasTextBlock := false
for _, block := range m.Content {
if block.Type == "text" {
hasTextBlock = true
break
}
}
if len(prunedBlocks) > 0 || hasTextBlock {
m.Content = prunedBlocks
prunedMessages = append(prunedMessages, any(m).(T))
}
}
}
return prunedMessages
case api.Message:
// Handle Ollama messages
var prunedMessages []T
for i, msg := range messages {
m := any(msg).(api.Message)
// If this message has tool calls, ensure we keep the next message (tool response)
if len(m.ToolCalls) > 0 {
if i+1 < len(messages) {
next := any(messages[i+1]).(api.Message)
if next.Role == "tool" {
prunedMessages = append(prunedMessages, msg)
prunedMessages = append(prunedMessages, messages[i+1])
continue
}
}
// If no matching tool response, skip this message
continue
}
// Skip tool responses that don't have a preceding tool call
if m.Role == "tool" {
if i > 0 {
prev := any(messages[i-1]).(api.Message)
if len(prev.ToolCalls) > 0 {
continue // Already handled in the tool call case
}
}
continue // Skip orphaned tool response
}
// Keep all other messages
prunedMessages = append(prunedMessages, msg)
}
return prunedMessages
}
return messages
}
func getTerminalWidth() int {
@@ -353,7 +456,7 @@ func runMCPHost() error {
Title("Enter your prompt (Type /help for commands, Ctrl+C to quit)").
Value(&prompt),
),
).WithWidth(width)
).WithWidth(width).WithTheme(huh.ThemeCharm())
err := form.Run()
if err != nil {
@@ -384,6 +487,9 @@ func runMCPHost() error {
continue
}
if len(messages) > 0 {
messages = pruneMessages(messages)
}
err = runPrompt(client, mcpClients, allTools, prompt, &messages)
if err != nil {
return err