mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-13 19:20:06 +00:00
0703dd1602
Each spinner created a new tea.NewProgram which sent DECRQM queries for synchronized output mode 2026. When the program exited and restored cooked terminal mode, the terminal's DECRPM response leaked as visible ^[[?2026;2$y characters. Replace Bubble Tea spinner with a simple goroutine animation loop writing directly to stderr via lipgloss.
722 lines
21 KiB
Go
722 lines
21 KiB
Go
package models
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"maps"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"charm.land/fantasy"
|
|
"charm.land/fantasy/providers/anthropic"
|
|
"charm.land/fantasy/providers/azure"
|
|
"charm.land/fantasy/providers/bedrock"
|
|
"charm.land/fantasy/providers/google"
|
|
"charm.land/fantasy/providers/openai"
|
|
"charm.land/fantasy/providers/openaicompat"
|
|
"charm.land/fantasy/providers/openrouter"
|
|
|
|
"github.com/mark3labs/mcphost/internal/auth"
|
|
"github.com/mark3labs/mcphost/internal/ui/progress"
|
|
)
|
|
|
|
const (
|
|
// ClaudeCodePrompt is the required system prompt for OAuth authentication.
|
|
ClaudeCodePrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
|
)
|
|
|
|
// resolveModelAlias resolves model aliases to their full names using the registry
|
|
func resolveModelAlias(provider, modelName string) string {
|
|
registry := GetGlobalRegistry()
|
|
|
|
aliasMap := map[string]string{
|
|
"claude-opus-latest": "claude-opus-4-20250514",
|
|
"claude-sonnet-latest": "claude-sonnet-4-20250514",
|
|
"claude-4-opus-latest": "claude-opus-4-20250514",
|
|
"claude-4-sonnet-latest": "claude-sonnet-4-20250514",
|
|
|
|
"claude-3-5-haiku-latest": "claude-3-5-haiku-20241022",
|
|
"claude-3-5-sonnet-latest": "claude-3-5-sonnet-20241022",
|
|
"claude-3-7-sonnet-latest": "claude-3-7-sonnet-20250219",
|
|
"claude-3-opus-latest": "claude-3-opus-20240229",
|
|
}
|
|
|
|
if resolved, exists := aliasMap[modelName]; exists {
|
|
if _, err := registry.ValidateModel(provider, resolved); err == nil {
|
|
return resolved
|
|
}
|
|
}
|
|
|
|
return modelName
|
|
}
|
|
|
|
// ProviderConfig holds configuration for creating LLM providers.
|
|
type ProviderConfig struct {
|
|
ModelString string
|
|
SystemPrompt string
|
|
ProviderAPIKey string
|
|
ProviderURL string
|
|
MaxTokens int
|
|
Temperature *float32
|
|
TopP *float32
|
|
TopK *int32
|
|
StopSequences []string
|
|
NumGPU *int32
|
|
MainGPU *int32
|
|
TLSSkipVerify bool
|
|
}
|
|
|
|
// ProviderResult contains the result of provider creation.
|
|
type ProviderResult struct {
|
|
// Model is the created fantasy LanguageModel
|
|
Model fantasy.LanguageModel
|
|
// Message contains optional feedback for the user
|
|
Message string
|
|
}
|
|
|
|
// CreateProvider creates a fantasy LanguageModel based on the provider configuration.
|
|
// It validates the model, checks required environment variables, and initializes
|
|
// the appropriate provider.
|
|
//
|
|
// Supported providers: anthropic, openai, google, ollama, azure, google-vertex-anthropic,
|
|
// openrouter, bedrock
|
|
func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResult, error) {
|
|
parts := strings.SplitN(config.ModelString, ":", 2)
|
|
if len(parts) < 2 {
|
|
return nil, fmt.Errorf("invalid model format. Expected provider:model, got %s", config.ModelString)
|
|
}
|
|
|
|
provider := parts[0]
|
|
modelName := parts[1]
|
|
|
|
// Resolve model aliases before validation (for OAuth compatibility)
|
|
if provider == "anthropic" || provider == "google-vertex-anthropic" {
|
|
modelName = resolveModelAlias(provider, modelName)
|
|
}
|
|
|
|
// Get the global registry for validation
|
|
registry := GetGlobalRegistry()
|
|
|
|
// Validate the model exists (skip for ollama as it's not in models.dev, and skip when using custom provider URL)
|
|
if provider != "ollama" && config.ProviderURL == "" {
|
|
modelInfo, err := registry.ValidateModel(provider, modelName)
|
|
if err != nil {
|
|
suggestions := registry.SuggestModels(provider, modelName)
|
|
if len(suggestions) > 0 {
|
|
return nil, fmt.Errorf("%v. Did you mean one of: %s", err, strings.Join(suggestions, ", "))
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
if err := registry.ValidateEnvironment(provider, config.ProviderAPIKey); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := validateModelConfig(config, modelInfo); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
switch provider {
|
|
case "anthropic":
|
|
return createAnthropicProvider(ctx, config, modelName)
|
|
case "openai":
|
|
return createOpenAIProvider(ctx, config, modelName)
|
|
case "google":
|
|
return createGoogleProvider(ctx, config, modelName)
|
|
case "ollama":
|
|
return createOllamaProvider(ctx, config, modelName)
|
|
case "azure":
|
|
return createAzureProvider(ctx, config, modelName)
|
|
case "google-vertex-anthropic":
|
|
return createVertexAnthropicProvider(ctx, config, modelName)
|
|
case "openrouter":
|
|
return createOpenRouterProvider(ctx, config, modelName)
|
|
case "bedrock":
|
|
return createBedrockProvider(ctx, config, modelName)
|
|
default:
|
|
return nil, fmt.Errorf("unsupported provider: %s. Supported: anthropic, openai, google, ollama, azure, google-vertex-anthropic, openrouter, bedrock", provider)
|
|
}
|
|
}
|
|
|
|
// validateModelConfig validates configuration parameters against model capabilities
|
|
func validateModelConfig(config *ProviderConfig, modelInfo *ModelInfo) error {
|
|
if config.Temperature != nil && !modelInfo.Temperature {
|
|
config.Temperature = nil
|
|
}
|
|
|
|
if config.MaxTokens > modelInfo.Limit.Output {
|
|
return fmt.Errorf("max_tokens (%d) exceeds model's output limit (%d) for %s",
|
|
config.MaxTokens, modelInfo.Limit.Output, modelInfo.ID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
|
apiKey, source, err := auth.GetAnthropicAPIKey(config.ProviderAPIKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if os.Getenv("DEBUG") != "" || os.Getenv("MCPHOST_DEBUG") != "" {
|
|
fmt.Fprintf(os.Stderr, "Using Anthropic API key from: %s\n", source)
|
|
}
|
|
|
|
var opts []anthropic.Option
|
|
opts = append(opts, anthropic.WithAPIKey(apiKey))
|
|
|
|
if config.ProviderURL != "" {
|
|
opts = append(opts, anthropic.WithBaseURL(config.ProviderURL))
|
|
}
|
|
|
|
// Handle OAuth vs API key authentication
|
|
if strings.HasPrefix(source, "stored OAuth") {
|
|
httpClient := createOAuthHTTPClient(apiKey, config.TLSSkipVerify)
|
|
opts = append(opts, anthropic.WithHTTPClient(httpClient))
|
|
// Note: For OAuth, the API key is set as a placeholder; the transport handles auth
|
|
} else if config.TLSSkipVerify {
|
|
opts = append(opts, anthropic.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
|
|
}
|
|
|
|
provider, err := anthropic.New(opts...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Anthropic provider: %w", err)
|
|
}
|
|
|
|
model, err := provider.LanguageModel(ctx, modelName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Anthropic model: %w", err)
|
|
}
|
|
|
|
return &ProviderResult{Model: model}, nil
|
|
}
|
|
|
|
func createVertexAnthropicProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
|
projectID := firstNonEmpty(
|
|
os.Getenv("GOOGLE_VERTEX_PROJECT"),
|
|
os.Getenv("ANTHROPIC_VERTEX_PROJECT_ID"),
|
|
os.Getenv("GOOGLE_CLOUD_PROJECT"),
|
|
os.Getenv("GCLOUD_PROJECT"),
|
|
os.Getenv("CLOUDSDK_CORE_PROJECT"),
|
|
)
|
|
if projectID == "" {
|
|
return nil, fmt.Errorf("Google Vertex project ID not provided. Set ANTHROPIC_VERTEX_PROJECT_ID, GOOGLE_CLOUD_PROJECT, or GCLOUD_PROJECT environment variable")
|
|
}
|
|
|
|
region := firstNonEmpty(
|
|
os.Getenv("GOOGLE_VERTEX_LOCATION"),
|
|
os.Getenv("ANTHROPIC_VERTEX_REGION"),
|
|
os.Getenv("CLOUD_ML_REGION"),
|
|
)
|
|
if region == "" {
|
|
region = "global"
|
|
}
|
|
|
|
var opts []anthropic.Option
|
|
opts = append(opts, anthropic.WithVertex(projectID, region))
|
|
|
|
provider, err := anthropic.New(opts...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Vertex Anthropic provider: %w", err)
|
|
}
|
|
|
|
model, err := provider.LanguageModel(ctx, modelName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Vertex Anthropic model: %w", err)
|
|
}
|
|
|
|
return &ProviderResult{Model: model}, nil
|
|
}
|
|
|
|
func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
|
apiKey := config.ProviderAPIKey
|
|
if apiKey == "" {
|
|
apiKey = os.Getenv("OPENAI_API_KEY")
|
|
}
|
|
if apiKey == "" {
|
|
return nil, fmt.Errorf("OpenAI API key not provided. Use --provider-api-key flag or OPENAI_API_KEY environment variable")
|
|
}
|
|
|
|
var opts []openai.Option
|
|
opts = append(opts, openai.WithAPIKey(apiKey))
|
|
|
|
if config.ProviderURL != "" {
|
|
opts = append(opts, openai.WithBaseURL(config.ProviderURL))
|
|
}
|
|
|
|
if config.TLSSkipVerify {
|
|
opts = append(opts, openai.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
|
|
}
|
|
|
|
provider, err := openai.New(opts...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create OpenAI provider: %w", err)
|
|
}
|
|
|
|
model, err := provider.LanguageModel(ctx, modelName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create OpenAI model: %w", err)
|
|
}
|
|
|
|
return &ProviderResult{Model: model}, nil
|
|
}
|
|
|
|
func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
|
apiKey := firstNonEmpty(
|
|
config.ProviderAPIKey,
|
|
os.Getenv("GOOGLE_API_KEY"),
|
|
os.Getenv("GEMINI_API_KEY"),
|
|
os.Getenv("GOOGLE_GENERATIVE_AI_API_KEY"),
|
|
)
|
|
if apiKey == "" {
|
|
return nil, fmt.Errorf("Google API key not provided. Use --provider-api-key flag or GOOGLE_API_KEY/GEMINI_API_KEY/GOOGLE_GENERATIVE_AI_API_KEY environment variable")
|
|
}
|
|
|
|
var opts []google.Option
|
|
opts = append(opts, google.WithGeminiAPIKey(apiKey))
|
|
|
|
provider, err := google.New(opts...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Google provider: %w", err)
|
|
}
|
|
|
|
model, err := provider.LanguageModel(ctx, modelName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Google model: %w", err)
|
|
}
|
|
|
|
return &ProviderResult{Model: model}, nil
|
|
}
|
|
|
|
func createAzureProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
|
apiKey := config.ProviderAPIKey
|
|
if apiKey == "" {
|
|
apiKey = os.Getenv("AZURE_OPENAI_API_KEY")
|
|
}
|
|
if apiKey == "" {
|
|
return nil, fmt.Errorf("Azure OpenAI API key not provided. Use --provider-api-key flag or AZURE_OPENAI_API_KEY environment variable")
|
|
}
|
|
|
|
baseURL := config.ProviderURL
|
|
if baseURL == "" {
|
|
baseURL = os.Getenv("AZURE_OPENAI_BASE_URL")
|
|
}
|
|
if baseURL == "" {
|
|
return nil, fmt.Errorf("Azure OpenAI Base URL not provided. Use --provider-url flag or AZURE_OPENAI_BASE_URL environment variable")
|
|
}
|
|
|
|
var opts []azure.Option
|
|
opts = append(opts, azure.WithAPIKey(apiKey))
|
|
opts = append(opts, azure.WithBaseURL(baseURL))
|
|
|
|
if config.TLSSkipVerify {
|
|
opts = append(opts, azure.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
|
|
}
|
|
|
|
provider, err := azure.New(opts...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Azure OpenAI provider: %w", err)
|
|
}
|
|
|
|
model, err := provider.LanguageModel(ctx, modelName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Azure OpenAI model: %w", err)
|
|
}
|
|
|
|
return &ProviderResult{Model: model}, nil
|
|
}
|
|
|
|
func createOpenRouterProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
|
apiKey := config.ProviderAPIKey
|
|
if apiKey == "" {
|
|
apiKey = os.Getenv("OPENROUTER_API_KEY")
|
|
}
|
|
if apiKey == "" {
|
|
return nil, fmt.Errorf("OpenRouter API key not provided. Use --provider-api-key flag or OPENROUTER_API_KEY environment variable")
|
|
}
|
|
|
|
var opts []openrouter.Option
|
|
opts = append(opts, openrouter.WithAPIKey(apiKey))
|
|
|
|
provider, err := openrouter.New(opts...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create OpenRouter provider: %w", err)
|
|
}
|
|
|
|
model, err := provider.LanguageModel(ctx, modelName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create OpenRouter model: %w", err)
|
|
}
|
|
|
|
return &ProviderResult{Model: model}, nil
|
|
}
|
|
|
|
func createBedrockProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
|
var opts []bedrock.Option
|
|
|
|
// Bedrock uses AWS SDK default credential chain (env vars, shared config, etc.)
|
|
provider, err := bedrock.New(opts...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Bedrock provider: %w", err)
|
|
}
|
|
|
|
model, err := provider.LanguageModel(ctx, modelName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Bedrock model: %w", err)
|
|
}
|
|
|
|
return &ProviderResult{Model: model}, nil
|
|
}
|
|
|
|
func createOllamaProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
|
baseURL := "http://localhost:11434"
|
|
if host := os.Getenv("OLLAMA_HOST"); host != "" {
|
|
baseURL = host
|
|
}
|
|
if config.ProviderURL != "" {
|
|
baseURL = config.ProviderURL
|
|
}
|
|
|
|
// Pre-load model with GPU fallback
|
|
loadingResult, err := loadOllamaModelWithFallback(ctx, baseURL, modelName, config)
|
|
var loadingMessage string
|
|
if err != nil {
|
|
loadingMessage = ""
|
|
} else {
|
|
loadingMessage = loadingResult.Message
|
|
}
|
|
|
|
// Use openaicompat provider pointed at Ollama's OpenAI-compatible endpoint
|
|
ollamaAPIBase := strings.TrimRight(baseURL, "/") + "/v1"
|
|
|
|
var opts []openaicompat.Option
|
|
opts = append(opts, openaicompat.WithBaseURL(ollamaAPIBase))
|
|
opts = append(opts, openaicompat.WithName("ollama"))
|
|
|
|
if config.ProviderAPIKey != "" {
|
|
opts = append(opts, openaicompat.WithAPIKey(config.ProviderAPIKey))
|
|
} else {
|
|
// Ollama doesn't require an API key, but the openaicompat provider might need one
|
|
opts = append(opts, openaicompat.WithAPIKey("ollama"))
|
|
}
|
|
|
|
if config.TLSSkipVerify {
|
|
opts = append(opts, openaicompat.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
|
|
}
|
|
|
|
provider, err := openaicompat.New(opts...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Ollama provider: %w", err)
|
|
}
|
|
|
|
model, err := provider.LanguageModel(ctx, modelName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Ollama model: %w", err)
|
|
}
|
|
|
|
return &ProviderResult{
|
|
Model: model,
|
|
Message: loadingMessage,
|
|
}, nil
|
|
}
|
|
|
|
// OllamaLoadingResult contains the result of model loading with actual settings used.
|
|
type OllamaLoadingResult struct {
|
|
Message string
|
|
}
|
|
|
|
// loadOllamaModelWithFallback loads an Ollama model with GPU settings and automatic CPU fallback
|
|
func loadOllamaModelWithFallback(ctx context.Context, baseURL, modelName string, config *ProviderConfig) (*OllamaLoadingResult, error) {
|
|
client := createHTTPClientWithTLSConfig(config.TLSSkipVerify)
|
|
|
|
// Phase 1: Check if model exists locally
|
|
if err := checkOllamaModelExists(client, baseURL, modelName); err != nil {
|
|
// Phase 2: Pull model if not found
|
|
if err := pullOllamaModel(ctx, client, baseURL, modelName); err != nil {
|
|
return nil, fmt.Errorf("failed to pull model %s: %v", modelName, err)
|
|
}
|
|
}
|
|
|
|
// Phase 3: Warmup the model
|
|
options := buildOllamaOptions(config)
|
|
_, err := loadOllamaModelWithOptions(ctx, client, baseURL, modelName, options)
|
|
if err != nil {
|
|
// Phase 4: Fallback to CPU if GPU memory insufficient
|
|
if isGPUMemoryError(err) {
|
|
cpuOptions := make(map[string]any)
|
|
maps.Copy(cpuOptions, options)
|
|
cpuOptions["num_gpu"] = 0
|
|
|
|
_, cpuErr := loadOllamaModelWithOptions(ctx, client, baseURL, modelName, cpuOptions)
|
|
if cpuErr != nil {
|
|
return nil, fmt.Errorf("failed to load model on GPU (%v) and CPU fallback failed (%v)", err, cpuErr)
|
|
}
|
|
|
|
return &OllamaLoadingResult{
|
|
Message: "Insufficient GPU memory, falling back to CPU inference",
|
|
}, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
return &OllamaLoadingResult{
|
|
Message: "Model loaded successfully on GPU",
|
|
}, nil
|
|
}
|
|
|
|
func buildOllamaOptions(config *ProviderConfig) map[string]any {
|
|
options := make(map[string]any)
|
|
if config.Temperature != nil {
|
|
options["temperature"] = *config.Temperature
|
|
}
|
|
if config.TopP != nil {
|
|
options["top_p"] = *config.TopP
|
|
}
|
|
if config.TopK != nil {
|
|
options["top_k"] = int(*config.TopK)
|
|
}
|
|
if len(config.StopSequences) > 0 {
|
|
options["stop"] = config.StopSequences
|
|
}
|
|
if config.MaxTokens > 0 {
|
|
options["num_predict"] = config.MaxTokens
|
|
}
|
|
if config.NumGPU != nil {
|
|
options["num_gpu"] = int(*config.NumGPU)
|
|
}
|
|
if config.MainGPU != nil {
|
|
options["main_gpu"] = int(*config.MainGPU)
|
|
}
|
|
return options
|
|
}
|
|
|
|
func checkOllamaModelExists(client *http.Client, baseURL, modelName string) error {
|
|
reqBody := map[string]string{"model": modelName}
|
|
jsonBody, _ := json.Marshal(reqBody)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/api/show", bytes.NewBuffer(jsonBody))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("model not found locally")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func pullOllamaModel(ctx context.Context, client *http.Client, baseURL, modelName string) error {
|
|
return pullOllamaModelWithProgress(ctx, client, baseURL, modelName, true)
|
|
}
|
|
|
|
func pullOllamaModelWithProgress(ctx context.Context, client *http.Client, baseURL, modelName string, showProgress bool) error {
|
|
reqBody := map[string]string{"name": modelName}
|
|
jsonBody, _ := json.Marshal(reqBody)
|
|
|
|
pullCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(pullCtx, "POST", baseURL+"/api/pull", bytes.NewBuffer(jsonBody))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("failed to pull model (status %d): %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
if showProgress {
|
|
progressReader := progress.NewProgressReader(resp.Body)
|
|
defer progressReader.Close()
|
|
_, err = io.ReadAll(progressReader)
|
|
} else {
|
|
_, err = io.ReadAll(resp.Body)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func loadOllamaModelWithOptions(ctx context.Context, client *http.Client, baseURL, modelName string, options map[string]any) (map[string]any, error) {
|
|
warmupOptions := make(map[string]any)
|
|
maps.Copy(warmupOptions, options)
|
|
warmupOptions["num_predict"] = 1
|
|
|
|
reqBody := map[string]any{
|
|
"model": modelName,
|
|
"prompt": "Hello",
|
|
"stream": false,
|
|
"options": warmupOptions,
|
|
}
|
|
|
|
jsonBody, _ := json.Marshal(reqBody)
|
|
|
|
warmupCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(warmupCtx, "POST", baseURL+"/api/generate", bytes.NewBuffer(jsonBody))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("warmup request failed (status %d): %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
_, err = io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return options, nil
|
|
}
|
|
|
|
func isGPUMemoryError(err error) bool {
|
|
errStr := strings.ToLower(err.Error())
|
|
return strings.Contains(errStr, "out of memory") ||
|
|
strings.Contains(errStr, "insufficient memory") ||
|
|
strings.Contains(errStr, "cuda out of memory") ||
|
|
strings.Contains(errStr, "gpu memory")
|
|
}
|
|
|
|
// createHTTPClientWithTLSConfig creates an HTTP client with optional TLS skip verify
|
|
func createHTTPClientWithTLSConfig(skipVerify bool) *http.Client {
|
|
if !skipVerify {
|
|
return &http.Client{}
|
|
}
|
|
|
|
transport := &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}
|
|
|
|
return &http.Client{
|
|
Transport: transport,
|
|
}
|
|
}
|
|
|
|
// createOAuthHTTPClient creates an HTTP client that adds OAuth headers for Anthropic API
|
|
func createOAuthHTTPClient(accessToken string, skipVerify bool) *http.Client {
|
|
var base = http.DefaultTransport
|
|
if skipVerify {
|
|
base = &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}
|
|
}
|
|
|
|
return &http.Client{
|
|
Transport: &oauthTransport{
|
|
accessToken: accessToken,
|
|
base: base,
|
|
},
|
|
}
|
|
}
|
|
|
|
type oauthTransport struct {
|
|
accessToken string
|
|
base http.RoundTripper
|
|
}
|
|
|
|
func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
newReq := req.Clone(req.Context())
|
|
newReq.Header.Del("x-api-key")
|
|
newReq.Header.Set("Authorization", "Bearer "+t.accessToken)
|
|
newReq.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
|
newReq.Header.Set("anthropic-version", "2023-06-01")
|
|
|
|
if req.Method == "POST" && strings.Contains(req.URL.Path, "/v1/messages") && req.Body != nil {
|
|
body, err := io.ReadAll(req.Body)
|
|
if err == nil {
|
|
modifiedBody, err := t.injectClaudeCodePrompt(body)
|
|
if err == nil {
|
|
newReq.Body = io.NopCloser(bytes.NewReader(modifiedBody))
|
|
newReq.ContentLength = int64(len(modifiedBody))
|
|
}
|
|
}
|
|
}
|
|
|
|
return t.base.RoundTrip(newReq)
|
|
}
|
|
|
|
func (t *oauthTransport) injectClaudeCodePrompt(body []byte) ([]byte, error) {
|
|
var data map[string]any
|
|
if err := json.Unmarshal(body, &data); err != nil {
|
|
return body, nil
|
|
}
|
|
|
|
systemRaw, hasSystem := data["system"]
|
|
if !hasSystem {
|
|
data["system"] = ClaudeCodePrompt
|
|
return json.Marshal(data)
|
|
}
|
|
|
|
switch system := systemRaw.(type) {
|
|
case string:
|
|
if system == ClaudeCodePrompt {
|
|
return body, nil
|
|
}
|
|
data["system"] = []any{
|
|
map[string]any{"type": "text", "text": ClaudeCodePrompt},
|
|
map[string]any{"type": "text", "text": system},
|
|
}
|
|
case []any:
|
|
if len(system) > 0 {
|
|
if first, ok := system[0].(map[string]any); ok {
|
|
if text, ok := first["text"].(string); ok && text == ClaudeCodePrompt {
|
|
return body, nil
|
|
}
|
|
}
|
|
}
|
|
newSystem := []any{
|
|
map[string]any{"type": "text", "text": ClaudeCodePrompt},
|
|
}
|
|
data["system"] = append(newSystem, system...)
|
|
}
|
|
|
|
return json.Marshal(data)
|
|
}
|
|
|
|
// firstNonEmpty returns the first non-empty string from the arguments.
|
|
func firstNonEmpty(values ...string) string {
|
|
for _, v := range values {
|
|
if v != "" {
|
|
return v
|
|
}
|
|
}
|
|
return ""
|
|
}
|