mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-06-12 12:26:57 -07:00
feat(core): enhance AgentSession with system instruction support and internal cancellation
This commit is contained in:
@@ -24,6 +24,7 @@ import {
|
||||
type ToolCallRequestInfo,
|
||||
} from '../scheduler/types.js';
|
||||
import { type ResumedSessionData } from '../services/chatRecordingService.js';
|
||||
import { type GeminiClient } from '../core/client.js';
|
||||
|
||||
vi.mock('../core/client.js');
|
||||
vi.mock('../scheduler/scheduler.js');
|
||||
@@ -33,6 +34,8 @@ describe('AgentSession', () => {
|
||||
let mockConfig: ReturnType<typeof makeFakeConfig>;
|
||||
let mockClient: {
|
||||
sendMessageStream: ReturnType<typeof vi.fn>;
|
||||
isInitialized: ReturnType<typeof vi.fn>;
|
||||
initialize: ReturnType<typeof vi.fn>;
|
||||
getChat: ReturnType<typeof vi.fn>;
|
||||
getCurrentSequenceModel: ReturnType<typeof vi.fn>;
|
||||
getHistory: ReturnType<typeof vi.fn>;
|
||||
@@ -56,10 +59,13 @@ describe('AgentSession', () => {
|
||||
|
||||
mockClient = {
|
||||
sendMessageStream: vi.fn(),
|
||||
isInitialized: vi.fn().mockReturnValue(false),
|
||||
initialize: vi.fn().mockResolvedValue(undefined),
|
||||
getChat: vi.fn().mockReturnValue({
|
||||
recordCompletedToolCalls: vi.fn(),
|
||||
setHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
setSystemInstruction: vi.fn(),
|
||||
}),
|
||||
getCurrentSequenceModel: vi.fn().mockReturnValue('test-model'),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
@@ -78,7 +84,7 @@ describe('AgentSession', () => {
|
||||
};
|
||||
|
||||
vi.spyOn(mockConfig, 'getGeminiClient').mockReturnValue(
|
||||
mockClient as unknown as import('../core/client.js').GeminiClient,
|
||||
mockClient as unknown as GeminiClient,
|
||||
);
|
||||
vi.mocked(Scheduler).mockImplementation(
|
||||
(options) =>
|
||||
@@ -428,6 +434,53 @@ describe('AgentSession', () => {
|
||||
expect(finishEvent.value.reason).toBe(AgentTerminateMode.ABORTED);
|
||||
});
|
||||
|
||||
it('should apply systemInstruction from AgentConfig', async () => {
|
||||
const customConfig: AgentConfig = {
|
||||
...agentConfig,
|
||||
systemInstruction: 'You are a helpful assistant.',
|
||||
};
|
||||
|
||||
// Mock isInitialized to true so constructor can set it
|
||||
mockClient.isInitialized.mockReturnValue(true);
|
||||
const mockChat = { setSystemInstruction: vi.fn() };
|
||||
mockClient.getChat.mockReturnValue(mockChat);
|
||||
|
||||
// Re-create to trigger constructor logic
|
||||
new AgentSession('test-session-3', customConfig, mockConfig);
|
||||
expect(mockChat.setSystemInstruction).toHaveBeenCalledWith(
|
||||
'You are a helpful assistant.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should abort internal operations if caller stops iterating', async () => {
|
||||
let internalSignal: AbortSignal | undefined;
|
||||
|
||||
mockClient.sendMessageStream.mockImplementation(async function* (
|
||||
_parts: unknown,
|
||||
signal: AbortSignal,
|
||||
) {
|
||||
internalSignal = signal;
|
||||
yield { type: GeminiEventType.Content, value: 'Part 1' };
|
||||
yield { type: GeminiEventType.Content, value: 'Part 2' };
|
||||
});
|
||||
|
||||
const promptStream = session.prompt('Test cancellation');
|
||||
const iterator = promptStream[Symbol.asyncIterator]();
|
||||
|
||||
const firstEvent = await iterator.next();
|
||||
expect(firstEvent.value?.type).toBe('agent_start');
|
||||
|
||||
const secondEvent = await iterator.next(); // content Part 1
|
||||
expect(secondEvent.value?.type).toBe(GeminiEventType.Content);
|
||||
|
||||
// Caller stops here and closes the generator
|
||||
if (iterator.return) {
|
||||
await iterator.return();
|
||||
}
|
||||
|
||||
expect(internalSignal?.aborted).toBe(true);
|
||||
});
|
||||
|
||||
it('should respect maxTurns from config', async () => {
|
||||
const customSession = new AgentSession(
|
||||
'test-session-2',
|
||||
|
||||
@@ -53,6 +53,15 @@ export class AgentSession {
|
||||
schedulerId: this.schedulerId,
|
||||
});
|
||||
this.compressionService = new ChatCompressionService();
|
||||
|
||||
// Ensure system instruction is set from AgentConfig
|
||||
if (this.config.systemInstruction) {
|
||||
if (this.client.isInitialized()) {
|
||||
this.client
|
||||
.getChat()
|
||||
.setSystemInstruction(this.config.systemInstruction);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -66,6 +75,11 @@ export class AgentSession {
|
||||
resumedSessionData.conversation.messages,
|
||||
);
|
||||
await this.client.resumeChat(clientHistory, resumedSessionData);
|
||||
|
||||
// Re-apply system instruction after resume since resume re-creates the chat
|
||||
if (this.config.systemInstruction) {
|
||||
this.client.getChat().setSystemInstruction(this.config.systemInstruction);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -76,6 +90,11 @@ export class AgentSession {
|
||||
input: string | Part[],
|
||||
signal?: AbortSignal,
|
||||
): AsyncIterable<AgentEvent> {
|
||||
const internalController = new AbortController();
|
||||
const combinedSignal = signal
|
||||
? AbortSignal.any([signal, internalController.signal])
|
||||
: internalController.signal;
|
||||
|
||||
yield {
|
||||
type: 'agent_start',
|
||||
value: { sessionId: this.sessionId },
|
||||
@@ -91,7 +110,7 @@ export class AgentSession {
|
||||
|
||||
try {
|
||||
while (maxTurns === -1 || this.totalTurns < maxTurns) {
|
||||
if (signal?.aborted) {
|
||||
if (combinedSignal.aborted) {
|
||||
terminationReason = AgentTerminateMode.ABORTED;
|
||||
break;
|
||||
}
|
||||
@@ -108,7 +127,7 @@ export class AgentSession {
|
||||
currentInput,
|
||||
promptId,
|
||||
isContinuation ? undefined : input,
|
||||
signal,
|
||||
combinedSignal,
|
||||
);
|
||||
|
||||
for await (const event of results.events) {
|
||||
@@ -121,13 +140,13 @@ export class AgentSession {
|
||||
break;
|
||||
}
|
||||
|
||||
if (signal?.aborted) {
|
||||
if (combinedSignal.aborted) {
|
||||
terminationReason = AgentTerminateMode.ABORTED;
|
||||
break;
|
||||
}
|
||||
|
||||
if (results.toolCalls.length > 0) {
|
||||
const toolRun = this.executeTools(results.toolCalls, signal);
|
||||
const toolRun = this.executeTools(results.toolCalls, combinedSignal);
|
||||
let resultsTools;
|
||||
while (true) {
|
||||
const { value, done } = await toolRun.next();
|
||||
@@ -138,8 +157,8 @@ export class AgentSession {
|
||||
yield value;
|
||||
}
|
||||
|
||||
if (resultsTools.stopExecution || (signal && signal.aborted)) {
|
||||
if (signal && signal.aborted) {
|
||||
if (resultsTools.stopExecution || combinedSignal.aborted) {
|
||||
if (combinedSignal.aborted) {
|
||||
terminationReason = AgentTerminateMode.ABORTED;
|
||||
} else if (resultsTools.stopExecutionInfo) {
|
||||
terminationReason = AgentTerminateMode.ERROR;
|
||||
@@ -166,6 +185,7 @@ export class AgentSession {
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
internalController.abort();
|
||||
yield {
|
||||
type: 'agent_finish',
|
||||
value: {
|
||||
@@ -197,6 +217,17 @@ export class AgentSession {
|
||||
const toolCalls: ToolCallRequestInfo[] = [];
|
||||
let loopDetected = false;
|
||||
|
||||
// Ensure client is initialized before sending message
|
||||
if (!this.client.isInitialized()) {
|
||||
await this.client.initialize();
|
||||
// Re-apply system instruction after initialization
|
||||
if (this.config.systemInstruction) {
|
||||
this.client
|
||||
.getChat()
|
||||
.setSystemInstruction(this.config.systemInstruction);
|
||||
}
|
||||
}
|
||||
|
||||
const stream = this.client.sendMessageStream(
|
||||
parts,
|
||||
signal ?? new AbortController().signal,
|
||||
|
||||
Reference in New Issue
Block a user