diff --git a/go.mod b/go.mod index cfe90e9e..83d69b4d 100644 --- a/go.mod +++ b/go.mod @@ -10,26 +10,21 @@ require ( github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 github.com/cloudwego/eino v0.3.41 github.com/cloudwego/eino-ext/components/model/claude v0.0.0-20250609074000-b7f307dffa18 - github.com/cloudwego/eino-ext/components/model/gemini v0.0.0-20250609074000-b7f307dffa18 github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250609074000-b7f307dffa18 github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250609074000-b7f307dffa18 github.com/getkin/kin-openapi v0.118.0 - github.com/google/generative-ai-go v0.19.0 github.com/mark3labs/mcp-go v0.31.0 github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.20.1 golang.org/x/term v0.31.0 - google.golang.org/api v0.228.0 + google.golang.org/genai v1.10.0 gopkg.in/yaml.v3 v3.0.1 ) require ( cloud.google.com/go v0.116.0 // indirect - cloud.google.com/go/ai v0.8.0 // indirect cloud.google.com/go/auth v0.15.0 // indirect - cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.6.0 // indirect - cloud.google.com/go/longrunning v0.5.7 // indirect github.com/alecthomas/chroma/v2 v2.14.0 // indirect github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.8 // indirect github.com/atotto/clipboard v0.1.4 // indirect @@ -65,12 +60,14 @@ require ( github.com/go-openapi/jsonpointer v0.21.0 // indirect github.com/go-openapi/swag v0.23.0 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect github.com/goph/emperror v0.17.2 // indirect github.com/gorilla/css v1.0.1 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/invopop/yaml v0.1.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect @@ -105,7 +102,6 @@ require ( github.com/yuin/goldmark v1.7.8 // indirect github.com/yuin/goldmark-emoji v1.0.5 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect - go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 // indirect go.opentelemetry.io/otel v1.34.0 // indirect go.opentelemetry.io/otel/metric v1.34.0 // indirect @@ -115,9 +111,6 @@ require ( golang.org/x/arch v0.12.0 // indirect golang.org/x/crypto v0.36.0 // indirect golang.org/x/net v0.37.0 // indirect - golang.org/x/oauth2 v0.28.0 // indirect - golang.org/x/time v0.11.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4 // indirect google.golang.org/grpc v1.71.0 // indirect google.golang.org/protobuf v1.36.6 // indirect diff --git a/go.sum b/go.sum index 71c2c6b4..6d4a73fa 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,9 @@ cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= -cloud.google.com/go/ai v0.8.0 h1:rXUEz8Wp2OlrM8r1bfmpF2+VKqc1VJpafE3HgzRnD/w= -cloud.google.com/go/ai v0.8.0/go.mod h1:t3Dfk4cM61sytiggo2UyGsDVW3RF1qGZaUKDrZFyqkE= cloud.google.com/go/auth v0.15.0 h1:Ly0u4aA5vG/fsSsxu98qCQBemXtAtJf+95z9HK+cxps= cloud.google.com/go/auth v0.15.0/go.mod h1:WJDGqZ1o9E9wKIL+IwStfyn/+s59zl4Bi+1KQNVXLZ8= -cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= -cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= -cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU= -cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng= github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o= @@ -99,8 +93,6 @@ github.com/cloudwego/eino v0.3.41 h1:bb6W5/+8QE+jlhDBY2TxcxYu6odpNQ436oDdBF45jgQ github.com/cloudwego/eino v0.3.41/go.mod h1:wUjz990apdsaOraOXdh6CdhVXq8DJsOvLsVlxNTcNfY= github.com/cloudwego/eino-ext/components/model/claude v0.0.0-20250609074000-b7f307dffa18 h1:foS8HJW0U8KOYy1hyWITSZUeMBZbckaBUjQQUIPwauw= github.com/cloudwego/eino-ext/components/model/claude v0.0.0-20250609074000-b7f307dffa18/go.mod h1:IhvvyyldQVIyOogbEbmmOjrxhZbpvbwjeiQJ4ZoNOX4= -github.com/cloudwego/eino-ext/components/model/gemini v0.0.0-20250609074000-b7f307dffa18 h1:pyiKX7sTo9BgtbOq/twj5olsMl/iKj5RTIpDqgAtdBM= -github.com/cloudwego/eino-ext/components/model/gemini v0.0.0-20250609074000-b7f307dffa18/go.mod h1:NTYXf6aAoO2zBES9S1lzkBvQoyD6UcUGvLmUAS5TMRU= github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250609074000-b7f307dffa18 h1:UxZVTapUwbzkRIP4Bl/VKni65wI+sfq6oPveZ+76aww= github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250609074000-b7f307dffa18/go.mod h1:giNUFqA+V7xrm/EDvH7JFnDqoWI+e2m1SVAnReU+Fd8= github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250609074000-b7f307dffa18 h1:FAg2QmtJ0tA/3BmQlUPdhZ9Nzzsov76ry7b3Gn86dAs= @@ -151,8 +143,6 @@ github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg= -github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -171,6 +161,8 @@ github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfre github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= @@ -331,8 +323,6 @@ github.com/yuin/goldmark-emoji v1.0.5 h1:EMVWyCGPlXJfUXBXpuMu+ii3TIaxbVBnEX9uaDC github.com/yuin/goldmark-emoji v1.0.5/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 h1:rgMkmiGfix9vFJDcDi1PK8WEQP4FLQwLDfhp5ZLpFeE= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0/go.mod h1:ijPqXp5P6IRRByFVVg9DY8P5HkxkHE5ARIa+86aXPf4= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 h1:CV7UdSGJt/Ao6Gp4CXckLxVRRsRgDHoI8XjbL3PDl8s= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0/go.mod h1:FRmFuRJfag1IZ2dPkHnEoSFVgTVPUd2qf5Vi69hLb8I= go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY= @@ -361,8 +351,6 @@ golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUU golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= -golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc= -golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= @@ -378,12 +366,8 @@ golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= -golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= -golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= -google.golang.org/api v0.228.0 h1:X2DJ/uoWGnY5obVjewbp8icSL5U4FzuCfy9OjbLSnLs= -google.golang.org/api v0.228.0/go.mod h1:wNvRS1Pbe8r4+IfBIniV8fwCpGwTrYa+kMUDiC5z5a4= -google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 h1:GVIKPyP/kLIyVOgOnTwFOrvQaQUzOzGMCxgFUOEmm24= -google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422/go.mod h1:b6h1vNKhxaSoEI+5jc3PJUCustfli/mRab7295pY7rw= +google.golang.org/genai v1.10.0 h1:ETP0Yksn5KUSEn5+ihMOnP3IqjZ+7Z4i0LjJslEXatI= +google.golang.org/genai v1.10.0/go.mod h1:TyfOKRz/QyCaj6f/ZDt505x+YreXnY40l2I6k8TvgqY= google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4 h1:iK2jbkWL86DXjEx0qiHcRE9dE4/Ahua5k6V8OWFb//c= google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4/go.mod h1:LuRYeWDFV6WOn90g357N17oMCaxpgCnbi/44qJvDn2I= google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg= diff --git a/internal/models/gemini/gemini.go b/internal/models/gemini/gemini.go new file mode 100644 index 00000000..510c005b --- /dev/null +++ b/internal/models/gemini/gemini.go @@ -0,0 +1,704 @@ +package gemini + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "runtime/debug" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/schema" + "github.com/getkin/kin-openapi/openapi3" + "google.golang.org/genai" +) + +var _ model.ToolCallingChatModel = (*ChatModel)(nil) + +// NewChatModel creates a new Gemini chat model instance +// +// Parameters: +// - ctx: The context for the operation +// - cfg: Configuration for the Gemini model +// +// Returns: +// - model.ChatModel: A chat model interface implementation +// - error: Any error that occurred during creation +// +// Example: +// +// model, err := gemini.NewChatModel(ctx, &gemini.Config{ +// Client: client, +// Model: "gemini-pro", +// }) +func NewChatModel(_ context.Context, cfg *Config) (*ChatModel, error) { + return &ChatModel{ + cli: cfg.Client, + model: cfg.Model, + maxTokens: cfg.MaxTokens, + temperature: cfg.Temperature, + topP: cfg.TopP, + topK: cfg.TopK, + responseSchema: cfg.ResponseSchema, + enableCodeExecution: cfg.EnableCodeExecution, + safetySettings: cfg.SafetySettings, + }, nil +} + +// Config contains the configuration options for the Gemini model +type Config struct { + // Client is the Gemini API client instance + // Required for making API calls to Gemini + Client *genai.Client + + // Model specifies which Gemini model to use + // Examples: "gemini-pro", "gemini-pro-vision", "gemini-1.5-flash" + Model string + + // MaxTokens limits the maximum number of tokens in the response + // Optional. Example: maxTokens := 100 + MaxTokens *int + + // Temperature controls randomness in responses + // Range: [0.0, 1.0], where 0.0 is more focused and 1.0 is more creative + // Optional. Example: temperature := float32(0.7) + Temperature *float32 + + // TopP controls diversity via nucleus sampling + // Range: [0.0, 1.0], where 1.0 disables nucleus sampling + // Optional. Example: topP := float32(0.95) + TopP *float32 + + // TopK controls diversity by limiting the top K tokens to sample from + // Optional. Example: topK := int32(40) + TopK *int32 + + // ResponseSchema defines the structure for JSON responses + // Optional. Used when you want structured output in JSON format + ResponseSchema *openapi3.Schema + + // EnableCodeExecution allows the model to execute code + // Warning: Be cautious with code execution in production + // Optional. Default: false + EnableCodeExecution bool + + // SafetySettings configures content filtering for different harm categories + // Controls the model's filtering behavior for potentially harmful content + // Optional. + SafetySettings []*genai.SafetySetting +} + +// options contains Gemini-specific options for model configuration +type options struct { + TopK *int32 + ResponseSchema *openapi3.Schema +} + +type ChatModel struct { + cli *genai.Client + + model string + maxTokens *int + topP *float32 + temperature *float32 + topK *int32 + responseSchema *openapi3.Schema + tools []*genai.Tool + origTools []*schema.ToolInfo + toolChoice *schema.ToolChoice + enableCodeExecution bool + safetySettings []*genai.SafetySetting +} + +func (cm *ChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (message *schema.Message, err error) { + ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel) + + config, conf, err := cm.buildGenerateConfig(opts...) + if err != nil { + return nil, err + } + + ctx = callbacks.OnStart(ctx, &model.CallbackInput{ + Messages: input, + Tools: model.GetCommonOptions(&model.Options{Tools: cm.origTools}, opts...).Tools, + Config: conf, + }) + defer func() { + if err != nil { + callbacks.OnError(ctx, err) + } + }() + + if len(input) == 0 { + return nil, fmt.Errorf("gemini input is empty") + } + + contents, err := cm.convertSchemaMessages(input) + if err != nil { + return nil, err + } + + result, err := cm.cli.Models.GenerateContent(ctx, cm.model, contents, config) + if err != nil { + return nil, fmt.Errorf("generate content failed: %w", err) + } + + message, err = cm.convertResponse(result) + if err != nil { + return nil, fmt.Errorf("convert response failed: %w", err) + } + + callbacks.OnEnd(ctx, cm.convertCallbackOutput(message, conf)) + return message, nil +} + +func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (result *schema.StreamReader[*schema.Message], err error) { + ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel) + + config, conf, err := cm.buildGenerateConfig(opts...) + if err != nil { + return nil, err + } + + ctx = callbacks.OnStart(ctx, &model.CallbackInput{ + Messages: input, + Tools: model.GetCommonOptions(&model.Options{Tools: cm.origTools}, opts...).Tools, + Config: conf, + }) + defer func() { + if err != nil { + callbacks.OnError(ctx, err) + } + }() + + if len(input) == 0 { + return nil, fmt.Errorf("gemini input is empty") + } + + contents, err := cm.convertSchemaMessages(input) + if err != nil { + return nil, err + } + + sr, sw := schema.Pipe[*model.CallbackOutput](1) + go func() { + defer func() { + panicErr := recover() + if panicErr != nil { + _ = sw.Send(nil, newPanicErr(panicErr, debug.Stack())) + } + sw.Close() + }() + + for resp, err := range cm.cli.Models.GenerateContentStream(ctx, cm.model, contents, config) { + if err != nil { + sw.Send(nil, err) + return + } + + message, err := cm.convertResponse(resp) + if err != nil { + sw.Send(nil, err) + return + } + + closed := sw.Send(cm.convertCallbackOutput(message, conf), nil) + if closed { + return + } + } + }() + + srList := sr.Copy(2) + callbacks.OnEndWithStreamOutput(ctx, srList[0]) + return schema.StreamReaderWithConvert(srList[1], func(t *model.CallbackOutput) (*schema.Message, error) { + return t.Message, nil + }), nil +} + +func (cm *ChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) { + if len(tools) == 0 { + return nil, errors.New("no tools to bind") + } + gTools, err := cm.convertToGeminiTools(tools) + if err != nil { + return nil, fmt.Errorf("convert to gemini tools failed: %w", err) + } + + tc := schema.ToolChoiceAllowed + ncm := *cm + ncm.toolChoice = &tc + ncm.tools = gTools + ncm.origTools = tools + return &ncm, nil +} + +func (cm *ChatModel) BindTools(tools []*schema.ToolInfo) error { + if len(tools) == 0 { + return errors.New("no tools to bind") + } + gTools, err := cm.convertToGeminiTools(tools) + if err != nil { + return err + } + + cm.tools = gTools + cm.origTools = tools + tc := schema.ToolChoiceAllowed + cm.toolChoice = &tc + return nil +} + +func (cm *ChatModel) BindForcedTools(tools []*schema.ToolInfo) error { + if len(tools) == 0 { + return errors.New("no tools to bind") + } + gTools, err := cm.convertToGeminiTools(tools) + if err != nil { + return err + } + + cm.tools = gTools + cm.origTools = tools + tc := schema.ToolChoiceForced + cm.toolChoice = &tc + return nil +} + +func (cm *ChatModel) buildGenerateConfig(opts ...model.Option) (*genai.GenerateContentConfig, *model.Config, error) { + commonOptions := model.GetCommonOptions(&model.Options{ + Temperature: cm.temperature, + MaxTokens: cm.maxTokens, + TopP: cm.topP, + Tools: nil, + ToolChoice: cm.toolChoice, + }, opts...) + geminiOptions := model.GetImplSpecificOptions(&options{ + TopK: cm.topK, + ResponseSchema: cm.responseSchema, + }, opts...) + + conf := &model.Config{} + config := &genai.GenerateContentConfig{} + + // Set model + if commonOptions.Model != nil { + conf.Model = *commonOptions.Model + } else { + conf.Model = cm.model + } + + // Set temperature + if commonOptions.Temperature != nil { + conf.Temperature = *commonOptions.Temperature + config.Temperature = commonOptions.Temperature + } else if cm.temperature != nil { + conf.Temperature = *cm.temperature + config.Temperature = cm.temperature + } + + // Set max tokens + if commonOptions.MaxTokens != nil { + conf.MaxTokens = *commonOptions.MaxTokens + config.MaxOutputTokens = int32(*commonOptions.MaxTokens) + } else if cm.maxTokens != nil { + conf.MaxTokens = *cm.maxTokens + config.MaxOutputTokens = int32(*cm.maxTokens) + } + + // Set top P + if commonOptions.TopP != nil { + conf.TopP = *commonOptions.TopP + config.TopP = commonOptions.TopP + } else if cm.topP != nil { + conf.TopP = *cm.topP + config.TopP = cm.topP + } + + // Set top K + if geminiOptions.TopK != nil { + config.TopK = genai.Ptr(float32(*geminiOptions.TopK)) + } else if cm.topK != nil { + config.TopK = genai.Ptr(float32(*cm.topK)) + } + + // Set tools + tools := cm.tools + if commonOptions.Tools != nil { + var err error + tools, err = cm.convertToGeminiTools(commonOptions.Tools) + if err != nil { + return nil, nil, err + } + } + if len(tools) > 0 { + config.Tools = tools + } + + // Set tool choice + if commonOptions.ToolChoice != nil { + switch *commonOptions.ToolChoice { + case schema.ToolChoiceForbidden: + config.ToolConfig = &genai.ToolConfig{ + FunctionCallingConfig: &genai.FunctionCallingConfig{ + Mode: genai.FunctionCallingConfigModeNone, + }, + } + case schema.ToolChoiceAllowed: + config.ToolConfig = &genai.ToolConfig{ + FunctionCallingConfig: &genai.FunctionCallingConfig{ + Mode: genai.FunctionCallingConfigModeAuto, + }, + } + case schema.ToolChoiceForced: + if len(tools) == 0 { + return nil, nil, fmt.Errorf("tool choice is forced but no tools provided") + } + config.ToolConfig = &genai.ToolConfig{ + FunctionCallingConfig: &genai.FunctionCallingConfig{ + Mode: genai.FunctionCallingConfigModeAny, + }, + } + default: + return nil, nil, fmt.Errorf("tool choice=%s not supported", *commonOptions.ToolChoice) + } + } + + // Set safety settings + if len(cm.safetySettings) > 0 { + config.SafetySettings = cm.safetySettings + } + + // Set response schema for JSON mode + if geminiOptions.ResponseSchema != nil { + gSchema, err := cm.convertOpenAPISchema(geminiOptions.ResponseSchema) + if err != nil { + return nil, nil, fmt.Errorf("convert response schema failed: %w", err) + } + config.ResponseMIMEType = "application/json" + config.ResponseSchema = gSchema + } + + return config, conf, nil +} + +func (cm *ChatModel) convertToGeminiTools(tools []*schema.ToolInfo) ([]*genai.Tool, error) { + if len(tools) == 0 { + return nil, nil + } + + var functionDeclarations []*genai.FunctionDeclaration + for _, tool := range tools { + openSchema, err := tool.ToOpenAPIV3() + if err != nil { + return nil, fmt.Errorf("get open schema failed: %w", err) + } + + gSchema, err := cm.convertOpenAPISchema(openSchema) + if err != nil { + return nil, fmt.Errorf("convert open schema failed: %w", err) + } + + funcDecl := &genai.FunctionDeclaration{ + Name: tool.Name, + Description: tool.Desc, + Parameters: gSchema, + } + functionDeclarations = append(functionDeclarations, funcDecl) + } + + return []*genai.Tool{{FunctionDeclarations: functionDeclarations}}, nil +} + +func (cm *ChatModel) convertOpenAPISchema(schema *openapi3.Schema) (*genai.Schema, error) { + if schema == nil { + return nil, nil + } + + result := &genai.Schema{ + Description: schema.Description, + } + + switch schema.Type { + case openapi3.TypeObject: + result.Type = genai.TypeObject + if schema.Properties != nil { + properties := make(map[string]*genai.Schema) + for name, prop := range schema.Properties { + if prop == nil || prop.Value == nil { + continue + } + propSchema, err := cm.convertOpenAPISchema(prop.Value) + if err != nil { + return nil, err + } + properties[name] = propSchema + } + result.Properties = properties + } + if schema.Required != nil { + result.Required = schema.Required + } + case openapi3.TypeArray: + result.Type = genai.TypeArray + if schema.Items != nil && schema.Items.Value != nil { + itemSchema, err := cm.convertOpenAPISchema(schema.Items.Value) + if err != nil { + return nil, err + } + result.Items = itemSchema + } + case openapi3.TypeString: + result.Type = genai.TypeString + if schema.Enum != nil { + enums := make([]string, 0, len(schema.Enum)) + for _, e := range schema.Enum { + if str, ok := e.(string); ok { + enums = append(enums, str) + } else { + return nil, fmt.Errorf("enum value must be a string, schema: %+v", schema) + } + } + result.Enum = enums + } + case openapi3.TypeNumber: + result.Type = genai.TypeNumber + case openapi3.TypeInteger: + result.Type = genai.TypeInteger + case openapi3.TypeBoolean: + result.Type = genai.TypeBoolean + default: + result.Type = genai.TypeUnspecified + } + + return result, nil +} + +func (cm *ChatModel) convertSchemaMessages(messages []*schema.Message) ([]*genai.Content, error) { + var contents []*genai.Content + for _, message := range messages { + content, err := cm.convertSchemaMessage(message) + if err != nil { + return nil, fmt.Errorf("convert schema message failed: %w", err) + } + if content != nil { + contents = append(contents, content) + } + } + return contents, nil +} + +func (cm *ChatModel) convertSchemaMessage(message *schema.Message) (*genai.Content, error) { + if message == nil { + return nil, nil + } + + var parts []*genai.Part + + // Handle tool calls + if message.ToolCalls != nil { + for _, call := range message.ToolCalls { + var args map[string]any + if err := json.Unmarshal([]byte(call.Function.Arguments), &args); err != nil { + return nil, fmt.Errorf("unmarshal tool call arguments failed: %w", err) + } + parts = append(parts, &genai.Part{ + FunctionCall: &genai.FunctionCall{ + Name: call.Function.Name, + Args: args, + }, + }) + } + } + + // Handle tool responses + if message.Role == schema.Tool { + var response map[string]any + if err := json.Unmarshal([]byte(message.Content), &response); err != nil { + return nil, fmt.Errorf("unmarshal tool response failed: %w", err) + } + parts = append(parts, &genai.Part{ + FunctionResponse: &genai.FunctionResponse{ + Name: message.ToolCallID, + Response: response, + }, + }) + } else { + // Handle text content + if message.Content != "" { + parts = append(parts, &genai.Part{Text: message.Content}) + } + + // Handle multi-content (images, audio, etc.) + for _, content := range message.MultiContent { + switch content.Type { + case schema.ChatMessagePartTypeText: + parts = append(parts, &genai.Part{Text: content.Text}) + case schema.ChatMessagePartTypeImageURL: + if content.ImageURL != nil { + parts = append(parts, &genai.Part{ + FileData: &genai.FileData{ + MIMEType: content.ImageURL.MIMEType, + FileURI: content.ImageURL.URI, + }, + }) + } + case schema.ChatMessagePartTypeAudioURL: + if content.AudioURL != nil { + parts = append(parts, &genai.Part{ + FileData: &genai.FileData{ + MIMEType: content.AudioURL.MIMEType, + FileURI: content.AudioURL.URI, + }, + }) + } + case schema.ChatMessagePartTypeVideoURL: + if content.VideoURL != nil { + parts = append(parts, &genai.Part{ + FileData: &genai.FileData{ + MIMEType: content.VideoURL.MIMEType, + FileURI: content.VideoURL.URI, + }, + }) + } + case schema.ChatMessagePartTypeFileURL: + if content.FileURL != nil { + parts = append(parts, &genai.Part{ + FileData: &genai.FileData{ + MIMEType: content.FileURL.MIMEType, + FileURI: content.FileURL.URI, + }, + }) + } + } + } + } + + if len(parts) == 0 { + return nil, nil + } + + return &genai.Content{ + Role: string(cm.convertRole(message.Role)), + Parts: parts, + }, nil +} + +func (cm *ChatModel) convertRole(role schema.RoleType) genai.Role { + switch role { + case schema.Assistant: + return genai.RoleModel + case schema.User: + return genai.RoleUser + case schema.Tool: + return genai.RoleUser // Tool responses are treated as user messages in the new API + default: + return genai.RoleUser + } +} + +func (cm *ChatModel) convertResponse(resp *genai.GenerateContentResponse) (*schema.Message, error) { + if len(resp.Candidates) == 0 { + return nil, fmt.Errorf("gemini result is empty") + } + + candidate := resp.Candidates[0] + message := &schema.Message{ + Role: schema.Assistant, + ResponseMeta: &schema.ResponseMeta{ + FinishReason: string(candidate.FinishReason), + }, + } + + // Handle usage metadata + if resp.UsageMetadata != nil { + message.ResponseMeta.Usage = &schema.TokenUsage{ + PromptTokens: int(resp.UsageMetadata.PromptTokenCount), + CompletionTokens: int(resp.UsageMetadata.CandidatesTokenCount), + TotalTokens: int(resp.UsageMetadata.TotalTokenCount), + } + } + + // Process content parts + var textParts []string + for _, part := range candidate.Content.Parts { + switch { + case part.Text != "": + textParts = append(textParts, part.Text) + case part.FunctionCall != nil: + args, err := json.Marshal(part.FunctionCall.Args) + if err != nil { + return nil, fmt.Errorf("marshal function call arguments failed: %w", err) + } + message.ToolCalls = append(message.ToolCalls, schema.ToolCall{ + ID: part.FunctionCall.Name, + Function: schema.FunctionCall{ + Name: part.FunctionCall.Name, + Arguments: string(args), + }, + }) + case part.ExecutableCode != nil: + textParts = append(textParts, part.ExecutableCode.Code) + case part.CodeExecutionResult != nil: + textParts = append(textParts, part.CodeExecutionResult.Output) + } + } + + // Set content + if len(textParts) == 1 { + message.Content = textParts[0] + } else if len(textParts) > 1 { + for _, text := range textParts { + message.MultiContent = append(message.MultiContent, schema.ChatMessagePart{ + Type: schema.ChatMessagePartTypeText, + Text: text, + }) + } + } + + return message, nil +} + +func (cm *ChatModel) convertCallbackOutput(message *schema.Message, conf *model.Config) *model.CallbackOutput { + callbackOutput := &model.CallbackOutput{ + Message: message, + Config: conf, + } + if message.ResponseMeta != nil && message.ResponseMeta.Usage != nil { + callbackOutput.TokenUsage = &model.TokenUsage{ + PromptTokens: message.ResponseMeta.Usage.PromptTokens, + CompletionTokens: message.ResponseMeta.Usage.CompletionTokens, + TotalTokens: message.ResponseMeta.Usage.TotalTokens, + } + } + return callbackOutput +} + +func (cm *ChatModel) IsCallbacksEnabled() bool { + return true +} + +const typ = "Gemini" + +func (cm *ChatModel) GetType() string { + return typ +} + +type panicErr struct { + info any + stack []byte +} + +func (p *panicErr) Error() string { + return fmt.Sprintf("panic error: %v, \nstack: %s", p.info, string(p.stack)) +} + +func newPanicErr(info any, stack []byte) error { + return &panicErr{ + info: info, + stack: stack, + } +} + diff --git a/internal/models/providers.go b/internal/models/providers.go index f66afaba..f5361342 100644 --- a/internal/models/providers.go +++ b/internal/models/providers.go @@ -7,12 +7,12 @@ import ( "strings" "github.com/cloudwego/eino-ext/components/model/claude" - "github.com/cloudwego/eino-ext/components/model/gemini" "github.com/cloudwego/eino-ext/components/model/ollama" "github.com/cloudwego/eino-ext/components/model/openai" "github.com/cloudwego/eino/components/model" - "github.com/google/generative-ai-go/genai" - "google.golang.org/api/option" + "google.golang.org/genai" + + "github.com/mark3labs/mcphost/internal/models/gemini" ) // ProviderConfig holds configuration for creating LLM providers @@ -105,7 +105,10 @@ func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName return nil, fmt.Errorf("Google API key not provided. Use --google-api-key flag or GOOGLE_API_KEY/GEMINI_API_KEY environment variable") } - client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey)) + client, err := genai.NewClient(ctx, &genai.ClientConfig{ + APIKey: apiKey, + Backend: genai.BackendGeminiAPI, + }) if err != nil { return nil, fmt.Errorf("failed to create Google client: %v", err) }