mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
feat(acp): Add support for AI Gateway auth (#21305)
This commit is contained in:
@@ -208,7 +208,16 @@ describe('GeminiAgent', () => {
|
||||
});
|
||||
|
||||
expect(response.protocolVersion).toBe(acp.PROTOCOL_VERSION);
|
||||
expect(response.authMethods).toHaveLength(3);
|
||||
expect(response.authMethods).toHaveLength(4);
|
||||
const gatewayAuth = response.authMethods?.find(
|
||||
(m) => m.id === AuthType.GATEWAY,
|
||||
);
|
||||
expect(gatewayAuth?._meta).toEqual({
|
||||
gateway: {
|
||||
protocol: 'google',
|
||||
restartRequired: 'false',
|
||||
},
|
||||
});
|
||||
const geminiAuth = response.authMethods?.find(
|
||||
(m) => m.id === AuthType.USE_GEMINI,
|
||||
);
|
||||
@@ -228,6 +237,8 @@ describe('GeminiAgent', () => {
|
||||
expect(mockConfig.refreshAuth).toHaveBeenCalledWith(
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
expect(mockSettings.setValue).toHaveBeenCalledWith(
|
||||
SettingScope.User,
|
||||
@@ -247,6 +258,8 @@ describe('GeminiAgent', () => {
|
||||
expect(mockConfig.refreshAuth).toHaveBeenCalledWith(
|
||||
AuthType.USE_GEMINI,
|
||||
'test-api-key',
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
expect(mockSettings.setValue).toHaveBeenCalledWith(
|
||||
SettingScope.User,
|
||||
@@ -255,6 +268,45 @@ describe('GeminiAgent', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should authenticate correctly with gateway method', async () => {
|
||||
await agent.authenticate({
|
||||
methodId: AuthType.GATEWAY,
|
||||
_meta: {
|
||||
gateway: {
|
||||
baseUrl: 'https://example.com',
|
||||
headers: { Authorization: 'Bearer token' },
|
||||
},
|
||||
},
|
||||
} as unknown as acp.AuthenticateRequest);
|
||||
|
||||
expect(mockConfig.refreshAuth).toHaveBeenCalledWith(
|
||||
AuthType.GATEWAY,
|
||||
undefined,
|
||||
'https://example.com',
|
||||
{ Authorization: 'Bearer token' },
|
||||
);
|
||||
expect(mockSettings.setValue).toHaveBeenCalledWith(
|
||||
SettingScope.User,
|
||||
'security.auth.selectedType',
|
||||
AuthType.GATEWAY,
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw acp.RequestError when gateway payload is malformed', async () => {
|
||||
await expect(
|
||||
agent.authenticate({
|
||||
methodId: AuthType.GATEWAY,
|
||||
_meta: {
|
||||
gateway: {
|
||||
// Invalid baseUrl
|
||||
baseUrl: 123,
|
||||
headers: { Authorization: 'Bearer token' },
|
||||
},
|
||||
},
|
||||
} as unknown as acp.AuthenticateRequest),
|
||||
).rejects.toThrow(/Malformed gateway payload/);
|
||||
});
|
||||
|
||||
it('should create a new session', async () => {
|
||||
vi.useFakeTimers();
|
||||
mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({
|
||||
|
||||
@@ -98,6 +98,8 @@ export class GeminiAgent {
|
||||
private sessions: Map<string, Session> = new Map();
|
||||
private clientCapabilities: acp.ClientCapabilities | undefined;
|
||||
private apiKey: string | undefined;
|
||||
private baseUrl: string | undefined;
|
||||
private customHeaders: Record<string, string> | undefined;
|
||||
|
||||
constructor(
|
||||
private config: Config,
|
||||
@@ -131,6 +133,17 @@ export class GeminiAgent {
|
||||
name: 'Vertex AI',
|
||||
description: 'Use an API key with Vertex AI GenAI API',
|
||||
},
|
||||
{
|
||||
id: AuthType.GATEWAY,
|
||||
name: 'AI API Gateway',
|
||||
description: 'Use a custom AI API Gateway',
|
||||
_meta: {
|
||||
gateway: {
|
||||
protocol: 'google',
|
||||
restartRequired: 'false',
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
await this.config.initialize();
|
||||
@@ -179,7 +192,38 @@ export class GeminiAgent {
|
||||
if (apiKey) {
|
||||
this.apiKey = apiKey;
|
||||
}
|
||||
await this.config.refreshAuth(method, apiKey ?? this.apiKey);
|
||||
|
||||
// Extract gateway details if present
|
||||
const gatewaySchema = z.object({
|
||||
baseUrl: z.string().optional(),
|
||||
headers: z.record(z.string()).optional(),
|
||||
});
|
||||
|
||||
let baseUrl: string | undefined;
|
||||
let headers: Record<string, string> | undefined;
|
||||
|
||||
if (meta?.['gateway']) {
|
||||
const result = gatewaySchema.safeParse(meta['gateway']);
|
||||
if (result.success) {
|
||||
baseUrl = result.data.baseUrl;
|
||||
headers = result.data.headers;
|
||||
} else {
|
||||
throw new acp.RequestError(
|
||||
-32602,
|
||||
`Malformed gateway payload: ${result.error.message}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
this.baseUrl = baseUrl;
|
||||
this.customHeaders = headers;
|
||||
|
||||
await this.config.refreshAuth(
|
||||
method,
|
||||
apiKey ?? this.apiKey,
|
||||
baseUrl,
|
||||
headers,
|
||||
);
|
||||
} catch (e) {
|
||||
throw new acp.RequestError(-32000, getAcpErrorMessage(e));
|
||||
}
|
||||
@@ -209,7 +253,12 @@ export class GeminiAgent {
|
||||
let isAuthenticated = false;
|
||||
let authErrorMessage = '';
|
||||
try {
|
||||
await config.refreshAuth(authType, this.apiKey);
|
||||
await config.refreshAuth(
|
||||
authType,
|
||||
this.apiKey,
|
||||
this.baseUrl,
|
||||
this.customHeaders,
|
||||
);
|
||||
isAuthenticated = true;
|
||||
|
||||
// Extra validation for Gemini API key
|
||||
@@ -371,7 +420,12 @@ export class GeminiAgent {
|
||||
// This satisfies the security requirement to verify the user before executing
|
||||
// potentially unsafe server definitions.
|
||||
try {
|
||||
await config.refreshAuth(selectedAuthType, this.apiKey);
|
||||
await config.refreshAuth(
|
||||
selectedAuthType,
|
||||
this.apiKey,
|
||||
this.baseUrl,
|
||||
this.customHeaders,
|
||||
);
|
||||
} catch (e) {
|
||||
debugLogger.error(`Authentication failed: ${e}`);
|
||||
throw acp.RequestError.authRequired();
|
||||
|
||||
Reference in New Issue
Block a user