Improve Ollama and GPU handling (#85)

* preload ollama models

* fix num-gpu

* add download progress bar for ollama models

* fmt

* reorg
This commit is contained in:
Ed Zynda
2025-06-26 13:32:18 +03:00
committed by GitHub
parent 6219b84937
commit a66c55e175
7 changed files with 578 additions and 49 deletions
+33 -8
View File
@@ -182,8 +182,9 @@ func init() {
flags.StringSliceVar(&stopSequences, "stop-sequences", nil, "custom stop sequences (comma-separated)")
// Ollama-specific parameters
flags.Int32Var(&numGPU, "num-gpu", 1, "number of GPUs to use for Ollama models")
flags.Int32Var(&mainGPU, "main-gpu", 0, "main GPU to use for Ollama models")
flags.Int32Var(&numGPU, "num-gpu-layers", -1, "number of model layers to offload to GPU for Ollama models (-1 for auto-detect)")
flags.MarkHidden("num-gpu-layers") // Advanced option, hidden from help
flags.Int32Var(&mainGPU, "main-gpu", 0, "main GPU device to use for Ollama models")
// Bind flags to viper for config file support
viper.BindPFlag("system-prompt", rootCmd.PersistentFlags().Lookup("system-prompt"))
@@ -198,7 +199,7 @@ func init() {
viper.BindPFlag("top-p", rootCmd.PersistentFlags().Lookup("top-p"))
viper.BindPFlag("top-k", rootCmd.PersistentFlags().Lookup("top-k"))
viper.BindPFlag("stop-sequences", rootCmd.PersistentFlags().Lookup("stop-sequences"))
viper.BindPFlag("num-gpu", rootCmd.PersistentFlags().Lookup("num-gpu"))
viper.BindPFlag("num-gpu-layers", rootCmd.PersistentFlags().Lookup("num-gpu-layers"))
viper.BindPFlag("main-gpu", rootCmd.PersistentFlags().Lookup("main-gpu"))
// Defaults are already set in flag definitions, no need to duplicate in viper
@@ -265,7 +266,7 @@ func runNormalMode(ctx context.Context) error {
temperature := float32(viper.GetFloat64("temperature"))
topP := float32(viper.GetFloat64("top-p"))
topK := int32(viper.GetInt("top-k"))
numGPU := int32(viper.GetInt("num-gpu"))
numGPU := int32(viper.GetInt("num-gpu-layers"))
mainGPU := int32(viper.GetInt("main-gpu"))
modelConfig := &models.ProviderConfig{
@@ -290,8 +291,27 @@ func runNormalMode(ctx context.Context) error {
MaxSteps: viper.GetInt("max-steps"), // Pass 0 for infinite, agent will handle it
}
// Create the agent
mcpAgent, err := agent.NewAgent(ctx, agentConfig)
// Create the agent with spinner for Ollama models
var mcpAgent *agent.Agent
if strings.HasPrefix(viper.GetString("model"), "ollama:") && !quietFlag {
// Create a temporary CLI for the spinner
tempCli, tempErr := ui.NewCLI(viper.GetBool("debug"))
if tempErr == nil {
err = tempCli.ShowSpinner("Loading Ollama model...", func() error {
var agentErr error
mcpAgent, agentErr = agent.NewAgent(ctx, agentConfig)
return agentErr
})
} else {
// Fallback without spinner
mcpAgent, err = agent.NewAgent(ctx, agentConfig)
}
} else {
// No spinner for other providers
mcpAgent, err = agent.NewAgent(ctx, agentConfig)
}
if err != nil {
return fmt.Errorf("failed to create agent: %v", err)
}
@@ -344,8 +364,13 @@ func runNormalMode(ctx context.Context) error {
if len(parts) == 2 {
cli.DisplayInfo(fmt.Sprintf("Model loaded: %s (%s)", parts[0], parts[1]))
}
cli.DisplayInfo(fmt.Sprintf("Loaded %d tools from MCP servers", len(tools)))
// Display loading message if available (e.g., GPU fallback info)
if loadingMessage := mcpAgent.GetLoadingMessage(); loadingMessage != "" {
cli.DisplayInfo(loadingMessage)
}
cli.DisplayInfo(fmt.Sprintf("Loaded %d tools from MCP servers", len(tools)))
// Display debug configuration if debug mode is enabled
if viper.GetBool("debug") {
debugConfig := map[string]any{
@@ -361,7 +386,7 @@ func runNormalMode(ctx context.Context) error {
// Add Ollama-specific parameters if using Ollama
if strings.HasPrefix(viper.GetString("model"), "ollama:") {
debugConfig["num-gpu"] = viper.GetInt("num-gpu")
debugConfig["num-gpu-layers"] = viper.GetInt("num-gpu-layers")
debugConfig["main-gpu"] = viper.GetInt("main-gpu")
}
+6 -5
View File
@@ -14,7 +14,6 @@ require (
github.com/cloudwego/eino-ext/components/model/claude v0.0.0-20250609074000-b7f307dffa18
github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250609074000-b7f307dffa18
github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250609074000-b7f307dffa18
github.com/getkin/kin-openapi v0.131.0
github.com/mark3labs/mcp-filesystem-server v0.11.1
github.com/mark3labs/mcp-go v0.32.0
github.com/ollama/ollama v0.5.12
@@ -25,6 +24,8 @@ require (
gopkg.in/yaml.v3 v3.0.1
)
require github.com/getkin/kin-openapi v0.118.0
require (
cloud.google.com/go v0.116.0 // indirect
cloud.google.com/go/auth v0.15.0 // indirect
@@ -51,6 +52,7 @@ require (
github.com/bytedance/sonic/loader v0.2.4 // indirect
github.com/catppuccin/go v0.2.0 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/harmonica v0.2.0 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13 // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf // indirect
github.com/cloudwego/base64x v0.1.5 // indirect
@@ -76,6 +78,7 @@ require (
github.com/goph/emperror v0.17.2 // indirect
github.com/gorilla/css v1.0.1 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/invopop/yaml v0.1.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
@@ -87,8 +90,6 @@ require (
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
github.com/muesli/reflow v0.3.0 // indirect
github.com/nikolalohinski/gonja v1.5.3 // indirect
github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 // indirect
github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/perimeterx/marshmallow v1.1.5 // indirect
github.com/pkg/errors v0.9.1 // indirect
@@ -127,8 +128,8 @@ require (
require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/charmbracelet/bubbles v0.20.0
github.com/charmbracelet/bubbletea v1.2.4
github.com/charmbracelet/bubbles v0.21.0
github.com/charmbracelet/bubbletea v1.3.5
github.com/charmbracelet/glamour v0.10.0
github.com/charmbracelet/x/ansi v0.8.0 // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
+23 -12
View File
@@ -73,14 +73,16 @@ github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFos
github.com/catppuccin/go v0.2.0 h1:ktBeIrIP42b/8FGiScP9sgrWOss3lw0Z5SktRoithGA=
github.com/catppuccin/go v0.2.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4=
github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE=
github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU=
github.com/charmbracelet/bubbletea v1.2.4 h1:KN8aCViA0eps9SCOThb2/XPIlea3ANJLUkv3KnQRNCE=
github.com/charmbracelet/bubbletea v1.2.4/go.mod h1:Qr6fVQw+wX7JkWWkVyXYk/ZUQ92a6XNekLXa3rR18MM=
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
github.com/charmbracelet/bubbletea v1.3.5 h1:JAMNLTbqMOhSwoELIr0qyP4VidFq72/6E9j7HHmRKQc=
github.com/charmbracelet/bubbletea v1.3.5/go.mod h1:TkCnmH+aBd4LrXhXcqrKiYwRs7qyQx5rBgH5fVY3v54=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
github.com/charmbracelet/glamour v0.10.0 h1:MtZvfwsYCx8jEPFJm3rIBFIMZUfUJ765oX8V6kXldcY=
github.com/charmbracelet/glamour v0.10.0/go.mod h1:f+uf+I/ChNmqo087elLnVdCiVgjSKWuXa/l6NU2ndYk=
github.com/charmbracelet/harmonica v0.2.0 h1:8NxJWRWg/bzKqqEaaeFNipOu77YR5t8aSwG4pgaUBiQ=
github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao=
github.com/charmbracelet/huh v0.3.0 h1:CxPplWkgW2yUTDDG0Z4S5HH8SJOosWHd4LxCvi0XsKE=
github.com/charmbracelet/huh v0.3.0/go.mod h1:fujUdKX8tC45CCSaRQdw789O6uaCRwx8l2NDyKfC4jA=
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE=
@@ -89,8 +91,8 @@ github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2ll
github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q=
github.com/charmbracelet/x/cellbuf v0.0.13 h1:/KBBKHuVRbq1lYx5BzEHBAFBP8VcQzJejZ/IA3iR28k=
github.com/charmbracelet/x/cellbuf v0.0.13/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b h1:MnAMdlwSltxJyULnrYbkZpp4k58Co7Tah3ciKhSNo0Q=
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ=
github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf h1:rLG0Yb6MQSDKdB52aGX55JT1oi0P0Kuaj7wi1bLUpnI=
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf/go.mod h1:B3UgsnsBZS/eX42BlaNiJkD1pPOUa+oF1IYC6Yd2CEU=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
@@ -131,8 +133,8 @@ github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY=
github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok=
github.com/getkin/kin-openapi v0.131.0 h1:NO2UeHnFKRYhZ8wg6Nyh5Cq7dHk4suQQr72a4pMrDxE=
github.com/getkin/kin-openapi v0.131.0/go.mod h1:3OlG51PCYNsPByuiMB0t4fjnNlIDnaEDsjiKUV8nL58=
github.com/getkin/kin-openapi v0.118.0 h1:z43njxPmJ7TaPpMSCQb7PN0dEYno4tyBPQcrFdHoLuM=
github.com/getkin/kin-openapi v0.118.0/go.mod h1:l5e9PaFUo9fyLJCPGQeXI2ML8c3P8BHOEV2VaAVf/pc=
github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ=
github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI=
github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98=
@@ -141,8 +143,10 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ=
github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY=
github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE=
github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM=
@@ -173,6 +177,7 @@ github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25d
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
@@ -180,6 +185,8 @@ github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSo
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/invopop/yaml v0.1.0 h1:YW3WGUoJEXYfzWBjn00zIlrw7brGVD0fUKRYDPAPhrc=
github.com/invopop/yaml v0.1.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
@@ -202,6 +209,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mark3labs/mcp-filesystem-server v0.11.1 h1:7uKIZRMaKWfgvtDj/uLAvo0+7Mwb8gxo5DJywhqFW88=
@@ -240,10 +249,6 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c=
github.com/nikolalohinski/gonja v1.5.3/go.mod h1:RmjwxNiXAEqcq1HeK5SSMmqFJvKOfTfXhkJv6YBtPa4=
github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 h1:G7ERwszslrBzRxj//JalHPu/3yz+De2J+4aLtSRlHiY=
github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037/go.mod h1:2bpvgLBZEtENV5scfDFEtB/5+1M4hkQhDQrccEJ/qGw=
github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 h1:bQx3WeLcUWy+RletIKwUIt4x3t8n2SxavmoclizMb8c=
github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o=
github.com/ollama/ollama v0.5.12 h1:qM+k/ozyHLJzEQoAEPrUQ0qXqsgDEEdpIVwuwScrd2U=
github.com/ollama/ollama v0.5.12/go.mod h1:ibdmDvb/TjKY1OArBWIazL3pd1DHTk8eG2MMjEkWhiI=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
@@ -251,6 +256,7 @@ github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W
github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
github.com/perimeterx/marshmallow v1.1.4/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s=
github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@@ -323,6 +329,9 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo=
github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg=
@@ -463,6 +472,7 @@ google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
@@ -473,6 +483,7 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
+16 -9
View File
@@ -41,16 +41,17 @@ type ToolCallContentHandler func(content string)
// Agent is the agent with real-time tool call display.
type Agent struct {
toolManager *tools.MCPToolManager
model model.ToolCallingChatModel
maxSteps int
systemPrompt string
toolManager *tools.MCPToolManager
model model.ToolCallingChatModel
maxSteps int
systemPrompt string
loadingMessage string // Message from provider loading (e.g., GPU fallback info)
}
// NewAgent creates an agent with MCP tool integration and real-time tool call display
func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) {
// Create the LLM provider
model, err := models.CreateProvider(ctx, config.ModelConfig)
providerResult, err := models.CreateProvider(ctx, config.ModelConfig)
if err != nil {
return nil, fmt.Errorf("failed to create model provider: %v", err)
}
@@ -62,10 +63,11 @@ func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) {
}
return &Agent{
toolManager: toolManager,
model: model,
maxSteps: config.MaxSteps, // Keep 0 for infinite, handle in loop
systemPrompt: config.SystemPrompt,
toolManager: toolManager,
model: providerResult.Model,
maxSteps: config.MaxSteps, // Keep 0 for infinite, handle in loop
systemPrompt: config.SystemPrompt,
loadingMessage: providerResult.Message,
}, nil
}
@@ -220,6 +222,11 @@ func (a *Agent) GetTools() []tool.BaseTool {
return a.toolManager.GetTools()
}
// GetLoadingMessage returns the loading message from provider creation (e.g., GPU fallback info)
func (a *Agent) GetLoadingMessage() string {
return a.loadingMessage
}
// generateWithCancellation calls the LLM with ESC key cancellation support
func (a *Agent) generateWithCancellation(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo) (*schema.Message, error) {
// Create a cancellable context for just this LLM call
+232 -14
View File
@@ -9,11 +9,13 @@ import (
"net/http"
"os"
"strings"
"time"
"github.com/cloudwego/eino-ext/components/model/claude"
"github.com/cloudwego/eino-ext/components/model/ollama"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/model"
"github.com/mark3labs/mcphost/internal/ui/progress"
"github.com/ollama/ollama/api"
"google.golang.org/genai"
@@ -76,8 +78,14 @@ type ProviderConfig struct {
MainGPU *int32
}
// ProviderResult contains the result of provider creation
type ProviderResult struct {
Model model.ToolCallingChatModel
Message string // Optional message for user feedback (e.g., GPU fallback info)
}
// CreateProvider creates an eino ToolCallingChatModel based on the provider configuration
func CreateProvider(ctx context.Context, config *ProviderConfig) (model.ToolCallingChatModel, error) {
func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResult, error) {
parts := strings.SplitN(config.ModelString, ":", 2)
if len(parts) < 2 {
return nil, fmt.Errorf("invalid model format. Expected provider:model, got %s", config.ModelString)
@@ -119,15 +127,31 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (model.ToolCall
switch provider {
case "anthropic":
return createAnthropicProvider(ctx, config, modelName)
model, err := createAnthropicProvider(ctx, config, modelName)
if err != nil {
return nil, err
}
return &ProviderResult{Model: model, Message: ""}, nil
case "openai":
return createOpenAIProvider(ctx, config, modelName)
model, err := createOpenAIProvider(ctx, config, modelName)
if err != nil {
return nil, err
}
return &ProviderResult{Model: model, Message: ""}, nil
case "google":
return createGoogleProvider(ctx, config, modelName)
model, err := createGoogleProvider(ctx, config, modelName)
if err != nil {
return nil, err
}
return &ProviderResult{Model: model, Message: ""}, nil
case "ollama":
return createOllamaProvider(ctx, config, modelName)
return createOllamaProviderWithResult(ctx, config, modelName)
case "azure":
return createAzureOpenAIProvider(ctx, config, modelName)
model, err := createAzureOpenAIProvider(ctx, config, modelName)
if err != nil {
return nil, err
}
return &ProviderResult{Model: model, Message: ""}, nil
default:
return nil, fmt.Errorf("unsupported provider: %s", provider)
}
@@ -353,7 +377,175 @@ func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName
return gemini.NewChatModel(ctx, geminiConfig)
}
func createOllamaProvider(ctx context.Context, config *ProviderConfig, modelName string) (model.ToolCallingChatModel, error) {
// OllamaLoadingResult contains the result of model loading with actual settings used
type OllamaLoadingResult struct {
Options *api.Options
Message string
}
// loadOllamaModelWithFallback loads an Ollama model with GPU settings and automatic CPU fallback
func loadOllamaModelWithFallback(ctx context.Context, baseURL, modelName string, options *api.Options) (*OllamaLoadingResult, error) {
client := &http.Client{}
// Phase 1: Check if model exists locally
if err := checkOllamaModelExists(client, baseURL, modelName); err != nil {
// Phase 2: Pull model if not found
if err := pullOllamaModel(ctx, client, baseURL, modelName); err != nil {
return nil, fmt.Errorf("failed to pull model %s: %v", modelName, err)
}
}
// Phase 3: Load model with GPU settings
_, err := loadOllamaModelWithOptions(ctx, client, baseURL, modelName, options)
if err != nil {
// Phase 4: Fallback to CPU if GPU memory insufficient
if isGPUMemoryError(err) {
cpuOptions := *options
cpuOptions.NumGPU = 0
_, cpuErr := loadOllamaModelWithOptions(ctx, client, baseURL, modelName, &cpuOptions)
if cpuErr != nil {
return nil, fmt.Errorf("failed to load model on GPU (%v) and CPU fallback failed (%v)", err, cpuErr)
}
return &OllamaLoadingResult{
Options: &cpuOptions,
Message: "Insufficient GPU memory, falling back to CPU inference",
}, nil
}
return nil, err
}
return &OllamaLoadingResult{
Options: options,
Message: "Model loaded successfully on GPU",
}, nil
}
// checkOllamaModelExists checks if a model exists locally
func checkOllamaModelExists(client *http.Client, baseURL, modelName string) error {
reqBody := map[string]string{"model": modelName}
jsonBody, _ := json.Marshal(reqBody)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/api/show", bytes.NewBuffer(jsonBody))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("model not found locally")
}
return nil
}
// pullOllamaModel pulls a model from the registry
func pullOllamaModel(ctx context.Context, client *http.Client, baseURL, modelName string) error {
return pullOllamaModelWithProgress(ctx, client, baseURL, modelName, true)
}
// pullOllamaModelWithProgress pulls a model from the registry with optional progress display
func pullOllamaModelWithProgress(ctx context.Context, client *http.Client, baseURL, modelName string, showProgress bool) error {
reqBody := map[string]string{"name": modelName}
jsonBody, _ := json.Marshal(reqBody)
// Use a longer timeout for pulling models (5 minutes)
pullCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
req, err := http.NewRequestWithContext(pullCtx, "POST", baseURL+"/api/pull", bytes.NewBuffer(jsonBody))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("failed to pull model (status %d): %s", resp.StatusCode, string(body))
}
// Read the streaming response with optional progress display
if showProgress {
progressReader := progress.NewProgressReader(resp.Body)
defer progressReader.Close()
_, err = io.ReadAll(progressReader)
} else {
_, err = io.ReadAll(resp.Body)
}
return err
}
// loadOllamaModelWithOptions loads a model with specific options using a warmup request
func loadOllamaModelWithOptions(ctx context.Context, client *http.Client, baseURL, modelName string, options *api.Options) (*api.Options, error) {
// Create a copy of options for warmup to avoid modifying the original
warmupOptions := *options
warmupOptions.NumPredict = 1 // Limit response length for warmup
reqBody := map[string]interface{}{
"model": modelName,
"prompt": "Hello",
"stream": false,
"options": &warmupOptions,
}
jsonBody, _ := json.Marshal(reqBody)
// Use medium timeout for warmup (30 seconds)
warmupCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(warmupCtx, "POST", baseURL+"/api/generate", bytes.NewBuffer(jsonBody))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("warmup request failed (status %d): %s", resp.StatusCode, string(body))
}
// Read response to completion
_, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return options, nil
}
// isGPUMemoryError checks if an error indicates insufficient GPU memory
func isGPUMemoryError(err error) bool {
errStr := strings.ToLower(err.Error())
return strings.Contains(errStr, "out of memory") ||
strings.Contains(errStr, "insufficient memory") ||
strings.Contains(errStr, "cuda out of memory") ||
strings.Contains(errStr, "gpu memory")
}
func createOllamaProviderWithResult(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
baseURL := "http://localhost:11434" // Default Ollama URL
// Check for custom Ollama host from environment
@@ -366,11 +558,6 @@ func createOllamaProvider(ctx context.Context, config *ProviderConfig, modelName
baseURL = config.ProviderURL
}
ollamaConfig := &ollama.ChatModelConfig{
BaseURL: baseURL,
Model: modelName,
}
// Set up options for Ollama using the api.Options struct
options := &api.Options{}
@@ -403,9 +590,40 @@ func createOllamaProvider(ctx context.Context, config *ProviderConfig, modelName
options.MainGPU = int(*config.MainGPU)
}
ollamaConfig.Options = options
// Create a clean copy of options for the final model
finalOptions := &api.Options{}
*finalOptions = *options // Copy all fields
return ollama.NewChatModel(ctx, ollamaConfig)
// Try to pre-load the model with GPU settings and automatic CPU fallback
// If this fails, fall back to the original behavior
loadingResult, err := loadOllamaModelWithFallback(ctx, baseURL, modelName, options)
var loadingMessage string
if err != nil {
// Pre-loading failed, use original options and no message
loadingMessage = ""
} else {
// Pre-loading succeeded, update GPU settings that worked
finalOptions.NumGPU = loadingResult.Options.NumGPU
finalOptions.MainGPU = loadingResult.Options.MainGPU
loadingMessage = loadingResult.Message
}
ollamaConfig := &ollama.ChatModelConfig{
BaseURL: baseURL,
Model: modelName,
Options: finalOptions,
}
chatModel, err := ollama.NewChatModel(ctx, ollamaConfig)
if err != nil {
return nil, err
}
return &ProviderResult{
Model: chatModel,
Message: loadingMessage,
}, nil
}
// createOAuthHTTPClient creates an HTTP client that adds OAuth headers for Anthropic API
+267
View File
@@ -0,0 +1,267 @@
package progress
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"sync"
"time"
"github.com/charmbracelet/bubbles/progress"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
var helpStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#626262")).Render
const (
padding = 2
maxWidth = 80
)
// OllamaPullProgress represents the progress information from Ollama pull API
type OllamaPullProgress struct {
Status string `json:"status"`
Digest string `json:"digest,omitempty"`
Total int64 `json:"total,omitempty"`
Completed int64 `json:"completed,omitempty"`
}
// progressMsg represents progress updates
type progressMsg struct {
percent float64
status string
}
// progressErrMsg represents errors during progress
type progressErrMsg struct{ err error }
// progressCompleteMsg indicates completion
type progressCompleteMsg struct{}
// ProgressModel represents the progress bar model
type ProgressModel struct {
progress progress.Model
status string
err error
complete bool
}
// NewProgressModel creates a new progress model
func NewProgressModel() ProgressModel {
return ProgressModel{
progress: progress.New(progress.WithDefaultGradient()),
status: "Initializing...",
}
}
// Init initializes the progress model
func (m ProgressModel) Init() tea.Cmd {
return nil
}
// Update handles progress updates
func (m ProgressModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyMsg:
if msg.String() == "q" || msg.String() == "ctrl+c" {
return m, tea.Quit
}
return m, nil
case tea.WindowSizeMsg:
m.progress.Width = msg.Width - padding*2 - 4
if m.progress.Width > maxWidth {
m.progress.Width = maxWidth
}
return m, nil
case progressErrMsg:
m.err = msg.err
return m, tea.Quit
case progressCompleteMsg:
m.complete = true
return m, tea.Quit
case progressMsg:
var cmds []tea.Cmd
m.status = msg.status
if msg.percent >= 1.0 {
m.complete = true
cmds = append(cmds, tea.Quit)
}
cmds = append(cmds, m.progress.SetPercent(msg.percent))
return m, tea.Batch(cmds...)
case progress.FrameMsg:
progressModel, cmd := m.progress.Update(msg)
m.progress = progressModel.(progress.Model)
return m, cmd
default:
return m, nil
}
}
// View renders the progress bar
func (m ProgressModel) View() string {
if m.err != nil {
return fmt.Sprintf("Error: %s\n", m.err.Error())
}
if m.complete {
return fmt.Sprintf("\n%s%s\n\n%sComplete!\n",
strings.Repeat(" ", padding),
m.progress.View(),
strings.Repeat(" ", padding))
}
pad := strings.Repeat(" ", padding)
return fmt.Sprintf("\n%s%s\n%s%s\n\n%s",
pad, m.progress.View(),
pad, m.status,
pad+helpStyle("Press 'q' or Ctrl+C to cancel"))
}
// ProgressReader wraps an io.Reader to provide progress updates for Ollama pull operations
type ProgressReader struct {
reader io.Reader
program *tea.Program
model ProgressModel
lastLine string
done chan struct{}
wg sync.WaitGroup
}
// NewProgressReader creates a new progress reader for Ollama pull operations
func NewProgressReader(reader io.Reader) *ProgressReader {
model := NewProgressModel()
// Create program with standard settings
program := tea.NewProgram(model)
pr := &ProgressReader{
reader: reader,
program: program,
model: model,
done: make(chan struct{}),
}
// Start the TUI in a goroutine
pr.wg.Add(1)
go func() {
defer pr.wg.Done()
if _, err := program.Run(); err != nil {
// Handle error silently for now
}
close(pr.done)
}()
return pr
}
// Read implements io.Reader and parses Ollama streaming responses
func (pr *ProgressReader) Read(p []byte) (n int, err error) {
n, err = pr.reader.Read(p)
if n > 0 {
// Parse the JSON lines for progress information
data := string(p[:n])
pr.lastLine += data
// Process complete lines
for {
lineEnd := strings.Index(pr.lastLine, "\n")
if lineEnd == -1 {
break
}
line := strings.TrimSpace(pr.lastLine[:lineEnd])
pr.lastLine = pr.lastLine[lineEnd+1:]
if line != "" {
pr.parseProgressLine(line)
}
}
}
if err == io.EOF {
// Send completion message and ensure program quits
pr.program.Send(progressCompleteMsg{})
}
return n, err
}
// parseProgressLine parses a single JSON line from Ollama pull response
func (pr *ProgressReader) parseProgressLine(line string) {
var progress OllamaPullProgress
if err := json.Unmarshal([]byte(line), &progress); err != nil {
return // Ignore malformed JSON
}
var percent float64
status := progress.Status
// Calculate progress percentage if we have total and completed
if progress.Total > 0 && progress.Completed >= 0 {
percent = float64(progress.Completed) / float64(progress.Total)
// Format status with progress info
if progress.Digest != "" {
status = fmt.Sprintf("%s (%s)", progress.Status, progress.Digest[:12])
}
// Add size information
if progress.Total > 0 {
totalMB := float64(progress.Total) / (1024 * 1024)
completedMB := float64(progress.Completed) / (1024 * 1024)
status = fmt.Sprintf("%s - %.1f/%.1f MB", status, completedMB, totalMB)
}
} else {
// For status-only updates (like "pulling manifest"), show indeterminate progress
if strings.Contains(strings.ToLower(progress.Status), "pulling") ||
strings.Contains(strings.ToLower(progress.Status), "downloading") {
// Keep current progress or show small progress for activity
percent = 0.1
} else if strings.Contains(strings.ToLower(progress.Status), "success") ||
strings.Contains(strings.ToLower(progress.Status), "complete") {
percent = 1.0
}
}
pr.program.Send(progressMsg{
percent: percent,
status: status,
})
}
// Close stops the progress display and waits for cleanup
func (pr *ProgressReader) Close() error {
// Send completion message to trigger quit
pr.program.Send(progressCompleteMsg{})
// Wait for the program to finish with timeout
done := make(chan struct{})
go func() {
pr.wg.Wait()
close(done)
}()
// Wait for completion or timeout after 2 seconds
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
select {
case <-done:
// Program finished normally
case <-ctx.Done():
// Timeout - force kill the program
pr.program.Kill()
}
return nil
}
+1 -1
View File
@@ -64,7 +64,7 @@ func (m spinnerModel) View() string {
Foreground(theme.Text).
Italic(true)
return fmt.Sprintf("%s %s",
return fmt.Sprintf(" %s %s",
spinnerStyle.Render(m.spinner.View()),
messageStyle.Render(m.message))
}