Files
kit/cmd/root.go
T
Hyunchul Jung beaa623fc5 feat: support system prompt (#29)
* support system prompt for ollama

* support system prompt for openai

* support system prompt for google gemini

* fixed setting for google gemini system-prompt

* support system prompt for anthropic

* edit README

* edit commments

* fix system-prompt flag desc

* modifed var name for systemPromp
2025-04-23 17:39:21 +03:00

627 lines
16 KiB
Go

package cmd
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/charmbracelet/glamour/styles"
"github.com/charmbracelet/huh"
"github.com/charmbracelet/huh/spinner"
"github.com/charmbracelet/log"
"github.com/charmbracelet/glamour"
mcpclient "github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcphost/pkg/history"
"github.com/mark3labs/mcphost/pkg/llm"
"github.com/mark3labs/mcphost/pkg/llm/anthropic"
"github.com/mark3labs/mcphost/pkg/llm/google"
"github.com/mark3labs/mcphost/pkg/llm/ollama"
"github.com/mark3labs/mcphost/pkg/llm/openai"
"github.com/spf13/cobra"
"golang.org/x/term"
)
var (
renderer *glamour.TermRenderer
configFile string
systemPromptFile string
messageWindow int
modelFlag string // New flag for model selection
openaiBaseURL string // Base URL for OpenAI API
anthropicBaseURL string // Base URL for Anthropic API
openaiAPIKey string
anthropicAPIKey string
googleAPIKey string
)
const (
initialBackoff = 1 * time.Second
maxBackoff = 30 * time.Second
maxRetries = 5 // Will reach close to max backoff
)
var rootCmd = &cobra.Command{
Use: "mcphost",
Short: "Chat with AI models through a unified interface",
Long: `MCPHost is a CLI tool that allows you to interact with various AI models
through a unified interface. It supports various tools through MCP servers
and provides streaming responses.
Available models can be specified using the --model flag:
- Anthropic Claude (default): anthropic:claude-3-5-sonnet-latest
- OpenAI: openai:gpt-4
- Ollama models: ollama:modelname
- Google: google:modelname
Example:
mcphost -m ollama:qwen2.5:3b
mcphost -m openai:gpt-4
mcphost -m google:gemini-2.0-flash`,
RunE: func(cmd *cobra.Command, args []string) error {
return runMCPHost(context.Background())
},
}
func Execute() {
if err := rootCmd.Execute(); err != nil {
os.Exit(1)
}
}
var debugMode bool
func init() {
rootCmd.PersistentFlags().
StringVar(&configFile, "config", "", "config file (default is $HOME/.mcp.json)")
rootCmd.PersistentFlags().
StringVar(&systemPromptFile, "system-prompt", "", "system prompt json file")
rootCmd.PersistentFlags().
IntVar(&messageWindow, "message-window", 10, "number of messages to keep in context")
rootCmd.PersistentFlags().
StringVarP(&modelFlag, "model", "m", "anthropic:claude-3-5-sonnet-latest",
"model to use (format: provider:model, e.g. anthropic:claude-3-5-sonnet-latest or ollama:qwen2.5:3b)")
// Add debug flag
rootCmd.PersistentFlags().
BoolVar(&debugMode, "debug", false, "enable debug logging")
flags := rootCmd.PersistentFlags()
flags.StringVar(&openaiBaseURL, "openai-url", "", "base URL for OpenAI API (defaults to api.openai.com)")
flags.StringVar(&anthropicBaseURL, "anthropic-url", "", "base URL for Anthropic API (defaults to api.anthropic.com)")
flags.StringVar(&openaiAPIKey, "openai-api-key", "", "OpenAI API key")
flags.StringVar(&anthropicAPIKey, "anthropic-api-key", "", "Anthropic API key")
flags.StringVar(&googleAPIKey, "google-api-key", "", "Google (Gemini) API key")
}
// Add new function to create provider
func createProvider(ctx context.Context, modelString, systemPrompt string) (llm.Provider, error) {
parts := strings.SplitN(modelString, ":", 2)
if len(parts) < 2 {
return nil, fmt.Errorf(
"invalid model format. Expected provider:model, got %s",
modelString,
)
}
provider := parts[0]
model := parts[1]
switch provider {
case "anthropic":
apiKey := anthropicAPIKey
if apiKey == "" {
apiKey = os.Getenv("ANTHROPIC_API_KEY")
}
if apiKey == "" {
return nil, fmt.Errorf(
"Anthropic API key not provided. Use --anthropic-api-key flag or ANTHROPIC_API_KEY environment variable",
)
}
return anthropic.NewProvider(apiKey, anthropicBaseURL, model, systemPrompt), nil
case "ollama":
return ollama.NewProvider(model, systemPrompt)
case "openai":
apiKey := openaiAPIKey
if apiKey == "" {
apiKey = os.Getenv("OPENAI_API_KEY")
}
if apiKey == "" {
return nil, fmt.Errorf(
"OpenAI API key not provided. Use --openai-api-key flag or OPENAI_API_KEY environment variable",
)
}
return openai.NewProvider(apiKey, openaiBaseURL, model, systemPrompt), nil
case "google":
apiKey := googleAPIKey
if apiKey == "" {
apiKey = os.Getenv("GOOGLE_API_KEY")
}
if apiKey == "" {
// The project structure is provider specific, but Google calls this GEMINI_API_KEY in e.g. AI Studio. Support both.
apiKey = os.Getenv("GEMINI_API_KEY")
}
return google.NewProvider(ctx, apiKey, model, systemPrompt)
default:
return nil, fmt.Errorf("unsupported provider: %s", provider)
}
}
func pruneMessages(messages []history.HistoryMessage) []history.HistoryMessage {
if len(messages) <= messageWindow {
return messages
}
// Keep only the most recent messages based on window size
messages = messages[len(messages)-messageWindow:]
// Handle messages
toolUseIds := make(map[string]bool)
toolResultIds := make(map[string]bool)
// First pass: collect all tool use and result IDs
for _, msg := range messages {
for _, block := range msg.Content {
if block.Type == "tool_use" {
toolUseIds[block.ID] = true
} else if block.Type == "tool_result" {
toolResultIds[block.ToolUseID] = true
}
}
}
// Second pass: filter out orphaned tool calls/results
var prunedMessages []history.HistoryMessage
for _, msg := range messages {
var prunedBlocks []history.ContentBlock
for _, block := range msg.Content {
keep := true
if block.Type == "tool_use" {
keep = toolResultIds[block.ID]
} else if block.Type == "tool_result" {
keep = toolUseIds[block.ToolUseID]
}
if keep {
prunedBlocks = append(prunedBlocks, block)
}
}
// Only include messages that have content or are not assistant messages
if (len(prunedBlocks) > 0 && msg.Role == "assistant") ||
msg.Role != "assistant" {
hasTextBlock := false
for _, block := range msg.Content {
if block.Type == "text" {
hasTextBlock = true
break
}
}
if len(prunedBlocks) > 0 || hasTextBlock {
msg.Content = prunedBlocks
prunedMessages = append(prunedMessages, msg)
}
}
}
return prunedMessages
}
func getTerminalWidth() int {
width, _, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil {
return 80 // Fallback width
}
return width - 20
}
func handleHistoryCommand(messages []history.HistoryMessage) {
displayMessageHistory(messages)
}
func updateRenderer() error {
width := getTerminalWidth()
var err error
renderer, err = glamour.NewTermRenderer(
glamour.WithStandardStyle(styles.TokyoNightStyle),
glamour.WithWordWrap(width),
)
return err
}
// Method implementations for simpleMessage
func runPrompt(
ctx context.Context,
provider llm.Provider,
mcpClients map[string]mcpclient.MCPClient,
tools []llm.Tool,
prompt string,
messages *[]history.HistoryMessage,
) error {
// Display the user's prompt if it's not empty (i.e., not a tool response)
if prompt != "" {
fmt.Printf("\n%s\n", promptStyle.Render("You: "+prompt))
*messages = append(
*messages,
history.HistoryMessage{
Role: "user",
Content: []history.ContentBlock{{
Type: "text",
Text: prompt,
}},
},
)
}
var message llm.Message
var err error
backoff := initialBackoff
retries := 0
// Convert MessageParam to llm.Message for provider
// Messages already implement llm.Message interface
llmMessages := make([]llm.Message, len(*messages))
for i := range *messages {
llmMessages[i] = &(*messages)[i]
}
for {
action := func() {
message, err = provider.CreateMessage(
ctx,
prompt,
llmMessages,
tools,
)
}
_ = spinner.New().Title("Thinking...").Action(action).Run()
if err != nil {
// Check if it's an overloaded error
if strings.Contains(err.Error(), "overloaded_error") {
if retries >= maxRetries {
return fmt.Errorf(
"claude is currently overloaded. please wait a few minutes and try again",
)
}
log.Warn("Claude is overloaded, backing off...",
"attempt", retries+1,
"backoff", backoff.String())
time.Sleep(backoff)
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
retries++
continue
}
// If it's not an overloaded error, return the error immediately
return err
}
// If we got here, the request succeeded
break
}
var messageContent []history.ContentBlock
// Handle the message response
if str, err := renderer.Render("\nAssistant: "); message.GetContent() != "" && err == nil {
fmt.Print(str)
}
toolResults := []history.ContentBlock{}
messageContent = []history.ContentBlock{}
// Add text content
if message.GetContent() != "" {
if err := updateRenderer(); err != nil {
return fmt.Errorf("error updating renderer: %v", err)
}
str, err := renderer.Render(message.GetContent() + "\n")
if err != nil {
log.Error("Failed to render response", "error", err)
fmt.Print(message.GetContent() + "\n")
} else {
fmt.Print(str)
}
messageContent = append(messageContent, history.ContentBlock{
Type: "text",
Text: message.GetContent(),
})
}
// Handle tool calls
for _, toolCall := range message.GetToolCalls() {
log.Info("🔧 Using tool", "name", toolCall.GetName())
input, _ := json.Marshal(toolCall.GetArguments())
messageContent = append(messageContent, history.ContentBlock{
Type: "tool_use",
ID: toolCall.GetID(),
Name: toolCall.GetName(),
Input: input,
})
// Log usage statistics if available
inputTokens, outputTokens := message.GetUsage()
if inputTokens > 0 || outputTokens > 0 {
log.Info("Usage statistics",
"input_tokens", inputTokens,
"output_tokens", outputTokens,
"total_tokens", inputTokens+outputTokens)
}
parts := strings.Split(toolCall.GetName(), "__")
if len(parts) != 2 {
fmt.Printf(
"Error: Invalid tool name format: %s\n",
toolCall.GetName(),
)
continue
}
serverName, toolName := parts[0], parts[1]
mcpClient, ok := mcpClients[serverName]
if !ok {
fmt.Printf("Error: Server not found: %s\n", serverName)
continue
}
var toolArgs map[string]interface{}
if err := json.Unmarshal(input, &toolArgs); err != nil {
fmt.Printf("Error parsing tool arguments: %v\n", err)
continue
}
var toolResultPtr *mcp.CallToolResult
action := func() {
req := mcp.CallToolRequest{}
req.Params.Name = toolName
req.Params.Arguments = toolArgs
toolResultPtr, err = mcpClient.CallTool(
context.Background(),
req,
)
}
_ = spinner.New().
Title(fmt.Sprintf("Running tool %s...", toolName)).
Action(action).
Run()
if err != nil {
errMsg := fmt.Sprintf(
"Error calling tool %s: %v",
toolName,
err,
)
fmt.Printf("\n%s\n", errorStyle.Render(errMsg))
// Add error message as tool result
toolResults = append(toolResults, history.ContentBlock{
Type: "tool_result",
ToolUseID: toolCall.GetID(),
Content: []history.ContentBlock{{
Type: "text",
Text: errMsg,
}},
})
continue
}
toolResult := *toolResultPtr
if toolResult.Content != nil {
log.Debug("raw tool result content", "content", toolResult.Content)
// Create the tool result block
resultBlock := history.ContentBlock{
Type: "tool_result",
ToolUseID: toolCall.GetID(),
Content: toolResult.Content,
}
// Extract text content
var resultText string
// Handle array content directly since we know it's []interface{}
for _, item := range toolResult.Content {
if contentMap, ok := item.(mcp.TextContent); ok {
resultText += fmt.Sprintf("%v ", contentMap.Text)
}
}
resultBlock.Text = strings.TrimSpace(resultText)
log.Debug("created tool result block",
"block", resultBlock,
"tool_id", toolCall.GetID())
toolResults = append(toolResults, resultBlock)
}
}
*messages = append(*messages, history.HistoryMessage{
Role: message.GetRole(),
Content: messageContent,
})
if len(toolResults) > 0 {
for _, toolResult := range toolResults {
*messages = append(*messages, history.HistoryMessage{
Role: "tool",
Content: []history.ContentBlock{toolResult},
})
}
// Make another call to get Claude's response to the tool results
return runPrompt(ctx, provider, mcpClients, tools, "", messages)
}
fmt.Println() // Add spacing
return nil
}
func runMCPHost(ctx context.Context) error {
// Set up logging based on debug flag
if debugMode {
log.SetLevel(log.DebugLevel)
// Enable caller information for debug logs
log.SetReportCaller(true)
} else {
log.SetLevel(log.InfoLevel)
log.SetReportCaller(false)
}
systemPrompt, err := loadSystemPrompt(systemPromptFile)
if err != nil {
return fmt.Errorf("error loading system prompt: %v", err)
}
// Create the provider based on the model flag
provider, err := createProvider(ctx, modelFlag, systemPrompt)
if err != nil {
return fmt.Errorf("error creating provider: %v", err)
}
// Split the model flag and get just the model name
parts := strings.SplitN(modelFlag, ":", 2)
log.Info("Model loaded",
"provider", provider.Name(),
"model", parts[1])
mcpConfig, err := loadMCPConfig()
if err != nil {
return fmt.Errorf("error loading MCP config: %v", err)
}
mcpClients, err := createMCPClients(mcpConfig)
if err != nil {
return fmt.Errorf("error creating MCP clients: %v", err)
}
defer func() {
log.Info("Shutting down MCP servers...")
for name, client := range mcpClients {
if err := client.Close(); err != nil {
log.Error("Failed to close server", "name", name, "error", err)
} else {
log.Info("Server closed", "name", name)
}
}
}()
for name := range mcpClients {
log.Info("Server connected", "name", name)
}
var allTools []llm.Tool
for serverName, mcpClient := range mcpClients {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
toolsResult, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{})
cancel()
if err != nil {
log.Error(
"Error fetching tools",
"server",
serverName,
"error",
err,
)
continue
}
serverTools := mcpToolsToAnthropicTools(serverName, toolsResult.Tools)
allTools = append(allTools, serverTools...)
log.Info(
"Tools loaded",
"server",
serverName,
"count",
len(toolsResult.Tools),
)
}
if err := updateRenderer(); err != nil {
return fmt.Errorf("error initializing renderer: %v", err)
}
messages := make([]history.HistoryMessage, 0)
// Main interaction loop
for {
var prompt string
err := huh.NewForm(huh.NewGroup(huh.NewText().
Title("Enter your prompt (Type /help for commands, Ctrl+C to quit)").
Value(&prompt).
CharLimit(5000)),
).WithWidth(getTerminalWidth()).
WithTheme(huh.ThemeCharm()).
Run()
if err != nil {
// Check if it's a user abort (Ctrl+C)
if errors.Is(err, huh.ErrUserAborted) {
fmt.Println("\nGoodbye!")
return nil // Exit cleanly
}
return err // Return other errors normally
}
if prompt == "" {
continue
}
// Handle slash commands
handled, err := handleSlashCommand(
prompt,
mcpConfig,
mcpClients,
messages,
)
if err != nil {
return err
}
if handled {
continue
}
if len(messages) > 0 {
messages = pruneMessages(messages)
}
err = runPrompt(ctx, provider, mcpClients, allTools, prompt, &messages)
if err != nil {
return err
}
}
}
// loadSystemPrompt loads the system prompt from a JSON file
func loadSystemPrompt(filePath string) (string, error) {
if filePath == "" {
return "", nil
}
data, err := os.ReadFile(filePath)
if err != nil {
return "", fmt.Errorf("error reading config file: %v", err)
}
// Parse only the systemPrompt field
var config struct {
SystemPrompt string `json:"systemPrompt"`
}
if err := json.Unmarshal(data, &config); err != nil {
return "", fmt.Errorf("error parsing config file: %v", err)
}
return config.SystemPrompt, nil
}