mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-22 17:53:04 -07:00
feat(core): implement additionalContext for BeforeModel hooks and aggregate context from multiple hooks
This commit is contained in:
@@ -197,6 +197,10 @@ request format.
|
||||
outgoing request (e.g., changing models or temperature).
|
||||
- `hookSpecificOutput.llm_response`: A **Synthetic Response** object. If
|
||||
provided, the CLI skips the LLM call entirely and uses this as the response.
|
||||
- `hookSpecificOutput.additionalContext`: (`string`) Text that is **appended**
|
||||
to the end of the user message (wrapped in `<hook_context>` tags). If
|
||||
`llm_request` also modifies contents, it will be appended to the modified
|
||||
contents.
|
||||
- `decision`: Set to `"deny"` to block the request and abort the turn.
|
||||
- **Exit Code 2 (Block Turn)**: Aborts the turn and skips the LLM call. Uses
|
||||
`stderr` as the error message.
|
||||
|
||||
@@ -423,6 +423,52 @@ describe('HookAggregator', () => {
|
||||
const llmRequest = output.hookSpecificOutput?.llm_request;
|
||||
expect(llmRequest?.['model']).toBe('model2'); // Later value wins
|
||||
});
|
||||
|
||||
it('should aggregate additionalContext for BeforeModel hooks', () => {
|
||||
const results: HookExecutionResult[] = [
|
||||
{
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: 'h1',
|
||||
timeout: 30000,
|
||||
},
|
||||
eventName: HookEventName.BeforeModel,
|
||||
success: true,
|
||||
output: {
|
||||
hookSpecificOutput: {
|
||||
hookEventName: 'BeforeModel',
|
||||
additionalContext: 'Context 1',
|
||||
},
|
||||
},
|
||||
duration: 10,
|
||||
},
|
||||
{
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: 'h2',
|
||||
timeout: 30000,
|
||||
},
|
||||
eventName: HookEventName.BeforeModel,
|
||||
success: true,
|
||||
output: {
|
||||
hookSpecificOutput: {
|
||||
hookEventName: 'BeforeModel',
|
||||
additionalContext: 'Context 2',
|
||||
},
|
||||
},
|
||||
duration: 10,
|
||||
},
|
||||
];
|
||||
|
||||
const aggregated = aggregator.aggregateResults(
|
||||
results,
|
||||
HookEventName.BeforeModel,
|
||||
);
|
||||
|
||||
expect(
|
||||
aggregated.finalOutput?.hookSpecificOutput?.['additionalContext'],
|
||||
).toBe('Context 1\nContext 2');
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractAdditionalContext', () => {
|
||||
|
||||
@@ -221,8 +221,12 @@ export class HookAggregator {
|
||||
*/
|
||||
private mergeWithFieldReplacement(outputs: HookOutput[]): HookOutput {
|
||||
let merged: HookOutput = {};
|
||||
const additionalContexts: string[] = [];
|
||||
|
||||
for (const output of outputs) {
|
||||
// Collect additional context
|
||||
this.extractAdditionalContext(output, additionalContexts);
|
||||
|
||||
// Later outputs override earlier ones
|
||||
merged = {
|
||||
...merged,
|
||||
@@ -234,6 +238,14 @@ export class HookAggregator {
|
||||
};
|
||||
}
|
||||
|
||||
// Add merged additional context
|
||||
if (additionalContexts.length > 0) {
|
||||
merged.hookSpecificOutput = {
|
||||
...(merged.hookSpecificOutput || {}),
|
||||
additionalContext: additionalContexts.join('\n'),
|
||||
};
|
||||
}
|
||||
|
||||
return merged;
|
||||
}
|
||||
|
||||
|
||||
@@ -655,10 +655,9 @@ describe('HookRunner', () => {
|
||||
|
||||
// Verify that the second hook received modified input
|
||||
const secondHookInput = JSON.parse(
|
||||
vi.mocked(mockSpawn.stdin.write).mock.calls[1][0],
|
||||
vi.mocked(mockSpawn.stdin.write).mock.calls[1][0] as string,
|
||||
);
|
||||
expect(secondHookInput.prompt).toContain('Original prompt');
|
||||
expect(secondHookInput.prompt).toContain('Context from hook 1');
|
||||
expect(secondHookInput.prompt).toBe('Original prompt');
|
||||
});
|
||||
|
||||
it('should pass modified LLM request from one hook to the next for BeforeModel', async () => {
|
||||
|
||||
@@ -15,7 +15,6 @@ import {
|
||||
type HookInput,
|
||||
type HookOutput,
|
||||
type HookExecutionResult,
|
||||
type BeforeAgentInput,
|
||||
type BeforeModelInput,
|
||||
type BeforeModelOutput,
|
||||
type BeforeToolInput,
|
||||
@@ -180,22 +179,6 @@ export class HookRunner {
|
||||
// Apply modifications based on hook output and event type
|
||||
if (hookOutput.hookSpecificOutput) {
|
||||
switch (eventName) {
|
||||
case HookEventName.BeforeAgent:
|
||||
if ('additionalContext' in hookOutput.hookSpecificOutput) {
|
||||
// For BeforeAgent, we could modify the prompt with additional context
|
||||
const additionalContext =
|
||||
hookOutput.hookSpecificOutput['additionalContext'];
|
||||
if (
|
||||
typeof additionalContext === 'string' &&
|
||||
'prompt' in modifiedInput
|
||||
) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
(modifiedInput as BeforeAgentInput).prompt +=
|
||||
'\n\n' + additionalContext;
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case HookEventName.BeforeModel:
|
||||
if ('llm_request' in hookOutput.hookSpecificOutput) {
|
||||
// For BeforeModel, we update the LLM request
|
||||
|
||||
@@ -24,6 +24,7 @@ import {
|
||||
import type {
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
Content,
|
||||
ToolConfig,
|
||||
} from '@google/genai';
|
||||
|
||||
@@ -271,6 +272,66 @@ describe('Hook Output Classes', () => {
|
||||
const output = new BeforeModelHookOutput({});
|
||||
expect(output.applyLLMRequestModifications(target)).toBe(target);
|
||||
});
|
||||
|
||||
it('applyLLMRequestModifications should append additionalContext to contents', () => {
|
||||
const target: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
const output = new BeforeModelHookOutput({
|
||||
hookSpecificOutput: { additionalContext: 'New Context' },
|
||||
});
|
||||
const result = output.applyLLMRequestModifications(target);
|
||||
expect(result.contents).toHaveLength(1);
|
||||
const contents = result.contents as Content[];
|
||||
expect(contents[0].parts!).toHaveLength(2);
|
||||
expect(contents[0].parts![1]).toEqual({
|
||||
text: '\n\n<hook_context>New Context</hook_context>',
|
||||
});
|
||||
});
|
||||
|
||||
it('applyLLMRequestModifications should create new user message if none exists', () => {
|
||||
const target: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: [{ role: 'model', parts: [{ text: 'Hi' }] }],
|
||||
};
|
||||
const output = new BeforeModelHookOutput({
|
||||
hookSpecificOutput: { additionalContext: 'New Context' },
|
||||
});
|
||||
const result = output.applyLLMRequestModifications(target);
|
||||
expect(result.contents).toHaveLength(2);
|
||||
const contents = result.contents as Content[];
|
||||
expect(contents[1].role).toBe('user');
|
||||
expect(contents[1].parts![0]).toEqual({
|
||||
text: '\n\n<hook_context>New Context</hook_context>',
|
||||
});
|
||||
});
|
||||
|
||||
it('applyLLMRequestModifications should handle both llm_request and additionalContext', () => {
|
||||
const target: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'original' }] }],
|
||||
};
|
||||
const mockRequest = {
|
||||
// Our mock fromHookLLMRequest just spreads the request, so we can use contents here for convenience in tests
|
||||
// even though LLMRequest uses messages.
|
||||
contents: [{ role: 'user', parts: [{ text: 'modified' }] }],
|
||||
} as unknown as Partial<LLMRequest>;
|
||||
const output = new BeforeModelHookOutput({
|
||||
hookSpecificOutput: {
|
||||
llm_request: mockRequest as LLMRequest,
|
||||
additionalContext: 'New Context',
|
||||
},
|
||||
});
|
||||
const result = output.applyLLMRequestModifications(target);
|
||||
expect(result.contents).toHaveLength(1);
|
||||
const contents = result.contents as Content[];
|
||||
expect(contents[0].parts!).toHaveLength(2);
|
||||
expect(contents[0].parts![0]).toEqual({ text: 'modified' });
|
||||
expect(contents[0].parts![1]).toEqual({
|
||||
text: '\n\n<hook_context>New Context</hook_context>',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('BeforeToolSelectionHookOutput', () => {
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
import type {
|
||||
GenerateContentResponse,
|
||||
GenerateContentParameters,
|
||||
Content,
|
||||
ToolConfig as GenAIToolConfig,
|
||||
ToolListUnion,
|
||||
} from '@google/genai';
|
||||
@@ -374,6 +375,8 @@ export class BeforeModelHookOutput extends DefaultHookOutput {
|
||||
override applyLLMRequestModifications(
|
||||
target: GenerateContentParameters,
|
||||
): GenerateContentParameters {
|
||||
let resultTarget = target;
|
||||
|
||||
if (this.hookSpecificOutput && 'llm_request' in this.hookSpecificOutput) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const hookRequest = this.hookSpecificOutput[
|
||||
@@ -386,13 +389,62 @@ export class BeforeModelHookOutput extends DefaultHookOutput {
|
||||
hookRequest as LLMRequest,
|
||||
target,
|
||||
);
|
||||
return {
|
||||
resultTarget = {
|
||||
...target,
|
||||
...sdkRequest,
|
||||
};
|
||||
}
|
||||
}
|
||||
return target;
|
||||
|
||||
const additionalContext = this.getAdditionalContext();
|
||||
if (additionalContext) {
|
||||
const originalContents = resultTarget.contents;
|
||||
let contents: Content[];
|
||||
|
||||
if (Array.isArray(originalContents)) {
|
||||
contents = (originalContents as unknown[]).map((c) => {
|
||||
if (typeof c === 'string') {
|
||||
return { role: 'user', parts: [{ text: c }] };
|
||||
}
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const content = c as Content;
|
||||
return { ...content, parts: [...(content.parts || [])] };
|
||||
});
|
||||
} else if (typeof originalContents === 'string') {
|
||||
contents = [{ role: 'user', parts: [{ text: originalContents }] }];
|
||||
} else if (originalContents && 'role' in originalContents) {
|
||||
const c = originalContents;
|
||||
contents = [{ ...c, parts: [...(c.parts || [])] }];
|
||||
} else {
|
||||
contents = [];
|
||||
}
|
||||
|
||||
const wrappedContext = `\n\n<hook_context>${additionalContext}</hook_context>`;
|
||||
|
||||
let lastUserMessageIndex = -1;
|
||||
for (let i = contents.length - 1; i >= 0; i--) {
|
||||
if (contents[i].role === 'user') {
|
||||
lastUserMessageIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (lastUserMessageIndex !== -1) {
|
||||
if (!contents[lastUserMessageIndex].parts) {
|
||||
contents[lastUserMessageIndex].parts = [];
|
||||
}
|
||||
contents[lastUserMessageIndex].parts!.push({ text: wrappedContext });
|
||||
} else {
|
||||
contents.push({ role: 'user', parts: [{ text: wrappedContext }] });
|
||||
}
|
||||
|
||||
resultTarget = {
|
||||
...resultTarget,
|
||||
contents,
|
||||
};
|
||||
}
|
||||
|
||||
return resultTarget;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -683,6 +735,7 @@ export interface BeforeModelOutput extends HookOutput {
|
||||
hookEventName: 'BeforeModel';
|
||||
llm_request?: Partial<LLMRequest>;
|
||||
llm_response?: LLMResponse;
|
||||
additionalContext?: string;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user