mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-13 19:20:06 +00:00
fix gemini
This commit is contained in:
@@ -10,26 +10,21 @@ require (
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834
|
||||
github.com/cloudwego/eino v0.3.41
|
||||
github.com/cloudwego/eino-ext/components/model/claude v0.0.0-20250609074000-b7f307dffa18
|
||||
github.com/cloudwego/eino-ext/components/model/gemini v0.0.0-20250609074000-b7f307dffa18
|
||||
github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250609074000-b7f307dffa18
|
||||
github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250609074000-b7f307dffa18
|
||||
github.com/getkin/kin-openapi v0.118.0
|
||||
github.com/google/generative-ai-go v0.19.0
|
||||
github.com/mark3labs/mcp-go v0.31.0
|
||||
github.com/spf13/cobra v1.8.1
|
||||
github.com/spf13/viper v1.20.1
|
||||
golang.org/x/term v0.31.0
|
||||
google.golang.org/api v0.228.0
|
||||
google.golang.org/genai v1.10.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.116.0 // indirect
|
||||
cloud.google.com/go/ai v0.8.0 // indirect
|
||||
cloud.google.com/go/auth v0.15.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.6.0 // indirect
|
||||
cloud.google.com/go/longrunning v0.5.7 // indirect
|
||||
github.com/alecthomas/chroma/v2 v2.14.0 // indirect
|
||||
github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.8 // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
@@ -65,12 +60,14 @@ require (
|
||||
github.com/go-openapi/jsonpointer v0.21.0 // indirect
|
||||
github.com/go-openapi/swag v0.23.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
|
||||
github.com/goph/emperror v0.17.2 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/invopop/yaml v0.1.0 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
@@ -105,7 +102,6 @@ require (
|
||||
github.com/yuin/goldmark v1.7.8 // indirect
|
||||
github.com/yuin/goldmark-emoji v1.0.5 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 // indirect
|
||||
go.opentelemetry.io/otel v1.34.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.34.0 // indirect
|
||||
@@ -115,9 +111,6 @@ require (
|
||||
golang.org/x/arch v0.12.0 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/net v0.37.0 // indirect
|
||||
golang.org/x/oauth2 v0.28.0 // indirect
|
||||
golang.org/x/time v0.11.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4 // indirect
|
||||
google.golang.org/grpc v1.71.0 // indirect
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
|
||||
@@ -1,15 +1,9 @@
|
||||
cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE=
|
||||
cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U=
|
||||
cloud.google.com/go/ai v0.8.0 h1:rXUEz8Wp2OlrM8r1bfmpF2+VKqc1VJpafE3HgzRnD/w=
|
||||
cloud.google.com/go/ai v0.8.0/go.mod h1:t3Dfk4cM61sytiggo2UyGsDVW3RF1qGZaUKDrZFyqkE=
|
||||
cloud.google.com/go/auth v0.15.0 h1:Ly0u4aA5vG/fsSsxu98qCQBemXtAtJf+95z9HK+cxps=
|
||||
cloud.google.com/go/auth v0.15.0/go.mod h1:WJDGqZ1o9E9wKIL+IwStfyn/+s59zl4Bi+1KQNVXLZ8=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
|
||||
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
|
||||
cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU=
|
||||
cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng=
|
||||
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
|
||||
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
|
||||
github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o=
|
||||
@@ -99,8 +93,6 @@ github.com/cloudwego/eino v0.3.41 h1:bb6W5/+8QE+jlhDBY2TxcxYu6odpNQ436oDdBF45jgQ
|
||||
github.com/cloudwego/eino v0.3.41/go.mod h1:wUjz990apdsaOraOXdh6CdhVXq8DJsOvLsVlxNTcNfY=
|
||||
github.com/cloudwego/eino-ext/components/model/claude v0.0.0-20250609074000-b7f307dffa18 h1:foS8HJW0U8KOYy1hyWITSZUeMBZbckaBUjQQUIPwauw=
|
||||
github.com/cloudwego/eino-ext/components/model/claude v0.0.0-20250609074000-b7f307dffa18/go.mod h1:IhvvyyldQVIyOogbEbmmOjrxhZbpvbwjeiQJ4ZoNOX4=
|
||||
github.com/cloudwego/eino-ext/components/model/gemini v0.0.0-20250609074000-b7f307dffa18 h1:pyiKX7sTo9BgtbOq/twj5olsMl/iKj5RTIpDqgAtdBM=
|
||||
github.com/cloudwego/eino-ext/components/model/gemini v0.0.0-20250609074000-b7f307dffa18/go.mod h1:NTYXf6aAoO2zBES9S1lzkBvQoyD6UcUGvLmUAS5TMRU=
|
||||
github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250609074000-b7f307dffa18 h1:UxZVTapUwbzkRIP4Bl/VKni65wI+sfq6oPveZ+76aww=
|
||||
github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250609074000-b7f307dffa18/go.mod h1:giNUFqA+V7xrm/EDvH7JFnDqoWI+e2m1SVAnReU+Fd8=
|
||||
github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250609074000-b7f307dffa18 h1:FAg2QmtJ0tA/3BmQlUPdhZ9Nzzsov76ry7b3Gn86dAs=
|
||||
@@ -151,8 +143,6 @@ github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg=
|
||||
github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
@@ -171,6 +161,8 @@ github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfre
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
@@ -331,8 +323,6 @@ github.com/yuin/goldmark-emoji v1.0.5 h1:EMVWyCGPlXJfUXBXpuMu+ii3TIaxbVBnEX9uaDC
|
||||
github.com/yuin/goldmark-emoji v1.0.5/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 h1:rgMkmiGfix9vFJDcDi1PK8WEQP4FLQwLDfhp5ZLpFeE=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0/go.mod h1:ijPqXp5P6IRRByFVVg9DY8P5HkxkHE5ARIa+86aXPf4=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 h1:CV7UdSGJt/Ao6Gp4CXckLxVRRsRgDHoI8XjbL3PDl8s=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0/go.mod h1:FRmFuRJfag1IZ2dPkHnEoSFVgTVPUd2qf5Vi69hLb8I=
|
||||
go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY=
|
||||
@@ -361,8 +351,6 @@ golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUU
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
||||
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc=
|
||||
golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
|
||||
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
@@ -378,12 +366,8 @@ golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
|
||||
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
|
||||
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
|
||||
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
google.golang.org/api v0.228.0 h1:X2DJ/uoWGnY5obVjewbp8icSL5U4FzuCfy9OjbLSnLs=
|
||||
google.golang.org/api v0.228.0/go.mod h1:wNvRS1Pbe8r4+IfBIniV8fwCpGwTrYa+kMUDiC5z5a4=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 h1:GVIKPyP/kLIyVOgOnTwFOrvQaQUzOzGMCxgFUOEmm24=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422/go.mod h1:b6h1vNKhxaSoEI+5jc3PJUCustfli/mRab7295pY7rw=
|
||||
google.golang.org/genai v1.10.0 h1:ETP0Yksn5KUSEn5+ihMOnP3IqjZ+7Z4i0LjJslEXatI=
|
||||
google.golang.org/genai v1.10.0/go.mod h1:TyfOKRz/QyCaj6f/ZDt505x+YreXnY40l2I6k8TvgqY=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4 h1:iK2jbkWL86DXjEx0qiHcRE9dE4/Ahua5k6V8OWFb//c=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4/go.mod h1:LuRYeWDFV6WOn90g357N17oMCaxpgCnbi/44qJvDn2I=
|
||||
google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg=
|
||||
|
||||
@@ -0,0 +1,704 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
var _ model.ToolCallingChatModel = (*ChatModel)(nil)
|
||||
|
||||
// NewChatModel creates a new Gemini chat model instance
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the operation
|
||||
// - cfg: Configuration for the Gemini model
|
||||
//
|
||||
// Returns:
|
||||
// - model.ChatModel: A chat model interface implementation
|
||||
// - error: Any error that occurred during creation
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// model, err := gemini.NewChatModel(ctx, &gemini.Config{
|
||||
// Client: client,
|
||||
// Model: "gemini-pro",
|
||||
// })
|
||||
func NewChatModel(_ context.Context, cfg *Config) (*ChatModel, error) {
|
||||
return &ChatModel{
|
||||
cli: cfg.Client,
|
||||
model: cfg.Model,
|
||||
maxTokens: cfg.MaxTokens,
|
||||
temperature: cfg.Temperature,
|
||||
topP: cfg.TopP,
|
||||
topK: cfg.TopK,
|
||||
responseSchema: cfg.ResponseSchema,
|
||||
enableCodeExecution: cfg.EnableCodeExecution,
|
||||
safetySettings: cfg.SafetySettings,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Config contains the configuration options for the Gemini model
|
||||
type Config struct {
|
||||
// Client is the Gemini API client instance
|
||||
// Required for making API calls to Gemini
|
||||
Client *genai.Client
|
||||
|
||||
// Model specifies which Gemini model to use
|
||||
// Examples: "gemini-pro", "gemini-pro-vision", "gemini-1.5-flash"
|
||||
Model string
|
||||
|
||||
// MaxTokens limits the maximum number of tokens in the response
|
||||
// Optional. Example: maxTokens := 100
|
||||
MaxTokens *int
|
||||
|
||||
// Temperature controls randomness in responses
|
||||
// Range: [0.0, 1.0], where 0.0 is more focused and 1.0 is more creative
|
||||
// Optional. Example: temperature := float32(0.7)
|
||||
Temperature *float32
|
||||
|
||||
// TopP controls diversity via nucleus sampling
|
||||
// Range: [0.0, 1.0], where 1.0 disables nucleus sampling
|
||||
// Optional. Example: topP := float32(0.95)
|
||||
TopP *float32
|
||||
|
||||
// TopK controls diversity by limiting the top K tokens to sample from
|
||||
// Optional. Example: topK := int32(40)
|
||||
TopK *int32
|
||||
|
||||
// ResponseSchema defines the structure for JSON responses
|
||||
// Optional. Used when you want structured output in JSON format
|
||||
ResponseSchema *openapi3.Schema
|
||||
|
||||
// EnableCodeExecution allows the model to execute code
|
||||
// Warning: Be cautious with code execution in production
|
||||
// Optional. Default: false
|
||||
EnableCodeExecution bool
|
||||
|
||||
// SafetySettings configures content filtering for different harm categories
|
||||
// Controls the model's filtering behavior for potentially harmful content
|
||||
// Optional.
|
||||
SafetySettings []*genai.SafetySetting
|
||||
}
|
||||
|
||||
// options contains Gemini-specific options for model configuration
|
||||
type options struct {
|
||||
TopK *int32
|
||||
ResponseSchema *openapi3.Schema
|
||||
}
|
||||
|
||||
type ChatModel struct {
|
||||
cli *genai.Client
|
||||
|
||||
model string
|
||||
maxTokens *int
|
||||
topP *float32
|
||||
temperature *float32
|
||||
topK *int32
|
||||
responseSchema *openapi3.Schema
|
||||
tools []*genai.Tool
|
||||
origTools []*schema.ToolInfo
|
||||
toolChoice *schema.ToolChoice
|
||||
enableCodeExecution bool
|
||||
safetySettings []*genai.SafetySetting
|
||||
}
|
||||
|
||||
func (cm *ChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (message *schema.Message, err error) {
|
||||
ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel)
|
||||
|
||||
config, conf, err := cm.buildGenerateConfig(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx = callbacks.OnStart(ctx, &model.CallbackInput{
|
||||
Messages: input,
|
||||
Tools: model.GetCommonOptions(&model.Options{Tools: cm.origTools}, opts...).Tools,
|
||||
Config: conf,
|
||||
})
|
||||
defer func() {
|
||||
if err != nil {
|
||||
callbacks.OnError(ctx, err)
|
||||
}
|
||||
}()
|
||||
|
||||
if len(input) == 0 {
|
||||
return nil, fmt.Errorf("gemini input is empty")
|
||||
}
|
||||
|
||||
contents, err := cm.convertSchemaMessages(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := cm.cli.Models.GenerateContent(ctx, cm.model, contents, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate content failed: %w", err)
|
||||
}
|
||||
|
||||
message, err = cm.convertResponse(result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert response failed: %w", err)
|
||||
}
|
||||
|
||||
callbacks.OnEnd(ctx, cm.convertCallbackOutput(message, conf))
|
||||
return message, nil
|
||||
}
|
||||
|
||||
func (cm *ChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (result *schema.StreamReader[*schema.Message], err error) {
|
||||
ctx = callbacks.EnsureRunInfo(ctx, cm.GetType(), components.ComponentOfChatModel)
|
||||
|
||||
config, conf, err := cm.buildGenerateConfig(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx = callbacks.OnStart(ctx, &model.CallbackInput{
|
||||
Messages: input,
|
||||
Tools: model.GetCommonOptions(&model.Options{Tools: cm.origTools}, opts...).Tools,
|
||||
Config: conf,
|
||||
})
|
||||
defer func() {
|
||||
if err != nil {
|
||||
callbacks.OnError(ctx, err)
|
||||
}
|
||||
}()
|
||||
|
||||
if len(input) == 0 {
|
||||
return nil, fmt.Errorf("gemini input is empty")
|
||||
}
|
||||
|
||||
contents, err := cm.convertSchemaMessages(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sr, sw := schema.Pipe[*model.CallbackOutput](1)
|
||||
go func() {
|
||||
defer func() {
|
||||
panicErr := recover()
|
||||
if panicErr != nil {
|
||||
_ = sw.Send(nil, newPanicErr(panicErr, debug.Stack()))
|
||||
}
|
||||
sw.Close()
|
||||
}()
|
||||
|
||||
for resp, err := range cm.cli.Models.GenerateContentStream(ctx, cm.model, contents, config) {
|
||||
if err != nil {
|
||||
sw.Send(nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
message, err := cm.convertResponse(resp)
|
||||
if err != nil {
|
||||
sw.Send(nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
closed := sw.Send(cm.convertCallbackOutput(message, conf), nil)
|
||||
if closed {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
srList := sr.Copy(2)
|
||||
callbacks.OnEndWithStreamOutput(ctx, srList[0])
|
||||
return schema.StreamReaderWithConvert(srList[1], func(t *model.CallbackOutput) (*schema.Message, error) {
|
||||
return t.Message, nil
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (cm *ChatModel) WithTools(tools []*schema.ToolInfo) (model.ToolCallingChatModel, error) {
|
||||
if len(tools) == 0 {
|
||||
return nil, errors.New("no tools to bind")
|
||||
}
|
||||
gTools, err := cm.convertToGeminiTools(tools)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert to gemini tools failed: %w", err)
|
||||
}
|
||||
|
||||
tc := schema.ToolChoiceAllowed
|
||||
ncm := *cm
|
||||
ncm.toolChoice = &tc
|
||||
ncm.tools = gTools
|
||||
ncm.origTools = tools
|
||||
return &ncm, nil
|
||||
}
|
||||
|
||||
func (cm *ChatModel) BindTools(tools []*schema.ToolInfo) error {
|
||||
if len(tools) == 0 {
|
||||
return errors.New("no tools to bind")
|
||||
}
|
||||
gTools, err := cm.convertToGeminiTools(tools)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cm.tools = gTools
|
||||
cm.origTools = tools
|
||||
tc := schema.ToolChoiceAllowed
|
||||
cm.toolChoice = &tc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *ChatModel) BindForcedTools(tools []*schema.ToolInfo) error {
|
||||
if len(tools) == 0 {
|
||||
return errors.New("no tools to bind")
|
||||
}
|
||||
gTools, err := cm.convertToGeminiTools(tools)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cm.tools = gTools
|
||||
cm.origTools = tools
|
||||
tc := schema.ToolChoiceForced
|
||||
cm.toolChoice = &tc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *ChatModel) buildGenerateConfig(opts ...model.Option) (*genai.GenerateContentConfig, *model.Config, error) {
|
||||
commonOptions := model.GetCommonOptions(&model.Options{
|
||||
Temperature: cm.temperature,
|
||||
MaxTokens: cm.maxTokens,
|
||||
TopP: cm.topP,
|
||||
Tools: nil,
|
||||
ToolChoice: cm.toolChoice,
|
||||
}, opts...)
|
||||
geminiOptions := model.GetImplSpecificOptions(&options{
|
||||
TopK: cm.topK,
|
||||
ResponseSchema: cm.responseSchema,
|
||||
}, opts...)
|
||||
|
||||
conf := &model.Config{}
|
||||
config := &genai.GenerateContentConfig{}
|
||||
|
||||
// Set model
|
||||
if commonOptions.Model != nil {
|
||||
conf.Model = *commonOptions.Model
|
||||
} else {
|
||||
conf.Model = cm.model
|
||||
}
|
||||
|
||||
// Set temperature
|
||||
if commonOptions.Temperature != nil {
|
||||
conf.Temperature = *commonOptions.Temperature
|
||||
config.Temperature = commonOptions.Temperature
|
||||
} else if cm.temperature != nil {
|
||||
conf.Temperature = *cm.temperature
|
||||
config.Temperature = cm.temperature
|
||||
}
|
||||
|
||||
// Set max tokens
|
||||
if commonOptions.MaxTokens != nil {
|
||||
conf.MaxTokens = *commonOptions.MaxTokens
|
||||
config.MaxOutputTokens = int32(*commonOptions.MaxTokens)
|
||||
} else if cm.maxTokens != nil {
|
||||
conf.MaxTokens = *cm.maxTokens
|
||||
config.MaxOutputTokens = int32(*cm.maxTokens)
|
||||
}
|
||||
|
||||
// Set top P
|
||||
if commonOptions.TopP != nil {
|
||||
conf.TopP = *commonOptions.TopP
|
||||
config.TopP = commonOptions.TopP
|
||||
} else if cm.topP != nil {
|
||||
conf.TopP = *cm.topP
|
||||
config.TopP = cm.topP
|
||||
}
|
||||
|
||||
// Set top K
|
||||
if geminiOptions.TopK != nil {
|
||||
config.TopK = genai.Ptr(float32(*geminiOptions.TopK))
|
||||
} else if cm.topK != nil {
|
||||
config.TopK = genai.Ptr(float32(*cm.topK))
|
||||
}
|
||||
|
||||
// Set tools
|
||||
tools := cm.tools
|
||||
if commonOptions.Tools != nil {
|
||||
var err error
|
||||
tools, err = cm.convertToGeminiTools(commonOptions.Tools)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
config.Tools = tools
|
||||
}
|
||||
|
||||
// Set tool choice
|
||||
if commonOptions.ToolChoice != nil {
|
||||
switch *commonOptions.ToolChoice {
|
||||
case schema.ToolChoiceForbidden:
|
||||
config.ToolConfig = &genai.ToolConfig{
|
||||
FunctionCallingConfig: &genai.FunctionCallingConfig{
|
||||
Mode: genai.FunctionCallingConfigModeNone,
|
||||
},
|
||||
}
|
||||
case schema.ToolChoiceAllowed:
|
||||
config.ToolConfig = &genai.ToolConfig{
|
||||
FunctionCallingConfig: &genai.FunctionCallingConfig{
|
||||
Mode: genai.FunctionCallingConfigModeAuto,
|
||||
},
|
||||
}
|
||||
case schema.ToolChoiceForced:
|
||||
if len(tools) == 0 {
|
||||
return nil, nil, fmt.Errorf("tool choice is forced but no tools provided")
|
||||
}
|
||||
config.ToolConfig = &genai.ToolConfig{
|
||||
FunctionCallingConfig: &genai.FunctionCallingConfig{
|
||||
Mode: genai.FunctionCallingConfigModeAny,
|
||||
},
|
||||
}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("tool choice=%s not supported", *commonOptions.ToolChoice)
|
||||
}
|
||||
}
|
||||
|
||||
// Set safety settings
|
||||
if len(cm.safetySettings) > 0 {
|
||||
config.SafetySettings = cm.safetySettings
|
||||
}
|
||||
|
||||
// Set response schema for JSON mode
|
||||
if geminiOptions.ResponseSchema != nil {
|
||||
gSchema, err := cm.convertOpenAPISchema(geminiOptions.ResponseSchema)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("convert response schema failed: %w", err)
|
||||
}
|
||||
config.ResponseMIMEType = "application/json"
|
||||
config.ResponseSchema = gSchema
|
||||
}
|
||||
|
||||
return config, conf, nil
|
||||
}
|
||||
|
||||
func (cm *ChatModel) convertToGeminiTools(tools []*schema.ToolInfo) ([]*genai.Tool, error) {
|
||||
if len(tools) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var functionDeclarations []*genai.FunctionDeclaration
|
||||
for _, tool := range tools {
|
||||
openSchema, err := tool.ToOpenAPIV3()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get open schema failed: %w", err)
|
||||
}
|
||||
|
||||
gSchema, err := cm.convertOpenAPISchema(openSchema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert open schema failed: %w", err)
|
||||
}
|
||||
|
||||
funcDecl := &genai.FunctionDeclaration{
|
||||
Name: tool.Name,
|
||||
Description: tool.Desc,
|
||||
Parameters: gSchema,
|
||||
}
|
||||
functionDeclarations = append(functionDeclarations, funcDecl)
|
||||
}
|
||||
|
||||
return []*genai.Tool{{FunctionDeclarations: functionDeclarations}}, nil
|
||||
}
|
||||
|
||||
func (cm *ChatModel) convertOpenAPISchema(schema *openapi3.Schema) (*genai.Schema, error) {
|
||||
if schema == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := &genai.Schema{
|
||||
Description: schema.Description,
|
||||
}
|
||||
|
||||
switch schema.Type {
|
||||
case openapi3.TypeObject:
|
||||
result.Type = genai.TypeObject
|
||||
if schema.Properties != nil {
|
||||
properties := make(map[string]*genai.Schema)
|
||||
for name, prop := range schema.Properties {
|
||||
if prop == nil || prop.Value == nil {
|
||||
continue
|
||||
}
|
||||
propSchema, err := cm.convertOpenAPISchema(prop.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
properties[name] = propSchema
|
||||
}
|
||||
result.Properties = properties
|
||||
}
|
||||
if schema.Required != nil {
|
||||
result.Required = schema.Required
|
||||
}
|
||||
case openapi3.TypeArray:
|
||||
result.Type = genai.TypeArray
|
||||
if schema.Items != nil && schema.Items.Value != nil {
|
||||
itemSchema, err := cm.convertOpenAPISchema(schema.Items.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.Items = itemSchema
|
||||
}
|
||||
case openapi3.TypeString:
|
||||
result.Type = genai.TypeString
|
||||
if schema.Enum != nil {
|
||||
enums := make([]string, 0, len(schema.Enum))
|
||||
for _, e := range schema.Enum {
|
||||
if str, ok := e.(string); ok {
|
||||
enums = append(enums, str)
|
||||
} else {
|
||||
return nil, fmt.Errorf("enum value must be a string, schema: %+v", schema)
|
||||
}
|
||||
}
|
||||
result.Enum = enums
|
||||
}
|
||||
case openapi3.TypeNumber:
|
||||
result.Type = genai.TypeNumber
|
||||
case openapi3.TypeInteger:
|
||||
result.Type = genai.TypeInteger
|
||||
case openapi3.TypeBoolean:
|
||||
result.Type = genai.TypeBoolean
|
||||
default:
|
||||
result.Type = genai.TypeUnspecified
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (cm *ChatModel) convertSchemaMessages(messages []*schema.Message) ([]*genai.Content, error) {
|
||||
var contents []*genai.Content
|
||||
for _, message := range messages {
|
||||
content, err := cm.convertSchemaMessage(message)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert schema message failed: %w", err)
|
||||
}
|
||||
if content != nil {
|
||||
contents = append(contents, content)
|
||||
}
|
||||
}
|
||||
return contents, nil
|
||||
}
|
||||
|
||||
func (cm *ChatModel) convertSchemaMessage(message *schema.Message) (*genai.Content, error) {
|
||||
if message == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var parts []*genai.Part
|
||||
|
||||
// Handle tool calls
|
||||
if message.ToolCalls != nil {
|
||||
for _, call := range message.ToolCalls {
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(call.Function.Arguments), &args); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tool call arguments failed: %w", err)
|
||||
}
|
||||
parts = append(parts, &genai.Part{
|
||||
FunctionCall: &genai.FunctionCall{
|
||||
Name: call.Function.Name,
|
||||
Args: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool responses
|
||||
if message.Role == schema.Tool {
|
||||
var response map[string]any
|
||||
if err := json.Unmarshal([]byte(message.Content), &response); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tool response failed: %w", err)
|
||||
}
|
||||
parts = append(parts, &genai.Part{
|
||||
FunctionResponse: &genai.FunctionResponse{
|
||||
Name: message.ToolCallID,
|
||||
Response: response,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
// Handle text content
|
||||
if message.Content != "" {
|
||||
parts = append(parts, &genai.Part{Text: message.Content})
|
||||
}
|
||||
|
||||
// Handle multi-content (images, audio, etc.)
|
||||
for _, content := range message.MultiContent {
|
||||
switch content.Type {
|
||||
case schema.ChatMessagePartTypeText:
|
||||
parts = append(parts, &genai.Part{Text: content.Text})
|
||||
case schema.ChatMessagePartTypeImageURL:
|
||||
if content.ImageURL != nil {
|
||||
parts = append(parts, &genai.Part{
|
||||
FileData: &genai.FileData{
|
||||
MIMEType: content.ImageURL.MIMEType,
|
||||
FileURI: content.ImageURL.URI,
|
||||
},
|
||||
})
|
||||
}
|
||||
case schema.ChatMessagePartTypeAudioURL:
|
||||
if content.AudioURL != nil {
|
||||
parts = append(parts, &genai.Part{
|
||||
FileData: &genai.FileData{
|
||||
MIMEType: content.AudioURL.MIMEType,
|
||||
FileURI: content.AudioURL.URI,
|
||||
},
|
||||
})
|
||||
}
|
||||
case schema.ChatMessagePartTypeVideoURL:
|
||||
if content.VideoURL != nil {
|
||||
parts = append(parts, &genai.Part{
|
||||
FileData: &genai.FileData{
|
||||
MIMEType: content.VideoURL.MIMEType,
|
||||
FileURI: content.VideoURL.URI,
|
||||
},
|
||||
})
|
||||
}
|
||||
case schema.ChatMessagePartTypeFileURL:
|
||||
if content.FileURL != nil {
|
||||
parts = append(parts, &genai.Part{
|
||||
FileData: &genai.FileData{
|
||||
MIMEType: content.FileURL.MIMEType,
|
||||
FileURI: content.FileURL.URI,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &genai.Content{
|
||||
Role: string(cm.convertRole(message.Role)),
|
||||
Parts: parts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (cm *ChatModel) convertRole(role schema.RoleType) genai.Role {
|
||||
switch role {
|
||||
case schema.Assistant:
|
||||
return genai.RoleModel
|
||||
case schema.User:
|
||||
return genai.RoleUser
|
||||
case schema.Tool:
|
||||
return genai.RoleUser // Tool responses are treated as user messages in the new API
|
||||
default:
|
||||
return genai.RoleUser
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *ChatModel) convertResponse(resp *genai.GenerateContentResponse) (*schema.Message, error) {
|
||||
if len(resp.Candidates) == 0 {
|
||||
return nil, fmt.Errorf("gemini result is empty")
|
||||
}
|
||||
|
||||
candidate := resp.Candidates[0]
|
||||
message := &schema.Message{
|
||||
Role: schema.Assistant,
|
||||
ResponseMeta: &schema.ResponseMeta{
|
||||
FinishReason: string(candidate.FinishReason),
|
||||
},
|
||||
}
|
||||
|
||||
// Handle usage metadata
|
||||
if resp.UsageMetadata != nil {
|
||||
message.ResponseMeta.Usage = &schema.TokenUsage{
|
||||
PromptTokens: int(resp.UsageMetadata.PromptTokenCount),
|
||||
CompletionTokens: int(resp.UsageMetadata.CandidatesTokenCount),
|
||||
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
|
||||
}
|
||||
}
|
||||
|
||||
// Process content parts
|
||||
var textParts []string
|
||||
for _, part := range candidate.Content.Parts {
|
||||
switch {
|
||||
case part.Text != "":
|
||||
textParts = append(textParts, part.Text)
|
||||
case part.FunctionCall != nil:
|
||||
args, err := json.Marshal(part.FunctionCall.Args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal function call arguments failed: %w", err)
|
||||
}
|
||||
message.ToolCalls = append(message.ToolCalls, schema.ToolCall{
|
||||
ID: part.FunctionCall.Name,
|
||||
Function: schema.FunctionCall{
|
||||
Name: part.FunctionCall.Name,
|
||||
Arguments: string(args),
|
||||
},
|
||||
})
|
||||
case part.ExecutableCode != nil:
|
||||
textParts = append(textParts, part.ExecutableCode.Code)
|
||||
case part.CodeExecutionResult != nil:
|
||||
textParts = append(textParts, part.CodeExecutionResult.Output)
|
||||
}
|
||||
}
|
||||
|
||||
// Set content
|
||||
if len(textParts) == 1 {
|
||||
message.Content = textParts[0]
|
||||
} else if len(textParts) > 1 {
|
||||
for _, text := range textParts {
|
||||
message.MultiContent = append(message.MultiContent, schema.ChatMessagePart{
|
||||
Type: schema.ChatMessagePartTypeText,
|
||||
Text: text,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return message, nil
|
||||
}
|
||||
|
||||
func (cm *ChatModel) convertCallbackOutput(message *schema.Message, conf *model.Config) *model.CallbackOutput {
|
||||
callbackOutput := &model.CallbackOutput{
|
||||
Message: message,
|
||||
Config: conf,
|
||||
}
|
||||
if message.ResponseMeta != nil && message.ResponseMeta.Usage != nil {
|
||||
callbackOutput.TokenUsage = &model.TokenUsage{
|
||||
PromptTokens: message.ResponseMeta.Usage.PromptTokens,
|
||||
CompletionTokens: message.ResponseMeta.Usage.CompletionTokens,
|
||||
TotalTokens: message.ResponseMeta.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
return callbackOutput
|
||||
}
|
||||
|
||||
func (cm *ChatModel) IsCallbacksEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
const typ = "Gemini"
|
||||
|
||||
func (cm *ChatModel) GetType() string {
|
||||
return typ
|
||||
}
|
||||
|
||||
type panicErr struct {
|
||||
info any
|
||||
stack []byte
|
||||
}
|
||||
|
||||
func (p *panicErr) Error() string {
|
||||
return fmt.Sprintf("panic error: %v, \nstack: %s", p.info, string(p.stack))
|
||||
}
|
||||
|
||||
func newPanicErr(info any, stack []byte) error {
|
||||
return &panicErr{
|
||||
info: info,
|
||||
stack: stack,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,12 +7,12 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/claude"
|
||||
"github.com/cloudwego/eino-ext/components/model/gemini"
|
||||
"github.com/cloudwego/eino-ext/components/model/ollama"
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/google/generative-ai-go/genai"
|
||||
"google.golang.org/api/option"
|
||||
"google.golang.org/genai"
|
||||
|
||||
"github.com/mark3labs/mcphost/internal/models/gemini"
|
||||
)
|
||||
|
||||
// ProviderConfig holds configuration for creating LLM providers
|
||||
@@ -105,7 +105,10 @@ func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
return nil, fmt.Errorf("Google API key not provided. Use --google-api-key flag or GOOGLE_API_KEY/GEMINI_API_KEY environment variable")
|
||||
}
|
||||
|
||||
client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey))
|
||||
client, err := genai.NewClient(ctx, &genai.ClientConfig{
|
||||
APIKey: apiKey,
|
||||
Backend: genai.BackendGeminiAPI,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Google client: %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user