fix gemini

This commit is contained in:
Ed Zynda
2025-06-10 14:20:40 +03:00
parent 57efdb5332
commit 82b50fbf0e
4 changed files with 718 additions and 34 deletions
+3 -10
View File
@@ -10,26 +10,21 @@ require (
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834
github.com/cloudwego/eino v0.3.41
github.com/cloudwego/eino-ext/components/model/claude v0.0.0-20250609074000-b7f307dffa18
github.com/cloudwego/eino-ext/components/model/gemini v0.0.0-20250609074000-b7f307dffa18
github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250609074000-b7f307dffa18
github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250609074000-b7f307dffa18
github.com/getkin/kin-openapi v0.118.0
github.com/google/generative-ai-go v0.19.0
github.com/mark3labs/mcp-go v0.31.0
github.com/spf13/cobra v1.8.1
github.com/spf13/viper v1.20.1
golang.org/x/term v0.31.0
google.golang.org/api v0.228.0
google.golang.org/genai v1.10.0
gopkg.in/yaml.v3 v3.0.1
)
require (
cloud.google.com/go v0.116.0 // indirect
cloud.google.com/go/ai v0.8.0 // indirect
cloud.google.com/go/auth v0.15.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.6.0 // indirect
cloud.google.com/go/longrunning v0.5.7 // indirect
github.com/alecthomas/chroma/v2 v2.14.0 // indirect
github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.8 // indirect
github.com/atotto/clipboard v0.1.4 // indirect
@@ -65,12 +60,14 @@ require (
github.com/go-openapi/jsonpointer v0.21.0 // indirect
github.com/go-openapi/swag v0.23.0 // indirect
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
github.com/goph/emperror v0.17.2 // indirect
github.com/gorilla/css v1.0.1 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/invopop/yaml v0.1.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
@@ -105,7 +102,6 @@ require (
github.com/yuin/goldmark v1.7.8 // indirect
github.com/yuin/goldmark-emoji v1.0.5 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 // indirect
go.opentelemetry.io/otel v1.34.0 // indirect
go.opentelemetry.io/otel/metric v1.34.0 // indirect
@@ -115,9 +111,6 @@ require (
golang.org/x/arch v0.12.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.37.0 // indirect
golang.org/x/oauth2 v0.28.0 // indirect
golang.org/x/time v0.11.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4 // indirect
google.golang.org/grpc v1.71.0 // indirect
google.golang.org/protobuf v1.36.6 // indirect
+4 -20
View File
@@ -1,15 +1,9 @@
cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE=
cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U=
cloud.google.com/go/ai v0.8.0 h1:rXUEz8Wp2OlrM8r1bfmpF2+VKqc1VJpafE3HgzRnD/w=
cloud.google.com/go/ai v0.8.0/go.mod h1:t3Dfk4cM61sytiggo2UyGsDVW3RF1qGZaUKDrZFyqkE=
cloud.google.com/go/auth v0.15.0 h1:Ly0u4aA5vG/fsSsxu98qCQBemXtAtJf+95z9HK+cxps=
cloud.google.com/go/auth v0.15.0/go.mod h1:WJDGqZ1o9E9wKIL+IwStfyn/+s59zl4Bi+1KQNVXLZ8=
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU=
cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng=
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o=
@@ -99,8 +93,6 @@ github.com/cloudwego/eino v0.3.41 h1:bb6W5/+8QE+jlhDBY2TxcxYu6odpNQ436oDdBF45jgQ
github.com/cloudwego/eino v0.3.41/go.mod h1:wUjz990apdsaOraOXdh6CdhVXq8DJsOvLsVlxNTcNfY=
github.com/cloudwego/eino-ext/components/model/claude v0.0.0-20250609074000-b7f307dffa18 h1:foS8HJW0U8KOYy1hyWITSZUeMBZbckaBUjQQUIPwauw=
github.com/cloudwego/eino-ext/components/model/claude v0.0.0-20250609074000-b7f307dffa18/go.mod h1:IhvvyyldQVIyOogbEbmmOjrxhZbpvbwjeiQJ4ZoNOX4=
github.com/cloudwego/eino-ext/components/model/gemini v0.0.0-20250609074000-b7f307dffa18 h1:pyiKX7sTo9BgtbOq/twj5olsMl/iKj5RTIpDqgAtdBM=
github.com/cloudwego/eino-ext/components/model/gemini v0.0.0-20250609074000-b7f307dffa18/go.mod h1:NTYXf6aAoO2zBES9S1lzkBvQoyD6UcUGvLmUAS5TMRU=
github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250609074000-b7f307dffa18 h1:UxZVTapUwbzkRIP4Bl/VKni65wI+sfq6oPveZ+76aww=
github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250609074000-b7f307dffa18/go.mod h1:giNUFqA+V7xrm/EDvH7JFnDqoWI+e2m1SVAnReU+Fd8=
github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250609074000-b7f307dffa18 h1:FAg2QmtJ0tA/3BmQlUPdhZ9Nzzsov76ry7b3Gn86dAs=
@@ -151,8 +143,6 @@ github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg=
github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -171,6 +161,8 @@ github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfre
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
@@ -331,8 +323,6 @@ github.com/yuin/goldmark-emoji v1.0.5 h1:EMVWyCGPlXJfUXBXpuMu+ii3TIaxbVBnEX9uaDC
github.com/yuin/goldmark-emoji v1.0.5/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 h1:rgMkmiGfix9vFJDcDi1PK8WEQP4FLQwLDfhp5ZLpFeE=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0/go.mod h1:ijPqXp5P6IRRByFVVg9DY8P5HkxkHE5ARIa+86aXPf4=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 h1:CV7UdSGJt/Ao6Gp4CXckLxVRRsRgDHoI8XjbL3PDl8s=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0/go.mod h1:FRmFuRJfag1IZ2dPkHnEoSFVgTVPUd2qf5Vi69hLb8I=
go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY=
@@ -361,8 +351,6 @@ golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUU
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc=
golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
@@ -378,12 +366,8 @@ golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
google.golang.org/api v0.228.0 h1:X2DJ/uoWGnY5obVjewbp8icSL5U4FzuCfy9OjbLSnLs=
google.golang.org/api v0.228.0/go.mod h1:wNvRS1Pbe8r4+IfBIniV8fwCpGwTrYa+kMUDiC5z5a4=
google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 h1:GVIKPyP/kLIyVOgOnTwFOrvQaQUzOzGMCxgFUOEmm24=
google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422/go.mod h1:b6h1vNKhxaSoEI+5jc3PJUCustfli/mRab7295pY7rw=
google.golang.org/genai v1.10.0 h1:ETP0Yksn5KUSEn5+ihMOnP3IqjZ+7Z4i0LjJslEXatI=
google.golang.org/genai v1.10.0/go.mod h1:TyfOKRz/QyCaj6f/ZDt505x+YreXnY40l2I6k8TvgqY=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4 h1:iK2jbkWL86DXjEx0qiHcRE9dE4/Ahua5k6V8OWFb//c=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4/go.mod h1:LuRYeWDFV6WOn90g357N17oMCaxpgCnbi/44qJvDn2I=
google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg=
+704
View File
@@ -0,0 +1,704 @@
package gemini
import (
"context"
"encoding/json"
"errors"
"fmt"
"runtime/debug"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"github.com/getkin/kin-openapi/openapi3"
"google.golang.org/genai"
)
var _ model.ToolCallingChatModel = (*ChatModel)(nil)
// NewChatModel creates a new Gemini chat model instance
//
// Parameters:
// - ctx: The context for the operation
// - cfg: Configuration for the Gemini model
//
// Returns:
// - model.ChatModel: A chat model interface implementation
// - error: Any error that occurred during creation
//
// Example:
//
// model, err := gemini.NewChatModel(ctx, &gemini.Config{
// Client: client,
// Model: "gemini-pro",
// })
func NewChatModel(_ context.Context, cfg *Config) (*ChatModel, error) {
return &ChatModel{
cli: cfg.Client,
model: cfg.Model,
maxTokens: cfg.MaxTokens,
temperature: cfg.Temperature,
topP: cfg.TopP,
topK: cfg.TopK,
responseSchema: cfg.ResponseSchema,
enableCodeExecution: cfg.EnableCodeExecution,
safetySettings: cfg.SafetySettings,
}, nil
}
// Config contains the configuration options for the Gemini model
type Config struct {
// Client is the Gemini API client instance
// Required for making API calls to Gemini
Client *genai.Client
// Model specifies which Gemini model to use
// Examples: "gemini-pro", "gemini-pro-vision", "gemini-1.5-flash"
Model string
// MaxTokens limits the maximum number of tokens in the response
// Optional. Example: maxTokens := 100
MaxTokens *int
// Temperature controls randomness in responses
// Range: [0.0, 1.0], where 0.0 is more focused and 1.0 is more creative
// Optional. Example: temperature := float32(0.7)
Temperature *float32
// TopP controls diversity via nucleus sampling
// Range: [0.0, 1.0], where 1.0 disables nucleus sampling
// Optional. Example: topP := float32(0.95)
TopP *float32
// TopK controls diversity by limiting the top K tokens to sample from
// Optional. Example: topK := int32(40)
TopK *int32
// ResponseSchema defines the structure for JSON responses
// Optional. Used when you want structured output in JSON format
ResponseSchema *openapi3.Schema
// EnableCodeExecution allows the model to execute code
// Warning: Be cautious with code execution in production
// Optional. Default: false
EnableCodeExecution bool
// SafetySettings configures content filtering for different harm categories
// Controls the model's filtering behavior for potentially harmful content
// Optional.
SafetySettings []*genai.SafetySetting
}
// options contains Gemini-specific options for model configuration
type options struct {
TopK *int32
ResponseSchema *openapi3.Schema
}
type ChatModel struct {
cli *genai.Client
model string
maxTokens *int
topP *float32
temperature *float32
topK *int32
responseSchema *openapi3.Schema
tools []*genai.Tool
origTools []*schema.ToolInfo
toolChoice *schema.ToolChoice
enableCodeExecution bool
safetySettings []*genai.SafetySetting
}
func (cm *ChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (message *schema.Message, err error) {
ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel)
config, conf, err := cm.buildGenerateConfig(opts...)
if err != nil {
return nil, err
}
ctx = callbacks.OnStart(ctx, &model.CallbackInput{
Messages: input,
Tools: model.GetCommonOptions(&model.Options{Tools: cm.origTools}, opts...).Tools,
Config: conf,
})
defer func() {
if err != nil {
callbacks.OnError(ctx, err)
}
}()
if len(input) == 0 {
return nil, fmt.Errorf("gemini input is empty")
}
contents, err := cm.convertSchemaMessages(input)
if err != nil {
return nil, err
}
result, err := cm.cli.Models.GenerateContent(ctx, cm.model, contents, config)
if err != nil {
return nil, fmt.Errorf("generate content failed: %w", err)
}
message, err = cm.convertResponse(result)
if err != nil {
return nil, fmt.Errorf("convert response failed: %w", err)
}
callbacks.OnEnd(ctx, cm.convertCallbackOutput(message, conf))
return message, nil
}
func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (result *schema.StreamReader[*schema.Message], err error) {
ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel)
config, conf, err := cm.buildGenerateConfig(opts...)
if err != nil {
return nil, err
}
ctx = callbacks.OnStart(ctx, &model.CallbackInput{
Messages: input,
Tools: model.GetCommonOptions(&model.Options{Tools: cm.origTools}, opts...).Tools,
Config: conf,
})
defer func() {
if err != nil {
callbacks.OnError(ctx, err)
}
}()
if len(input) == 0 {
return nil, fmt.Errorf("gemini input is empty")
}
contents, err := cm.convertSchemaMessages(input)
if err != nil {
return nil, err
}
sr, sw := schema.Pipe[*model.CallbackOutput](1)
go func() {
defer func() {
panicErr := recover()
if panicErr != nil {
_ = sw.Send(nil, newPanicErr(panicErr, debug.Stack()))
}
sw.Close()
}()
for resp, err := range cm.cli.Models.GenerateContentStream(ctx, cm.model, contents, config) {
if err != nil {
sw.Send(nil, err)
return
}
message, err := cm.convertResponse(resp)
if err != nil {
sw.Send(nil, err)
return
}
closed := sw.Send(cm.convertCallbackOutput(message, conf), nil)
if closed {
return
}
}
}()
srList := sr.Copy(2)
callbacks.OnEndWithStreamOutput(ctx, srList[0])
return schema.StreamReaderWithConvert(srList[1], func(t *model.CallbackOutput) (*schema.Message, error) {
return t.Message, nil
}), nil
}
func (cm *ChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
if len(tools) == 0 {
return nil, errors.New("no tools to bind")
}
gTools, err := cm.convertToGeminiTools(tools)
if err != nil {
return nil, fmt.Errorf("convert to gemini tools failed: %w", err)
}
tc := schema.ToolChoiceAllowed
ncm := *cm
ncm.toolChoice = &tc
ncm.tools = gTools
ncm.origTools = tools
return &ncm, nil
}
func (cm *ChatModel) BindTools(tools []*schema.ToolInfo) error {
if len(tools) == 0 {
return errors.New("no tools to bind")
}
gTools, err := cm.convertToGeminiTools(tools)
if err != nil {
return err
}
cm.tools = gTools
cm.origTools = tools
tc := schema.ToolChoiceAllowed
cm.toolChoice = &tc
return nil
}
func (cm *ChatModel) BindForcedTools(tools []*schema.ToolInfo) error {
if len(tools) == 0 {
return errors.New("no tools to bind")
}
gTools, err := cm.convertToGeminiTools(tools)
if err != nil {
return err
}
cm.tools = gTools
cm.origTools = tools
tc := schema.ToolChoiceForced
cm.toolChoice = &tc
return nil
}
func (cm *ChatModel) buildGenerateConfig(opts ...model.Option) (*genai.GenerateContentConfig, *model.Config, error) {
commonOptions := model.GetCommonOptions(&model.Options{
Temperature: cm.temperature,
MaxTokens: cm.maxTokens,
TopP: cm.topP,
Tools: nil,
ToolChoice: cm.toolChoice,
}, opts...)
geminiOptions := model.GetImplSpecificOptions(&options{
TopK: cm.topK,
ResponseSchema: cm.responseSchema,
}, opts...)
conf := &model.Config{}
config := &genai.GenerateContentConfig{}
// Set model
if commonOptions.Model != nil {
conf.Model = *commonOptions.Model
} else {
conf.Model = cm.model
}
// Set temperature
if commonOptions.Temperature != nil {
conf.Temperature = *commonOptions.Temperature
config.Temperature = commonOptions.Temperature
} else if cm.temperature != nil {
conf.Temperature = *cm.temperature
config.Temperature = cm.temperature
}
// Set max tokens
if commonOptions.MaxTokens != nil {
conf.MaxTokens = *commonOptions.MaxTokens
config.MaxOutputTokens = int32(*commonOptions.MaxTokens)
} else if cm.maxTokens != nil {
conf.MaxTokens = *cm.maxTokens
config.MaxOutputTokens = int32(*cm.maxTokens)
}
// Set top P
if commonOptions.TopP != nil {
conf.TopP = *commonOptions.TopP
config.TopP = commonOptions.TopP
} else if cm.topP != nil {
conf.TopP = *cm.topP
config.TopP = cm.topP
}
// Set top K
if geminiOptions.TopK != nil {
config.TopK = genai.Ptr(float32(*geminiOptions.TopK))
} else if cm.topK != nil {
config.TopK = genai.Ptr(float32(*cm.topK))
}
// Set tools
tools := cm.tools
if commonOptions.Tools != nil {
var err error
tools, err = cm.convertToGeminiTools(commonOptions.Tools)
if err != nil {
return nil, nil, err
}
}
if len(tools) > 0 {
config.Tools = tools
}
// Set tool choice
if commonOptions.ToolChoice != nil {
switch *commonOptions.ToolChoice {
case schema.ToolChoiceForbidden:
config.ToolConfig = &genai.ToolConfig{
FunctionCallingConfig: &genai.FunctionCallingConfig{
Mode: genai.FunctionCallingConfigModeNone,
},
}
case schema.ToolChoiceAllowed:
config.ToolConfig = &genai.ToolConfig{
FunctionCallingConfig: &genai.FunctionCallingConfig{
Mode: genai.FunctionCallingConfigModeAuto,
},
}
case schema.ToolChoiceForced:
if len(tools) == 0 {
return nil, nil, fmt.Errorf("tool choice is forced but no tools provided")
}
config.ToolConfig = &genai.ToolConfig{
FunctionCallingConfig: &genai.FunctionCallingConfig{
Mode: genai.FunctionCallingConfigModeAny,
},
}
default:
return nil, nil, fmt.Errorf("tool choice=%s not supported", *commonOptions.ToolChoice)
}
}
// Set safety settings
if len(cm.safetySettings) > 0 {
config.SafetySettings = cm.safetySettings
}
// Set response schema for JSON mode
if geminiOptions.ResponseSchema != nil {
gSchema, err := cm.convertOpenAPISchema(geminiOptions.ResponseSchema)
if err != nil {
return nil, nil, fmt.Errorf("convert response schema failed: %w", err)
}
config.ResponseMIMEType = "application/json"
config.ResponseSchema = gSchema
}
return config, conf, nil
}
func (cm *ChatModel) convertToGeminiTools(tools []*schema.ToolInfo) ([]*genai.Tool, error) {
if len(tools) == 0 {
return nil, nil
}
var functionDeclarations []*genai.FunctionDeclaration
for _, tool := range tools {
openSchema, err := tool.ToOpenAPIV3()
if err != nil {
return nil, fmt.Errorf("get open schema failed: %w", err)
}
gSchema, err := cm.convertOpenAPISchema(openSchema)
if err != nil {
return nil, fmt.Errorf("convert open schema failed: %w", err)
}
funcDecl := &genai.FunctionDeclaration{
Name: tool.Name,
Description: tool.Desc,
Parameters: gSchema,
}
functionDeclarations = append(functionDeclarations, funcDecl)
}
return []*genai.Tool{{FunctionDeclarations: functionDeclarations}}, nil
}
func (cm *ChatModel) convertOpenAPISchema(schema *openapi3.Schema) (*genai.Schema, error) {
if schema == nil {
return nil, nil
}
result := &genai.Schema{
Description: schema.Description,
}
switch schema.Type {
case openapi3.TypeObject:
result.Type = genai.TypeObject
if schema.Properties != nil {
properties := make(map[string]*genai.Schema)
for name, prop := range schema.Properties {
if prop == nil || prop.Value == nil {
continue
}
propSchema, err := cm.convertOpenAPISchema(prop.Value)
if err != nil {
return nil, err
}
properties[name] = propSchema
}
result.Properties = properties
}
if schema.Required != nil {
result.Required = schema.Required
}
case openapi3.TypeArray:
result.Type = genai.TypeArray
if schema.Items != nil && schema.Items.Value != nil {
itemSchema, err := cm.convertOpenAPISchema(schema.Items.Value)
if err != nil {
return nil, err
}
result.Items = itemSchema
}
case openapi3.TypeString:
result.Type = genai.TypeString
if schema.Enum != nil {
enums := make([]string, 0, len(schema.Enum))
for _, e := range schema.Enum {
if str, ok := e.(string); ok {
enums = append(enums, str)
} else {
return nil, fmt.Errorf("enum value must be a string, schema: %+v", schema)
}
}
result.Enum = enums
}
case openapi3.TypeNumber:
result.Type = genai.TypeNumber
case openapi3.TypeInteger:
result.Type = genai.TypeInteger
case openapi3.TypeBoolean:
result.Type = genai.TypeBoolean
default:
result.Type = genai.TypeUnspecified
}
return result, nil
}
func (cm *ChatModel) convertSchemaMessages(messages []*schema.Message) ([]*genai.Content, error) {
var contents []*genai.Content
for _, message := range messages {
content, err := cm.convertSchemaMessage(message)
if err != nil {
return nil, fmt.Errorf("convert schema message failed: %w", err)
}
if content != nil {
contents = append(contents, content)
}
}
return contents, nil
}
func (cm *ChatModel) convertSchemaMessage(message *schema.Message) (*genai.Content, error) {
if message == nil {
return nil, nil
}
var parts []*genai.Part
// Handle tool calls
if message.ToolCalls != nil {
for _, call := range message.ToolCalls {
var args map[string]any
if err := json.Unmarshal([]byte(call.Function.Arguments), &args); err != nil {
return nil, fmt.Errorf("unmarshal tool call arguments failed: %w", err)
}
parts = append(parts, &genai.Part{
FunctionCall: &genai.FunctionCall{
Name: call.Function.Name,
Args: args,
},
})
}
}
// Handle tool responses
if message.Role == schema.Tool {
var response map[string]any
if err := json.Unmarshal([]byte(message.Content), &response); err != nil {
return nil, fmt.Errorf("unmarshal tool response failed: %w", err)
}
parts = append(parts, &genai.Part{
FunctionResponse: &genai.FunctionResponse{
Name: message.ToolCallID,
Response: response,
},
})
} else {
// Handle text content
if message.Content != "" {
parts = append(parts, &genai.Part{Text: message.Content})
}
// Handle multi-content (images, audio, etc.)
for _, content := range message.MultiContent {
switch content.Type {
case schema.ChatMessagePartTypeText:
parts = append(parts, &genai.Part{Text: content.Text})
case schema.ChatMessagePartTypeImageURL:
if content.ImageURL != nil {
parts = append(parts, &genai.Part{
FileData: &genai.FileData{
MIMEType: content.ImageURL.MIMEType,
FileURI: content.ImageURL.URI,
},
})
}
case schema.ChatMessagePartTypeAudioURL:
if content.AudioURL != nil {
parts = append(parts, &genai.Part{
FileData: &genai.FileData{
MIMEType: content.AudioURL.MIMEType,
FileURI: content.AudioURL.URI,
},
})
}
case schema.ChatMessagePartTypeVideoURL:
if content.VideoURL != nil {
parts = append(parts, &genai.Part{
FileData: &genai.FileData{
MIMEType: content.VideoURL.MIMEType,
FileURI: content.VideoURL.URI,
},
})
}
case schema.ChatMessagePartTypeFileURL:
if content.FileURL != nil {
parts = append(parts, &genai.Part{
FileData: &genai.FileData{
MIMEType: content.FileURL.MIMEType,
FileURI: content.FileURL.URI,
},
})
}
}
}
}
if len(parts) == 0 {
return nil, nil
}
return &genai.Content{
Role: string(cm.convertRole(message.Role)),
Parts: parts,
}, nil
}
func (cm *ChatModel) convertRole(role schema.RoleType) genai.Role {
switch role {
case schema.Assistant:
return genai.RoleModel
case schema.User:
return genai.RoleUser
case schema.Tool:
return genai.RoleUser // Tool responses are treated as user messages in the new API
default:
return genai.RoleUser
}
}
func (cm *ChatModel) convertResponse(resp *genai.GenerateContentResponse) (*schema.Message, error) {
if len(resp.Candidates) == 0 {
return nil, fmt.Errorf("gemini result is empty")
}
candidate := resp.Candidates[0]
message := &schema.Message{
Role: schema.Assistant,
ResponseMeta: &schema.ResponseMeta{
FinishReason: string(candidate.FinishReason),
},
}
// Handle usage metadata
if resp.UsageMetadata != nil {
message.ResponseMeta.Usage = &schema.TokenUsage{
PromptTokens: int(resp.UsageMetadata.PromptTokenCount),
CompletionTokens: int(resp.UsageMetadata.CandidatesTokenCount),
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
}
}
// Process content parts
var textParts []string
for _, part := range candidate.Content.Parts {
switch {
case part.Text != "":
textParts = append(textParts, part.Text)
case part.FunctionCall != nil:
args, err := json.Marshal(part.FunctionCall.Args)
if err != nil {
return nil, fmt.Errorf("marshal function call arguments failed: %w", err)
}
message.ToolCalls = append(message.ToolCalls, schema.ToolCall{
ID: part.FunctionCall.Name,
Function: schema.FunctionCall{
Name: part.FunctionCall.Name,
Arguments: string(args),
},
})
case part.ExecutableCode != nil:
textParts = append(textParts, part.ExecutableCode.Code)
case part.CodeExecutionResult != nil:
textParts = append(textParts, part.CodeExecutionResult.Output)
}
}
// Set content
if len(textParts) == 1 {
message.Content = textParts[0]
} else if len(textParts) > 1 {
for _, text := range textParts {
message.MultiContent = append(message.MultiContent, schema.ChatMessagePart{
Type: schema.ChatMessagePartTypeText,
Text: text,
})
}
}
return message, nil
}
func (cm *ChatModel) convertCallbackOutput(message *schema.Message, conf *model.Config) *model.CallbackOutput {
callbackOutput := &model.CallbackOutput{
Message: message,
Config: conf,
}
if message.ResponseMeta != nil && message.ResponseMeta.Usage != nil {
callbackOutput.TokenUsage = &model.TokenUsage{
PromptTokens: message.ResponseMeta.Usage.PromptTokens,
CompletionTokens: message.ResponseMeta.Usage.CompletionTokens,
TotalTokens: message.ResponseMeta.Usage.TotalTokens,
}
}
return callbackOutput
}
func (cm *ChatModel) IsCallbacksEnabled() bool {
return true
}
const typ = "Gemini"
func (cm *ChatModel) GetType() string {
return typ
}
type panicErr struct {
info any
stack []byte
}
func (p *panicErr) Error() string {
return fmt.Sprintf("panic error: %v, \nstack: %s", p.info, string(p.stack))
}
func newPanicErr(info any, stack []byte) error {
return &panicErr{
info: info,
stack: stack,
}
}
+7 -4
View File
@@ -7,12 +7,12 @@ import (
"strings"
"github.com/cloudwego/eino-ext/components/model/claude"
"github.com/cloudwego/eino-ext/components/model/gemini"
"github.com/cloudwego/eino-ext/components/model/ollama"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/model"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/option"
"google.golang.org/genai"
"github.com/mark3labs/mcphost/internal/models/gemini"
)
// ProviderConfig holds configuration for creating LLM providers
@@ -105,7 +105,10 @@ func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName
return nil, fmt.Errorf("Google API key not provided. Use --google-api-key flag or GOOGLE_API_KEY/GEMINI_API_KEY environment variable")
}
client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey))
client, err := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: apiKey,
Backend: genai.BackendGeminiAPI,
})
if err != nil {
return nil, fmt.Errorf("failed to create Google client: %v", err)
}