mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
fix(hooks): deduplicate agent hooks and add cross-platform integration tests (#15701)
This commit is contained in:
2
integration-tests/hooks-agent-flow-multistep.responses
Normal file
2
integration-tests/hooks-agent-flow-multistep.responses
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_dir","args":{"path":"."}}}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":10,"totalTokenCount":20}}]}
|
||||||
|
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"text":"Final Answer"}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":10,"totalTokenCount":20}}]}
|
||||||
1
integration-tests/hooks-agent-flow.responses
Normal file
1
integration-tests/hooks-agent-flow.responses
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"text":"**Responding**\n\nI will respond to the user's request.\n\n","thought":true}],"role":"model"},"index":0}],"usageMetadata":{"promptTokenCount":100,"totalTokenCount":120,"promptTokensDetails":[{"modality":"TEXT","tokenCount":100}],"thoughtsTokenCount":20}},{"candidates":[{"content":{"parts":[{"text":"Response to: "}],"role":"model"},"index":0}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":5,"totalTokenCount":125,"promptTokensDetails":[{"modality":"TEXT","tokenCount":100}],"thoughtsTokenCount":20}},{"candidates":[{"content":{"parts":[{"text":"Hello World"}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":7,"totalTokenCount":127,"promptTokensDetails":[{"modality":"TEXT","tokenCount":100}],"thoughtsTokenCount":20}}]}
|
||||||
238
integration-tests/hooks-agent-flow.test.ts
Normal file
238
integration-tests/hooks-agent-flow.test.ts
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||||
|
import { TestRig } from './test-helper.js';
|
||||||
|
import { join } from 'node:path';
|
||||||
|
import { writeFileSync } from 'node:fs';
|
||||||
|
|
||||||
|
describe('Hooks Agent Flow', () => {
|
||||||
|
let rig: TestRig;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
rig = new TestRig();
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(async () => {
|
||||||
|
if (rig) {
|
||||||
|
await rig.cleanup();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('BeforeAgent Hooks', () => {
|
||||||
|
it('should inject additional context via BeforeAgent hook', async () => {
|
||||||
|
await rig.setup('should inject additional context via BeforeAgent hook', {
|
||||||
|
fakeResponsesPath: join(
|
||||||
|
import.meta.dirname,
|
||||||
|
'hooks-agent-flow.responses',
|
||||||
|
),
|
||||||
|
});
|
||||||
|
|
||||||
|
const hookScript = `
|
||||||
|
try {
|
||||||
|
const output = {
|
||||||
|
decision: "allow",
|
||||||
|
hookSpecificOutput: {
|
||||||
|
hookEventName: "BeforeAgent",
|
||||||
|
additionalContext: "SYSTEM INSTRUCTION: This is injected context."
|
||||||
|
}
|
||||||
|
};
|
||||||
|
process.stdout.write(JSON.stringify(output));
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Failed to write stdout:', e);
|
||||||
|
process.exit(1);
|
||||||
|
}
|
||||||
|
console.error('DEBUG: BeforeAgent hook executed');
|
||||||
|
`;
|
||||||
|
|
||||||
|
const scriptPath = join(rig.testDir!, 'before_agent_context.cjs');
|
||||||
|
writeFileSync(scriptPath, hookScript);
|
||||||
|
|
||||||
|
await rig.setup('should inject additional context via BeforeAgent hook', {
|
||||||
|
settings: {
|
||||||
|
tools: {
|
||||||
|
enableHooks: true,
|
||||||
|
},
|
||||||
|
hooks: {
|
||||||
|
BeforeAgent: [
|
||||||
|
{
|
||||||
|
hooks: [
|
||||||
|
{
|
||||||
|
type: 'command',
|
||||||
|
command: `node "${scriptPath}"`,
|
||||||
|
timeout: 5000,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await rig.run({ args: 'Hello test' });
|
||||||
|
|
||||||
|
// Verify hook execution and telemetry
|
||||||
|
const hookTelemetryFound = await rig.waitForTelemetryEvent('hook_call');
|
||||||
|
expect(hookTelemetryFound).toBeTruthy();
|
||||||
|
|
||||||
|
const hookLogs = rig.readHookLogs();
|
||||||
|
const beforeAgentLog = hookLogs.find(
|
||||||
|
(log) => log.hookCall.hook_event_name === 'BeforeAgent',
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(beforeAgentLog).toBeDefined();
|
||||||
|
expect(beforeAgentLog?.hookCall.stdout).toContain('injected context');
|
||||||
|
expect(beforeAgentLog?.hookCall.stdout).toContain('"decision":"allow"');
|
||||||
|
expect(beforeAgentLog?.hookCall.stdout).toContain(
|
||||||
|
'SYSTEM INSTRUCTION: This is injected context.',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('AfterAgent Hooks', () => {
|
||||||
|
it('should receive prompt and response in AfterAgent hook', async () => {
|
||||||
|
await rig.setup('should receive prompt and response in AfterAgent hook', {
|
||||||
|
fakeResponsesPath: join(
|
||||||
|
import.meta.dirname,
|
||||||
|
'hooks-agent-flow.responses',
|
||||||
|
),
|
||||||
|
});
|
||||||
|
|
||||||
|
const hookScript = `
|
||||||
|
const fs = require('fs');
|
||||||
|
try {
|
||||||
|
const input = fs.readFileSync(0, 'utf-8');
|
||||||
|
console.error('DEBUG: AfterAgent hook input received');
|
||||||
|
process.stdout.write("Received Input: " + input);
|
||||||
|
// Ensure separation between the echo and the JSON output if they were to run together (though relying on separate console calls usually separates by newline)
|
||||||
|
// usage of process.stdout.write does NOT add newline.
|
||||||
|
// But here we want strictly the output "Received Input..." to be present.
|
||||||
|
// We also need to output the JSON decision for the hook runner to consider it successful?
|
||||||
|
// Actually HookRunner parses the *last* valid JSON block or treats text as system message.
|
||||||
|
// If we output mixed text and JSON, HookRunner might get confused if we don't handle it right.
|
||||||
|
// Existing test expects "Received Input" in stdout. And "Hello World".
|
||||||
|
// It DOES NOT parse the decision?
|
||||||
|
// Wait, HookRunner logic:
|
||||||
|
// "if (exitCode === EXIT_CODE_SUCCESS && stdout.trim()) ... JSON.parse ..."
|
||||||
|
// If JSON.parse fails: "Not JSON, convert plain text to structured output"
|
||||||
|
// So if we output formatted text, it becomes "systemMessage".
|
||||||
|
// That is fine for this test as we don't check the decision, just the stdout content.
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Hook Failed:', err);
|
||||||
|
process.exit(1);
|
||||||
|
}
|
||||||
|
`;
|
||||||
|
|
||||||
|
const scriptPath = join(rig.testDir!, 'after_agent_verify.cjs');
|
||||||
|
writeFileSync(scriptPath, hookScript);
|
||||||
|
|
||||||
|
await rig.setup('should receive prompt and response in AfterAgent hook', {
|
||||||
|
settings: {
|
||||||
|
tools: {
|
||||||
|
enableHooks: true,
|
||||||
|
},
|
||||||
|
hooks: {
|
||||||
|
AfterAgent: [
|
||||||
|
{
|
||||||
|
hooks: [
|
||||||
|
{
|
||||||
|
type: 'command',
|
||||||
|
command: `node "${scriptPath}"`,
|
||||||
|
timeout: 5000,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await rig.run({ args: 'Hello validation' });
|
||||||
|
|
||||||
|
const hookTelemetryFound = await rig.waitForTelemetryEvent('hook_call');
|
||||||
|
expect(hookTelemetryFound).toBeTruthy();
|
||||||
|
|
||||||
|
const hookLogs = rig.readHookLogs();
|
||||||
|
const afterAgentLog = hookLogs.find(
|
||||||
|
(log) => log.hookCall.hook_event_name === 'AfterAgent',
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(afterAgentLog).toBeDefined();
|
||||||
|
// Verify the hook stdout contains the input we echoed which proves the hook received the prompt and response
|
||||||
|
expect(afterAgentLog?.hookCall.stdout).toContain('Received Input');
|
||||||
|
expect(afterAgentLog?.hookCall.stdout).toContain('Hello validation');
|
||||||
|
// The fake response contains "Hello World"
|
||||||
|
expect(afterAgentLog?.hookCall.stdout).toContain('Hello World');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('Multi-step Loops', () => {
|
||||||
|
it('should fire BeforeAgent and AfterAgent exactly once per turn despite tool calls', async () => {
|
||||||
|
await rig.setup(
|
||||||
|
'should fire BeforeAgent and AfterAgent exactly once per turn despite tool calls',
|
||||||
|
{
|
||||||
|
fakeResponsesPath: join(
|
||||||
|
import.meta.dirname,
|
||||||
|
'hooks-agent-flow-multistep.responses',
|
||||||
|
),
|
||||||
|
settings: {
|
||||||
|
tools: {
|
||||||
|
enableHooks: true,
|
||||||
|
},
|
||||||
|
hooks: {
|
||||||
|
BeforeAgent: [
|
||||||
|
{
|
||||||
|
hooks: [
|
||||||
|
{
|
||||||
|
type: 'command',
|
||||||
|
command: `node -e "console.log('BeforeAgent Fired')"`,
|
||||||
|
timeout: 5000,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
AfterAgent: [
|
||||||
|
{
|
||||||
|
hooks: [
|
||||||
|
{
|
||||||
|
type: 'command',
|
||||||
|
command: `node -e "console.log('AfterAgent Fired')"`,
|
||||||
|
timeout: 5000,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
await rig.run({ args: 'Do a multi-step task' });
|
||||||
|
|
||||||
|
const hookLogs = rig.readHookLogs();
|
||||||
|
const beforeAgentLogs = hookLogs.filter(
|
||||||
|
(log) => log.hookCall.hook_event_name === 'BeforeAgent',
|
||||||
|
);
|
||||||
|
const afterAgentLogs = hookLogs.filter(
|
||||||
|
(log) => log.hookCall.hook_event_name === 'AfterAgent',
|
||||||
|
);
|
||||||
|
|
||||||
|
// Should ensure BeforeAgent fired once
|
||||||
|
expect(beforeAgentLogs).toHaveLength(1);
|
||||||
|
|
||||||
|
// Should ensure AfterAgent fired once
|
||||||
|
// Note: If the tool call itself triggered BeforeTool/AfterTool, that's fine,
|
||||||
|
// but BeforeAgent/AfterAgent should only wrap the *entire* turn (User Request -> Final Answer).
|
||||||
|
expect(afterAgentLogs).toHaveLength(1);
|
||||||
|
|
||||||
|
// Verify the output log content to ensure we actually got the final answer
|
||||||
|
// (This implies the loop completed successfully)
|
||||||
|
const afterAgentLog = afterAgentLogs[0];
|
||||||
|
expect(afterAgentLog).toBeDefined();
|
||||||
|
expect(afterAgentLog?.hookCall.stdout).toContain('AfterAgent Fired');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -44,7 +44,7 @@ describe('Hooks System Integration', () => {
|
|||||||
{
|
{
|
||||||
type: 'command',
|
type: 'command',
|
||||||
command:
|
command:
|
||||||
'echo "{\\"decision\\": \\"block\\", \\"reason\\": \\"File writing blocked by security policy\\"}"',
|
"node -e \"console.log(JSON.stringify({decision: 'block', reason: 'File writing blocked by security policy'}))\"",
|
||||||
timeout: 5000,
|
timeout: 5000,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -97,7 +97,7 @@ describe('Hooks System Integration', () => {
|
|||||||
{
|
{
|
||||||
type: 'command',
|
type: 'command',
|
||||||
command:
|
command:
|
||||||
'echo "{\\"decision\\": \\"allow\\", \\"reason\\": \\"File writing approved\\"}"',
|
"node -e \"console.log(JSON.stringify({decision: 'allow', reason: 'File writing approved'}))\"",
|
||||||
timeout: 5000,
|
timeout: 5000,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -129,7 +129,7 @@ describe('Hooks System Integration', () => {
|
|||||||
describe('Command Hooks - Additional Context', () => {
|
describe('Command Hooks - Additional Context', () => {
|
||||||
it('should add additional context from AfterTool hooks', async () => {
|
it('should add additional context from AfterTool hooks', async () => {
|
||||||
const command =
|
const command =
|
||||||
'echo "{\\"hookSpecificOutput\\": {\\"hookEventName\\": \\"AfterTool\\", \\"additionalContext\\": \\"Security scan: File content appears safe\\"}}"';
|
"node -e \"console.log(JSON.stringify({hookSpecificOutput: {hookEventName: 'AfterTool', additionalContext: 'Security scan: File content appears safe'}}))\"";
|
||||||
await rig.setup('should add additional context from AfterTool hooks', {
|
await rig.setup('should add additional context from AfterTool hooks', {
|
||||||
fakeResponsesPath: join(
|
fakeResponsesPath: join(
|
||||||
import.meta.dirname,
|
import.meta.dirname,
|
||||||
@@ -190,27 +190,24 @@ describe('Hooks System Integration', () => {
|
|||||||
'hooks-system.before-model.responses',
|
'hooks-system.before-model.responses',
|
||||||
),
|
),
|
||||||
});
|
});
|
||||||
const hookScript = `#!/bin/bash
|
const hookScript = `const fs = require('fs');
|
||||||
echo '{
|
console.log(JSON.stringify({
|
||||||
"decision": "allow",
|
decision: "allow",
|
||||||
"hookSpecificOutput": {
|
hookSpecificOutput: {
|
||||||
"hookEventName": "BeforeModel",
|
hookEventName: "BeforeModel",
|
||||||
"llm_request": {
|
llm_request: {
|
||||||
"messages": [
|
messages: [
|
||||||
{
|
{
|
||||||
"role": "user",
|
role: "user",
|
||||||
"content": "Please respond with exactly: The security hook modified this request successfully."
|
content: "Please respond with exactly: The security hook modified this request successfully."
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}'`;
|
}));`;
|
||||||
|
|
||||||
const scriptPath = join(rig.testDir!, 'before_model_hook.sh');
|
const scriptPath = join(rig.testDir!, 'before_model_hook.cjs');
|
||||||
writeFileSync(scriptPath, hookScript);
|
writeFileSync(scriptPath, hookScript);
|
||||||
// Make executable
|
|
||||||
const { execSync } = await import('node:child_process');
|
|
||||||
execSync(`chmod +x "${scriptPath}"`);
|
|
||||||
|
|
||||||
await rig.setup('should modify LLM requests with BeforeModel hooks', {
|
await rig.setup('should modify LLM requests with BeforeModel hooks', {
|
||||||
settings: {
|
settings: {
|
||||||
@@ -223,7 +220,7 @@ echo '{
|
|||||||
hooks: [
|
hooks: [
|
||||||
{
|
{
|
||||||
type: 'command',
|
type: 'command',
|
||||||
command: scriptPath,
|
command: `node "${scriptPath}"`,
|
||||||
timeout: 5000,
|
timeout: 5000,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -250,7 +247,9 @@ echo '{
|
|||||||
expect(hookTelemetryFound[0].hookCall.hook_event_name).toBe(
|
expect(hookTelemetryFound[0].hookCall.hook_event_name).toBe(
|
||||||
'BeforeModel',
|
'BeforeModel',
|
||||||
);
|
);
|
||||||
expect(hookTelemetryFound[0].hookCall.hook_name).toBe(scriptPath);
|
expect(hookTelemetryFound[0].hookCall.hook_name).toBe(
|
||||||
|
`node "${scriptPath}"`,
|
||||||
|
);
|
||||||
expect(hookTelemetryFound[0].hookCall.hook_input).toBeDefined();
|
expect(hookTelemetryFound[0].hookCall.hook_input).toBeDefined();
|
||||||
expect(hookTelemetryFound[0].hookCall.hook_output).toBeDefined();
|
expect(hookTelemetryFound[0].hookCall.hook_output).toBeDefined();
|
||||||
expect(hookTelemetryFound[0].hookCall.exit_code).toBe(0);
|
expect(hookTelemetryFound[0].hookCall.exit_code).toBe(0);
|
||||||
@@ -270,30 +269,28 @@ echo '{
|
|||||||
),
|
),
|
||||||
});
|
});
|
||||||
// Create a hook script that modifies the LLM response
|
// Create a hook script that modifies the LLM response
|
||||||
const hookScript = `#!/bin/bash
|
const hookScript = `const fs = require('fs');
|
||||||
echo '{
|
console.log(JSON.stringify({
|
||||||
"hookSpecificOutput": {
|
hookSpecificOutput: {
|
||||||
"hookEventName": "AfterModel",
|
hookEventName: "AfterModel",
|
||||||
"llm_response": {
|
llm_response: {
|
||||||
"candidates": [
|
candidates: [
|
||||||
{
|
{
|
||||||
"content": {
|
content: {
|
||||||
"role": "model",
|
role: "model",
|
||||||
"parts": [
|
parts: [
|
||||||
"[FILTERED] Response has been filtered for security compliance."
|
"[FILTERED] Response has been filtered for security compliance."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"finishReason": "STOP"
|
finishReason: "STOP"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}'`;
|
}));`;
|
||||||
|
|
||||||
const scriptPath = join(rig.testDir!, 'after_model_hook.sh');
|
const scriptPath = join(rig.testDir!, 'after_model_hook.cjs');
|
||||||
writeFileSync(scriptPath, hookScript);
|
writeFileSync(scriptPath, hookScript);
|
||||||
const { execSync } = await import('node:child_process');
|
|
||||||
execSync(`chmod +x "${scriptPath}"`);
|
|
||||||
|
|
||||||
await rig.setup('should modify LLM responses with AfterModel hooks', {
|
await rig.setup('should modify LLM responses with AfterModel hooks', {
|
||||||
settings: {
|
settings: {
|
||||||
@@ -306,7 +303,7 @@ echo '{
|
|||||||
hooks: [
|
hooks: [
|
||||||
{
|
{
|
||||||
type: 'command',
|
type: 'command',
|
||||||
command: scriptPath,
|
command: `node "${scriptPath}"`,
|
||||||
timeout: 5000,
|
timeout: 5000,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -343,7 +340,7 @@ echo '{
|
|||||||
);
|
);
|
||||||
// Create inline hook command (works on both Unix and Windows)
|
// Create inline hook command (works on both Unix and Windows)
|
||||||
const hookCommand =
|
const hookCommand =
|
||||||
'echo "{\\"hookSpecificOutput\\": {\\"hookEventName\\": \\"BeforeToolSelection\\", \\"toolConfig\\": {\\"mode\\": \\"ANY\\", \\"allowedFunctionNames\\": [\\"read_file\\", \\"run_shell_command\\"]}}}"';
|
"node -e \"console.log(JSON.stringify({hookSpecificOutput: {hookEventName: 'BeforeToolSelection', toolConfig: {mode: 'ANY', allowedFunctionNames: ['read_file', 'run_shell_command']}}}))\"";
|
||||||
|
|
||||||
await rig.setup(
|
await rig.setup(
|
||||||
'should modify tool selection with BeforeToolSelection hooks',
|
'should modify tool selection with BeforeToolSelection hooks',
|
||||||
@@ -404,19 +401,17 @@ echo '{
|
|||||||
),
|
),
|
||||||
});
|
});
|
||||||
// Create a hook script that adds context to the prompt
|
// Create a hook script that adds context to the prompt
|
||||||
const hookScript = `#!/bin/bash
|
const hookScript = `const fs = require('fs');
|
||||||
echo '{
|
console.log(JSON.stringify({
|
||||||
"decision": "allow",
|
decision: "allow",
|
||||||
"hookSpecificOutput": {
|
hookSpecificOutput: {
|
||||||
"hookEventName": "BeforeAgent",
|
hookEventName: "BeforeAgent",
|
||||||
"additionalContext": "SYSTEM INSTRUCTION: You are in a secure environment. Always mention security compliance in your responses."
|
additionalContext: "SYSTEM INSTRUCTION: You are in a secure environment. Always mention security compliance in your responses."
|
||||||
}
|
}
|
||||||
}'`;
|
}));`;
|
||||||
|
|
||||||
const scriptPath = join(rig.testDir!, 'before_agent_hook.sh');
|
const scriptPath = join(rig.testDir!, 'before_agent_hook.cjs');
|
||||||
writeFileSync(scriptPath, hookScript);
|
writeFileSync(scriptPath, hookScript);
|
||||||
const { execSync } = await import('node:child_process');
|
|
||||||
execSync(`chmod +x "${scriptPath}"`);
|
|
||||||
|
|
||||||
await rig.setup('should augment prompts with BeforeAgent hooks', {
|
await rig.setup('should augment prompts with BeforeAgent hooks', {
|
||||||
settings: {
|
settings: {
|
||||||
@@ -429,7 +424,7 @@ echo '{
|
|||||||
hooks: [
|
hooks: [
|
||||||
{
|
{
|
||||||
type: 'command',
|
type: 'command',
|
||||||
command: scriptPath,
|
command: `node "${scriptPath}"`,
|
||||||
timeout: 5000,
|
timeout: 5000,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -452,9 +447,10 @@ echo '{
|
|||||||
|
|
||||||
describe('Notification Hooks - Permission Handling', () => {
|
describe('Notification Hooks - Permission Handling', () => {
|
||||||
it('should handle notification hooks for tool permissions', async () => {
|
it('should handle notification hooks for tool permissions', async () => {
|
||||||
|
// Create inline hook command (works on both Unix and Windows)
|
||||||
// Create inline hook command (works on both Unix and Windows)
|
// Create inline hook command (works on both Unix and Windows)
|
||||||
const hookCommand =
|
const hookCommand =
|
||||||
'echo "{\\"suppressOutput\\": false, \\"systemMessage\\": \\"Permission request logged by security hook\\"}"';
|
'node -e "console.log(JSON.stringify({suppressOutput: false, systemMessage: \'Permission request logged by security hook\'}))"';
|
||||||
|
|
||||||
await rig.setup('should handle notification hooks for tool permissions', {
|
await rig.setup('should handle notification hooks for tool permissions', {
|
||||||
fakeResponsesPath: join(
|
fakeResponsesPath: join(
|
||||||
@@ -548,9 +544,9 @@ echo '{
|
|||||||
it('should execute hooks sequentially when configured', async () => {
|
it('should execute hooks sequentially when configured', async () => {
|
||||||
// Create inline hook commands (works on both Unix and Windows)
|
// Create inline hook commands (works on both Unix and Windows)
|
||||||
const hook1Command =
|
const hook1Command =
|
||||||
'echo "{\\"decision\\": \\"allow\\", \\"hookSpecificOutput\\": {\\"hookEventName\\": \\"BeforeAgent\\", \\"additionalContext\\": \\"Step 1: Initial validation passed.\\"}}"';
|
"node -e \"console.log(JSON.stringify({decision: 'allow', hookSpecificOutput: {hookEventName: 'BeforeAgent', additionalContext: 'Step 1: Initial validation passed.'}}))\"";
|
||||||
const hook2Command =
|
const hook2Command =
|
||||||
'echo "{\\"decision\\": \\"allow\\", \\"hookSpecificOutput\\": {\\"hookEventName\\": \\"BeforeAgent\\", \\"additionalContext\\": \\"Step 2: Security check completed.\\"}}"';
|
"node -e \"console.log(JSON.stringify({decision: 'allow', hookSpecificOutput: {hookEventName: 'BeforeAgent', additionalContext: 'Step 2: Security check completed.'}}))\"";
|
||||||
|
|
||||||
await rig.setup('should execute hooks sequentially when configured', {
|
await rig.setup('should execute hooks sequentially when configured', {
|
||||||
fakeResponsesPath: join(
|
fakeResponsesPath: join(
|
||||||
@@ -621,23 +617,22 @@ echo '{
|
|||||||
),
|
),
|
||||||
});
|
});
|
||||||
// Create a hook script that validates the input format
|
// Create a hook script that validates the input format
|
||||||
const hookScript = `#!/bin/bash
|
const hookScript = `const fs = require('fs');
|
||||||
# Read JSON input from stdin
|
const input = fs.readFileSync(0, 'utf-8');
|
||||||
input=$(cat)
|
try {
|
||||||
|
const json = JSON.parse(input);
|
||||||
|
// Check fields
|
||||||
|
if (json.session_id && json.cwd && json.hook_event_name && json.timestamp && json.tool_name && json.tool_input) {
|
||||||
|
console.log(JSON.stringify({decision: "allow", reason: "Input format is correct"}));
|
||||||
|
} else {
|
||||||
|
console.log(JSON.stringify({decision: "block", reason: "Input format is invalid"}));
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.log(JSON.stringify({decision: "block", reason: "Invalid JSON"}));
|
||||||
|
}`;
|
||||||
|
|
||||||
# Check for required fields
|
const scriptPath = join(rig.testDir!, 'input_validation_hook.cjs');
|
||||||
if echo "$input" | jq -e '.session_id and .cwd and .hook_event_name and .timestamp and .tool_name and .tool_input' > /dev/null 2>&1; then
|
|
||||||
echo '{"decision": "allow", "reason": "Input format is correct"}'
|
|
||||||
exit 0
|
|
||||||
else
|
|
||||||
echo '{"decision": "block", "reason": "Input format is invalid"}'
|
|
||||||
exit 0
|
|
||||||
fi`;
|
|
||||||
|
|
||||||
const scriptPath = join(rig.testDir!, 'input_validation_hook.sh');
|
|
||||||
writeFileSync(scriptPath, hookScript);
|
writeFileSync(scriptPath, hookScript);
|
||||||
const { execSync } = await import('node:child_process');
|
|
||||||
execSync(`chmod +x "${scriptPath}"`);
|
|
||||||
|
|
||||||
await rig.setup('should provide correct input format to hooks', {
|
await rig.setup('should provide correct input format to hooks', {
|
||||||
settings: {
|
settings: {
|
||||||
@@ -650,7 +645,7 @@ fi`;
|
|||||||
hooks: [
|
hooks: [
|
||||||
{
|
{
|
||||||
type: 'command',
|
type: 'command',
|
||||||
command: scriptPath,
|
command: `node "${scriptPath}"`,
|
||||||
timeout: 5000,
|
timeout: 5000,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -682,11 +677,11 @@ fi`;
|
|||||||
it('should handle hooks for all major event types', async () => {
|
it('should handle hooks for all major event types', async () => {
|
||||||
// Create inline hook commands (works on both Unix and Windows)
|
// Create inline hook commands (works on both Unix and Windows)
|
||||||
const beforeToolCommand =
|
const beforeToolCommand =
|
||||||
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"BeforeTool: File operation logged\\"}"';
|
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'BeforeTool: File operation logged'}))\"";
|
||||||
const afterToolCommand =
|
const afterToolCommand =
|
||||||
'echo "{\\"hookSpecificOutput\\": {\\"hookEventName\\": \\"AfterTool\\", \\"additionalContext\\": \\"AfterTool: Operation completed successfully\\"}}"';
|
"node -e \"console.log(JSON.stringify({hookSpecificOutput: {hookEventName: 'AfterTool', additionalContext: 'AfterTool: Operation completed successfully'}}))\"";
|
||||||
const beforeAgentCommand =
|
const beforeAgentCommand =
|
||||||
'echo "{\\"decision\\": \\"allow\\", \\"hookSpecificOutput\\": {\\"hookEventName\\": \\"BeforeAgent\\", \\"additionalContext\\": \\"BeforeAgent: User request processed\\"}}"';
|
"node -e \"console.log(JSON.stringify({decision: 'allow', hookSpecificOutput: {hookEventName: 'BeforeAgent', additionalContext: 'BeforeAgent: User request processed'}}))\"";
|
||||||
|
|
||||||
await rig.setup('should handle hooks for all major event types', {
|
await rig.setup('should handle hooks for all major event types', {
|
||||||
fakeResponsesPath: join(
|
fakeResponsesPath: join(
|
||||||
@@ -802,10 +797,10 @@ fi`;
|
|||||||
// Create a hook script that fails
|
// Create a hook script that fails
|
||||||
// Create inline hook commands (works on both Unix and Windows)
|
// Create inline hook commands (works on both Unix and Windows)
|
||||||
// Failing hook: exits with non-zero code
|
// Failing hook: exits with non-zero code
|
||||||
const failingCommand = 'exit 1';
|
const failingCommand = 'node -e "process.exit(1)"';
|
||||||
// Working hook: returns success with JSON
|
// Working hook: returns success with JSON
|
||||||
const workingCommand =
|
const workingCommand =
|
||||||
'echo "{\\"decision\\": \\"allow\\", \\"reason\\": \\"Working hook succeeded\\"}"';
|
"node -e \"console.log(JSON.stringify({decision: 'allow', reason: 'Working hook succeeded'}))\"";
|
||||||
|
|
||||||
await rig.setup('should handle hook failures gracefully', {
|
await rig.setup('should handle hook failures gracefully', {
|
||||||
settings: {
|
settings: {
|
||||||
@@ -855,7 +850,7 @@ fi`;
|
|||||||
it('should generate telemetry events for hook executions', async () => {
|
it('should generate telemetry events for hook executions', async () => {
|
||||||
// Create inline hook command (works on both Unix and Windows)
|
// Create inline hook command (works on both Unix and Windows)
|
||||||
const hookCommand =
|
const hookCommand =
|
||||||
'echo "{\\"decision\\": \\"allow\\", \\"reason\\": \\"Telemetry test hook\\"}"';
|
"node -e \"console.log(JSON.stringify({decision: 'allow', reason: 'Telemetry test hook'}))\"";
|
||||||
|
|
||||||
await rig.setup('should generate telemetry events for hook executions', {
|
await rig.setup('should generate telemetry events for hook executions', {
|
||||||
fakeResponsesPath: join(
|
fakeResponsesPath: join(
|
||||||
@@ -898,7 +893,7 @@ fi`;
|
|||||||
it('should fire SessionStart hook on app startup', async () => {
|
it('should fire SessionStart hook on app startup', async () => {
|
||||||
// Create inline hook command that outputs JSON
|
// Create inline hook command that outputs JSON
|
||||||
const sessionStartCommand =
|
const sessionStartCommand =
|
||||||
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"Session starting on startup\\"}"';
|
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'Session starting on startup'}))\"";
|
||||||
|
|
||||||
await rig.setup('should fire SessionStart hook on app startup', {
|
await rig.setup('should fire SessionStart hook on app startup', {
|
||||||
fakeResponsesPath: join(
|
fakeResponsesPath: join(
|
||||||
@@ -958,9 +953,9 @@ fi`;
|
|||||||
it('should fire SessionEnd and SessionStart hooks on /clear command', async () => {
|
it('should fire SessionEnd and SessionStart hooks on /clear command', async () => {
|
||||||
// Create inline hook commands for both SessionEnd and SessionStart
|
// Create inline hook commands for both SessionEnd and SessionStart
|
||||||
const sessionEndCommand =
|
const sessionEndCommand =
|
||||||
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"Session ending due to clear\\"}"';
|
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'Session ending due to clear'}))\"";
|
||||||
const sessionStartCommand =
|
const sessionStartCommand =
|
||||||
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"Session starting after clear\\"}"';
|
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'Session starting after clear'}))\"";
|
||||||
|
|
||||||
await rig.setup(
|
await rig.setup(
|
||||||
'should fire SessionEnd and SessionStart hooks on /clear command',
|
'should fire SessionEnd and SessionStart hooks on /clear command',
|
||||||
@@ -1136,7 +1131,7 @@ fi`;
|
|||||||
it('should fire PreCompress hook on automatic compression', async () => {
|
it('should fire PreCompress hook on automatic compression', async () => {
|
||||||
// Create inline hook command that outputs JSON
|
// Create inline hook command that outputs JSON
|
||||||
const preCompressCommand =
|
const preCompressCommand =
|
||||||
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"PreCompress hook executed for automatic compression\\"}"';
|
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'PreCompress hook executed for automatic compression'}))\"";
|
||||||
|
|
||||||
await rig.setup('should fire PreCompress hook on automatic compression', {
|
await rig.setup('should fire PreCompress hook on automatic compression', {
|
||||||
fakeResponsesPath: join(
|
fakeResponsesPath: join(
|
||||||
@@ -1203,7 +1198,7 @@ fi`;
|
|||||||
describe('SessionEnd on Exit', () => {
|
describe('SessionEnd on Exit', () => {
|
||||||
it('should fire SessionEnd hook on graceful exit in non-interactive mode', async () => {
|
it('should fire SessionEnd hook on graceful exit in non-interactive mode', async () => {
|
||||||
const sessionEndCommand =
|
const sessionEndCommand =
|
||||||
'echo "{\\"decision\\": \\"allow\\", \\"systemMessage\\": \\"SessionEnd hook executed on exit\\"}"';
|
"node -e \"console.log(JSON.stringify({decision: 'allow', systemMessage: 'SessionEnd hook executed on exit'}))\"";
|
||||||
|
|
||||||
await rig.setup('should fire SessionEnd hook on graceful exit', {
|
await rig.setup('should fire SessionEnd hook on graceful exit', {
|
||||||
fakeResponsesPath: join(
|
fakeResponsesPath: join(
|
||||||
@@ -1297,20 +1292,17 @@ fi`;
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Create two hook scripts - one enabled, one disabled
|
// Create two hook scripts - one enabled, one disabled
|
||||||
const enabledHookScript = `#!/bin/bash
|
const enabledHookScript = `const fs = require('fs');
|
||||||
echo '{"decision": "allow", "systemMessage": "Enabled hook executed"}'`;
|
console.log(JSON.stringify({decision: "allow", systemMessage: "Enabled hook executed"}));`;
|
||||||
|
|
||||||
const disabledHookScript = `#!/bin/bash
|
const disabledHookScript = `const fs = require('fs');
|
||||||
echo '{"decision": "block", "systemMessage": "Disabled hook should not execute", "reason": "This hook should be disabled"}'`;
|
console.log(JSON.stringify({decision: "block", systemMessage: "Disabled hook should not execute", reason: "This hook should be disabled"}));`;
|
||||||
|
|
||||||
const enabledPath = join(rig.testDir!, 'enabled_hook.sh');
|
const enabledPath = join(rig.testDir!, 'enabled_hook.cjs');
|
||||||
const disabledPath = join(rig.testDir!, 'disabled_hook.sh');
|
const disabledPath = join(rig.testDir!, 'disabled_hook.cjs');
|
||||||
|
|
||||||
writeFileSync(enabledPath, enabledHookScript);
|
writeFileSync(enabledPath, enabledHookScript);
|
||||||
writeFileSync(disabledPath, disabledHookScript);
|
writeFileSync(disabledPath, disabledHookScript);
|
||||||
const { execSync } = await import('node:child_process');
|
|
||||||
execSync(`chmod +x "${enabledPath}"`);
|
|
||||||
execSync(`chmod +x "${disabledPath}"`);
|
|
||||||
|
|
||||||
await rig.setup('should not execute hooks disabled in settings file', {
|
await rig.setup('should not execute hooks disabled in settings file', {
|
||||||
settings: {
|
settings: {
|
||||||
@@ -1323,18 +1315,18 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
|
|||||||
hooks: [
|
hooks: [
|
||||||
{
|
{
|
||||||
type: 'command',
|
type: 'command',
|
||||||
command: enabledPath,
|
command: `node "${enabledPath}"`,
|
||||||
timeout: 5000,
|
timeout: 5000,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: 'command',
|
type: 'command',
|
||||||
command: disabledPath,
|
command: `node "${disabledPath}"`,
|
||||||
timeout: 5000,
|
timeout: 5000,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
disabled: [disabledPath], // Disable the second hook
|
disabled: [`node "${disabledPath}"`], // Disable the second hook
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
@@ -1358,10 +1350,10 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
|
|||||||
// Check hook telemetry - only enabled hook should have executed
|
// Check hook telemetry - only enabled hook should have executed
|
||||||
const hookLogs = rig.readHookLogs();
|
const hookLogs = rig.readHookLogs();
|
||||||
const enabledHookLog = hookLogs.find(
|
const enabledHookLog = hookLogs.find(
|
||||||
(log) => log.hookCall.hook_name === enabledPath,
|
(log) => log.hookCall.hook_name === `node "${enabledPath}"`,
|
||||||
);
|
);
|
||||||
const disabledHookLog = hookLogs.find(
|
const disabledHookLog = hookLogs.find(
|
||||||
(log) => log.hookCall.hook_name === disabledPath,
|
(log) => log.hookCall.hook_name === `node "${disabledPath}"`,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(enabledHookLog).toBeDefined();
|
expect(enabledHookLog).toBeDefined();
|
||||||
@@ -1380,20 +1372,17 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Create two hook scripts - one that will be disabled, one that won't
|
// Create two hook scripts - one that will be disabled, one that won't
|
||||||
const activeHookScript = `#!/bin/bash
|
const activeHookScript = `const fs = require('fs');
|
||||||
echo '{"decision": "allow", "systemMessage": "Active hook executed"}'`;
|
console.log(JSON.stringify({decision: "allow", systemMessage: "Active hook executed"}));`;
|
||||||
|
|
||||||
const disabledHookScript = `#!/bin/bash
|
const disabledHookScript = `const fs = require('fs');
|
||||||
echo '{"decision": "block", "systemMessage": "Disabled hook should not execute", "reason": "This hook is disabled"}'`;
|
console.log(JSON.stringify({decision: "block", systemMessage: "Disabled hook should not execute", reason: "This hook is disabled"}));`;
|
||||||
|
|
||||||
const activePath = join(rig.testDir!, 'active_hook.sh');
|
const activePath = join(rig.testDir!, 'active_hook.cjs');
|
||||||
const disabledPath = join(rig.testDir!, 'disabled_hook.sh');
|
const disabledPath = join(rig.testDir!, 'disabled_hook.cjs');
|
||||||
|
|
||||||
writeFileSync(activePath, activeHookScript);
|
writeFileSync(activePath, activeHookScript);
|
||||||
writeFileSync(disabledPath, disabledHookScript);
|
writeFileSync(disabledPath, disabledHookScript);
|
||||||
const { execSync } = await import('node:child_process');
|
|
||||||
execSync(`chmod +x "${activePath}"`);
|
|
||||||
execSync(`chmod +x "${disabledPath}"`);
|
|
||||||
|
|
||||||
await rig.setup(
|
await rig.setup(
|
||||||
'should respect disabled hooks across multiple operations',
|
'should respect disabled hooks across multiple operations',
|
||||||
@@ -1408,18 +1397,18 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
|
|||||||
hooks: [
|
hooks: [
|
||||||
{
|
{
|
||||||
type: 'command',
|
type: 'command',
|
||||||
command: activePath,
|
command: `node "${activePath}"`,
|
||||||
timeout: 5000,
|
timeout: 5000,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: 'command',
|
type: 'command',
|
||||||
command: disabledPath,
|
command: `node "${disabledPath}"`,
|
||||||
timeout: 5000,
|
timeout: 5000,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
disabled: [disabledPath], // Disable the second hook
|
disabled: [`node "${disabledPath}"`], // Disable the second hook
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -1441,10 +1430,10 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
|
|||||||
// Check hook telemetry
|
// Check hook telemetry
|
||||||
const hookLogs1 = rig.readHookLogs();
|
const hookLogs1 = rig.readHookLogs();
|
||||||
const activeHookLog1 = hookLogs1.find(
|
const activeHookLog1 = hookLogs1.find(
|
||||||
(log) => log.hookCall.hook_name === activePath,
|
(log) => log.hookCall.hook_name === `node "${activePath}"`,
|
||||||
);
|
);
|
||||||
const disabledHookLog1 = hookLogs1.find(
|
const disabledHookLog1 = hookLogs1.find(
|
||||||
(log) => log.hookCall.hook_name === disabledPath,
|
(log) => log.hookCall.hook_name === `node "${disabledPath}"`,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(activeHookLog1).toBeDefined();
|
expect(activeHookLog1).toBeDefined();
|
||||||
@@ -1465,7 +1454,7 @@ echo '{"decision": "block", "systemMessage": "Disabled hook should not execute",
|
|||||||
// Verify disabled hook still hasn't executed
|
// Verify disabled hook still hasn't executed
|
||||||
const hookLogs2 = rig.readHookLogs();
|
const hookLogs2 = rig.readHookLogs();
|
||||||
const disabledHookCalls = hookLogs2.filter(
|
const disabledHookCalls = hookLogs2.filter(
|
||||||
(log) => log.hookCall.hook_name === disabledPath,
|
(log) => log.hookCall.hook_name === `node "${disabledPath}"`,
|
||||||
);
|
);
|
||||||
expect(disabledHookCalls.length).toBe(0);
|
expect(disabledHookCalls.length).toBe(0);
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -81,6 +81,10 @@ vi.mock('node:fs', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// --- Mocks ---
|
// --- Mocks ---
|
||||||
|
interface MockTurnContext {
|
||||||
|
getResponseText: Mock<() => string>;
|
||||||
|
}
|
||||||
|
|
||||||
const mockTurnRunFn = vi.fn();
|
const mockTurnRunFn = vi.fn();
|
||||||
|
|
||||||
vi.mock('./turn', async (importOriginal) => {
|
vi.mock('./turn', async (importOriginal) => {
|
||||||
@@ -94,6 +98,8 @@ vi.mock('./turn', async (importOriginal) => {
|
|||||||
constructor() {
|
constructor() {
|
||||||
// The constructor can be empty or do some mock setup
|
// The constructor can be empty or do some mock setup
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getResponseText = vi.fn().mockReturnValue('Mock Response');
|
||||||
}
|
}
|
||||||
// Export the mock class as 'Turn'
|
// Export the mock class as 'Turn'
|
||||||
return {
|
return {
|
||||||
@@ -129,6 +135,15 @@ vi.mock('../telemetry/uiTelemetry.js', () => ({
|
|||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
vi.mock('../hooks/hookSystem.js');
|
vi.mock('../hooks/hookSystem.js');
|
||||||
|
vi.mock('./clientHookTriggers.js', () => ({
|
||||||
|
fireBeforeAgentHook: vi.fn(),
|
||||||
|
fireAfterAgentHook: vi.fn().mockResolvedValue({
|
||||||
|
decision: 'allow',
|
||||||
|
continue: false,
|
||||||
|
suppressOutput: false,
|
||||||
|
systemMessage: undefined,
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Array.fromAsync ponyfill, which will be available in es 2024.
|
* Array.fromAsync ponyfill, which will be available in es 2024.
|
||||||
@@ -543,16 +558,22 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
await client.tryCompressChat('prompt-1', false); // force = false
|
await client.tryCompressChat('prompt-1', false); // force = false
|
||||||
|
|
||||||
// 3. Assert Step 1: Check that the flag became true
|
// 3. Assert Step 1: Check that the flag became true
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// 3. Assert Step 1: Check that the flag became true
|
||||||
expect((client as any).hasFailedCompressionAttempt).toBe(true);
|
expect(
|
||||||
|
(client as unknown as { hasFailedCompressionAttempt: boolean })
|
||||||
|
.hasFailedCompressionAttempt,
|
||||||
|
).toBe(true);
|
||||||
|
|
||||||
// 4. Test Step 2: Trigger a forced failure
|
// 4. Test Step 2: Trigger a forced failure
|
||||||
|
|
||||||
await client.tryCompressChat('prompt-2', true); // force = true
|
await client.tryCompressChat('prompt-2', true); // force = true
|
||||||
|
|
||||||
// 5. Assert Step 2: Check that the flag REMAINS true
|
// 5. Assert Step 2: Check that the flag REMAINS true
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// 5. Assert Step 2: Check that the flag REMAINS true
|
||||||
expect((client as any).hasFailedCompressionAttempt).toBe(true);
|
expect(
|
||||||
|
(client as unknown as { hasFailedCompressionAttempt: boolean })
|
||||||
|
.hasFailedCompressionAttempt,
|
||||||
|
).toBe(true);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should not trigger summarization if token count is below threshold', async () => {
|
it('should not trigger summarization if token count is below threshold', async () => {
|
||||||
@@ -2615,5 +2636,152 @@ ${JSON.stringify(
|
|||||||
'test-session-id',
|
'test-session-id',
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('Hook System', () => {
|
||||||
|
let mockMessageBus: { publish: Mock; subscribe: Mock };
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks();
|
||||||
|
mockMessageBus = { publish: vi.fn(), subscribe: vi.fn() };
|
||||||
|
|
||||||
|
// Force override config methods on the client instance
|
||||||
|
client['config'].getEnableHooks = vi.fn().mockReturnValue(true);
|
||||||
|
client['config'].getMessageBus = vi
|
||||||
|
.fn()
|
||||||
|
.mockReturnValue(mockMessageBus);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should fire BeforeAgent and AfterAgent exactly once for a simple turn', async () => {
|
||||||
|
const promptId = 'test-prompt-hook-1';
|
||||||
|
const request = { text: 'Hello Hooks' };
|
||||||
|
const signal = new AbortController().signal;
|
||||||
|
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
|
||||||
|
'./clientHookTriggers.js'
|
||||||
|
);
|
||||||
|
|
||||||
|
mockTurnRunFn.mockImplementation(async function* (
|
||||||
|
this: MockTurnContext,
|
||||||
|
) {
|
||||||
|
this.getResponseText.mockReturnValue('Hook Response');
|
||||||
|
yield { type: GeminiEventType.Content, value: 'Hook Response' };
|
||||||
|
});
|
||||||
|
|
||||||
|
const stream = client.sendMessageStream(request, signal, promptId);
|
||||||
|
while (!(await stream.next()).done);
|
||||||
|
|
||||||
|
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1);
|
||||||
|
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1);
|
||||||
|
expect(fireAfterAgentHook).toHaveBeenCalledWith(
|
||||||
|
expect.anything(),
|
||||||
|
request,
|
||||||
|
'Hook Response',
|
||||||
|
);
|
||||||
|
|
||||||
|
// Map should be empty
|
||||||
|
expect(client['hookStateMap'].size).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should fire BeforeAgent once and AfterAgent once even with recursion', async () => {
|
||||||
|
const { checkNextSpeaker } = await import(
|
||||||
|
'../utils/nextSpeakerChecker.js'
|
||||||
|
);
|
||||||
|
vi.mocked(checkNextSpeaker)
|
||||||
|
.mockResolvedValueOnce({ next_speaker: 'model', reasoning: 'more' })
|
||||||
|
.mockResolvedValueOnce(null);
|
||||||
|
|
||||||
|
const promptId = 'test-prompt-hook-recursive';
|
||||||
|
const request = { text: 'Recursion Test' };
|
||||||
|
const signal = new AbortController().signal;
|
||||||
|
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
|
||||||
|
'./clientHookTriggers.js'
|
||||||
|
);
|
||||||
|
|
||||||
|
let callCount = 0;
|
||||||
|
mockTurnRunFn.mockImplementation(async function* (
|
||||||
|
this: MockTurnContext,
|
||||||
|
) {
|
||||||
|
callCount++;
|
||||||
|
const response = `Response ${callCount}`;
|
||||||
|
this.getResponseText.mockReturnValue(response);
|
||||||
|
yield { type: GeminiEventType.Content, value: response };
|
||||||
|
});
|
||||||
|
|
||||||
|
const stream = client.sendMessageStream(request, signal, promptId);
|
||||||
|
while (!(await stream.next()).done);
|
||||||
|
|
||||||
|
// BeforeAgent should fire ONLY once despite multiple internal turns
|
||||||
|
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
|
// AfterAgent should fire ONLY when the stack unwinds
|
||||||
|
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
|
// Check cumulative response (separated by newline)
|
||||||
|
expect(fireAfterAgentHook).toHaveBeenCalledWith(
|
||||||
|
expect.anything(),
|
||||||
|
request,
|
||||||
|
'Response 1\nResponse 2',
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(client['hookStateMap'].size).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should use original request in AfterAgent hook even when continuation happened', async () => {
|
||||||
|
const { checkNextSpeaker } = await import(
|
||||||
|
'../utils/nextSpeakerChecker.js'
|
||||||
|
);
|
||||||
|
vi.mocked(checkNextSpeaker)
|
||||||
|
.mockResolvedValueOnce({ next_speaker: 'model', reasoning: 'more' })
|
||||||
|
.mockResolvedValueOnce(null);
|
||||||
|
|
||||||
|
const promptId = 'test-prompt-hook-original-req';
|
||||||
|
const request = { text: 'Do something' };
|
||||||
|
const signal = new AbortController().signal;
|
||||||
|
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
|
||||||
|
|
||||||
|
mockTurnRunFn.mockImplementation(async function* (
|
||||||
|
this: MockTurnContext,
|
||||||
|
) {
|
||||||
|
this.getResponseText.mockReturnValue('Ok');
|
||||||
|
yield { type: GeminiEventType.Content, value: 'Ok' };
|
||||||
|
});
|
||||||
|
|
||||||
|
const stream = client.sendMessageStream(request, signal, promptId);
|
||||||
|
while (!(await stream.next()).done);
|
||||||
|
|
||||||
|
expect(fireAfterAgentHook).toHaveBeenCalledWith(
|
||||||
|
expect.anything(),
|
||||||
|
request, // Should be 'Do something'
|
||||||
|
expect.stringContaining('Ok'),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should cleanup state when prompt_id changes', async () => {
|
||||||
|
const signal = new AbortController().signal;
|
||||||
|
mockTurnRunFn.mockImplementation(async function* (
|
||||||
|
this: MockTurnContext,
|
||||||
|
) {
|
||||||
|
this.getResponseText.mockReturnValue('Ok');
|
||||||
|
yield { type: GeminiEventType.Content, value: 'Ok' };
|
||||||
|
});
|
||||||
|
|
||||||
|
client['hookStateMap'].set('old-id', {
|
||||||
|
hasFiredBeforeAgent: true,
|
||||||
|
cumulativeResponse: 'Old',
|
||||||
|
activeCalls: 0,
|
||||||
|
originalRequest: { text: 'Old' },
|
||||||
|
});
|
||||||
|
client['lastPromptId'] = 'old-id';
|
||||||
|
|
||||||
|
const stream = client.sendMessageStream(
|
||||||
|
{ text: 'New' },
|
||||||
|
signal,
|
||||||
|
'new-id',
|
||||||
|
);
|
||||||
|
await stream.next();
|
||||||
|
|
||||||
|
expect(client['hookStateMap'].has('old-id')).toBe(false);
|
||||||
|
expect(client['hookStateMap'].has('new-id')).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import type {
|
|||||||
Tool,
|
Tool,
|
||||||
GenerateContentResponse,
|
GenerateContentResponse,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
|
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||||
import {
|
import {
|
||||||
getDirectoryContextString,
|
getDirectoryContextString,
|
||||||
getInitialChatHistory,
|
getInitialChatHistory,
|
||||||
@@ -42,6 +43,7 @@ import {
|
|||||||
fireBeforeAgentHook,
|
fireBeforeAgentHook,
|
||||||
fireAfterAgentHook,
|
fireAfterAgentHook,
|
||||||
} from './clientHookTriggers.js';
|
} from './clientHookTriggers.js';
|
||||||
|
import type { DefaultHookOutput } from '../hooks/types.js';
|
||||||
import {
|
import {
|
||||||
ContentRetryFailureEvent,
|
ContentRetryFailureEvent,
|
||||||
NextSpeakerCheckEvent,
|
NextSpeakerCheckEvent,
|
||||||
@@ -61,6 +63,14 @@ import type { RetryAvailabilityContext } from '../utils/retry.js';
|
|||||||
|
|
||||||
const MAX_TURNS = 100;
|
const MAX_TURNS = 100;
|
||||||
|
|
||||||
|
type BeforeAgentHookReturn =
|
||||||
|
| {
|
||||||
|
type: GeminiEventType.Error;
|
||||||
|
value: { error: Error };
|
||||||
|
}
|
||||||
|
| { additionalContext: string | undefined }
|
||||||
|
| undefined;
|
||||||
|
|
||||||
export class GeminiClient {
|
export class GeminiClient {
|
||||||
private chat?: GeminiChat;
|
private chat?: GeminiChat;
|
||||||
private sessionTurnCount = 0;
|
private sessionTurnCount = 0;
|
||||||
@@ -84,6 +94,95 @@ export class GeminiClient {
|
|||||||
this.lastPromptId = this.config.getSessionId();
|
this.lastPromptId = this.config.getSessionId();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hook state to deduplicate BeforeAgent calls and track response for
|
||||||
|
// AfterAgent
|
||||||
|
private hookStateMap = new Map<
|
||||||
|
string,
|
||||||
|
{
|
||||||
|
hasFiredBeforeAgent: boolean;
|
||||||
|
cumulativeResponse: string;
|
||||||
|
activeCalls: number;
|
||||||
|
originalRequest: PartListUnion;
|
||||||
|
}
|
||||||
|
>();
|
||||||
|
|
||||||
|
private async fireBeforeAgentHookSafe(
|
||||||
|
messageBus: MessageBus,
|
||||||
|
request: PartListUnion,
|
||||||
|
prompt_id: string,
|
||||||
|
): Promise<BeforeAgentHookReturn> {
|
||||||
|
let hookState = this.hookStateMap.get(prompt_id);
|
||||||
|
if (!hookState) {
|
||||||
|
hookState = {
|
||||||
|
hasFiredBeforeAgent: false,
|
||||||
|
cumulativeResponse: '',
|
||||||
|
activeCalls: 0,
|
||||||
|
originalRequest: request,
|
||||||
|
};
|
||||||
|
this.hookStateMap.set(prompt_id, hookState);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment active calls for this prompt_id
|
||||||
|
// This is called at the start of sendMessageStream, so it acts as an entry
|
||||||
|
// counter. We increment here, assuming this helper is ALWAYS called at
|
||||||
|
// entry.
|
||||||
|
hookState.activeCalls++;
|
||||||
|
|
||||||
|
if (hookState.hasFiredBeforeAgent) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
const hookOutput = await fireBeforeAgentHook(messageBus, request);
|
||||||
|
hookState.hasFiredBeforeAgent = true;
|
||||||
|
|
||||||
|
if (hookOutput?.isBlockingDecision() || hookOutput?.shouldStopExecution()) {
|
||||||
|
return {
|
||||||
|
type: GeminiEventType.Error,
|
||||||
|
value: {
|
||||||
|
error: new Error(
|
||||||
|
`BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const additionalContext = hookOutput?.getAdditionalContext();
|
||||||
|
if (additionalContext) {
|
||||||
|
return { additionalContext };
|
||||||
|
}
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
private async fireAfterAgentHookSafe(
|
||||||
|
messageBus: MessageBus,
|
||||||
|
currentRequest: PartListUnion,
|
||||||
|
prompt_id: string,
|
||||||
|
turn?: Turn,
|
||||||
|
): Promise<DefaultHookOutput | undefined> {
|
||||||
|
const hookState = this.hookStateMap.get(prompt_id);
|
||||||
|
// Only fire on the outermost call (when activeCalls is 1)
|
||||||
|
if (!hookState || hookState.activeCalls !== 1) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (turn && turn.pendingToolCalls.length > 0) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
const finalResponseText =
|
||||||
|
hookState.cumulativeResponse ||
|
||||||
|
turn?.getResponseText() ||
|
||||||
|
'[no response text]';
|
||||||
|
const finalRequest = hookState.originalRequest || currentRequest;
|
||||||
|
|
||||||
|
const hookOutput = await fireAfterAgentHook(
|
||||||
|
messageBus,
|
||||||
|
finalRequest,
|
||||||
|
finalResponseText,
|
||||||
|
);
|
||||||
|
return hookOutput;
|
||||||
|
}
|
||||||
|
|
||||||
private updateTelemetryTokenCount() {
|
private updateTelemetryTokenCount() {
|
||||||
if (this.chat) {
|
if (this.chat) {
|
||||||
uiTelemetryService.setLastPromptTokenCount(
|
uiTelemetryService.setLastPromptTokenCount(
|
||||||
@@ -400,63 +499,27 @@ export class GeminiClient {
|
|||||||
return this.config.getActiveModel();
|
return this.config.getActiveModel();
|
||||||
}
|
}
|
||||||
|
|
||||||
async *sendMessageStream(
|
private async *processTurn(
|
||||||
request: PartListUnion,
|
request: PartListUnion,
|
||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
prompt_id: string,
|
prompt_id: string,
|
||||||
turns: number = MAX_TURNS,
|
boundedTurns: number,
|
||||||
isInvalidStreamRetry: boolean = false,
|
isInvalidStreamRetry: boolean,
|
||||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||||
if (!isInvalidStreamRetry) {
|
// Re-initialize turn (it was empty before if in loop, or new instance)
|
||||||
this.config.resetTurn();
|
let turn = new Turn(this.getChat(), prompt_id);
|
||||||
}
|
|
||||||
|
|
||||||
// Fire BeforeAgent hook through MessageBus (only if hooks are enabled)
|
|
||||||
const hooksEnabled = this.config.getEnableHooks();
|
|
||||||
const messageBus = this.config.getMessageBus();
|
|
||||||
if (hooksEnabled && messageBus) {
|
|
||||||
const hookOutput = await fireBeforeAgentHook(messageBus, request);
|
|
||||||
|
|
||||||
if (
|
|
||||||
hookOutput?.isBlockingDecision() ||
|
|
||||||
hookOutput?.shouldStopExecution()
|
|
||||||
) {
|
|
||||||
yield {
|
|
||||||
type: GeminiEventType.Error,
|
|
||||||
value: {
|
|
||||||
error: new Error(
|
|
||||||
`BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
return new Turn(this.getChat(), prompt_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add additional context from hooks to the request
|
|
||||||
const additionalContext = hookOutput?.getAdditionalContext();
|
|
||||||
if (additionalContext) {
|
|
||||||
const requestArray = Array.isArray(request) ? request : [request];
|
|
||||||
request = [...requestArray, { text: additionalContext }];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.lastPromptId !== prompt_id) {
|
|
||||||
this.loopDetector.reset(prompt_id);
|
|
||||||
this.lastPromptId = prompt_id;
|
|
||||||
this.currentSequenceModel = null;
|
|
||||||
}
|
|
||||||
this.sessionTurnCount++;
|
this.sessionTurnCount++;
|
||||||
if (
|
if (
|
||||||
this.config.getMaxSessionTurns() > 0 &&
|
this.config.getMaxSessionTurns() > 0 &&
|
||||||
this.sessionTurnCount > this.config.getMaxSessionTurns()
|
this.sessionTurnCount > this.config.getMaxSessionTurns()
|
||||||
) {
|
) {
|
||||||
yield { type: GeminiEventType.MaxSessionTurns };
|
yield { type: GeminiEventType.MaxSessionTurns };
|
||||||
return new Turn(this.getChat(), prompt_id);
|
return turn;
|
||||||
}
|
}
|
||||||
// Ensure turns never exceeds MAX_TURNS to prevent infinite loops
|
|
||||||
const boundedTurns = Math.min(turns, MAX_TURNS);
|
|
||||||
if (!boundedTurns) {
|
if (!boundedTurns) {
|
||||||
return new Turn(this.getChat(), prompt_id);
|
return turn;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for context window overflow
|
// Check for context window overflow
|
||||||
@@ -478,7 +541,7 @@ export class GeminiClient {
|
|||||||
type: GeminiEventType.ContextWindowWillOverflow,
|
type: GeminiEventType.ContextWindowWillOverflow,
|
||||||
value: { estimatedRequestTokenCount, remainingTokenCount },
|
value: { estimatedRequestTokenCount, remainingTokenCount },
|
||||||
};
|
};
|
||||||
return new Turn(this.getChat(), prompt_id);
|
return turn;
|
||||||
}
|
}
|
||||||
|
|
||||||
const compressed = await this.tryCompressChat(prompt_id, false);
|
const compressed = await this.tryCompressChat(prompt_id, false);
|
||||||
@@ -514,7 +577,8 @@ export class GeminiClient {
|
|||||||
this.forceFullIdeContext = false;
|
this.forceFullIdeContext = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const turn = new Turn(this.getChat(), prompt_id);
|
// Re-initialize turn with fresh history
|
||||||
|
turn = new Turn(this.getChat(), prompt_id);
|
||||||
|
|
||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
const linkedSignal = AbortSignal.any([signal, controller.signal]);
|
const linkedSignal = AbortSignal.any([signal, controller.signal]);
|
||||||
@@ -555,6 +619,9 @@ export class GeminiClient {
|
|||||||
yield { type: GeminiEventType.ModelInfo, value: modelToUse };
|
yield { type: GeminiEventType.ModelInfo, value: modelToUse };
|
||||||
|
|
||||||
const resultStream = turn.run(modelConfigKey, request, linkedSignal);
|
const resultStream = turn.run(modelConfigKey, request, linkedSignal);
|
||||||
|
let isError = false;
|
||||||
|
let isInvalidStream = false;
|
||||||
|
|
||||||
for await (const event of resultStream) {
|
for await (const event of resultStream) {
|
||||||
if (this.loopDetector.addAndCheck(event)) {
|
if (this.loopDetector.addAndCheck(event)) {
|
||||||
yield { type: GeminiEventType.LoopDetected };
|
yield { type: GeminiEventType.LoopDetected };
|
||||||
@@ -566,94 +633,181 @@ export class GeminiClient {
|
|||||||
this.updateTelemetryTokenCount();
|
this.updateTelemetryTokenCount();
|
||||||
|
|
||||||
if (event.type === GeminiEventType.InvalidStream) {
|
if (event.type === GeminiEventType.InvalidStream) {
|
||||||
if (this.config.getContinueOnFailedApiCall()) {
|
isInvalidStream = true;
|
||||||
if (isInvalidStreamRetry) {
|
}
|
||||||
// We already retried once, so stop here.
|
if (event.type === GeminiEventType.Error) {
|
||||||
logContentRetryFailure(
|
isError = true;
|
||||||
this.config,
|
}
|
||||||
new ContentRetryFailureEvent(
|
}
|
||||||
4, // 2 initial + 2 after injections
|
|
||||||
'FAILED_AFTER_PROMPT_INJECTION',
|
if (isError) {
|
||||||
modelToUse,
|
return turn;
|
||||||
),
|
}
|
||||||
);
|
|
||||||
return turn;
|
// Update cumulative response in hook state
|
||||||
}
|
// We do this immediately after the stream finishes for THIS turn.
|
||||||
const nextRequest = [{ text: 'System: Please continue.' }];
|
const hooksEnabled = this.config.getEnableHooks();
|
||||||
yield* this.sendMessageStream(
|
if (hooksEnabled) {
|
||||||
nextRequest,
|
const responseText = turn.getResponseText() || '';
|
||||||
signal,
|
const hookState = this.hookStateMap.get(prompt_id);
|
||||||
prompt_id,
|
if (hookState && responseText) {
|
||||||
boundedTurns - 1,
|
// Append with newline if not empty
|
||||||
true, // Set isInvalidStreamRetry to true
|
hookState.cumulativeResponse = hookState.cumulativeResponse
|
||||||
|
? `${hookState.cumulativeResponse}\n${responseText}`
|
||||||
|
: responseText;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isInvalidStream) {
|
||||||
|
if (this.config.getContinueOnFailedApiCall()) {
|
||||||
|
if (isInvalidStreamRetry) {
|
||||||
|
logContentRetryFailure(
|
||||||
|
this.config,
|
||||||
|
new ContentRetryFailureEvent(
|
||||||
|
4,
|
||||||
|
'FAILED_AFTER_PROMPT_INJECTION',
|
||||||
|
modelToUse,
|
||||||
|
),
|
||||||
);
|
);
|
||||||
return turn;
|
return turn;
|
||||||
}
|
}
|
||||||
}
|
const nextRequest = [{ text: 'System: Please continue.' }];
|
||||||
if (event.type === GeminiEventType.Error) {
|
// Recursive call - update turn with result
|
||||||
return turn;
|
turn = yield* this.sendMessageStream(
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
|
|
||||||
// Check if next speaker check is needed
|
|
||||||
if (this.config.getQuotaErrorOccurred()) {
|
|
||||||
return turn;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.config.getSkipNextSpeakerCheck()) {
|
|
||||||
return turn;
|
|
||||||
}
|
|
||||||
|
|
||||||
const nextSpeakerCheck = await checkNextSpeaker(
|
|
||||||
this.getChat(),
|
|
||||||
this.config.getBaseLlmClient(),
|
|
||||||
signal,
|
|
||||||
prompt_id,
|
|
||||||
);
|
|
||||||
logNextSpeakerCheck(
|
|
||||||
this.config,
|
|
||||||
new NextSpeakerCheckEvent(
|
|
||||||
prompt_id,
|
|
||||||
turn.finishReason?.toString() || '',
|
|
||||||
nextSpeakerCheck?.next_speaker || '',
|
|
||||||
),
|
|
||||||
);
|
|
||||||
if (nextSpeakerCheck?.next_speaker === 'model') {
|
|
||||||
const nextRequest = [{ text: 'Please continue.' }];
|
|
||||||
// This recursive call's events will be yielded out, and the final
|
|
||||||
// turn object from the recursive call will be returned.
|
|
||||||
return yield* this.sendMessageStream(
|
|
||||||
nextRequest,
|
nextRequest,
|
||||||
signal,
|
signal,
|
||||||
prompt_id,
|
prompt_id,
|
||||||
boundedTurns - 1,
|
boundedTurns - 1,
|
||||||
// isInvalidStreamRetry is false here, as this is a next speaker check
|
true,
|
||||||
);
|
);
|
||||||
|
return turn;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fire AfterAgent hook through MessageBus (only if hooks are enabled)
|
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
|
||||||
if (hooksEnabled && messageBus) {
|
|
||||||
const responseText = turn.getResponseText() || '[no response text]';
|
|
||||||
const hookOutput = await fireAfterAgentHook(
|
|
||||||
messageBus,
|
|
||||||
request,
|
|
||||||
responseText,
|
|
||||||
);
|
|
||||||
|
|
||||||
// For AfterAgent hooks, blocking/stop execution should force continuation
|
|
||||||
if (
|
if (
|
||||||
hookOutput?.isBlockingDecision() ||
|
!this.config.getQuotaErrorOccurred() &&
|
||||||
hookOutput?.shouldStopExecution()
|
!this.config.getSkipNextSpeakerCheck()
|
||||||
) {
|
) {
|
||||||
const continueReason = hookOutput.getEffectiveReason();
|
const nextSpeakerCheck = await checkNextSpeaker(
|
||||||
const continueRequest = [{ text: continueReason }];
|
this.getChat(),
|
||||||
yield* this.sendMessageStream(
|
this.config.getBaseLlmClient(),
|
||||||
continueRequest,
|
|
||||||
signal,
|
signal,
|
||||||
prompt_id,
|
prompt_id,
|
||||||
boundedTurns - 1,
|
|
||||||
);
|
);
|
||||||
|
logNextSpeakerCheck(
|
||||||
|
this.config,
|
||||||
|
new NextSpeakerCheckEvent(
|
||||||
|
prompt_id,
|
||||||
|
turn.finishReason?.toString() || '',
|
||||||
|
nextSpeakerCheck?.next_speaker || '',
|
||||||
|
),
|
||||||
|
);
|
||||||
|
if (nextSpeakerCheck?.next_speaker === 'model') {
|
||||||
|
const nextRequest = [{ text: 'Please continue.' }];
|
||||||
|
turn = yield* this.sendMessageStream(
|
||||||
|
nextRequest,
|
||||||
|
signal,
|
||||||
|
prompt_id,
|
||||||
|
boundedTurns - 1,
|
||||||
|
// isInvalidStreamRetry is false
|
||||||
|
);
|
||||||
|
return turn;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return turn;
|
||||||
|
}
|
||||||
|
|
||||||
|
async *sendMessageStream(
|
||||||
|
request: PartListUnion,
|
||||||
|
signal: AbortSignal,
|
||||||
|
prompt_id: string,
|
||||||
|
turns: number = MAX_TURNS,
|
||||||
|
isInvalidStreamRetry: boolean = false,
|
||||||
|
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||||
|
if (!isInvalidStreamRetry) {
|
||||||
|
this.config.resetTurn();
|
||||||
|
}
|
||||||
|
|
||||||
|
const hooksEnabled = this.config.getEnableHooks();
|
||||||
|
const messageBus = this.config.getMessageBus();
|
||||||
|
|
||||||
|
if (this.lastPromptId !== prompt_id) {
|
||||||
|
this.loopDetector.reset(prompt_id);
|
||||||
|
this.hookStateMap.delete(this.lastPromptId);
|
||||||
|
this.lastPromptId = prompt_id;
|
||||||
|
this.currentSequenceModel = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hooksEnabled && messageBus) {
|
||||||
|
const hookResult = await this.fireBeforeAgentHookSafe(
|
||||||
|
messageBus,
|
||||||
|
request,
|
||||||
|
prompt_id,
|
||||||
|
);
|
||||||
|
if (hookResult) {
|
||||||
|
if ('type' in hookResult && hookResult.type === GeminiEventType.Error) {
|
||||||
|
yield hookResult;
|
||||||
|
return new Turn(this.getChat(), prompt_id);
|
||||||
|
} else if ('additionalContext' in hookResult) {
|
||||||
|
const additionalContext = hookResult.additionalContext;
|
||||||
|
if (additionalContext) {
|
||||||
|
const requestArray = Array.isArray(request) ? request : [request];
|
||||||
|
request = [...requestArray, { text: additionalContext }];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const boundedTurns = Math.min(turns, MAX_TURNS);
|
||||||
|
let turn = new Turn(this.getChat(), prompt_id);
|
||||||
|
|
||||||
|
try {
|
||||||
|
turn = yield* this.processTurn(
|
||||||
|
request,
|
||||||
|
signal,
|
||||||
|
prompt_id,
|
||||||
|
boundedTurns,
|
||||||
|
isInvalidStreamRetry,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Fire AfterAgent hook if we have a turn and no pending tools
|
||||||
|
if (hooksEnabled && messageBus) {
|
||||||
|
const hookOutput = await this.fireAfterAgentHookSafe(
|
||||||
|
messageBus,
|
||||||
|
request,
|
||||||
|
prompt_id,
|
||||||
|
turn,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (
|
||||||
|
hookOutput?.isBlockingDecision() ||
|
||||||
|
hookOutput?.shouldStopExecution()
|
||||||
|
) {
|
||||||
|
const continueReason = hookOutput.getEffectiveReason();
|
||||||
|
const continueRequest = [{ text: continueReason }];
|
||||||
|
yield* this.sendMessageStream(
|
||||||
|
continueRequest,
|
||||||
|
signal,
|
||||||
|
prompt_id,
|
||||||
|
boundedTurns - 1,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
const hookState = this.hookStateMap.get(prompt_id);
|
||||||
|
if (hookState) {
|
||||||
|
hookState.activeCalls--;
|
||||||
|
const isPendingTools =
|
||||||
|
turn?.pendingToolCalls && turn.pendingToolCalls.length > 0;
|
||||||
|
const isAborted = signal?.aborted;
|
||||||
|
|
||||||
|
if (hookState.activeCalls <= 0) {
|
||||||
|
if (!isPendingTools || isAborted) {
|
||||||
|
this.hookStateMap.delete(prompt_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user