fix(auth): support OAuth credentials in ACP mode and auto-refresh tokens

Remove the early ValidateEnvironment gate from CreateProvider that only
checked env vars and --provider-api-key, blocking stored OAuth credentials
from working. Each provider creation function already handles its own auth
resolution with clear error messages.

Update ValidateEnvironment to also check stored Anthropic credentials so
the model selector UI correctly shows Anthropic models for OAuth users.

Add automatic token refresh in oauthTransport so long-lived ACP sessions
survive token renewals. Surface actionable auth error messages in ACP
session creation.

Fix pre-existing staticcheck SA5011 warnings in test files.
This commit is contained in:
Ed Zynda
2026-03-15 12:38:23 +03:00
parent fefbf19b42
commit a05da5f3ab
8 changed files with 54 additions and 8 deletions
+1
View File
@@ -90,6 +90,7 @@ func (a *Agent) NewSession(ctx context.Context, params acp.NewSessionRequest) (a
sess, err := a.registry.create(ctx, cwd)
if err != nil {
log.Error("acp: session creation failed", "cwd", cwd, "error", err)
return acp.NewSessionResponse{}, fmt.Errorf("create session: %w", err)
}
+7
View File
@@ -3,6 +3,7 @@ package acpserver
import (
"context"
"fmt"
"strings"
"sync"
kit "github.com/mark3labs/kit/pkg/kit"
@@ -39,6 +40,12 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
Streaming: true,
})
if err != nil {
// Provide actionable guidance for provider auth errors, which are
// the most common failure mode when running via ACP.
msg := err.Error()
if strings.Contains(msg, "API key") || strings.Contains(msg, "credentials") || strings.Contains(msg, "OAuth") {
return nil, fmt.Errorf("provider authentication failed: %w — run 'kit auth login <provider>' or set the appropriate environment variable before starting 'kit acp'", err)
}
return nil, fmt.Errorf("create kit instance: %w", err)
}
+2
View File
@@ -51,6 +51,7 @@ func TestCredentialManager(t *testing.T) {
}
if creds == nil {
t.Fatal("Expected credentials to be returned")
return
}
if creds.APIKey != testAPIKey {
t.Errorf("Expected API key %s, got %s", testAPIKey, creds.APIKey)
@@ -236,6 +237,7 @@ func TestCredentialStorePersistence(t *testing.T) {
}
if creds == nil {
t.Fatal("Expected credentials to persist")
return
}
if creds.APIKey != testAPIKey {
t.Errorf("Expected API key %s, got %s", testAPIKey, creds.APIKey)
+18 -5
View File
@@ -210,10 +210,11 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
}
}
// Validate environment variables
if err := registry.ValidateEnvironment(provider, config.ProviderAPIKey); err != nil {
return nil, err
}
// NOTE: We intentionally skip registry.ValidateEnvironment() here.
// Each create*Provider function handles its own auth resolution and
// produces provider-specific error messages. The early env-var check
// was too narrow — it didn't account for stored credentials (e.g.
// OAuth tokens from 'kit auth login') and blocked valid auth paths.
// Validate config against known model limits when metadata is available
if modelInfo != nil {
@@ -1042,9 +1043,21 @@ type oauthTransport struct {
}
func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Resolve the freshest available token. The credential manager
// automatically refreshes tokens nearing expiry (5-minute buffer).
// This keeps long-lived sessions (e.g. ACP) working across token
// renewals. Falls back to the originally-provided token if the
// credential manager is unavailable.
token := t.accessToken
if cm, err := auth.NewCredentialManager(); err == nil {
if fresh, err := cm.GetValidAccessToken(); err == nil && fresh != "" {
token = fresh
}
}
newReq := req.Clone(req.Context())
newReq.Header.Del("x-api-key")
newReq.Header.Set("Authorization", "Bearer "+t.accessToken)
newReq.Header.Set("Authorization", "Bearer "+token)
newReq.Header.Set("anthropic-beta", "oauth-2025-04-20")
newReq.Header.Set("anthropic-version", "2023-06-01")
+1
View File
@@ -78,6 +78,7 @@ func TestCreateOAuthHTTPClient(t *testing.T) {
if client == nil {
t.Fatal("expected non-nil client")
return
}
// Check that the transport is an oauthTransport
+18 -3
View File
@@ -6,6 +6,8 @@ import (
"fmt"
"os"
"strings"
"github.com/mark3labs/kit/internal/auth"
)
//go:embed embedded_models.json
@@ -171,14 +173,27 @@ func (r *ModelsRegistry) GetRequiredEnvVars(provider string) ([]string, error) {
return providerInfo.Env, nil
}
// ValidateEnvironment checks if required environment variables are set.
// Returns nil for providers not in the registry (unknown providers are
// assumed to handle auth themselves or via --provider-api-key).
// ValidateEnvironment checks if required credentials are available for a
// provider. It checks the explicit API key, stored credentials (for
// providers that support them, such as Anthropic OAuth), and environment
// variables. Returns nil for providers not in the registry (unknown
// providers are assumed to handle auth themselves or via --provider-api-key).
func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) error {
if apiKey != "" {
return nil
}
// For anthropic, also check stored credentials (OAuth / API key)
// since auth resolution goes through the credential manager, not
// just environment variables.
if provider == "anthropic" {
if cm, err := auth.NewCredentialManager(); err == nil {
if has, _ := cm.HasAnthropicCredentials(); has {
return nil
}
}
}
envVars, err := r.GetRequiredEnvVars(provider)
if err != nil {
// Unknown provider — nothing to validate
+2
View File
@@ -24,6 +24,7 @@ func TestUsageTracker_OAuthCosts(t *testing.T) {
stats := regularTracker.GetLastRequestStats()
if stats == nil {
t.Fatal("Expected stats to be non-nil")
return
}
// Check that costs are calculated for regular API key
@@ -48,6 +49,7 @@ func TestUsageTracker_OAuthCosts(t *testing.T) {
oauthStats := oauthTracker.GetLastRequestStats()
if oauthStats == nil {
t.Fatal("Expected OAuth stats to be non-nil")
return
}
// Check that all costs are $0 for OAuth
+5
View File
@@ -24,6 +24,7 @@ func TestHookRegistry_RegisterAndRun(t *testing.T) {
got := hr.run("hello")
if got == nil {
t.Fatal("expected non-nil result")
return
}
if *got != "handled: hello" {
t.Errorf("expected 'handled: hello', got %q", *got)
@@ -51,6 +52,7 @@ func TestHookRegistry_FirstNonNilWins(t *testing.T) {
got := hr.run("test")
if got == nil {
t.Fatal("expected non-nil result")
return
}
if *got != "second: test" {
t.Errorf("expected 'second: test', got %q", *got)
@@ -77,6 +79,7 @@ func TestHookRegistry_PriorityOrdering(t *testing.T) {
got := hr.run("x")
if got == nil {
t.Fatal("expected non-nil result")
return
}
if *got != "high" {
t.Errorf("expected 'high' (priority 0 runs first), got %q", *got)
@@ -441,6 +444,7 @@ func TestBeforeTurnHook_PromptOverride(t *testing.T) {
result := hr.run(BeforeTurnHook{Prompt: "original"})
if result == nil {
t.Fatal("expected non-nil result")
return
}
if result.Prompt == nil || *result.Prompt != "modified prompt" {
t.Errorf("expected prompt override, got %v", result.Prompt)
@@ -462,6 +466,7 @@ func TestBeforeTurnHook_InjectSystemAndContext(t *testing.T) {
result := hr.run(BeforeTurnHook{Prompt: "hello"})
if result == nil {
t.Fatal("expected non-nil result")
return
}
if result.SystemPrompt == nil || *result.SystemPrompt != "be concise" {
t.Errorf("expected system prompt injection")