diff --git a/README.md b/README.md index 379ab36a..16de1e5b 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,7 @@ go install github.com/mark3labs/mcphost@latest ## Configuration ⚙️ +### MCP-server MCPHost will automatically create a configuration file at `~/.mcp.json` if it doesn't exist. You can also specify a custom location using the `--config` flag: ```json @@ -107,6 +108,23 @@ Each MCP server entry requires: - For SQLite server: `mcp-server-sqlite` with database path - For filesystem server: `@modelcontextprotocol/server-filesystem` with directory path + +### System-Prompt + +You can specify a custom system prompt using the `--system-prompt` flag. The system prompt should be a JSON file containing the instructions and context you want to provide to the model. For example: + +```json +{ + "systemPrompt": "You're a cat. Name is Neko" +} +``` + +Usage: +```bash +mcphost --system-prompt ./my-system-prompt.json +``` + + ## Usage 🚀 MCPHost is a CLI tool that allows you to interact with various AI models through a unified interface. It supports various tools through MCP servers. @@ -136,6 +154,7 @@ mcphost --model openai: \ - `--anthropic-url string`: Base URL for Anthropic API (defaults to api.anthropic.com) - `--anthropic-api-key string`: Anthropic API key (can also be set via ANTHROPIC_API_KEY environment variable) - `--config string`: Config file location (default is $HOME/.mcp.json) +- `--system-prompt string`: system-prompt file location - `--debug`: Enable debug logging - `--message-window int`: Number of messages to keep in context (default: 10) - `-m, --model string`: Model to use (format: provider:model) (default "anthropic:claude-3-5-sonnet-latest") diff --git a/cmd/root.go b/cmd/root.go index 940e63f8..02c302ab 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -30,6 +30,7 @@ import ( var ( renderer *glamour.TermRenderer configFile string + systemPromptFile string messageWindow int modelFlag string // New flag for model selection openaiBaseURL string // Base URL for OpenAI API @@ -78,6 +79,8 @@ var debugMode bool func init() { rootCmd.PersistentFlags(). StringVar(&configFile, "config", "", "config file (default is $HOME/.mcp.json)") + rootCmd.PersistentFlags(). + StringVar(&systemPromptFile, "system-prompt", "", "system prompt json file") rootCmd.PersistentFlags(). IntVar(&messageWindow, "message-window", 10, "number of messages to keep in context") rootCmd.PersistentFlags(). @@ -97,7 +100,7 @@ func init() { } // Add new function to create provider -func createProvider(ctx context.Context, modelString string) (llm.Provider, error) { +func createProvider(ctx context.Context, modelString, systemPrompt string) (llm.Provider, error) { parts := strings.SplitN(modelString, ":", 2) if len(parts) < 2 { return nil, fmt.Errorf( @@ -121,10 +124,10 @@ func createProvider(ctx context.Context, modelString string) (llm.Provider, erro "Anthropic API key not provided. Use --anthropic-api-key flag or ANTHROPIC_API_KEY environment variable", ) } - return anthropic.NewProvider(apiKey, anthropicBaseURL, model), nil + return anthropic.NewProvider(apiKey, anthropicBaseURL, model, systemPrompt), nil case "ollama": - return ollama.NewProvider(model) + return ollama.NewProvider(model, systemPrompt) case "openai": apiKey := openaiAPIKey @@ -137,7 +140,7 @@ func createProvider(ctx context.Context, modelString string) (llm.Provider, erro "OpenAI API key not provided. Use --openai-api-key flag or OPENAI_API_KEY environment variable", ) } - return openai.NewProvider(apiKey, openaiBaseURL, model), nil + return openai.NewProvider(apiKey, openaiBaseURL, model, systemPrompt), nil case "google": apiKey := googleAPIKey @@ -148,7 +151,7 @@ func createProvider(ctx context.Context, modelString string) (llm.Provider, erro // The project structure is provider specific, but Google calls this GEMINI_API_KEY in e.g. AI Studio. Support both. apiKey = os.Getenv("GEMINI_API_KEY") } - return google.NewProvider(ctx, apiKey, model) + return google.NewProvider(ctx, apiKey, model, systemPrompt) default: return nil, fmt.Errorf("unsupported provider: %s", provider) @@ -476,8 +479,13 @@ func runMCPHost(ctx context.Context) error { log.SetReportCaller(false) } + systemPrompt, err := loadSystemPrompt(systemPromptFile) + if err != nil { + return fmt.Errorf("error loading system prompt: %v", err) + } + // Create the provider based on the model flag - provider, err := createProvider(ctx, modelFlag) + provider, err := createProvider(ctx, modelFlag, systemPrompt) if err != nil { return fmt.Errorf("error creating provider: %v", err) } @@ -594,3 +602,25 @@ func runMCPHost(ctx context.Context) error { } } } + +// loadSystemPrompt loads the system prompt from a JSON file +func loadSystemPrompt(filePath string) (string, error) { + if filePath == "" { + return "", nil + } + + data, err := os.ReadFile(filePath) + if err != nil { + return "", fmt.Errorf("error reading config file: %v", err) + } + + // Parse only the systemPrompt field + var config struct { + SystemPrompt string `json:"systemPrompt"` + } + if err := json.Unmarshal(data, &config); err != nil { + return "", fmt.Errorf("error parsing config file: %v", err) + } + + return config.SystemPrompt, nil +} diff --git a/pkg/llm/anthropic/provider.go b/pkg/llm/anthropic/provider.go index 75b395eb..29d78f28 100644 --- a/pkg/llm/anthropic/provider.go +++ b/pkg/llm/anthropic/provider.go @@ -12,17 +12,19 @@ import ( ) type Provider struct { - client *Client - model string + client *Client + model string + systemPrompt string } -func NewProvider(apiKey string, baseURL string, model string) *Provider { +func NewProvider(apiKey, baseURL, model, systemPrompt string) *Provider { if model == "" { model = "claude-3-5-sonnet-20240620" // 默认模型 } return &Provider{ - client: NewClient(apiKey, baseURL), - model: model, + client: NewClient(apiKey, baseURL), + model: model, + systemPrompt: systemPrompt, } } @@ -135,6 +137,7 @@ func (p *Provider) CreateMessage( Messages: anthropicMessages, MaxTokens: 4096, Tools: anthropicTools, + System: p.systemPrompt, }) if err != nil { return nil, err diff --git a/pkg/llm/anthropic/types.go b/pkg/llm/anthropic/types.go index 551a4c65..45e07a24 100644 --- a/pkg/llm/anthropic/types.go +++ b/pkg/llm/anthropic/types.go @@ -13,6 +13,7 @@ type CreateRequest struct { Model string `json:"model"` Messages []MessageParam `json:"messages"` MaxTokens int `json:"max_tokens"` + System string `json:"system,omitempty"` Tools []Tool `json:"tools,omitempty"` } diff --git a/pkg/llm/google/provider.go b/pkg/llm/google/provider.go index f992e7b1..96ffbb6a 100644 --- a/pkg/llm/google/provider.go +++ b/pkg/llm/google/provider.go @@ -19,12 +19,16 @@ type Provider struct { toolCallID int } -func NewProvider(ctx context.Context, apiKey string, model string) (*Provider, error) { +func NewProvider(ctx context.Context, apiKey, model, systemPrompt string) (*Provider, error) { client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey)) if err != nil { return nil, err } m := client.GenerativeModel(model) + // If systemPrompt is provided, set the system prompt for the model. + if systemPrompt != "" { + m.SystemInstruction = genai.NewUserContent(genai.Text(systemPrompt)) + } return &Provider{ client: client, model: m, diff --git a/pkg/llm/ollama/provider.go b/pkg/llm/ollama/provider.go index bbc7252f..d726e425 100644 --- a/pkg/llm/ollama/provider.go +++ b/pkg/llm/ollama/provider.go @@ -18,19 +18,21 @@ func boolPtr(b bool) *bool { // Provider implements the Provider interface for Ollama type Provider struct { - client *api.Client - model string + client *api.Client + model string + systemPrompt string } // NewProvider creates a new Ollama provider -func NewProvider(model string) (*Provider, error) { +func NewProvider(model string, systemPrompt string) (*Provider, error) { client, err := api.ClientFromEnvironment() if err != nil { return nil, err } return &Provider{ - client: client, - model: model, + client: client, + model: model, + systemPrompt: systemPrompt, }, nil } @@ -48,6 +50,14 @@ func (p *Provider) CreateMessage( // Convert generic messages to Ollama format ollamaMessages := make([]api.Message, 0, len(messages)+1) + // Add system prompt if it exists + if p.systemPrompt != "" { + ollamaMessages = append(ollamaMessages, api.Message{ + Role: "system", + Content: p.systemPrompt, + }) + } + // Add existing messages for _, msg := range messages { // Handle tool responses @@ -151,7 +161,7 @@ func (p *Provider) CreateMessage( "num_messages", len(messages), "num_tools", len(tools)) - log.Debug("sending messages to Ollama", + log.Debug("sending messages to Ollama", "messages", ollamaMessages, "num_tools", len(tools)) diff --git a/pkg/llm/openai/provider.go b/pkg/llm/openai/provider.go index 3ca9e67f..de09bd54 100644 --- a/pkg/llm/openai/provider.go +++ b/pkg/llm/openai/provider.go @@ -12,8 +12,9 @@ import ( ) type Provider struct { - client *Client - model string + client *Client + model string + systemPrompt string } func convertSchema(schema llm.Schema) map[string]interface{} { @@ -30,10 +31,11 @@ func convertSchema(schema llm.Schema) map[string]interface{} { } } -func NewProvider(apiKey string, baseURL string, model string) *Provider { +func NewProvider(apiKey, baseURL, model, systemPrompt string) *Provider { return &Provider{ - client: NewClient(apiKey, baseURL), - model: model, + client: NewClient(apiKey, baseURL), + model: model, + systemPrompt: systemPrompt, } } @@ -50,6 +52,14 @@ func (p *Provider) CreateMessage( openaiMessages := make([]MessageParam, 0, len(messages)) + // Add system prompt if provided + if p.systemPrompt != "" { + openaiMessages = append(openaiMessages, MessageParam{ + Role: "system", + Content: &p.systemPrompt, + }) + } + // Convert previous messages for _, msg := range messages { log.Debug("converting message",