From c9637090faf083f7313f75340aa6c75aa7ada685 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 2 Apr 2026 14:45:03 +0300 Subject: [PATCH] feat(subagent): return early error for invalid model instead of silent fallback - Add ValidateModelString() to ModelsRegistry for format, provider, and model name validation with typo suggestions - Validate model in Kit.Subagent() before expensive Kit.New() setup - Remove silent fallback to parent model on creation failure - Error propagates as tool result so calling agent can self-correct - Add registry_test.go covering format, provider, and suggestion cases --- internal/models/registry.go | 46 ++++++++++++++++ internal/models/registry_test.go | 92 ++++++++++++++++++++++++++++++++ pkg/kit/kit.go | 29 +++++----- 3 files changed, 151 insertions(+), 16 deletions(-) create mode 100644 internal/models/registry_test.go diff --git a/internal/models/registry.go b/internal/models/registry.go index a5c3bd47..06a153f4 100644 --- a/internal/models/registry.go +++ b/internal/models/registry.go @@ -400,6 +400,52 @@ func (r *ModelsRegistry) GetProviderInfo(provider string) *ProviderInfo { return &info } +// ValidateModelString checks whether a model string is well-formed and refers +// to a known provider. It returns a user-friendly error with suggestions when +// the model or provider is unrecognised. Passing validation does not guarantee +// that API authentication will succeed — it only catches obvious mistakes +// (typos, missing provider prefix, non-existent provider names) early so that +// callers such as subagent spawning can return fast feedback. +// +// Unknown models under a known provider are allowed (the provider API is the +// authority), but a completely unknown provider is rejected. +func (r *ModelsRegistry) ValidateModelString(modelString string) error { + provider, modelName, err := ParseModelString(modelString) + if err != nil { + return err + } + + // Ollama and custom are always valid — model names are user-defined. + if provider == "ollama" || provider == "custom" { + return nil + } + + // Check if the provider exists in the registry. + providerInfo := r.GetProviderInfo(provider) + if providerInfo == nil { + known := r.GetSupportedProviders() + return fmt.Errorf( + "unknown provider %q in model string %q. Known providers: %s", + provider, modelString, strings.Join(known, ", "), + ) + } + + // Provider exists — check if the model is known. An unknown model is + // only a warning (the provider API decides), but we surface suggestions + // so the caller can self-correct. + if r.LookupModel(provider, modelName) == nil { + if suggestions := r.SuggestModels(provider, modelName); len(suggestions) > 0 { + return fmt.Errorf( + "model %q not found for provider %s. Did you mean one of: %s", + modelName, provider, strings.Join(suggestions, ", "), + ) + } + // No suggestions — let it through; the provider API is the authority. + } + + return nil +} + // Global registry instance var globalRegistry = NewModelsRegistry() diff --git a/internal/models/registry_test.go b/internal/models/registry_test.go new file mode 100644 index 00000000..629323a5 --- /dev/null +++ b/internal/models/registry_test.go @@ -0,0 +1,92 @@ +package models + +import ( + "strings" + "testing" +) + +func TestValidateModelString(t *testing.T) { + registry := GetGlobalRegistry() + + tests := []struct { + name string + model string + wantErr bool + errSubstr string // expected substring in error message (empty = don't check) + }{ + { + name: "valid anthropic model", + model: "anthropic/claude-sonnet-4-6", + wantErr: false, + }, + { + name: "missing provider prefix", + model: "claude-sonnet-4-6", + wantErr: true, + errSubstr: "invalid model format", + }, + { + name: "empty string", + model: "", + wantErr: true, + errSubstr: "invalid model format", + }, + { + name: "unknown provider", + model: "fakeprovider/some-model", + wantErr: true, + errSubstr: "unknown provider", + }, + { + name: "ollama always valid", + model: "ollama/llama3", + wantErr: false, + }, + { + name: "custom always valid", + model: "custom/my-fine-tune", + wantErr: false, + }, + { + name: "empty provider", + model: "/claude-sonnet-4-6", + wantErr: true, + errSubstr: "invalid model format", + }, + { + name: "empty model name", + model: "anthropic/", + wantErr: true, + errSubstr: "invalid model format", + }, + { + name: "unknown model under known provider (no suggestions)", + model: "anthropic/totally-unknown-xyz-999", + wantErr: false, // no suggestions → passes through + }, + { + name: "typo model under known provider with suggestions", + model: "anthropic/claude-sonet", // misspelled "sonnet" + wantErr: true, + errSubstr: "Did you mean", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := registry.ValidateModelString(tt.model) + if tt.wantErr && err == nil { + t.Errorf("ValidateModelString(%q) = nil, want error", tt.model) + } + if !tt.wantErr && err != nil { + t.Errorf("ValidateModelString(%q) = %v, want nil", tt.model, err) + } + if tt.errSubstr != "" && err != nil { + if !strings.Contains(err.Error(), tt.errSubstr) { + t.Errorf("ValidateModelString(%q) error = %q, want substring %q", + tt.model, err.Error(), tt.errSubstr) + } + } + }) + } +} diff --git a/pkg/kit/kit.go b/pkg/kit/kit.go index 2bce6cd3..7d5d4ef3 100644 --- a/pkg/kit/kit.go +++ b/pkg/kit/kit.go @@ -920,6 +920,17 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult } } + // Early validation: check model format and provider before doing any + // expensive work (MCP init, system prompt composition, etc.). This + // gives the calling agent immediate feedback it can act on — e.g. + // correcting a typo — instead of waiting for a full Kit.New() cycle + // that silently falls back to the parent model. + if model != m.modelString { + if err := models.GetGlobalRegistry().ValidateModelString(model); err != nil { + return nil, fmt.Errorf("invalid subagent model %q: %w", model, err) + } + } + // Default system prompt. systemPrompt := cfg.SystemPrompt if systemPrompt == "" { @@ -932,9 +943,7 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult tools = SubagentTools() } - // Create child Kit instance. If the requested model fails (bad name, - // unsupported provider, etc.), fall back to the parent's model so the - // agent gets a useful error message instead of a hard failure. + // Create child Kit instance. childOpts := &Options{ Model: model, SystemPrompt: systemPrompt, @@ -943,19 +952,7 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult Quiet: true, } child, err := New(ctx, childOpts) - if err != nil && model != m.modelString { - // Model-specific failure — retry with parent's model. - childOpts.Model = m.modelString - child, err = New(ctx, childOpts) - if err != nil { - return nil, fmt.Errorf("failed to create subagent: %w", err) - } - // Prepend a note so the agent knows which model is actually running. - cfg.Prompt = fmt.Sprintf( - "[Note: requested model %q was not available, using %s instead.]\n\n%s", - model, m.modelString, cfg.Prompt, - ) - } else if err != nil { + if err != nil { return nil, fmt.Errorf("failed to create subagent: %w", err) } defer func() { _ = child.Close() }()