Files
gemini-cli/packages/a2a-server/src/utils/testing_utils.ts

179 lines
5.2 KiB
TypeScript

/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type {
Task as SDKTask,
TaskStatusUpdateEvent,
SendStreamingMessageSuccessResponse,
} from '@a2a-js/sdk';
import {
ApprovalMode,
DEFAULT_GEMINI_MODEL,
DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES,
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
GeminiClient,
HookSystem,
PolicyDecision,
} from '@google/gemini-cli-core';
import { createMockMessageBus } from '@google/gemini-cli-core/src/test-utils/mock-message-bus.js';
import type { Config, Storage } from '@google/gemini-cli-core';
import { expect, vi } from 'vitest';
export function createMockConfig(
overrides: Partial<Config> = {},
): Partial<Config> {
const mockConfig = {
getToolRegistry: vi.fn().mockReturnValue({
getTool: vi.fn(),
getAllToolNames: vi.fn().mockReturnValue([]),
getAllTools: vi.fn().mockReturnValue([]),
getToolsByServer: vi.fn().mockReturnValue([]),
}),
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
getIdeMode: vi.fn().mockReturnValue(false),
isInteractive: () => true,
getAllowedTools: vi.fn().mockReturnValue([]),
getWorkspaceContext: vi.fn().mockReturnValue({
isPathWithinWorkspace: () => true,
}),
getTargetDir: () => '/test',
getCheckpointingEnabled: vi.fn().mockReturnValue(false),
storage: {
getProjectTempDir: () => '/tmp',
getProjectTempCheckpointsDir: () => '/tmp/checkpoints',
} as Storage,
getTruncateToolOutputThreshold: () =>
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES,
getActiveModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL),
getDebugMode: vi.fn().mockReturnValue(false),
getContentGeneratorConfig: vi.fn().mockReturnValue({ model: 'gemini-pro' }),
getModel: vi.fn().mockReturnValue('gemini-pro'),
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
setFallbackModelHandler: vi.fn(),
initialize: vi.fn().mockResolvedValue(undefined),
getProxy: vi.fn().mockReturnValue(undefined),
getHistory: vi.fn().mockReturnValue([]),
getEmbeddingModel: vi.fn().mockReturnValue('text-embedding-004'),
getSessionId: vi.fn().mockReturnValue('test-session-id'),
getUserTier: vi.fn(),
getMessageBus: vi.fn(),
getPolicyEngine: vi.fn(),
getEnableExtensionReloading: vi.fn().mockReturnValue(false),
getEnableHooks: vi.fn().mockReturnValue(false),
getMcpClientManager: vi.fn().mockReturnValue({
getMcpServers: vi.fn().mockReturnValue({}),
}),
getGitService: vi.fn(),
...overrides,
} as unknown as Config;
mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus());
mockConfig.getHookSystem = vi
.fn()
.mockReturnValue(new HookSystem(mockConfig));
mockConfig.getGeminiClient = vi
.fn()
.mockReturnValue(new GeminiClient(mockConfig));
mockConfig.getPolicyEngine = vi.fn().mockReturnValue({
check: async () => {
const mode = mockConfig.getApprovalMode();
if (mode === ApprovalMode.YOLO) {
return { decision: PolicyDecision.ALLOW };
}
return { decision: PolicyDecision.ASK_USER };
},
});
return mockConfig;
}
export function createStreamMessageRequest(
text: string,
messageId: string,
taskId?: string,
) {
const request: {
jsonrpc: string;
id: string;
method: string;
params: {
message: {
kind: string;
role: string;
parts: [{ kind: string; text: string }];
messageId: string;
};
metadata: {
coderAgent: {
kind: string;
workspacePath: string;
};
};
taskId?: string;
};
} = {
jsonrpc: '2.0',
id: '1',
method: 'message/stream',
params: {
message: {
kind: 'message',
role: 'user',
parts: [{ kind: 'text', text }],
messageId,
},
metadata: {
coderAgent: {
kind: 'agent-settings',
workspacePath: '/tmp',
},
},
},
};
if (taskId) {
request.params.taskId = taskId;
}
return request;
}
export function assertUniqueFinalEventIsLast(
events: SendStreamingMessageSuccessResponse[],
) {
// Final event is input-required & final
const finalEvent = events[events.length - 1].result as TaskStatusUpdateEvent;
expect(finalEvent.metadata?.['coderAgent']).toMatchObject({
kind: 'state-change',
});
expect(finalEvent.status?.state).toBe('input-required');
expect(finalEvent.final).toBe(true);
// There is only one event with final and its the last
expect(
events.filter((e) => (e.result as TaskStatusUpdateEvent).final).length,
).toBe(1);
expect(
events.findIndex((e) => (e.result as TaskStatusUpdateEvent).final),
).toBe(events.length - 1);
}
export function assertTaskCreationAndWorkingStatus(
events: SendStreamingMessageSuccessResponse[],
) {
// Initial task creation event
const taskEvent = events[0].result as SDKTask;
expect(taskEvent.kind).toBe('task');
expect(taskEvent.status.state).toBe('submitted');
// Status update: working
const workingEvent = events[1].result as TaskStatusUpdateEvent;
expect(workingEvent.kind).toBe('status-update');
expect(workingEvent.status.state).toBe('working');
}