Compare commits

...

1 Commits

Author SHA1 Message Date
Innei 912d71789b init 2026-01-17 22:14:54 +08:00
9 changed files with 679 additions and 43 deletions
+170 -28
View File
@@ -4,6 +4,7 @@ import crypto from 'node:crypto';
import querystring from 'node:querystring';
import { URL } from 'node:url';
import type { OIDCCallbackResult } from '@/core/infrastructure/OIDCCallbackServerManager';
import { createLogger } from '@/utils/logger';
import RemoteServerConfigCtr from './RemoteServerConfigCtr';
@@ -12,9 +13,23 @@ import { ControllerModule, IpcMethod } from './index';
// Create logger
const logger = createLogger('controllers:AuthCtr');
interface AuthorizationSuccess {
code: string;
state: string;
}
interface AuthorizationError {
error: string;
errorDescription?: string;
state: string;
}
type AuthorizationResult = AuthorizationSuccess | AuthorizationError;
type AuthorizationSource = 'local' | 'polling';
/**
* Authentication Controller
* Implements OAuth authorization flow using intermediate page + polling mechanism
* Implements OAuth authorization flow using local callback with polling fallback
*/
export default class AuthCtr extends ControllerModule {
static override readonly groupName = 'auth';
@@ -25,11 +40,16 @@ export default class AuthCtr extends ControllerModule {
return this.app.getController(RemoteServerConfigCtr);
}
private get oidcCallbackServerManager() {
return this.app.oidcCallbackServerManager;
}
/**
* Current PKCE parameters
*/
private codeVerifier: string | null = null;
private authRequestState: string | null = null;
private authorizationHandled = false;
/**
* Polling related parameters
@@ -60,6 +80,7 @@ export default class AuthCtr extends ControllerModule {
@IpcMethod()
async requestAuthorization(config: DataSyncConfig) {
// Clear any old authorization state
this.authorizationHandled = false;
this.clearAuthorizationState();
const remoteUrl = await this.remoteServerConfigCtr.getRemoteServerUrl(config);
@@ -81,6 +102,31 @@ export default class AuthCtr extends ControllerModule {
this.authRequestState = crypto.randomBytes(16).toString('hex');
logger.debug(`Generated state parameter: ${this.authRequestState}`);
const callbackManager = this.oidcCallbackServerManager;
let localCallback: { port: number; waitForCallback: Promise<OIDCCallbackResult> } | null =
null;
if (callbackManager) {
try {
localCallback = await callbackManager.startCallbackServer(this.authRequestState);
const registered = await this.registerNotifyPort(
remoteUrl,
this.authRequestState,
localCallback.port,
);
if (!registered) {
localCallback.waitForCallback.catch(() => undefined);
await callbackManager.stopCallbackServer();
localCallback = null;
}
} catch (error) {
logger.warn('Failed to start local callback server:', error);
await callbackManager.stopCallbackServer();
localCallback = null;
}
}
// Construct authorization URL with new redirect_uri
const authUrl = new URL('/oidc/auth', remoteUrl);
const redirectUri = this.constructRedirectUri(remoteUrl);
@@ -110,9 +156,18 @@ export default class AuthCtr extends ControllerModule {
// Start polling for credentials
this.startPolling();
if (localCallback) {
localCallback.waitForCallback
.then((result) => this.handleAuthorizationResult(result, 'local'))
.catch((error) => {
logger.warn('Local callback channel stopped:', error);
});
}
return { success: true };
} catch (error) {
logger.error('Authorization request failed:', error);
this.clearAuthorizationState();
return { error: error.message, success: false };
}
}
@@ -142,6 +197,95 @@ export default class AuthCtr extends ControllerModule {
}
}
private async registerNotifyPort(
remoteUrl: string,
state: string,
notifyPort: number,
): Promise<boolean> {
try {
const url = new URL('/oidc/handoff', remoteUrl);
const response = await fetch(url.toString(), {
body: JSON.stringify({ client: 'desktop', id: state, notifyPort }),
headers: {
'Content-Type': 'application/json',
},
method: 'POST',
});
if (!response.ok) {
logger.warn(`Failed to register notifyPort: ${response.status} ${response.statusText}`);
return false;
}
return true;
} catch (error) {
logger.warn('Failed to register notifyPort:', error);
return false;
}
}
private isAuthorizationError(result: AuthorizationResult): result is AuthorizationError {
return 'error' in result && typeof result.error === 'string';
}
private getAuthorizationErrorMessage(result: AuthorizationError): string {
if (result.errorDescription) {
return `${result.error}: ${result.errorDescription}`;
}
return result.error;
}
private async handleAuthorizationResult(
result: AuthorizationResult,
source: AuthorizationSource,
): Promise<void> {
if (this.authorizationHandled) return;
this.authorizationHandled = true;
this.stopPolling();
await this.oidcCallbackServerManager?.stopCallbackServer();
logger.info(`Received authorization result via ${source} channel`);
if (!this.authRequestState || result.state !== this.authRequestState) {
logger.error(
`Invalid state parameter: expected ${this.authRequestState}, received ${result.state}`,
);
this.broadcastAuthorizationFailed('Invalid state parameter');
this.clearAuthorizationState();
return;
}
if (this.isAuthorizationError(result)) {
const errorMessage = this.getAuthorizationErrorMessage(result);
this.broadcastAuthorizationFailed(errorMessage);
this.clearAuthorizationState();
return;
}
try {
if (!this.codeVerifier) {
throw new Error('Missing code verifier');
}
const exchangeResult = await this.exchangeCodeForToken(result.code, this.codeVerifier);
if (exchangeResult.success) {
logger.info('Authorization successful');
this.broadcastAuthorizationSuccessful();
} else {
logger.warn(`Authorization failed: ${exchangeResult.error || 'Unknown error'}`);
this.broadcastAuthorizationFailed(exchangeResult.error || 'Unknown error');
}
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
this.broadcastAuthorizationFailed(message);
} finally {
this.clearAuthorizationState();
}
}
/**
* 启动轮询机制获取凭证
*/
@@ -158,9 +302,13 @@ export default class AuthCtr extends ControllerModule {
this.pollingInterval = setInterval(async () => {
try {
if (this.authorizationHandled) return;
// Check if polling has timed out
if (Date.now() - startTime > maxPollTime) {
logger.warn('Credential polling timed out');
this.authorizationHandled = true;
void this.oidcCallbackServerManager?.stopCallbackServer();
this.clearAuthorizationState();
this.broadcastAuthorizationFailed('Authorization timed out');
return;
@@ -171,30 +319,12 @@ export default class AuthCtr extends ControllerModule {
if (result) {
logger.info('Successfully received credentials from polling');
this.stopPolling();
// Validate state parameter
if (result.state !== this.authRequestState) {
logger.error(
`Invalid state parameter: expected ${this.authRequestState}, received ${result.state}`,
);
this.broadcastAuthorizationFailed('Invalid state parameter');
return;
}
// Exchange code for tokens
const exchangeResult = await this.exchangeCodeForToken(result.code, this.codeVerifier!);
if (exchangeResult.success) {
logger.info('Authorization successful');
this.broadcastAuthorizationSuccessful();
} else {
logger.warn(`Authorization failed: ${exchangeResult.error || 'Unknown error'}`);
this.broadcastAuthorizationFailed(exchangeResult.error || 'Unknown error');
}
await this.handleAuthorizationResult(result, 'polling');
}
} catch (error) {
logger.error('Error during credential polling:', error);
this.authorizationHandled = true;
void this.oidcCallbackServerManager?.stopCallbackServer();
this.clearAuthorizationState();
this.broadcastAuthorizationFailed('Polling error: ' + error.message);
}
@@ -218,6 +348,7 @@ export default class AuthCtr extends ControllerModule {
private clearAuthorizationState() {
logger.debug('Clearing authorization state');
this.stopPolling();
void this.oidcCallbackServerManager?.stopCallbackServer();
this.codeVerifier = null;
this.authRequestState = null;
this.cachedRemoteUrl = null;
@@ -287,7 +418,7 @@ export default class AuthCtr extends ControllerModule {
* Poll for credentials
* Sends HTTP request directly to remote server
*/
private async pollForCredentials(): Promise<{ code: string; state: string } | null> {
private async pollForCredentials(): Promise<AuthorizationResult | null> {
if (!this.authRequestState || !this.cachedRemoteUrl) {
return null;
}
@@ -325,17 +456,27 @@ export default class AuthCtr extends ControllerModule {
const data = (await response.json()) as {
data: {
id: string;
payload: { code: string; state: string };
payload: Record<string, unknown>;
};
success: boolean;
};
if (data.success && data.data?.payload) {
logger.debug('Successfully retrieved credentials from handoff');
return {
code: data.data.payload.code,
state: data.data.payload.state,
};
const payload = data.data.payload as Record<string, unknown>;
const code = typeof payload.code === 'string' ? payload.code : undefined;
const state = typeof payload.state === 'string' ? payload.state : undefined;
const error = typeof payload.error === 'string' ? payload.error : undefined;
const errorDescription =
typeof payload.error_description === 'string' ? payload.error_description : undefined;
if (code && state) {
return { code, state };
}
if (error && state) {
return { error, errorDescription, state };
}
}
return null;
@@ -590,6 +731,7 @@ export default class AuthCtr extends ControllerModule {
logger.debug('Cleaning up AuthCtr timers');
this.stopPolling();
this.stopAutoRefresh();
void this.oidcCallbackServerManager?.stopCallbackServer();
}
/**
+4
View File
@@ -16,6 +16,7 @@ import { createLogger } from '@/utils/logger';
import { BrowserManager } from './browser/BrowserManager';
import { I18nManager } from './infrastructure/I18nManager';
import { IoCContainer } from './infrastructure/IoCContainer';
import { OIDCCallbackServerManager } from './infrastructure/OIDCCallbackServerManager';
import { ProtocolManager } from './infrastructure/ProtocolManager';
import { RendererUrlManager } from './infrastructure/RendererUrlManager';
import { StaticFileServerManager } from './infrastructure/StaticFileServerManager';
@@ -44,6 +45,7 @@ export class App {
shortcutManager: ShortcutManager;
trayManager: TrayManager;
staticFileServerManager: StaticFileServerManager;
oidcCallbackServerManager: OIDCCallbackServerManager;
protocolManager: ProtocolManager;
rendererUrlManager: RendererUrlManager;
chromeFlags: string[] = ['OverlayScrollbar', 'FluentOverlayScrollbar', 'FluentScrollbar'];
@@ -118,6 +120,7 @@ export class App {
this.shortcutManager = new ShortcutManager(this);
this.trayManager = new TrayManager(this);
this.staticFileServerManager = new StaticFileServerManager(this);
this.oidcCallbackServerManager = new OIDCCallbackServerManager(this);
this.protocolManager = new ProtocolManager(this);
// Configure renderer loading strategy (dev server vs static export)
@@ -397,5 +400,6 @@ export class App {
// 执行清理操作
this.staticFileServerManager.destroy();
this.oidcCallbackServerManager.destroy();
};
}
@@ -129,6 +129,14 @@ vi.mock('../infrastructure/StaticFileServerManager', () => ({
})),
}));
vi.mock('../infrastructure/OIDCCallbackServerManager', () => ({
OIDCCallbackServerManager: vi.fn().mockImplementation(() => ({
destroy: vi.fn(),
startCallbackServer: vi.fn(),
stopCallbackServer: vi.fn(),
})),
}));
vi.mock('../infrastructure/UpdaterManager', () => ({
UpdaterManager: vi.fn().mockImplementation(() => ({
initialize: vi.fn().mockResolvedValue(undefined),
@@ -0,0 +1,206 @@
import { getPort } from 'get-port-please';
import { IncomingMessage, ServerResponse, createServer } from 'node:http';
import { createLogger } from '@/utils/logger';
import type { App } from '../App';
const logger = createLogger('core:OIDCCallbackServerManager');
const CALLBACK_PORT_MIN = 34_210;
const CALLBACK_PORT_MAX = 34_219;
const DEFAULT_TIMEOUT_MS = 5 * 60 * 1000;
interface PendingCallback {
expectedState: string;
reject: (error: Error) => void;
resolve: (result: OIDCCallbackResult) => void;
timeoutId: NodeJS.Timeout;
}
export interface OIDCCallbackResult {
code?: string;
error?: string;
errorDescription?: string;
state: string;
}
export class OIDCCallbackServerManager {
private httpServer: ReturnType<typeof createServer> | null = null;
private serverPort = 0;
private pendingCallback: PendingCallback | null = null;
constructor() {
logger.debug('OIDCCallbackServerManager initialized');
}
async startCallbackServer(
expectedState: string,
timeoutMs: number = DEFAULT_TIMEOUT_MS,
): Promise<{ port: number; waitForCallback: Promise<OIDCCallbackResult> }> {
await this.stopCallbackServer();
this.serverPort = await getPort({
host: '127.0.0.1',
port: CALLBACK_PORT_MIN,
ports: Array.from({ length: CALLBACK_PORT_MAX - CALLBACK_PORT_MIN }, (_, index) => {
return CALLBACK_PORT_MIN + index + 1;
}),
});
const waitForCallback = new Promise<OIDCCallbackResult>((resolve, reject) => {
const timeoutId = setTimeout(() => {
if (this.pendingCallback) {
this.pendingCallback = null;
}
reject(new Error('Local callback timed out'));
void this.stopCallbackServer();
}, timeoutMs);
this.pendingCallback = {
expectedState,
reject,
resolve,
timeoutId,
};
});
await new Promise<void>((resolve, reject) => {
const server = createServer(async (req, res) => {
try {
await this.handleHttpRequest(req, res);
} catch (error) {
logger.error('Unhandled error in OIDC callback server:', error);
if (!res.headersSent) {
res.writeHead(500, { 'Content-Type': 'text/plain' });
res.end('Internal Server Error');
}
}
});
server.on('error', (error) => {
logger.error('OIDC callback server error:', error);
reject(error);
});
server.listen(this.serverPort, '127.0.0.1', () => {
this.httpServer = server;
logger.info(`OIDC callback server started on port ${this.serverPort}`);
resolve();
});
});
return { port: this.serverPort, waitForCallback };
}
async stopCallbackServer(): Promise<void> {
if (this.pendingCallback) {
clearTimeout(this.pendingCallback.timeoutId);
this.pendingCallback.reject(new Error('Local callback server stopped'));
this.pendingCallback = null;
}
if (!this.httpServer) {
this.serverPort = 0;
return;
}
await new Promise<void>((resolve) => {
this.httpServer?.close(() => resolve());
});
this.httpServer = null;
this.serverPort = 0;
logger.info('OIDC callback server stopped');
}
destroy() {
void this.stopCallbackServer();
}
getPort(): number {
return this.serverPort;
}
private resolvePendingCallback(result: OIDCCallbackResult) {
if (!this.pendingCallback) return;
clearTimeout(this.pendingCallback.timeoutId);
this.pendingCallback.resolve(result);
this.pendingCallback = null;
}
private async handleHttpRequest(req: IncomingMessage, res: ServerResponse): Promise<void> {
const url = new URL(req.url || '/', `http://127.0.0.1:${this.serverPort}`);
if (req.method !== 'GET' || url.pathname !== '/notify') {
res.writeHead(404, { 'Content-Type': 'text/plain' });
res.end('Not Found');
return;
}
const state = url.searchParams.get('state');
if (!state || state !== this.pendingCallback?.expectedState) {
res.writeHead(400, { 'Content-Type': 'text/html; charset=utf-8' });
res.end(this.getErrorPageHtml('Invalid state or missing state'));
return;
}
const error = url.searchParams.get('error');
const errorDescription = url.searchParams.get('error_description');
const code = url.searchParams.get('code');
if (error) {
this.resolvePendingCallback({
error,
errorDescription: errorDescription || undefined,
state,
});
res.writeHead(200, { 'Content-Type': 'text/html; charset=utf-8' });
res.end(this.getErrorPageHtml(errorDescription || error));
void this.stopCallbackServer();
return;
}
if (code) {
this.resolvePendingCallback({ code, state });
res.writeHead(200, { 'Content-Type': 'text/html; charset=utf-8' });
res.end(this.getSuccessPageHtml());
void this.stopCallbackServer();
return;
}
res.writeHead(400, { 'Content-Type': 'text/html; charset=utf-8' });
res.end(this.getErrorPageHtml('Missing code'));
}
private getSuccessPageHtml(): string {
return `<!DOCTYPE html>
<html>
<head><title>Authorization complete</title></head>
<body style="font-family: system-ui; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0;">
<div style="text-align: center;">
<h1 style="color: #2e7d32;">Authorization complete</h1>
<p>You can close this window and return to LobeHub Desktop.</p>
</div>
</body>
</html>`;
}
private getErrorPageHtml(message: string): string {
return `<!DOCTYPE html>
<html>
<head><title>Authorization failed</title></head>
<body style="font-family: system-ui; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0;">
<div style="text-align: center;">
<h1 style="color: #d32f2f;">Authorization failed</h1>
<p>${message}</p>
<p>Please close this window and return to LobeHub Desktop.</p>
</div>
</body>
</html>`;
}
}
@@ -0,0 +1,95 @@
# Desktop OIDC Local Callback Plan
## Summary
Switch the desktop OIDC flow from polling-only to a dual-channel flow that prefers a local callback server and falls back to polling. Bind the local notify port to the OIDC state on the server so the callback route can safely redirect to localhost without taking a port from user input.
## Goals
- Reduce latency by receiving the callback locally.
- Keep polling as a fallback if the local channel fails.
- Avoid open-redirect risks by binding notify ports to state on the server.
- Preserve existing PKCE and state validation.
## Flow
Current flow:
```
Browser -> /oidc/callback/desktop -> DB(handoff) -> Electron polls /oidc/handoff -> Token exchange
```
New flow:
```
Electron starts local server on 127.0.0.1:34210..34219
Electron POST /oidc/handoff (id=state, client=desktop, notifyPort)
Browser -> /oidc/auth (state)
Browser -> /oidc/callback/desktop
-> DB(handoff) { code/state or error } + notifyPort (if registered)
-> if notifyPort: redirect to http://127.0.0.1:{port}/notify
Electron waits for local callback; polling continues as fallback
First successful result wins, then exchange token
```
## Data model
Reuse `oauth_handoffs.payload` for two phases:
- Pending: `{ notifyPort: number }`
- Complete:
- Success: `{ code: string; state: string; notifyPort?: number }`
- Error: `{ error: string; error_description?: string; state: string; notifyPort?: number }`
`fetchAndConsume` must only return a record when `code` or `error` is present, otherwise keep it for later.
## API changes
Add a registration endpoint to bind state -> notifyPort:
- `POST /oidc/handoff`
- Body: `{ id: string; client: 'desktop'; notifyPort: number }`
- Validates port in the allowed range.
- Upserts payload to include `notifyPort`.
`GET /oidc/handoff` stays the polling endpoint.
## Server changes
- `/oidc/callback/desktop`:
- Accept `code/state` or `error/error_description/state`.
- Merge payload with any existing `notifyPort`.
- Redirect to local `http://127.0.0.1:{notifyPort}/notify` only if the port is registered for this state.
- Otherwise, redirect to the normal success/error pages.
## Desktop changes
- New `OIDCCallbackServerManager` to:
- Bind `127.0.0.1` and listen on `34210..34219`.
- Resolve a promise on `/notify?code=...&state=...` or `/notify?error=...&state=...`.
- Return a success or error HTML page to the browser.
- Time out after 5 minutes and close the server.
- `AuthCtr`:
- Start local server and register `notifyPort` before opening the browser.
- Start polling as fallback.
- If local callback returns `error`, stop polling and surface failure.
- If local callback times out or is invalid, keep polling.
- Only the first successful result triggers token exchange.
## Security
- Redirect only to a port stored on the server for this state.
- Only listen on `127.0.0.1`.
- Keep PKCE + state validation unchanged.
- Validate the notify port is in the configured range.
## Test plan
- `OAuthHandoffModel`:
- Ensure `fetchAndConsume` ignores pending records with only `notifyPort`.
- Ensure upsert/merge preserves `notifyPort`.
- Desktop:
- Mock local callback manager and verify polling still works as fallback.
- Verify local callback error stops polling and reports failure.
@@ -81,6 +81,23 @@ describe('OAuthHandoffModel', () => {
expect(deleted).toBeUndefined();
});
it('should ignore pending records without credentials', async () => {
await oauthHandoffModel.create({
id: 'handoff-pending',
client: 'desktop',
payload: { notifyPort: 34210 },
});
const result = await oauthHandoffModel.fetchAndConsume('handoff-pending', 'desktop');
expect(result).toBeNull();
const record = await serverDB.query.oauthHandoffs.findFirst({
where: eq(oauthHandoffs.id, 'handoff-pending'),
});
expect(record).toBeDefined();
});
it('should return null for non-existent credentials', async () => {
const result = await oauthHandoffModel.fetchAndConsume('non-existent', 'desktop');
+51 -3
View File
@@ -11,6 +11,15 @@ export class OAuthHandoffModel {
this.db = db;
}
private isCredentialPayload(payload: Record<string, unknown> | null | undefined): boolean {
if (!payload) return false;
return typeof payload.code === 'string' || typeof payload.error === 'string';
}
private getFreshCutoff(): Date {
return new Date(Date.now() - 5 * 60 * 1000);
}
/**
* Create a new OAuth handoff record
* @param params Credential data
@@ -26,6 +35,41 @@ export class OAuthHandoffModel {
return result;
};
/**
* Find an active handoff record without consuming it
*/
findActive = async (id: string, client: string): Promise<OAuthHandoffItem | null> => {
const fiveMinutesAgo = this.getFreshCutoff();
return this.db.query.oauthHandoffs.findFirst({
where: and(
eq(oauthHandoffs.id, id),
eq(oauthHandoffs.client, client),
sql`${oauthHandoffs.createdAt} > ${fiveMinutesAgo}`,
),
});
};
/**
* Upsert payload for a handoff record
*/
upsertPayload = async (
id: string,
client: string,
payload: Record<string, unknown>,
): Promise<OAuthHandoffItem> => {
const [result] = await this.db
.insert(oauthHandoffs)
.values({ client, id, payload })
.onConflictDoUpdate({
set: { payload },
target: oauthHandoffs.id,
})
.returning();
return result;
};
/**
* Fetch and consume OAuth credentials
* This method queries the record first, and if found, deletes it immediately to ensure credentials can only be used once
@@ -35,7 +79,7 @@ export class OAuthHandoffModel {
*/
fetchAndConsume = async (id: string, client: string): Promise<OAuthHandoffItem | null> => {
// First find the record while checking if it's expired (5 minute TTL)
const fiveMinutesAgo = new Date(Date.now() - 5 * 60 * 1000);
const fiveMinutesAgo = this.getFreshCutoff();
const handoff = await this.db.query.oauthHandoffs.findFirst({
where: and(
@@ -50,6 +94,10 @@ export class OAuthHandoffModel {
return null;
}
if (!this.isCredentialPayload(handoff.payload)) {
return null;
}
// Immediately delete the record to ensure one-time use
await this.db.delete(oauthHandoffs).where(eq(oauthHandoffs.id, id));
@@ -62,7 +110,7 @@ export class OAuthHandoffModel {
* @returns Number of records cleaned up
*/
cleanupExpired = async (): Promise<number> => {
const fiveMinutesAgo = new Date(Date.now() - 5 * 60 * 1000);
const fiveMinutesAgo = this.getFreshCutoff();
const result = await this.db
.delete(oauthHandoffs)
@@ -79,7 +127,7 @@ export class OAuthHandoffModel {
* @returns Whether it exists and is not expired
*/
exists = async (id: string, client: string): Promise<boolean> => {
const fiveMinutesAgo = new Date(Date.now() - 5 * 60 * 1000);
const fiveMinutesAgo = this.getFreshCutoff();
const handoff = await this.db.query.oauthHandoffs.findFirst({
where: and(
@@ -8,6 +8,18 @@ import { correctOIDCUrl } from '@/utils/server/correctOIDCUrl';
const log = debug('lobe-oidc:callback:desktop');
const errorPathname = '/oauth/callback/error';
const DESKTOP_NOTIFY_PORT_MIN = 34210;
const DESKTOP_NOTIFY_PORT_MAX = 34219;
const isAllowedNotifyPort = (value: number): boolean =>
Number.isInteger(value) && value >= DESKTOP_NOTIFY_PORT_MIN && value <= DESKTOP_NOTIFY_PORT_MAX;
interface OAuthCallbackPayload {
code?: string;
error?: string;
error_description?: string;
state: string;
}
/**
* 安全地构建重定向URL,使用经过验证的 correctOIDCUrl 防止开放重定向攻击
@@ -25,10 +37,22 @@ export const GET = async (req: NextRequest) => {
try {
const searchParams = req.nextUrl.searchParams;
const code = searchParams.get('code');
const error = searchParams.get('error');
const errorDescription = searchParams.get('error_description');
const state = searchParams.get('state'); // This `state` is the handoff ID
if (!code || !state || typeof code !== 'string' || typeof state !== 'string') {
log('Missing code or state in form data');
if (!state || typeof state !== 'string') {
log('Missing state in callback');
const errorUrl = buildRedirectUrl(req, errorPathname);
errorUrl.searchParams.set('reason', 'invalid_request');
log('Redirecting to error URL: %s', errorUrl.toString());
return NextResponse.redirect(errorUrl);
}
if (!code && !error) {
log('Missing code or error in callback');
const errorUrl = buildRedirectUrl(req, errorPathname);
errorUrl.searchParams.set('reason', 'invalid_request');
@@ -41,21 +65,20 @@ export const GET = async (req: NextRequest) => {
// The 'client' is 'desktop' because this redirect_uri is for the desktop client.
const client = 'desktop';
const payload = { code, state };
const payload: OAuthCallbackPayload = error
? { error, error_description: errorDescription || undefined, state }
: { code: code as string, state };
const id = state;
const authHandoffModel = new OAuthHandoffModel(serverDB);
await authHandoffModel.create({ client, id, payload });
const existing = await authHandoffModel.findActive(id, client);
const mergedPayload = { ...(existing?.payload || {}), ...payload };
const notifyPort =
typeof mergedPayload.notifyPort === 'number' ? mergedPayload.notifyPort : undefined;
await authHandoffModel.upsertPayload(id, client, mergedPayload);
log('Handoff record created successfully for id: %s', id);
const successUrl = buildRedirectUrl(req, '/oauth/callback/success');
// 添加调试日志
log('Request host header: %s', req.headers.get('host'));
log('Request x-forwarded-host: %s', req.headers.get('x-forwarded-host'));
log('Request x-forwarded-proto: %s', req.headers.get('x-forwarded-proto'));
log('Constructed success URL: %s', successUrl.toString());
// cleanup expired
after(async () => {
const cleanedCount = await authHandoffModel.cleanupExpired();
@@ -63,6 +86,41 @@ export const GET = async (req: NextRequest) => {
log('Cleaned up %d expired handoff records', cleanedCount);
});
if (notifyPort && isAllowedNotifyPort(notifyPort)) {
const localNotifyUrl = new URL(`http://127.0.0.1:${notifyPort}/notify`);
if (error) {
localNotifyUrl.searchParams.set('error', error);
if (errorDescription) {
localNotifyUrl.searchParams.set('error_description', errorDescription);
}
} else if (code) {
localNotifyUrl.searchParams.set('code', code);
}
localNotifyUrl.searchParams.set('state', state);
return NextResponse.redirect(localNotifyUrl);
}
const successUrl = buildRedirectUrl(req, '/oauth/callback/success');
const errorUrl = buildRedirectUrl(req, errorPathname);
// 添加调试日志
log('Request host header: %s', req.headers.get('host'));
log('Request x-forwarded-host: %s', req.headers.get('x-forwarded-host'));
log('Request x-forwarded-proto: %s', req.headers.get('x-forwarded-proto'));
log('Constructed success URL: %s', successUrl.toString());
if (error) {
errorUrl.searchParams.set('reason', error);
if (errorDescription) {
errorUrl.searchParams.set('errorMessage', errorDescription);
}
return NextResponse.redirect(errorUrl);
}
return NextResponse.redirect(successUrl);
} catch (error) {
log('Error in OIDC callback: %O', error);
+58
View File
@@ -6,6 +6,18 @@ import { serverDB } from '@/database/server';
const log = debug('lobe-oidc:handoff');
const DESKTOP_NOTIFY_PORT_MIN = 34210;
const DESKTOP_NOTIFY_PORT_MAX = 34219;
interface HandoffRegistrationBody {
client: string;
id: string;
notifyPort: number;
}
const isAllowedNotifyPort = (value: number): boolean =>
Number.isInteger(value) && value >= DESKTOP_NOTIFY_PORT_MIN && value <= DESKTOP_NOTIFY_PORT_MAX;
/**
* GET /oidc/handoff?id=xxx&client=xxx
* 轮询获取并消费认证凭证
@@ -44,3 +56,49 @@ export async function GET(request: NextRequest) {
return NextResponse.json({ error: 'Internal server error' }, { status: 500 });
}
}
/**
* POST /oidc/handoff
* Register desktop notifyPort for a given handoff id (state)
*/
export async function POST(request: NextRequest) {
log('Received POST request for /oidc/handoff');
let body: HandoffRegistrationBody;
try {
body = (await request.json()) as HandoffRegistrationBody;
} catch {
return NextResponse.json({ error: 'Invalid JSON body' }, { status: 400 });
}
const { client, id, notifyPort } = body;
if (!id || !client) {
return NextResponse.json(
{ error: 'Missing required parameters: id and client' },
{ status: 400 },
);
}
if (client !== 'desktop') {
return NextResponse.json({ error: 'Unsupported client type' }, { status: 400 });
}
if (!isAllowedNotifyPort(notifyPort)) {
return NextResponse.json({ error: 'Invalid notifyPort' }, { status: 400 });
}
try {
const authHandoffModel = new OAuthHandoffModel(serverDB);
const existing = await authHandoffModel.findActive(id, client);
const payload = { ...(existing?.payload || {}), notifyPort };
await authHandoffModel.upsertPayload(id, client, payload);
return NextResponse.json({ success: true });
} catch (error) {
log('Error registering notifyPort: %O', error);
return NextResponse.json({ error: 'Internal server error' }, { status: 500 });
}
}