diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 4edfed3e56..d51c036cee 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -618,25 +618,6 @@ export class GeminiClient { ): AsyncGenerator { let turn = new Turn(this.getChat(), prompt_id); - const watcherInterval = this.config.getExperimentalWatcherInterval(); - if ( - this.config.isExperimentalWatcherEnabled() && - (this.sessionTurnCount === 1 || - this.sessionTurnCount % watcherInterval === 0) - ) { - const watcherResult = await this.tryRunWatcher(prompt_id, signal); - if (watcherResult?.feedback) { - const feedback = watcherResult.feedback; - const feedbackRequest = [ - { - text: `System: Feedback from Watcher Sub Agent based on recent progress (Review of last ${this.config.getExperimentalWatcherInterval()} turns):\n\n${feedback}`, - }, - ]; - // Inject feedback into the conversation - this.getChat().addHistory(createUserContent(feedbackRequest)); - } - } - if ( this.config.getMaxSessionTurns() > 0 && this.sessionTurnCount > this.config.getMaxSessionTurns() @@ -912,6 +893,7 @@ export class GeminiClient { } } } + return turn; } @@ -1069,6 +1051,30 @@ export class GeminiClient { } } } + + // Trigger Watcher after the full interaction (including tool recursions) is complete. + // But only if we are at the top-level sendMessageStream (not a continuation). + if (!continuationHandled && !isInvalidStreamRetry && !stopHookActive) { + const watcherInterval = this.config.getExperimentalWatcherInterval(); + const currentTurn = this.sessionTurnCount; + if ( + this.config.isExperimentalWatcherEnabled() && + currentTurn > 0 && + (currentTurn === 1 || currentTurn % watcherInterval === 0) + ) { + const watcherResult = await this.tryRunWatcher(prompt_id, signal); + if (watcherResult?.feedback) { + const feedback = watcherResult.feedback; + const feedbackRequest = [ + { + text: `System: Feedback from Watcher Sub Agent based on recent progress (Review of last ${watcherInterval} turns):\n\n${feedback}`, + }, + ]; + // Inject feedback into the conversation for the NEXT turn + this.getChat().addHistory(createUserContent(feedbackRequest)); + } + } + } } return turn; diff --git a/packages/core/src/core/client_watcher.test.ts b/packages/core/src/core/client_watcher.test.ts index d22d307925..c6f4075591 100644 --- a/packages/core/src/core/client_watcher.test.ts +++ b/packages/core/src/core/client_watcher.test.ts @@ -104,16 +104,6 @@ describe('GeminiClient Watcher Integration', () => { // Use type assertion for testing purposes to access protected members const clientAccess = client as unknown as { context: AgentLoopContext; - sessionTurnCount: number; - tryCompressChat: () => Promise<{ compressionStatus: string }>; - _getActiveModelForCurrentTurn: () => string; - processTurn: ( - request: unknown, - signal: AbortSignal, - promptId: string, - maxTokens: number, - forceFullContext: boolean, - ) => AsyncGenerator; }; Object.defineProperty(clientAccess.context, 'toolRegistry', { @@ -125,29 +115,19 @@ describe('GeminiClient Watcher Integration', () => { clientAccess.context as unknown as { agentRegistry: unknown } ).agentRegistry = { getAllDefinitions: vi.fn().mockReturnValue([]), + initialize: vi.fn().mockResolvedValue(undefined), }; await config.storage.initialize(); await client.initialize(); - vi.spyOn(clientAccess, 'tryCompressChat').mockResolvedValue({ - compressionStatus: 'skipped', - }); - vi.spyOn(clientAccess, '_getActiveModelForCurrentTurn').mockReturnValue( - 'gemini-pro', - ); - - clientAccess.sessionTurnCount = 1; - const promptId = 'test-prompt'; const signal = new AbortController().signal; - const generator = clientAccess.processTurn( + const generator = client.sendMessageStream( [{ text: 'test' }], signal, promptId, - 10, - false, ); for await (const _ of generator) { // Intentionally consume @@ -181,16 +161,6 @@ describe('GeminiClient Watcher Integration', () => { // Use type assertion for testing purposes to access protected members const clientAccess = client as unknown as { context: AgentLoopContext; - sessionTurnCount: number; - tryCompressChat: () => Promise<{ compressionStatus: string }>; - _getActiveModelForCurrentTurn: () => string; - processTurn: ( - request: unknown, - signal: AbortSignal, - promptId: string, - maxTokens: number, - forceFullContext: boolean, - ) => AsyncGenerator; }; Object.defineProperty(clientAccess.context, 'toolRegistry', { @@ -202,29 +172,19 @@ describe('GeminiClient Watcher Integration', () => { clientAccess.context as unknown as { agentRegistry: unknown } ).agentRegistry = { getAllDefinitions: vi.fn().mockReturnValue([]), + initialize: vi.fn().mockResolvedValue(undefined), }; await config.storage.initialize(); await client.initialize(); - vi.spyOn(clientAccess, 'tryCompressChat').mockResolvedValue({ - compressionStatus: 'skipped', - }); - vi.spyOn(clientAccess, '_getActiveModelForCurrentTurn').mockReturnValue( - 'gemini-pro', - ); - - clientAccess.sessionTurnCount = 1; - const promptId = 'test-prompt'; const signal = new AbortController().signal; - const generator = clientAccess.processTurn( + const generator = client.sendMessageStream( [{ text: 'test' }], signal, promptId, - 10, - false, ); for await (const _ of generator) { // Intentionally consume @@ -285,16 +245,6 @@ describe('GeminiClient Watcher Integration', () => { // Use type assertion for testing purposes to access protected members const clientAccess = client as unknown as { context: AgentLoopContext; - sessionTurnCount: number; - tryCompressChat: () => Promise<{ compressionStatus: string }>; - _getActiveModelForCurrentTurn: () => string; - processTurn: ( - request: unknown, - signal: AbortSignal, - promptId: string, - maxTokens: number, - forceFullContext: boolean, - ) => AsyncGenerator; }; Object.defineProperty(clientAccess.context, 'toolRegistry', { @@ -306,31 +256,21 @@ describe('GeminiClient Watcher Integration', () => { clientAccess.context as unknown as { agentRegistry: unknown } ).agentRegistry = { getAllDefinitions: vi.fn().mockReturnValue([]), + initialize: vi.fn().mockResolvedValue(undefined), }; await config.storage.initialize(); await client.initialize(); - vi.spyOn(clientAccess, 'tryCompressChat').mockResolvedValue({ - compressionStatus: 'skipped', - }); - vi.spyOn(clientAccess, '_getActiveModelForCurrentTurn').mockReturnValue( - 'gemini-pro', - ); - - const promptId = 'test-prompt'; const signal = new AbortController().signal; // Simulate 11 turns for (let i = 1; i <= 11; i++) { - clientAccess.sessionTurnCount = i; - - const generator = clientAccess.processTurn( + const promptId = `test-prompt-${i}`; + const generator = client.sendMessageStream( [{ text: `turn ${i}` }], signal, promptId, - 10, - false, ); for await (const _ of generator) { // consume