mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
873 lines
26 KiB
Go
873 lines
26 KiB
Go
package models
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
einoclaude "github.com/cloudwego/eino-ext/components/model/claude"
|
|
"github.com/cloudwego/eino-ext/components/model/ollama"
|
|
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
|
"github.com/cloudwego/eino/components/model"
|
|
"github.com/mark3labs/mcphost/internal/models/anthropic"
|
|
"github.com/mark3labs/mcphost/internal/models/openai"
|
|
"github.com/mark3labs/mcphost/internal/ui/progress"
|
|
"github.com/ollama/ollama/api"
|
|
"google.golang.org/genai"
|
|
|
|
"github.com/mark3labs/mcphost/internal/auth"
|
|
"github.com/mark3labs/mcphost/internal/models/gemini"
|
|
)
|
|
|
|
const (
|
|
// ClaudeCodePrompt is the required system prompt for OAuth authentication.
|
|
// This prompt must be included as the first system message when using OAuth-based
|
|
// authentication with Anthropic's API to properly identify the application.
|
|
ClaudeCodePrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
|
)
|
|
|
|
// resolveModelAlias resolves model aliases to their full names using the registry
|
|
func resolveModelAlias(provider, modelName string) string {
|
|
registry := GetGlobalRegistry()
|
|
|
|
// Common alias patterns for Anthropic models - using Claude 4 as the latest/default
|
|
aliasMap := map[string]string{
|
|
// Claude 4 models (latest and most capable)
|
|
"claude-opus-latest": "claude-opus-4-20250514",
|
|
"claude-sonnet-latest": "claude-sonnet-4-20250514",
|
|
"claude-4-opus-latest": "claude-opus-4-20250514",
|
|
"claude-4-sonnet-latest": "claude-sonnet-4-20250514",
|
|
|
|
// Claude 3.x models for backward compatibility
|
|
"claude-3-5-haiku-latest": "claude-3-5-haiku-20241022",
|
|
"claude-3-5-sonnet-latest": "claude-3-5-sonnet-20241022",
|
|
"claude-3-7-sonnet-latest": "claude-3-7-sonnet-20250219",
|
|
"claude-3-opus-latest": "claude-3-opus-20240229",
|
|
}
|
|
|
|
// Check if it's a known alias
|
|
if resolved, exists := aliasMap[modelName]; exists {
|
|
// Verify the resolved model exists in the registry
|
|
if _, err := registry.ValidateModel(provider, resolved); err == nil {
|
|
return resolved
|
|
}
|
|
}
|
|
|
|
// Return original if no alias found or resolved model doesn't exist
|
|
return modelName
|
|
}
|
|
|
|
// ProviderConfig holds configuration for creating LLM providers.
|
|
// It contains all necessary settings to initialize and configure
|
|
// various LLM providers including API keys, model parameters, and
|
|
// provider-specific settings.
|
|
type ProviderConfig struct {
|
|
// ModelString specifies the model in the format "provider:model"
|
|
// Example: "anthropic:claude-3-sonnet-20240620" or "openai:gpt-4"
|
|
ModelString string
|
|
|
|
// SystemPrompt sets the system message/instructions for the model
|
|
SystemPrompt string
|
|
|
|
// ProviderAPIKey is the API key for authentication with the provider
|
|
// Used for OpenAI, Anthropic, Google, and Azure providers
|
|
ProviderAPIKey string
|
|
|
|
// ProviderURL is the base URL for the provider's API endpoint
|
|
// Can be used to specify custom endpoints for OpenAI, Anthropic, Ollama, or Azure
|
|
ProviderURL string
|
|
|
|
// Model generation parameters
|
|
|
|
// MaxTokens limits the maximum number of tokens in the response
|
|
MaxTokens int
|
|
|
|
// Temperature controls randomness in generation (0.0 to 1.0)
|
|
// Lower values make output more focused and deterministic
|
|
Temperature *float32
|
|
|
|
// TopP implements nucleus sampling, controlling diversity
|
|
// Value between 0.0 and 1.0, where 1.0 considers all tokens
|
|
TopP *float32
|
|
|
|
// TopK limits the number of tokens to sample from
|
|
// Used by Anthropic and Google providers
|
|
TopK *int32
|
|
|
|
// StopSequences is a list of sequences that will stop generation
|
|
StopSequences []string
|
|
|
|
// Ollama-specific parameters
|
|
|
|
// NumGPU specifies the number of GPUs to use for inference
|
|
NumGPU *int32
|
|
|
|
// MainGPU specifies which GPU to use as the primary device
|
|
MainGPU *int32
|
|
|
|
// TLS configuration
|
|
|
|
// TLSSkipVerify skips TLS certificate verification (insecure)
|
|
// Should only be used for development or with self-signed certificates
|
|
TLSSkipVerify bool
|
|
}
|
|
|
|
// ProviderResult contains the result of provider creation.
|
|
// It includes both the created model instance and any informational
|
|
// messages that should be displayed to the user.
|
|
type ProviderResult struct {
|
|
// Model is the created LLM provider instance that implements
|
|
// the ToolCallingChatModel interface for tool-enabled conversations
|
|
Model model.ToolCallingChatModel
|
|
|
|
// Message contains optional feedback for the user
|
|
// Example: "Insufficient GPU memory, falling back to CPU inference"
|
|
Message string
|
|
}
|
|
|
|
// CreateProvider creates an eino ToolCallingChatModel based on the provider configuration.
|
|
// It validates the model, checks required environment variables, and initializes
|
|
// the appropriate provider (Anthropic, OpenAI, Google, Ollama, or Azure).
|
|
//
|
|
// Parameters:
|
|
// - ctx: Context for the operation, used for cancellation and deadlines
|
|
// - config: Provider configuration containing model details and parameters
|
|
//
|
|
// Returns:
|
|
// - *ProviderResult: Contains the created model and any informational messages
|
|
// - error: Returns an error if provider creation fails, model validation fails,
|
|
// or required credentials are missing
|
|
//
|
|
// Supported providers:
|
|
// - anthropic: Claude models with API key or OAuth authentication
|
|
// - openai: GPT models including reasoning models like o1
|
|
// - google: Gemini models
|
|
// - ollama: Local models with automatic GPU/CPU fallback
|
|
// - azure: Azure OpenAI Service deployments
|
|
func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResult, error) {
|
|
parts := strings.SplitN(config.ModelString, ":", 2)
|
|
if len(parts) < 2 {
|
|
return nil, fmt.Errorf("invalid model format. Expected provider:model, got %s", config.ModelString)
|
|
}
|
|
|
|
provider := parts[0]
|
|
modelName := parts[1]
|
|
|
|
// Resolve model aliases before validation (for OAuth compatibility)
|
|
if provider == "anthropic" {
|
|
modelName = resolveModelAlias(provider, modelName)
|
|
}
|
|
|
|
// Get the global registry for validation
|
|
registry := GetGlobalRegistry()
|
|
|
|
// Validate the model exists (skip for ollama as it's not in models.dev, and skip when using custom provider URL)
|
|
if provider != "ollama" && config.ProviderURL == "" {
|
|
modelInfo, err := registry.ValidateModel(provider, modelName)
|
|
if err != nil {
|
|
// Provide helpful suggestions
|
|
suggestions := registry.SuggestModels(provider, modelName)
|
|
if len(suggestions) > 0 {
|
|
return nil, fmt.Errorf("%v. Did you mean one of: %s", err, strings.Join(suggestions, ", "))
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// Validate environment variables
|
|
if err := registry.ValidateEnvironment(provider, config.ProviderAPIKey); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Validate configuration parameters against model capabilities
|
|
if err := validateModelConfig(config, modelInfo); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
switch provider {
|
|
case "anthropic":
|
|
model, err := createAnthropicProvider(ctx, config, modelName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &ProviderResult{Model: model, Message: ""}, nil
|
|
case "openai":
|
|
model, err := createOpenAIProvider(ctx, config, modelName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &ProviderResult{Model: model, Message: ""}, nil
|
|
case "google":
|
|
model, err := createGoogleProvider(ctx, config, modelName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &ProviderResult{Model: model, Message: ""}, nil
|
|
case "ollama":
|
|
return createOllamaProviderWithResult(ctx, config, modelName)
|
|
case "azure":
|
|
model, err := createAzureOpenAIProvider(ctx, config, modelName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &ProviderResult{Model: model, Message: ""}, nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
|
}
|
|
}
|
|
|
|
// validateModelConfig validates configuration parameters against model capabilities
|
|
func validateModelConfig(config *ProviderConfig, modelInfo *ModelInfo) error {
|
|
// Omit temperature if not supported by the model
|
|
if config.Temperature != nil && !modelInfo.Temperature {
|
|
config.Temperature = nil
|
|
}
|
|
|
|
// Warn about context limits if MaxTokens is set too high
|
|
if config.MaxTokens > modelInfo.Limit.Output {
|
|
return fmt.Errorf("max_tokens (%d) exceeds model's output limit (%d) for %s",
|
|
config.MaxTokens, modelInfo.Limit.Output, modelInfo.ID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func createAzureOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName string) (model.ToolCallingChatModel, error) {
|
|
apiKey := config.ProviderAPIKey
|
|
if apiKey == "" {
|
|
apiKey = os.Getenv("AZURE_OPENAI_API_KEY")
|
|
}
|
|
if apiKey == "" {
|
|
return nil, fmt.Errorf("Azure OpenAI API key not provided. Use --provider-api-key flag or AZURE_OPENAI_API_KEY environment variable")
|
|
}
|
|
|
|
azureConfig := &einoopenai.ChatModelConfig{
|
|
APIKey: apiKey,
|
|
Model: modelName,
|
|
ByAzure: true, // Indicate this is an Azure OpenAI model
|
|
APIVersion: "2025-01-01-preview", // Default Azure OpenAI API version
|
|
}
|
|
|
|
if config.ProviderURL != "" {
|
|
azureConfig.BaseURL = config.ProviderURL
|
|
} else {
|
|
azureConfig.BaseURL = os.Getenv("AZURE_OPENAI_BASE_URL")
|
|
}
|
|
if azureConfig.BaseURL == "" {
|
|
return nil, fmt.Errorf("Azure OpenAI Base URL not provided. Use --provider-url flag or AZURE_OPENAI_BASE_URL environment variable")
|
|
}
|
|
|
|
if config.MaxTokens > 0 {
|
|
azureConfig.MaxTokens = &config.MaxTokens
|
|
}
|
|
|
|
if config.Temperature != nil {
|
|
azureConfig.Temperature = config.Temperature
|
|
}
|
|
|
|
if config.TopP != nil {
|
|
azureConfig.TopP = config.TopP
|
|
}
|
|
|
|
if len(config.StopSequences) > 0 {
|
|
azureConfig.Stop = config.StopSequences
|
|
}
|
|
|
|
// Set HTTP client with TLS config if needed
|
|
if config.TLSSkipVerify {
|
|
azureConfig.HTTPClient = createHTTPClientWithTLSConfig(true)
|
|
}
|
|
|
|
return openai.NewCustomChatModel(ctx, azureConfig)
|
|
}
|
|
|
|
func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelName string) (model.ToolCallingChatModel, error) {
|
|
apiKey, source, err := auth.GetAnthropicAPIKey(config.ProviderAPIKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Log the source of the API key in debug mode (without revealing the key)
|
|
if os.Getenv("DEBUG") != "" || os.Getenv("MCPHOST_DEBUG") != "" {
|
|
fmt.Fprintf(os.Stderr, "Using Anthropic API key from: %s\n", source)
|
|
}
|
|
|
|
// Model alias resolution is handled in CreateProvider
|
|
|
|
maxTokens := config.MaxTokens
|
|
if maxTokens == 0 {
|
|
maxTokens = 4096 // Default value
|
|
}
|
|
|
|
claudeConfig := &einoclaude.Config{
|
|
Model: modelName,
|
|
MaxTokens: maxTokens,
|
|
}
|
|
|
|
// Handle OAuth vs API key authentication
|
|
if strings.HasPrefix(source, "stored OAuth") {
|
|
// For OAuth tokens, we need to use Authorization: Bearer header
|
|
// Create a custom HTTP client that adds the proper headers
|
|
claudeConfig.HTTPClient = createOAuthHTTPClient(apiKey, config.TLSSkipVerify)
|
|
// Set a dummy API key to prevent the library from failing validation
|
|
claudeConfig.APIKey = "oauth-placeholder"
|
|
} else {
|
|
// For API keys, use the standard x-api-key header
|
|
claudeConfig.APIKey = apiKey
|
|
// Set HTTP client with TLS config if needed
|
|
if config.TLSSkipVerify {
|
|
claudeConfig.HTTPClient = createHTTPClientWithTLSConfig(true)
|
|
}
|
|
}
|
|
|
|
if config.ProviderURL != "" {
|
|
claudeConfig.BaseURL = &config.ProviderURL
|
|
}
|
|
|
|
if config.Temperature != nil {
|
|
claudeConfig.Temperature = config.Temperature
|
|
}
|
|
|
|
if config.TopP != nil {
|
|
claudeConfig.TopP = config.TopP
|
|
}
|
|
|
|
if config.TopK != nil {
|
|
claudeConfig.TopK = config.TopK
|
|
}
|
|
|
|
if len(config.StopSequences) > 0 {
|
|
claudeConfig.StopSequences = config.StopSequences
|
|
}
|
|
|
|
return anthropic.NewCustomChatModel(ctx, claudeConfig)
|
|
}
|
|
|
|
func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName string) (model.ToolCallingChatModel, error) {
|
|
apiKey := config.ProviderAPIKey
|
|
if apiKey == "" {
|
|
apiKey = os.Getenv("OPENAI_API_KEY")
|
|
}
|
|
if apiKey == "" {
|
|
return nil, fmt.Errorf("OpenAI API key not provided. Use --provider-api-key flag or OPENAI_API_KEY environment variable")
|
|
}
|
|
|
|
openaiConfig := &einoopenai.ChatModelConfig{
|
|
APIKey: apiKey,
|
|
Model: modelName,
|
|
}
|
|
|
|
if config.ProviderURL != "" {
|
|
openaiConfig.BaseURL = config.ProviderURL
|
|
}
|
|
|
|
// Set HTTP client with TLS config if needed
|
|
if config.TLSSkipVerify {
|
|
openaiConfig.HTTPClient = createHTTPClientWithTLSConfig(true)
|
|
}
|
|
|
|
// Check if this is a reasoning model to handle beta limitations (skip validation if using custom URL)
|
|
registry := GetGlobalRegistry()
|
|
isReasoningModel := false
|
|
if config.ProviderURL == "" {
|
|
if modelInfo, err := registry.ValidateModel("openai", modelName); err == nil && modelInfo.Reasoning {
|
|
isReasoningModel = true
|
|
}
|
|
}
|
|
|
|
if config.MaxTokens > 0 {
|
|
if isReasoningModel {
|
|
// For reasoning models, use MaxCompletionTokens instead of MaxTokens
|
|
if openaiConfig.ExtraFields == nil {
|
|
openaiConfig.ExtraFields = make(map[string]any)
|
|
}
|
|
openaiConfig.ExtraFields["max_completion_tokens"] = config.MaxTokens
|
|
} else {
|
|
// For non-reasoning models, use MaxTokens as usual
|
|
openaiConfig.MaxTokens = &config.MaxTokens
|
|
}
|
|
}
|
|
|
|
// For reasoning models, skip temperature and top_p due to beta limitations
|
|
if !isReasoningModel {
|
|
if config.Temperature != nil {
|
|
openaiConfig.Temperature = config.Temperature
|
|
}
|
|
|
|
if config.TopP != nil {
|
|
openaiConfig.TopP = config.TopP
|
|
}
|
|
}
|
|
|
|
if len(config.StopSequences) > 0 {
|
|
openaiConfig.Stop = config.StopSequences
|
|
}
|
|
|
|
return openai.NewCustomChatModel(ctx, openaiConfig)
|
|
}
|
|
|
|
func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName string) (model.ToolCallingChatModel, error) {
|
|
apiKey := config.ProviderAPIKey
|
|
if apiKey == "" {
|
|
apiKey = os.Getenv("GOOGLE_API_KEY")
|
|
}
|
|
if apiKey == "" {
|
|
apiKey = os.Getenv("GEMINI_API_KEY")
|
|
}
|
|
if apiKey == "" {
|
|
apiKey = os.Getenv("GOOGLE_GENERATIVE_AI_API_KEY")
|
|
}
|
|
if apiKey == "" {
|
|
return nil, fmt.Errorf("Google API key not provided. Use --provider-api-key flag or GOOGLE_API_KEY/GEMINI_API_KEY/GOOGLE_GENERATIVE_AI_API_KEY environment variable")
|
|
}
|
|
|
|
clientConfig := &genai.ClientConfig{
|
|
APIKey: apiKey,
|
|
Backend: genai.BackendGeminiAPI,
|
|
}
|
|
|
|
// Set HTTP client with TLS config if needed
|
|
if config.TLSSkipVerify {
|
|
clientConfig.HTTPClient = createHTTPClientWithTLSConfig(true)
|
|
}
|
|
|
|
client, err := genai.NewClient(ctx, clientConfig)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Google client: %v", err)
|
|
}
|
|
|
|
geminiConfig := &gemini.Config{
|
|
Client: client,
|
|
Model: modelName,
|
|
}
|
|
|
|
if config.MaxTokens > 0 {
|
|
geminiConfig.MaxTokens = &config.MaxTokens
|
|
}
|
|
|
|
if config.Temperature != nil {
|
|
geminiConfig.Temperature = config.Temperature
|
|
}
|
|
|
|
if config.TopP != nil {
|
|
geminiConfig.TopP = config.TopP
|
|
}
|
|
|
|
if config.TopK != nil {
|
|
geminiConfig.TopK = config.TopK
|
|
}
|
|
|
|
return gemini.NewChatModel(ctx, geminiConfig)
|
|
}
|
|
|
|
// OllamaLoadingResult contains the result of model loading with actual settings used.
|
|
// It provides information about how the model was loaded, including any fallback
|
|
// behavior that occurred during initialization.
|
|
type OllamaLoadingResult struct {
|
|
// Options contains the actual Ollama options used for loading
|
|
// May differ from requested options if fallback occurred (e.g., CPU instead of GPU)
|
|
Options *api.Options
|
|
|
|
// Message describes the loading result
|
|
// Example: "Model loaded successfully on GPU" or
|
|
// "Insufficient GPU memory, falling back to CPU inference"
|
|
Message string
|
|
}
|
|
|
|
// loadOllamaModelWithFallback loads an Ollama model with GPU settings and automatic CPU fallback
|
|
func loadOllamaModelWithFallback(ctx context.Context, baseURL, modelName string, options *api.Options, tlsSkipVerify bool) (*OllamaLoadingResult, error) {
|
|
client := createHTTPClientWithTLSConfig(tlsSkipVerify)
|
|
|
|
// Phase 1: Check if model exists locally
|
|
if err := checkOllamaModelExists(client, baseURL, modelName); err != nil {
|
|
// Phase 2: Pull model if not found
|
|
if err := pullOllamaModel(ctx, client, baseURL, modelName); err != nil {
|
|
return nil, fmt.Errorf("failed to pull model %s: %v", modelName, err)
|
|
}
|
|
}
|
|
|
|
// Phase 3: Load model with GPU settings
|
|
_, err := loadOllamaModelWithOptions(ctx, client, baseURL, modelName, options)
|
|
if err != nil {
|
|
// Phase 4: Fallback to CPU if GPU memory insufficient
|
|
if isGPUMemoryError(err) {
|
|
cpuOptions := *options
|
|
cpuOptions.NumGPU = 0
|
|
|
|
_, cpuErr := loadOllamaModelWithOptions(ctx, client, baseURL, modelName, &cpuOptions)
|
|
if cpuErr != nil {
|
|
return nil, fmt.Errorf("failed to load model on GPU (%v) and CPU fallback failed (%v)", err, cpuErr)
|
|
}
|
|
|
|
return &OllamaLoadingResult{
|
|
Options: &cpuOptions,
|
|
Message: "Insufficient GPU memory, falling back to CPU inference",
|
|
}, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
return &OllamaLoadingResult{
|
|
Options: options,
|
|
Message: "Model loaded successfully on GPU",
|
|
}, nil
|
|
}
|
|
|
|
// checkOllamaModelExists checks if a model exists locally
|
|
func checkOllamaModelExists(client *http.Client, baseURL, modelName string) error {
|
|
reqBody := map[string]string{"model": modelName}
|
|
jsonBody, _ := json.Marshal(reqBody)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/api/show", bytes.NewBuffer(jsonBody))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("model not found locally")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// pullOllamaModel pulls a model from the registry
|
|
func pullOllamaModel(ctx context.Context, client *http.Client, baseURL, modelName string) error {
|
|
return pullOllamaModelWithProgress(ctx, client, baseURL, modelName, true)
|
|
}
|
|
|
|
// pullOllamaModelWithProgress pulls a model from the registry with optional progress display
|
|
func pullOllamaModelWithProgress(ctx context.Context, client *http.Client, baseURL, modelName string, showProgress bool) error {
|
|
reqBody := map[string]string{"name": modelName}
|
|
jsonBody, _ := json.Marshal(reqBody)
|
|
|
|
// Use a longer timeout for pulling models (5 minutes)
|
|
pullCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(pullCtx, "POST", baseURL+"/api/pull", bytes.NewBuffer(jsonBody))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("failed to pull model (status %d): %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
// Read the streaming response with optional progress display
|
|
if showProgress {
|
|
progressReader := progress.NewProgressReader(resp.Body)
|
|
defer progressReader.Close()
|
|
_, err = io.ReadAll(progressReader)
|
|
} else {
|
|
_, err = io.ReadAll(resp.Body)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// loadOllamaModelWithOptions loads a model with specific options using a warmup request
|
|
func loadOllamaModelWithOptions(ctx context.Context, client *http.Client, baseURL, modelName string, options *api.Options) (*api.Options, error) {
|
|
// Create a copy of options for warmup to avoid modifying the original
|
|
warmupOptions := *options
|
|
warmupOptions.NumPredict = 1 // Limit response length for warmup
|
|
|
|
reqBody := map[string]interface{}{
|
|
"model": modelName,
|
|
"prompt": "Hello",
|
|
"stream": false,
|
|
"options": &warmupOptions,
|
|
}
|
|
|
|
jsonBody, _ := json.Marshal(reqBody)
|
|
|
|
// Use medium timeout for warmup (30 seconds)
|
|
warmupCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(warmupCtx, "POST", baseURL+"/api/generate", bytes.NewBuffer(jsonBody))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("warmup request failed (status %d): %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
// Read response to completion
|
|
_, err = io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return options, nil
|
|
}
|
|
|
|
// isGPUMemoryError checks if an error indicates insufficient GPU memory
|
|
func isGPUMemoryError(err error) bool {
|
|
errStr := strings.ToLower(err.Error())
|
|
return strings.Contains(errStr, "out of memory") ||
|
|
strings.Contains(errStr, "insufficient memory") ||
|
|
strings.Contains(errStr, "cuda out of memory") ||
|
|
strings.Contains(errStr, "gpu memory")
|
|
}
|
|
|
|
type bearerTransport struct {
|
|
base http.RoundTripper
|
|
token string
|
|
}
|
|
|
|
func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
newReq := req.Clone(req.Context())
|
|
newReq.Header.Set("Authorization", "Bearer "+t.token)
|
|
return t.base.RoundTrip(newReq)
|
|
}
|
|
|
|
func createOllamaProviderWithResult(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
|
baseURL := "http://localhost:11434" // Default Ollama URL
|
|
|
|
// Check for custom Ollama host from environment
|
|
if host := os.Getenv("OLLAMA_HOST"); host != "" {
|
|
baseURL = host
|
|
}
|
|
|
|
// Override with ProviderURL if provided
|
|
if config.ProviderURL != "" {
|
|
baseURL = config.ProviderURL
|
|
}
|
|
|
|
// Set up options for Ollama using the api.Options struct
|
|
options := &api.Options{}
|
|
|
|
if config.Temperature != nil {
|
|
options.Temperature = *config.Temperature
|
|
}
|
|
|
|
if config.TopP != nil {
|
|
options.TopP = *config.TopP
|
|
}
|
|
|
|
if config.TopK != nil {
|
|
options.TopK = int(*config.TopK)
|
|
}
|
|
|
|
if len(config.StopSequences) > 0 {
|
|
options.Stop = config.StopSequences
|
|
}
|
|
|
|
if config.MaxTokens > 0 {
|
|
options.NumPredict = config.MaxTokens
|
|
}
|
|
|
|
// Set GPU configuration for Ollama
|
|
if config.NumGPU != nil {
|
|
options.NumGPU = int(*config.NumGPU)
|
|
}
|
|
|
|
if config.MainGPU != nil {
|
|
options.MainGPU = int(*config.MainGPU)
|
|
}
|
|
|
|
// Create a clean copy of options for the final model
|
|
finalOptions := &api.Options{}
|
|
*finalOptions = *options // Copy all fields
|
|
|
|
// Try to pre-load the model with GPU settings and automatic CPU fallback
|
|
// If this fails, fall back to the original behavior
|
|
loadingResult, err := loadOllamaModelWithFallback(ctx, baseURL, modelName, options, config.TLSSkipVerify)
|
|
var loadingMessage string
|
|
|
|
if err != nil {
|
|
// Pre-loading failed, use original options and no message
|
|
loadingMessage = ""
|
|
} else {
|
|
// Pre-loading succeeded, update GPU settings that worked
|
|
finalOptions.NumGPU = loadingResult.Options.NumGPU
|
|
finalOptions.MainGPU = loadingResult.Options.MainGPU
|
|
loadingMessage = loadingResult.Message
|
|
}
|
|
|
|
ollamaConfig := &ollama.ChatModelConfig{
|
|
BaseURL: baseURL,
|
|
Model: modelName,
|
|
Options: finalOptions,
|
|
}
|
|
|
|
if config.ProviderAPIKey != "" {
|
|
transport := http.DefaultTransport
|
|
authTransport := &bearerTransport{
|
|
base: transport,
|
|
token: config.ProviderAPIKey,
|
|
}
|
|
ollamaConfig.HTTPClient = &http.Client{
|
|
Transport: authTransport,
|
|
}
|
|
}
|
|
|
|
// Set HTTP client with TLS config if needed
|
|
if config.TLSSkipVerify {
|
|
ollamaConfig.HTTPClient = createHTTPClientWithTLSConfig(true)
|
|
}
|
|
|
|
chatModel, err := ollama.NewChatModel(ctx, ollamaConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &ProviderResult{
|
|
Model: chatModel,
|
|
Message: loadingMessage,
|
|
}, nil
|
|
}
|
|
|
|
// createHTTPClientWithTLSConfig creates an HTTP client with optional TLS skip verify
|
|
func createHTTPClientWithTLSConfig(skipVerify bool) *http.Client {
|
|
if !skipVerify {
|
|
return &http.Client{}
|
|
}
|
|
|
|
transport := &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}
|
|
|
|
return &http.Client{
|
|
Transport: transport,
|
|
}
|
|
}
|
|
|
|
// createOAuthHTTPClient creates an HTTP client that adds OAuth headers for Anthropic API
|
|
func createOAuthHTTPClient(accessToken string, skipVerify bool) *http.Client {
|
|
var base http.RoundTripper = http.DefaultTransport
|
|
if skipVerify {
|
|
base = &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}
|
|
}
|
|
|
|
return &http.Client{
|
|
Transport: &oauthTransport{
|
|
accessToken: accessToken,
|
|
base: base,
|
|
},
|
|
}
|
|
}
|
|
|
|
// oauthTransport is an HTTP transport that adds OAuth headers
|
|
type oauthTransport struct {
|
|
accessToken string
|
|
base http.RoundTripper
|
|
}
|
|
|
|
func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
// Clone the request to avoid modifying the original
|
|
newReq := req.Clone(req.Context())
|
|
|
|
// Remove any existing x-api-key header (from the dummy API key)
|
|
newReq.Header.Del("x-api-key")
|
|
|
|
// Add OAuth headers as required by Anthropic's OAuth API
|
|
newReq.Header.Set("Authorization", "Bearer "+t.accessToken)
|
|
newReq.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
|
newReq.Header.Set("anthropic-version", "2023-06-01")
|
|
|
|
// Inject Claude Code system prompt for /v1/messages endpoint
|
|
if req.Method == "POST" && strings.Contains(req.URL.Path, "/v1/messages") && req.Body != nil {
|
|
body, err := io.ReadAll(req.Body)
|
|
if err == nil {
|
|
modifiedBody, err := t.injectClaudeCodePrompt(body)
|
|
if err == nil {
|
|
newReq.Body = io.NopCloser(bytes.NewReader(modifiedBody))
|
|
newReq.ContentLength = int64(len(modifiedBody))
|
|
}
|
|
}
|
|
}
|
|
|
|
// Use the base transport to make the request
|
|
return t.base.RoundTrip(newReq)
|
|
}
|
|
|
|
// injectClaudeCodePrompt modifies the request body to inject Claude Code system prompt
|
|
func (t *oauthTransport) injectClaudeCodePrompt(body []byte) ([]byte, error) {
|
|
var data map[string]interface{}
|
|
if err := json.Unmarshal(body, &data); err != nil {
|
|
return body, nil // Return original if not JSON
|
|
}
|
|
|
|
// Check if request has a system prompt
|
|
systemRaw, hasSystem := data["system"]
|
|
if !hasSystem {
|
|
// No system prompt, inject Claude Code identification
|
|
data["system"] = ClaudeCodePrompt
|
|
return json.Marshal(data)
|
|
}
|
|
|
|
switch system := systemRaw.(type) {
|
|
case string:
|
|
// Handle string system prompt
|
|
if system == ClaudeCodePrompt {
|
|
// Already correct, leave as-is
|
|
return body, nil
|
|
}
|
|
// Convert to array with Claude Code first
|
|
data["system"] = []interface{}{
|
|
map[string]interface{}{"type": "text", "text": ClaudeCodePrompt},
|
|
map[string]interface{}{"type": "text", "text": system},
|
|
}
|
|
|
|
case []interface{}:
|
|
// Handle array system prompt
|
|
if len(system) > 0 {
|
|
// Check if first element has correct text
|
|
if first, ok := system[0].(map[string]interface{}); ok {
|
|
if text, ok := first["text"].(string); ok && text == ClaudeCodePrompt {
|
|
// Already has Claude Code first, return as-is
|
|
return body, nil
|
|
}
|
|
}
|
|
}
|
|
// Prepend Claude Code identification
|
|
newSystem := []interface{}{
|
|
map[string]interface{}{"type": "text", "text": ClaudeCodePrompt},
|
|
}
|
|
data["system"] = append(newSystem, system...)
|
|
}
|
|
|
|
// Re-marshal the modified data
|
|
return json.Marshal(data)
|
|
}
|