feat(core): enhance AgentSession with system instruction support and internal cancellation

This commit is contained in:
Abhi
2026-02-23 21:07:26 -05:00
parent b826256f3e
commit 805a83de3f
2 changed files with 91 additions and 7 deletions
+54 -1
View File
@@ -24,6 +24,7 @@ import {
type ToolCallRequestInfo,
} from '../scheduler/types.js';
import { type ResumedSessionData } from '../services/chatRecordingService.js';
import { type GeminiClient } from '../core/client.js';
vi.mock('../core/client.js');
vi.mock('../scheduler/scheduler.js');
@@ -33,6 +34,8 @@ describe('AgentSession', () => {
let mockConfig: ReturnType<typeof makeFakeConfig>;
let mockClient: {
sendMessageStream: ReturnType<typeof vi.fn>;
isInitialized: ReturnType<typeof vi.fn>;
initialize: ReturnType<typeof vi.fn>;
getChat: ReturnType<typeof vi.fn>;
getCurrentSequenceModel: ReturnType<typeof vi.fn>;
getHistory: ReturnType<typeof vi.fn>;
@@ -56,10 +59,13 @@ describe('AgentSession', () => {
mockClient = {
sendMessageStream: vi.fn(),
isInitialized: vi.fn().mockReturnValue(false),
initialize: vi.fn().mockResolvedValue(undefined),
getChat: vi.fn().mockReturnValue({
recordCompletedToolCalls: vi.fn(),
setHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
setSystemInstruction: vi.fn(),
}),
getCurrentSequenceModel: vi.fn().mockReturnValue('test-model'),
getHistory: vi.fn().mockReturnValue([]),
@@ -78,7 +84,7 @@ describe('AgentSession', () => {
};
vi.spyOn(mockConfig, 'getGeminiClient').mockReturnValue(
mockClient as unknown as import('../core/client.js').GeminiClient,
mockClient as unknown as GeminiClient,
);
vi.mocked(Scheduler).mockImplementation(
(options) =>
@@ -428,6 +434,53 @@ describe('AgentSession', () => {
expect(finishEvent.value.reason).toBe(AgentTerminateMode.ABORTED);
});
it('should apply systemInstruction from AgentConfig', async () => {
const customConfig: AgentConfig = {
...agentConfig,
systemInstruction: 'You are a helpful assistant.',
};
// Mock isInitialized to true so constructor can set it
mockClient.isInitialized.mockReturnValue(true);
const mockChat = { setSystemInstruction: vi.fn() };
mockClient.getChat.mockReturnValue(mockChat);
// Re-create to trigger constructor logic
new AgentSession('test-session-3', customConfig, mockConfig);
expect(mockChat.setSystemInstruction).toHaveBeenCalledWith(
'You are a helpful assistant.',
);
});
it('should abort internal operations if caller stops iterating', async () => {
let internalSignal: AbortSignal | undefined;
mockClient.sendMessageStream.mockImplementation(async function* (
_parts: unknown,
signal: AbortSignal,
) {
internalSignal = signal;
yield { type: GeminiEventType.Content, value: 'Part 1' };
yield { type: GeminiEventType.Content, value: 'Part 2' };
});
const promptStream = session.prompt('Test cancellation');
const iterator = promptStream[Symbol.asyncIterator]();
const firstEvent = await iterator.next();
expect(firstEvent.value?.type).toBe('agent_start');
const secondEvent = await iterator.next(); // content Part 1
expect(secondEvent.value?.type).toBe(GeminiEventType.Content);
// Caller stops here and closes the generator
if (iterator.return) {
await iterator.return();
}
expect(internalSignal?.aborted).toBe(true);
});
it('should respect maxTurns from config', async () => {
const customSession = new AgentSession(
'test-session-2',
+37 -6
View File
@@ -53,6 +53,15 @@ export class AgentSession {
schedulerId: this.schedulerId,
});
this.compressionService = new ChatCompressionService();
// Ensure system instruction is set from AgentConfig
if (this.config.systemInstruction) {
if (this.client.isInitialized()) {
this.client
.getChat()
.setSystemInstruction(this.config.systemInstruction);
}
}
}
/**
@@ -66,6 +75,11 @@ export class AgentSession {
resumedSessionData.conversation.messages,
);
await this.client.resumeChat(clientHistory, resumedSessionData);
// Re-apply system instruction after resume since resume re-creates the chat
if (this.config.systemInstruction) {
this.client.getChat().setSystemInstruction(this.config.systemInstruction);
}
}
/**
@@ -76,6 +90,11 @@ export class AgentSession {
input: string | Part[],
signal?: AbortSignal,
): AsyncIterable<AgentEvent> {
const internalController = new AbortController();
const combinedSignal = signal
? AbortSignal.any([signal, internalController.signal])
: internalController.signal;
yield {
type: 'agent_start',
value: { sessionId: this.sessionId },
@@ -91,7 +110,7 @@ export class AgentSession {
try {
while (maxTurns === -1 || this.totalTurns < maxTurns) {
if (signal?.aborted) {
if (combinedSignal.aborted) {
terminationReason = AgentTerminateMode.ABORTED;
break;
}
@@ -108,7 +127,7 @@ export class AgentSession {
currentInput,
promptId,
isContinuation ? undefined : input,
signal,
combinedSignal,
);
for await (const event of results.events) {
@@ -121,13 +140,13 @@ export class AgentSession {
break;
}
if (signal?.aborted) {
if (combinedSignal.aborted) {
terminationReason = AgentTerminateMode.ABORTED;
break;
}
if (results.toolCalls.length > 0) {
const toolRun = this.executeTools(results.toolCalls, signal);
const toolRun = this.executeTools(results.toolCalls, combinedSignal);
let resultsTools;
while (true) {
const { value, done } = await toolRun.next();
@@ -138,8 +157,8 @@ export class AgentSession {
yield value;
}
if (resultsTools.stopExecution || (signal && signal.aborted)) {
if (signal && signal.aborted) {
if (resultsTools.stopExecution || combinedSignal.aborted) {
if (combinedSignal.aborted) {
terminationReason = AgentTerminateMode.ABORTED;
} else if (resultsTools.stopExecutionInfo) {
terminationReason = AgentTerminateMode.ERROR;
@@ -166,6 +185,7 @@ export class AgentSession {
}
}
} finally {
internalController.abort();
yield {
type: 'agent_finish',
value: {
@@ -197,6 +217,17 @@ export class AgentSession {
const toolCalls: ToolCallRequestInfo[] = [];
let loopDetected = false;
// Ensure client is initialized before sending message
if (!this.client.isInitialized()) {
await this.client.initialize();
// Re-apply system instruction after initialization
if (this.config.systemInstruction) {
this.client
.getChat()
.setSystemInstruction(this.config.systemInstruction);
}
}
const stream = this.client.sendMessageStream(
parts,
signal ?? new AbortController().signal,