feat(watcher): ensure subagent triggers after turns and only once per interaction

This commit is contained in:
Aishanee Shah
2026-04-10 01:59:48 +00:00
parent a9a37b2b3f
commit 4b88dc1bcf
2 changed files with 32 additions and 86 deletions
+25 -19
View File
@@ -618,25 +618,6 @@ export class GeminiClient {
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
let turn = new Turn(this.getChat(), prompt_id);
const watcherInterval = this.config.getExperimentalWatcherInterval();
if (
this.config.isExperimentalWatcherEnabled() &&
(this.sessionTurnCount === 1 ||
this.sessionTurnCount % watcherInterval === 0)
) {
const watcherResult = await this.tryRunWatcher(prompt_id, signal);
if (watcherResult?.feedback) {
const feedback = watcherResult.feedback;
const feedbackRequest = [
{
text: `System: Feedback from Watcher Sub Agent based on recent progress (Review of last ${this.config.getExperimentalWatcherInterval()} turns):\n\n${feedback}`,
},
];
// Inject feedback into the conversation
this.getChat().addHistory(createUserContent(feedbackRequest));
}
}
if (
this.config.getMaxSessionTurns() > 0 &&
this.sessionTurnCount > this.config.getMaxSessionTurns()
@@ -912,6 +893,7 @@ export class GeminiClient {
}
}
}
return turn;
}
@@ -1069,6 +1051,30 @@ export class GeminiClient {
}
}
}
// Trigger Watcher after the full interaction (including tool recursions) is complete.
// But only if we are at the top-level sendMessageStream (not a continuation).
if (!continuationHandled && !isInvalidStreamRetry && !stopHookActive) {
const watcherInterval = this.config.getExperimentalWatcherInterval();
const currentTurn = this.sessionTurnCount;
if (
this.config.isExperimentalWatcherEnabled() &&
currentTurn > 0 &&
(currentTurn === 1 || currentTurn % watcherInterval === 0)
) {
const watcherResult = await this.tryRunWatcher(prompt_id, signal);
if (watcherResult?.feedback) {
const feedback = watcherResult.feedback;
const feedbackRequest = [
{
text: `System: Feedback from Watcher Sub Agent based on recent progress (Review of last ${watcherInterval} turns):\n\n${feedback}`,
},
];
// Inject feedback into the conversation for the NEXT turn
this.getChat().addHistory(createUserContent(feedbackRequest));
}
}
}
}
return turn;
+7 -67
View File
@@ -104,16 +104,6 @@ describe('GeminiClient Watcher Integration', () => {
// Use type assertion for testing purposes to access protected members
const clientAccess = client as unknown as {
context: AgentLoopContext;
sessionTurnCount: number;
tryCompressChat: () => Promise<{ compressionStatus: string }>;
_getActiveModelForCurrentTurn: () => string;
processTurn: (
request: unknown,
signal: AbortSignal,
promptId: string,
maxTokens: number,
forceFullContext: boolean,
) => AsyncGenerator;
};
Object.defineProperty(clientAccess.context, 'toolRegistry', {
@@ -125,29 +115,19 @@ describe('GeminiClient Watcher Integration', () => {
clientAccess.context as unknown as { agentRegistry: unknown }
).agentRegistry = {
getAllDefinitions: vi.fn().mockReturnValue([]),
initialize: vi.fn().mockResolvedValue(undefined),
};
await config.storage.initialize();
await client.initialize();
vi.spyOn(clientAccess, 'tryCompressChat').mockResolvedValue({
compressionStatus: 'skipped',
});
vi.spyOn(clientAccess, '_getActiveModelForCurrentTurn').mockReturnValue(
'gemini-pro',
);
clientAccess.sessionTurnCount = 1;
const promptId = 'test-prompt';
const signal = new AbortController().signal;
const generator = clientAccess.processTurn(
const generator = client.sendMessageStream(
[{ text: 'test' }],
signal,
promptId,
10,
false,
);
for await (const _ of generator) {
// Intentionally consume
@@ -181,16 +161,6 @@ describe('GeminiClient Watcher Integration', () => {
// Use type assertion for testing purposes to access protected members
const clientAccess = client as unknown as {
context: AgentLoopContext;
sessionTurnCount: number;
tryCompressChat: () => Promise<{ compressionStatus: string }>;
_getActiveModelForCurrentTurn: () => string;
processTurn: (
request: unknown,
signal: AbortSignal,
promptId: string,
maxTokens: number,
forceFullContext: boolean,
) => AsyncGenerator;
};
Object.defineProperty(clientAccess.context, 'toolRegistry', {
@@ -202,29 +172,19 @@ describe('GeminiClient Watcher Integration', () => {
clientAccess.context as unknown as { agentRegistry: unknown }
).agentRegistry = {
getAllDefinitions: vi.fn().mockReturnValue([]),
initialize: vi.fn().mockResolvedValue(undefined),
};
await config.storage.initialize();
await client.initialize();
vi.spyOn(clientAccess, 'tryCompressChat').mockResolvedValue({
compressionStatus: 'skipped',
});
vi.spyOn(clientAccess, '_getActiveModelForCurrentTurn').mockReturnValue(
'gemini-pro',
);
clientAccess.sessionTurnCount = 1;
const promptId = 'test-prompt';
const signal = new AbortController().signal;
const generator = clientAccess.processTurn(
const generator = client.sendMessageStream(
[{ text: 'test' }],
signal,
promptId,
10,
false,
);
for await (const _ of generator) {
// Intentionally consume
@@ -285,16 +245,6 @@ describe('GeminiClient Watcher Integration', () => {
// Use type assertion for testing purposes to access protected members
const clientAccess = client as unknown as {
context: AgentLoopContext;
sessionTurnCount: number;
tryCompressChat: () => Promise<{ compressionStatus: string }>;
_getActiveModelForCurrentTurn: () => string;
processTurn: (
request: unknown,
signal: AbortSignal,
promptId: string,
maxTokens: number,
forceFullContext: boolean,
) => AsyncGenerator;
};
Object.defineProperty(clientAccess.context, 'toolRegistry', {
@@ -306,31 +256,21 @@ describe('GeminiClient Watcher Integration', () => {
clientAccess.context as unknown as { agentRegistry: unknown }
).agentRegistry = {
getAllDefinitions: vi.fn().mockReturnValue([]),
initialize: vi.fn().mockResolvedValue(undefined),
};
await config.storage.initialize();
await client.initialize();
vi.spyOn(clientAccess, 'tryCompressChat').mockResolvedValue({
compressionStatus: 'skipped',
});
vi.spyOn(clientAccess, '_getActiveModelForCurrentTurn').mockReturnValue(
'gemini-pro',
);
const promptId = 'test-prompt';
const signal = new AbortController().signal;
// Simulate 11 turns
for (let i = 1; i <= 11; i++) {
clientAccess.sessionTurnCount = i;
const generator = clientAccess.processTurn(
const promptId = `test-prompt-${i}`;
const generator = client.sendMessageStream(
[{ text: `turn ${i}` }],
signal,
promptId,
10,
false,
);
for await (const _ of generator) {
// consume