mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-25 12:34:38 -07:00
feat(hooks): implement granular stop and block behavior for agent hooks (#15824)
This commit is contained in:
@@ -46,6 +46,7 @@ import type {
|
||||
} from '../services/modelConfigService.js';
|
||||
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
|
||||
import { HookSystem } from '../hooks/hookSystem.js';
|
||||
import type { DefaultHookOutput } from '../hooks/types.js';
|
||||
import * as policyCatalog from '../availability/policyCatalog.js';
|
||||
|
||||
vi.mock('../services/chatCompressionService.js');
|
||||
@@ -2781,6 +2782,136 @@ ${JSON.stringify(
|
||||
expect(client['hookStateMap'].has('old-id')).toBe(false);
|
||||
expect(client['hookStateMap'].has('new-id')).toBe(true);
|
||||
});
|
||||
|
||||
it('should stop execution in BeforeAgent when hook returns continue: false', async () => {
|
||||
const { fireBeforeAgentHook } = await import('./clientHookTriggers.js');
|
||||
vi.mocked(fireBeforeAgentHook).mockResolvedValue({
|
||||
shouldStopExecution: () => true,
|
||||
getEffectiveReason: () => 'Stopped by hook',
|
||||
} as DefaultHookOutput);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
getLastPromptTokenCount: vi.fn(),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const request = [{ text: 'Hello' }];
|
||||
const stream = client.sendMessageStream(
|
||||
request,
|
||||
new AbortController().signal,
|
||||
'test-prompt',
|
||||
);
|
||||
const events = await fromAsync(stream);
|
||||
|
||||
expect(events).toContainEqual({
|
||||
type: GeminiEventType.AgentExecutionStopped,
|
||||
value: { reason: 'Stopped by hook' },
|
||||
});
|
||||
expect(mockChat.addHistory).toHaveBeenCalledWith({
|
||||
role: 'user',
|
||||
parts: request,
|
||||
});
|
||||
expect(mockTurnRunFn).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should block execution in BeforeAgent when hook returns decision: block', async () => {
|
||||
const { fireBeforeAgentHook } = await import('./clientHookTriggers.js');
|
||||
vi.mocked(fireBeforeAgentHook).mockResolvedValue({
|
||||
shouldStopExecution: () => false,
|
||||
isBlockingDecision: () => true,
|
||||
getEffectiveReason: () => 'Blocked by hook',
|
||||
} as DefaultHookOutput);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
getLastPromptTokenCount: vi.fn(),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const request = [{ text: 'Hello' }];
|
||||
const stream = client.sendMessageStream(
|
||||
request,
|
||||
new AbortController().signal,
|
||||
'test-prompt',
|
||||
);
|
||||
const events = await fromAsync(stream);
|
||||
|
||||
expect(events).toContainEqual({
|
||||
type: GeminiEventType.AgentExecutionBlocked,
|
||||
value: {
|
||||
reason: 'Blocked by hook',
|
||||
},
|
||||
});
|
||||
expect(mockChat.addHistory).not.toHaveBeenCalled();
|
||||
expect(mockTurnRunFn).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should stop execution in AfterAgent when hook returns continue: false', async () => {
|
||||
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
|
||||
vi.mocked(fireAfterAgentHook).mockResolvedValue({
|
||||
shouldStopExecution: () => true,
|
||||
getEffectiveReason: () => 'Stopped after agent',
|
||||
} as DefaultHookOutput);
|
||||
|
||||
mockTurnRunFn.mockImplementation(async function* () {
|
||||
yield { type: GeminiEventType.Content, value: 'Hello' };
|
||||
});
|
||||
|
||||
const stream = client.sendMessageStream(
|
||||
{ text: 'Hi' },
|
||||
new AbortController().signal,
|
||||
'test-prompt',
|
||||
);
|
||||
const events = await fromAsync(stream);
|
||||
|
||||
expect(events).toContainEqual({
|
||||
type: GeminiEventType.AgentExecutionStopped,
|
||||
value: { reason: 'Stopped after agent' },
|
||||
});
|
||||
// sendMessageStream should not recurse
|
||||
expect(mockTurnRunFn).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should yield AgentExecutionBlocked and recurse in AfterAgent when hook returns decision: block', async () => {
|
||||
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
|
||||
vi.mocked(fireAfterAgentHook)
|
||||
.mockResolvedValueOnce({
|
||||
shouldStopExecution: () => false,
|
||||
isBlockingDecision: () => true,
|
||||
getEffectiveReason: () => 'Please explain',
|
||||
} as DefaultHookOutput)
|
||||
.mockResolvedValueOnce({
|
||||
shouldStopExecution: () => false,
|
||||
isBlockingDecision: () => false,
|
||||
} as DefaultHookOutput);
|
||||
|
||||
mockTurnRunFn.mockImplementation(async function* () {
|
||||
yield { type: GeminiEventType.Content, value: 'Response' };
|
||||
});
|
||||
|
||||
const stream = client.sendMessageStream(
|
||||
{ text: 'Hi' },
|
||||
new AbortController().signal,
|
||||
'test-prompt',
|
||||
);
|
||||
const events = await fromAsync(stream);
|
||||
|
||||
expect(events).toContainEqual({
|
||||
type: GeminiEventType.AgentExecutionBlocked,
|
||||
value: { reason: 'Please explain' },
|
||||
});
|
||||
// Should have called turn run twice (original + re-prompt)
|
||||
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
|
||||
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.anything(),
|
||||
[{ text: 'Please explain' }],
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -11,6 +11,7 @@ import type {
|
||||
Tool,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import { createUserContent } from '@google/genai';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import {
|
||||
getDirectoryContextString,
|
||||
@@ -65,8 +66,12 @@ const MAX_TURNS = 100;
|
||||
|
||||
type BeforeAgentHookReturn =
|
||||
| {
|
||||
type: GeminiEventType.Error;
|
||||
value: { error: Error };
|
||||
type: GeminiEventType.AgentExecutionStopped;
|
||||
value: { reason: string };
|
||||
}
|
||||
| {
|
||||
type: GeminiEventType.AgentExecutionBlocked;
|
||||
value: { reason: string };
|
||||
}
|
||||
| { additionalContext: string | undefined }
|
||||
| undefined;
|
||||
@@ -135,13 +140,20 @@ export class GeminiClient {
|
||||
const hookOutput = await fireBeforeAgentHook(messageBus, request);
|
||||
hookState.hasFiredBeforeAgent = true;
|
||||
|
||||
if (hookOutput?.isBlockingDecision() || hookOutput?.shouldStopExecution()) {
|
||||
if (hookOutput?.shouldStopExecution()) {
|
||||
return {
|
||||
type: GeminiEventType.Error,
|
||||
type: GeminiEventType.AgentExecutionStopped,
|
||||
value: {
|
||||
error: new Error(
|
||||
`BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`,
|
||||
),
|
||||
reason: hookOutput.getEffectiveReason(),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
if (hookOutput?.isBlockingDecision()) {
|
||||
return {
|
||||
type: GeminiEventType.AgentExecutionBlocked,
|
||||
value: {
|
||||
reason: hookOutput.getEffectiveReason(),
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -747,7 +759,18 @@ export class GeminiClient {
|
||||
prompt_id,
|
||||
);
|
||||
if (hookResult) {
|
||||
if ('type' in hookResult && hookResult.type === GeminiEventType.Error) {
|
||||
if (
|
||||
'type' in hookResult &&
|
||||
hookResult.type === GeminiEventType.AgentExecutionStopped
|
||||
) {
|
||||
// Add user message to history before returning so it's kept in the transcript
|
||||
this.getChat().addHistory(createUserContent(request));
|
||||
yield hookResult;
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
} else if (
|
||||
'type' in hookResult &&
|
||||
hookResult.type === GeminiEventType.AgentExecutionBlocked
|
||||
) {
|
||||
yield hookResult;
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
} else if ('additionalContext' in hookResult) {
|
||||
@@ -781,11 +804,24 @@ export class GeminiClient {
|
||||
turn,
|
||||
);
|
||||
|
||||
if (
|
||||
hookOutput?.isBlockingDecision() ||
|
||||
hookOutput?.shouldStopExecution()
|
||||
) {
|
||||
if (hookOutput?.shouldStopExecution()) {
|
||||
yield {
|
||||
type: GeminiEventType.AgentExecutionStopped,
|
||||
value: {
|
||||
reason: hookOutput.getEffectiveReason(),
|
||||
},
|
||||
};
|
||||
return turn;
|
||||
}
|
||||
|
||||
if (hookOutput?.isBlockingDecision()) {
|
||||
const continueReason = hookOutput.getEffectiveReason();
|
||||
yield {
|
||||
type: GeminiEventType.AgentExecutionBlocked,
|
||||
value: {
|
||||
reason: continueReason,
|
||||
},
|
||||
};
|
||||
const continueRequest = [{ text: continueReason }];
|
||||
yield* this.sendMessageStream(
|
||||
continueRequest,
|
||||
|
||||
@@ -66,12 +66,28 @@ export enum GeminiEventType {
|
||||
ContextWindowWillOverflow = 'context_window_will_overflow',
|
||||
InvalidStream = 'invalid_stream',
|
||||
ModelInfo = 'model_info',
|
||||
AgentExecutionStopped = 'agent_execution_stopped',
|
||||
AgentExecutionBlocked = 'agent_execution_blocked',
|
||||
}
|
||||
|
||||
export type ServerGeminiRetryEvent = {
|
||||
type: GeminiEventType.Retry;
|
||||
};
|
||||
|
||||
export type ServerGeminiAgentExecutionStoppedEvent = {
|
||||
type: GeminiEventType.AgentExecutionStopped;
|
||||
value: {
|
||||
reason: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type ServerGeminiAgentExecutionBlockedEvent = {
|
||||
type: GeminiEventType.AgentExecutionBlocked;
|
||||
value: {
|
||||
reason: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type ServerGeminiContextWindowWillOverflowEvent = {
|
||||
type: GeminiEventType.ContextWindowWillOverflow;
|
||||
value: {
|
||||
@@ -204,7 +220,9 @@ export type ServerGeminiStreamEvent =
|
||||
| ServerGeminiRetryEvent
|
||||
| ServerGeminiContextWindowWillOverflowEvent
|
||||
| ServerGeminiInvalidStreamEvent
|
||||
| ServerGeminiModelInfoEvent;
|
||||
| ServerGeminiModelInfoEvent
|
||||
| ServerGeminiAgentExecutionStoppedEvent
|
||||
| ServerGeminiAgentExecutionBlockedEvent;
|
||||
|
||||
// A turn manages the agentic loop turn within the server context.
|
||||
export class Turn {
|
||||
|
||||
Reference in New Issue
Block a user