mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-17 07:13:07 -07:00
feat(core): introduce Agent and AgentSession v1 with ReAct loop and event streaming
This commit is contained in:
@@ -0,0 +1,84 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { Agent } from './agent.js';
|
||||
import { AgentSession } from './session.js';
|
||||
import { makeFakeConfig } from '../test-utils/config.js';
|
||||
import { type AgentConfig } from './types.js';
|
||||
|
||||
vi.mock('./session.js', () => ({
|
||||
AgentSession: vi.fn().mockImplementation(() => ({
|
||||
prompt: vi.fn().mockImplementation(async function* () {
|
||||
yield { type: 'agent_start', value: { sessionId: 'test-session' } };
|
||||
yield {
|
||||
type: 'agent_finish',
|
||||
value: { sessionId: 'test-session', totalTurns: 1 },
|
||||
};
|
||||
}),
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('Agent', () => {
|
||||
let mockConfig: ReturnType<typeof makeFakeConfig>;
|
||||
const agentConfig: AgentConfig = {
|
||||
name: 'TestAgent',
|
||||
systemInstruction: 'You are a test agent.',
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockConfig = makeFakeConfig();
|
||||
vi.spyOn(mockConfig, 'getSessionId').mockReturnValue('global-session-id');
|
||||
});
|
||||
|
||||
it('should create an AgentSession', () => {
|
||||
const agent = new Agent(agentConfig, mockConfig);
|
||||
const session = agent.createSession('custom-session-id');
|
||||
|
||||
expect(session).toBeDefined();
|
||||
expect(AgentSession).toHaveBeenCalledWith(
|
||||
'custom-session-id',
|
||||
agentConfig,
|
||||
mockConfig,
|
||||
);
|
||||
});
|
||||
|
||||
it('should use global session ID if none provided to createSession', () => {
|
||||
const agent = new Agent(agentConfig, mockConfig);
|
||||
agent.createSession();
|
||||
|
||||
expect(AgentSession).toHaveBeenCalledWith(
|
||||
'global-session-id',
|
||||
agentConfig,
|
||||
mockConfig,
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass config and runtime to the session', () => {
|
||||
const agent = new Agent(agentConfig, mockConfig);
|
||||
agent.createSession('test-id');
|
||||
|
||||
expect(AgentSession).toHaveBeenCalledWith(
|
||||
'test-id',
|
||||
agentConfig,
|
||||
mockConfig,
|
||||
);
|
||||
});
|
||||
|
||||
it('should prompt through a new session', async () => {
|
||||
const agent = new Agent(agentConfig, mockConfig);
|
||||
const events = [];
|
||||
for await (const event of agent.prompt('Hello')) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events).toHaveLength(2);
|
||||
expect(events[0].type).toBe('agent_start');
|
||||
expect(events[1].type).toBe('agent_finish');
|
||||
expect(AgentSession).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,41 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type Part } from '@google/genai';
|
||||
import { type Config } from '../config/config.js';
|
||||
import { type AgentEvent, type AgentConfig } from './types.js';
|
||||
import { AgentSession } from './session.js';
|
||||
|
||||
/**
|
||||
* The Agent class is a factory for creating stateful AgentSessions.
|
||||
* This represents a configured agent template.
|
||||
*/
|
||||
export class Agent {
|
||||
constructor(
|
||||
private readonly config: AgentConfig,
|
||||
private readonly runtime: Config,
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Creates a new stateful session for interacting with the agent.
|
||||
*/
|
||||
createSession(sessionId?: string): AgentSession {
|
||||
const id = sessionId ?? this.runtime.getSessionId();
|
||||
return new AgentSession(id, this.config, this.runtime);
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to quickly run a single prompt and get the results.
|
||||
*/
|
||||
async *prompt(
|
||||
input: string | Part[],
|
||||
sessionId?: string,
|
||||
signal?: AbortSignal,
|
||||
): AsyncIterable<AgentEvent> {
|
||||
const session = this.createSession(sessionId);
|
||||
yield* session.prompt(input, signal);
|
||||
}
|
||||
}
|
||||
@@ -549,7 +549,8 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
if (
|
||||
terminateReason !== AgentTerminateMode.ERROR &&
|
||||
terminateReason !== AgentTerminateMode.ABORTED &&
|
||||
terminateReason !== AgentTerminateMode.GOAL
|
||||
terminateReason !== AgentTerminateMode.GOAL &&
|
||||
terminateReason !== AgentTerminateMode.LOOP
|
||||
) {
|
||||
const recoveryResult = await this.executeFinalWarningTurn(
|
||||
chat,
|
||||
|
||||
@@ -0,0 +1,478 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { AgentSession } from './session.js';
|
||||
import { makeFakeConfig } from '../test-utils/config.js';
|
||||
import {
|
||||
type AgentConfig,
|
||||
AgentTerminateMode,
|
||||
type AgentEvent,
|
||||
} from './types.js';
|
||||
import { Scheduler } from '../scheduler/scheduler.js';
|
||||
import { GeminiEventType, CompressionStatus } from '../core/turn.js';
|
||||
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
||||
import {
|
||||
MessageBusType,
|
||||
type ToolCallsUpdateMessage,
|
||||
} from '../confirmation-bus/types.js';
|
||||
import {
|
||||
CoreToolCallStatus,
|
||||
type ToolCallRequestInfo,
|
||||
} from '../scheduler/types.js';
|
||||
import { type ResumedSessionData } from '../services/chatRecordingService.js';
|
||||
|
||||
vi.mock('../core/client.js');
|
||||
vi.mock('../scheduler/scheduler.js');
|
||||
vi.mock('../services/chatCompressionService.js');
|
||||
|
||||
describe('AgentSession', () => {
|
||||
let mockConfig: ReturnType<typeof makeFakeConfig>;
|
||||
let mockClient: {
|
||||
sendMessageStream: ReturnType<typeof vi.fn>;
|
||||
getChat: ReturnType<typeof vi.fn>;
|
||||
getCurrentSequenceModel: ReturnType<typeof vi.fn>;
|
||||
getHistory: ReturnType<typeof vi.fn>;
|
||||
resumeChat: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
let mockScheduler: {
|
||||
schedule: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
let mockCompressionService: {
|
||||
compress: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
let session: AgentSession;
|
||||
const agentConfig: AgentConfig = {
|
||||
name: 'TestAgent',
|
||||
capabilities: { compression: true },
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockConfig = makeFakeConfig();
|
||||
|
||||
mockClient = {
|
||||
sendMessageStream: vi.fn(),
|
||||
getChat: vi.fn().mockReturnValue({
|
||||
recordCompletedToolCalls: vi.fn(),
|
||||
setHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
}),
|
||||
getCurrentSequenceModel: vi.fn().mockReturnValue('test-model'),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
resumeChat: vi.fn(),
|
||||
};
|
||||
|
||||
mockScheduler = {
|
||||
schedule: vi.fn(),
|
||||
};
|
||||
|
||||
mockCompressionService = {
|
||||
compress: vi.fn().mockResolvedValue({
|
||||
newHistory: null,
|
||||
info: { compressionStatus: CompressionStatus.NOOP },
|
||||
}),
|
||||
};
|
||||
|
||||
vi.spyOn(mockConfig, 'getGeminiClient').mockReturnValue(
|
||||
mockClient as unknown as import('../core/client.js').GeminiClient,
|
||||
);
|
||||
vi.mocked(Scheduler).mockImplementation(
|
||||
(options) =>
|
||||
({
|
||||
...mockScheduler,
|
||||
schedulerId: (options as { schedulerId: string }).schedulerId,
|
||||
}) as unknown as Scheduler,
|
||||
);
|
||||
vi.mocked(ChatCompressionService).mockImplementation(
|
||||
() => mockCompressionService as unknown as ChatCompressionService,
|
||||
);
|
||||
|
||||
session = new AgentSession('test-session', agentConfig, mockConfig);
|
||||
});
|
||||
|
||||
it('should emit agent_start and agent_finish', async () => {
|
||||
mockClient.sendMessageStream.mockImplementation(async function* () {
|
||||
yield { type: GeminiEventType.Content, value: 'Hello' };
|
||||
yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } };
|
||||
});
|
||||
|
||||
const events = [];
|
||||
for await (const event of session.prompt('Hi')) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const finishEvent = events[events.length - 1] as Extract<
|
||||
AgentEvent,
|
||||
{ type: 'agent_finish' }
|
||||
>;
|
||||
expect(events[0].type).toBe('agent_start');
|
||||
expect(finishEvent.type).toBe('agent_finish');
|
||||
expect(finishEvent.value.reason).toBe(AgentTerminateMode.GOAL);
|
||||
expect(mockClient.sendMessageStream).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle tool calls and execute them via MessageBus updates', async () => {
|
||||
mockClient.sendMessageStream.mockImplementationOnce(async function* () {
|
||||
yield {
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: { callId: 'call1', name: 'test_tool', args: {} },
|
||||
};
|
||||
yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } };
|
||||
});
|
||||
|
||||
mockClient.sendMessageStream.mockImplementationOnce(async function* () {
|
||||
yield { type: GeminiEventType.Content, value: 'Tool executed' };
|
||||
yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } };
|
||||
});
|
||||
|
||||
const toolResponse = {
|
||||
response: {
|
||||
callId: 'call1',
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: 'test_tool',
|
||||
response: { ok: true },
|
||||
id: 'call1',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
|
||||
mockScheduler.schedule.mockImplementation(async () => {
|
||||
const bus = mockConfig.getMessageBus();
|
||||
const schedulerId = (session as unknown as { schedulerId: string })
|
||||
.schedulerId;
|
||||
|
||||
await bus.publish({
|
||||
type: MessageBusType.TOOL_CALLS_UPDATE,
|
||||
schedulerId,
|
||||
toolCalls: [
|
||||
{
|
||||
request: { callId: 'call1', name: 'test_tool', args: {} },
|
||||
status: CoreToolCallStatus.Executing,
|
||||
schedulerId,
|
||||
} as unknown as ToolCallsUpdateMessage['toolCalls'][number],
|
||||
],
|
||||
});
|
||||
|
||||
await bus.publish({
|
||||
type: MessageBusType.TOOL_CALLS_UPDATE,
|
||||
schedulerId,
|
||||
toolCalls: [
|
||||
{
|
||||
request: { callId: 'call1', name: 'test_tool', args: {} },
|
||||
status: CoreToolCallStatus.Success,
|
||||
response: toolResponse.response,
|
||||
schedulerId,
|
||||
} as unknown as ToolCallsUpdateMessage['toolCalls'][number],
|
||||
],
|
||||
});
|
||||
|
||||
return [toolResponse];
|
||||
});
|
||||
|
||||
const events = [];
|
||||
for await (const event of session.prompt('Run tool')) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(mockClient.sendMessageStream).toHaveBeenCalledTimes(2);
|
||||
expect(mockScheduler.schedule).toHaveBeenCalledTimes(1);
|
||||
|
||||
const callStart = events.find((e) => e.type === 'tool_call_start');
|
||||
const callFinish = events.find((e) => e.type === 'tool_call_finish');
|
||||
expect(callStart).toBeDefined();
|
||||
expect(callFinish).toBeDefined();
|
||||
expect(
|
||||
(callFinish as Extract<AgentEvent, { type: 'tool_call_finish' }>).value
|
||||
.callId,
|
||||
).toBe('call1');
|
||||
});
|
||||
|
||||
it('should handle multiple consecutive ReAct turns', async () => {
|
||||
// Turn 1: tool1
|
||||
mockClient.sendMessageStream.mockImplementationOnce(async function* () {
|
||||
yield {
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: { callId: 'c1', name: 'tool1', args: {} },
|
||||
};
|
||||
yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } };
|
||||
});
|
||||
// Turn 2: tool2
|
||||
mockClient.sendMessageStream.mockImplementationOnce(async function* () {
|
||||
yield {
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: { callId: 'c2', name: 'tool2', args: {} },
|
||||
};
|
||||
yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } };
|
||||
});
|
||||
// Turn 3: final content
|
||||
mockClient.sendMessageStream.mockImplementationOnce(async function* () {
|
||||
yield { type: GeminiEventType.Content, value: 'All done' };
|
||||
yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } };
|
||||
});
|
||||
|
||||
mockScheduler.schedule.mockImplementation(async (calls) =>
|
||||
(calls as ToolCallRequestInfo[]).map((c) => ({
|
||||
response: { callId: c.callId, responseParts: [] },
|
||||
})),
|
||||
);
|
||||
|
||||
const events = [];
|
||||
for await (const event of session.prompt('Start multistep')) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(mockClient.sendMessageStream).toHaveBeenCalledTimes(3);
|
||||
expect(mockScheduler.schedule).toHaveBeenCalledTimes(2);
|
||||
expect(events.filter((e) => e.type === 'tool_suite_start')).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('should handle parallel tool calls in a single turn', async () => {
|
||||
mockClient.sendMessageStream.mockImplementationOnce(async function* () {
|
||||
yield {
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: { callId: 'p1', name: 'toolA', args: {} },
|
||||
};
|
||||
yield {
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: { callId: 'p2', name: 'toolB', args: {} },
|
||||
};
|
||||
yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } };
|
||||
});
|
||||
mockClient.sendMessageStream.mockImplementationOnce(async function* () {
|
||||
yield { type: GeminiEventType.Content, value: 'Parallel done' };
|
||||
yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } };
|
||||
});
|
||||
|
||||
mockScheduler.schedule.mockImplementation(async (calls) =>
|
||||
(calls as ToolCallRequestInfo[]).map((c) => ({
|
||||
response: { callId: c.callId, responseParts: [] },
|
||||
})),
|
||||
);
|
||||
|
||||
const events = [];
|
||||
for await (const event of session.prompt('Parallel')) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const suiteStart = events.find(
|
||||
(e) => e.type === 'tool_suite_start',
|
||||
) as Extract<AgentEvent, { type: 'tool_suite_start' }>;
|
||||
expect(suiteStart.value.count).toBe(2);
|
||||
expect(mockScheduler.schedule).toHaveBeenCalledTimes(1);
|
||||
expect(mockScheduler.schedule).toHaveBeenCalledWith(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({ callId: 'p1' }),
|
||||
expect.objectContaining({ callId: 'p2' }),
|
||||
]),
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should resume from session data', async () => {
|
||||
const resumeData = {
|
||||
conversation: {
|
||||
messages: [{ type: 'user', content: 'Hello' }],
|
||||
},
|
||||
} as unknown as ResumedSessionData;
|
||||
|
||||
await session.resume(resumeData);
|
||||
|
||||
expect(mockClient.resumeChat).toHaveBeenCalledWith(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({ role: 'user', parts: [{ text: 'Hello' }] }),
|
||||
]),
|
||||
resumeData,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle model stream errors gracefully', async () => {
|
||||
mockClient.sendMessageStream.mockImplementationOnce(async function* () {
|
||||
yield* []; // satisfy require-yield
|
||||
throw new Error('Model connection lost');
|
||||
});
|
||||
|
||||
const events = [];
|
||||
try {
|
||||
for await (const event of session.prompt('Error test')) {
|
||||
events.push(event);
|
||||
}
|
||||
} catch (_e) {
|
||||
// Expected error
|
||||
}
|
||||
|
||||
const finishEvent = events.find(
|
||||
(e) => e.type === 'agent_finish',
|
||||
) as Extract<AgentEvent, { type: 'agent_finish' }>;
|
||||
expect(finishEvent).toBeDefined();
|
||||
});
|
||||
|
||||
it('should ignore MessageBus updates from other schedulers', async () => {
|
||||
mockClient.sendMessageStream.mockImplementationOnce(async function* () {
|
||||
yield {
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: { callId: 'call1', name: 'test_tool', args: {} },
|
||||
};
|
||||
yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } };
|
||||
});
|
||||
mockClient.sendMessageStream.mockImplementationOnce(async function* () {
|
||||
yield { type: GeminiEventType.Content, value: 'Done' };
|
||||
yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } };
|
||||
});
|
||||
|
||||
mockScheduler.schedule.mockImplementation(async () => {
|
||||
const bus = mockConfig.getMessageBus();
|
||||
|
||||
// Update from ANOTHER scheduler
|
||||
await bus.publish({
|
||||
type: MessageBusType.TOOL_CALLS_UPDATE,
|
||||
schedulerId: 'different-scheduler',
|
||||
toolCalls: [
|
||||
{
|
||||
request: { callId: 'call1', name: 'test_tool', args: {} },
|
||||
status: CoreToolCallStatus.Executing,
|
||||
schedulerId: 'different-scheduler',
|
||||
} as unknown as ToolCallsUpdateMessage['toolCalls'][number],
|
||||
],
|
||||
});
|
||||
|
||||
return [{ response: { callId: 'call1', responseParts: [] } }];
|
||||
});
|
||||
|
||||
const events = [];
|
||||
for await (const event of session.prompt('Isolation test')) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// Should NOT see tool_call_start because it was from a different schedulerId
|
||||
expect(events.find((e) => e.type === 'tool_call_start')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should terminate with LOOP when loop is detected by model', async () => {
|
||||
mockClient.sendMessageStream.mockImplementationOnce(async function* () {
|
||||
yield { type: GeminiEventType.LoopDetected };
|
||||
yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } };
|
||||
});
|
||||
|
||||
const events = [];
|
||||
for await (const event of session.prompt('Loop')) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const finishEvent = events.find(
|
||||
(e) => e.type === 'agent_finish',
|
||||
) as Extract<AgentEvent, { type: 'agent_finish' }>;
|
||||
expect(finishEvent.value.reason).toBe(AgentTerminateMode.LOOP);
|
||||
expect(finishEvent.value.message).toContain('Loop detected');
|
||||
});
|
||||
|
||||
it('should handle Part[] input correctly', async () => {
|
||||
mockClient.sendMessageStream.mockImplementationOnce(async function* () {
|
||||
yield { type: GeminiEventType.Content, value: 'I see parts' };
|
||||
yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } };
|
||||
});
|
||||
|
||||
const input = [{ text: 'Hello' }, { text: 'World' }];
|
||||
for await (const _ of session.prompt(input)) {
|
||||
// consume
|
||||
}
|
||||
|
||||
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
|
||||
input,
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
undefined,
|
||||
false,
|
||||
input,
|
||||
);
|
||||
});
|
||||
|
||||
it('should trigger compression if enabled', async () => {
|
||||
mockClient.sendMessageStream.mockImplementation(async function* () {
|
||||
yield { type: GeminiEventType.Content, value: 'Done' };
|
||||
yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } };
|
||||
});
|
||||
|
||||
for await (const _ of session.prompt('Compress me')) {
|
||||
// consume stream to trigger compression
|
||||
}
|
||||
|
||||
expect(mockCompressionService.compress).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should respect abort signal', async () => {
|
||||
const controller = new AbortController();
|
||||
mockClient.sendMessageStream.mockImplementation(async function* () {
|
||||
yield { type: GeminiEventType.Content, value: 'Thinking...' };
|
||||
controller.abort();
|
||||
yield { type: GeminiEventType.Content, value: 'Still thinking...' };
|
||||
});
|
||||
|
||||
const events = [];
|
||||
for await (const event of session.prompt('Long task', controller.signal)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const finishEvent = events[events.length - 1] as Extract<
|
||||
AgentEvent,
|
||||
{ type: 'agent_finish' }
|
||||
>;
|
||||
expect(finishEvent.type).toBe('agent_finish');
|
||||
expect(finishEvent.value.reason).toBe(AgentTerminateMode.ABORTED);
|
||||
});
|
||||
|
||||
it('should respect maxTurns from config', async () => {
|
||||
const customSession = new AgentSession(
|
||||
'test-session-2',
|
||||
{ ...agentConfig, maxTurns: 2 },
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
mockClient.sendMessageStream.mockImplementation(async function* () {
|
||||
yield {
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: { callId: 'call', name: 'test_tool', args: {} },
|
||||
};
|
||||
yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } };
|
||||
});
|
||||
|
||||
mockScheduler.schedule.mockResolvedValue([
|
||||
{
|
||||
response: {
|
||||
callId: 'call',
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: 'test_tool',
|
||||
response: { ok: true },
|
||||
id: 'call',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
const events = [];
|
||||
for await (const event of customSession.prompt('Start loop')) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(mockScheduler.schedule).toHaveBeenCalledTimes(2);
|
||||
|
||||
const finishEvent = events[events.length - 1] as Extract<
|
||||
AgentEvent,
|
||||
{ type: 'agent_finish' }
|
||||
>;
|
||||
expect(finishEvent.type).toBe('agent_finish');
|
||||
expect(finishEvent.value.totalTurns).toBe(2);
|
||||
expect(finishEvent.value.reason).toBe(AgentTerminateMode.MAX_TURNS);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,384 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type Part } from '@google/genai';
|
||||
import { type Config } from '../config/config.js';
|
||||
import { type GeminiClient } from '../core/client.js';
|
||||
import { type AgentEvent, type AgentConfig } from './types.js';
|
||||
import { Scheduler } from '../scheduler/scheduler.js';
|
||||
import {
|
||||
type ToolCallRequestInfo,
|
||||
type ToolCallResponseInfo,
|
||||
CoreToolCallStatus,
|
||||
} from '../scheduler/types.js';
|
||||
import { GeminiEventType, CompressionStatus } from '../core/turn.js';
|
||||
import { recordToolCallInteractions } from '../code_assist/telemetry.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { ToolErrorType } from '../tools/tool-error.js';
|
||||
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
||||
import { AgentTerminateMode } from './types.js';
|
||||
import type { ResumedSessionData } from '../services/chatRecordingService.js';
|
||||
import { convertSessionToClientHistory } from '../utils/sessionUtils.js';
|
||||
import {
|
||||
MessageBusType,
|
||||
type ToolCallsUpdateMessage,
|
||||
} from '../confirmation-bus/types.js';
|
||||
|
||||
/**
|
||||
* AgentSession manages the state of a conversation and orchestrates the agent
|
||||
* loop.
|
||||
*/
|
||||
export class AgentSession {
|
||||
private readonly client: GeminiClient;
|
||||
private readonly scheduler: Scheduler;
|
||||
private readonly schedulerId: string;
|
||||
private readonly compressionService: ChatCompressionService;
|
||||
private totalTurns = 0;
|
||||
private hasFailedCompressionAttempt = false;
|
||||
|
||||
constructor(
|
||||
private readonly sessionId: string,
|
||||
private readonly config: AgentConfig,
|
||||
private readonly runtime: Config,
|
||||
) {
|
||||
this.client = this.runtime.getGeminiClient();
|
||||
this.schedulerId = `agent-scheduler-${this.sessionId}-${Math.random().toString(36).substring(2, 9)}`;
|
||||
this.scheduler = new Scheduler({
|
||||
config: this.runtime,
|
||||
messageBus: this.runtime.getMessageBus(),
|
||||
getPreferredEditor: () => undefined,
|
||||
schedulerId: this.schedulerId,
|
||||
});
|
||||
this.compressionService = new ChatCompressionService();
|
||||
}
|
||||
|
||||
/**
|
||||
* Resumes the agent session from persistent storage data.
|
||||
* Hydrates the internal language model client with the previously saved trajectory.
|
||||
*
|
||||
* @param resumedSessionData The raw payload of a previously saved session.
|
||||
*/
|
||||
async resume(resumedSessionData: ResumedSessionData): Promise<void> {
|
||||
const clientHistory = convertSessionToClientHistory(
|
||||
resumedSessionData.conversation.messages,
|
||||
);
|
||||
await this.client.resumeChat(clientHistory, resumedSessionData);
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes the ReAct loop for a given user input.
|
||||
* Returns an AsyncIterable of events occurring during the session.
|
||||
*/
|
||||
async *prompt(
|
||||
input: string | Part[],
|
||||
signal?: AbortSignal,
|
||||
): AsyncIterable<AgentEvent> {
|
||||
yield {
|
||||
type: 'agent_start',
|
||||
value: { sessionId: this.sessionId },
|
||||
};
|
||||
|
||||
let currentInput = input;
|
||||
let isContinuation = false;
|
||||
const maxTurns = this.config.maxTurns ?? -1;
|
||||
|
||||
let terminationReason = AgentTerminateMode.GOAL;
|
||||
let terminationMessage: string | undefined = undefined;
|
||||
let terminationError: unknown | undefined = undefined;
|
||||
|
||||
try {
|
||||
while (maxTurns === -1 || this.totalTurns < maxTurns) {
|
||||
if (signal?.aborted) {
|
||||
terminationReason = AgentTerminateMode.ABORTED;
|
||||
break;
|
||||
}
|
||||
|
||||
this.totalTurns++;
|
||||
const promptId = `${this.sessionId}#${this.totalTurns}`;
|
||||
|
||||
// Compression check (from LocalAgentExecutor / useGeminiStream patterns)
|
||||
if (this.config.capabilities?.compression) {
|
||||
await this.tryCompressChat(promptId);
|
||||
}
|
||||
|
||||
const results = await this.runModelTurn(
|
||||
currentInput,
|
||||
promptId,
|
||||
isContinuation ? undefined : input,
|
||||
signal,
|
||||
);
|
||||
|
||||
for await (const event of results.events) {
|
||||
yield event;
|
||||
}
|
||||
|
||||
if (results.loopDetected) {
|
||||
terminationReason = AgentTerminateMode.LOOP;
|
||||
terminationMessage = 'Loop detected, stopping execution';
|
||||
break;
|
||||
}
|
||||
|
||||
if (signal?.aborted) {
|
||||
terminationReason = AgentTerminateMode.ABORTED;
|
||||
break;
|
||||
}
|
||||
|
||||
if (results.toolCalls.length > 0) {
|
||||
const toolRun = this.executeTools(results.toolCalls, signal);
|
||||
let resultsTools;
|
||||
while (true) {
|
||||
const { value, done } = await toolRun.next();
|
||||
if (done) {
|
||||
resultsTools = value;
|
||||
break;
|
||||
}
|
||||
yield value;
|
||||
}
|
||||
|
||||
if (resultsTools.stopExecution || (signal && signal.aborted)) {
|
||||
if (signal && signal.aborted) {
|
||||
terminationReason = AgentTerminateMode.ABORTED;
|
||||
} else if (resultsTools.stopExecutionInfo) {
|
||||
terminationReason = AgentTerminateMode.ERROR;
|
||||
terminationMessage =
|
||||
resultsTools.stopExecutionInfo.error?.message;
|
||||
terminationError = resultsTools.stopExecutionInfo.error;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Check if we hit the turn limit
|
||||
if (maxTurns !== -1 && this.totalTurns >= maxTurns) {
|
||||
terminationReason = AgentTerminateMode.MAX_TURNS;
|
||||
terminationMessage = 'Maximum session turns exceeded.';
|
||||
break;
|
||||
}
|
||||
|
||||
currentInput = resultsTools.nextParts;
|
||||
isContinuation = true;
|
||||
} else {
|
||||
// No more tool calls, turn is complete.
|
||||
terminationReason = AgentTerminateMode.GOAL;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
yield {
|
||||
type: 'agent_finish',
|
||||
value: {
|
||||
sessionId: this.sessionId,
|
||||
totalTurns: this.totalTurns,
|
||||
reason: terminationReason,
|
||||
message: terminationMessage,
|
||||
error: terminationError,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls the model and yields the event stream.
|
||||
* Collects tool call requests for the next phase.
|
||||
*/
|
||||
private async runModelTurn(
|
||||
input: string | Part[],
|
||||
promptId: string,
|
||||
displayContent?: string | Part[],
|
||||
signal?: AbortSignal,
|
||||
): Promise<{
|
||||
toolCalls: ToolCallRequestInfo[];
|
||||
events: AsyncIterable<AgentEvent>;
|
||||
loopDetected: boolean;
|
||||
}> {
|
||||
const parts = Array.isArray(input) ? input : [{ text: input }];
|
||||
const toolCalls: ToolCallRequestInfo[] = [];
|
||||
let loopDetected = false;
|
||||
|
||||
const stream = this.client.sendMessageStream(
|
||||
parts,
|
||||
signal ?? new AbortController().signal,
|
||||
promptId,
|
||||
undefined, // maxTurns (client handles its own)
|
||||
false, // isInvalidStreamRetry
|
||||
displayContent,
|
||||
);
|
||||
|
||||
const eventGenerator = async function* (): AsyncIterable<AgentEvent> {
|
||||
for await (const event of stream) {
|
||||
if (event.type === GeminiEventType.ToolCallRequest) {
|
||||
toolCalls.push(event.value);
|
||||
} else if (event.type === GeminiEventType.LoopDetected) {
|
||||
loopDetected = true;
|
||||
}
|
||||
yield event as AgentEvent;
|
||||
}
|
||||
};
|
||||
|
||||
const events = eventGenerator();
|
||||
|
||||
return {
|
||||
toolCalls,
|
||||
events,
|
||||
get loopDetected() {
|
||||
return loopDetected;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes a batch of tool calls via the Scheduler.
|
||||
*/
|
||||
private async *executeTools(
|
||||
toolCalls: ToolCallRequestInfo[],
|
||||
signal?: AbortSignal,
|
||||
): AsyncGenerator<
|
||||
AgentEvent,
|
||||
{
|
||||
nextParts: Part[];
|
||||
stopExecution: boolean;
|
||||
stopExecutionInfo: ToolCallResponseInfo | undefined;
|
||||
}
|
||||
> {
|
||||
yield {
|
||||
type: 'tool_suite_start',
|
||||
value: { count: toolCalls.length },
|
||||
};
|
||||
|
||||
const eventQueue: AgentEvent[] = [];
|
||||
let resolveNext: (() => void) | undefined;
|
||||
let isFinished = false;
|
||||
|
||||
// Track seen status transitions to avoid duplicate events
|
||||
const seenStatuses = new Map<string, CoreToolCallStatus>();
|
||||
|
||||
const messageBus = this.runtime.getMessageBus();
|
||||
const onToolUpdate = (message: ToolCallsUpdateMessage) => {
|
||||
if (message.schedulerId !== this.schedulerId) return;
|
||||
|
||||
for (const call of message.toolCalls) {
|
||||
const prevStatus = seenStatuses.get(call.request.callId);
|
||||
if (prevStatus === call.status) continue;
|
||||
|
||||
if (call.status === CoreToolCallStatus.Executing) {
|
||||
eventQueue.push({ type: 'tool_call_start', value: call.request });
|
||||
} else if (
|
||||
call.status === CoreToolCallStatus.Success ||
|
||||
call.status === CoreToolCallStatus.Error ||
|
||||
call.status === CoreToolCallStatus.Cancelled
|
||||
) {
|
||||
eventQueue.push({
|
||||
type: 'tool_call_finish',
|
||||
value: call.response,
|
||||
});
|
||||
}
|
||||
seenStatuses.set(call.request.callId, call.status);
|
||||
}
|
||||
resolveNext?.();
|
||||
};
|
||||
|
||||
messageBus.subscribe(MessageBusType.TOOL_CALLS_UPDATE, onToolUpdate);
|
||||
|
||||
const schedulePromise = this.scheduler.schedule(
|
||||
toolCalls,
|
||||
signal ?? new AbortController().signal,
|
||||
);
|
||||
|
||||
try {
|
||||
while (!isFinished || eventQueue.length > 0) {
|
||||
if (eventQueue.length > 0) {
|
||||
const event = eventQueue.shift();
|
||||
if (event) yield event;
|
||||
} else {
|
||||
const waitNext = new Promise<void>((resolve) => {
|
||||
resolveNext = resolve;
|
||||
});
|
||||
await Promise.race([
|
||||
waitNext,
|
||||
schedulePromise.then(() => {
|
||||
isFinished = true;
|
||||
resolveNext?.();
|
||||
}),
|
||||
]);
|
||||
}
|
||||
}
|
||||
|
||||
const completedCalls = await schedulePromise;
|
||||
|
||||
yield {
|
||||
type: 'tool_suite_finish',
|
||||
value: { responses: completedCalls.map((c) => c.response) },
|
||||
};
|
||||
|
||||
// Record tool call info for persistence/telemetry
|
||||
try {
|
||||
const currentModel =
|
||||
this.client.getCurrentSequenceModel() ?? this.runtime.getModel();
|
||||
this.client
|
||||
.getChat()
|
||||
.recordCompletedToolCalls(currentModel, completedCalls);
|
||||
await recordToolCallInteractions(this.runtime, completedCalls);
|
||||
} catch (e) {
|
||||
debugLogger.warn(`Error recording tool call information: ${e}`);
|
||||
}
|
||||
|
||||
const nextParts = completedCalls.flatMap((c) => c.response.responseParts);
|
||||
const stopExecutionInfo = completedCalls.find(
|
||||
(c) => c.response.errorType === ToolErrorType.STOP_EXECUTION,
|
||||
)?.response;
|
||||
|
||||
return {
|
||||
nextParts,
|
||||
stopExecution: !!stopExecutionInfo,
|
||||
stopExecutionInfo,
|
||||
};
|
||||
} finally {
|
||||
messageBus.unsubscribe(MessageBusType.TOOL_CALLS_UPDATE, onToolUpdate);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to compress the chat history if thresholds are exceeded.
|
||||
*/
|
||||
private async tryCompressChat(promptId: string): Promise<void> {
|
||||
const chat = this.client.getChat();
|
||||
const model = this.config.model ?? this.runtime.getModel();
|
||||
|
||||
const { newHistory, info } = await this.compressionService.compress(
|
||||
chat,
|
||||
promptId,
|
||||
false,
|
||||
model,
|
||||
this.runtime,
|
||||
this.hasFailedCompressionAttempt,
|
||||
);
|
||||
|
||||
if (
|
||||
info.compressionStatus ===
|
||||
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT
|
||||
) {
|
||||
this.hasFailedCompressionAttempt = true;
|
||||
} else if (info.compressionStatus === CompressionStatus.COMPRESSED) {
|
||||
if (newHistory) {
|
||||
chat.setHistory(newHistory);
|
||||
this.hasFailedCompressionAttempt = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the current message history for this session.
|
||||
*/
|
||||
getHistory() {
|
||||
return this.client.getHistory();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the current session ID.
|
||||
*/
|
||||
getSessionId(): string {
|
||||
return this.sessionId;
|
||||
}
|
||||
}
|
||||
@@ -4,16 +4,66 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/**
|
||||
* @fileoverview Defines the core configuration interfaces and types for the agent architecture.
|
||||
*/
|
||||
|
||||
import type { Content, FunctionDeclaration } from '@google/genai';
|
||||
import type { AnyDeclarativeTool } from '../tools/tools.js';
|
||||
import { type z } from 'zod';
|
||||
import type { ModelConfig } from '../services/modelConfigService.js';
|
||||
import type { AnySchema } from 'ajv';
|
||||
import type { A2AAuthConfig } from './auth-provider/types.js';
|
||||
import { type ServerGeminiStreamEvent } from '../core/turn.js';
|
||||
import {
|
||||
type ToolCallResponseInfo,
|
||||
type ToolCallRequestInfo,
|
||||
} from '../scheduler/types.js';
|
||||
|
||||
/**
|
||||
* Unified event type for the Agent loop.
|
||||
* This extends the base Gemini stream events with higher-level agent lifecycle events.
|
||||
*/
|
||||
export type AgentEvent =
|
||||
| ServerGeminiStreamEvent
|
||||
| { type: 'agent_start'; value: { sessionId: string } }
|
||||
| {
|
||||
type: 'agent_finish';
|
||||
value: {
|
||||
sessionId: string;
|
||||
totalTurns: number;
|
||||
reason: AgentTerminateMode;
|
||||
message?: string;
|
||||
error?: unknown;
|
||||
};
|
||||
}
|
||||
| { type: 'tool_suite_start'; value: { count: number } }
|
||||
| { type: 'tool_suite_finish'; value: { responses: ToolCallResponseInfo[] } }
|
||||
| { type: 'tool_call_start'; value: ToolCallRequestInfo }
|
||||
| { type: 'tool_call_finish'; value: ToolCallResponseInfo }
|
||||
| { type: 'thought'; value: string }
|
||||
| { type: 'loop_detected'; value: { sessionId: string } };
|
||||
|
||||
/**
|
||||
* Configuration for an Agent.
|
||||
*/
|
||||
export interface AgentConfig {
|
||||
/** The name of the agent. */
|
||||
name: string;
|
||||
/** The system instruction (personality/rules) for the agent. */
|
||||
systemInstruction?: string;
|
||||
/** Optional override for the model to use. */
|
||||
model?: string;
|
||||
/**
|
||||
* Optional maximum number of conversational turns.
|
||||
* Set to -1 for no limit, defaults to -1 if not specified.
|
||||
*/
|
||||
maxTurns?: number;
|
||||
/**
|
||||
* Optional capabilities to enable for this agent.
|
||||
*/
|
||||
capabilities?: {
|
||||
compression?: boolean;
|
||||
loopDetection?: boolean;
|
||||
ideContext?: boolean;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Describes the possible termination modes for an agent.
|
||||
@@ -24,6 +74,7 @@ export enum AgentTerminateMode {
|
||||
GOAL = 'GOAL',
|
||||
MAX_TURNS = 'MAX_TURNS',
|
||||
ABORTED = 'ABORTED',
|
||||
LOOP = 'LOOP',
|
||||
ERROR_NO_COMPLETE_TASK_CALL = 'ERROR_NO_COMPLETE_TASK_CALL',
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user