feat: add support for Gemini Enterprise (Discovery Engine) assistant

- Implemented EnterpriseAgentProtocol and EnterpriseAgentSession in core
- Authenticates seamlessly via Application Default Credentials (ADC)
- Added robust brace-counting JSON stream parser with partial chunk caching
- Extracted and rendered immersive docArtifacts (markdown tables) E2E
- Integrated with CLI config schema and enabled default 'all tools' execution
- Added comprehensive unit tests verifying all stream events (thoughts, tools, tables)

TAG=agy
CONV=81e82460-f8cd-4c7b-a037-2cbedda4d3c0
This commit is contained in:
Michael Bleigh
2026-05-18 00:59:55 +00:00
parent 5611ff40e7
commit fc8928c089
10 changed files with 760 additions and 19 deletions
+1
View File
@@ -974,6 +974,7 @@ export async function loadCliConfig(
mcpEnabled,
extensionsEnabled,
agents: settings.agents,
enterprise: settings.enterprise,
adminSkillsEnabled,
allowedMcpServers: mcpEnabled
? (argv.allowedMcpServerNames ?? settings.mcp?.allowed)
+38
View File
@@ -1366,6 +1366,44 @@ const SETTINGS_SCHEMA = {
},
},
},
enterprise: {
type: 'object',
label: 'Gemini Enterprise',
category: 'Advanced',
requiresRestart: true,
default: {},
description: 'Settings for Gemini Enterprise integration.',
showInDialog: false,
properties: {
projectId: {
type: 'string',
label: 'Project ID',
category: 'Advanced',
requiresRestart: true,
default: undefined as string | undefined,
description: 'Google Cloud Project ID for Gemini Enterprise.',
showInDialog: true,
},
engineId: {
type: 'string',
label: 'Engine ID',
category: 'Advanced',
requiresRestart: true,
default: undefined as string | undefined,
description: 'Discovery Engine ID for Gemini Enterprise.',
showInDialog: true,
},
location: {
type: 'string',
label: 'Location',
category: 'Advanced',
requiresRestart: true,
default: 'global',
description: 'Google Cloud Location for Gemini Enterprise.',
showInDialog: true,
},
},
},
context: {
type: 'object',
+12 -8
View File
@@ -507,7 +507,8 @@ export async function main() {
// the sandbox because the sandbox will interfere with the Oauth2 web
// redirect.
let initialAuthFailed = false;
if (!settings.merged.security.auth.useExternal && !argv.isCommand) {
const useEnterprise = process.env['GEMINI_CLI_ENTERPRISE_AGENT'] === 'true';
if (!settings.merged.security.auth.useExternal && !argv.isCommand && !useEnterprise) {
try {
if (
partialConfig.isInteractive() &&
@@ -858,13 +859,16 @@ export async function main() {
),
);
const authType = await validateNonInteractiveAuth(
settings.merged.security.auth.selectedType,
settings.merged.security.auth.useExternal,
config,
settings,
);
await config.refreshAuth(authType);
const useEnterprise = process.env['GEMINI_CLI_ENTERPRISE_AGENT'] === 'true';
if (!useEnterprise) {
const authType = await validateNonInteractiveAuth(
settings.merged.security.auth.selectedType,
settings.merged.security.auth.useExternal,
config,
settings,
);
await config.refreshAuth(authType);
}
if (config.getDebugMode()) {
debugLogger.log('Session ID: %s', sessionId);
+3 -1
View File
@@ -66,7 +66,9 @@ interface RunNonInteractiveParams {
export async function runNonInteractive(
params: RunNonInteractiveParams,
): Promise<void> {
const useAgentSession = params.config.getAgentSessionNoninteractiveEnabled();
const useAgentSession =
params.config.getAgentSessionNoninteractiveEnabled() ||
process.env['GEMINI_CLI_ENTERPRISE_AGENT'] === 'true';
if (useAgentSession) {
debugLogger.debug(
'[ADK] Running non-interactive mode with ADK agent session',
@@ -35,6 +35,7 @@ import {
Scheduler,
ROOT_SCHEDULER_ID,
LegacyAgentSession,
EnterpriseAgentSession,
ToolErrorType,
geminiPartsToContentParts,
displayContentToString,
@@ -295,13 +296,16 @@ export async function runNonInteractive({
});
}
// Create LegacyAgentSession — owns the agentic loop
const session = new LegacyAgentSession({
client: geminiClient,
scheduler,
config,
promptId: prompt_id,
});
const useEnterprise = process.env['GEMINI_CLI_ENTERPRISE_AGENT'] === 'true';
// Create AgentSession — owns the agentic loop
const session = useEnterprise
? new EnterpriseAgentSession({ config, promptId: prompt_id })
: new LegacyAgentSession({
client: geminiClient,
scheduler,
config,
promptId: prompt_id,
});
// Wire Ctrl+C to session abort
abortSession = () => {
+8 -3
View File
@@ -90,6 +90,7 @@ import {
logBillingEvent,
ApiKeyUpdatedEvent,
LegacyAgentProtocol,
EnterpriseAgentProtocol,
type InjectionSource,
} from '@google/gemini-cli-core';
import { validateAuthMethod } from '../config/auth.js';
@@ -1173,10 +1174,14 @@ Logging in with Google... Restarting Gemini CLI to continue.
}, [config]);
const streamAgent = useMemo(
() =>
config?.getAgentSessionInteractiveEnabled()
() => {
if (process.env['GEMINI_CLI_ENTERPRISE_AGENT'] === 'true') {
return new EnterpriseAgentProtocol({ config });
}
return config?.getAgentSessionInteractiveEnabled()
? new LegacyAgentProtocol({ config, getPreferredEditor })
: undefined,
: undefined;
},
[config, getPreferredEditor],
);
@@ -0,0 +1,253 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, expect, it, vi, beforeEach, afterEach, type Mock } from 'vitest';
import { EnterpriseAgentSession } from './enterprise-agent-session.js';
import type { Config } from '../config/config.js';
import type { AgentEvent } from './types.js';
import { GoogleAuth } from 'google-auth-library';
// Mock google-auth-library
vi.mock('google-auth-library', () => ({
GoogleAuth: vi.fn(),
}));
describe('EnterpriseAgentSession', () => {
let mockConfig: Config;
let globalFetch: typeof fetch;
beforeEach(() => {
vi.clearAllMocks();
(GoogleAuth as unknown as Mock).mockImplementation(() => ({
getClient: vi.fn().mockResolvedValue({
getAccessToken: vi.fn().mockResolvedValue({ token: 'fake-token' }),
}),
}));
mockConfig = {
getSessionId: vi.fn().mockReturnValue('test-session'),
getEnterpriseConfig: vi.fn().mockReturnValue({
projectId: 'test-project',
engineId: 'test-engine',
location: 'global',
}),
} as unknown as Config;
globalFetch = global.fetch;
});
afterEach(() => {
global.fetch = globalFetch;
vi.restoreAllMocks();
});
const mockFetchResponse = (chunks: string[]) => {
const stream = new ReadableStream({
start(controller) {
for (const chunk of chunks) {
controller.enqueue(new TextEncoder().encode(chunk));
}
controller.close();
},
});
global.fetch = vi.fn().mockResolvedValue({
ok: true,
body: stream,
headers: new Headers(),
} as Response);
};
it('should successfully call Enterprise API and stream responses', async () => {
const chunk1 = JSON.stringify({
sessionInfo: { session: 'projects/test-project/locations/global/collections/default_collection/engines/test-engine/sessions/s1' },
answer: {
replies: [
{
groundedContent: {
content: { text: 'Hello' },
},
},
],
},
}) + '\n';
const chunk2 = JSON.stringify({
answer: {
replies: [
{
groundedContent: {
content: { text: ' World' },
},
},
],
},
}) + '\n';
mockFetchResponse([chunk1, chunk2]);
const session = new EnterpriseAgentSession({ config: mockConfig });
const { streamId } = await session.send({
message: { content: [{ type: 'text', text: 'hi' }] },
});
expect(streamId).toBe('enterprise-stream-1');
const events: AgentEvent[] = [];
for await (const event of session.stream({ streamId: streamId! })) {
events.push(event);
}
expect(events.map(e => e.type)).toEqual([
'agent_start',
'message', // Hello
'message', // World
'agent_end',
]);
const messages = events.filter((e): e is AgentEvent<'message'> => e.type === 'message' && e.role === 'agent');
expect(messages[0].content).toEqual([{ type: 'text', text: 'Hello' }]);
expect(messages[1].content).toEqual([{ type: 'text', text: ' World' }]);
});
it('should handle thoughts', async () => {
const chunk = JSON.stringify({
answer: {
replies: [
{
groundedContent: {
content: { text: 'Thinking...', thought: true },
},
},
{
groundedContent: {
content: { text: 'Final answer' },
},
},
],
},
}) + '\n';
mockFetchResponse([chunk]);
const session = new EnterpriseAgentSession({ config: mockConfig });
const { streamId } = await session.send({
message: { content: [{ type: 'text', text: 'hi' }] },
});
const events: AgentEvent[] = [];
for await (const event of session.stream({ streamId: streamId! })) {
events.push(event);
}
const thoughts = events.filter((e): e is AgentEvent<'message'> => e.type === 'message' && e.content[0]?.type === 'thought');
expect(thoughts).toHaveLength(1);
expect(thoughts[0].content).toEqual([{ type: 'thought', thought: 'Thinking...' }]);
const texts = events.filter((e): e is AgentEvent<'message'> => e.type === 'message' && e.content[0]?.type === 'text' && e.role === 'agent');
expect(texts).toHaveLength(1);
expect(texts[0].content).toEqual([{ type: 'text', text: 'Final answer' }]);
});
it('should handle tool requests and responses (executable code)', async () => {
const chunk1 = JSON.stringify({
answer: {
replies: [
{
groundedContent: {
content: {
executableCode: { code: 'print("hello")' },
},
},
},
],
},
}) + '\n';
const chunk2 = JSON.stringify({
answer: {
replies: [
{
groundedContent: {
content: {
codeExecutionResult: { outcome: 'OUTCOME_OK', output: 'hello\n' },
},
},
},
],
},
}) + '\n';
mockFetchResponse([chunk1, chunk2]);
const session = new EnterpriseAgentSession({ config: mockConfig });
const { streamId } = await session.send({
message: { content: [{ type: 'text', text: 'run code' }] },
});
const events: AgentEvent[] = [];
for await (const event of session.stream({ streamId: streamId! })) {
events.push(event);
}
expect(events.map(e => e.type)).toEqual([
'agent_start',
'tool_request',
'tool_response',
'agent_end',
]);
const toolReq = events.find(e => e.type === 'tool_request') as AgentEvent<'tool_request'>;
expect(toolReq.name).toBe('python_interpreter');
expect(toolReq.args).toEqual({ code: 'print("hello")' });
const toolResp = events.find(e => e.type === 'tool_response') as AgentEvent<'tool_response'>;
expect(toolResp.name).toBe('python_interpreter');
expect(toolResp.content).toEqual([{ type: 'text', text: 'hello\n' }]);
expect(toolResp.isError).toBe(false);
});
it('should handle immersive artifacts (tables/docs)', async () => {
const chunk = JSON.stringify({
answer: {
replies: [
{
groundedContent: {
content: { text: 'Here is the table:\n' },
},
immersiveArtifact: [
{
docArtifact: { text: '| Col 1 | Col 2 |\n|---|---|\n| Val 1 | Val 2 |' },
},
],
},
],
},
}) + '\n';
mockFetchResponse([chunk]);
const session = new EnterpriseAgentSession({ config: mockConfig });
const { streamId } = await session.send({
message: { content: [{ type: 'text', text: 'show table' }] },
});
const events: AgentEvent[] = [];
for await (const event of session.stream({ streamId: streamId! })) {
events.push(event);
}
const texts = events.filter((e): e is AgentEvent<'message'> => e.type === 'message' && e.role === 'agent');
expect(texts).toHaveLength(2);
expect(texts[0].content).toEqual([{ type: 'text', text: 'Here is the table:\n' }]);
expect(texts[1].content).toEqual([{ type: 'text', text: '| Col 1 | Col 2 |\n|---|---|\n| Val 1 | Val 2 |' }]);
});
});
@@ -0,0 +1,420 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { GoogleAuth } from 'google-auth-library';
import type { Config } from '../config/config.js';
import { AgentSession } from './agent-session.js';
import type {
AgentEvent,
AgentProtocol,
AgentSend,
ContentPart,
Unsubscribe,
} from './types.js';
import { fetchWithTimeout } from '../utils/fetch.js';
import { debugLogger } from '../utils/debugLogger.js';
export interface EnterpriseAgentSessionDeps {
config: Config;
promptId?: string;
streamId?: string;
}
export class EnterpriseAgentProtocol implements AgentProtocol {
private _events: AgentEvent[] = [];
private _subscribers = new Set<(event: AgentEvent) => void>();
private _activeStreamId?: string;
private _abortController = new AbortController();
private _streamCounter = 0;
private _eventCounter = 0;
private readonly _config: Config;
private _sessionResourceName?: string;
constructor(deps: EnterpriseAgentSessionDeps) {
this._config = deps.config;
this._sessionResourceName = deps.streamId;
}
get events(): readonly AgentEvent[] {
return this._events;
}
subscribe(callback: (event: AgentEvent) => void): Unsubscribe {
this._subscribers.add(callback);
return () => {
this._subscribers.delete(callback);
};
}
async abort(): Promise<void> {
this._abortController.abort();
}
async send(payload: AgentSend): Promise<{ streamId: string }> {
const message = 'message' in payload ? payload.message : undefined;
if (!message) {
throw new Error(
'EnterpriseAgentSession.send() only supports message sends for the moment.',
);
}
if (this._activeStreamId) {
throw new Error(
'EnterpriseAgentSession.send() cannot be called while a stream is active.',
);
}
this._beginNewStream();
const streamId = this._activeStreamId!;
const userMessage = this._makeUserMessageEvent(
message.content,
message.displayContent,
payload._meta,
);
this._emit([userMessage]);
this._scheduleRunLoop(message.content.map(p => p.type === 'text' ? p.text : '').join(' '));
return { streamId };
}
private _beginNewStream(): void {
this._streamCounter++;
this._eventCounter = 0;
this._abortController = new AbortController();
this._activeStreamId = `enterprise-stream-${this._streamCounter}`;
}
private _scheduleRunLoop(queryText: string): void {
setTimeout(() => {
void this._runLoopInBackground(queryText);
}, 0);
}
private async _runLoopInBackground(queryText: string): Promise<void> {
this._ensureAgentStart();
try {
const enterpriseConfig = this._config.getEnterpriseConfig();
if (!enterpriseConfig?.projectId || !enterpriseConfig?.engineId) {
throw new Error('Gemini Enterprise is not fully configured. projectId and engineId are required in ~/.gemini/settings.json.');
}
const projectId = enterpriseConfig.projectId;
const engineId = enterpriseConfig.engineId;
const location = enterpriseConfig.location ?? 'global';
// Get Auth Token
const auth = new GoogleAuth({
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
});
const client = await auth.getClient();
const tokenResponse = await client.getAccessToken();
const token = tokenResponse.token;
if (!token) {
throw new Error('Failed to retrieve ADC access token.');
}
const endpoint = `https://discoveryengine.googleapis.com/v1alpha/projects/${projectId}/locations/${location}/collections/default_collection/engines/${engineId}/assistants/default_assistant:streamAssist`;
const requestBody = {
query: {
text: queryText,
},
session: this._sessionResourceName || '-',
assistSkippingMode: 'REQUEST_ASSIST',
toolsSpec: {
vertexAiSearchSpec: {},
webGroundingSpec: {},
},
};
debugLogger.debug(`Calling Enterprise API: ${endpoint}`);
debugLogger.debug(`Request Body: ${JSON.stringify(requestBody)}`);
const response = await fetchWithTimeout(endpoint, 60000, {
method: 'POST',
headers: {
'Authorization': `Bearer ${token}`,
'Content-Type': 'application/json',
},
body: JSON.stringify(requestBody),
signal: this._abortController.signal,
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`Enterprise API call failed with status ${response.status}: ${errorText}`);
}
if (!response.body) {
throw new Error('Response body is empty.');
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = '';
let braceCount = 0;
let insideString = false;
let escapeNext = false;
let objectStart = -1;
let lastScannedIndex = 0;
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
for (let i = lastScannedIndex; i < buffer.length; i++) {
const char = buffer[i];
if (escapeNext) {
escapeNext = false;
lastScannedIndex = i + 1;
continue;
}
if (char === '\\') {
escapeNext = true;
lastScannedIndex = i + 1;
continue;
}
if (char === '"') {
insideString = !insideString;
lastScannedIndex = i + 1;
continue;
}
if (!insideString) {
if (char === '{') {
if (braceCount === 0) {
objectStart = i;
}
braceCount++;
} else if (char === '}') {
braceCount--;
if (braceCount === 0 && objectStart !== -1) {
const objectStr = buffer.substring(objectStart, i + 1);
await this._parseAndEmitChunk(objectStr);
// Remove parsed object from buffer and reset index
buffer = buffer.substring(i + 1);
i = -1;
lastScannedIndex = 0;
objectStart = -1;
}
}
}
lastScannedIndex = i + 1;
}
}
this._emit([this._makeAgentEndEvent('completed')]);
} catch (err: unknown) {
if (this._abortController.signal.aborted) {
this._emit([this._makeAgentEndEvent('aborted')]);
} else {
const message = err instanceof Error ? err.message : String(err);
this._emit([
this._makeErrorEvent({
status: 'INTERNAL',
message,
fatal: true,
}),
]);
this._emit([this._makeAgentEndEvent('failed')]);
}
} finally {
this._activeStreamId = undefined;
}
}
private _ensureAgentStart(): void {
// We can emit agent_start early or wait until we get the first chunk.
// For now, emit it at the start of background run.
this._emit([this._makeAgentStartEvent()]);
}
private async _parseAndEmitChunk(line: string): Promise<void> {
try {
const response = JSON.parse(line);
debugLogger.debug(`Received Enterprise Chunk: ${JSON.stringify(response)}`);
if (response.sessionInfo?.session) {
this._sessionResourceName = response.sessionInfo.session;
debugLogger.debug(`Session updated: ${this._sessionResourceName}`);
}
const answer = response.answer;
if (answer) {
if (answer.replies) {
for (const reply of answer.replies) {
if (reply.groundedContent?.content) {
const content = reply.groundedContent.content;
// Handle Thought
if (content.thought) {
if (content.text) {
this._emit([
this._makeMessageEvent('agent', [
{ type: 'thought', thought: content.text }
])
]);
}
}
// Handle Tool Use (Executable Code)
else if (content.executableCode) {
this._emit([
this._makeToolRequestEvent({
requestId: `ent-tool-${Date.now()}`,
name: 'python_interpreter',
args: { code: content.executableCode.code },
display: {
name: 'Python Interpreter',
description: 'Executing code server-side',
}
})
]);
}
// Handle Tool Response (Code Execution Result)
else if (content.codeExecutionResult) {
this._emit([
this._makeToolResponseEvent({
requestId: `ent-tool-${Date.now()}`,
name: 'python_interpreter',
content: [{ type: 'text', text: content.codeExecutionResult.output || '' }],
isError: content.codeExecutionResult.outcome === 'OUTCOME_FAILED',
})
]);
}
// Handle standard Text
else if (content.text) {
this._emit([
this._makeMessageEvent('agent', [
{ type: 'text', text: content.text }
])
]);
}
}
if (reply.immersiveArtifact) {
for (const artifact of reply.immersiveArtifact) {
if (artifact.docArtifact?.text) {
this._emit([
this._makeMessageEvent('agent', [
{ type: 'text', text: artifact.docArtifact.text }
])
]);
}
}
}
}
}
}
} catch (e) {
debugLogger.error(`Failed to parse line: ${line}`, e);
}
}
private _emit(events: AgentEvent[]): void {
if (events.length === 0) return;
const subscribers = [...this._subscribers];
for (const event of events) {
this._events.push(event);
for (const subscriber of subscribers) {
subscriber(event);
}
}
}
private _nextEventFields() {
return {
id: `${this._activeStreamId}-${this._eventCounter++}`,
timestamp: new Date().toISOString(),
streamId: this._activeStreamId!,
};
}
private _makeUserMessageEvent(
content: ContentPart[],
displayContent?: string,
meta?: Record<string, unknown>,
): AgentEvent<'message'> {
const eventContent: ContentPart[] = displayContent
? [{ type: 'text', text: displayContent }]
: content;
return {
...this._nextEventFields(),
type: 'message',
role: 'user',
content: eventContent,
...(meta ? { _meta: meta } : {}),
};
}
private _makeMessageEvent(
role: 'agent' | 'user' | 'developer',
content: ContentPart[],
): AgentEvent<'message'> {
return {
...this._nextEventFields(),
type: 'message',
role,
content,
};
}
private _makeAgentStartEvent(): AgentEvent<'agent_start'> {
return {
...this._nextEventFields(),
type: 'agent_start',
};
}
private _makeAgentEndEvent(reason: string): AgentEvent<'agent_end'> {
return {
...this._nextEventFields(),
type: 'agent_end',
reason,
};
}
private _makeErrorEvent(payload: Omit<AgentEvent<'error'>, 'id' | 'timestamp' | 'streamId' | 'type'>): AgentEvent<'error'> {
return {
...this._nextEventFields(),
type: 'error',
...payload,
};
}
private _makeToolRequestEvent(payload: Omit<AgentEvent<'tool_request'>, 'id' | 'timestamp' | 'streamId' | 'type'>): AgentEvent<'tool_request'> {
return {
...this._nextEventFields(),
type: 'tool_request',
...payload,
};
}
private _makeToolResponseEvent(payload: Omit<AgentEvent<'tool_response'>, 'id' | 'timestamp' | 'streamId' | 'type'>): AgentEvent<'tool_response'> {
return {
...this._nextEventFields(),
type: 'tool_response',
...payload,
};
}
}
export class EnterpriseAgentSession extends AgentSession {
constructor(deps: EnterpriseAgentSessionDeps) {
super(new EnterpriseAgentProtocol(deps));
}
}
+13
View File
@@ -208,6 +208,12 @@ export interface PlanSettings {
modelRouting?: boolean;
}
export interface EnterpriseSettings {
projectId?: string;
engineId?: string;
location?: string;
}
export interface TelemetrySettings {
enabled?: boolean;
traces?: boolean;
@@ -742,6 +748,7 @@ export interface ConfigParameters {
};
vertexAiRouting?: VertexAiRoutingConfig;
logRagSnippets?: boolean;
enterprise?: EnterpriseSettings;
}
export class Config implements McpContext, AgentLoopContext {
@@ -978,12 +985,14 @@ export class Config implements McpContext, AgentLoopContext {
private lastModeSwitchTime: number = performance.now();
readonly injectionService: InjectionService;
private approvedPlanPath: string | undefined;
private readonly enterprise?: EnterpriseSettings;
constructor(params: ConfigParameters) {
this._sessionId = params.sessionId;
this.clientName = params.clientName;
this._clientVersion = params.clientVersion ?? 'unknown';
this.approvedPlanPath = undefined;
this.enterprise = params.enterprise;
this.embeddingModel =
params.embeddingModel ?? DEFAULT_GEMINI_EMBEDDING_MODEL;
@@ -1817,6 +1826,10 @@ export class Config implements McpContext, AgentLoopContext {
return this.clientName;
}
getEnterpriseConfig(): EnterpriseSettings | undefined {
return this.enterprise;
}
setSessionId(sessionId: string): void {
const previousPlansDir = this.storage.isInitialized()
? this.storage.getPlansDir()
+1
View File
@@ -196,6 +196,7 @@ export { resetBrowserSession } from './agents/browser/browserAgentFactory.js';
// Export agent session interface
export * from './agent/agent-session.js';
export * from './agent/legacy-agent-session.js';
export * from './agent/enterprise-agent-session.js';
export * from './agent/event-translator.js';
export * from './agent/content-utils.js';
export * from './agent/tool-display-utils.js';