feat(core): update context aggregation to use per-hook <hook_context> blocks for better provenance

This commit is contained in:
Michael Bleigh
2026-03-26 16:26:23 -07:00
parent 8faf2e31ea
commit cda12a4f02
9 changed files with 101 additions and 51 deletions
+2 -3
View File
@@ -198,9 +198,8 @@ request format.
- `hookSpecificOutput.llm_response`: A **Synthetic Response** object. If
provided, the CLI skips the LLM call entirely and uses this as the response.
- `hookSpecificOutput.additionalContext`: (`string`) Text that is **appended**
to the end of the user message (wrapped in `<hook_context>` tags). If
`llm_request` also modifies contents, it will be appended to the modified
contents.
to the end of the user message. If `llm_request` also modifies contents, it
will be appended to the modified contents.
- `decision`: Set to `"deny"` to block the request and abort the turn.
- **Exit Code 2 (Block Turn)**: Aborts the turn and skips the LLM call. Uses
`stderr` as the error message.
+3 -2
View File
@@ -665,8 +665,9 @@ export async function main() {
const additionalContext = result.getAdditionalContext();
if (additionalContext) {
// Prepend context to input (System Context -> Stdin -> Question)
const wrappedContext = `<hook_context>${additionalContext}</hook_context>`;
input = input ? `${wrappedContext}\n\n${input}` : wrappedContext;
input = input
? `${additionalContext}\n\n${input}`
: additionalContext;
}
}
}
+2 -4
View File
@@ -432,13 +432,11 @@ export const AppContainer = (props: AppContainerProps) => {
}
const additionalContext = result.getAdditionalContext();
const geminiClient = config.getGeminiClient();
if (additionalContext && geminiClient) {
await geminiClient.addHistory({
role: 'user',
parts: [
{ text: `<hook_context>${additionalContext}</hook_context>` },
],
parts: [{ text: `\n\n${additionalContext}` }],
});
}
}
+1 -4
View File
@@ -913,10 +913,7 @@ export class GeminiClient {
const additionalContext = hookResult.additionalContext;
if (additionalContext) {
const requestArray = Array.isArray(request) ? request : [request];
request = [
...requestArray,
{ text: `<hook_context>${additionalContext}</hook_context>` },
];
request = [...requestArray, { text: `\n\n${additionalContext}` }];
}
}
}
@@ -220,7 +220,7 @@ export async function executeToolWithHooks(
// Add additional context from hooks to the tool result
const additionalContext = afterOutput?.getAdditionalContext();
if (additionalContext) {
const wrappedContext = `\n\n<hook_context>${additionalContext}</hook_context>`;
const wrappedContext = `\n\n${additionalContext}`;
if (typeof toolResult.llmContent === 'string') {
toolResult.llmContent += wrappedContext;
} else if (Array.isArray(toolResult.llmContent)) {
+37 -2
View File
@@ -467,7 +467,9 @@ describe('HookAggregator', () => {
expect(
aggregated.finalOutput?.hookSpecificOutput?.['additionalContext'],
).toBe('Context 1\nContext 2');
).toBe(
'<hook_context hook="h1">\nContext 1\n</hook_context>\n<hook_context hook="h2">\nContext 2\n</hook_context>',
);
});
});
@@ -516,7 +518,40 @@ describe('HookAggregator', () => {
expect(aggregated.success).toBe(true);
expect(
aggregated.finalOutput?.hookSpecificOutput?.['additionalContext'],
).toBe('Context from hook 1\nContext from hook 2');
).toBe(
'<hook_context hook="test-command">\nContext from hook 1\n</hook_context>\n<hook_context hook="test-command">\nContext from hook 2\n</hook_context>',
);
});
it('should sanitize additional context by escaping < and > tags', () => {
const results: HookExecutionResult[] = [
{
hookConfig: {
type: HookType.Command,
command: 'test-hook',
},
eventName: HookEventName.AfterTool,
success: true,
output: {
hookSpecificOutput: {
hookEventName: 'AfterTool',
additionalContext: 'context with <b>bold</b> and <script> tags',
},
},
duration: 10,
},
];
const aggregated = aggregator.aggregateResults(
results,
HookEventName.AfterTool,
);
expect(
aggregated.finalOutput?.hookSpecificOutput?.['additionalContext'],
).toBe(
'<hook_context hook="test-hook">\ncontext with &lt;b&gt;bold&lt;/b&gt; and &lt;script&gt; tags\n</hook_context>',
);
});
});
});
+38 -24
View File
@@ -58,7 +58,7 @@ export class HookAggregator {
}
// Merge outputs using event-specific strategy
const mergedOutput = this.mergeOutputs(allOutputs, eventName);
const mergedOutput = this.mergeOutputs(results, eventName);
const finalOutput = mergedOutput
? this.createSpecificHookOutput(mergedOutput, eventName)
: undefined;
@@ -79,10 +79,11 @@ export class HookAggregator {
* consistent default behaviors (e.g., default decision='allow' for OR logic)
*/
private mergeOutputs(
outputs: HookOutput[],
results: HookExecutionResult[],
eventName: HookEventName,
): HookOutput | undefined {
if (outputs.length === 0) {
const resultsWithOutput = results.filter((r) => r.output);
if (resultsWithOutput.length === 0) {
return undefined;
}
@@ -92,28 +93,25 @@ export class HookAggregator {
case HookEventName.BeforeAgent:
case HookEventName.AfterAgent:
case HookEventName.SessionStart:
return this.mergeWithOrDecision(outputs);
return this.mergeWithOrDecision(resultsWithOutput);
case HookEventName.BeforeModel:
case HookEventName.AfterModel:
return this.mergeWithFieldReplacement(outputs);
return this.mergeWithFieldReplacement(resultsWithOutput);
case HookEventName.BeforeToolSelection:
return this.mergeToolSelectionOutputs(
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
outputs as BeforeToolSelectionOutput[],
);
return this.mergeToolSelectionOutputs(resultsWithOutput);
default:
// For other events, use simple merge
return this.mergeSimple(outputs);
return this.mergeSimple(resultsWithOutput);
}
}
/**
* Merge outputs with OR decision logic and message concatenation
*/
private mergeWithOrDecision(outputs: HookOutput[]): HookOutput {
private mergeWithOrDecision(results: HookExecutionResult[]): HookOutput {
const merged: HookOutput = {
continue: true,
suppressOutput: false,
@@ -128,7 +126,9 @@ export class HookAggregator {
let hasAskDecision = false;
let hasContinueFalse = false;
for (const output of outputs) {
for (const result of results) {
const output = result.output!;
// Handle continue flag
if (output.continue === false) {
hasContinueFalse = true;
@@ -184,7 +184,7 @@ export class HookAggregator {
}
// Collect additional context from hook-specific outputs
this.extractAdditionalContext(output, additionalContexts);
this.extractAdditionalContext(result, additionalContexts);
}
// Set final decision if no blocking or ask decision was found
@@ -219,13 +219,17 @@ export class HookAggregator {
/**
* Merge outputs with later fields replacing earlier fields
*/
private mergeWithFieldReplacement(outputs: HookOutput[]): HookOutput {
private mergeWithFieldReplacement(
results: HookExecutionResult[],
): HookOutput {
let merged: HookOutput = {};
const additionalContexts: string[] = [];
for (const output of outputs) {
for (const result of results) {
const output = result.output!;
// Collect additional context
this.extractAdditionalContext(output, additionalContexts);
this.extractAdditionalContext(result, additionalContexts);
// Later outputs override earlier ones
merged = {
@@ -263,7 +267,7 @@ export class HookAggregator {
* If one hook restricts and another re-enables, the union takes the re-enabled tool.
*/
private mergeToolSelectionOutputs(
outputs: BeforeToolSelectionOutput[],
results: HookExecutionResult[],
): BeforeToolSelectionOutput {
const merged: BeforeToolSelectionOutput = {};
@@ -271,7 +275,9 @@ export class HookAggregator {
let hasNoneMode = false;
let hasAnyMode = false;
for (const output of outputs) {
for (const result of results) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const output = result.output as BeforeToolSelectionOutput;
const toolConfig = output.hookSpecificOutput?.toolConfig;
if (!toolConfig) {
continue;
@@ -326,11 +332,11 @@ export class HookAggregator {
/**
* Simple merge for events without special logic
*/
private mergeSimple(outputs: HookOutput[]): HookOutput {
private mergeSimple(results: HookExecutionResult[]): HookOutput {
let merged: HookOutput = {};
for (const output of outputs) {
merged = { ...merged, ...output };
for (const result of results) {
merged = { ...merged, ...result.output! };
}
return merged;
@@ -363,10 +369,10 @@ export class HookAggregator {
* Extract additional context from hook-specific outputs
*/
private extractAdditionalContext(
output: HookOutput,
result: HookExecutionResult,
contexts: string[],
): void {
const specific = output.hookSpecificOutput;
const specific = result.output?.hookSpecificOutput;
if (!specific) {
return;
}
@@ -377,7 +383,15 @@ export class HookAggregator {
// eslint-disable-next-line no-restricted-syntax
typeof specific['additionalContext'] === 'string'
) {
contexts.push(specific['additionalContext']);
const hookName =
result.hookConfig.name || result.hookConfig.command || 'unknown-hook';
// Sanitize the context text before wrapping to prevent tag injection
const contextText = specific['additionalContext']
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;');
contexts.push(
`<hook_context hook="${hookName}">\n${contextText}\n</hook_context>`,
);
}
}
}
+15 -8
View File
@@ -188,14 +188,14 @@ describe('Hook Output Classes', () => {
expect(output.getAdditionalContext()).toBe('some context');
});
it('getAdditionalContext should sanitize context by escaping <', () => {
it('getAdditionalContext should return raw context without sanitization (handled by aggregator)', () => {
const output = new DefaultHookOutput({
hookSpecificOutput: {
additionalContext: 'context with <tag> and </hook_context>',
},
});
expect(output.getAdditionalContext()).toBe(
'context with &lt;tag&gt; and &lt;/hook_context&gt;',
'context with <tag> and </hook_context>',
);
});
@@ -279,14 +279,17 @@ describe('Hook Output Classes', () => {
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const output = new BeforeModelHookOutput({
hookSpecificOutput: { additionalContext: 'New Context' },
hookSpecificOutput: {
additionalContext:
'<hook_context hook="test">\nNew Context\n</hook_context>',
},
});
const result = output.applyLLMRequestModifications(target);
expect(result.contents).toHaveLength(1);
const contents = result.contents as Content[];
expect(contents[0].parts!).toHaveLength(2);
expect(contents[0].parts![1]).toEqual({
text: '\n\n<hook_context>New Context</hook_context>',
text: '\n\n<hook_context hook="test">\nNew Context\n</hook_context>',
});
});
@@ -296,14 +299,17 @@ describe('Hook Output Classes', () => {
contents: [{ role: 'model', parts: [{ text: 'Hi' }] }],
};
const output = new BeforeModelHookOutput({
hookSpecificOutput: { additionalContext: 'New Context' },
hookSpecificOutput: {
additionalContext:
'<hook_context hook="test">\nNew Context\n</hook_context>',
},
});
const result = output.applyLLMRequestModifications(target);
expect(result.contents).toHaveLength(2);
const contents = result.contents as Content[];
expect(contents[1].role).toBe('user');
expect(contents[1].parts![0]).toEqual({
text: '\n\n<hook_context>New Context</hook_context>',
text: '\n\n<hook_context hook="test">\nNew Context\n</hook_context>',
});
});
@@ -320,7 +326,8 @@ describe('Hook Output Classes', () => {
const output = new BeforeModelHookOutput({
hookSpecificOutput: {
llm_request: mockRequest as LLMRequest,
additionalContext: 'New Context',
additionalContext:
'<hook_context hook="test">\nNew Context\n</hook_context>',
},
});
const result = output.applyLLMRequestModifications(target);
@@ -329,7 +336,7 @@ describe('Hook Output Classes', () => {
expect(contents[0].parts!).toHaveLength(2);
expect(contents[0].parts![0]).toEqual({ text: 'modified' });
expect(contents[0].parts![1]).toEqual({
text: '\n\n<hook_context>New Context</hook_context>',
text: '\n\n<hook_context hook="test">\nNew Context\n</hook_context>',
});
});
});
+2 -3
View File
@@ -271,8 +271,7 @@ export class DefaultHookOutput implements HookOutput {
return undefined;
}
// Sanitize by escaping < and > to prevent tag injection
return context.replace(/</g, '&lt;').replace(/>/g, '&gt;');
return context;
}
return undefined;
}
@@ -419,7 +418,7 @@ export class BeforeModelHookOutput extends DefaultHookOutput {
contents = [];
}
const wrappedContext = `\n\n<hook_context>${additionalContext}</hook_context>`;
const wrappedContext = `\n\n${additionalContext}`;
let lastUserMessageIndex = -1;
for (let i = contents.length - 1; i >= 0; i--) {