From ec996fd1e107a6b9c5abe107586b262dc826d3cd Mon Sep 17 00:00:00 2001 From: Roman Gelembjuk Date: Wed, 16 Apr 2025 18:18:44 +0100 Subject: [PATCH] Add support of config for SSE servers. Add Authorization header support (#24) * Add support of config for SSE servers. Add Authorization header support * Change the format of JSON for SSE servers to be similar to other tools --------- Co-authored-by: Roman Gelembjuk --- cmd/mcp.go | 134 ++++++++++++++++++++++++++++++++++++++++++++-------- cmd/root.go | 6 +-- go.mod | 3 +- go.sum | 6 ++- 4 files changed, 121 insertions(+), 28 deletions(-) diff --git a/cmd/mcp.go b/cmd/mcp.go index c3686a8c..1511c038 100644 --- a/cmd/mcp.go +++ b/cmd/mcp.go @@ -21,6 +21,11 @@ import ( "github.com/mark3labs/mcphost/pkg/llm" ) +const ( + transportStdio = "stdio" + transportSSE = "sse" +) + var ( // Tokyo Night theme colors tokyoPurple = lipgloss.Color("99") // #9d7cd8 @@ -60,15 +65,66 @@ var ( ) type MCPConfig struct { - MCPServers map[string]ServerConfig `json:"mcpServers"` + MCPServers map[string]ServerConfigWrapper `json:"mcpServers"` } -type ServerConfig struct { +type ServerConfig interface { + GetType() string +} + +type STDIOServerConfig struct { Command string `json:"command"` Args []string `json:"args"` Env map[string]string `json:"env,omitempty"` } +func (s STDIOServerConfig) GetType() string { + return transportStdio +} + +type SSEServerConfig struct { + Url string `json:"url"` + Headers []string `json:"headers,omitempty"` +} + +func (s SSEServerConfig) GetType() string { + return transportSSE +} + +type ServerConfigWrapper struct { + Config ServerConfig +} + +func (w *ServerConfigWrapper) UnmarshalJSON(data []byte) error { + var typeField struct { + Url string `json:"url"` + } + + if err := json.Unmarshal(data, &typeField); err != nil { + return err + } + if typeField.Url != "" { + // If the URL field is present, treat it as an SSE server + var sse SSEServerConfig + if err := json.Unmarshal(data, &sse); err != nil { + return err + } + w.Config = sse + } else { + // Otherwise, treat it as a STDIOServerConfig + var stdio STDIOServerConfig + if err := json.Unmarshal(data, &stdio); err != nil { + return err + } + w.Config = stdio + } + + return nil +} +func (w ServerConfigWrapper) MarshalJSON() ([]byte, error) { + return json.Marshal(w.Config) +} + func mcpToolsToAnthropicTools( serverName string, mcpTools []mcp.Tool, @@ -108,7 +164,7 @@ func loadMCPConfig() (*MCPConfig, error) { if _, err := os.Stat(configPath); os.IsNotExist(err) { // Create default config defaultConfig := MCPConfig{ - MCPServers: make(map[string]ServerConfig), + MCPServers: make(map[string]ServerConfigWrapper), } // Create the file with default config @@ -149,31 +205,45 @@ func createMCPClients( clients := make(map[string]mcpclient.MCPClient) for name, server := range config.MCPServers { - var env []string - for k, v := range server.Env { - env = append(env, fmt.Sprintf("%s=%s", k, v)) - } var client mcpclient.MCPClient var err error - if server.Command == "sse_server" { - if len(server.Args) == 0 { - return nil, fmt.Errorf( - "no arguments provided for sse command", - ) + if server.Config.GetType() == transportSSE { + sseConfig := server.Config.(SSEServerConfig) + + options := []mcpclient.ClientOption{} + + if sseConfig.Headers != nil { + // Parse headers from the config + headers := make(map[string]string) + for _, header := range sseConfig.Headers { + parts := strings.SplitN(header, ":", 2) + if len(parts) == 2 { + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + headers[key] = value + } + } + options = append(options, mcpclient.WithHeaders(headers)) } client, err = mcpclient.NewSSEMCPClient( - server.Args[0], + sseConfig.Url, + options..., ) if err == nil { err = client.(*mcpclient.SSEMCPClient).Start(context.Background()) } } else { + stdioConfig := server.Config.(STDIOServerConfig) + var env []string + for k, v := range stdioConfig.Env { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } client, err = mcpclient.NewStdioMCPClient( - server.Command, + stdioConfig.Command, env, - server.Args...) + stdioConfig.Args...) } if err != nil { for _, c := range clients { @@ -310,15 +380,37 @@ func handleServersCommand(config *MCPConfig) { } else { for name, server := range config.MCPServers { markdown.WriteString(fmt.Sprintf("# %s\n\n", name)) - markdown.WriteString("*Command*\n") - markdown.WriteString(fmt.Sprintf("`%s`\n\n", server.Command)) - markdown.WriteString("*Arguments*\n") - if len(server.Args) > 0 { - markdown.WriteString(fmt.Sprintf("`%s`\n", strings.Join(server.Args, " "))) + if server.Config.GetType() == transportSSE { + sseConfig := server.Config.(SSEServerConfig) + markdown.WriteString("*Url*\n") + markdown.WriteString(fmt.Sprintf("`%s`\n\n", sseConfig.Url)) + markdown.WriteString("*headers*\n") + if sseConfig.Headers != nil { + for _, header := range sseConfig.Headers { + parts := strings.SplitN(header, ":", 2) + if len(parts) == 2 { + key := strings.TrimSpace(parts[0]) + markdown.WriteString("`" + key + ": [REDACTED]`\n") + } + } + } else { + markdown.WriteString("*None*\n") + } + } else { - markdown.WriteString("*None*\n") + stdioConfig := server.Config.(STDIOServerConfig) + markdown.WriteString("*Command*\n") + markdown.WriteString(fmt.Sprintf("`%s`\n\n", stdioConfig.Command)) + + markdown.WriteString("*Arguments*\n") + if len(stdioConfig.Args) > 0 { + markdown.WriteString(fmt.Sprintf("`%s`\n", strings.Join(stdioConfig.Args, " "))) + } else { + markdown.WriteString("*None*\n") + } } + markdown.WriteString("\n") // Add spacing between servers } } diff --git a/cmd/root.go b/cmd/root.go index a0c4e2b1..e57570cb 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -431,10 +431,8 @@ func runPrompt( var resultText string // Handle array content directly since we know it's []interface{} for _, item := range toolResult.Content { - if contentMap, ok := item.(map[string]interface{}); ok { - if text, ok := contentMap["text"]; ok { - resultText += fmt.Sprintf("%v ", text) - } + if contentMap, ok := item.(mcp.TextContent); ok { + resultText += fmt.Sprintf("%v ", contentMap.Text) } } diff --git a/go.mod b/go.mod index d66bb7b7..045d1c8e 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/charmbracelet/lipgloss v1.0.0 github.com/charmbracelet/log v0.4.0 github.com/google/generative-ai-go v0.19.0 - github.com/mark3labs/mcp-go v0.8.2 + github.com/mark3labs/mcp-go v0.20.0 github.com/ollama/ollama v0.5.1 github.com/spf13/cobra v1.8.1 golang.org/x/term v0.30.0 @@ -37,6 +37,7 @@ require ( github.com/gorilla/css v1.0.1 // indirect github.com/microcosm-cc/bluemonday v1.0.27 // indirect github.com/muesli/reflow v0.3.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yuin/goldmark v1.7.4 // indirect github.com/yuin/goldmark-emoji v1.0.3 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect diff --git a/go.sum b/go.sum index 5263c05c..59197112 100644 --- a/go.sum +++ b/go.sum @@ -86,8 +86,8 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= -github.com/mark3labs/mcp-go v0.8.2 h1:OtqqXlRqjXs6zuMhf1uiuQ2iqBrhMGgLpDeVDUWMKFc= -github.com/mark3labs/mcp-go v0.8.2/go.mod h1:cjMlBU0cv/cj9kjlgmRhoJ5JREdS7YX83xeIG9Ko/jE= +github.com/mark3labs/mcp-go v0.20.0 h1:NYZDZ10GBKHVz4SdQ2tPFSDFQFKCTrTZJLn4wj6jAaw= +github.com/mark3labs/mcp-go v0.20.0/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= @@ -120,6 +120,8 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=