mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-13 19:20:06 +00:00
fix(auth): support OAuth credentials in ACP mode and auto-refresh tokens
Remove the early ValidateEnvironment gate from CreateProvider that only checked env vars and --provider-api-key, blocking stored OAuth credentials from working. Each provider creation function already handles its own auth resolution with clear error messages. Update ValidateEnvironment to also check stored Anthropic credentials so the model selector UI correctly shows Anthropic models for OAuth users. Add automatic token refresh in oauthTransport so long-lived ACP sessions survive token renewals. Surface actionable auth error messages in ACP session creation. Fix pre-existing staticcheck SA5011 warnings in test files.
This commit is contained in:
@@ -90,6 +90,7 @@ func (a *Agent) NewSession(ctx context.Context, params acp.NewSessionRequest) (a
|
||||
|
||||
sess, err := a.registry.create(ctx, cwd)
|
||||
if err != nil {
|
||||
log.Error("acp: session creation failed", "cwd", cwd, "error", err)
|
||||
return acp.NewSessionResponse{}, fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package acpserver
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
@@ -39,6 +40,12 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
Streaming: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Provide actionable guidance for provider auth errors, which are
|
||||
// the most common failure mode when running via ACP.
|
||||
msg := err.Error()
|
||||
if strings.Contains(msg, "API key") || strings.Contains(msg, "credentials") || strings.Contains(msg, "OAuth") {
|
||||
return nil, fmt.Errorf("provider authentication failed: %w — run 'kit auth login <provider>' or set the appropriate environment variable before starting 'kit acp'", err)
|
||||
}
|
||||
return nil, fmt.Errorf("create kit instance: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@ func TestCredentialManager(t *testing.T) {
|
||||
}
|
||||
if creds == nil {
|
||||
t.Fatal("Expected credentials to be returned")
|
||||
return
|
||||
}
|
||||
if creds.APIKey != testAPIKey {
|
||||
t.Errorf("Expected API key %s, got %s", testAPIKey, creds.APIKey)
|
||||
@@ -236,6 +237,7 @@ func TestCredentialStorePersistence(t *testing.T) {
|
||||
}
|
||||
if creds == nil {
|
||||
t.Fatal("Expected credentials to persist")
|
||||
return
|
||||
}
|
||||
if creds.APIKey != testAPIKey {
|
||||
t.Errorf("Expected API key %s, got %s", testAPIKey, creds.APIKey)
|
||||
|
||||
@@ -210,10 +210,11 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
}
|
||||
}
|
||||
|
||||
// Validate environment variables
|
||||
if err := registry.ValidateEnvironment(provider, config.ProviderAPIKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// NOTE: We intentionally skip registry.ValidateEnvironment() here.
|
||||
// Each create*Provider function handles its own auth resolution and
|
||||
// produces provider-specific error messages. The early env-var check
|
||||
// was too narrow — it didn't account for stored credentials (e.g.
|
||||
// OAuth tokens from 'kit auth login') and blocked valid auth paths.
|
||||
|
||||
// Validate config against known model limits when metadata is available
|
||||
if modelInfo != nil {
|
||||
@@ -1042,9 +1043,21 @@ type oauthTransport struct {
|
||||
}
|
||||
|
||||
func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Resolve the freshest available token. The credential manager
|
||||
// automatically refreshes tokens nearing expiry (5-minute buffer).
|
||||
// This keeps long-lived sessions (e.g. ACP) working across token
|
||||
// renewals. Falls back to the originally-provided token if the
|
||||
// credential manager is unavailable.
|
||||
token := t.accessToken
|
||||
if cm, err := auth.NewCredentialManager(); err == nil {
|
||||
if fresh, err := cm.GetValidAccessToken(); err == nil && fresh != "" {
|
||||
token = fresh
|
||||
}
|
||||
}
|
||||
|
||||
newReq := req.Clone(req.Context())
|
||||
newReq.Header.Del("x-api-key")
|
||||
newReq.Header.Set("Authorization", "Bearer "+t.accessToken)
|
||||
newReq.Header.Set("Authorization", "Bearer "+token)
|
||||
newReq.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
newReq.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
|
||||
@@ -78,6 +78,7 @@ func TestCreateOAuthHTTPClient(t *testing.T) {
|
||||
|
||||
if client == nil {
|
||||
t.Fatal("expected non-nil client")
|
||||
return
|
||||
}
|
||||
|
||||
// Check that the transport is an oauthTransport
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
)
|
||||
|
||||
//go:embed embedded_models.json
|
||||
@@ -171,14 +173,27 @@ func (r *ModelsRegistry) GetRequiredEnvVars(provider string) ([]string, error) {
|
||||
return providerInfo.Env, nil
|
||||
}
|
||||
|
||||
// ValidateEnvironment checks if required environment variables are set.
|
||||
// Returns nil for providers not in the registry (unknown providers are
|
||||
// assumed to handle auth themselves or via --provider-api-key).
|
||||
// ValidateEnvironment checks if required credentials are available for a
|
||||
// provider. It checks the explicit API key, stored credentials (for
|
||||
// providers that support them, such as Anthropic OAuth), and environment
|
||||
// variables. Returns nil for providers not in the registry (unknown
|
||||
// providers are assumed to handle auth themselves or via --provider-api-key).
|
||||
func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) error {
|
||||
if apiKey != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For anthropic, also check stored credentials (OAuth / API key)
|
||||
// since auth resolution goes through the credential manager, not
|
||||
// just environment variables.
|
||||
if provider == "anthropic" {
|
||||
if cm, err := auth.NewCredentialManager(); err == nil {
|
||||
if has, _ := cm.HasAnthropicCredentials(); has {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
envVars, err := r.GetRequiredEnvVars(provider)
|
||||
if err != nil {
|
||||
// Unknown provider — nothing to validate
|
||||
|
||||
@@ -24,6 +24,7 @@ func TestUsageTracker_OAuthCosts(t *testing.T) {
|
||||
stats := regularTracker.GetLastRequestStats()
|
||||
if stats == nil {
|
||||
t.Fatal("Expected stats to be non-nil")
|
||||
return
|
||||
}
|
||||
|
||||
// Check that costs are calculated for regular API key
|
||||
@@ -48,6 +49,7 @@ func TestUsageTracker_OAuthCosts(t *testing.T) {
|
||||
oauthStats := oauthTracker.GetLastRequestStats()
|
||||
if oauthStats == nil {
|
||||
t.Fatal("Expected OAuth stats to be non-nil")
|
||||
return
|
||||
}
|
||||
|
||||
// Check that all costs are $0 for OAuth
|
||||
|
||||
@@ -24,6 +24,7 @@ func TestHookRegistry_RegisterAndRun(t *testing.T) {
|
||||
got := hr.run("hello")
|
||||
if got == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
return
|
||||
}
|
||||
if *got != "handled: hello" {
|
||||
t.Errorf("expected 'handled: hello', got %q", *got)
|
||||
@@ -51,6 +52,7 @@ func TestHookRegistry_FirstNonNilWins(t *testing.T) {
|
||||
got := hr.run("test")
|
||||
if got == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
return
|
||||
}
|
||||
if *got != "second: test" {
|
||||
t.Errorf("expected 'second: test', got %q", *got)
|
||||
@@ -77,6 +79,7 @@ func TestHookRegistry_PriorityOrdering(t *testing.T) {
|
||||
got := hr.run("x")
|
||||
if got == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
return
|
||||
}
|
||||
if *got != "high" {
|
||||
t.Errorf("expected 'high' (priority 0 runs first), got %q", *got)
|
||||
@@ -441,6 +444,7 @@ func TestBeforeTurnHook_PromptOverride(t *testing.T) {
|
||||
result := hr.run(BeforeTurnHook{Prompt: "original"})
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
return
|
||||
}
|
||||
if result.Prompt == nil || *result.Prompt != "modified prompt" {
|
||||
t.Errorf("expected prompt override, got %v", result.Prompt)
|
||||
@@ -462,6 +466,7 @@ func TestBeforeTurnHook_InjectSystemAndContext(t *testing.T) {
|
||||
result := hr.run(BeforeTurnHook{Prompt: "hello"})
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
return
|
||||
}
|
||||
if result.SystemPrompt == nil || *result.SystemPrompt != "be concise" {
|
||||
t.Errorf("expected system prompt injection")
|
||||
|
||||
Reference in New Issue
Block a user