mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 13:22:35 -07:00
Merge remote-tracking branch 'origin/main' into akkr/subagents
This commit is contained in:
@@ -439,6 +439,54 @@ auth:
|
||||
});
|
||||
});
|
||||
|
||||
it('should parse remote agent with Digest via raw value', async () => {
|
||||
const filePath = await writeAgentMarkdown(`---
|
||||
kind: remote
|
||||
name: digest-agent
|
||||
agent_card_url: https://example.com/card
|
||||
auth:
|
||||
type: http
|
||||
scheme: Digest
|
||||
value: username="admin", response="abc123"
|
||||
---
|
||||
`);
|
||||
const result = await parseAgentMarkdown(filePath);
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0]).toMatchObject({
|
||||
kind: 'remote',
|
||||
name: 'digest-agent',
|
||||
auth: {
|
||||
type: 'http',
|
||||
scheme: 'Digest',
|
||||
value: 'username="admin", response="abc123"',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should parse remote agent with generic raw auth value', async () => {
|
||||
const filePath = await writeAgentMarkdown(`---
|
||||
kind: remote
|
||||
name: raw-agent
|
||||
agent_card_url: https://example.com/card
|
||||
auth:
|
||||
type: http
|
||||
scheme: CustomScheme
|
||||
value: raw-token-value
|
||||
---
|
||||
`);
|
||||
const result = await parseAgentMarkdown(filePath);
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0]).toMatchObject({
|
||||
kind: 'remote',
|
||||
name: 'raw-agent',
|
||||
auth: {
|
||||
type: 'http',
|
||||
scheme: 'CustomScheme',
|
||||
value: 'raw-token-value',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should throw error for Bearer auth without token', async () => {
|
||||
const filePath = await writeAgentMarkdown(`---
|
||||
kind: remote
|
||||
|
||||
@@ -58,10 +58,11 @@ interface FrontmatterAuthConfig {
|
||||
key?: string;
|
||||
name?: string;
|
||||
// HTTP
|
||||
scheme?: 'Bearer' | 'Basic';
|
||||
scheme?: string;
|
||||
token?: string;
|
||||
username?: string;
|
||||
password?: string;
|
||||
value?: string;
|
||||
}
|
||||
|
||||
interface FrontmatterRemoteAgentDefinition
|
||||
@@ -149,16 +150,21 @@ const apiKeyAuthSchema = z.object({
|
||||
const httpAuthSchema = z.object({
|
||||
...baseAuthFields,
|
||||
type: z.literal('http'),
|
||||
scheme: z.enum(['Bearer', 'Basic']),
|
||||
scheme: z.string().min(1),
|
||||
token: z.string().min(1).optional(),
|
||||
username: z.string().min(1).optional(),
|
||||
password: z.string().min(1).optional(),
|
||||
value: z.string().min(1).optional(),
|
||||
});
|
||||
|
||||
const authConfigSchema = z
|
||||
.discriminatedUnion('type', [apiKeyAuthSchema, httpAuthSchema])
|
||||
.superRefine((data, ctx) => {
|
||||
if (data.type === 'http') {
|
||||
if (data.value) {
|
||||
// Raw mode - only scheme and value are needed
|
||||
return;
|
||||
}
|
||||
if (data.scheme === 'Bearer' && !data.token) {
|
||||
ctx.addIssue({
|
||||
code: z.ZodIssueCode.custom,
|
||||
@@ -358,6 +364,14 @@ function convertFrontmatterAuthToConfig(
|
||||
'Internal error: HTTP scheme missing after validation.',
|
||||
);
|
||||
}
|
||||
if (frontmatter.value) {
|
||||
return {
|
||||
...base,
|
||||
type: 'http',
|
||||
scheme: frontmatter.scheme,
|
||||
value: frontmatter.value,
|
||||
};
|
||||
}
|
||||
switch (frontmatter.scheme) {
|
||||
case 'Bearer':
|
||||
if (!frontmatter.token) {
|
||||
@@ -385,8 +399,8 @@ function convertFrontmatterAuthToConfig(
|
||||
password: frontmatter.password,
|
||||
};
|
||||
default: {
|
||||
const exhaustive: never = frontmatter.scheme;
|
||||
throw new Error(`Unknown HTTP scheme: ${exhaustive}`);
|
||||
// Other IANA schemes without a value should not reach here after validation
|
||||
throw new Error(`Unknown HTTP scheme: ${frontmatter.scheme}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import type {
|
||||
AuthValidationResult,
|
||||
} from './types.js';
|
||||
import { ApiKeyAuthProvider } from './api-key-provider.js';
|
||||
import { HttpAuthProvider } from './http-provider.js';
|
||||
|
||||
export interface CreateAuthProviderOptions {
|
||||
/** Required for OAuth/OIDC token storage. */
|
||||
@@ -50,9 +51,11 @@ export class A2AAuthProviderFactory {
|
||||
return provider;
|
||||
}
|
||||
|
||||
case 'http':
|
||||
// TODO: Implement
|
||||
throw new Error('http auth provider not yet implemented');
|
||||
case 'http': {
|
||||
const provider = new HttpAuthProvider(authConfig);
|
||||
await provider.initialize();
|
||||
return provider;
|
||||
}
|
||||
|
||||
case 'oauth2':
|
||||
// TODO: Implement
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { HttpAuthProvider } from './http-provider.js';
|
||||
|
||||
describe('HttpAuthProvider', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('Bearer Authentication', () => {
|
||||
it('should provide Bearer token header', async () => {
|
||||
const config = {
|
||||
type: 'http' as const,
|
||||
scheme: 'Bearer' as const,
|
||||
token: 'test-token',
|
||||
};
|
||||
const provider = new HttpAuthProvider(config);
|
||||
await provider.initialize();
|
||||
|
||||
const headers = await provider.headers();
|
||||
expect(headers).toEqual({ Authorization: 'Bearer test-token' });
|
||||
});
|
||||
|
||||
it('should resolve token from environment variable', async () => {
|
||||
process.env['TEST_TOKEN'] = 'env-token';
|
||||
const config = {
|
||||
type: 'http' as const,
|
||||
scheme: 'Bearer' as const,
|
||||
token: '$TEST_TOKEN',
|
||||
};
|
||||
const provider = new HttpAuthProvider(config);
|
||||
await provider.initialize();
|
||||
|
||||
const headers = await provider.headers();
|
||||
expect(headers).toEqual({ Authorization: 'Bearer env-token' });
|
||||
delete process.env['TEST_TOKEN'];
|
||||
});
|
||||
});
|
||||
|
||||
describe('Basic Authentication', () => {
|
||||
it('should provide Basic auth header', async () => {
|
||||
const config = {
|
||||
type: 'http' as const,
|
||||
scheme: 'Basic' as const,
|
||||
username: 'user',
|
||||
password: 'password',
|
||||
};
|
||||
const provider = new HttpAuthProvider(config);
|
||||
await provider.initialize();
|
||||
|
||||
const headers = await provider.headers();
|
||||
const expected = Buffer.from('user:password').toString('base64');
|
||||
expect(headers).toEqual({ Authorization: `Basic ${expected}` });
|
||||
});
|
||||
});
|
||||
|
||||
describe('Generic/Raw Authentication', () => {
|
||||
it('should provide custom scheme with raw value', async () => {
|
||||
const config = {
|
||||
type: 'http' as const,
|
||||
scheme: 'CustomScheme',
|
||||
value: 'raw-value-here',
|
||||
};
|
||||
const provider = new HttpAuthProvider(config);
|
||||
await provider.initialize();
|
||||
|
||||
const headers = await provider.headers();
|
||||
expect(headers).toEqual({ Authorization: 'CustomScheme raw-value-here' });
|
||||
});
|
||||
|
||||
it('should support Digest via raw value', async () => {
|
||||
const config = {
|
||||
type: 'http' as const,
|
||||
scheme: 'Digest',
|
||||
value: 'username="foo", response="bar"',
|
||||
};
|
||||
const provider = new HttpAuthProvider(config);
|
||||
await provider.initialize();
|
||||
|
||||
const headers = await provider.headers();
|
||||
expect(headers).toEqual({
|
||||
Authorization: 'Digest username="foo", response="bar"',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Retry logic', () => {
|
||||
it('should re-initialize on 401 for Bearer', async () => {
|
||||
const config = {
|
||||
type: 'http' as const,
|
||||
scheme: 'Bearer' as const,
|
||||
token: '$DYNAMIC_TOKEN',
|
||||
};
|
||||
process.env['DYNAMIC_TOKEN'] = 'first';
|
||||
const provider = new HttpAuthProvider(config);
|
||||
await provider.initialize();
|
||||
|
||||
process.env['DYNAMIC_TOKEN'] = 'second';
|
||||
const mockResponse = { status: 401 } as Response;
|
||||
const retryHeaders = await provider.shouldRetryWithHeaders(
|
||||
{},
|
||||
mockResponse,
|
||||
);
|
||||
|
||||
expect(retryHeaders).toEqual({ Authorization: 'Bearer second' });
|
||||
delete process.env['DYNAMIC_TOKEN'];
|
||||
});
|
||||
|
||||
it('should stop after max retries', async () => {
|
||||
const config = {
|
||||
type: 'http' as const,
|
||||
scheme: 'Bearer' as const,
|
||||
token: 'token',
|
||||
};
|
||||
const provider = new HttpAuthProvider(config);
|
||||
await provider.initialize();
|
||||
|
||||
const mockResponse = { status: 401 } as Response;
|
||||
|
||||
// MAX_AUTH_RETRIES is 2
|
||||
await provider.shouldRetryWithHeaders({}, mockResponse);
|
||||
await provider.shouldRetryWithHeaders({}, mockResponse);
|
||||
const third = await provider.shouldRetryWithHeaders({}, mockResponse);
|
||||
|
||||
expect(third).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,88 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { HttpHeaders } from '@a2a-js/sdk/client';
|
||||
import { BaseA2AAuthProvider } from './base-provider.js';
|
||||
import type { HttpAuthConfig } from './types.js';
|
||||
import { resolveAuthValue } from './value-resolver.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
|
||||
/**
|
||||
* Authentication provider for HTTP authentication schemes.
|
||||
* Supports Bearer, Basic, and any IANA-registered scheme via raw value.
|
||||
*/
|
||||
export class HttpAuthProvider extends BaseA2AAuthProvider {
|
||||
readonly type = 'http' as const;
|
||||
|
||||
private resolvedToken?: string;
|
||||
private resolvedUsername?: string;
|
||||
private resolvedPassword?: string;
|
||||
private resolvedValue?: string;
|
||||
|
||||
constructor(private readonly config: HttpAuthConfig) {
|
||||
super();
|
||||
}
|
||||
|
||||
override async initialize(): Promise<void> {
|
||||
const config = this.config;
|
||||
if ('token' in config) {
|
||||
this.resolvedToken = await resolveAuthValue(config.token);
|
||||
} else if ('username' in config) {
|
||||
this.resolvedUsername = await resolveAuthValue(config.username);
|
||||
this.resolvedPassword = await resolveAuthValue(config.password);
|
||||
} else {
|
||||
// Generic raw value for any other IANA-registered scheme
|
||||
this.resolvedValue = await resolveAuthValue(config.value);
|
||||
}
|
||||
debugLogger.debug(
|
||||
`[HttpAuthProvider] Initialized with scheme: ${this.config.scheme}`,
|
||||
);
|
||||
}
|
||||
|
||||
override async headers(): Promise<HttpHeaders> {
|
||||
const config = this.config;
|
||||
if ('token' in config) {
|
||||
if (!this.resolvedToken)
|
||||
throw new Error('HttpAuthProvider not initialized');
|
||||
return { Authorization: `Bearer ${this.resolvedToken}` };
|
||||
}
|
||||
|
||||
if ('username' in config) {
|
||||
if (!this.resolvedUsername || !this.resolvedPassword) {
|
||||
throw new Error('HttpAuthProvider not initialized');
|
||||
}
|
||||
const credentials = Buffer.from(
|
||||
`${this.resolvedUsername}:${this.resolvedPassword}`,
|
||||
).toString('base64');
|
||||
return { Authorization: `Basic ${credentials}` };
|
||||
}
|
||||
|
||||
// Generic raw value for any other IANA-registered scheme
|
||||
if (!this.resolvedValue)
|
||||
throw new Error('HttpAuthProvider not initialized');
|
||||
return { Authorization: `${config.scheme} ${this.resolvedValue}` };
|
||||
}
|
||||
|
||||
/**
|
||||
* Re-resolves credentials on auth failure (e.g. rotated tokens via $ENV or !command).
|
||||
* Respects MAX_AUTH_RETRIES from the base class to prevent infinite loops.
|
||||
*/
|
||||
override async shouldRetryWithHeaders(
|
||||
req: RequestInit,
|
||||
res: Response,
|
||||
): Promise<HttpHeaders | undefined> {
|
||||
if (res.status === 401 || res.status === 403) {
|
||||
if (this.authRetryCount >= BaseA2AAuthProvider.MAX_AUTH_RETRIES) {
|
||||
return undefined;
|
||||
}
|
||||
debugLogger.debug(
|
||||
'[HttpAuthProvider] Re-resolving values after auth failure',
|
||||
);
|
||||
await this.initialize();
|
||||
}
|
||||
return super.shouldRetryWithHeaders(req, res);
|
||||
}
|
||||
}
|
||||
@@ -60,6 +60,12 @@ export type HttpAuthConfig = BaseAuthConfig & {
|
||||
/** For Basic. Supports $ENV_VAR, !command, or literal. */
|
||||
password: string;
|
||||
}
|
||||
| {
|
||||
/** Any IANA-registered scheme (e.g., "Digest", "HOBA", "Custom"). */
|
||||
scheme: string;
|
||||
/** Raw value to be sent as "Authorization: <scheme> <value>". Supports $ENV_VAR, !command, or literal. */
|
||||
value: string;
|
||||
}
|
||||
);
|
||||
|
||||
/** Client config corresponding to OAuth2SecurityScheme. */
|
||||
|
||||
@@ -16,8 +16,11 @@
|
||||
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { LocalAgentExecutor } from '../local-executor.js';
|
||||
import type { AnsiOutput } from '../../utils/terminalSerializer.js';
|
||||
import { BaseToolInvocation, type ToolResult } from '../../tools/tools.js';
|
||||
import {
|
||||
BaseToolInvocation,
|
||||
type ToolResult,
|
||||
type ToolLiveOutput,
|
||||
} from '../../tools/tools.js';
|
||||
import { ToolErrorType } from '../../tools/tool-error.js';
|
||||
import type { AgentInputs, SubagentActivityEvent } from '../types.js';
|
||||
import type { MessageBus } from '../../confirmation-bus/message-bus.js';
|
||||
@@ -82,7 +85,7 @@ export class BrowserAgentInvocation extends BaseToolInvocation<
|
||||
*/
|
||||
async execute(
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string | AnsiOutput) => void,
|
||||
updateOutput?: (output: ToolLiveOutput) => void,
|
||||
): Promise<ToolResult> {
|
||||
let browserManager;
|
||||
|
||||
|
||||
@@ -501,7 +501,7 @@ describe('LocalAgentExecutor', () => {
|
||||
expect(agentRegistry.getTool(subAgentName)).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should enforce qualified names for MCP tools in agent definitions', async () => {
|
||||
it('should automatically qualify MCP tools in agent definitions', async () => {
|
||||
const serverName = 'mcp-server';
|
||||
const toolName = 'mcp-tool';
|
||||
const qualifiedName = `${serverName}${MCP_QUALIFIED_NAME_SEPARATOR}${toolName}`;
|
||||
@@ -530,7 +530,7 @@ describe('LocalAgentExecutor', () => {
|
||||
return undefined;
|
||||
});
|
||||
|
||||
// 1. Qualified name works and registers the tool (using short name per status quo)
|
||||
// 1. Qualified name works and registers the tool (using qualified name)
|
||||
const definition = createTestDefinition([qualifiedName]);
|
||||
const executor = await LocalAgentExecutor.create(
|
||||
definition,
|
||||
@@ -539,14 +539,18 @@ describe('LocalAgentExecutor', () => {
|
||||
);
|
||||
|
||||
const agentRegistry = executor['toolRegistry'];
|
||||
// Registry shortening logic means it's registered as 'mcp-tool' internally
|
||||
expect(agentRegistry.getTool(toolName)).toBeDefined();
|
||||
// It should be registered as the qualified name
|
||||
expect(agentRegistry.getTool(qualifiedName)).toBeDefined();
|
||||
|
||||
// 2. Unqualified name for MCP tool THROWS
|
||||
const badDefinition = createTestDefinition([toolName]);
|
||||
await expect(
|
||||
LocalAgentExecutor.create(badDefinition, mockConfig, onActivity),
|
||||
).rejects.toThrow(/must be requested with its server prefix/);
|
||||
// 2. Unqualified name for MCP tool now also works (and gets upgraded to qualified)
|
||||
const definition2 = createTestDefinition([toolName]);
|
||||
const executor2 = await LocalAgentExecutor.create(
|
||||
definition2,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
const agentRegistry2 = executor2['toolRegistry'];
|
||||
expect(agentRegistry2.getTool(qualifiedName)).toBeDefined();
|
||||
|
||||
getToolSpy.mockRestore();
|
||||
});
|
||||
@@ -809,25 +813,28 @@ describe('LocalAgentExecutor', () => {
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
type: 'THOUGHT_CHUNK',
|
||||
data: { text: 'T1: Listing' },
|
||||
data: expect.objectContaining({ text: 'T1: Listing' }),
|
||||
}),
|
||||
expect.objectContaining({
|
||||
type: 'TOOL_CALL_END',
|
||||
data: { name: LS_TOOL_NAME, output: 'file1.txt' },
|
||||
data: expect.objectContaining({
|
||||
name: LS_TOOL_NAME,
|
||||
output: 'file1.txt',
|
||||
}),
|
||||
}),
|
||||
expect.objectContaining({
|
||||
type: 'TOOL_CALL_START',
|
||||
data: {
|
||||
data: expect.objectContaining({
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'Found file1.txt' },
|
||||
},
|
||||
}),
|
||||
}),
|
||||
expect.objectContaining({
|
||||
type: 'TOOL_CALL_END',
|
||||
data: {
|
||||
data: expect.objectContaining({
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
output: expect.stringContaining('Output submitted'),
|
||||
},
|
||||
}),
|
||||
}),
|
||||
]),
|
||||
);
|
||||
|
||||
@@ -214,15 +214,14 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
// registry and register it with the agent's isolated registry.
|
||||
const tool = parentToolRegistry.getTool(toolName);
|
||||
if (tool) {
|
||||
if (
|
||||
tool instanceof DiscoveredMCPTool &&
|
||||
!toolName.includes(MCP_QUALIFIED_NAME_SEPARATOR)
|
||||
) {
|
||||
throw new Error(
|
||||
`MCP tool '${toolName}' must be requested with its server prefix (e.g., '${tool.serverName}${MCP_QUALIFIED_NAME_SEPARATOR}${toolName}') in agent '${definition.name}'.`,
|
||||
);
|
||||
if (tool instanceof DiscoveredMCPTool) {
|
||||
// Subagents MUST use fully qualified names for MCP tools to ensure
|
||||
// unambiguous tool calls and to comply with policy requirements.
|
||||
// We automatically "upgrade" any MCP tool to its qualified version.
|
||||
agentToolRegistry.registerTool(tool.asFullyQualifiedTool());
|
||||
} else {
|
||||
agentToolRegistry.registerTool(tool);
|
||||
}
|
||||
agentToolRegistry.registerTool(tool);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -341,13 +340,22 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
};
|
||||
}
|
||||
|
||||
const { nextMessage, submittedOutput, taskCompleted } =
|
||||
const { nextMessage, submittedOutput, taskCompleted, aborted } =
|
||||
await this.processFunctionCalls(
|
||||
functionCalls,
|
||||
combinedSignal,
|
||||
promptId,
|
||||
onWaitingForConfirmation,
|
||||
);
|
||||
|
||||
if (aborted) {
|
||||
return {
|
||||
status: 'stop',
|
||||
terminateReason: AgentTerminateMode.ABORTED,
|
||||
finalResult: null,
|
||||
};
|
||||
}
|
||||
|
||||
if (taskCompleted) {
|
||||
const finalResult = submittedOutput ?? 'Task completed successfully.';
|
||||
return {
|
||||
@@ -929,6 +937,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
nextMessage: Content;
|
||||
submittedOutput: string | null;
|
||||
taskCompleted: boolean;
|
||||
aborted: boolean;
|
||||
}> {
|
||||
const allowedToolNames = new Set(this.toolRegistry.getAllToolNames());
|
||||
// Always allow the completion tool
|
||||
@@ -936,6 +945,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
|
||||
let submittedOutput: string | null = null;
|
||||
let taskCompleted = false;
|
||||
let aborted = false;
|
||||
|
||||
// We'll separate complete_task from other tools
|
||||
const toolRequests: ToolCallRequestInfo[] = [];
|
||||
@@ -950,8 +960,24 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const toolName = functionCall.name as string;
|
||||
|
||||
let displayName = toolName;
|
||||
let description: string | undefined = undefined;
|
||||
|
||||
try {
|
||||
const tool = this.toolRegistry.getTool(toolName);
|
||||
if (tool) {
|
||||
displayName = tool.displayName ?? toolName;
|
||||
const invocation = tool.build(args);
|
||||
description = invocation.getDescription();
|
||||
}
|
||||
} catch {
|
||||
// Ignore errors during formatting for activity emission
|
||||
}
|
||||
|
||||
this.emitActivity('TOOL_CALL_START', {
|
||||
name: toolName,
|
||||
displayName,
|
||||
description,
|
||||
args,
|
||||
});
|
||||
|
||||
@@ -1149,8 +1175,9 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
this.emitActivity('ERROR', {
|
||||
context: 'tool_call',
|
||||
name: toolName,
|
||||
error: 'Tool call was cancelled.',
|
||||
error: 'Request cancelled.',
|
||||
});
|
||||
aborted = true;
|
||||
}
|
||||
|
||||
// Add result to syncResults to preserve order later
|
||||
@@ -1183,6 +1210,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
nextMessage: { role: 'user', parts: toolResponseParts },
|
||||
submittedOutput,
|
||||
taskCompleted,
|
||||
aborted,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -4,17 +4,25 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, type Mocked } from 'vitest';
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mocked,
|
||||
} from 'vitest';
|
||||
import type {
|
||||
LocalAgentDefinition,
|
||||
SubagentActivityEvent,
|
||||
AgentInputs,
|
||||
SubagentProgress,
|
||||
} from './types.js';
|
||||
import { LocalSubagentInvocation } from './local-invocation.js';
|
||||
import { LocalAgentExecutor } from './local-executor.js';
|
||||
import { AgentTerminateMode } from './types.js';
|
||||
import { makeFakeConfig } from '../test-utils/config.js';
|
||||
import { ToolErrorType } from '../tools/tool-error.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { type z } from 'zod';
|
||||
@@ -29,6 +37,7 @@ let mockConfig: Config;
|
||||
const testDefinition: LocalAgentDefinition<z.ZodUnknown> = {
|
||||
kind: 'local',
|
||||
name: 'MockAgent',
|
||||
displayName: 'Mock Agent',
|
||||
description: 'A mock agent.',
|
||||
inputConfig: {
|
||||
inputSchema: {
|
||||
@@ -70,6 +79,10 @@ describe('LocalSubagentInvocation', () => {
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should pass the messageBus to the parent constructor', () => {
|
||||
const params = { task: 'Analyze data' };
|
||||
const invocation = new LocalSubagentInvocation(
|
||||
@@ -173,7 +186,12 @@ describe('LocalSubagentInvocation', () => {
|
||||
mockConfig,
|
||||
expect.any(Function),
|
||||
);
|
||||
expect(updateOutput).toHaveBeenCalledWith('Subagent starting...\n');
|
||||
expect(updateOutput).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
isSubagentProgress: true,
|
||||
agentName: 'MockAgent',
|
||||
}),
|
||||
);
|
||||
|
||||
expect(mockExecutorInstance.run).toHaveBeenCalledWith(params, signal);
|
||||
|
||||
@@ -211,13 +229,17 @@ describe('LocalSubagentInvocation', () => {
|
||||
|
||||
await invocation.execute(signal, updateOutput);
|
||||
|
||||
expect(updateOutput).toHaveBeenCalledWith('Subagent starting...\n');
|
||||
expect(updateOutput).toHaveBeenCalledWith('🤖💭 Analyzing...');
|
||||
expect(updateOutput).toHaveBeenCalledWith('🤖💭 Still thinking.');
|
||||
expect(updateOutput).toHaveBeenCalledTimes(3); // Initial message + 2 thoughts
|
||||
expect(updateOutput).toHaveBeenCalledTimes(3); // Initial + 2 updates
|
||||
const lastCall = updateOutput.mock.calls[2][0] as SubagentProgress;
|
||||
expect(lastCall.recentActivity).toContainEqual(
|
||||
expect.objectContaining({
|
||||
type: 'thought',
|
||||
content: 'Analyzing... Still thinking.',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should NOT stream other activities (e.g., TOOL_CALL_START, ERROR)', async () => {
|
||||
it('should stream other activities (e.g., TOOL_CALL_START, ERROR)', async () => {
|
||||
mockExecutorInstance.run.mockImplementation(async () => {
|
||||
const onActivity = MockLocalAgentExecutor.create.mock.calls[0][2];
|
||||
|
||||
@@ -226,7 +248,7 @@ describe('LocalSubagentInvocation', () => {
|
||||
isSubagentActivityEvent: true,
|
||||
agentName: 'MockAgent',
|
||||
type: 'TOOL_CALL_START',
|
||||
data: { name: 'ls' },
|
||||
data: { name: 'ls', args: {} },
|
||||
} as SubagentActivityEvent);
|
||||
onActivity({
|
||||
isSubagentActivityEvent: true,
|
||||
@@ -240,9 +262,15 @@ describe('LocalSubagentInvocation', () => {
|
||||
|
||||
await invocation.execute(signal, updateOutput);
|
||||
|
||||
// Should only contain the initial "Subagent starting..." message
|
||||
expect(updateOutput).toHaveBeenCalledTimes(1);
|
||||
expect(updateOutput).toHaveBeenCalledWith('Subagent starting...\n');
|
||||
expect(updateOutput).toHaveBeenCalledTimes(3);
|
||||
const lastCall = updateOutput.mock.calls[2][0] as SubagentProgress;
|
||||
expect(lastCall.recentActivity).toContainEqual(
|
||||
expect.objectContaining({
|
||||
type: 'thought',
|
||||
content: 'Error: Failed',
|
||||
status: 'error',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should run successfully without an updateOutput callback', async () => {
|
||||
@@ -272,16 +300,19 @@ describe('LocalSubagentInvocation', () => {
|
||||
|
||||
const result = await invocation.execute(signal, updateOutput);
|
||||
|
||||
expect(result.error).toEqual({
|
||||
message: error.message,
|
||||
type: ToolErrorType.EXECUTION_FAILED,
|
||||
});
|
||||
expect(result.returnDisplay).toBe(
|
||||
`Subagent Failed: MockAgent\nError: ${error.message}`,
|
||||
);
|
||||
expect(result.error).toBeUndefined();
|
||||
expect(result.llmContent).toBe(
|
||||
`Subagent 'MockAgent' failed. Error: ${error.message}`,
|
||||
);
|
||||
const display = result.returnDisplay as SubagentProgress;
|
||||
expect(display.isSubagentProgress).toBe(true);
|
||||
expect(display.recentActivity).toContainEqual(
|
||||
expect.objectContaining({
|
||||
type: 'thought',
|
||||
content: `Error: ${error.message}`,
|
||||
status: 'error',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle executor creation failure', async () => {
|
||||
@@ -291,19 +322,21 @@ describe('LocalSubagentInvocation', () => {
|
||||
const result = await invocation.execute(signal, updateOutput);
|
||||
|
||||
expect(mockExecutorInstance.run).not.toHaveBeenCalled();
|
||||
expect(result.error).toEqual({
|
||||
message: creationError.message,
|
||||
type: ToolErrorType.EXECUTION_FAILED,
|
||||
});
|
||||
expect(result.returnDisplay).toContain(`Error: ${creationError.message}`);
|
||||
expect(result.error).toBeUndefined();
|
||||
expect(result.llmContent).toContain(creationError.message);
|
||||
|
||||
const display = result.returnDisplay as SubagentProgress;
|
||||
expect(display.recentActivity).toContainEqual(
|
||||
expect.objectContaining({
|
||||
content: `Error: ${creationError.message}`,
|
||||
status: 'error',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* This test verifies that the AbortSignal is correctly propagated and
|
||||
* that a rejection from the executor due to abortion is handled gracefully.
|
||||
*/
|
||||
it('should handle abortion signal during execution', async () => {
|
||||
const abortError = new Error('Aborted');
|
||||
abortError.name = 'AbortError';
|
||||
mockExecutorInstance.run.mockRejectedValue(abortError);
|
||||
|
||||
const controller = new AbortController();
|
||||
@@ -312,14 +345,24 @@ describe('LocalSubagentInvocation', () => {
|
||||
updateOutput,
|
||||
);
|
||||
controller.abort();
|
||||
const result = await executePromise;
|
||||
await expect(executePromise).rejects.toThrow('Aborted');
|
||||
|
||||
expect(mockExecutorInstance.run).toHaveBeenCalledWith(
|
||||
params,
|
||||
controller.signal,
|
||||
);
|
||||
expect(result.error?.message).toBe('Aborted');
|
||||
expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED);
|
||||
});
|
||||
|
||||
it('should throw an error and bubble cancellation when execution returns ABORTED', async () => {
|
||||
const mockOutput = {
|
||||
result: 'Cancelled by user',
|
||||
terminate_reason: AgentTerminateMode.ABORTED,
|
||||
};
|
||||
mockExecutorInstance.run.mockResolvedValue(mockOutput);
|
||||
|
||||
await expect(invocation.execute(signal, updateOutput)).rejects.toThrow(
|
||||
'Operation cancelled by user',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,18 +6,25 @@
|
||||
|
||||
import type { Config } from '../config/config.js';
|
||||
import { LocalAgentExecutor } from './local-executor.js';
|
||||
import type { AnsiOutput } from '../utils/terminalSerializer.js';
|
||||
import { BaseToolInvocation, type ToolResult } from '../tools/tools.js';
|
||||
import { ToolErrorType } from '../tools/tool-error.js';
|
||||
import type {
|
||||
LocalAgentDefinition,
|
||||
AgentInputs,
|
||||
SubagentActivityEvent,
|
||||
import {
|
||||
BaseToolInvocation,
|
||||
type ToolResult,
|
||||
type ToolLiveOutput,
|
||||
} from '../tools/tools.js';
|
||||
import {
|
||||
type LocalAgentDefinition,
|
||||
type AgentInputs,
|
||||
type SubagentActivityEvent,
|
||||
type SubagentProgress,
|
||||
type SubagentActivityItem,
|
||||
AgentTerminateMode,
|
||||
} from './types.js';
|
||||
import { randomUUID } from 'node:crypto';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
|
||||
const INPUT_PREVIEW_MAX_LENGTH = 50;
|
||||
const DESCRIPTION_MAX_LENGTH = 200;
|
||||
const MAX_RECENT_ACTIVITY = 3;
|
||||
|
||||
/**
|
||||
* Represents a validated, executable instance of a subagent tool.
|
||||
@@ -81,11 +88,20 @@ export class LocalSubagentInvocation extends BaseToolInvocation<
|
||||
*/
|
||||
async execute(
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string | AnsiOutput) => void,
|
||||
updateOutput?: (output: ToolLiveOutput) => void,
|
||||
): Promise<ToolResult> {
|
||||
let recentActivity: SubagentActivityItem[] = [];
|
||||
|
||||
try {
|
||||
if (updateOutput) {
|
||||
updateOutput('Subagent starting...\n');
|
||||
// Send initial state
|
||||
const initialProgress: SubagentProgress = {
|
||||
isSubagentProgress: true,
|
||||
agentName: this.definition.name,
|
||||
recentActivity: [],
|
||||
state: 'running',
|
||||
};
|
||||
updateOutput(initialProgress);
|
||||
}
|
||||
|
||||
// Create an activity callback to bridge the executor's events to the
|
||||
@@ -93,11 +109,114 @@ export class LocalSubagentInvocation extends BaseToolInvocation<
|
||||
const onActivity = (activity: SubagentActivityEvent): void => {
|
||||
if (!updateOutput) return;
|
||||
|
||||
if (
|
||||
activity.type === 'THOUGHT_CHUNK' &&
|
||||
typeof activity.data['text'] === 'string'
|
||||
) {
|
||||
updateOutput(`🤖💭 ${activity.data['text']}`);
|
||||
let updated = false;
|
||||
|
||||
switch (activity.type) {
|
||||
case 'THOUGHT_CHUNK': {
|
||||
const text = String(activity.data['text']);
|
||||
const lastItem = recentActivity[recentActivity.length - 1];
|
||||
if (
|
||||
lastItem &&
|
||||
lastItem.type === 'thought' &&
|
||||
lastItem.status === 'running'
|
||||
) {
|
||||
lastItem.content += text;
|
||||
} else {
|
||||
recentActivity.push({
|
||||
id: randomUUID(),
|
||||
type: 'thought',
|
||||
content: text,
|
||||
status: 'running',
|
||||
});
|
||||
}
|
||||
updated = true;
|
||||
break;
|
||||
}
|
||||
case 'TOOL_CALL_START': {
|
||||
const name = String(activity.data['name']);
|
||||
const displayName = activity.data['displayName']
|
||||
? String(activity.data['displayName'])
|
||||
: undefined;
|
||||
const description = activity.data['description']
|
||||
? String(activity.data['description'])
|
||||
: undefined;
|
||||
const args = JSON.stringify(activity.data['args']);
|
||||
recentActivity.push({
|
||||
id: randomUUID(),
|
||||
type: 'tool_call',
|
||||
content: name,
|
||||
displayName,
|
||||
description,
|
||||
args,
|
||||
status: 'running',
|
||||
});
|
||||
updated = true;
|
||||
break;
|
||||
}
|
||||
case 'TOOL_CALL_END': {
|
||||
const name = String(activity.data['name']);
|
||||
// Find the last running tool call with this name
|
||||
for (let i = recentActivity.length - 1; i >= 0; i--) {
|
||||
if (
|
||||
recentActivity[i].type === 'tool_call' &&
|
||||
recentActivity[i].content === name &&
|
||||
recentActivity[i].status === 'running'
|
||||
) {
|
||||
recentActivity[i].status = 'completed';
|
||||
updated = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'ERROR': {
|
||||
const error = String(activity.data['error']);
|
||||
const isCancellation = error === 'Request cancelled.';
|
||||
const toolName = activity.data['name']
|
||||
? String(activity.data['name'])
|
||||
: undefined;
|
||||
|
||||
if (toolName && isCancellation) {
|
||||
for (let i = recentActivity.length - 1; i >= 0; i--) {
|
||||
if (
|
||||
recentActivity[i].type === 'tool_call' &&
|
||||
recentActivity[i].content === toolName &&
|
||||
recentActivity[i].status === 'running'
|
||||
) {
|
||||
recentActivity[i].status = 'cancelled';
|
||||
updated = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
recentActivity.push({
|
||||
id: randomUUID(),
|
||||
type: 'thought', // Treat errors as thoughts for now, or add an error type
|
||||
content: isCancellation ? error : `Error: ${error}`,
|
||||
status: isCancellation ? 'cancelled' : 'error',
|
||||
});
|
||||
updated = true;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
if (updated) {
|
||||
// Keep only the last N items
|
||||
if (recentActivity.length > MAX_RECENT_ACTIVITY) {
|
||||
recentActivity = recentActivity.slice(-MAX_RECENT_ACTIVITY);
|
||||
}
|
||||
|
||||
const progress: SubagentProgress = {
|
||||
isSubagentProgress: true,
|
||||
agentName: this.definition.name,
|
||||
recentActivity: [...recentActivity], // Copy to avoid mutation issues
|
||||
state: 'running',
|
||||
};
|
||||
|
||||
updateOutput(progress);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -109,6 +228,23 @@ export class LocalSubagentInvocation extends BaseToolInvocation<
|
||||
|
||||
const output = await executor.run(this.params, signal);
|
||||
|
||||
if (output.terminate_reason === AgentTerminateMode.ABORTED) {
|
||||
const progress: SubagentProgress = {
|
||||
isSubagentProgress: true,
|
||||
agentName: this.definition.name,
|
||||
recentActivity: [...recentActivity],
|
||||
state: 'cancelled',
|
||||
};
|
||||
|
||||
if (updateOutput) {
|
||||
updateOutput(progress);
|
||||
}
|
||||
|
||||
const cancelError = new Error('Operation cancelled by user');
|
||||
cancelError.name = 'AbortError';
|
||||
throw cancelError;
|
||||
}
|
||||
|
||||
const resultContent = `Subagent '${this.definition.name}' finished.
|
||||
Termination Reason: ${output.terminate_reason}
|
||||
Result:
|
||||
@@ -131,13 +267,55 @@ ${output.result}
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
|
||||
const isAbort =
|
||||
(error instanceof Error && error.name === 'AbortError') ||
|
||||
errorMessage.includes('Aborted');
|
||||
|
||||
// Mark any running items as error/cancelled
|
||||
for (const item of recentActivity) {
|
||||
if (item.status === 'running') {
|
||||
item.status = isAbort ? 'cancelled' : 'error';
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the error is reflected in the recent activity for display
|
||||
// But only if it's NOT an abort, or if we want to show "Cancelled" as a thought
|
||||
if (!isAbort) {
|
||||
const lastActivity = recentActivity[recentActivity.length - 1];
|
||||
if (!lastActivity || lastActivity.status !== 'error') {
|
||||
recentActivity.push({
|
||||
id: randomUUID(),
|
||||
type: 'thought',
|
||||
content: `Error: ${errorMessage}`,
|
||||
status: 'error',
|
||||
});
|
||||
// Maintain size limit
|
||||
if (recentActivity.length > MAX_RECENT_ACTIVITY) {
|
||||
recentActivity = recentActivity.slice(-MAX_RECENT_ACTIVITY);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const progress: SubagentProgress = {
|
||||
isSubagentProgress: true,
|
||||
agentName: this.definition.name,
|
||||
recentActivity: [...recentActivity],
|
||||
state: isAbort ? 'cancelled' : 'error',
|
||||
};
|
||||
|
||||
if (updateOutput) {
|
||||
updateOutput(progress);
|
||||
}
|
||||
|
||||
if (isAbort) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return {
|
||||
llmContent: `Subagent '${this.definition.name}' failed. Error: ${errorMessage}`,
|
||||
returnDisplay: `Subagent Failed: ${this.definition.name}\nError: ${errorMessage}`,
|
||||
error: {
|
||||
message: errorMessage,
|
||||
type: ToolErrorType.EXECUTION_FAILED,
|
||||
},
|
||||
returnDisplay: progress,
|
||||
// We omit the 'error' property so that the UI renders our rich returnDisplay
|
||||
// instead of the raw error message. The llmContent still informs the agent of the failure.
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,6 +30,8 @@ import type { ToolRegistry } from '../tools/tool-registry.js';
|
||||
import { ThinkingLevel } from '@google/genai';
|
||||
import type { AcknowledgedAgentsService } from './acknowledgedAgents.js';
|
||||
import { PolicyDecision } from '../policy/types.js';
|
||||
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
|
||||
import type { A2AAuthProvider } from './auth-provider/types.js';
|
||||
|
||||
vi.mock('./agentLoader.js', () => ({
|
||||
loadAgentsFromDirectory: vi
|
||||
@@ -43,6 +45,12 @@ vi.mock('./a2a-client-manager.js', () => ({
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock('./auth-provider/factory.js', () => ({
|
||||
A2AAuthProviderFactory: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
function makeMockedConfig(params?: Partial<ConfigParameters>): Config {
|
||||
const config = makeFakeConfig(params);
|
||||
vi.spyOn(config, 'getToolRegistry').mockReturnValue({
|
||||
@@ -546,6 +554,90 @@ describe('AgentRegistry', () => {
|
||||
expect(registry.getDefinition('RemoteAgent')).toEqual(remoteAgent);
|
||||
});
|
||||
|
||||
it('should register a remote agent with authentication configuration', async () => {
|
||||
const mockAuth = {
|
||||
type: 'http' as const,
|
||||
scheme: 'Bearer' as const,
|
||||
token: 'secret-token',
|
||||
};
|
||||
const remoteAgent: AgentDefinition = {
|
||||
kind: 'remote',
|
||||
name: 'RemoteAgentWithAuth',
|
||||
description: 'A remote agent',
|
||||
agentCardUrl: 'https://example.com/card',
|
||||
inputConfig: { inputSchema: { type: 'object' } },
|
||||
auth: mockAuth,
|
||||
};
|
||||
|
||||
const mockHandler = {
|
||||
type: 'http' as const,
|
||||
headers: vi
|
||||
.fn()
|
||||
.mockResolvedValue({ Authorization: 'Bearer secret-token' }),
|
||||
shouldRetryWithHeaders: vi.fn(),
|
||||
} as unknown as A2AAuthProvider;
|
||||
vi.mocked(A2AAuthProviderFactory.create).mockResolvedValue(mockHandler);
|
||||
|
||||
const loadAgentSpy = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ name: 'RemoteAgentWithAuth' });
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
loadAgent: loadAgentSpy,
|
||||
clearCache: vi.fn(),
|
||||
} as unknown as A2AClientManager);
|
||||
|
||||
await registry.testRegisterAgent(remoteAgent);
|
||||
|
||||
expect(A2AAuthProviderFactory.create).toHaveBeenCalledWith({
|
||||
authConfig: mockAuth,
|
||||
agentName: 'RemoteAgentWithAuth',
|
||||
});
|
||||
expect(loadAgentSpy).toHaveBeenCalledWith(
|
||||
'RemoteAgentWithAuth',
|
||||
'https://example.com/card',
|
||||
mockHandler,
|
||||
);
|
||||
expect(registry.getDefinition('RemoteAgentWithAuth')).toEqual(
|
||||
remoteAgent,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not register remote agent when auth provider factory returns undefined', async () => {
|
||||
const remoteAgent: AgentDefinition = {
|
||||
kind: 'remote',
|
||||
name: 'RemoteAgentBadAuth',
|
||||
description: 'A remote agent',
|
||||
agentCardUrl: 'https://example.com/card',
|
||||
inputConfig: { inputSchema: { type: 'object' } },
|
||||
auth: {
|
||||
type: 'http' as const,
|
||||
scheme: 'Bearer' as const,
|
||||
token: 'secret-token',
|
||||
},
|
||||
};
|
||||
|
||||
vi.mocked(A2AAuthProviderFactory.create).mockResolvedValue(undefined);
|
||||
const loadAgentSpy = vi.fn();
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
loadAgent: loadAgentSpy,
|
||||
clearCache: vi.fn(),
|
||||
} as unknown as A2AClientManager);
|
||||
|
||||
const warnSpy = vi
|
||||
.spyOn(debugLogger, 'warn')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
await registry.testRegisterAgent(remoteAgent);
|
||||
|
||||
expect(loadAgentSpy).not.toHaveBeenCalled();
|
||||
expect(registry.getDefinition('RemoteAgentBadAuth')).toBeUndefined();
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Error loading A2A agent'),
|
||||
expect.any(Error),
|
||||
);
|
||||
warnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should log remote agent registration in debug mode', async () => {
|
||||
const debugConfig = makeMockedConfig({ debugMode: true });
|
||||
const debugRegistry = new TestableAgentRegistry(debugConfig);
|
||||
@@ -572,6 +664,30 @@ describe('AgentRegistry', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should surface an error if remote agent registration fails', async () => {
|
||||
const remoteAgent: AgentDefinition = {
|
||||
kind: 'remote',
|
||||
name: 'FailingRemoteAgent',
|
||||
description: 'A remote agent',
|
||||
agentCardUrl: 'https://example.com/card',
|
||||
inputConfig: { inputSchema: { type: 'object' } },
|
||||
};
|
||||
|
||||
const error = new Error('401 Unauthorized');
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
loadAgent: vi.fn().mockRejectedValue(error),
|
||||
} as unknown as A2AClientManager);
|
||||
|
||||
const feedbackSpy = vi.spyOn(coreEvents, 'emitFeedback');
|
||||
|
||||
await registry.testRegisterAgent(remoteAgent);
|
||||
|
||||
expect(feedbackSpy).toHaveBeenCalledWith(
|
||||
'error',
|
||||
`Error loading A2A agent "FailingRemoteAgent": 401 Unauthorized`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should merge user and agent description and skills when registering a remote agent', async () => {
|
||||
const remoteAgent: AgentDefinition = {
|
||||
kind: 'remote',
|
||||
|
||||
@@ -14,7 +14,8 @@ import { CliHelpAgent } from './cli-help-agent.js';
|
||||
import { GeneralistAgent } from './generalist-agent.js';
|
||||
import { BrowserAgentDefinition } from './browser/browserAgentDefinition.js';
|
||||
import { A2AClientManager } from './a2a-client-manager.js';
|
||||
import { ADCHandler } from './remote-invocation.js';
|
||||
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
|
||||
import type { AuthenticationHandler } from '@a2a-js/sdk/client';
|
||||
import { type z } from 'zod';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { isAutoModel } from '../config/models.js';
|
||||
@@ -118,7 +119,20 @@ export class AgentRegistry {
|
||||
coreEvents.emitFeedback('error', `Agent loading error: ${error.message}`);
|
||||
}
|
||||
await Promise.allSettled(
|
||||
userAgents.agents.map((agent) => this.registerAgent(agent)),
|
||||
userAgents.agents.map(async (agent) => {
|
||||
try {
|
||||
await this.registerAgent(agent);
|
||||
} catch (e) {
|
||||
debugLogger.warn(
|
||||
`[AgentRegistry] Error registering user agent "${agent.name}":`,
|
||||
e,
|
||||
);
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`Error registering user agent "${agent.name}": ${e instanceof Error ? e.message : String(e)}`,
|
||||
);
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// Load project-level agents: .gemini/agents/ (relative to Project Root)
|
||||
@@ -173,7 +187,20 @@ export class AgentRegistry {
|
||||
}
|
||||
|
||||
await Promise.allSettled(
|
||||
agentsToRegister.map((agent) => this.registerAgent(agent)),
|
||||
agentsToRegister.map(async (agent) => {
|
||||
try {
|
||||
await this.registerAgent(agent);
|
||||
} catch (e) {
|
||||
debugLogger.warn(
|
||||
`[AgentRegistry] Error registering project agent "${agent.name}":`,
|
||||
e,
|
||||
);
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`Error registering project agent "${agent.name}": ${e instanceof Error ? e.message : String(e)}`,
|
||||
);
|
||||
}
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
coreEvents.emitFeedback(
|
||||
@@ -186,7 +213,20 @@ export class AgentRegistry {
|
||||
for (const extension of this.config.getExtensions()) {
|
||||
if (extension.isActive && extension.agents) {
|
||||
await Promise.allSettled(
|
||||
extension.agents.map((agent) => this.registerAgent(agent)),
|
||||
extension.agents.map(async (agent) => {
|
||||
try {
|
||||
await this.registerAgent(agent);
|
||||
} catch (e) {
|
||||
debugLogger.warn(
|
||||
`[AgentRegistry] Error registering extension agent "${agent.name}":`,
|
||||
e,
|
||||
);
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`Error registering extension agent "${agent.name}": ${e instanceof Error ? e.message : String(e)}`,
|
||||
);
|
||||
}
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -371,8 +411,20 @@ export class AgentRegistry {
|
||||
// Log remote A2A agent registration for visibility.
|
||||
try {
|
||||
const clientManager = A2AClientManager.getInstance();
|
||||
// Use ADCHandler to ensure we can load agents hosted on secure platforms (e.g. Vertex AI)
|
||||
const authHandler = new ADCHandler();
|
||||
let authHandler: AuthenticationHandler | undefined;
|
||||
if (definition.auth) {
|
||||
const provider = await A2AAuthProviderFactory.create({
|
||||
authConfig: definition.auth,
|
||||
agentName: definition.name,
|
||||
});
|
||||
if (!provider) {
|
||||
throw new Error(
|
||||
`Failed to create auth provider for agent '${definition.name}'`,
|
||||
);
|
||||
}
|
||||
authHandler = provider;
|
||||
}
|
||||
|
||||
const agentCard = await clientManager.loadAgent(
|
||||
remoteDef.name,
|
||||
remoteDef.agentCardUrl,
|
||||
@@ -411,10 +463,9 @@ export class AgentRegistry {
|
||||
this.agents.set(definition.name, definition);
|
||||
this.addAgentPolicy(definition);
|
||||
} catch (e) {
|
||||
debugLogger.warn(
|
||||
`[AgentRegistry] Error loading A2A agent "${definition.name}":`,
|
||||
e,
|
||||
);
|
||||
const errorMessage = `Error loading A2A agent "${definition.name}": ${e instanceof Error ? e.message : String(e)}`;
|
||||
debugLogger.warn(`[AgentRegistry] ${errorMessage}`, e);
|
||||
coreEvents.emitFeedback('error', errorMessage);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,14 +20,22 @@ import {
|
||||
} from './a2a-client-manager.js';
|
||||
import type { RemoteAgentDefinition } from './types.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
|
||||
import type { A2AAuthProvider } from './auth-provider/types.js';
|
||||
|
||||
// Mock A2AClientManager
|
||||
vi.mock('./a2a-client-manager.js', () => {
|
||||
const A2AClientManager = {
|
||||
vi.mock('./a2a-client-manager.js', () => ({
|
||||
A2AClientManager: {
|
||||
getInstance: vi.fn(),
|
||||
};
|
||||
return { A2AClientManager };
|
||||
});
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock A2AAuthProviderFactory
|
||||
vi.mock('./auth-provider/factory.js', () => ({
|
||||
A2AAuthProviderFactory: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('RemoteAgentInvocation', () => {
|
||||
const mockDefinition: RemoteAgentDefinition = {
|
||||
@@ -118,7 +126,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
});
|
||||
|
||||
describe('Execution Logic', () => {
|
||||
it('should lazy load the agent with ADCHandler if not present', async () => {
|
||||
it('should lazy load the agent without auth handler when no auth configured', async () => {
|
||||
mockClientManager.getClient.mockReturnValue(undefined);
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
@@ -143,10 +151,80 @@ describe('RemoteAgentInvocation', () => {
|
||||
expect(mockClientManager.loadAgent).toHaveBeenCalledWith(
|
||||
'test-agent',
|
||||
'http://test-agent/card',
|
||||
expect.objectContaining({
|
||||
headers: expect.any(Function),
|
||||
shouldRetryWithHeaders: expect.any(Function),
|
||||
}),
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should use A2AAuthProviderFactory when auth is present in definition', async () => {
|
||||
const mockAuth = {
|
||||
type: 'http' as const,
|
||||
scheme: 'Basic' as const,
|
||||
username: 'admin',
|
||||
password: 'password',
|
||||
};
|
||||
const authDefinition: RemoteAgentDefinition = {
|
||||
...mockDefinition,
|
||||
auth: mockAuth,
|
||||
};
|
||||
|
||||
const mockHandler = {
|
||||
type: 'http' as const,
|
||||
headers: vi.fn().mockResolvedValue({ Authorization: 'Basic dGVzdA==' }),
|
||||
shouldRetryWithHeaders: vi.fn(),
|
||||
} as unknown as A2AAuthProvider;
|
||||
(A2AAuthProviderFactory.create as Mock).mockResolvedValue(mockHandler);
|
||||
mockClientManager.getClient.mockReturnValue(undefined);
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Hello' }],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
authDefinition,
|
||||
{ query: 'hi' },
|
||||
mockMessageBus,
|
||||
);
|
||||
await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(A2AAuthProviderFactory.create).toHaveBeenCalledWith({
|
||||
authConfig: mockAuth,
|
||||
agentName: 'test-agent',
|
||||
});
|
||||
expect(mockClientManager.loadAgent).toHaveBeenCalledWith(
|
||||
'test-agent',
|
||||
'http://test-agent/card',
|
||||
mockHandler,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error when auth provider factory returns undefined for configured auth', async () => {
|
||||
const authDefinition: RemoteAgentDefinition = {
|
||||
...mockDefinition,
|
||||
auth: {
|
||||
type: 'http' as const,
|
||||
scheme: 'Bearer' as const,
|
||||
token: 'secret-token',
|
||||
},
|
||||
};
|
||||
|
||||
(A2AAuthProviderFactory.create as Mock).mockResolvedValue(undefined);
|
||||
mockClientManager.getClient.mockReturnValue(undefined);
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
authDefinition,
|
||||
{ query: 'hi' },
|
||||
mockMessageBus,
|
||||
);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(result.error?.message).toContain(
|
||||
"Failed to create auth provider for agent 'test-agent'",
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ import type { AuthenticationHandler } from '@a2a-js/sdk/client';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { AnsiOutput } from '../utils/terminalSerializer.js';
|
||||
import type { SendMessageResult } from './a2a-client-manager.js';
|
||||
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
|
||||
|
||||
/**
|
||||
* Authentication handler implementation using Google Application Default Credentials (ADC).
|
||||
@@ -79,7 +80,7 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
|
||||
// TODO: See if we can reuse the singleton from AppContainer or similar, but for now use getInstance directly
|
||||
// as per the current pattern in the codebase.
|
||||
private readonly clientManager = A2AClientManager.getInstance();
|
||||
private readonly authHandler = new ADCHandler();
|
||||
private authHandler: AuthenticationHandler | undefined;
|
||||
|
||||
constructor(
|
||||
private readonly definition: RemoteAgentDefinition,
|
||||
@@ -107,6 +108,27 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
|
||||
return `Calling remote agent ${this.definition.displayName ?? this.definition.name}`;
|
||||
}
|
||||
|
||||
private async getAuthHandler(): Promise<AuthenticationHandler | undefined> {
|
||||
if (this.authHandler) {
|
||||
return this.authHandler;
|
||||
}
|
||||
|
||||
if (this.definition.auth) {
|
||||
const provider = await A2AAuthProviderFactory.create({
|
||||
authConfig: this.definition.auth,
|
||||
agentName: this.definition.name,
|
||||
});
|
||||
if (!provider) {
|
||||
throw new Error(
|
||||
`Failed to create auth provider for agent '${this.definition.name}'`,
|
||||
);
|
||||
}
|
||||
this.authHandler = provider;
|
||||
}
|
||||
|
||||
return this.authHandler;
|
||||
}
|
||||
|
||||
protected override async getConfirmationDetails(
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
@@ -138,11 +160,13 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
|
||||
this.taskId = priorState.taskId;
|
||||
}
|
||||
|
||||
const authHandler = await this.getAuthHandler();
|
||||
|
||||
if (!this.clientManager.getClient(this.definition.name)) {
|
||||
await this.clientManager.loadAgent(
|
||||
this.definition.name,
|
||||
this.definition.agentCardUrl,
|
||||
this.authHandler,
|
||||
authHandler,
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -120,6 +120,16 @@ describe('SubAgentInvocation', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should return the correct description', () => {
|
||||
const tool = new SubagentTool(testDefinition, mockConfig, mockMessageBus);
|
||||
const params = {};
|
||||
// @ts-expect-error - accessing protected method for testing
|
||||
const invocation = tool.createInvocation(params, mockMessageBus);
|
||||
expect(invocation.getDescription()).toBe(
|
||||
"Delegating to agent 'LocalAgent'",
|
||||
);
|
||||
});
|
||||
|
||||
it('should delegate shouldConfirmExecute to the inner sub-invocation (remote)', async () => {
|
||||
const tool = new SubagentTool(
|
||||
testRemoteDefinition,
|
||||
|
||||
@@ -12,8 +12,8 @@ import {
|
||||
BaseToolInvocation,
|
||||
type ToolCallConfirmationDetails,
|
||||
isTool,
|
||||
type ToolLiveOutput,
|
||||
} from '../tools/tools.js';
|
||||
import type { AnsiOutput } from '../utils/terminalSerializer.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import type { AgentDefinition, AgentInputs } from './types.js';
|
||||
@@ -155,7 +155,7 @@ class SubAgentInvocation extends BaseToolInvocation<AgentInputs, ToolResult> {
|
||||
|
||||
async execute(
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string | AnsiOutput) => void,
|
||||
updateOutput?: (output: ToolLiveOutput) => void,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = SchemaValidator.validate(
|
||||
this.definition.inputConfig.inputSchema,
|
||||
|
||||
@@ -73,6 +73,32 @@ export interface SubagentActivityEvent {
|
||||
data: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface SubagentActivityItem {
|
||||
id: string;
|
||||
type: 'thought' | 'tool_call';
|
||||
content: string;
|
||||
displayName?: string;
|
||||
description?: string;
|
||||
args?: string;
|
||||
status: 'running' | 'completed' | 'error' | 'cancelled';
|
||||
}
|
||||
|
||||
export interface SubagentProgress {
|
||||
isSubagentProgress: true;
|
||||
agentName: string;
|
||||
recentActivity: SubagentActivityItem[];
|
||||
state?: 'running' | 'completed' | 'error' | 'cancelled';
|
||||
}
|
||||
|
||||
export function isSubagentProgress(obj: unknown): obj is SubagentProgress {
|
||||
return (
|
||||
typeof obj === 'object' &&
|
||||
obj !== null &&
|
||||
'isSubagentProgress' in obj &&
|
||||
obj.isSubagentProgress === true
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* The base definition for an agent.
|
||||
* @template TOutput The specific Zod schema for the agent's final output object.
|
||||
|
||||
@@ -223,8 +223,6 @@ import type {
|
||||
ModelConfigService,
|
||||
ModelConfigServiceConfig,
|
||||
} from '../services/modelConfigService.js';
|
||||
import { ExitPlanModeTool } from '../tools/exit-plan-mode.js';
|
||||
import { EnterPlanModeTool } from '../tools/enter-plan-mode.js';
|
||||
import { LocalLiteRtLmClient } from '../core/localLiteRtLmClient.js';
|
||||
|
||||
vi.mock('../core/baseLlmClient.js');
|
||||
@@ -1204,6 +1202,28 @@ describe('Server Config (config.ts)', () => {
|
||||
expect(SubAgentToolMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should register EnterPlanModeTool and ExitPlanModeTool when plan is enabled', async () => {
|
||||
const params: ConfigParameters = {
|
||||
...baseParams,
|
||||
plan: true,
|
||||
};
|
||||
const config = new Config(params);
|
||||
|
||||
await config.initialize();
|
||||
|
||||
const registerToolMock = (
|
||||
(await vi.importMock('../tools/tool-registry')) as {
|
||||
ToolRegistry: { prototype: { registerTool: Mock } };
|
||||
}
|
||||
).ToolRegistry.prototype.registerTool;
|
||||
|
||||
const registeredTools = registerToolMock.mock.calls.map(
|
||||
(call) => call[0].constructor.name,
|
||||
);
|
||||
expect(registeredTools).toContain('EnterPlanModeTool');
|
||||
expect(registeredTools).toContain('ExitPlanModeTool');
|
||||
});
|
||||
|
||||
describe('with minified tool class names', () => {
|
||||
beforeEach(() => {
|
||||
Object.defineProperty(
|
||||
@@ -2961,131 +2981,6 @@ describe('Plans Directory Initialization', () => {
|
||||
expect(fs.promises.mkdir).not.toHaveBeenCalledWith(plansDir, {
|
||||
recursive: true,
|
||||
});
|
||||
|
||||
const context = config.getWorkspaceContext();
|
||||
expect(context.getDirectories()).not.toContain(plansDir);
|
||||
});
|
||||
});
|
||||
|
||||
describe('syncPlanModeTools', () => {
|
||||
const baseParams: ConfigParameters = {
|
||||
sessionId: 'test-session',
|
||||
targetDir: '.',
|
||||
debugMode: false,
|
||||
model: 'test-model',
|
||||
cwd: '.',
|
||||
};
|
||||
|
||||
it('should register ExitPlanModeTool and unregister EnterPlanModeTool when in PLAN mode', async () => {
|
||||
const config = new Config({
|
||||
...baseParams,
|
||||
approvalMode: ApprovalMode.PLAN,
|
||||
});
|
||||
const registry = new ToolRegistry(config, config.getMessageBus());
|
||||
vi.spyOn(config, 'getToolRegistry').mockReturnValue(registry);
|
||||
|
||||
const registerSpy = vi.spyOn(registry, 'registerTool');
|
||||
const unregisterSpy = vi.spyOn(registry, 'unregisterTool');
|
||||
const getToolSpy = vi.spyOn(registry, 'getTool');
|
||||
|
||||
getToolSpy.mockImplementation((name) => {
|
||||
if (name === 'enter_plan_mode')
|
||||
return new EnterPlanModeTool(config, config.getMessageBus());
|
||||
return undefined;
|
||||
});
|
||||
|
||||
config.syncPlanModeTools();
|
||||
|
||||
expect(unregisterSpy).toHaveBeenCalledWith('enter_plan_mode');
|
||||
expect(registerSpy).toHaveBeenCalledWith(expect.anything());
|
||||
const registeredTool = registerSpy.mock.calls[0][0];
|
||||
const { ExitPlanModeTool } = await import('../tools/exit-plan-mode.js');
|
||||
expect(registeredTool).toBeInstanceOf(ExitPlanModeTool);
|
||||
});
|
||||
|
||||
it('should register EnterPlanModeTool and unregister ExitPlanModeTool when NOT in PLAN mode and experimental.plan is enabled', async () => {
|
||||
const config = new Config({
|
||||
...baseParams,
|
||||
approvalMode: ApprovalMode.DEFAULT,
|
||||
plan: true,
|
||||
});
|
||||
const registry = new ToolRegistry(config, config.getMessageBus());
|
||||
vi.spyOn(config, 'getToolRegistry').mockReturnValue(registry);
|
||||
|
||||
const registerSpy = vi.spyOn(registry, 'registerTool');
|
||||
const unregisterSpy = vi.spyOn(registry, 'unregisterTool');
|
||||
const getToolSpy = vi.spyOn(registry, 'getTool');
|
||||
|
||||
getToolSpy.mockImplementation((name) => {
|
||||
if (name === 'exit_plan_mode')
|
||||
return new ExitPlanModeTool(config, config.getMessageBus());
|
||||
return undefined;
|
||||
});
|
||||
|
||||
config.syncPlanModeTools();
|
||||
|
||||
expect(unregisterSpy).toHaveBeenCalledWith('exit_plan_mode');
|
||||
expect(registerSpy).toHaveBeenCalledWith(expect.anything());
|
||||
const registeredTool = registerSpy.mock.calls[0][0];
|
||||
const { EnterPlanModeTool } = await import('../tools/enter-plan-mode.js');
|
||||
expect(registeredTool).toBeInstanceOf(EnterPlanModeTool);
|
||||
});
|
||||
|
||||
it('should NOT register EnterPlanModeTool when experimental.plan is disabled', async () => {
|
||||
const config = new Config({
|
||||
...baseParams,
|
||||
approvalMode: ApprovalMode.DEFAULT,
|
||||
plan: false,
|
||||
});
|
||||
const registry = new ToolRegistry(config, config.getMessageBus());
|
||||
vi.spyOn(config, 'getToolRegistry').mockReturnValue(registry);
|
||||
|
||||
const registerSpy = vi.spyOn(registry, 'registerTool');
|
||||
vi.spyOn(registry, 'getTool').mockReturnValue(undefined);
|
||||
|
||||
config.syncPlanModeTools();
|
||||
|
||||
const { EnterPlanModeTool } = await import('../tools/enter-plan-mode.js');
|
||||
const registeredTool = registerSpy.mock.calls.find(
|
||||
(call) => call[0] instanceof EnterPlanModeTool,
|
||||
);
|
||||
expect(registeredTool).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should NOT register EnterPlanModeTool when in YOLO mode, even if plan is enabled', async () => {
|
||||
const config = new Config({
|
||||
...baseParams,
|
||||
approvalMode: ApprovalMode.YOLO,
|
||||
plan: true,
|
||||
});
|
||||
const registry = new ToolRegistry(config, config.getMessageBus());
|
||||
vi.spyOn(config, 'getToolRegistry').mockReturnValue(registry);
|
||||
|
||||
const registerSpy = vi.spyOn(registry, 'registerTool');
|
||||
vi.spyOn(registry, 'getTool').mockReturnValue(undefined);
|
||||
|
||||
config.syncPlanModeTools();
|
||||
|
||||
const { EnterPlanModeTool } = await import('../tools/enter-plan-mode.js');
|
||||
const registeredTool = registerSpy.mock.calls.find(
|
||||
(call) => call[0] instanceof EnterPlanModeTool,
|
||||
);
|
||||
expect(registeredTool).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should call geminiClient.setTools if initialized', async () => {
|
||||
const config = new Config(baseParams);
|
||||
const registry = new ToolRegistry(config, config.getMessageBus());
|
||||
vi.spyOn(config, 'getToolRegistry').mockReturnValue(registry);
|
||||
const client = config.getGeminiClient();
|
||||
vi.spyOn(client, 'isInitialized').mockReturnValue(true);
|
||||
const setToolsSpy = vi
|
||||
.spyOn(client, 'setTools')
|
||||
.mockResolvedValue(undefined);
|
||||
|
||||
config.syncPlanModeTools();
|
||||
|
||||
expect(setToolsSpy).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -370,10 +370,6 @@ import { McpClientManager } from '../tools/mcp-client-manager.js';
|
||||
import { type McpContext } from '../tools/mcp-client.js';
|
||||
import type { EnvironmentSanitizationConfig } from '../services/environmentSanitization.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import {
|
||||
ENTER_PLAN_MODE_TOOL_NAME,
|
||||
EXIT_PLAN_MODE_TOOL_NAME,
|
||||
} from '../tools/tool-names.js';
|
||||
|
||||
export type { FileFilteringOptions };
|
||||
export {
|
||||
@@ -1181,7 +1177,6 @@ export class Config implements McpContext {
|
||||
}
|
||||
|
||||
await this.geminiClient.initialize();
|
||||
this.syncPlanModeTools();
|
||||
this.initialized = true;
|
||||
}
|
||||
|
||||
@@ -2020,52 +2015,15 @@ export class Config implements McpContext {
|
||||
(currentMode === ApprovalMode.YOLO || mode === ApprovalMode.YOLO);
|
||||
|
||||
if (isPlanModeTransition || isYoloModeTransition) {
|
||||
this.syncPlanModeTools();
|
||||
if (this.geminiClient?.isInitialized()) {
|
||||
this.geminiClient.setTools().catch((err) => {
|
||||
debugLogger.error('Failed to update tools', err);
|
||||
});
|
||||
}
|
||||
this.updateSystemInstructionIfInitialized();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Synchronizes enter/exit plan mode tools based on current mode.
|
||||
*/
|
||||
syncPlanModeTools(): void {
|
||||
const registry = this.getToolRegistry();
|
||||
if (!registry) {
|
||||
return;
|
||||
}
|
||||
const approvalMode = this.getApprovalMode();
|
||||
const isPlanMode = approvalMode === ApprovalMode.PLAN;
|
||||
const isYoloMode = approvalMode === ApprovalMode.YOLO;
|
||||
|
||||
if (isPlanMode) {
|
||||
if (registry.getTool(ENTER_PLAN_MODE_TOOL_NAME)) {
|
||||
registry.unregisterTool(ENTER_PLAN_MODE_TOOL_NAME);
|
||||
}
|
||||
if (!registry.getTool(EXIT_PLAN_MODE_TOOL_NAME)) {
|
||||
registry.registerTool(new ExitPlanModeTool(this, this.messageBus));
|
||||
}
|
||||
} else {
|
||||
if (registry.getTool(EXIT_PLAN_MODE_TOOL_NAME)) {
|
||||
registry.unregisterTool(EXIT_PLAN_MODE_TOOL_NAME);
|
||||
}
|
||||
if (this.planEnabled && !isYoloMode) {
|
||||
if (!registry.getTool(ENTER_PLAN_MODE_TOOL_NAME)) {
|
||||
registry.registerTool(new EnterPlanModeTool(this, this.messageBus));
|
||||
}
|
||||
} else {
|
||||
if (registry.getTool(ENTER_PLAN_MODE_TOOL_NAME)) {
|
||||
registry.unregisterTool(ENTER_PLAN_MODE_TOOL_NAME);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (this.geminiClient?.isInitialized()) {
|
||||
this.geminiClient.setTools().catch((err) => {
|
||||
debugLogger.error('Failed to update tools', err);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Logs the duration of the current approval mode.
|
||||
*/
|
||||
|
||||
@@ -10,10 +10,11 @@ import type {
|
||||
ToolResult,
|
||||
AnyDeclarativeTool,
|
||||
AnyToolInvocation,
|
||||
ToolLiveOutput,
|
||||
} from '../tools/tools.js';
|
||||
import { ToolErrorType } from '../tools/tool-error.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { AnsiOutput, ShellExecutionConfig } from '../index.js';
|
||||
import type { ShellExecutionConfig } from '../index.js';
|
||||
import { ShellToolInvocation } from '../tools/shell.js';
|
||||
import { DiscoveredMCPToolInvocation } from '../tools/mcp-tool.js';
|
||||
|
||||
@@ -71,7 +72,7 @@ export async function executeToolWithHooks(
|
||||
toolName: string,
|
||||
signal: AbortSignal,
|
||||
tool: AnyDeclarativeTool,
|
||||
liveOutputCallback?: (outputChunk: string | AnsiOutput) => void,
|
||||
liveOutputCallback?: (outputChunk: ToolLiveOutput) => void,
|
||||
shellExecutionConfig?: ShellExecutionConfig,
|
||||
setPidCallback?: (pid: number) => void,
|
||||
config?: Config,
|
||||
|
||||
@@ -189,11 +189,16 @@ export class InvalidStreamError extends Error {
|
||||
readonly type:
|
||||
| 'NO_FINISH_REASON'
|
||||
| 'NO_RESPONSE_TEXT'
|
||||
| 'MALFORMED_FUNCTION_CALL';
|
||||
| 'MALFORMED_FUNCTION_CALL'
|
||||
| 'UNEXPECTED_TOOL_CALL';
|
||||
|
||||
constructor(
|
||||
message: string,
|
||||
type: 'NO_FINISH_REASON' | 'NO_RESPONSE_TEXT' | 'MALFORMED_FUNCTION_CALL',
|
||||
type:
|
||||
| 'NO_FINISH_REASON'
|
||||
| 'NO_RESPONSE_TEXT'
|
||||
| 'MALFORMED_FUNCTION_CALL'
|
||||
| 'UNEXPECTED_TOOL_CALL',
|
||||
) {
|
||||
super(message);
|
||||
this.name = 'InvalidStreamError';
|
||||
@@ -935,6 +940,12 @@ export class GeminiChat {
|
||||
'MALFORMED_FUNCTION_CALL',
|
||||
);
|
||||
}
|
||||
if (finishReason === FinishReason.UNEXPECTED_TOOL_CALL) {
|
||||
throw new InvalidStreamError(
|
||||
'Model stream ended with unexpected tool call.',
|
||||
'UNEXPECTED_TOOL_CALL',
|
||||
);
|
||||
}
|
||||
if (!responseText) {
|
||||
throw new InvalidStreamError(
|
||||
'Model stream ended with empty response text.',
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import {
|
||||
MCPOAuthClientProvider,
|
||||
type OAuthAuthorizationResponse,
|
||||
} from './mcp-oauth-provider.js';
|
||||
import type {
|
||||
OAuthClientInformation,
|
||||
OAuthClientMetadata,
|
||||
OAuthTokens,
|
||||
} from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
|
||||
describe('MCPOAuthClientProvider', () => {
|
||||
const mockRedirectUrl = 'http://localhost:8090/callback';
|
||||
const mockClientMetadata: OAuthClientMetadata = {
|
||||
client_name: 'Test Client',
|
||||
redirect_uris: [mockRedirectUrl],
|
||||
grant_types: ['authorization_code', 'refresh_token'],
|
||||
response_types: ['code'],
|
||||
token_endpoint_auth_method: 'client_secret_post',
|
||||
scope: 'test-scope',
|
||||
};
|
||||
const mockState = 'test-state-123';
|
||||
|
||||
describe('oauth flow', () => {
|
||||
it('should support full OAuth flow', async () => {
|
||||
const onRedirectMock = vi.fn();
|
||||
const provider = new MCPOAuthClientProvider(
|
||||
mockRedirectUrl,
|
||||
mockClientMetadata,
|
||||
mockState,
|
||||
onRedirectMock,
|
||||
);
|
||||
|
||||
// Step 1: Save client information
|
||||
const clientInfo: OAuthClientInformation = {
|
||||
client_id: 'my-client-id',
|
||||
client_secret: 'my-client-secret',
|
||||
};
|
||||
provider.saveClientInformation(clientInfo);
|
||||
|
||||
// Step 2: Save code verifier
|
||||
provider.saveCodeVerifier('my-code-verifier');
|
||||
|
||||
// Step 3: Set up callback server
|
||||
const mockAuthResponse: OAuthAuthorizationResponse = {
|
||||
code: 'authorization-code',
|
||||
state: mockState,
|
||||
};
|
||||
const mockServer = {
|
||||
port: Promise.resolve(8090),
|
||||
waitForResponse: vi.fn().mockResolvedValue(mockAuthResponse),
|
||||
close: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
provider.saveCallbackServer(mockServer);
|
||||
|
||||
// Step 4: Redirect to authorization
|
||||
const authUrl = new URL('http://auth.example.com/authorize');
|
||||
await provider.redirectToAuthorization(authUrl);
|
||||
|
||||
// Step 5: Save tokens after exchange
|
||||
const tokens: OAuthTokens = {
|
||||
access_token: 'final-access-token',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
refresh_token: 'final-refresh-token',
|
||||
};
|
||||
provider.saveTokens(tokens);
|
||||
|
||||
// Verify all data is stored correctly
|
||||
expect(provider.clientInformation()).toEqual(clientInfo);
|
||||
expect(provider.codeVerifier()).toBe('my-code-verifier');
|
||||
expect(provider.state()).toBe(mockState);
|
||||
expect(provider.tokens()).toEqual(tokens);
|
||||
expect(onRedirectMock).toHaveBeenCalledWith(authUrl);
|
||||
expect(provider.getSavedCallbackServer()).toBe(mockServer);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,97 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
|
||||
import type {
|
||||
OAuthClientInformation,
|
||||
OAuthClientMetadata,
|
||||
OAuthTokens,
|
||||
} from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
/**
|
||||
* OAuth authorization response.
|
||||
*/
|
||||
export interface OAuthAuthorizationResponse {
|
||||
code: string;
|
||||
state: string;
|
||||
}
|
||||
|
||||
type CallbackServer = {
|
||||
port: Promise<number>;
|
||||
waitForResponse: () => Promise<OAuthAuthorizationResponse>;
|
||||
close: () => Promise<void>;
|
||||
};
|
||||
|
||||
export class MCPOAuthClientProvider implements OAuthClientProvider {
|
||||
private _clientInformation?: OAuthClientInformation;
|
||||
private _tokens?: OAuthTokens;
|
||||
private _codeVerifier?: string;
|
||||
private _cbServer?: CallbackServer;
|
||||
|
||||
constructor(
|
||||
private readonly _redirectUrl: string | URL,
|
||||
private readonly _clientMetadata: OAuthClientMetadata,
|
||||
private readonly _state?: string | undefined,
|
||||
private readonly _onRedirect: (url: URL) => void = (url) => {
|
||||
debugLogger.log(`Redirect to: ${url.toString()}`);
|
||||
},
|
||||
) {}
|
||||
|
||||
get redirectUrl(): string | URL {
|
||||
return this._redirectUrl;
|
||||
}
|
||||
|
||||
get clientMetadata(): OAuthClientMetadata {
|
||||
return this._clientMetadata;
|
||||
}
|
||||
|
||||
saveCallbackServer(server: CallbackServer): void {
|
||||
this._cbServer = server;
|
||||
}
|
||||
|
||||
getSavedCallbackServer(): CallbackServer | undefined {
|
||||
return this._cbServer;
|
||||
}
|
||||
|
||||
clientInformation(): OAuthClientInformation | undefined {
|
||||
return this._clientInformation;
|
||||
}
|
||||
|
||||
saveClientInformation(clientInformation: OAuthClientInformation): void {
|
||||
this._clientInformation = clientInformation;
|
||||
}
|
||||
|
||||
tokens(): OAuthTokens | undefined {
|
||||
return this._tokens;
|
||||
}
|
||||
|
||||
saveTokens(tokens: OAuthTokens): void {
|
||||
this._tokens = tokens;
|
||||
}
|
||||
|
||||
async redirectToAuthorization(authorizationUrl: URL): Promise<void> {
|
||||
this._onRedirect(authorizationUrl);
|
||||
}
|
||||
|
||||
saveCodeVerifier(codeVerifier: string): void {
|
||||
this._codeVerifier = codeVerifier;
|
||||
}
|
||||
|
||||
codeVerifier(): string {
|
||||
if (!this._codeVerifier) {
|
||||
throw new Error('No code verifier saved');
|
||||
}
|
||||
return this._codeVerifier;
|
||||
}
|
||||
|
||||
state(): string {
|
||||
if (!this._state) {
|
||||
throw new Error('No code state saved');
|
||||
}
|
||||
return this._state;
|
||||
}
|
||||
}
|
||||
@@ -10,10 +10,10 @@ import * as crypto from 'node:crypto';
|
||||
import { fileURLToPath } from 'node:url';
|
||||
import { Storage } from '../config/storage.js';
|
||||
import {
|
||||
ApprovalMode,
|
||||
type PolicyEngineConfig,
|
||||
PolicyDecision,
|
||||
type PolicyRule,
|
||||
ApprovalMode,
|
||||
type PolicySettings,
|
||||
type SafetyCheckerRule,
|
||||
} from './types.js';
|
||||
@@ -144,7 +144,8 @@ export function getPolicyTier(
|
||||
*/
|
||||
export function formatPolicyError(error: PolicyFileError): string {
|
||||
const tierLabel = error.tier.toUpperCase();
|
||||
let message = `[${tierLabel}] Policy file error in ${error.fileName}:\n`;
|
||||
const severityLabel = error.severity === 'warning' ? 'warning' : 'error';
|
||||
let message = `[${tierLabel}] Policy file ${severityLabel} in ${error.fileName}:\n`;
|
||||
message += ` ${error.message}`;
|
||||
if (error.details) {
|
||||
message += `\n${error.details}`;
|
||||
@@ -295,7 +296,10 @@ export async function createPolicyEngineConfig(
|
||||
// coreEvents has a buffer that will display these once the UI is ready
|
||||
if (errors.length > 0) {
|
||||
for (const error of errors) {
|
||||
coreEvents.emitFeedback('error', formatPolicyError(error));
|
||||
coreEvents.emitFeedback(
|
||||
error.severity ?? 'error',
|
||||
formatPolicyError(error),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,20 +5,21 @@
|
||||
#
|
||||
# Priority bands (tiers):
|
||||
# - Default policies (TOML): 1 + priority/1000 (e.g., priority 100 → 1.100)
|
||||
# - Workspace policies (TOML): 2 + priority/1000 (e.g., priority 100 → 2.100)
|
||||
# - User policies (TOML): 3 + priority/1000 (e.g., priority 100 → 3.100)
|
||||
# - Admin policies (TOML): 4 + priority/1000 (e.g., priority 100 → 4.100)
|
||||
# - Extension policies (TOML): 2 + priority/1000 (e.g., priority 100 → 2.100)
|
||||
# - Workspace policies (TOML): 3 + priority/1000 (e.g., priority 100 → 3.100)
|
||||
# - User policies (TOML): 4 + priority/1000 (e.g., priority 100 → 4.100)
|
||||
# - Admin policies (TOML): 5 + priority/1000 (e.g., priority 100 → 5.100)
|
||||
#
|
||||
# This ensures Admin > User > Workspace > Default hierarchy is always preserved,
|
||||
# This ensures Admin > User > Workspace > Extension > Default hierarchy is always preserved,
|
||||
# while allowing user-specified priorities to work within each tier.
|
||||
#
|
||||
# Settings-based and dynamic rules (all in user tier 3.x):
|
||||
# 3.95: Tools that the user has selected as "Always Allow" in the interactive UI
|
||||
# 3.9: MCP servers excluded list (security: persistent server blocks)
|
||||
# 3.4: Command line flag --exclude-tools (explicit temporary blocks)
|
||||
# 3.3: Command line flag --allowed-tools (explicit temporary allows)
|
||||
# 3.2: MCP servers with trust=true (persistent trusted servers)
|
||||
# 3.1: MCP servers allowed list (persistent general server allows)
|
||||
# Settings-based and dynamic rules (all in user tier 4.x):
|
||||
# 4.95: Tools that the user has selected as "Always Allow" in the interactive UI
|
||||
# 4.9: MCP servers excluded list (security: persistent server blocks)
|
||||
# 4.4: Command line flag --exclude-tools (explicit temporary blocks)
|
||||
# 4.3: Command line flag --allowed-tools (explicit temporary allows)
|
||||
# 4.2: MCP servers with trust=true (persistent trusted servers)
|
||||
# 4.1: MCP servers allowed list (persistent general server allows)
|
||||
#
|
||||
# TOML policy priorities (before transformation):
|
||||
# 10: Write tools default to ASK_USER (becomes 1.010 in default tier)
|
||||
@@ -26,6 +27,33 @@
|
||||
# 70: Plan mode explicit ALLOW override (becomes 1.070 in default tier)
|
||||
# 999: YOLO mode allow-all (becomes 1.999 in default tier)
|
||||
|
||||
# Mode Transitions (into/out of Plan Mode)
|
||||
|
||||
[[rule]]
|
||||
toolName = "enter_plan_mode"
|
||||
decision = "ask_user"
|
||||
priority = 50
|
||||
|
||||
[[rule]]
|
||||
toolName = "enter_plan_mode"
|
||||
decision = "deny"
|
||||
priority = 70
|
||||
modes = ["plan"]
|
||||
deny_message = "You are already in Plan Mode."
|
||||
|
||||
[[rule]]
|
||||
toolName = "exit_plan_mode"
|
||||
decision = "ask_user"
|
||||
priority = 70
|
||||
modes = ["plan"]
|
||||
|
||||
[[rule]]
|
||||
toolName = "exit_plan_mode"
|
||||
decision = "deny"
|
||||
priority = 50
|
||||
deny_message = "You are not currently in Plan Mode. Use enter_plan_mode first to design a plan."
|
||||
|
||||
|
||||
# Catch-All: Deny everything by default in Plan mode.
|
||||
|
||||
[[rule]]
|
||||
@@ -50,7 +78,7 @@ priority = 70
|
||||
modes = ["plan"]
|
||||
|
||||
[[rule]]
|
||||
toolName = ["ask_user", "exit_plan_mode", "save_memory"]
|
||||
toolName = ["ask_user", "save_memory"]
|
||||
decision = "ask_user"
|
||||
priority = 70
|
||||
modes = ["plan"]
|
||||
|
||||
@@ -5,20 +5,21 @@
|
||||
#
|
||||
# Priority bands (tiers):
|
||||
# - Default policies (TOML): 1 + priority/1000 (e.g., priority 100 → 1.100)
|
||||
# - Workspace policies (TOML): 2 + priority/1000 (e.g., priority 100 → 2.100)
|
||||
# - User policies (TOML): 3 + priority/1000 (e.g., priority 100 → 3.100)
|
||||
# - Admin policies (TOML): 4 + priority/1000 (e.g., priority 100 → 4.100)
|
||||
# - Extension policies (TOML): 2 + priority/1000 (e.g., priority 100 → 2.100)
|
||||
# - Workspace policies (TOML): 3 + priority/1000 (e.g., priority 100 → 3.100)
|
||||
# - User policies (TOML): 4 + priority/1000 (e.g., priority 100 → 4.100)
|
||||
# - Admin policies (TOML): 5 + priority/1000 (e.g., priority 100 → 5.100)
|
||||
#
|
||||
# This ensures Admin > User > Workspace > Default hierarchy is always preserved,
|
||||
# This ensures Admin > User > Workspace > Extension > Default hierarchy is always preserved,
|
||||
# while allowing user-specified priorities to work within each tier.
|
||||
#
|
||||
# Settings-based and dynamic rules (all in user tier 3.x):
|
||||
# 3.95: Tools that the user has selected as "Always Allow" in the interactive UI
|
||||
# 3.9: MCP servers excluded list (security: persistent server blocks)
|
||||
# 3.4: Command line flag --exclude-tools (explicit temporary blocks)
|
||||
# 3.3: Command line flag --allowed-tools (explicit temporary allows)
|
||||
# 3.2: MCP servers with trust=true (persistent trusted servers)
|
||||
# 3.1: MCP servers allowed list (persistent general server allows)
|
||||
# Settings-based and dynamic rules (all in user tier 4.x):
|
||||
# 4.95: Tools that the user has selected as "Always Allow" in the interactive UI
|
||||
# 4.9: MCP servers excluded list (security: persistent server blocks)
|
||||
# 4.4: Command line flag --exclude-tools (explicit temporary blocks)
|
||||
# 4.3: Command line flag --allowed-tools (explicit temporary allows)
|
||||
# 4.2: MCP servers with trust=true (persistent trusted servers)
|
||||
# 4.1: MCP servers allowed list (persistent general server allows)
|
||||
#
|
||||
# TOML policy priorities (before transformation):
|
||||
# 10: Write tools default to ASK_USER (becomes 1.010 in default tier)
|
||||
|
||||
@@ -5,20 +5,21 @@
|
||||
#
|
||||
# Priority bands (tiers):
|
||||
# - Default policies (TOML): 1 + priority/1000 (e.g., priority 100 → 1.100)
|
||||
# - Workspace policies (TOML): 2 + priority/1000 (e.g., priority 100 → 2.100)
|
||||
# - User policies (TOML): 3 + priority/1000 (e.g., priority 100 → 3.100)
|
||||
# - Admin policies (TOML): 4 + priority/1000 (e.g., priority 100 → 4.100)
|
||||
# - Extension policies (TOML): 2 + priority/1000 (e.g., priority 100 → 2.100)
|
||||
# - Workspace policies (TOML): 3 + priority/1000 (e.g., priority 100 → 3.100)
|
||||
# - User policies (TOML): 4 + priority/1000 (e.g., priority 100 → 4.100)
|
||||
# - Admin policies (TOML): 5 + priority/1000 (e.g., priority 100 → 5.100)
|
||||
#
|
||||
# This ensures Admin > User > Workspace > Default hierarchy is always preserved,
|
||||
# This ensures Admin > User > Workspace > Extension > Default hierarchy is always preserved,
|
||||
# while allowing user-specified priorities to work within each tier.
|
||||
#
|
||||
# Settings-based and dynamic rules (all in user tier 3.x):
|
||||
# 3.95: Tools that the user has selected as "Always Allow" in the interactive UI
|
||||
# 3.9: MCP servers excluded list (security: persistent server blocks)
|
||||
# 3.4: Command line flag --exclude-tools (explicit temporary blocks)
|
||||
# 3.3: Command line flag --allowed-tools (explicit temporary allows)
|
||||
# 3.2: MCP servers with trust=true (persistent trusted servers)
|
||||
# 3.1: MCP servers allowed list (persistent general server allows)
|
||||
# Settings-based and dynamic rules (all in user tier 4.x):
|
||||
# 4.95: Tools that the user has selected as "Always Allow" in the interactive UI
|
||||
# 4.9: MCP servers excluded list (security: persistent server blocks)
|
||||
# 4.4: Command line flag --exclude-tools (explicit temporary blocks)
|
||||
# 4.3: Command line flag --allowed-tools (explicit temporary allows)
|
||||
# 4.2: MCP servers with trust=true (persistent trusted servers)
|
||||
# 4.1: MCP servers allowed list (persistent general server allows)
|
||||
#
|
||||
# TOML policy priorities (before transformation):
|
||||
# 10: Write tools default to ASK_USER (becomes 1.010 in default tier)
|
||||
|
||||
@@ -5,20 +5,21 @@
|
||||
#
|
||||
# Priority bands (tiers):
|
||||
# - Default policies (TOML): 1 + priority/1000 (e.g., priority 100 → 1.100)
|
||||
# - Workspace policies (TOML): 2 + priority/1000 (e.g., priority 100 → 2.100)
|
||||
# - User policies (TOML): 3 + priority/1000 (e.g., priority 100 → 3.100)
|
||||
# - Admin policies (TOML): 4 + priority/1000 (e.g., priority 100 → 4.100)
|
||||
# - Extension policies (TOML): 2 + priority/1000 (e.g., priority 100 → 2.100)
|
||||
# - Workspace policies (TOML): 3 + priority/1000 (e.g., priority 100 → 3.100)
|
||||
# - User policies (TOML): 4 + priority/1000 (e.g., priority 100 → 4.100)
|
||||
# - Admin policies (TOML): 5 + priority/1000 (e.g., priority 100 → 5.100)
|
||||
#
|
||||
# This ensures Admin > User > Workspace > Default hierarchy is always preserved,
|
||||
# This ensures Admin > User > Workspace > Extension > Default hierarchy is always preserved,
|
||||
# while allowing user-specified priorities to work within each tier.
|
||||
#
|
||||
# Settings-based and dynamic rules (all in user tier 3.x):
|
||||
# 3.95: Tools that the user has selected as "Always Allow" in the interactive UI
|
||||
# 3.9: MCP servers excluded list (security: persistent server blocks)
|
||||
# 3.4: Command line flag --exclude-tools (explicit temporary blocks)
|
||||
# 3.3: Command line flag --allowed-tools (explicit temporary allows)
|
||||
# 3.2: MCP servers with trust=true (persistent trusted servers)
|
||||
# 3.1: MCP servers allowed list (persistent general server allows)
|
||||
# Settings-based and dynamic rules (all in user tier 4.x):
|
||||
# 4.95: Tools that the user has selected as "Always Allow" in the interactive UI
|
||||
# 4.9: MCP servers excluded list (security: persistent server blocks)
|
||||
# 4.4: Command line flag --exclude-tools (explicit temporary blocks)
|
||||
# 4.3: Command line flag --allowed-tools (explicit temporary allows)
|
||||
# 4.2: MCP servers with trust=true (persistent trusted servers)
|
||||
# 4.1: MCP servers allowed list (persistent general server allows)
|
||||
#
|
||||
# TOML policy priorities (before transformation):
|
||||
# 10: Write tools default to ASK_USER (becomes 1.010 in default tier)
|
||||
@@ -36,6 +37,15 @@ decision = "ask_user"
|
||||
priority = 999
|
||||
modes = ["yolo"]
|
||||
|
||||
# Plan mode transitions are blocked in YOLO mode to maintain state consistency
|
||||
# and because planning currently requires human interaction (plan approval),
|
||||
# which conflicts with YOLO's autonomous nature.
|
||||
[[rule]]
|
||||
toolName = ["enter_plan_mode", "exit_plan_mode"]
|
||||
decision = "deny"
|
||||
priority = 999
|
||||
modes = ["yolo"]
|
||||
|
||||
# Allow everything else in YOLO mode
|
||||
[[rule]]
|
||||
decision = "allow"
|
||||
|
||||
@@ -2808,6 +2808,82 @@ describe('PolicyEngine', () => {
|
||||
'Execution of scripts (including those from skills) is blocked',
|
||||
);
|
||||
});
|
||||
|
||||
it('should deny enter_plan_mode when already in PLAN mode', async () => {
|
||||
const rules: PolicyRule[] = [
|
||||
{
|
||||
toolName: 'enter_plan_mode',
|
||||
decision: PolicyDecision.DENY,
|
||||
priority: 70,
|
||||
modes: [ApprovalMode.PLAN],
|
||||
denyMessage: 'You are already in Plan Mode.',
|
||||
},
|
||||
];
|
||||
|
||||
engine = new PolicyEngine({
|
||||
rules,
|
||||
approvalMode: ApprovalMode.PLAN,
|
||||
});
|
||||
|
||||
const result = await engine.check({ name: 'enter_plan_mode' }, undefined);
|
||||
expect(result.decision).toBe(PolicyDecision.DENY);
|
||||
expect(result.rule?.denyMessage).toBe('You are already in Plan Mode.');
|
||||
});
|
||||
|
||||
it('should deny exit_plan_mode when in DEFAULT mode', async () => {
|
||||
const rules: PolicyRule[] = [
|
||||
{
|
||||
toolName: 'exit_plan_mode',
|
||||
decision: PolicyDecision.DENY,
|
||||
priority: 10,
|
||||
modes: [ApprovalMode.DEFAULT],
|
||||
denyMessage: 'You are not in Plan Mode.',
|
||||
},
|
||||
];
|
||||
|
||||
engine = new PolicyEngine({
|
||||
rules,
|
||||
approvalMode: ApprovalMode.DEFAULT,
|
||||
});
|
||||
|
||||
const result = await engine.check({ name: 'exit_plan_mode' }, undefined);
|
||||
expect(result.decision).toBe(PolicyDecision.DENY);
|
||||
expect(result.rule?.denyMessage).toBe('You are not in Plan Mode.');
|
||||
});
|
||||
|
||||
it('should deny both plan tools in YOLO mode', async () => {
|
||||
const rules: PolicyRule[] = [
|
||||
{
|
||||
toolName: 'enter_plan_mode',
|
||||
decision: PolicyDecision.DENY,
|
||||
priority: 999,
|
||||
modes: [ApprovalMode.YOLO],
|
||||
},
|
||||
{
|
||||
toolName: 'exit_plan_mode',
|
||||
decision: PolicyDecision.DENY,
|
||||
priority: 999,
|
||||
modes: [ApprovalMode.YOLO],
|
||||
},
|
||||
];
|
||||
|
||||
engine = new PolicyEngine({
|
||||
rules,
|
||||
approvalMode: ApprovalMode.YOLO,
|
||||
});
|
||||
|
||||
const resultEnter = await engine.check(
|
||||
{ name: 'enter_plan_mode' },
|
||||
undefined,
|
||||
);
|
||||
expect(resultEnter.decision).toBe(PolicyDecision.DENY);
|
||||
|
||||
const resultExit = await engine.check(
|
||||
{ name: 'exit_plan_mode' },
|
||||
undefined,
|
||||
);
|
||||
expect(resultExit.decision).toBe(PolicyDecision.DENY);
|
||||
});
|
||||
});
|
||||
|
||||
describe('removeRulesByTier', () => {
|
||||
|
||||
@@ -14,13 +14,26 @@ import * as fs from 'node:fs/promises';
|
||||
import * as path from 'node:path';
|
||||
import * as os from 'node:os';
|
||||
import { fileURLToPath } from 'node:url';
|
||||
import { loadPoliciesFromToml } from './toml-loader.js';
|
||||
import {
|
||||
loadPoliciesFromToml,
|
||||
validateMcpPolicyToolNames,
|
||||
} from './toml-loader.js';
|
||||
import type { PolicyLoadResult } from './toml-loader.js';
|
||||
import { PolicyEngine } from './policy-engine.js';
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = path.dirname(__filename);
|
||||
|
||||
/** Returns only errors (severity !== 'warning') from a PolicyLoadResult. */
|
||||
function getErrors(result: PolicyLoadResult): PolicyLoadResult['errors'] {
|
||||
return result.errors.filter((e) => e.severity !== 'warning');
|
||||
}
|
||||
|
||||
/** Returns only warnings (severity === 'warning') from a PolicyLoadResult. */
|
||||
function getWarnings(result: PolicyLoadResult): PolicyLoadResult['errors'] {
|
||||
return result.errors.filter((e) => e.severity === 'warning');
|
||||
}
|
||||
|
||||
describe('policy-toml-loader', () => {
|
||||
let tempDir: string;
|
||||
|
||||
@@ -189,7 +202,7 @@ priority = 100
|
||||
'grep',
|
||||
'read',
|
||||
]);
|
||||
expect(result.errors).toHaveLength(0);
|
||||
expect(getErrors(result)).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should transform mcpName to composite toolName', async () => {
|
||||
@@ -228,7 +241,7 @@ modes = ["yolo"]
|
||||
expect(result.rules[0].modes).toEqual(['default', 'yolo']);
|
||||
expect(result.rules[1].toolName).toBe('grep');
|
||||
expect(result.rules[1].modes).toEqual(['yolo']);
|
||||
expect(result.errors).toHaveLength(0);
|
||||
expect(getErrors(result)).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should parse and transform allow_redirection property', async () => {
|
||||
@@ -259,7 +272,7 @@ deny_message = "Deletion is permanent"
|
||||
expect(result.rules[0].toolName).toBe('rm');
|
||||
expect(result.rules[0].decision).toBe(PolicyDecision.DENY);
|
||||
expect(result.rules[0].denyMessage).toBe('Deletion is permanent');
|
||||
expect(result.errors).toHaveLength(0);
|
||||
expect(getErrors(result)).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should support modes property for Tier 4 and Tier 5 policies', async () => {
|
||||
@@ -547,8 +560,8 @@ commandRegex = ".*"
|
||||
decision = "allow"
|
||||
priority = 100
|
||||
`);
|
||||
expect(result.errors).toHaveLength(1);
|
||||
const error = result.errors[0];
|
||||
expect(getErrors(result)).toHaveLength(1);
|
||||
const error = getErrors(result)[0];
|
||||
expect(error.errorType).toBe('rule_validation');
|
||||
expect(error.details).toContain('run_shell_command');
|
||||
});
|
||||
@@ -576,8 +589,8 @@ argsPattern = "([a-z)"
|
||||
decision = "allow"
|
||||
priority = 100
|
||||
`);
|
||||
expect(result.errors).toHaveLength(1);
|
||||
const error = result.errors[0];
|
||||
expect(getErrors(result)).toHaveLength(1);
|
||||
const error = getErrors(result)[0];
|
||||
expect(error.errorType).toBe('regex_compilation');
|
||||
expect(error.message).toBe('Invalid regex pattern');
|
||||
});
|
||||
@@ -592,7 +605,7 @@ priority = 100
|
||||
const getPolicyTier = (_dir: string) => 1;
|
||||
const result = await loadPoliciesFromToml([filePath], getPolicyTier);
|
||||
|
||||
expect(result.errors).toHaveLength(0);
|
||||
expect(getErrors(result)).toHaveLength(0);
|
||||
expect(result.rules).toHaveLength(1);
|
||||
expect(result.rules[0].toolName).toBe('test-tool');
|
||||
expect(result.rules[0].decision).toBe(PolicyDecision.ALLOW);
|
||||
@@ -612,6 +625,177 @@ priority = 100
|
||||
});
|
||||
});
|
||||
|
||||
describe('Tool name validation', () => {
|
||||
it('should warn for unrecognized tool names with suggestions', async () => {
|
||||
const result = await runLoadPoliciesFromToml(`
|
||||
[[rule]]
|
||||
toolName = "grob"
|
||||
decision = "allow"
|
||||
priority = 100
|
||||
`);
|
||||
|
||||
const warnings = getWarnings(result);
|
||||
expect(warnings).toHaveLength(1);
|
||||
expect(warnings[0].errorType).toBe('tool_name_warning');
|
||||
expect(warnings[0].severity).toBe('warning');
|
||||
expect(warnings[0].details).toContain('Unrecognized tool name "grob"');
|
||||
expect(warnings[0].details).toContain('glob');
|
||||
// Rules should still load despite warnings
|
||||
expect(result.rules).toHaveLength(1);
|
||||
expect(result.rules[0].toolName).toBe('grob');
|
||||
});
|
||||
|
||||
it('should not warn for valid built-in tool names', async () => {
|
||||
const result = await runLoadPoliciesFromToml(`
|
||||
[[rule]]
|
||||
toolName = "glob"
|
||||
decision = "allow"
|
||||
priority = 100
|
||||
|
||||
[[rule]]
|
||||
toolName = "read_file"
|
||||
decision = "allow"
|
||||
priority = 100
|
||||
`);
|
||||
|
||||
expect(getWarnings(result)).toHaveLength(0);
|
||||
expect(getErrors(result)).toHaveLength(0);
|
||||
expect(result.rules).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('should not warn for wildcard "*"', async () => {
|
||||
const result = await runLoadPoliciesFromToml(`
|
||||
[[rule]]
|
||||
toolName = "*"
|
||||
decision = "allow"
|
||||
priority = 100
|
||||
`);
|
||||
|
||||
expect(getWarnings(result)).toHaveLength(0);
|
||||
expect(getErrors(result)).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should not warn for MCP format tool names', async () => {
|
||||
const result = await runLoadPoliciesFromToml(`
|
||||
[[rule]]
|
||||
toolName = "my-server__my-tool"
|
||||
decision = "allow"
|
||||
priority = 100
|
||||
|
||||
[[rule]]
|
||||
toolName = "my-server__*"
|
||||
decision = "allow"
|
||||
priority = 100
|
||||
`);
|
||||
|
||||
expect(getWarnings(result)).toHaveLength(0);
|
||||
expect(getErrors(result)).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should not warn when mcpName is present (skips validation)', async () => {
|
||||
const result = await runLoadPoliciesFromToml(`
|
||||
[[rule]]
|
||||
mcpName = "my-server"
|
||||
toolName = "nonexistent"
|
||||
decision = "allow"
|
||||
priority = 100
|
||||
`);
|
||||
|
||||
expect(getWarnings(result)).toHaveLength(0);
|
||||
expect(getErrors(result)).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should not warn for legacy aliases', async () => {
|
||||
const result = await runLoadPoliciesFromToml(`
|
||||
[[rule]]
|
||||
toolName = "search_file_content"
|
||||
decision = "allow"
|
||||
priority = 100
|
||||
`);
|
||||
|
||||
expect(getWarnings(result)).toHaveLength(0);
|
||||
expect(getErrors(result)).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should not warn for discovered tool prefix', async () => {
|
||||
const result = await runLoadPoliciesFromToml(`
|
||||
[[rule]]
|
||||
toolName = "discovered_tool_my_custom_tool"
|
||||
decision = "allow"
|
||||
priority = 100
|
||||
`);
|
||||
|
||||
expect(getWarnings(result)).toHaveLength(0);
|
||||
expect(getErrors(result)).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should warn for each invalid name in a toolName array', async () => {
|
||||
const result = await runLoadPoliciesFromToml(`
|
||||
[[rule]]
|
||||
toolName = ["grob", "glob", "replce"]
|
||||
decision = "allow"
|
||||
priority = 100
|
||||
`);
|
||||
|
||||
const warnings = getWarnings(result);
|
||||
expect(warnings).toHaveLength(2);
|
||||
expect(warnings[0].details).toContain('"grob"');
|
||||
expect(warnings[1].details).toContain('"replce"');
|
||||
// All rules still load
|
||||
expect(result.rules).toHaveLength(3);
|
||||
});
|
||||
|
||||
it('should not warn for names far from any built-in (dynamic/agent tools)', async () => {
|
||||
const result = await runLoadPoliciesFromToml(`
|
||||
[[rule]]
|
||||
toolName = "delegate_to_agent"
|
||||
decision = "allow"
|
||||
priority = 100
|
||||
|
||||
[[rule]]
|
||||
toolName = "my_custom_tool"
|
||||
decision = "allow"
|
||||
priority = 100
|
||||
`);
|
||||
|
||||
expect(getWarnings(result)).toHaveLength(0);
|
||||
expect(getErrors(result)).toHaveLength(0);
|
||||
expect(result.rules).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('should not warn for catch-all rules (no toolName)', async () => {
|
||||
const result = await runLoadPoliciesFromToml(`
|
||||
[[rule]]
|
||||
decision = "deny"
|
||||
priority = 100
|
||||
`);
|
||||
|
||||
expect(getWarnings(result)).toHaveLength(0);
|
||||
expect(getErrors(result)).toHaveLength(0);
|
||||
expect(result.rules).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('should still load rules even with warnings', async () => {
|
||||
const result = await runLoadPoliciesFromToml(`
|
||||
[[rule]]
|
||||
toolName = "wrte_file"
|
||||
decision = "deny"
|
||||
priority = 50
|
||||
|
||||
[[rule]]
|
||||
toolName = "glob"
|
||||
decision = "allow"
|
||||
priority = 100
|
||||
`);
|
||||
|
||||
expect(getWarnings(result)).toHaveLength(1);
|
||||
expect(getErrors(result)).toHaveLength(0);
|
||||
expect(result.rules).toHaveLength(2);
|
||||
expect(result.rules[0].toolName).toBe('wrte_file');
|
||||
expect(result.rules[1].toolName).toBe('glob');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Built-in Plan Mode Policy', () => {
|
||||
it('should allow MCP tools with readOnlyHint annotation in Plan Mode (ASK_USER, not DENY)', async () => {
|
||||
const planTomlPath = path.resolve(__dirname, 'policies', 'plan.toml');
|
||||
@@ -779,4 +963,88 @@ priority = 100
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateMcpPolicyToolNames', () => {
|
||||
it('should warn for MCP tool names that are likely typos', () => {
|
||||
const warnings = validateMcpPolicyToolNames(
|
||||
'google-workspace',
|
||||
['people.getMe', 'calendar.list', 'calendar.get'],
|
||||
[
|
||||
{
|
||||
toolName: 'google-workspace__people.getxMe',
|
||||
source: 'User: workspace.toml',
|
||||
},
|
||||
],
|
||||
);
|
||||
|
||||
expect(warnings).toHaveLength(1);
|
||||
expect(warnings[0]).toContain('people.getxMe');
|
||||
expect(warnings[0]).toContain('google-workspace');
|
||||
expect(warnings[0]).toContain('people.getMe');
|
||||
});
|
||||
|
||||
it('should not warn for matching MCP tool names', () => {
|
||||
const warnings = validateMcpPolicyToolNames(
|
||||
'google-workspace',
|
||||
['people.getMe', 'calendar.list'],
|
||||
[
|
||||
{ toolName: 'google-workspace__people.getMe' },
|
||||
{ toolName: 'google-workspace__calendar.list' },
|
||||
],
|
||||
);
|
||||
|
||||
expect(warnings).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should not warn for wildcard MCP rules', () => {
|
||||
const warnings = validateMcpPolicyToolNames(
|
||||
'my-server',
|
||||
['tool1', 'tool2'],
|
||||
[{ toolName: 'my-server__*' }],
|
||||
);
|
||||
|
||||
expect(warnings).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should not warn for rules targeting other servers', () => {
|
||||
const warnings = validateMcpPolicyToolNames(
|
||||
'server-a',
|
||||
['tool1'],
|
||||
[{ toolName: 'server-b__toolx' }],
|
||||
);
|
||||
|
||||
expect(warnings).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should not warn for tool names far from any discovered tool', () => {
|
||||
const warnings = validateMcpPolicyToolNames(
|
||||
'my-server',
|
||||
['tool1', 'tool2'],
|
||||
[{ toolName: 'my-server__completely_different_name' }],
|
||||
);
|
||||
|
||||
expect(warnings).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should skip rules without toolName', () => {
|
||||
const warnings = validateMcpPolicyToolNames(
|
||||
'my-server',
|
||||
['tool1'],
|
||||
[{ toolName: undefined }],
|
||||
);
|
||||
|
||||
expect(warnings).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should include source in warning when available', () => {
|
||||
const warnings = validateMcpPolicyToolNames(
|
||||
'my-server',
|
||||
['tool1'],
|
||||
[{ toolName: 'my-server__tol1', source: 'User: custom.toml' }],
|
||||
);
|
||||
|
||||
expect(warnings).toHaveLength(1);
|
||||
expect(warnings[0]).toContain('User: custom.toml');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -13,12 +13,25 @@ import {
|
||||
InProcessCheckerType,
|
||||
} from './types.js';
|
||||
import { buildArgsPatterns, isSafeRegExp } from './utils.js';
|
||||
import {
|
||||
isValidToolName,
|
||||
ALL_BUILTIN_TOOL_NAMES,
|
||||
} from '../tools/tool-names.js';
|
||||
import { getToolSuggestion } from '../utils/tool-utils.js';
|
||||
import levenshtein from 'fast-levenshtein';
|
||||
import fs from 'node:fs/promises';
|
||||
import path from 'node:path';
|
||||
import toml from '@iarna/toml';
|
||||
import { z, type ZodError } from 'zod';
|
||||
import { isNodeError } from '../utils/errors.js';
|
||||
|
||||
/**
|
||||
* Maximum Levenshtein distance to consider a name a likely typo of a built-in tool.
|
||||
* Names further from all built-in tools are assumed to be intentional
|
||||
* (e.g., dynamically registered agent tools) and are not warned about.
|
||||
*/
|
||||
const MAX_TYPO_DISTANCE = 3;
|
||||
|
||||
/**
|
||||
* Schema for a single policy rule in the TOML file (before transformation).
|
||||
*/
|
||||
@@ -100,7 +113,8 @@ export type PolicyFileErrorType =
|
||||
| 'toml_parse'
|
||||
| 'schema_validation'
|
||||
| 'rule_validation'
|
||||
| 'regex_compilation';
|
||||
| 'regex_compilation'
|
||||
| 'tool_name_warning';
|
||||
|
||||
/**
|
||||
* Detailed error information for policy file loading failures.
|
||||
@@ -114,6 +128,7 @@ export interface PolicyFileError {
|
||||
message: string;
|
||||
details?: string;
|
||||
suggestion?: string;
|
||||
severity?: 'error' | 'warning';
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -241,6 +256,36 @@ function validateShellCommandSyntax(
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates that a tool name is recognized.
|
||||
* Returns a warning message if the tool name is a likely typo of a built-in
|
||||
* tool name, or null if valid or not close to any built-in name.
|
||||
*/
|
||||
function validateToolName(name: string, ruleIndex: number): string | null {
|
||||
// A name that looks like an MCP tool (e.g., "re__ad") could be a typo of a
|
||||
// built-in tool ("read_file"). We should let such names fall through to the
|
||||
// Levenshtein distance check below. Non-MCP-like names that are valid can
|
||||
// be safely skipped.
|
||||
if (isValidToolName(name, { allowWildcards: true }) && !name.includes('__')) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Only warn if the name is close to a built-in name (likely typo).
|
||||
// Names that are very different from all built-in names are likely
|
||||
// intentional (dynamic tools, agent tools, etc.).
|
||||
const allNames = [...ALL_BUILTIN_TOOL_NAMES];
|
||||
const minDistance = Math.min(
|
||||
...allNames.map((n) => levenshtein.get(name, n)),
|
||||
);
|
||||
|
||||
if (minDistance > MAX_TYPO_DISTANCE) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const suggestion = getToolSuggestion(name, allNames);
|
||||
return `Rule #${ruleIndex + 1}: Unrecognized tool name "${name}".${suggestion}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Transforms a priority number based on the policy tier.
|
||||
* Formula: tier + priority/1000
|
||||
@@ -354,6 +399,35 @@ export async function loadPoliciesFromToml(
|
||||
}
|
||||
}
|
||||
|
||||
// Validate tool names in rules
|
||||
for (let i = 0; i < tomlRules.length; i++) {
|
||||
const rule = tomlRules[i];
|
||||
// Skip MCP-scoped rules — MCP tool names are server-defined and dynamic
|
||||
if (rule.mcpName) continue;
|
||||
|
||||
const toolNames: string[] = rule.toolName
|
||||
? Array.isArray(rule.toolName)
|
||||
? rule.toolName
|
||||
: [rule.toolName]
|
||||
: [];
|
||||
|
||||
for (const name of toolNames) {
|
||||
const warning = validateToolName(name, i);
|
||||
if (warning) {
|
||||
errors.push({
|
||||
filePath,
|
||||
fileName: file,
|
||||
tier: tierName,
|
||||
ruleIndex: i,
|
||||
errorType: 'tool_name_warning',
|
||||
message: 'Unrecognized tool name',
|
||||
details: warning,
|
||||
severity: 'warning',
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Transform rules
|
||||
const parsedRules: PolicyRule[] = (validationResult.data.rule ?? [])
|
||||
.flatMap((rule) => {
|
||||
@@ -439,6 +513,35 @@ export async function loadPoliciesFromToml(
|
||||
|
||||
rules.push(...parsedRules);
|
||||
|
||||
// Validate tool names in safety checker rules
|
||||
const tomlCheckerRules = validationResult.data.safety_checker ?? [];
|
||||
for (let i = 0; i < tomlCheckerRules.length; i++) {
|
||||
const checker = tomlCheckerRules[i];
|
||||
if (checker.mcpName) continue;
|
||||
|
||||
const checkerToolNames: string[] = checker.toolName
|
||||
? Array.isArray(checker.toolName)
|
||||
? checker.toolName
|
||||
: [checker.toolName]
|
||||
: [];
|
||||
|
||||
for (const name of checkerToolNames) {
|
||||
const warning = validateToolName(name, i);
|
||||
if (warning) {
|
||||
errors.push({
|
||||
filePath,
|
||||
fileName: file,
|
||||
tier: tierName,
|
||||
ruleIndex: i,
|
||||
errorType: 'tool_name_warning',
|
||||
message: 'Unrecognized tool name in safety checker',
|
||||
details: warning,
|
||||
severity: 'warning',
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Transform checkers
|
||||
const parsedCheckers: SafetyCheckerRule[] = (
|
||||
validationResult.data.safety_checker ?? []
|
||||
@@ -535,3 +638,55 @@ export async function loadPoliciesFromToml(
|
||||
|
||||
return { rules, checkers, errors };
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates MCP tool names in policy rules against actually discovered MCP tools.
|
||||
* Called after an MCP server connects and its tools are discovered.
|
||||
*
|
||||
* For each policy rule that references the given MCP server, checks if the
|
||||
* tool name matches any discovered tool. Emits warnings for likely typos
|
||||
* using Levenshtein distance.
|
||||
*
|
||||
* @param serverName The MCP server name (e.g., "google-workspace")
|
||||
* @param discoveredToolNames The tool names discovered from this server (simple names, not fully qualified)
|
||||
* @param policyRules The current set of policy rules to validate against
|
||||
* @returns Array of warning messages for unrecognized MCP tool names
|
||||
*/
|
||||
export function validateMcpPolicyToolNames(
|
||||
serverName: string,
|
||||
discoveredToolNames: string[],
|
||||
policyRules: ReadonlyArray<{ toolName?: string; source?: string }>,
|
||||
): string[] {
|
||||
const prefix = `${serverName}__`;
|
||||
const warnings: string[] = [];
|
||||
|
||||
for (const rule of policyRules) {
|
||||
if (!rule.toolName) continue;
|
||||
if (!rule.toolName.startsWith(prefix)) continue;
|
||||
|
||||
const toolPart = rule.toolName.slice(prefix.length);
|
||||
|
||||
// Skip wildcards
|
||||
if (toolPart === '*') continue;
|
||||
|
||||
// Check if the tool exists
|
||||
if (discoveredToolNames.includes(toolPart)) continue;
|
||||
|
||||
// Tool not found — check if it's a likely typo
|
||||
if (discoveredToolNames.length === 0) continue;
|
||||
|
||||
const minDistance = Math.min(
|
||||
...discoveredToolNames.map((n) => levenshtein.get(toolPart, n)),
|
||||
);
|
||||
|
||||
if (minDistance > MAX_TYPO_DISTANCE) continue;
|
||||
|
||||
const suggestion = getToolSuggestion(toolPart, discoveredToolNames);
|
||||
const source = rule.source ? ` (from ${rule.source})` : '';
|
||||
warnings.push(
|
||||
`Unrecognized MCP tool "${toolPart}" for server "${serverName}"${source}.${suggestion}`,
|
||||
);
|
||||
}
|
||||
|
||||
return warnings;
|
||||
}
|
||||
|
||||
@@ -16,6 +16,8 @@ import { MockTool } from '../test-utils/mock-tool.js';
|
||||
import type { ScheduledToolCall } from './types.js';
|
||||
import { CoreToolCallStatus } from './types.js';
|
||||
import { SHELL_TOOL_NAME } from '../tools/tool-names.js';
|
||||
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
||||
import type { CallableTool } from '@google/genai';
|
||||
import * as fileUtils from '../utils/fileUtils.js';
|
||||
import * as coreToolHookTriggers from '../core/coreToolHookTriggers.js';
|
||||
import { ShellToolInvocation } from '../tools/shell.js';
|
||||
@@ -312,6 +314,162 @@ describe('ToolExecutor', () => {
|
||||
}
|
||||
});
|
||||
|
||||
it('should truncate large MCP tool output with single text Part', async () => {
|
||||
// 1. Setup Config for Truncation
|
||||
vi.spyOn(config, 'getTruncateToolOutputThreshold').mockReturnValue(10);
|
||||
vi.spyOn(config.storage, 'getProjectTempDir').mockReturnValue('/tmp');
|
||||
|
||||
const mcpToolName = 'get_big_text';
|
||||
const messageBus = createMockMessageBus();
|
||||
const mcpTool = new DiscoveredMCPTool(
|
||||
{} as CallableTool,
|
||||
'my-server',
|
||||
'get_big_text',
|
||||
'A test MCP tool',
|
||||
{},
|
||||
messageBus,
|
||||
);
|
||||
const invocation = mcpTool.build({});
|
||||
const longText = 'This is a very long MCP output that should be truncated.';
|
||||
|
||||
// 2. Mock execution returning Part[] with single text Part
|
||||
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockResolvedValue({
|
||||
llmContent: [{ text: longText }],
|
||||
returnDisplay: longText,
|
||||
});
|
||||
|
||||
const scheduledCall: ScheduledToolCall = {
|
||||
status: CoreToolCallStatus.Scheduled,
|
||||
request: {
|
||||
callId: 'call-mcp-trunc',
|
||||
name: mcpToolName,
|
||||
args: { query: 'test' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-mcp-trunc',
|
||||
},
|
||||
tool: mcpTool,
|
||||
invocation: invocation as unknown as AnyToolInvocation,
|
||||
startTime: Date.now(),
|
||||
};
|
||||
|
||||
// 3. Execute
|
||||
const result = await executor.execute({
|
||||
call: scheduledCall,
|
||||
signal: new AbortController().signal,
|
||||
onUpdateToolCall: vi.fn(),
|
||||
});
|
||||
|
||||
// 4. Verify Truncation Logic
|
||||
expect(fileUtils.saveTruncatedToolOutput).toHaveBeenCalledWith(
|
||||
longText,
|
||||
mcpToolName,
|
||||
'call-mcp-trunc',
|
||||
expect.any(String),
|
||||
'test-session-id',
|
||||
);
|
||||
|
||||
expect(fileUtils.formatTruncatedToolOutput).toHaveBeenCalledWith(
|
||||
longText,
|
||||
'/tmp/truncated_output.txt',
|
||||
10,
|
||||
);
|
||||
|
||||
expect(result.status).toBe(CoreToolCallStatus.Success);
|
||||
if (result.status === CoreToolCallStatus.Success) {
|
||||
expect(result.response.outputFile).toBe('/tmp/truncated_output.txt');
|
||||
}
|
||||
});
|
||||
|
||||
it('should not truncate MCP tool output with multiple Parts', async () => {
|
||||
vi.spyOn(config, 'getTruncateToolOutputThreshold').mockReturnValue(10);
|
||||
|
||||
const messageBus = createMockMessageBus();
|
||||
const mcpTool = new DiscoveredMCPTool(
|
||||
{} as CallableTool,
|
||||
'my-server',
|
||||
'get_big_text',
|
||||
'A test MCP tool',
|
||||
{},
|
||||
messageBus,
|
||||
);
|
||||
const invocation = mcpTool.build({});
|
||||
const longText = 'This is long text that exceeds the threshold.';
|
||||
|
||||
// Part[] with multiple parts — should NOT be truncated
|
||||
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockResolvedValue({
|
||||
llmContent: [{ text: longText }, { text: 'second part' }],
|
||||
returnDisplay: longText,
|
||||
});
|
||||
|
||||
const scheduledCall: ScheduledToolCall = {
|
||||
status: CoreToolCallStatus.Scheduled,
|
||||
request: {
|
||||
callId: 'call-mcp-multi',
|
||||
name: 'get_big_text',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-mcp-multi',
|
||||
},
|
||||
tool: mcpTool,
|
||||
invocation: invocation as unknown as AnyToolInvocation,
|
||||
startTime: Date.now(),
|
||||
};
|
||||
|
||||
const result = await executor.execute({
|
||||
call: scheduledCall,
|
||||
signal: new AbortController().signal,
|
||||
onUpdateToolCall: vi.fn(),
|
||||
});
|
||||
|
||||
// Should NOT have been truncated
|
||||
expect(fileUtils.saveTruncatedToolOutput).not.toHaveBeenCalled();
|
||||
expect(fileUtils.formatTruncatedToolOutput).not.toHaveBeenCalled();
|
||||
expect(result.status).toBe(CoreToolCallStatus.Success);
|
||||
});
|
||||
|
||||
it('should not truncate MCP tool output when text is below threshold', async () => {
|
||||
vi.spyOn(config, 'getTruncateToolOutputThreshold').mockReturnValue(10000);
|
||||
|
||||
const messageBus = createMockMessageBus();
|
||||
const mcpTool = new DiscoveredMCPTool(
|
||||
{} as CallableTool,
|
||||
'my-server',
|
||||
'get_big_text',
|
||||
'A test MCP tool',
|
||||
{},
|
||||
messageBus,
|
||||
);
|
||||
const invocation = mcpTool.build({});
|
||||
|
||||
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockResolvedValue({
|
||||
llmContent: [{ text: 'short' }],
|
||||
returnDisplay: 'short',
|
||||
});
|
||||
|
||||
const scheduledCall: ScheduledToolCall = {
|
||||
status: CoreToolCallStatus.Scheduled,
|
||||
request: {
|
||||
callId: 'call-mcp-short',
|
||||
name: 'get_big_text',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-mcp-short',
|
||||
},
|
||||
tool: mcpTool,
|
||||
invocation: invocation as unknown as AnyToolInvocation,
|
||||
startTime: Date.now(),
|
||||
};
|
||||
|
||||
const result = await executor.execute({
|
||||
call: scheduledCall,
|
||||
signal: new AbortController().signal,
|
||||
onUpdateToolCall: vi.fn(),
|
||||
});
|
||||
|
||||
expect(fileUtils.saveTruncatedToolOutput).not.toHaveBeenCalled();
|
||||
expect(result.status).toBe(CoreToolCallStatus.Success);
|
||||
});
|
||||
|
||||
it('should report PID updates for shell tools', async () => {
|
||||
// 1. Setup ShellToolInvocation
|
||||
const messageBus = createMockMessageBus();
|
||||
|
||||
@@ -9,7 +9,8 @@ import type {
|
||||
ToolCallResponseInfo,
|
||||
ToolResult,
|
||||
Config,
|
||||
AnsiOutput,
|
||||
ToolResultDisplay,
|
||||
ToolLiveOutput,
|
||||
} from '../index.js';
|
||||
import {
|
||||
ToolErrorType,
|
||||
@@ -18,6 +19,7 @@ import {
|
||||
runInDevTraceSpan,
|
||||
} from '../index.js';
|
||||
import { SHELL_TOOL_NAME } from '../tools/tool-names.js';
|
||||
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
||||
import { ShellToolInvocation } from '../tools/shell.js';
|
||||
import { executeToolWithHooks } from '../core/coreToolHookTriggers.js';
|
||||
import {
|
||||
@@ -44,7 +46,7 @@ import {
|
||||
export interface ToolExecutionContext {
|
||||
call: ToolCall;
|
||||
signal: AbortSignal;
|
||||
outputUpdateHandler?: (callId: string, output: string | AnsiOutput) => void;
|
||||
outputUpdateHandler?: (callId: string, output: ToolLiveOutput) => void;
|
||||
onUpdateToolCall: (updatedCall: ToolCall) => void;
|
||||
}
|
||||
|
||||
@@ -67,7 +69,7 @@ export class ToolExecutor {
|
||||
// Setup live output handling
|
||||
const liveOutputCallback =
|
||||
tool.canUpdateOutput && outputUpdateHandler
|
||||
? (outputChunk: string | AnsiOutput) => {
|
||||
? (outputChunk: ToolLiveOutput) => {
|
||||
outputUpdateHandler(callId, outputChunk);
|
||||
}
|
||||
: undefined;
|
||||
@@ -133,6 +135,7 @@ export class ToolExecutor {
|
||||
completedToolCall = this.createCancelledResult(
|
||||
call,
|
||||
'User cancelled tool execution.',
|
||||
toolResult.returnDisplay,
|
||||
);
|
||||
} else if (toolResult.error === undefined) {
|
||||
completedToolCall = await this.createSuccessResult(
|
||||
@@ -154,7 +157,12 @@ export class ToolExecutor {
|
||||
}
|
||||
} catch (executionError: unknown) {
|
||||
spanMetadata.error = executionError;
|
||||
if (signal.aborted) {
|
||||
const isAbortError =
|
||||
executionError instanceof Error &&
|
||||
(executionError.name === 'AbortError' ||
|
||||
executionError.message.includes('Operation cancelled by user'));
|
||||
|
||||
if (signal.aborted || isAbortError) {
|
||||
completedToolCall = this.createCancelledResult(
|
||||
call,
|
||||
'User cancelled tool execution.',
|
||||
@@ -181,6 +189,7 @@ export class ToolExecutor {
|
||||
private createCancelledResult(
|
||||
call: ToolCall,
|
||||
reason: string,
|
||||
resultDisplay?: ToolResultDisplay,
|
||||
): CancelledToolCall {
|
||||
const errorMessage = `[Operation Cancelled] ${reason}`;
|
||||
const startTime = 'startTime' in call ? call.startTime : undefined;
|
||||
@@ -205,7 +214,7 @@ export class ToolExecutor {
|
||||
},
|
||||
},
|
||||
],
|
||||
resultDisplay: undefined,
|
||||
resultDisplay,
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: errorMessage.length,
|
||||
@@ -253,6 +262,45 @@ export class ToolExecutor {
|
||||
}),
|
||||
);
|
||||
}
|
||||
} else if (
|
||||
Array.isArray(content) &&
|
||||
content.length === 1 &&
|
||||
'tool' in call &&
|
||||
call.tool instanceof DiscoveredMCPTool
|
||||
) {
|
||||
const firstPart = content[0];
|
||||
if (typeof firstPart === 'object' && typeof firstPart.text === 'string') {
|
||||
const textContent = firstPart.text;
|
||||
const threshold = this.config.getTruncateToolOutputThreshold();
|
||||
|
||||
if (threshold > 0 && textContent.length > threshold) {
|
||||
const originalContentLength = textContent.length;
|
||||
const { outputFile: savedPath } = await saveTruncatedToolOutput(
|
||||
textContent,
|
||||
toolName,
|
||||
callId,
|
||||
this.config.storage.getProjectTempDir(),
|
||||
this.config.getSessionId(),
|
||||
);
|
||||
outputFile = savedPath;
|
||||
const truncatedText = formatTruncatedToolOutput(
|
||||
textContent,
|
||||
outputFile,
|
||||
threshold,
|
||||
);
|
||||
content[0] = { ...firstPart, text: truncatedText };
|
||||
|
||||
logToolOutputTruncated(
|
||||
this.config,
|
||||
new ToolOutputTruncatedEvent(call.request.prompt_id, {
|
||||
toolName,
|
||||
originalContentLength,
|
||||
truncatedContentLength: truncatedText.length,
|
||||
threshold,
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const response = convertToFunctionResponse(
|
||||
|
||||
@@ -11,8 +11,8 @@ import type {
|
||||
ToolCallConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
ToolResultDisplay,
|
||||
ToolLiveOutput,
|
||||
} from '../tools/tools.js';
|
||||
import type { AnsiOutput } from '../utils/terminalSerializer.js';
|
||||
import type { ToolErrorType } from '../tools/tool-error.js';
|
||||
import type { SerializableConfirmationDetails } from '../confirmation-bus/types.js';
|
||||
import { type ApprovalMode } from '../policy/types.js';
|
||||
@@ -125,7 +125,7 @@ export type ExecutingToolCall = {
|
||||
request: ToolCallRequestInfo;
|
||||
tool: AnyDeclarativeTool;
|
||||
invocation: AnyToolInvocation;
|
||||
liveOutput?: string | AnsiOutput;
|
||||
liveOutput?: ToolLiveOutput;
|
||||
progressMessage?: string;
|
||||
progressPercent?: number;
|
||||
progress?: number;
|
||||
@@ -197,7 +197,7 @@ export type ConfirmHandler = (
|
||||
|
||||
export type OutputUpdateHandler = (
|
||||
toolCallId: string,
|
||||
outputChunk: string | AnsiOutput,
|
||||
outputChunk: ToolLiveOutput,
|
||||
) => void;
|
||||
|
||||
export type AllToolCallsCompleteHandler = (
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
|
||||
|
||||
exports[`ReadFileTool > getSchema > should return the Gemini 3 schema when a Gemini 3 modelId is provided 1`] = `"Reads and returns the content of a specified file. To maintain context efficiency, you MUST use 'start_line' and 'end_line' for targeted, surgical reads of specific sections. For your safety, the tool will automatically truncate output exceeding 2000 lines, 2000 characters per line, or 20MB in size; however, triggering these limits is considered token-inefficient. Always retrieve only the minimum content necessary for your next step. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), audio files (MP3, WAV, AIFF, AAC, OGG, FLAC), and PDF files."`;
|
||||
|
||||
exports[`ReadFileTool > getSchema > should return the base schema when no modelId is provided 1`] = `"Reads and returns the content of a specified file. If the file is large, the content will be truncated. The tool's response will clearly indicate if truncation has occurred and will provide details on how to read more of the file using the 'start_line' and 'end_line' parameters. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), audio files (MP3, WAV, AIFF, AAC, OGG, FLAC), and PDF files. For text files, it can read specific line ranges."`;
|
||||
|
||||
exports[`ReadFileTool > getSchema > should return the schema from the resolver when modelId is provided 1`] = `"Reads and returns the content of a specified file. If the file is large, the content will be truncated. The tool's response will clearly indicate if truncation has occurred and will provide details on how to read more of the file using the 'start_line' and 'end_line' parameters. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), audio files (MP3, WAV, AIFF, AAC, OGG, FLAC), and PDF files. For text files, it can read specific line ranges."`;
|
||||
|
||||
@@ -28,9 +28,11 @@ export class AskUserTool extends BaseDeclarativeTool<
|
||||
AskUserParams,
|
||||
ToolResult
|
||||
> {
|
||||
static readonly Name = ASK_USER_TOOL_NAME;
|
||||
|
||||
constructor(messageBus: MessageBus) {
|
||||
super(
|
||||
ASK_USER_TOOL_NAME,
|
||||
AskUserTool.Name,
|
||||
ASK_USER_DISPLAY_NAME,
|
||||
ASK_USER_DEFINITION.base.description!,
|
||||
Kind.Communicate,
|
||||
|
||||
+1
-1
@@ -1197,7 +1197,7 @@ exports[`coreTools snapshots for specific models > Model: gemini-3-pro-preview >
|
||||
|
||||
exports[`coreTools snapshots for specific models > Model: gemini-3-pro-preview > snapshot for tool: read_file 1`] = `
|
||||
{
|
||||
"description": "Reads and returns the content of a specified file. If the file is large, the content will be truncated. The tool's response will clearly indicate if truncation has occurred and will provide details on how to read more of the file using the 'start_line' and 'end_line' parameters. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), audio files (MP3, WAV, AIFF, AAC, OGG, FLAC), and PDF files. For text files, it can read specific line ranges.",
|
||||
"description": "Reads and returns the content of a specified file. To maintain context efficiency, you MUST use 'start_line' and 'end_line' for targeted, surgical reads of specific sections. For your safety, the tool will automatically truncate output exceeding 2000 lines, 2000 characters per line, or 20MB in size; however, triggering these limits is considered token-inefficient. Always retrieve only the minimum content necessary for your next step. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), audio files (MP3, WAV, AIFF, AAC, OGG, FLAC), and PDF files.",
|
||||
"name": "read_file",
|
||||
"parametersJsonSchema": {
|
||||
"properties": {
|
||||
|
||||
@@ -79,6 +79,11 @@ import {
|
||||
getExitPlanModeDeclaration,
|
||||
getActivateSkillDeclaration,
|
||||
} from '../dynamic-declaration-helpers.js';
|
||||
import {
|
||||
DEFAULT_MAX_LINES_TEXT_FILE,
|
||||
MAX_LINE_LENGTH_TEXT_FILE,
|
||||
MAX_FILE_SIZE_MB,
|
||||
} from '../../../utils/constants.js';
|
||||
|
||||
/**
|
||||
* Gemini 3 tool set. Initially a copy of the default legacy set.
|
||||
@@ -86,7 +91,7 @@ import {
|
||||
export const GEMINI_3_SET: CoreToolSet = {
|
||||
read_file: {
|
||||
name: READ_FILE_TOOL_NAME,
|
||||
description: `Reads and returns the content of a specified file. If the file is large, the content will be truncated. The tool's response will clearly indicate if truncation has occurred and will provide details on how to read more of the file using the 'start_line' and 'end_line' parameters. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), audio files (MP3, WAV, AIFF, AAC, OGG, FLAC), and PDF files. For text files, it can read specific line ranges.`,
|
||||
description: `Reads and returns the content of a specified file. To maintain context efficiency, you MUST use 'start_line' and 'end_line' for targeted, surgical reads of specific sections. For your safety, the tool will automatically truncate output exceeding ${DEFAULT_MAX_LINES_TEXT_FILE} lines, ${MAX_LINE_LENGTH_TEXT_FILE} characters per line, or ${MAX_FILE_SIZE_MB}MB in size; however, triggering these limits is considered token-inefficient. Always retrieve only the minimum content necessary for your next step. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), audio files (MP3, WAV, AIFF, AAC, OGG, FLAC), and PDF files.`,
|
||||
parametersJsonSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
|
||||
@@ -27,12 +27,14 @@ export class EnterPlanModeTool extends BaseDeclarativeTool<
|
||||
EnterPlanModeParams,
|
||||
ToolResult
|
||||
> {
|
||||
static readonly Name = ENTER_PLAN_MODE_TOOL_NAME;
|
||||
|
||||
constructor(
|
||||
private config: Config,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
super(
|
||||
ENTER_PLAN_MODE_TOOL_NAME,
|
||||
EnterPlanModeTool.Name,
|
||||
'Enter Plan Mode',
|
||||
ENTER_PLAN_MODE_DEFINITION.base.description!,
|
||||
Kind.Plan,
|
||||
|
||||
@@ -35,6 +35,8 @@ export class ExitPlanModeTool extends BaseDeclarativeTool<
|
||||
ExitPlanModeParams,
|
||||
ToolResult
|
||||
> {
|
||||
static readonly Name = EXIT_PLAN_MODE_TOOL_NAME;
|
||||
|
||||
constructor(
|
||||
private config: Config,
|
||||
messageBus: MessageBus,
|
||||
@@ -42,7 +44,7 @@ export class ExitPlanModeTool extends BaseDeclarativeTool<
|
||||
const plansDir = config.storage.getPlansDir();
|
||||
const definition = getExitPlanModeDefinition(plansDir);
|
||||
super(
|
||||
EXIT_PLAN_MODE_TOOL_NAME,
|
||||
ExitPlanModeTool.Name,
|
||||
'Exit Plan Mode',
|
||||
definition.base.description!,
|
||||
Kind.Plan,
|
||||
|
||||
@@ -69,6 +69,7 @@ import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { type MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { coreEvents } from '../utils/events.js';
|
||||
import type { ResourceRegistry } from '../resources/resource-registry.js';
|
||||
import { validateMcpPolicyToolNames } from '../policy/toml-loader.js';
|
||||
import {
|
||||
sanitizeEnvironment,
|
||||
type EnvironmentSanitizationConfig,
|
||||
@@ -221,6 +222,23 @@ export class McpClient implements McpProgressReporter {
|
||||
this.toolRegistry.registerTool(tool);
|
||||
}
|
||||
this.toolRegistry.sortTools();
|
||||
|
||||
// Validate MCP tool names in policy rules against discovered tools
|
||||
try {
|
||||
const discoveredToolNames = tools.map((t) => t.serverToolName);
|
||||
const policyRules = cliConfig.getPolicyEngine?.()?.getRules() ?? [];
|
||||
const warnings = validateMcpPolicyToolNames(
|
||||
this.serverName,
|
||||
discoveredToolNames,
|
||||
policyRules,
|
||||
);
|
||||
for (const warning of warnings) {
|
||||
coreEvents.emitFeedback('warning', warning);
|
||||
}
|
||||
} catch {
|
||||
// Policy engine may not be available in all contexts (e.g. tests).
|
||||
// Validation is best-effort; skip silently if unavailable.
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1577,6 +1595,9 @@ export interface McpContext {
|
||||
): void;
|
||||
setUserInteractedWithMcp?(): void;
|
||||
isTrustedFolder(): boolean;
|
||||
getPolicyEngine?(): {
|
||||
getRules(): ReadonlyArray<{ toolName?: string; source?: string }>;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -588,5 +588,13 @@ describe('ReadFileTool', () => {
|
||||
expect(schema.name).toBe(ReadFileTool.Name);
|
||||
expect(schema.description).toMatchSnapshot();
|
||||
});
|
||||
|
||||
it('should return the Gemini 3 schema when a Gemini 3 modelId is provided', () => {
|
||||
const modelId = 'gemini-3-pro-preview';
|
||||
const schema = tool.getSchema(modelId);
|
||||
expect(schema.name).toBe(ReadFileTool.Name);
|
||||
expect(schema.description).toMatchSnapshot();
|
||||
expect(schema.description).toContain('surgical reads');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -17,6 +17,7 @@ import type {
|
||||
ToolCallConfirmationDetails,
|
||||
ToolExecuteConfirmationDetails,
|
||||
PolicyUpdateOptions,
|
||||
ToolLiveOutput,
|
||||
} from './tools.js';
|
||||
import {
|
||||
BaseDeclarativeTool,
|
||||
@@ -149,7 +150,7 @@ export class ShellToolInvocation extends BaseToolInvocation<
|
||||
|
||||
async execute(
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string | AnsiOutput) => void,
|
||||
updateOutput?: (output: ToolLiveOutput) => void,
|
||||
shellExecutionConfig?: ShellExecutionConfig,
|
||||
setPidCallback?: (pid: number) => void,
|
||||
): Promise<ToolResult> {
|
||||
|
||||
@@ -380,20 +380,36 @@ describe('ToolRegistry', () => {
|
||||
});
|
||||
|
||||
describe('getAllToolNames', () => {
|
||||
it('should return all registered tool names', () => {
|
||||
it('should return all registered tool names with qualified names for MCP tools', () => {
|
||||
// Register tools with displayNames in non-alphabetical order
|
||||
const toolC = new MockTool({ name: 'c-tool', displayName: 'Tool C' });
|
||||
const toolA = new MockTool({ name: 'a-tool', displayName: 'Tool A' });
|
||||
const toolB = new MockTool({ name: 'b-tool', displayName: 'Tool B' });
|
||||
const mcpTool = createMCPTool('my-server', 'my-tool', 'desc');
|
||||
|
||||
toolRegistry.registerTool(toolC);
|
||||
toolRegistry.registerTool(toolA);
|
||||
toolRegistry.registerTool(toolB);
|
||||
toolRegistry.registerTool(mcpTool);
|
||||
|
||||
const toolNames = toolRegistry.getAllToolNames();
|
||||
|
||||
// Assert that the returned array contains all tool names
|
||||
expect(toolNames).toEqual(['c-tool', 'a-tool', 'b-tool']);
|
||||
// Assert that the returned array contains all tool names, with MCP qualified
|
||||
expect(toolNames).toContain('c-tool');
|
||||
expect(toolNames).toContain('a-tool');
|
||||
expect(toolNames).toContain('my-server__my-tool');
|
||||
expect(toolNames).toHaveLength(3);
|
||||
});
|
||||
|
||||
it('should deduplicate tool names', () => {
|
||||
const serverName = 'my-server';
|
||||
const toolName = 'my-tool';
|
||||
const mcpTool = createMCPTool(serverName, toolName, 'desc');
|
||||
|
||||
// Register same MCP tool twice (one as alias, one as qualified)
|
||||
toolRegistry.registerTool(mcpTool);
|
||||
toolRegistry.registerTool(mcpTool.asFullyQualifiedTool());
|
||||
|
||||
const toolNames = toolRegistry.getAllToolNames();
|
||||
expect(toolNames).toEqual([`${serverName}__${toolName}`]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -465,8 +481,8 @@ describe('ToolRegistry', () => {
|
||||
'builtin-1',
|
||||
'builtin-2',
|
||||
DISCOVERED_TOOL_PREFIX + 'discovered-1',
|
||||
'mcp-apple',
|
||||
'mcp-zebra',
|
||||
'apple-server__mcp-apple',
|
||||
'zebra-server__mcp-zebra',
|
||||
]);
|
||||
});
|
||||
});
|
||||
@@ -659,6 +675,34 @@ describe('ToolRegistry', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('getFunctionDeclarations', () => {
|
||||
it('should use fully qualified names for MCP tools in declarations', () => {
|
||||
const serverName = 'my-server';
|
||||
const toolName = 'my-tool';
|
||||
const mcpTool = createMCPTool(serverName, toolName, 'description');
|
||||
|
||||
toolRegistry.registerTool(mcpTool);
|
||||
|
||||
const declarations = toolRegistry.getFunctionDeclarations();
|
||||
expect(declarations).toHaveLength(1);
|
||||
expect(declarations[0].name).toBe(`${serverName}__${toolName}`);
|
||||
});
|
||||
|
||||
it('should deduplicate MCP tools in declarations', () => {
|
||||
const serverName = 'my-server';
|
||||
const toolName = 'my-tool';
|
||||
const mcpTool = createMCPTool(serverName, toolName, 'description');
|
||||
|
||||
// Register both alias and qualified
|
||||
toolRegistry.registerTool(mcpTool);
|
||||
toolRegistry.registerTool(mcpTool.asFullyQualifiedTool());
|
||||
|
||||
const declarations = toolRegistry.getFunctionDeclarations();
|
||||
expect(declarations).toHaveLength(1);
|
||||
expect(declarations[0].name).toBe(`${serverName}__${toolName}`);
|
||||
});
|
||||
});
|
||||
|
||||
describe('plan mode', () => {
|
||||
it('should only return policy-allowed tools in plan mode', () => {
|
||||
// Register several tools
|
||||
|
||||
@@ -554,11 +554,32 @@ export class ToolRegistry {
|
||||
const plansDir = this.config.storage.getPlansDir();
|
||||
|
||||
const declarations: FunctionDeclaration[] = [];
|
||||
const seenNames = new Set<string>();
|
||||
|
||||
this.getActiveTools().forEach((tool) => {
|
||||
const toolName =
|
||||
tool instanceof DiscoveredMCPTool
|
||||
? tool.getFullyQualifiedName()
|
||||
: tool.name;
|
||||
|
||||
if (seenNames.has(toolName)) {
|
||||
return;
|
||||
}
|
||||
seenNames.add(toolName);
|
||||
|
||||
let schema = tool.getSchema(modelId);
|
||||
|
||||
// Ensure the schema name matches the qualified name for MCP tools
|
||||
if (tool instanceof DiscoveredMCPTool) {
|
||||
schema = {
|
||||
...schema,
|
||||
name: toolName,
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
isPlanMode &&
|
||||
(tool.name === WRITE_FILE_TOOL_NAME || tool.name === EDIT_TOOL_NAME)
|
||||
(toolName === WRITE_FILE_TOOL_NAME || toolName === EDIT_TOOL_NAME)
|
||||
) {
|
||||
schema = {
|
||||
...schema,
|
||||
@@ -591,20 +612,42 @@ export class ToolRegistry {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an array of all registered and discovered tool names which are not
|
||||
* excluded via configuration.
|
||||
* Returns an array of names for all active tools.
|
||||
* For MCP tools, this returns their fully qualified names.
|
||||
* The list is deduplicated.
|
||||
*/
|
||||
getAllToolNames(): string[] {
|
||||
return this.getActiveTools().map((tool) => tool.name);
|
||||
const names = new Set<string>();
|
||||
for (const tool of this.getActiveTools()) {
|
||||
if (tool instanceof DiscoveredMCPTool) {
|
||||
names.add(tool.getFullyQualifiedName());
|
||||
} else {
|
||||
names.add(tool.name);
|
||||
}
|
||||
}
|
||||
return Array.from(names);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an array of all registered and discovered tool instances.
|
||||
*/
|
||||
getAllTools(): AnyDeclarativeTool[] {
|
||||
return this.getActiveTools().sort((a, b) =>
|
||||
const seen = new Set<string>();
|
||||
const tools: AnyDeclarativeTool[] = [];
|
||||
|
||||
for (const tool of this.getActiveTools().sort((a, b) =>
|
||||
a.displayName.localeCompare(b.displayName),
|
||||
);
|
||||
)) {
|
||||
const name =
|
||||
tool instanceof DiscoveredMCPTool
|
||||
? tool.getFullyQualifiedName()
|
||||
: tool.name;
|
||||
if (!seen.has(name)) {
|
||||
seen.add(name);
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
return tools;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
type Question,
|
||||
} from '../confirmation-bus/types.js';
|
||||
import { type ApprovalMode } from '../policy/types.js';
|
||||
import type { SubagentProgress } from '../agents/types.js';
|
||||
|
||||
/**
|
||||
* Represents a validated and ready-to-execute tool call.
|
||||
@@ -64,7 +65,7 @@ export interface ToolInvocation<
|
||||
*/
|
||||
execute(
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string | AnsiOutput) => void,
|
||||
updateOutput?: (output: ToolLiveOutput) => void,
|
||||
shellExecutionConfig?: ShellExecutionConfig,
|
||||
): Promise<TResult>;
|
||||
}
|
||||
@@ -276,7 +277,7 @@ export abstract class BaseToolInvocation<
|
||||
|
||||
abstract execute(
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string | AnsiOutput) => void,
|
||||
updateOutput?: (output: ToolLiveOutput) => void,
|
||||
shellExecutionConfig?: ShellExecutionConfig,
|
||||
): Promise<TResult>;
|
||||
}
|
||||
@@ -422,7 +423,7 @@ export abstract class DeclarativeTool<
|
||||
async buildAndExecute(
|
||||
params: TParams,
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string | AnsiOutput) => void,
|
||||
updateOutput?: (output: ToolLiveOutput) => void,
|
||||
shellExecutionConfig?: ShellExecutionConfig,
|
||||
): Promise<TResult> {
|
||||
const invocation = this.build(params);
|
||||
@@ -688,7 +689,14 @@ export interface TodoList {
|
||||
todos: Todo[];
|
||||
}
|
||||
|
||||
export type ToolResultDisplay = string | FileDiff | AnsiOutput | TodoList;
|
||||
export type ToolLiveOutput = string | AnsiOutput | SubagentProgress;
|
||||
|
||||
export type ToolResultDisplay =
|
||||
| string
|
||||
| FileDiff
|
||||
| AnsiOutput
|
||||
| TodoList
|
||||
| SubagentProgress;
|
||||
|
||||
export type TodoStatus = 'pending' | 'in_progress' | 'completed' | 'cancelled';
|
||||
|
||||
|
||||
@@ -6,3 +6,7 @@
|
||||
|
||||
export const REFERENCE_CONTENT_START = '--- Content from referenced files ---';
|
||||
export const REFERENCE_CONTENT_END = '--- End of content ---';
|
||||
|
||||
export const DEFAULT_MAX_LINES_TEXT_FILE = 2000;
|
||||
export const MAX_LINE_LENGTH_TEXT_FILE = 2000;
|
||||
export const MAX_FILE_SIZE_MB = 20;
|
||||
|
||||
@@ -8,7 +8,7 @@ import { getErrorMessage, isNodeError } from './errors.js';
|
||||
import { URL } from 'node:url';
|
||||
import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||
|
||||
const DEFAULT_HEADERS_TIMEOUT = 60000; // 60 seconds
|
||||
const DEFAULT_HEADERS_TIMEOUT = 300000; // 5 minutes
|
||||
const DEFAULT_BODY_TIMEOUT = 300000; // 5 minutes
|
||||
|
||||
// Configure default global dispatcher with higher timeouts
|
||||
|
||||
@@ -15,6 +15,11 @@ import { ToolErrorType } from '../tools/tool-error.js';
|
||||
import { BINARY_EXTENSIONS } from './ignorePatterns.js';
|
||||
import { createRequire as createModuleRequire } from 'node:module';
|
||||
import { debugLogger } from './debugLogger.js';
|
||||
import {
|
||||
DEFAULT_MAX_LINES_TEXT_FILE,
|
||||
MAX_LINE_LENGTH_TEXT_FILE,
|
||||
MAX_FILE_SIZE_MB,
|
||||
} from './constants.js';
|
||||
|
||||
const requireModule = createModuleRequire(import.meta.url);
|
||||
|
||||
@@ -52,10 +57,6 @@ export async function loadWasmBinary(
|
||||
}
|
||||
}
|
||||
|
||||
// Constants for text file processing
|
||||
export const DEFAULT_MAX_LINES_TEXT_FILE = 2000;
|
||||
const MAX_LINE_LENGTH_TEXT_FILE = 2000;
|
||||
|
||||
// Default values for encoding and separator format
|
||||
export const DEFAULT_ENCODING: BufferEncoding = 'utf-8';
|
||||
|
||||
@@ -434,11 +435,11 @@ export async function processSingleFileContent(
|
||||
}
|
||||
|
||||
const fileSizeInMB = stats.size / (1024 * 1024);
|
||||
if (fileSizeInMB > 20) {
|
||||
if (fileSizeInMB > MAX_FILE_SIZE_MB) {
|
||||
return {
|
||||
llmContent: 'File size exceeds the 20MB limit.',
|
||||
returnDisplay: 'File size exceeds the 20MB limit.',
|
||||
error: `File size exceeds the 20MB limit: ${filePath} (${fileSizeInMB.toFixed(2)}MB)`,
|
||||
llmContent: `File size exceeds the ${MAX_FILE_SIZE_MB}MB limit.`,
|
||||
returnDisplay: `File size exceeds the ${MAX_FILE_SIZE_MB}MB limit.`,
|
||||
error: `File size exceeds the ${MAX_FILE_SIZE_MB}MB limit: ${filePath} (${fileSizeInMB.toFixed(2)}MB)`,
|
||||
errorType: ToolErrorType.FILE_TOO_LARGE,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -98,6 +98,17 @@ describe('shouldHideToolCall', () => {
|
||||
).toBe(!visible);
|
||||
},
|
||||
);
|
||||
|
||||
it('hides tool calls with a parentCallId', () => {
|
||||
expect(
|
||||
shouldHideToolCall({
|
||||
displayName: 'any_tool',
|
||||
status: CoreToolCallStatus.Success,
|
||||
hasResultDisplay: true,
|
||||
parentCallId: 'some-parent',
|
||||
}),
|
||||
).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getToolSuggestion', () => {
|
||||
|
||||
@@ -28,20 +28,28 @@ export interface ShouldHideToolCallParams {
|
||||
approvalMode?: ApprovalMode;
|
||||
/** Whether the tool has produced a result for display. */
|
||||
hasResultDisplay: boolean;
|
||||
/** The ID of the parent tool call, if any. */
|
||||
parentCallId?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines if a tool call should be hidden from the standard tool history UI.
|
||||
*
|
||||
* We hide tools in several cases:
|
||||
* 1. Ask User tools that are in progress, displayed via specialized UI.
|
||||
* 2. Ask User tools that errored without result display, typically param
|
||||
* 1. Tool calls that have a parent, as they are "internal" to another tool (e.g. subagent).
|
||||
* 2. Ask User tools that are in progress, displayed via specialized UI.
|
||||
* 3. Ask User tools that errored without result display, typically param
|
||||
* validation errors that the agent automatically recovers from.
|
||||
* 3. WriteFile and Edit tools when in Plan Mode, redundant because the
|
||||
* 4. WriteFile and Edit tools when in Plan Mode, redundant because the
|
||||
* resulting plans are displayed separately upon exiting plan mode.
|
||||
*/
|
||||
export function shouldHideToolCall(params: ShouldHideToolCallParams): boolean {
|
||||
const { displayName, status, approvalMode, hasResultDisplay } = params;
|
||||
const { displayName, status, approvalMode, hasResultDisplay, parentCallId } =
|
||||
params;
|
||||
|
||||
if (parentCallId) {
|
||||
return true;
|
||||
}
|
||||
|
||||
switch (displayName) {
|
||||
case ASK_USER_DISPLAY_NAME:
|
||||
|
||||
Reference in New Issue
Block a user