mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-13 19:20:06 +00:00
Improve Ollama and GPU handling (#85)
* preload ollama models * fix num-gpu * add download progress bar for ollama models * fmt * reorg
This commit is contained in:
+33
-8
@@ -182,8 +182,9 @@ func init() {
|
||||
flags.StringSliceVar(&stopSequences, "stop-sequences", nil, "custom stop sequences (comma-separated)")
|
||||
|
||||
// Ollama-specific parameters
|
||||
flags.Int32Var(&numGPU, "num-gpu", 1, "number of GPUs to use for Ollama models")
|
||||
flags.Int32Var(&mainGPU, "main-gpu", 0, "main GPU to use for Ollama models")
|
||||
flags.Int32Var(&numGPU, "num-gpu-layers", -1, "number of model layers to offload to GPU for Ollama models (-1 for auto-detect)")
|
||||
flags.MarkHidden("num-gpu-layers") // Advanced option, hidden from help
|
||||
flags.Int32Var(&mainGPU, "main-gpu", 0, "main GPU device to use for Ollama models")
|
||||
|
||||
// Bind flags to viper for config file support
|
||||
viper.BindPFlag("system-prompt", rootCmd.PersistentFlags().Lookup("system-prompt"))
|
||||
@@ -198,7 +199,7 @@ func init() {
|
||||
viper.BindPFlag("top-p", rootCmd.PersistentFlags().Lookup("top-p"))
|
||||
viper.BindPFlag("top-k", rootCmd.PersistentFlags().Lookup("top-k"))
|
||||
viper.BindPFlag("stop-sequences", rootCmd.PersistentFlags().Lookup("stop-sequences"))
|
||||
viper.BindPFlag("num-gpu", rootCmd.PersistentFlags().Lookup("num-gpu"))
|
||||
viper.BindPFlag("num-gpu-layers", rootCmd.PersistentFlags().Lookup("num-gpu-layers"))
|
||||
viper.BindPFlag("main-gpu", rootCmd.PersistentFlags().Lookup("main-gpu"))
|
||||
|
||||
// Defaults are already set in flag definitions, no need to duplicate in viper
|
||||
@@ -265,7 +266,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
temperature := float32(viper.GetFloat64("temperature"))
|
||||
topP := float32(viper.GetFloat64("top-p"))
|
||||
topK := int32(viper.GetInt("top-k"))
|
||||
numGPU := int32(viper.GetInt("num-gpu"))
|
||||
numGPU := int32(viper.GetInt("num-gpu-layers"))
|
||||
mainGPU := int32(viper.GetInt("main-gpu"))
|
||||
|
||||
modelConfig := &models.ProviderConfig{
|
||||
@@ -290,8 +291,27 @@ func runNormalMode(ctx context.Context) error {
|
||||
MaxSteps: viper.GetInt("max-steps"), // Pass 0 for infinite, agent will handle it
|
||||
}
|
||||
|
||||
// Create the agent
|
||||
mcpAgent, err := agent.NewAgent(ctx, agentConfig)
|
||||
// Create the agent with spinner for Ollama models
|
||||
var mcpAgent *agent.Agent
|
||||
|
||||
if strings.HasPrefix(viper.GetString("model"), "ollama:") && !quietFlag {
|
||||
// Create a temporary CLI for the spinner
|
||||
tempCli, tempErr := ui.NewCLI(viper.GetBool("debug"))
|
||||
if tempErr == nil {
|
||||
err = tempCli.ShowSpinner("Loading Ollama model...", func() error {
|
||||
var agentErr error
|
||||
mcpAgent, agentErr = agent.NewAgent(ctx, agentConfig)
|
||||
return agentErr
|
||||
})
|
||||
} else {
|
||||
// Fallback without spinner
|
||||
mcpAgent, err = agent.NewAgent(ctx, agentConfig)
|
||||
}
|
||||
} else {
|
||||
// No spinner for other providers
|
||||
mcpAgent, err = agent.NewAgent(ctx, agentConfig)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create agent: %v", err)
|
||||
}
|
||||
@@ -344,8 +364,13 @@ func runNormalMode(ctx context.Context) error {
|
||||
if len(parts) == 2 {
|
||||
cli.DisplayInfo(fmt.Sprintf("Model loaded: %s (%s)", parts[0], parts[1]))
|
||||
}
|
||||
cli.DisplayInfo(fmt.Sprintf("Loaded %d tools from MCP servers", len(tools)))
|
||||
|
||||
// Display loading message if available (e.g., GPU fallback info)
|
||||
if loadingMessage := mcpAgent.GetLoadingMessage(); loadingMessage != "" {
|
||||
cli.DisplayInfo(loadingMessage)
|
||||
}
|
||||
|
||||
cli.DisplayInfo(fmt.Sprintf("Loaded %d tools from MCP servers", len(tools)))
|
||||
// Display debug configuration if debug mode is enabled
|
||||
if viper.GetBool("debug") {
|
||||
debugConfig := map[string]any{
|
||||
@@ -361,7 +386,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
|
||||
// Add Ollama-specific parameters if using Ollama
|
||||
if strings.HasPrefix(viper.GetString("model"), "ollama:") {
|
||||
debugConfig["num-gpu"] = viper.GetInt("num-gpu")
|
||||
debugConfig["num-gpu-layers"] = viper.GetInt("num-gpu-layers")
|
||||
debugConfig["main-gpu"] = viper.GetInt("main-gpu")
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ require (
|
||||
github.com/cloudwego/eino-ext/components/model/claude v0.0.0-20250609074000-b7f307dffa18
|
||||
github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250609074000-b7f307dffa18
|
||||
github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250609074000-b7f307dffa18
|
||||
github.com/getkin/kin-openapi v0.131.0
|
||||
github.com/mark3labs/mcp-filesystem-server v0.11.1
|
||||
github.com/mark3labs/mcp-go v0.32.0
|
||||
github.com/ollama/ollama v0.5.12
|
||||
@@ -25,6 +24,8 @@ require (
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require github.com/getkin/kin-openapi v0.118.0
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.116.0 // indirect
|
||||
cloud.google.com/go/auth v0.15.0 // indirect
|
||||
@@ -51,6 +52,7 @@ require (
|
||||
github.com/bytedance/sonic/loader v0.2.4 // indirect
|
||||
github.com/catppuccin/go v0.2.0 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
github.com/charmbracelet/harmonica v0.2.0 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf // indirect
|
||||
github.com/cloudwego/base64x v0.1.5 // indirect
|
||||
@@ -76,6 +78,7 @@ require (
|
||||
github.com/goph/emperror v0.17.2 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/invopop/yaml v0.1.0 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
|
||||
@@ -87,8 +90,6 @@ require (
|
||||
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
|
||||
github.com/muesli/reflow v0.3.0 // indirect
|
||||
github.com/nikolalohinski/gonja v1.5.3 // indirect
|
||||
github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 // indirect
|
||||
github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
||||
github.com/perimeterx/marshmallow v1.1.5 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
@@ -127,8 +128,8 @@ require (
|
||||
|
||||
require (
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/charmbracelet/bubbles v0.20.0
|
||||
github.com/charmbracelet/bubbletea v1.2.4
|
||||
github.com/charmbracelet/bubbles v0.21.0
|
||||
github.com/charmbracelet/bubbletea v1.3.5
|
||||
github.com/charmbracelet/glamour v0.10.0
|
||||
github.com/charmbracelet/x/ansi v0.8.0 // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
|
||||
@@ -73,14 +73,16 @@ github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFos
|
||||
github.com/catppuccin/go v0.2.0 h1:ktBeIrIP42b/8FGiScP9sgrWOss3lw0Z5SktRoithGA=
|
||||
github.com/catppuccin/go v0.2.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
|
||||
github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4=
|
||||
github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE=
|
||||
github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU=
|
||||
github.com/charmbracelet/bubbletea v1.2.4 h1:KN8aCViA0eps9SCOThb2/XPIlea3ANJLUkv3KnQRNCE=
|
||||
github.com/charmbracelet/bubbletea v1.2.4/go.mod h1:Qr6fVQw+wX7JkWWkVyXYk/ZUQ92a6XNekLXa3rR18MM=
|
||||
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
|
||||
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
|
||||
github.com/charmbracelet/bubbletea v1.3.5 h1:JAMNLTbqMOhSwoELIr0qyP4VidFq72/6E9j7HHmRKQc=
|
||||
github.com/charmbracelet/bubbletea v1.3.5/go.mod h1:TkCnmH+aBd4LrXhXcqrKiYwRs7qyQx5rBgH5fVY3v54=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
|
||||
github.com/charmbracelet/glamour v0.10.0 h1:MtZvfwsYCx8jEPFJm3rIBFIMZUfUJ765oX8V6kXldcY=
|
||||
github.com/charmbracelet/glamour v0.10.0/go.mod h1:f+uf+I/ChNmqo087elLnVdCiVgjSKWuXa/l6NU2ndYk=
|
||||
github.com/charmbracelet/harmonica v0.2.0 h1:8NxJWRWg/bzKqqEaaeFNipOu77YR5t8aSwG4pgaUBiQ=
|
||||
github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao=
|
||||
github.com/charmbracelet/huh v0.3.0 h1:CxPplWkgW2yUTDDG0Z4S5HH8SJOosWHd4LxCvi0XsKE=
|
||||
github.com/charmbracelet/huh v0.3.0/go.mod h1:fujUdKX8tC45CCSaRQdw789O6uaCRwx8l2NDyKfC4jA=
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE=
|
||||
@@ -89,8 +91,8 @@ github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2ll
|
||||
github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13 h1:/KBBKHuVRbq1lYx5BzEHBAFBP8VcQzJejZ/IA3iR28k=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b h1:MnAMdlwSltxJyULnrYbkZpp4k58Co7Tah3ciKhSNo0Q=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf h1:rLG0Yb6MQSDKdB52aGX55JT1oi0P0Kuaj7wi1bLUpnI=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf/go.mod h1:B3UgsnsBZS/eX42BlaNiJkD1pPOUa+oF1IYC6Yd2CEU=
|
||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||
@@ -131,8 +133,8 @@ github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/
|
||||
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY=
|
||||
github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok=
|
||||
github.com/getkin/kin-openapi v0.131.0 h1:NO2UeHnFKRYhZ8wg6Nyh5Cq7dHk4suQQr72a4pMrDxE=
|
||||
github.com/getkin/kin-openapi v0.131.0/go.mod h1:3OlG51PCYNsPByuiMB0t4fjnNlIDnaEDsjiKUV8nL58=
|
||||
github.com/getkin/kin-openapi v0.118.0 h1:z43njxPmJ7TaPpMSCQb7PN0dEYno4tyBPQcrFdHoLuM=
|
||||
github.com/getkin/kin-openapi v0.118.0/go.mod h1:l5e9PaFUo9fyLJCPGQeXI2ML8c3P8BHOEV2VaAVf/pc=
|
||||
github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ=
|
||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI=
|
||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98=
|
||||
@@ -141,8 +143,10 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
|
||||
github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ=
|
||||
github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY=
|
||||
github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
|
||||
github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE=
|
||||
github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
|
||||
github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM=
|
||||
@@ -173,6 +177,7 @@ github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25d
|
||||
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
@@ -180,6 +185,8 @@ github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSo
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/invopop/yaml v0.1.0 h1:YW3WGUoJEXYfzWBjn00zIlrw7brGVD0fUKRYDPAPhrc=
|
||||
github.com/invopop/yaml v0.1.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q=
|
||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
@@ -202,6 +209,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
||||
github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
||||
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mark3labs/mcp-filesystem-server v0.11.1 h1:7uKIZRMaKWfgvtDj/uLAvo0+7Mwb8gxo5DJywhqFW88=
|
||||
@@ -240,10 +249,6 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c=
|
||||
github.com/nikolalohinski/gonja v1.5.3/go.mod h1:RmjwxNiXAEqcq1HeK5SSMmqFJvKOfTfXhkJv6YBtPa4=
|
||||
github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 h1:G7ERwszslrBzRxj//JalHPu/3yz+De2J+4aLtSRlHiY=
|
||||
github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037/go.mod h1:2bpvgLBZEtENV5scfDFEtB/5+1M4hkQhDQrccEJ/qGw=
|
||||
github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 h1:bQx3WeLcUWy+RletIKwUIt4x3t8n2SxavmoclizMb8c=
|
||||
github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o=
|
||||
github.com/ollama/ollama v0.5.12 h1:qM+k/ozyHLJzEQoAEPrUQ0qXqsgDEEdpIVwuwScrd2U=
|
||||
github.com/ollama/ollama v0.5.12/go.mod h1:ibdmDvb/TjKY1OArBWIazL3pd1DHTk8eG2MMjEkWhiI=
|
||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
@@ -251,6 +256,7 @@ github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W
|
||||
github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
|
||||
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
|
||||
github.com/perimeterx/marshmallow v1.1.4/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
|
||||
github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s=
|
||||
github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@@ -323,6 +329,9 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo=
|
||||
github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=
|
||||
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg=
|
||||
@@ -463,6 +472,7 @@ google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd
|
||||
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
||||
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
@@ -473,6 +483,7 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
|
||||
|
||||
+16
-9
@@ -41,16 +41,17 @@ type ToolCallContentHandler func(content string)
|
||||
|
||||
// Agent is the agent with real-time tool call display.
|
||||
type Agent struct {
|
||||
toolManager *tools.MCPToolManager
|
||||
model model.ToolCallingChatModel
|
||||
maxSteps int
|
||||
systemPrompt string
|
||||
toolManager *tools.MCPToolManager
|
||||
model model.ToolCallingChatModel
|
||||
maxSteps int
|
||||
systemPrompt string
|
||||
loadingMessage string // Message from provider loading (e.g., GPU fallback info)
|
||||
}
|
||||
|
||||
// NewAgent creates an agent with MCP tool integration and real-time tool call display
|
||||
func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) {
|
||||
// Create the LLM provider
|
||||
model, err := models.CreateProvider(ctx, config.ModelConfig)
|
||||
providerResult, err := models.CreateProvider(ctx, config.ModelConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create model provider: %v", err)
|
||||
}
|
||||
@@ -62,10 +63,11 @@ func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) {
|
||||
}
|
||||
|
||||
return &Agent{
|
||||
toolManager: toolManager,
|
||||
model: model,
|
||||
maxSteps: config.MaxSteps, // Keep 0 for infinite, handle in loop
|
||||
systemPrompt: config.SystemPrompt,
|
||||
toolManager: toolManager,
|
||||
model: providerResult.Model,
|
||||
maxSteps: config.MaxSteps, // Keep 0 for infinite, handle in loop
|
||||
systemPrompt: config.SystemPrompt,
|
||||
loadingMessage: providerResult.Message,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -220,6 +222,11 @@ func (a *Agent) GetTools() []tool.BaseTool {
|
||||
return a.toolManager.GetTools()
|
||||
}
|
||||
|
||||
// GetLoadingMessage returns the loading message from provider creation (e.g., GPU fallback info)
|
||||
func (a *Agent) GetLoadingMessage() string {
|
||||
return a.loadingMessage
|
||||
}
|
||||
|
||||
// generateWithCancellation calls the LLM with ESC key cancellation support
|
||||
func (a *Agent) generateWithCancellation(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo) (*schema.Message, error) {
|
||||
// Create a cancellable context for just this LLM call
|
||||
|
||||
+232
-14
@@ -9,11 +9,13 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/claude"
|
||||
"github.com/cloudwego/eino-ext/components/model/ollama"
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/mark3labs/mcphost/internal/ui/progress"
|
||||
"github.com/ollama/ollama/api"
|
||||
"google.golang.org/genai"
|
||||
|
||||
@@ -76,8 +78,14 @@ type ProviderConfig struct {
|
||||
MainGPU *int32
|
||||
}
|
||||
|
||||
// ProviderResult contains the result of provider creation
|
||||
type ProviderResult struct {
|
||||
Model model.ToolCallingChatModel
|
||||
Message string // Optional message for user feedback (e.g., GPU fallback info)
|
||||
}
|
||||
|
||||
// CreateProvider creates an eino ToolCallingChatModel based on the provider configuration
|
||||
func CreateProvider(ctx context.Context, config *ProviderConfig) (model.ToolCallingChatModel, error) {
|
||||
func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResult, error) {
|
||||
parts := strings.SplitN(config.ModelString, ":", 2)
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("invalid model format. Expected provider:model, got %s", config.ModelString)
|
||||
@@ -119,15 +127,31 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (model.ToolCall
|
||||
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return createAnthropicProvider(ctx, config, modelName)
|
||||
model, err := createAnthropicProvider(ctx, config, modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ProviderResult{Model: model, Message: ""}, nil
|
||||
case "openai":
|
||||
return createOpenAIProvider(ctx, config, modelName)
|
||||
model, err := createOpenAIProvider(ctx, config, modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ProviderResult{Model: model, Message: ""}, nil
|
||||
case "google":
|
||||
return createGoogleProvider(ctx, config, modelName)
|
||||
model, err := createGoogleProvider(ctx, config, modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ProviderResult{Model: model, Message: ""}, nil
|
||||
case "ollama":
|
||||
return createOllamaProvider(ctx, config, modelName)
|
||||
return createOllamaProviderWithResult(ctx, config, modelName)
|
||||
case "azure":
|
||||
return createAzureOpenAIProvider(ctx, config, modelName)
|
||||
model, err := createAzureOpenAIProvider(ctx, config, modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ProviderResult{Model: model, Message: ""}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
||||
}
|
||||
@@ -353,7 +377,175 @@ func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
return gemini.NewChatModel(ctx, geminiConfig)
|
||||
}
|
||||
|
||||
func createOllamaProvider(ctx context.Context, config *ProviderConfig, modelName string) (model.ToolCallingChatModel, error) {
|
||||
// OllamaLoadingResult contains the result of model loading with actual settings used
|
||||
type OllamaLoadingResult struct {
|
||||
Options *api.Options
|
||||
Message string
|
||||
}
|
||||
|
||||
// loadOllamaModelWithFallback loads an Ollama model with GPU settings and automatic CPU fallback
|
||||
func loadOllamaModelWithFallback(ctx context.Context, baseURL, modelName string, options *api.Options) (*OllamaLoadingResult, error) {
|
||||
client := &http.Client{}
|
||||
|
||||
// Phase 1: Check if model exists locally
|
||||
if err := checkOllamaModelExists(client, baseURL, modelName); err != nil {
|
||||
// Phase 2: Pull model if not found
|
||||
if err := pullOllamaModel(ctx, client, baseURL, modelName); err != nil {
|
||||
return nil, fmt.Errorf("failed to pull model %s: %v", modelName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: Load model with GPU settings
|
||||
_, err := loadOllamaModelWithOptions(ctx, client, baseURL, modelName, options)
|
||||
if err != nil {
|
||||
// Phase 4: Fallback to CPU if GPU memory insufficient
|
||||
if isGPUMemoryError(err) {
|
||||
cpuOptions := *options
|
||||
cpuOptions.NumGPU = 0
|
||||
|
||||
_, cpuErr := loadOllamaModelWithOptions(ctx, client, baseURL, modelName, &cpuOptions)
|
||||
if cpuErr != nil {
|
||||
return nil, fmt.Errorf("failed to load model on GPU (%v) and CPU fallback failed (%v)", err, cpuErr)
|
||||
}
|
||||
|
||||
return &OllamaLoadingResult{
|
||||
Options: &cpuOptions,
|
||||
Message: "Insufficient GPU memory, falling back to CPU inference",
|
||||
}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &OllamaLoadingResult{
|
||||
Options: options,
|
||||
Message: "Model loaded successfully on GPU",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// checkOllamaModelExists checks if a model exists locally
|
||||
func checkOllamaModelExists(client *http.Client, baseURL, modelName string) error {
|
||||
reqBody := map[string]string{"model": modelName}
|
||||
jsonBody, _ := json.Marshal(reqBody)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/api/show", bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("model not found locally")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// pullOllamaModel pulls a model from the registry
|
||||
func pullOllamaModel(ctx context.Context, client *http.Client, baseURL, modelName string) error {
|
||||
return pullOllamaModelWithProgress(ctx, client, baseURL, modelName, true)
|
||||
}
|
||||
|
||||
// pullOllamaModelWithProgress pulls a model from the registry with optional progress display
|
||||
func pullOllamaModelWithProgress(ctx context.Context, client *http.Client, baseURL, modelName string, showProgress bool) error {
|
||||
reqBody := map[string]string{"name": modelName}
|
||||
jsonBody, _ := json.Marshal(reqBody)
|
||||
|
||||
// Use a longer timeout for pulling models (5 minutes)
|
||||
pullCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(pullCtx, "POST", baseURL+"/api/pull", bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("failed to pull model (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Read the streaming response with optional progress display
|
||||
if showProgress {
|
||||
progressReader := progress.NewProgressReader(resp.Body)
|
||||
defer progressReader.Close()
|
||||
_, err = io.ReadAll(progressReader)
|
||||
} else {
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// loadOllamaModelWithOptions loads a model with specific options using a warmup request
|
||||
func loadOllamaModelWithOptions(ctx context.Context, client *http.Client, baseURL, modelName string, options *api.Options) (*api.Options, error) {
|
||||
// Create a copy of options for warmup to avoid modifying the original
|
||||
warmupOptions := *options
|
||||
warmupOptions.NumPredict = 1 // Limit response length for warmup
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": modelName,
|
||||
"prompt": "Hello",
|
||||
"stream": false,
|
||||
"options": &warmupOptions,
|
||||
}
|
||||
|
||||
jsonBody, _ := json.Marshal(reqBody)
|
||||
|
||||
// Use medium timeout for warmup (30 seconds)
|
||||
warmupCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(warmupCtx, "POST", baseURL+"/api/generate", bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("warmup request failed (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Read response to completion
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return options, nil
|
||||
}
|
||||
|
||||
// isGPUMemoryError checks if an error indicates insufficient GPU memory
|
||||
func isGPUMemoryError(err error) bool {
|
||||
errStr := strings.ToLower(err.Error())
|
||||
return strings.Contains(errStr, "out of memory") ||
|
||||
strings.Contains(errStr, "insufficient memory") ||
|
||||
strings.Contains(errStr, "cuda out of memory") ||
|
||||
strings.Contains(errStr, "gpu memory")
|
||||
}
|
||||
|
||||
func createOllamaProviderWithResult(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
baseURL := "http://localhost:11434" // Default Ollama URL
|
||||
|
||||
// Check for custom Ollama host from environment
|
||||
@@ -366,11 +558,6 @@ func createOllamaProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
baseURL = config.ProviderURL
|
||||
}
|
||||
|
||||
ollamaConfig := &ollama.ChatModelConfig{
|
||||
BaseURL: baseURL,
|
||||
Model: modelName,
|
||||
}
|
||||
|
||||
// Set up options for Ollama using the api.Options struct
|
||||
options := &api.Options{}
|
||||
|
||||
@@ -403,9 +590,40 @@ func createOllamaProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
options.MainGPU = int(*config.MainGPU)
|
||||
}
|
||||
|
||||
ollamaConfig.Options = options
|
||||
// Create a clean copy of options for the final model
|
||||
finalOptions := &api.Options{}
|
||||
*finalOptions = *options // Copy all fields
|
||||
|
||||
return ollama.NewChatModel(ctx, ollamaConfig)
|
||||
// Try to pre-load the model with GPU settings and automatic CPU fallback
|
||||
// If this fails, fall back to the original behavior
|
||||
loadingResult, err := loadOllamaModelWithFallback(ctx, baseURL, modelName, options)
|
||||
var loadingMessage string
|
||||
|
||||
if err != nil {
|
||||
// Pre-loading failed, use original options and no message
|
||||
loadingMessage = ""
|
||||
} else {
|
||||
// Pre-loading succeeded, update GPU settings that worked
|
||||
finalOptions.NumGPU = loadingResult.Options.NumGPU
|
||||
finalOptions.MainGPU = loadingResult.Options.MainGPU
|
||||
loadingMessage = loadingResult.Message
|
||||
}
|
||||
|
||||
ollamaConfig := &ollama.ChatModelConfig{
|
||||
BaseURL: baseURL,
|
||||
Model: modelName,
|
||||
Options: finalOptions,
|
||||
}
|
||||
|
||||
chatModel, err := ollama.NewChatModel(ctx, ollamaConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ProviderResult{
|
||||
Model: chatModel,
|
||||
Message: loadingMessage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createOAuthHTTPClient creates an HTTP client that adds OAuth headers for Anthropic API
|
||||
|
||||
@@ -0,0 +1,267 @@
|
||||
package progress
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/progress"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
var helpStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#626262")).Render
|
||||
|
||||
const (
|
||||
padding = 2
|
||||
maxWidth = 80
|
||||
)
|
||||
|
||||
// OllamaPullProgress represents the progress information from Ollama pull API
|
||||
type OllamaPullProgress struct {
|
||||
Status string `json:"status"`
|
||||
Digest string `json:"digest,omitempty"`
|
||||
Total int64 `json:"total,omitempty"`
|
||||
Completed int64 `json:"completed,omitempty"`
|
||||
}
|
||||
|
||||
// progressMsg represents progress updates
|
||||
type progressMsg struct {
|
||||
percent float64
|
||||
status string
|
||||
}
|
||||
|
||||
// progressErrMsg represents errors during progress
|
||||
type progressErrMsg struct{ err error }
|
||||
|
||||
// progressCompleteMsg indicates completion
|
||||
type progressCompleteMsg struct{}
|
||||
|
||||
// ProgressModel represents the progress bar model
|
||||
type ProgressModel struct {
|
||||
progress progress.Model
|
||||
status string
|
||||
err error
|
||||
complete bool
|
||||
}
|
||||
|
||||
// NewProgressModel creates a new progress model
|
||||
func NewProgressModel() ProgressModel {
|
||||
return ProgressModel{
|
||||
progress: progress.New(progress.WithDefaultGradient()),
|
||||
status: "Initializing...",
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the progress model
|
||||
func (m ProgressModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update handles progress updates
|
||||
func (m ProgressModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
if msg.String() == "q" || msg.String() == "ctrl+c" {
|
||||
return m, tea.Quit
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case tea.WindowSizeMsg:
|
||||
m.progress.Width = msg.Width - padding*2 - 4
|
||||
if m.progress.Width > maxWidth {
|
||||
m.progress.Width = maxWidth
|
||||
}
|
||||
return m, nil
|
||||
|
||||
case progressErrMsg:
|
||||
m.err = msg.err
|
||||
return m, tea.Quit
|
||||
|
||||
case progressCompleteMsg:
|
||||
m.complete = true
|
||||
return m, tea.Quit
|
||||
|
||||
case progressMsg:
|
||||
var cmds []tea.Cmd
|
||||
m.status = msg.status
|
||||
|
||||
if msg.percent >= 1.0 {
|
||||
m.complete = true
|
||||
cmds = append(cmds, tea.Quit)
|
||||
}
|
||||
|
||||
cmds = append(cmds, m.progress.SetPercent(msg.percent))
|
||||
return m, tea.Batch(cmds...)
|
||||
|
||||
case progress.FrameMsg:
|
||||
progressModel, cmd := m.progress.Update(msg)
|
||||
m.progress = progressModel.(progress.Model)
|
||||
return m, cmd
|
||||
|
||||
default:
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
// View renders the progress bar
|
||||
func (m ProgressModel) View() string {
|
||||
if m.err != nil {
|
||||
return fmt.Sprintf("Error: %s\n", m.err.Error())
|
||||
}
|
||||
|
||||
if m.complete {
|
||||
return fmt.Sprintf("\n%s%s\n\n%sComplete!\n",
|
||||
strings.Repeat(" ", padding),
|
||||
m.progress.View(),
|
||||
strings.Repeat(" ", padding))
|
||||
}
|
||||
|
||||
pad := strings.Repeat(" ", padding)
|
||||
return fmt.Sprintf("\n%s%s\n%s%s\n\n%s",
|
||||
pad, m.progress.View(),
|
||||
pad, m.status,
|
||||
pad+helpStyle("Press 'q' or Ctrl+C to cancel"))
|
||||
}
|
||||
|
||||
// ProgressReader wraps an io.Reader to provide progress updates for Ollama pull operations
|
||||
type ProgressReader struct {
|
||||
reader io.Reader
|
||||
program *tea.Program
|
||||
model ProgressModel
|
||||
lastLine string
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewProgressReader creates a new progress reader for Ollama pull operations
|
||||
func NewProgressReader(reader io.Reader) *ProgressReader {
|
||||
model := NewProgressModel()
|
||||
// Create program with standard settings
|
||||
program := tea.NewProgram(model)
|
||||
|
||||
pr := &ProgressReader{
|
||||
reader: reader,
|
||||
program: program,
|
||||
model: model,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start the TUI in a goroutine
|
||||
pr.wg.Add(1)
|
||||
go func() {
|
||||
defer pr.wg.Done()
|
||||
if _, err := program.Run(); err != nil {
|
||||
// Handle error silently for now
|
||||
}
|
||||
close(pr.done)
|
||||
}()
|
||||
|
||||
return pr
|
||||
}
|
||||
|
||||
// Read implements io.Reader and parses Ollama streaming responses
|
||||
func (pr *ProgressReader) Read(p []byte) (n int, err error) {
|
||||
n, err = pr.reader.Read(p)
|
||||
if n > 0 {
|
||||
// Parse the JSON lines for progress information
|
||||
data := string(p[:n])
|
||||
pr.lastLine += data
|
||||
|
||||
// Process complete lines
|
||||
for {
|
||||
lineEnd := strings.Index(pr.lastLine, "\n")
|
||||
if lineEnd == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
line := strings.TrimSpace(pr.lastLine[:lineEnd])
|
||||
pr.lastLine = pr.lastLine[lineEnd+1:]
|
||||
|
||||
if line != "" {
|
||||
pr.parseProgressLine(line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
// Send completion message and ensure program quits
|
||||
pr.program.Send(progressCompleteMsg{})
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// parseProgressLine parses a single JSON line from Ollama pull response
|
||||
func (pr *ProgressReader) parseProgressLine(line string) {
|
||||
var progress OllamaPullProgress
|
||||
if err := json.Unmarshal([]byte(line), &progress); err != nil {
|
||||
return // Ignore malformed JSON
|
||||
}
|
||||
|
||||
var percent float64
|
||||
status := progress.Status
|
||||
|
||||
// Calculate progress percentage if we have total and completed
|
||||
if progress.Total > 0 && progress.Completed >= 0 {
|
||||
percent = float64(progress.Completed) / float64(progress.Total)
|
||||
|
||||
// Format status with progress info
|
||||
if progress.Digest != "" {
|
||||
status = fmt.Sprintf("%s (%s)", progress.Status, progress.Digest[:12])
|
||||
}
|
||||
|
||||
// Add size information
|
||||
if progress.Total > 0 {
|
||||
totalMB := float64(progress.Total) / (1024 * 1024)
|
||||
completedMB := float64(progress.Completed) / (1024 * 1024)
|
||||
status = fmt.Sprintf("%s - %.1f/%.1f MB", status, completedMB, totalMB)
|
||||
}
|
||||
} else {
|
||||
// For status-only updates (like "pulling manifest"), show indeterminate progress
|
||||
if strings.Contains(strings.ToLower(progress.Status), "pulling") ||
|
||||
strings.Contains(strings.ToLower(progress.Status), "downloading") {
|
||||
// Keep current progress or show small progress for activity
|
||||
percent = 0.1
|
||||
} else if strings.Contains(strings.ToLower(progress.Status), "success") ||
|
||||
strings.Contains(strings.ToLower(progress.Status), "complete") {
|
||||
percent = 1.0
|
||||
}
|
||||
}
|
||||
|
||||
pr.program.Send(progressMsg{
|
||||
percent: percent,
|
||||
status: status,
|
||||
})
|
||||
}
|
||||
|
||||
// Close stops the progress display and waits for cleanup
|
||||
func (pr *ProgressReader) Close() error {
|
||||
// Send completion message to trigger quit
|
||||
pr.program.Send(progressCompleteMsg{})
|
||||
|
||||
// Wait for the program to finish with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
pr.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Wait for completion or timeout after 2 seconds
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Program finished normally
|
||||
case <-ctx.Done():
|
||||
// Timeout - force kill the program
|
||||
pr.program.Kill()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -64,7 +64,7 @@ func (m spinnerModel) View() string {
|
||||
Foreground(theme.Text).
|
||||
Italic(true)
|
||||
|
||||
return fmt.Sprintf("%s %s",
|
||||
return fmt.Sprintf(" %s %s",
|
||||
spinnerStyle.Render(m.spinner.View()),
|
||||
messageStyle.Render(m.message))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user