diff --git a/internal/acpserver/agent.go b/internal/acpserver/agent.go index 1deaf55b..86dbf0ed 100644 --- a/internal/acpserver/agent.go +++ b/internal/acpserver/agent.go @@ -90,6 +90,7 @@ func (a *Agent) NewSession(ctx context.Context, params acp.NewSessionRequest) (a sess, err := a.registry.create(ctx, cwd) if err != nil { + log.Error("acp: session creation failed", "cwd", cwd, "error", err) return acp.NewSessionResponse{}, fmt.Errorf("create session: %w", err) } diff --git a/internal/acpserver/session.go b/internal/acpserver/session.go index 4d99c838..1c3e521e 100644 --- a/internal/acpserver/session.go +++ b/internal/acpserver/session.go @@ -3,6 +3,7 @@ package acpserver import ( "context" "fmt" + "strings" "sync" kit "github.com/mark3labs/kit/pkg/kit" @@ -39,6 +40,12 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession, Streaming: true, }) if err != nil { + // Provide actionable guidance for provider auth errors, which are + // the most common failure mode when running via ACP. + msg := err.Error() + if strings.Contains(msg, "API key") || strings.Contains(msg, "credentials") || strings.Contains(msg, "OAuth") { + return nil, fmt.Errorf("provider authentication failed: %w — run 'kit auth login ' or set the appropriate environment variable before starting 'kit acp'", err) + } return nil, fmt.Errorf("create kit instance: %w", err) } diff --git a/internal/auth/credentials_test.go b/internal/auth/credentials_test.go index dfc3aeca..70e35667 100644 --- a/internal/auth/credentials_test.go +++ b/internal/auth/credentials_test.go @@ -51,6 +51,7 @@ func TestCredentialManager(t *testing.T) { } if creds == nil { t.Fatal("Expected credentials to be returned") + return } if creds.APIKey != testAPIKey { t.Errorf("Expected API key %s, got %s", testAPIKey, creds.APIKey) @@ -236,6 +237,7 @@ func TestCredentialStorePersistence(t *testing.T) { } if creds == nil { t.Fatal("Expected credentials to persist") + return } if creds.APIKey != testAPIKey { t.Errorf("Expected API key %s, got %s", testAPIKey, creds.APIKey) diff --git a/internal/models/providers.go b/internal/models/providers.go index 87a0f0d0..7968e445 100644 --- a/internal/models/providers.go +++ b/internal/models/providers.go @@ -210,10 +210,11 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul } } - // Validate environment variables - if err := registry.ValidateEnvironment(provider, config.ProviderAPIKey); err != nil { - return nil, err - } + // NOTE: We intentionally skip registry.ValidateEnvironment() here. + // Each create*Provider function handles its own auth resolution and + // produces provider-specific error messages. The early env-var check + // was too narrow — it didn't account for stored credentials (e.g. + // OAuth tokens from 'kit auth login') and blocked valid auth paths. // Validate config against known model limits when metadata is available if modelInfo != nil { @@ -1042,9 +1043,21 @@ type oauthTransport struct { } func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Resolve the freshest available token. The credential manager + // automatically refreshes tokens nearing expiry (5-minute buffer). + // This keeps long-lived sessions (e.g. ACP) working across token + // renewals. Falls back to the originally-provided token if the + // credential manager is unavailable. + token := t.accessToken + if cm, err := auth.NewCredentialManager(); err == nil { + if fresh, err := cm.GetValidAccessToken(); err == nil && fresh != "" { + token = fresh + } + } + newReq := req.Clone(req.Context()) newReq.Header.Del("x-api-key") - newReq.Header.Set("Authorization", "Bearer "+t.accessToken) + newReq.Header.Set("Authorization", "Bearer "+token) newReq.Header.Set("anthropic-beta", "oauth-2025-04-20") newReq.Header.Set("anthropic-version", "2023-06-01") diff --git a/internal/models/providers_test.go b/internal/models/providers_test.go index 30d531bc..691dbb4e 100644 --- a/internal/models/providers_test.go +++ b/internal/models/providers_test.go @@ -78,6 +78,7 @@ func TestCreateOAuthHTTPClient(t *testing.T) { if client == nil { t.Fatal("expected non-nil client") + return } // Check that the transport is an oauthTransport diff --git a/internal/models/registry.go b/internal/models/registry.go index fdf41206..d9baa923 100644 --- a/internal/models/registry.go +++ b/internal/models/registry.go @@ -6,6 +6,8 @@ import ( "fmt" "os" "strings" + + "github.com/mark3labs/kit/internal/auth" ) //go:embed embedded_models.json @@ -171,14 +173,27 @@ func (r *ModelsRegistry) GetRequiredEnvVars(provider string) ([]string, error) { return providerInfo.Env, nil } -// ValidateEnvironment checks if required environment variables are set. -// Returns nil for providers not in the registry (unknown providers are -// assumed to handle auth themselves or via --provider-api-key). +// ValidateEnvironment checks if required credentials are available for a +// provider. It checks the explicit API key, stored credentials (for +// providers that support them, such as Anthropic OAuth), and environment +// variables. Returns nil for providers not in the registry (unknown +// providers are assumed to handle auth themselves or via --provider-api-key). func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) error { if apiKey != "" { return nil } + // For anthropic, also check stored credentials (OAuth / API key) + // since auth resolution goes through the credential manager, not + // just environment variables. + if provider == "anthropic" { + if cm, err := auth.NewCredentialManager(); err == nil { + if has, _ := cm.HasAnthropicCredentials(); has { + return nil + } + } + } + envVars, err := r.GetRequiredEnvVars(provider) if err != nil { // Unknown provider — nothing to validate diff --git a/internal/ui/usage_tracker_test.go b/internal/ui/usage_tracker_test.go index 9033368a..97119a15 100644 --- a/internal/ui/usage_tracker_test.go +++ b/internal/ui/usage_tracker_test.go @@ -24,6 +24,7 @@ func TestUsageTracker_OAuthCosts(t *testing.T) { stats := regularTracker.GetLastRequestStats() if stats == nil { t.Fatal("Expected stats to be non-nil") + return } // Check that costs are calculated for regular API key @@ -48,6 +49,7 @@ func TestUsageTracker_OAuthCosts(t *testing.T) { oauthStats := oauthTracker.GetLastRequestStats() if oauthStats == nil { t.Fatal("Expected OAuth stats to be non-nil") + return } // Check that all costs are $0 for OAuth diff --git a/pkg/kit/hooks_test.go b/pkg/kit/hooks_test.go index db36251f..126e7c90 100644 --- a/pkg/kit/hooks_test.go +++ b/pkg/kit/hooks_test.go @@ -24,6 +24,7 @@ func TestHookRegistry_RegisterAndRun(t *testing.T) { got := hr.run("hello") if got == nil { t.Fatal("expected non-nil result") + return } if *got != "handled: hello" { t.Errorf("expected 'handled: hello', got %q", *got) @@ -51,6 +52,7 @@ func TestHookRegistry_FirstNonNilWins(t *testing.T) { got := hr.run("test") if got == nil { t.Fatal("expected non-nil result") + return } if *got != "second: test" { t.Errorf("expected 'second: test', got %q", *got) @@ -77,6 +79,7 @@ func TestHookRegistry_PriorityOrdering(t *testing.T) { got := hr.run("x") if got == nil { t.Fatal("expected non-nil result") + return } if *got != "high" { t.Errorf("expected 'high' (priority 0 runs first), got %q", *got) @@ -441,6 +444,7 @@ func TestBeforeTurnHook_PromptOverride(t *testing.T) { result := hr.run(BeforeTurnHook{Prompt: "original"}) if result == nil { t.Fatal("expected non-nil result") + return } if result.Prompt == nil || *result.Prompt != "modified prompt" { t.Errorf("expected prompt override, got %v", result.Prompt) @@ -462,6 +466,7 @@ func TestBeforeTurnHook_InjectSystemAndContext(t *testing.T) { result := hr.run(BeforeTurnHook{Prompt: "hello"}) if result == nil { t.Fatal("expected non-nil result") + return } if result.SystemPrompt == nil || *result.SystemPrompt != "be concise" { t.Errorf("expected system prompt injection")