mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
beaa623fc5
* support system prompt for ollama * support system prompt for openai * support system prompt for google gemini * fixed setting for google gemini system-prompt * support system prompt for anthropic * edit README * edit commments * fix system-prompt flag desc * modifed var name for systemPromp
627 lines
16 KiB
Go
627 lines
16 KiB
Go
package cmd
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/charmbracelet/glamour/styles"
|
|
"github.com/charmbracelet/huh"
|
|
"github.com/charmbracelet/huh/spinner"
|
|
"github.com/charmbracelet/log"
|
|
|
|
"github.com/charmbracelet/glamour"
|
|
mcpclient "github.com/mark3labs/mcp-go/client"
|
|
"github.com/mark3labs/mcp-go/mcp"
|
|
"github.com/mark3labs/mcphost/pkg/history"
|
|
"github.com/mark3labs/mcphost/pkg/llm"
|
|
"github.com/mark3labs/mcphost/pkg/llm/anthropic"
|
|
"github.com/mark3labs/mcphost/pkg/llm/google"
|
|
"github.com/mark3labs/mcphost/pkg/llm/ollama"
|
|
"github.com/mark3labs/mcphost/pkg/llm/openai"
|
|
"github.com/spf13/cobra"
|
|
"golang.org/x/term"
|
|
)
|
|
|
|
var (
|
|
renderer *glamour.TermRenderer
|
|
configFile string
|
|
systemPromptFile string
|
|
messageWindow int
|
|
modelFlag string // New flag for model selection
|
|
openaiBaseURL string // Base URL for OpenAI API
|
|
anthropicBaseURL string // Base URL for Anthropic API
|
|
openaiAPIKey string
|
|
anthropicAPIKey string
|
|
googleAPIKey string
|
|
)
|
|
|
|
const (
|
|
initialBackoff = 1 * time.Second
|
|
maxBackoff = 30 * time.Second
|
|
maxRetries = 5 // Will reach close to max backoff
|
|
)
|
|
|
|
var rootCmd = &cobra.Command{
|
|
Use: "mcphost",
|
|
Short: "Chat with AI models through a unified interface",
|
|
Long: `MCPHost is a CLI tool that allows you to interact with various AI models
|
|
through a unified interface. It supports various tools through MCP servers
|
|
and provides streaming responses.
|
|
|
|
Available models can be specified using the --model flag:
|
|
- Anthropic Claude (default): anthropic:claude-3-5-sonnet-latest
|
|
- OpenAI: openai:gpt-4
|
|
- Ollama models: ollama:modelname
|
|
- Google: google:modelname
|
|
|
|
Example:
|
|
mcphost -m ollama:qwen2.5:3b
|
|
mcphost -m openai:gpt-4
|
|
mcphost -m google:gemini-2.0-flash`,
|
|
RunE: func(cmd *cobra.Command, args []string) error {
|
|
return runMCPHost(context.Background())
|
|
},
|
|
}
|
|
|
|
func Execute() {
|
|
if err := rootCmd.Execute(); err != nil {
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
var debugMode bool
|
|
|
|
func init() {
|
|
rootCmd.PersistentFlags().
|
|
StringVar(&configFile, "config", "", "config file (default is $HOME/.mcp.json)")
|
|
rootCmd.PersistentFlags().
|
|
StringVar(&systemPromptFile, "system-prompt", "", "system prompt json file")
|
|
rootCmd.PersistentFlags().
|
|
IntVar(&messageWindow, "message-window", 10, "number of messages to keep in context")
|
|
rootCmd.PersistentFlags().
|
|
StringVarP(&modelFlag, "model", "m", "anthropic:claude-3-5-sonnet-latest",
|
|
"model to use (format: provider:model, e.g. anthropic:claude-3-5-sonnet-latest or ollama:qwen2.5:3b)")
|
|
|
|
// Add debug flag
|
|
rootCmd.PersistentFlags().
|
|
BoolVar(&debugMode, "debug", false, "enable debug logging")
|
|
|
|
flags := rootCmd.PersistentFlags()
|
|
flags.StringVar(&openaiBaseURL, "openai-url", "", "base URL for OpenAI API (defaults to api.openai.com)")
|
|
flags.StringVar(&anthropicBaseURL, "anthropic-url", "", "base URL for Anthropic API (defaults to api.anthropic.com)")
|
|
flags.StringVar(&openaiAPIKey, "openai-api-key", "", "OpenAI API key")
|
|
flags.StringVar(&anthropicAPIKey, "anthropic-api-key", "", "Anthropic API key")
|
|
flags.StringVar(&googleAPIKey, "google-api-key", "", "Google (Gemini) API key")
|
|
}
|
|
|
|
// Add new function to create provider
|
|
func createProvider(ctx context.Context, modelString, systemPrompt string) (llm.Provider, error) {
|
|
parts := strings.SplitN(modelString, ":", 2)
|
|
if len(parts) < 2 {
|
|
return nil, fmt.Errorf(
|
|
"invalid model format. Expected provider:model, got %s",
|
|
modelString,
|
|
)
|
|
}
|
|
|
|
provider := parts[0]
|
|
model := parts[1]
|
|
|
|
switch provider {
|
|
case "anthropic":
|
|
apiKey := anthropicAPIKey
|
|
if apiKey == "" {
|
|
apiKey = os.Getenv("ANTHROPIC_API_KEY")
|
|
}
|
|
|
|
if apiKey == "" {
|
|
return nil, fmt.Errorf(
|
|
"Anthropic API key not provided. Use --anthropic-api-key flag or ANTHROPIC_API_KEY environment variable",
|
|
)
|
|
}
|
|
return anthropic.NewProvider(apiKey, anthropicBaseURL, model, systemPrompt), nil
|
|
|
|
case "ollama":
|
|
return ollama.NewProvider(model, systemPrompt)
|
|
|
|
case "openai":
|
|
apiKey := openaiAPIKey
|
|
if apiKey == "" {
|
|
apiKey = os.Getenv("OPENAI_API_KEY")
|
|
}
|
|
|
|
if apiKey == "" {
|
|
return nil, fmt.Errorf(
|
|
"OpenAI API key not provided. Use --openai-api-key flag or OPENAI_API_KEY environment variable",
|
|
)
|
|
}
|
|
return openai.NewProvider(apiKey, openaiBaseURL, model, systemPrompt), nil
|
|
|
|
case "google":
|
|
apiKey := googleAPIKey
|
|
if apiKey == "" {
|
|
apiKey = os.Getenv("GOOGLE_API_KEY")
|
|
}
|
|
if apiKey == "" {
|
|
// The project structure is provider specific, but Google calls this GEMINI_API_KEY in e.g. AI Studio. Support both.
|
|
apiKey = os.Getenv("GEMINI_API_KEY")
|
|
}
|
|
return google.NewProvider(ctx, apiKey, model, systemPrompt)
|
|
|
|
default:
|
|
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
|
}
|
|
}
|
|
|
|
func pruneMessages(messages []history.HistoryMessage) []history.HistoryMessage {
|
|
if len(messages) <= messageWindow {
|
|
return messages
|
|
}
|
|
|
|
// Keep only the most recent messages based on window size
|
|
messages = messages[len(messages)-messageWindow:]
|
|
|
|
// Handle messages
|
|
toolUseIds := make(map[string]bool)
|
|
toolResultIds := make(map[string]bool)
|
|
|
|
// First pass: collect all tool use and result IDs
|
|
for _, msg := range messages {
|
|
for _, block := range msg.Content {
|
|
if block.Type == "tool_use" {
|
|
toolUseIds[block.ID] = true
|
|
} else if block.Type == "tool_result" {
|
|
toolResultIds[block.ToolUseID] = true
|
|
}
|
|
}
|
|
}
|
|
|
|
// Second pass: filter out orphaned tool calls/results
|
|
var prunedMessages []history.HistoryMessage
|
|
for _, msg := range messages {
|
|
var prunedBlocks []history.ContentBlock
|
|
for _, block := range msg.Content {
|
|
keep := true
|
|
if block.Type == "tool_use" {
|
|
keep = toolResultIds[block.ID]
|
|
} else if block.Type == "tool_result" {
|
|
keep = toolUseIds[block.ToolUseID]
|
|
}
|
|
if keep {
|
|
prunedBlocks = append(prunedBlocks, block)
|
|
}
|
|
}
|
|
// Only include messages that have content or are not assistant messages
|
|
if (len(prunedBlocks) > 0 && msg.Role == "assistant") ||
|
|
msg.Role != "assistant" {
|
|
hasTextBlock := false
|
|
for _, block := range msg.Content {
|
|
if block.Type == "text" {
|
|
hasTextBlock = true
|
|
break
|
|
}
|
|
}
|
|
if len(prunedBlocks) > 0 || hasTextBlock {
|
|
msg.Content = prunedBlocks
|
|
prunedMessages = append(prunedMessages, msg)
|
|
}
|
|
}
|
|
}
|
|
return prunedMessages
|
|
}
|
|
|
|
func getTerminalWidth() int {
|
|
width, _, err := term.GetSize(int(os.Stdout.Fd()))
|
|
if err != nil {
|
|
return 80 // Fallback width
|
|
}
|
|
return width - 20
|
|
}
|
|
|
|
func handleHistoryCommand(messages []history.HistoryMessage) {
|
|
displayMessageHistory(messages)
|
|
}
|
|
|
|
func updateRenderer() error {
|
|
width := getTerminalWidth()
|
|
var err error
|
|
renderer, err = glamour.NewTermRenderer(
|
|
glamour.WithStandardStyle(styles.TokyoNightStyle),
|
|
glamour.WithWordWrap(width),
|
|
)
|
|
return err
|
|
}
|
|
|
|
// Method implementations for simpleMessage
|
|
func runPrompt(
|
|
ctx context.Context,
|
|
provider llm.Provider,
|
|
mcpClients map[string]mcpclient.MCPClient,
|
|
tools []llm.Tool,
|
|
prompt string,
|
|
messages *[]history.HistoryMessage,
|
|
) error {
|
|
// Display the user's prompt if it's not empty (i.e., not a tool response)
|
|
if prompt != "" {
|
|
fmt.Printf("\n%s\n", promptStyle.Render("You: "+prompt))
|
|
*messages = append(
|
|
*messages,
|
|
history.HistoryMessage{
|
|
Role: "user",
|
|
Content: []history.ContentBlock{{
|
|
Type: "text",
|
|
Text: prompt,
|
|
}},
|
|
},
|
|
)
|
|
}
|
|
|
|
var message llm.Message
|
|
var err error
|
|
backoff := initialBackoff
|
|
retries := 0
|
|
|
|
// Convert MessageParam to llm.Message for provider
|
|
// Messages already implement llm.Message interface
|
|
llmMessages := make([]llm.Message, len(*messages))
|
|
for i := range *messages {
|
|
llmMessages[i] = &(*messages)[i]
|
|
}
|
|
|
|
for {
|
|
action := func() {
|
|
message, err = provider.CreateMessage(
|
|
ctx,
|
|
prompt,
|
|
llmMessages,
|
|
tools,
|
|
)
|
|
}
|
|
_ = spinner.New().Title("Thinking...").Action(action).Run()
|
|
|
|
if err != nil {
|
|
// Check if it's an overloaded error
|
|
if strings.Contains(err.Error(), "overloaded_error") {
|
|
if retries >= maxRetries {
|
|
return fmt.Errorf(
|
|
"claude is currently overloaded. please wait a few minutes and try again",
|
|
)
|
|
}
|
|
|
|
log.Warn("Claude is overloaded, backing off...",
|
|
"attempt", retries+1,
|
|
"backoff", backoff.String())
|
|
|
|
time.Sleep(backoff)
|
|
backoff *= 2
|
|
if backoff > maxBackoff {
|
|
backoff = maxBackoff
|
|
}
|
|
retries++
|
|
continue
|
|
}
|
|
// If it's not an overloaded error, return the error immediately
|
|
return err
|
|
}
|
|
// If we got here, the request succeeded
|
|
break
|
|
}
|
|
|
|
var messageContent []history.ContentBlock
|
|
|
|
// Handle the message response
|
|
if str, err := renderer.Render("\nAssistant: "); message.GetContent() != "" && err == nil {
|
|
fmt.Print(str)
|
|
}
|
|
|
|
toolResults := []history.ContentBlock{}
|
|
messageContent = []history.ContentBlock{}
|
|
|
|
// Add text content
|
|
if message.GetContent() != "" {
|
|
if err := updateRenderer(); err != nil {
|
|
return fmt.Errorf("error updating renderer: %v", err)
|
|
}
|
|
str, err := renderer.Render(message.GetContent() + "\n")
|
|
if err != nil {
|
|
log.Error("Failed to render response", "error", err)
|
|
fmt.Print(message.GetContent() + "\n")
|
|
} else {
|
|
fmt.Print(str)
|
|
}
|
|
messageContent = append(messageContent, history.ContentBlock{
|
|
Type: "text",
|
|
Text: message.GetContent(),
|
|
})
|
|
}
|
|
|
|
// Handle tool calls
|
|
for _, toolCall := range message.GetToolCalls() {
|
|
log.Info("🔧 Using tool", "name", toolCall.GetName())
|
|
|
|
input, _ := json.Marshal(toolCall.GetArguments())
|
|
messageContent = append(messageContent, history.ContentBlock{
|
|
Type: "tool_use",
|
|
ID: toolCall.GetID(),
|
|
Name: toolCall.GetName(),
|
|
Input: input,
|
|
})
|
|
|
|
// Log usage statistics if available
|
|
inputTokens, outputTokens := message.GetUsage()
|
|
if inputTokens > 0 || outputTokens > 0 {
|
|
log.Info("Usage statistics",
|
|
"input_tokens", inputTokens,
|
|
"output_tokens", outputTokens,
|
|
"total_tokens", inputTokens+outputTokens)
|
|
}
|
|
|
|
parts := strings.Split(toolCall.GetName(), "__")
|
|
if len(parts) != 2 {
|
|
fmt.Printf(
|
|
"Error: Invalid tool name format: %s\n",
|
|
toolCall.GetName(),
|
|
)
|
|
continue
|
|
}
|
|
|
|
serverName, toolName := parts[0], parts[1]
|
|
mcpClient, ok := mcpClients[serverName]
|
|
if !ok {
|
|
fmt.Printf("Error: Server not found: %s\n", serverName)
|
|
continue
|
|
}
|
|
|
|
var toolArgs map[string]interface{}
|
|
if err := json.Unmarshal(input, &toolArgs); err != nil {
|
|
fmt.Printf("Error parsing tool arguments: %v\n", err)
|
|
continue
|
|
}
|
|
|
|
var toolResultPtr *mcp.CallToolResult
|
|
action := func() {
|
|
req := mcp.CallToolRequest{}
|
|
req.Params.Name = toolName
|
|
req.Params.Arguments = toolArgs
|
|
toolResultPtr, err = mcpClient.CallTool(
|
|
context.Background(),
|
|
req,
|
|
)
|
|
}
|
|
_ = spinner.New().
|
|
Title(fmt.Sprintf("Running tool %s...", toolName)).
|
|
Action(action).
|
|
Run()
|
|
|
|
if err != nil {
|
|
errMsg := fmt.Sprintf(
|
|
"Error calling tool %s: %v",
|
|
toolName,
|
|
err,
|
|
)
|
|
fmt.Printf("\n%s\n", errorStyle.Render(errMsg))
|
|
|
|
// Add error message as tool result
|
|
toolResults = append(toolResults, history.ContentBlock{
|
|
Type: "tool_result",
|
|
ToolUseID: toolCall.GetID(),
|
|
Content: []history.ContentBlock{{
|
|
Type: "text",
|
|
Text: errMsg,
|
|
}},
|
|
})
|
|
continue
|
|
}
|
|
|
|
toolResult := *toolResultPtr
|
|
|
|
if toolResult.Content != nil {
|
|
log.Debug("raw tool result content", "content", toolResult.Content)
|
|
|
|
// Create the tool result block
|
|
resultBlock := history.ContentBlock{
|
|
Type: "tool_result",
|
|
ToolUseID: toolCall.GetID(),
|
|
Content: toolResult.Content,
|
|
}
|
|
|
|
// Extract text content
|
|
var resultText string
|
|
// Handle array content directly since we know it's []interface{}
|
|
for _, item := range toolResult.Content {
|
|
if contentMap, ok := item.(mcp.TextContent); ok {
|
|
resultText += fmt.Sprintf("%v ", contentMap.Text)
|
|
}
|
|
}
|
|
|
|
resultBlock.Text = strings.TrimSpace(resultText)
|
|
log.Debug("created tool result block",
|
|
"block", resultBlock,
|
|
"tool_id", toolCall.GetID())
|
|
|
|
toolResults = append(toolResults, resultBlock)
|
|
}
|
|
}
|
|
|
|
*messages = append(*messages, history.HistoryMessage{
|
|
Role: message.GetRole(),
|
|
Content: messageContent,
|
|
})
|
|
|
|
if len(toolResults) > 0 {
|
|
for _, toolResult := range toolResults {
|
|
*messages = append(*messages, history.HistoryMessage{
|
|
Role: "tool",
|
|
Content: []history.ContentBlock{toolResult},
|
|
})
|
|
}
|
|
// Make another call to get Claude's response to the tool results
|
|
return runPrompt(ctx, provider, mcpClients, tools, "", messages)
|
|
}
|
|
|
|
fmt.Println() // Add spacing
|
|
return nil
|
|
}
|
|
|
|
func runMCPHost(ctx context.Context) error {
|
|
// Set up logging based on debug flag
|
|
if debugMode {
|
|
log.SetLevel(log.DebugLevel)
|
|
// Enable caller information for debug logs
|
|
log.SetReportCaller(true)
|
|
} else {
|
|
log.SetLevel(log.InfoLevel)
|
|
log.SetReportCaller(false)
|
|
}
|
|
|
|
systemPrompt, err := loadSystemPrompt(systemPromptFile)
|
|
if err != nil {
|
|
return fmt.Errorf("error loading system prompt: %v", err)
|
|
}
|
|
|
|
// Create the provider based on the model flag
|
|
provider, err := createProvider(ctx, modelFlag, systemPrompt)
|
|
if err != nil {
|
|
return fmt.Errorf("error creating provider: %v", err)
|
|
}
|
|
|
|
// Split the model flag and get just the model name
|
|
parts := strings.SplitN(modelFlag, ":", 2)
|
|
log.Info("Model loaded",
|
|
"provider", provider.Name(),
|
|
"model", parts[1])
|
|
|
|
mcpConfig, err := loadMCPConfig()
|
|
if err != nil {
|
|
return fmt.Errorf("error loading MCP config: %v", err)
|
|
}
|
|
|
|
mcpClients, err := createMCPClients(mcpConfig)
|
|
if err != nil {
|
|
return fmt.Errorf("error creating MCP clients: %v", err)
|
|
}
|
|
|
|
defer func() {
|
|
log.Info("Shutting down MCP servers...")
|
|
for name, client := range mcpClients {
|
|
if err := client.Close(); err != nil {
|
|
log.Error("Failed to close server", "name", name, "error", err)
|
|
} else {
|
|
log.Info("Server closed", "name", name)
|
|
}
|
|
}
|
|
}()
|
|
|
|
for name := range mcpClients {
|
|
log.Info("Server connected", "name", name)
|
|
}
|
|
|
|
var allTools []llm.Tool
|
|
for serverName, mcpClient := range mcpClients {
|
|
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
|
toolsResult, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{})
|
|
cancel()
|
|
|
|
if err != nil {
|
|
log.Error(
|
|
"Error fetching tools",
|
|
"server",
|
|
serverName,
|
|
"error",
|
|
err,
|
|
)
|
|
continue
|
|
}
|
|
|
|
serverTools := mcpToolsToAnthropicTools(serverName, toolsResult.Tools)
|
|
allTools = append(allTools, serverTools...)
|
|
log.Info(
|
|
"Tools loaded",
|
|
"server",
|
|
serverName,
|
|
"count",
|
|
len(toolsResult.Tools),
|
|
)
|
|
}
|
|
|
|
if err := updateRenderer(); err != nil {
|
|
return fmt.Errorf("error initializing renderer: %v", err)
|
|
}
|
|
|
|
messages := make([]history.HistoryMessage, 0)
|
|
|
|
// Main interaction loop
|
|
for {
|
|
var prompt string
|
|
err := huh.NewForm(huh.NewGroup(huh.NewText().
|
|
Title("Enter your prompt (Type /help for commands, Ctrl+C to quit)").
|
|
Value(&prompt).
|
|
CharLimit(5000)),
|
|
).WithWidth(getTerminalWidth()).
|
|
WithTheme(huh.ThemeCharm()).
|
|
Run()
|
|
|
|
if err != nil {
|
|
// Check if it's a user abort (Ctrl+C)
|
|
if errors.Is(err, huh.ErrUserAborted) {
|
|
fmt.Println("\nGoodbye!")
|
|
return nil // Exit cleanly
|
|
}
|
|
return err // Return other errors normally
|
|
}
|
|
|
|
if prompt == "" {
|
|
continue
|
|
}
|
|
|
|
// Handle slash commands
|
|
handled, err := handleSlashCommand(
|
|
prompt,
|
|
mcpConfig,
|
|
mcpClients,
|
|
messages,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if handled {
|
|
continue
|
|
}
|
|
|
|
if len(messages) > 0 {
|
|
messages = pruneMessages(messages)
|
|
}
|
|
err = runPrompt(ctx, provider, mcpClients, allTools, prompt, &messages)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
// loadSystemPrompt loads the system prompt from a JSON file
|
|
func loadSystemPrompt(filePath string) (string, error) {
|
|
if filePath == "" {
|
|
return "", nil
|
|
}
|
|
|
|
data, err := os.ReadFile(filePath)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error reading config file: %v", err)
|
|
}
|
|
|
|
// Parse only the systemPrompt field
|
|
var config struct {
|
|
SystemPrompt string `json:"systemPrompt"`
|
|
}
|
|
if err := json.Unmarshal(data, &config); err != nil {
|
|
return "", fmt.Errorf("error parsing config file: %v", err)
|
|
}
|
|
|
|
return config.SystemPrompt, nil
|
|
}
|