mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-17 07:13:07 -07:00
feat(watcher): ensure subagent triggers after turns and only once per interaction
This commit is contained in:
@@ -618,25 +618,6 @@ export class GeminiClient {
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
let turn = new Turn(this.getChat(), prompt_id);
|
||||
|
||||
const watcherInterval = this.config.getExperimentalWatcherInterval();
|
||||
if (
|
||||
this.config.isExperimentalWatcherEnabled() &&
|
||||
(this.sessionTurnCount === 1 ||
|
||||
this.sessionTurnCount % watcherInterval === 0)
|
||||
) {
|
||||
const watcherResult = await this.tryRunWatcher(prompt_id, signal);
|
||||
if (watcherResult?.feedback) {
|
||||
const feedback = watcherResult.feedback;
|
||||
const feedbackRequest = [
|
||||
{
|
||||
text: `System: Feedback from Watcher Sub Agent based on recent progress (Review of last ${this.config.getExperimentalWatcherInterval()} turns):\n\n${feedback}`,
|
||||
},
|
||||
];
|
||||
// Inject feedback into the conversation
|
||||
this.getChat().addHistory(createUserContent(feedbackRequest));
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
this.config.getMaxSessionTurns() > 0 &&
|
||||
this.sessionTurnCount > this.config.getMaxSessionTurns()
|
||||
@@ -912,6 +893,7 @@ export class GeminiClient {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return turn;
|
||||
}
|
||||
|
||||
@@ -1069,6 +1051,30 @@ export class GeminiClient {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger Watcher after the full interaction (including tool recursions) is complete.
|
||||
// But only if we are at the top-level sendMessageStream (not a continuation).
|
||||
if (!continuationHandled && !isInvalidStreamRetry && !stopHookActive) {
|
||||
const watcherInterval = this.config.getExperimentalWatcherInterval();
|
||||
const currentTurn = this.sessionTurnCount;
|
||||
if (
|
||||
this.config.isExperimentalWatcherEnabled() &&
|
||||
currentTurn > 0 &&
|
||||
(currentTurn === 1 || currentTurn % watcherInterval === 0)
|
||||
) {
|
||||
const watcherResult = await this.tryRunWatcher(prompt_id, signal);
|
||||
if (watcherResult?.feedback) {
|
||||
const feedback = watcherResult.feedback;
|
||||
const feedbackRequest = [
|
||||
{
|
||||
text: `System: Feedback from Watcher Sub Agent based on recent progress (Review of last ${watcherInterval} turns):\n\n${feedback}`,
|
||||
},
|
||||
];
|
||||
// Inject feedback into the conversation for the NEXT turn
|
||||
this.getChat().addHistory(createUserContent(feedbackRequest));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return turn;
|
||||
|
||||
@@ -104,16 +104,6 @@ describe('GeminiClient Watcher Integration', () => {
|
||||
// Use type assertion for testing purposes to access protected members
|
||||
const clientAccess = client as unknown as {
|
||||
context: AgentLoopContext;
|
||||
sessionTurnCount: number;
|
||||
tryCompressChat: () => Promise<{ compressionStatus: string }>;
|
||||
_getActiveModelForCurrentTurn: () => string;
|
||||
processTurn: (
|
||||
request: unknown,
|
||||
signal: AbortSignal,
|
||||
promptId: string,
|
||||
maxTokens: number,
|
||||
forceFullContext: boolean,
|
||||
) => AsyncGenerator;
|
||||
};
|
||||
|
||||
Object.defineProperty(clientAccess.context, 'toolRegistry', {
|
||||
@@ -125,29 +115,19 @@ describe('GeminiClient Watcher Integration', () => {
|
||||
clientAccess.context as unknown as { agentRegistry: unknown }
|
||||
).agentRegistry = {
|
||||
getAllDefinitions: vi.fn().mockReturnValue([]),
|
||||
initialize: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
await config.storage.initialize();
|
||||
await client.initialize();
|
||||
|
||||
vi.spyOn(clientAccess, 'tryCompressChat').mockResolvedValue({
|
||||
compressionStatus: 'skipped',
|
||||
});
|
||||
vi.spyOn(clientAccess, '_getActiveModelForCurrentTurn').mockReturnValue(
|
||||
'gemini-pro',
|
||||
);
|
||||
|
||||
clientAccess.sessionTurnCount = 1;
|
||||
|
||||
const promptId = 'test-prompt';
|
||||
const signal = new AbortController().signal;
|
||||
|
||||
const generator = clientAccess.processTurn(
|
||||
const generator = client.sendMessageStream(
|
||||
[{ text: 'test' }],
|
||||
signal,
|
||||
promptId,
|
||||
10,
|
||||
false,
|
||||
);
|
||||
for await (const _ of generator) {
|
||||
// Intentionally consume
|
||||
@@ -181,16 +161,6 @@ describe('GeminiClient Watcher Integration', () => {
|
||||
// Use type assertion for testing purposes to access protected members
|
||||
const clientAccess = client as unknown as {
|
||||
context: AgentLoopContext;
|
||||
sessionTurnCount: number;
|
||||
tryCompressChat: () => Promise<{ compressionStatus: string }>;
|
||||
_getActiveModelForCurrentTurn: () => string;
|
||||
processTurn: (
|
||||
request: unknown,
|
||||
signal: AbortSignal,
|
||||
promptId: string,
|
||||
maxTokens: number,
|
||||
forceFullContext: boolean,
|
||||
) => AsyncGenerator;
|
||||
};
|
||||
|
||||
Object.defineProperty(clientAccess.context, 'toolRegistry', {
|
||||
@@ -202,29 +172,19 @@ describe('GeminiClient Watcher Integration', () => {
|
||||
clientAccess.context as unknown as { agentRegistry: unknown }
|
||||
).agentRegistry = {
|
||||
getAllDefinitions: vi.fn().mockReturnValue([]),
|
||||
initialize: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
await config.storage.initialize();
|
||||
await client.initialize();
|
||||
|
||||
vi.spyOn(clientAccess, 'tryCompressChat').mockResolvedValue({
|
||||
compressionStatus: 'skipped',
|
||||
});
|
||||
vi.spyOn(clientAccess, '_getActiveModelForCurrentTurn').mockReturnValue(
|
||||
'gemini-pro',
|
||||
);
|
||||
|
||||
clientAccess.sessionTurnCount = 1;
|
||||
|
||||
const promptId = 'test-prompt';
|
||||
const signal = new AbortController().signal;
|
||||
|
||||
const generator = clientAccess.processTurn(
|
||||
const generator = client.sendMessageStream(
|
||||
[{ text: 'test' }],
|
||||
signal,
|
||||
promptId,
|
||||
10,
|
||||
false,
|
||||
);
|
||||
for await (const _ of generator) {
|
||||
// Intentionally consume
|
||||
@@ -285,16 +245,6 @@ describe('GeminiClient Watcher Integration', () => {
|
||||
// Use type assertion for testing purposes to access protected members
|
||||
const clientAccess = client as unknown as {
|
||||
context: AgentLoopContext;
|
||||
sessionTurnCount: number;
|
||||
tryCompressChat: () => Promise<{ compressionStatus: string }>;
|
||||
_getActiveModelForCurrentTurn: () => string;
|
||||
processTurn: (
|
||||
request: unknown,
|
||||
signal: AbortSignal,
|
||||
promptId: string,
|
||||
maxTokens: number,
|
||||
forceFullContext: boolean,
|
||||
) => AsyncGenerator;
|
||||
};
|
||||
|
||||
Object.defineProperty(clientAccess.context, 'toolRegistry', {
|
||||
@@ -306,31 +256,21 @@ describe('GeminiClient Watcher Integration', () => {
|
||||
clientAccess.context as unknown as { agentRegistry: unknown }
|
||||
).agentRegistry = {
|
||||
getAllDefinitions: vi.fn().mockReturnValue([]),
|
||||
initialize: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
await config.storage.initialize();
|
||||
await client.initialize();
|
||||
|
||||
vi.spyOn(clientAccess, 'tryCompressChat').mockResolvedValue({
|
||||
compressionStatus: 'skipped',
|
||||
});
|
||||
vi.spyOn(clientAccess, '_getActiveModelForCurrentTurn').mockReturnValue(
|
||||
'gemini-pro',
|
||||
);
|
||||
|
||||
const promptId = 'test-prompt';
|
||||
const signal = new AbortController().signal;
|
||||
|
||||
// Simulate 11 turns
|
||||
for (let i = 1; i <= 11; i++) {
|
||||
clientAccess.sessionTurnCount = i;
|
||||
|
||||
const generator = clientAccess.processTurn(
|
||||
const promptId = `test-prompt-${i}`;
|
||||
const generator = client.sendMessageStream(
|
||||
[{ text: `turn ${i}` }],
|
||||
signal,
|
||||
promptId,
|
||||
10,
|
||||
false,
|
||||
);
|
||||
for await (const _ of generator) {
|
||||
// consume
|
||||
|
||||
Reference in New Issue
Block a user