Files
kit/internal/models/providers.go
T

722 lines
21 KiB
Go
Raw Normal View History

2025-06-09 14:38:31 +03:00
package models
import (
2025-06-25 14:27:19 +03:00
"bytes"
2025-06-09 14:38:31 +03:00
"context"
"crypto/tls"
2025-06-25 14:27:19 +03:00
"encoding/json"
2025-06-09 14:38:31 +03:00
"fmt"
2025-06-25 14:27:19 +03:00
"io"
"maps"
2025-06-25 14:27:19 +03:00
"net/http"
2025-06-09 14:38:31 +03:00
"os"
"strings"
2025-06-26 13:32:18 +03:00
"time"
2025-06-09 14:38:31 +03:00
"charm.land/fantasy"
"charm.land/fantasy/providers/anthropic"
"charm.land/fantasy/providers/azure"
"charm.land/fantasy/providers/bedrock"
"charm.land/fantasy/providers/google"
"charm.land/fantasy/providers/openai"
"charm.land/fantasy/providers/openaicompat"
"charm.land/fantasy/providers/openrouter"
2025-06-10 14:20:40 +03:00
2025-06-25 14:27:19 +03:00
"github.com/mark3labs/mcphost/internal/auth"
"github.com/mark3labs/mcphost/internal/ui/progress"
2025-06-09 14:38:31 +03:00
)
2025-06-25 14:27:19 +03:00
const (
2025-11-12 16:48:46 +03:00
// ClaudeCodePrompt is the required system prompt for OAuth authentication.
2025-06-25 14:27:19 +03:00
ClaudeCodePrompt = "You are Claude Code, Anthropic's official CLI for Claude."
)
// resolveModelAlias resolves model aliases to their full names using the registry
func resolveModelAlias(provider, modelName string) string {
registry := GetGlobalRegistry()
2025-06-25 17:24:37 +03:00
2025-06-25 14:27:19 +03:00
aliasMap := map[string]string{
2025-06-25 17:24:37 +03:00
"claude-opus-latest": "claude-opus-4-20250514",
"claude-sonnet-latest": "claude-sonnet-4-20250514",
"claude-4-opus-latest": "claude-opus-4-20250514",
"claude-4-sonnet-latest": "claude-sonnet-4-20250514",
2025-06-25 14:27:19 +03:00
"claude-3-5-haiku-latest": "claude-3-5-haiku-20241022",
2025-06-25 17:24:37 +03:00
"claude-3-5-sonnet-latest": "claude-3-5-sonnet-20241022",
2025-06-25 14:27:19 +03:00
"claude-3-7-sonnet-latest": "claude-3-7-sonnet-20250219",
"claude-3-opus-latest": "claude-3-opus-20240229",
}
2025-06-25 17:24:37 +03:00
2025-06-25 14:27:19 +03:00
if resolved, exists := aliasMap[modelName]; exists {
if _, err := registry.ValidateModel(provider, resolved); err == nil {
return resolved
}
}
2025-06-25 17:24:37 +03:00
2025-06-25 14:27:19 +03:00
return modelName
}
2025-11-12 16:48:46 +03:00
// ProviderConfig holds configuration for creating LLM providers.
2025-06-09 14:38:31 +03:00
type ProviderConfig struct {
ModelString string
SystemPrompt string
2025-11-12 16:48:46 +03:00
ProviderAPIKey string
ProviderURL string
MaxTokens int
Temperature *float32
TopP *float32
TopK *int32
StopSequences []string
NumGPU *int32
MainGPU *int32
TLSSkipVerify bool
2025-06-09 14:38:31 +03:00
}
2025-11-12 16:48:46 +03:00
// ProviderResult contains the result of provider creation.
2025-06-26 13:32:18 +03:00
type ProviderResult struct {
// Model is the created fantasy LanguageModel
Model fantasy.LanguageModel
2025-11-12 16:48:46 +03:00
// Message contains optional feedback for the user
Message string
2025-06-26 13:32:18 +03:00
}
// CreateProvider creates a fantasy LanguageModel based on the provider configuration.
2025-11-12 16:48:46 +03:00
// It validates the model, checks required environment variables, and initializes
// the appropriate provider.
2025-11-12 16:48:46 +03:00
//
// Supported providers: anthropic, openai, google, ollama, azure, google-vertex-anthropic,
// openrouter, bedrock
2025-06-26 13:32:18 +03:00
func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResult, error) {
2025-06-09 14:38:31 +03:00
parts := strings.SplitN(config.ModelString, ":", 2)
if len(parts) < 2 {
return nil, fmt.Errorf("invalid model format. Expected provider:model, got %s", config.ModelString)
}
provider := parts[0]
modelName := parts[1]
2025-06-25 14:27:19 +03:00
// Resolve model aliases before validation (for OAuth compatibility)
if provider == "anthropic" || provider == "google-vertex-anthropic" {
2025-06-25 14:27:19 +03:00
modelName = resolveModelAlias(provider, modelName)
}
2025-06-18 13:48:08 +03:00
// Get the global registry for validation
registry := GetGlobalRegistry()
// Validate the model exists (skip for ollama as it's not in models.dev, and skip when using custom provider URL)
if provider != "ollama" && config.ProviderURL == "" {
2025-06-18 13:48:08 +03:00
modelInfo, err := registry.ValidateModel(provider, modelName)
if err != nil {
suggestions := registry.SuggestModels(provider, modelName)
if len(suggestions) > 0 {
return nil, fmt.Errorf("%v. Did you mean one of: %s", err, strings.Join(suggestions, ", "))
}
return nil, err
}
if err := registry.ValidateEnvironment(provider, config.ProviderAPIKey); err != nil {
return nil, err
}
if err := validateModelConfig(config, modelInfo); err != nil {
return nil, err
}
}
2025-06-09 14:38:31 +03:00
switch provider {
case "anthropic":
return createAnthropicProvider(ctx, config, modelName)
2025-06-09 14:38:31 +03:00
case "openai":
return createOpenAIProvider(ctx, config, modelName)
2025-06-09 14:38:31 +03:00
case "google":
return createGoogleProvider(ctx, config, modelName)
2025-06-09 14:38:31 +03:00
case "ollama":
return createOllamaProvider(ctx, config, modelName)
2025-06-24 07:47:46 +02:00
case "azure":
return createAzureProvider(ctx, config, modelName)
case "google-vertex-anthropic":
return createVertexAnthropicProvider(ctx, config, modelName)
case "openrouter":
return createOpenRouterProvider(ctx, config, modelName)
case "bedrock":
return createBedrockProvider(ctx, config, modelName)
2025-06-09 14:38:31 +03:00
default:
return nil, fmt.Errorf("unsupported provider: %s. Supported: anthropic, openai, google, ollama, azure, google-vertex-anthropic, openrouter, bedrock", provider)
2025-06-09 14:38:31 +03:00
}
}
2025-06-18 13:48:08 +03:00
// validateModelConfig validates configuration parameters against model capabilities
func validateModelConfig(config *ProviderConfig, modelInfo *ModelInfo) error {
if config.Temperature != nil && !modelInfo.Temperature {
config.Temperature = nil
2025-06-18 13:48:08 +03:00
}
if config.MaxTokens > modelInfo.Limit.Output {
return fmt.Errorf("max_tokens (%d) exceeds model's output limit (%d) for %s",
config.MaxTokens, modelInfo.Limit.Output, modelInfo.ID)
}
return nil
}
func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
apiKey, source, err := auth.GetAnthropicAPIKey(config.ProviderAPIKey)
if err != nil {
return nil, err
2025-06-24 07:47:46 +02:00
}
if os.Getenv("DEBUG") != "" || os.Getenv("MCPHOST_DEBUG") != "" {
fmt.Fprintf(os.Stderr, "Using Anthropic API key from: %s\n", source)
2025-06-24 07:47:46 +02:00
}
var opts []anthropic.Option
opts = append(opts, anthropic.WithAPIKey(apiKey))
2025-06-24 07:47:46 +02:00
if config.ProviderURL != "" {
opts = append(opts, anthropic.WithBaseURL(config.ProviderURL))
2025-06-24 07:47:46 +02:00
}
// Handle OAuth vs API key authentication
if strings.HasPrefix(source, "stored OAuth") {
httpClient := createOAuthHTTPClient(apiKey, config.TLSSkipVerify)
opts = append(opts, anthropic.WithHTTPClient(httpClient))
// Note: For OAuth, the API key is set as a placeholder; the transport handles auth
} else if config.TLSSkipVerify {
opts = append(opts, anthropic.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
2025-06-24 07:47:46 +02:00
}
provider, err := anthropic.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create Anthropic provider: %w", err)
2025-06-24 07:47:46 +02:00
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create Anthropic model: %w", err)
2025-06-24 07:47:46 +02:00
}
return &ProviderResult{Model: model}, nil
}
func createVertexAnthropicProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
projectID := firstNonEmpty(
os.Getenv("GOOGLE_VERTEX_PROJECT"),
os.Getenv("ANTHROPIC_VERTEX_PROJECT_ID"),
os.Getenv("GOOGLE_CLOUD_PROJECT"),
os.Getenv("GCLOUD_PROJECT"),
os.Getenv("CLOUDSDK_CORE_PROJECT"),
)
if projectID == "" {
return nil, fmt.Errorf("Google Vertex project ID not provided. Set ANTHROPIC_VERTEX_PROJECT_ID, GOOGLE_CLOUD_PROJECT, or GCLOUD_PROJECT environment variable")
2025-06-24 07:47:46 +02:00
}
region := firstNonEmpty(
os.Getenv("GOOGLE_VERTEX_LOCATION"),
os.Getenv("ANTHROPIC_VERTEX_REGION"),
os.Getenv("CLOUD_ML_REGION"),
)
if region == "" {
region = "global"
}
var opts []anthropic.Option
opts = append(opts, anthropic.WithVertex(projectID, region))
2025-06-24 07:47:46 +02:00
provider, err := anthropic.New(opts...)
2025-06-25 14:27:19 +03:00
if err != nil {
return nil, fmt.Errorf("failed to create Vertex Anthropic provider: %w", err)
2025-06-09 14:38:31 +03:00
}
2025-06-25 14:27:19 +03:00
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create Vertex Anthropic model: %w", err)
2025-06-11 11:45:55 +03:00
}
return &ProviderResult{Model: model}, nil
}
2025-06-25 14:27:19 +03:00
func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
apiKey := config.ProviderAPIKey
if apiKey == "" {
apiKey = os.Getenv("OPENAI_API_KEY")
2025-06-09 14:38:31 +03:00
}
if apiKey == "" {
return nil, fmt.Errorf("OpenAI API key not provided. Use --provider-api-key flag or OPENAI_API_KEY environment variable")
2025-06-11 11:45:55 +03:00
}
var opts []openai.Option
opts = append(opts, openai.WithAPIKey(apiKey))
2025-06-25 14:27:19 +03:00
2025-06-11 11:45:55 +03:00
if config.ProviderURL != "" {
opts = append(opts, openai.WithBaseURL(config.ProviderURL))
2025-06-11 11:45:55 +03:00
}
if config.TLSSkipVerify {
opts = append(opts, openai.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
2025-06-11 11:45:55 +03:00
}
provider, err := openai.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create OpenAI provider: %w", err)
2025-06-09 14:38:31 +03:00
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create OpenAI model: %w", err)
2025-06-09 14:38:31 +03:00
}
return &ProviderResult{Model: model}, nil
2025-06-09 14:38:31 +03:00
}
func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
apiKey := firstNonEmpty(
config.ProviderAPIKey,
os.Getenv("GOOGLE_API_KEY"),
os.Getenv("GEMINI_API_KEY"),
os.Getenv("GOOGLE_GENERATIVE_AI_API_KEY"),
)
if apiKey == "" {
return nil, fmt.Errorf("Google API key not provided. Use --provider-api-key flag or GOOGLE_API_KEY/GEMINI_API_KEY/GOOGLE_GENERATIVE_AI_API_KEY environment variable")
}
var opts []google.Option
opts = append(opts, google.WithGeminiAPIKey(apiKey))
provider, err := google.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create Google provider: %w", err)
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create Google model: %w", err)
}
return &ProviderResult{Model: model}, nil
}
func createAzureProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
apiKey := config.ProviderAPIKey
if apiKey == "" {
apiKey = os.Getenv("AZURE_OPENAI_API_KEY")
}
if apiKey == "" {
return nil, fmt.Errorf("Azure OpenAI API key not provided. Use --provider-api-key flag or AZURE_OPENAI_API_KEY environment variable")
}
baseURL := config.ProviderURL
if baseURL == "" {
baseURL = os.Getenv("AZURE_OPENAI_BASE_URL")
}
if baseURL == "" {
return nil, fmt.Errorf("Azure OpenAI Base URL not provided. Use --provider-url flag or AZURE_OPENAI_BASE_URL environment variable")
}
var opts []azure.Option
opts = append(opts, azure.WithAPIKey(apiKey))
opts = append(opts, azure.WithBaseURL(baseURL))
if config.TLSSkipVerify {
opts = append(opts, azure.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
}
provider, err := azure.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create Azure OpenAI provider: %w", err)
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create Azure OpenAI model: %w", err)
}
return &ProviderResult{Model: model}, nil
}
func createOpenRouterProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
2025-06-11 11:45:55 +03:00
apiKey := config.ProviderAPIKey
2025-06-09 14:38:31 +03:00
if apiKey == "" {
apiKey = os.Getenv("OPENROUTER_API_KEY")
2025-06-09 14:38:31 +03:00
}
if apiKey == "" {
return nil, fmt.Errorf("OpenRouter API key not provided. Use --provider-api-key flag or OPENROUTER_API_KEY environment variable")
2025-06-09 14:38:31 +03:00
}
var opts []openrouter.Option
opts = append(opts, openrouter.WithAPIKey(apiKey))
2025-06-11 11:45:55 +03:00
provider, err := openrouter.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create OpenRouter provider: %w", err)
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create OpenRouter model: %w", err)
2025-06-11 11:45:55 +03:00
}
return &ProviderResult{Model: model}, nil
}
2025-06-11 11:45:55 +03:00
func createBedrockProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
var opts []bedrock.Option
// Bedrock uses AWS SDK default credential chain (env vars, shared config, etc.)
provider, err := bedrock.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create Bedrock provider: %w", err)
2025-06-11 11:45:55 +03:00
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create Bedrock model: %w", err)
2025-06-09 14:38:31 +03:00
}
return &ProviderResult{Model: model}, nil
2025-06-09 14:38:31 +03:00
}
func createOllamaProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
baseURL := "http://localhost:11434"
if host := os.Getenv("OLLAMA_HOST"); host != "" {
baseURL = host
}
if config.ProviderURL != "" {
baseURL = config.ProviderURL
}
// Pre-load model with GPU fallback
loadingResult, err := loadOllamaModelWithFallback(ctx, baseURL, modelName, config)
var loadingMessage string
2025-06-09 14:38:31 +03:00
if err != nil {
loadingMessage = ""
} else {
loadingMessage = loadingResult.Message
2025-06-09 14:38:31 +03:00
}
// Use openaicompat provider pointed at Ollama's OpenAI-compatible endpoint
ollamaAPIBase := strings.TrimRight(baseURL, "/") + "/v1"
2025-06-09 14:38:31 +03:00
var opts []openaicompat.Option
opts = append(opts, openaicompat.WithBaseURL(ollamaAPIBase))
opts = append(opts, openaicompat.WithName("ollama"))
if config.ProviderAPIKey != "" {
opts = append(opts, openaicompat.WithAPIKey(config.ProviderAPIKey))
} else {
// Ollama doesn't require an API key, but the openaicompat provider might need one
opts = append(opts, openaicompat.WithAPIKey("ollama"))
2025-06-11 11:45:55 +03:00
}
if config.TLSSkipVerify {
opts = append(opts, openaicompat.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
2025-06-11 11:45:55 +03:00
}
provider, err := openaicompat.New(opts...)
if err != nil {
return nil, fmt.Errorf("failed to create Ollama provider: %w", err)
2025-06-11 11:45:55 +03:00
}
model, err := provider.LanguageModel(ctx, modelName)
if err != nil {
return nil, fmt.Errorf("failed to create Ollama model: %w", err)
2025-06-11 11:45:55 +03:00
}
return &ProviderResult{
Model: model,
Message: loadingMessage,
}, nil
2025-06-09 14:38:31 +03:00
}
2025-11-12 16:48:46 +03:00
// OllamaLoadingResult contains the result of model loading with actual settings used.
2025-06-26 13:32:18 +03:00
type OllamaLoadingResult struct {
Message string
}
// loadOllamaModelWithFallback loads an Ollama model with GPU settings and automatic CPU fallback
func loadOllamaModelWithFallback(ctx context.Context, baseURL, modelName string, config *ProviderConfig) (*OllamaLoadingResult, error) {
client := createHTTPClientWithTLSConfig(config.TLSSkipVerify)
2025-06-26 13:32:18 +03:00
// Phase 1: Check if model exists locally
if err := checkOllamaModelExists(client, baseURL, modelName); err != nil {
// Phase 2: Pull model if not found
if err := pullOllamaModel(ctx, client, baseURL, modelName); err != nil {
return nil, fmt.Errorf("failed to pull model %s: %v", modelName, err)
}
}
// Phase 3: Warmup the model
options := buildOllamaOptions(config)
2025-06-26 13:32:18 +03:00
_, err := loadOllamaModelWithOptions(ctx, client, baseURL, modelName, options)
if err != nil {
// Phase 4: Fallback to CPU if GPU memory insufficient
if isGPUMemoryError(err) {
cpuOptions := make(map[string]any)
maps.Copy(cpuOptions, options)
cpuOptions["num_gpu"] = 0
2025-06-26 13:32:18 +03:00
_, cpuErr := loadOllamaModelWithOptions(ctx, client, baseURL, modelName, cpuOptions)
2025-06-26 13:32:18 +03:00
if cpuErr != nil {
return nil, fmt.Errorf("failed to load model on GPU (%v) and CPU fallback failed (%v)", err, cpuErr)
}
return &OllamaLoadingResult{
Message: "Insufficient GPU memory, falling back to CPU inference",
}, nil
}
return nil, err
}
return &OllamaLoadingResult{
Message: "Model loaded successfully on GPU",
}, nil
}
func buildOllamaOptions(config *ProviderConfig) map[string]any {
options := make(map[string]any)
if config.Temperature != nil {
options["temperature"] = *config.Temperature
}
if config.TopP != nil {
options["top_p"] = *config.TopP
}
if config.TopK != nil {
options["top_k"] = int(*config.TopK)
}
if len(config.StopSequences) > 0 {
options["stop"] = config.StopSequences
}
if config.MaxTokens > 0 {
options["num_predict"] = config.MaxTokens
}
if config.NumGPU != nil {
options["num_gpu"] = int(*config.NumGPU)
}
if config.MainGPU != nil {
options["main_gpu"] = int(*config.MainGPU)
}
return options
}
2025-06-26 13:32:18 +03:00
func checkOllamaModelExists(client *http.Client, baseURL, modelName string) error {
reqBody := map[string]string{"model": modelName}
jsonBody, _ := json.Marshal(reqBody)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/api/show", bytes.NewBuffer(jsonBody))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("model not found locally")
}
return nil
}
func pullOllamaModel(ctx context.Context, client *http.Client, baseURL, modelName string) error {
return pullOllamaModelWithProgress(ctx, client, baseURL, modelName, true)
}
func pullOllamaModelWithProgress(ctx context.Context, client *http.Client, baseURL, modelName string, showProgress bool) error {
reqBody := map[string]string{"name": modelName}
jsonBody, _ := json.Marshal(reqBody)
pullCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
req, err := http.NewRequestWithContext(pullCtx, "POST", baseURL+"/api/pull", bytes.NewBuffer(jsonBody))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("failed to pull model (status %d): %s", resp.StatusCode, string(body))
}
if showProgress {
progressReader := progress.NewProgressReader(resp.Body)
defer progressReader.Close()
_, err = io.ReadAll(progressReader)
} else {
_, err = io.ReadAll(resp.Body)
}
return err
}
func loadOllamaModelWithOptions(ctx context.Context, client *http.Client, baseURL, modelName string, options map[string]any) (map[string]any, error) {
warmupOptions := make(map[string]any)
maps.Copy(warmupOptions, options)
warmupOptions["num_predict"] = 1
2025-06-26 13:32:18 +03:00
reqBody := map[string]any{
2025-06-26 13:32:18 +03:00
"model": modelName,
"prompt": "Hello",
"stream": false,
"options": warmupOptions,
2025-06-26 13:32:18 +03:00
}
jsonBody, _ := json.Marshal(reqBody)
warmupCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(warmupCtx, "POST", baseURL+"/api/generate", bytes.NewBuffer(jsonBody))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("warmup request failed (status %d): %s", resp.StatusCode, string(body))
}
_, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return options, nil
}
func isGPUMemoryError(err error) bool {
errStr := strings.ToLower(err.Error())
return strings.Contains(errStr, "out of memory") ||
strings.Contains(errStr, "insufficient memory") ||
strings.Contains(errStr, "cuda out of memory") ||
strings.Contains(errStr, "gpu memory")
}
// createHTTPClientWithTLSConfig creates an HTTP client with optional TLS skip verify
func createHTTPClientWithTLSConfig(skipVerify bool) *http.Client {
if !skipVerify {
return &http.Client{}
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
return &http.Client{
Transport: transport,
}
}
2025-06-25 14:27:19 +03:00
// createOAuthHTTPClient creates an HTTP client that adds OAuth headers for Anthropic API
func createOAuthHTTPClient(accessToken string, skipVerify bool) *http.Client {
var base = http.DefaultTransport
if skipVerify {
base = &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
}
2025-06-25 14:27:19 +03:00
return &http.Client{
Transport: &oauthTransport{
accessToken: accessToken,
base: base,
2025-06-25 14:27:19 +03:00
},
}
}
type oauthTransport struct {
accessToken string
base http.RoundTripper
}
func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := req.Clone(req.Context())
newReq.Header.Del("x-api-key")
newReq.Header.Set("Authorization", "Bearer "+t.accessToken)
newReq.Header.Set("anthropic-beta", "oauth-2025-04-20")
newReq.Header.Set("anthropic-version", "2023-06-01")
if req.Method == "POST" && strings.Contains(req.URL.Path, "/v1/messages") && req.Body != nil {
body, err := io.ReadAll(req.Body)
if err == nil {
modifiedBody, err := t.injectClaudeCodePrompt(body)
if err == nil {
newReq.Body = io.NopCloser(bytes.NewReader(modifiedBody))
newReq.ContentLength = int64(len(modifiedBody))
}
}
}
return t.base.RoundTrip(newReq)
}
func (t *oauthTransport) injectClaudeCodePrompt(body []byte) ([]byte, error) {
var data map[string]any
2025-06-25 14:27:19 +03:00
if err := json.Unmarshal(body, &data); err != nil {
return body, nil
2025-06-25 14:27:19 +03:00
}
systemRaw, hasSystem := data["system"]
if !hasSystem {
data["system"] = ClaudeCodePrompt
return json.Marshal(data)
}
switch system := systemRaw.(type) {
case string:
if system == ClaudeCodePrompt {
return body, nil
}
data["system"] = []any{
map[string]any{"type": "text", "text": ClaudeCodePrompt},
map[string]any{"type": "text", "text": system},
2025-06-25 14:27:19 +03:00
}
case []any:
2025-06-25 14:27:19 +03:00
if len(system) > 0 {
if first, ok := system[0].(map[string]any); ok {
2025-06-25 14:27:19 +03:00
if text, ok := first["text"].(string); ok && text == ClaudeCodePrompt {
return body, nil
}
}
}
newSystem := []any{
map[string]any{"type": "text", "text": ClaudeCodePrompt},
2025-06-25 14:27:19 +03:00
}
data["system"] = append(newSystem, system...)
}
return json.Marshal(data)
}
// firstNonEmpty returns the first non-empty string from the arguments.
func firstNonEmpty(values ...string) string {
for _, v := range values {
if v != "" {
return v
}
}
return ""
}