mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-11 22:51:00 -07:00
179 lines
5.2 KiB
TypeScript
179 lines
5.2 KiB
TypeScript
/**
|
|
* @license
|
|
* Copyright 2025 Google LLC
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
import type {
|
|
Task as SDKTask,
|
|
TaskStatusUpdateEvent,
|
|
SendStreamingMessageSuccessResponse,
|
|
} from '@a2a-js/sdk';
|
|
import {
|
|
ApprovalMode,
|
|
DEFAULT_GEMINI_MODEL,
|
|
DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES,
|
|
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
|
|
GeminiClient,
|
|
HookSystem,
|
|
PolicyDecision,
|
|
} from '@google/gemini-cli-core';
|
|
import { createMockMessageBus } from '@google/gemini-cli-core/src/test-utils/mock-message-bus.js';
|
|
import type { Config, Storage } from '@google/gemini-cli-core';
|
|
import { expect, vi } from 'vitest';
|
|
|
|
export function createMockConfig(
|
|
overrides: Partial<Config> = {},
|
|
): Partial<Config> {
|
|
const mockConfig = {
|
|
getToolRegistry: vi.fn().mockReturnValue({
|
|
getTool: vi.fn(),
|
|
getAllToolNames: vi.fn().mockReturnValue([]),
|
|
getAllTools: vi.fn().mockReturnValue([]),
|
|
getToolsByServer: vi.fn().mockReturnValue([]),
|
|
}),
|
|
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
|
|
getIdeMode: vi.fn().mockReturnValue(false),
|
|
isInteractive: () => true,
|
|
getAllowedTools: vi.fn().mockReturnValue([]),
|
|
getWorkspaceContext: vi.fn().mockReturnValue({
|
|
isPathWithinWorkspace: () => true,
|
|
}),
|
|
getTargetDir: () => '/test',
|
|
getCheckpointingEnabled: vi.fn().mockReturnValue(false),
|
|
storage: {
|
|
getProjectTempDir: () => '/tmp',
|
|
getProjectTempCheckpointsDir: () => '/tmp/checkpoints',
|
|
} as Storage,
|
|
getTruncateToolOutputThreshold: () =>
|
|
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
|
|
getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES,
|
|
getActiveModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL),
|
|
getDebugMode: vi.fn().mockReturnValue(false),
|
|
getContentGeneratorConfig: vi.fn().mockReturnValue({ model: 'gemini-pro' }),
|
|
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
|
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
|
|
setFallbackModelHandler: vi.fn(),
|
|
initialize: vi.fn().mockResolvedValue(undefined),
|
|
getProxy: vi.fn().mockReturnValue(undefined),
|
|
getHistory: vi.fn().mockReturnValue([]),
|
|
getEmbeddingModel: vi.fn().mockReturnValue('text-embedding-004'),
|
|
getSessionId: vi.fn().mockReturnValue('test-session-id'),
|
|
getUserTier: vi.fn(),
|
|
getMessageBus: vi.fn(),
|
|
getPolicyEngine: vi.fn(),
|
|
getEnableExtensionReloading: vi.fn().mockReturnValue(false),
|
|
getEnableHooks: vi.fn().mockReturnValue(false),
|
|
getMcpClientManager: vi.fn().mockReturnValue({
|
|
getMcpServers: vi.fn().mockReturnValue({}),
|
|
}),
|
|
getGitService: vi.fn(),
|
|
...overrides,
|
|
} as unknown as Config;
|
|
mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus());
|
|
mockConfig.getHookSystem = vi
|
|
.fn()
|
|
.mockReturnValue(new HookSystem(mockConfig));
|
|
|
|
mockConfig.getGeminiClient = vi
|
|
.fn()
|
|
.mockReturnValue(new GeminiClient(mockConfig));
|
|
|
|
mockConfig.getPolicyEngine = vi.fn().mockReturnValue({
|
|
check: async () => {
|
|
const mode = mockConfig.getApprovalMode();
|
|
if (mode === ApprovalMode.YOLO) {
|
|
return { decision: PolicyDecision.ALLOW };
|
|
}
|
|
return { decision: PolicyDecision.ASK_USER };
|
|
},
|
|
});
|
|
|
|
return mockConfig;
|
|
}
|
|
|
|
export function createStreamMessageRequest(
|
|
text: string,
|
|
messageId: string,
|
|
taskId?: string,
|
|
) {
|
|
const request: {
|
|
jsonrpc: string;
|
|
id: string;
|
|
method: string;
|
|
params: {
|
|
message: {
|
|
kind: string;
|
|
role: string;
|
|
parts: [{ kind: string; text: string }];
|
|
messageId: string;
|
|
};
|
|
metadata: {
|
|
coderAgent: {
|
|
kind: string;
|
|
workspacePath: string;
|
|
};
|
|
};
|
|
taskId?: string;
|
|
};
|
|
} = {
|
|
jsonrpc: '2.0',
|
|
id: '1',
|
|
method: 'message/stream',
|
|
params: {
|
|
message: {
|
|
kind: 'message',
|
|
role: 'user',
|
|
parts: [{ kind: 'text', text }],
|
|
messageId,
|
|
},
|
|
metadata: {
|
|
coderAgent: {
|
|
kind: 'agent-settings',
|
|
workspacePath: '/tmp',
|
|
},
|
|
},
|
|
},
|
|
};
|
|
|
|
if (taskId) {
|
|
request.params.taskId = taskId;
|
|
}
|
|
|
|
return request;
|
|
}
|
|
|
|
export function assertUniqueFinalEventIsLast(
|
|
events: SendStreamingMessageSuccessResponse[],
|
|
) {
|
|
// Final event is input-required & final
|
|
const finalEvent = events[events.length - 1].result as TaskStatusUpdateEvent;
|
|
expect(finalEvent.metadata?.['coderAgent']).toMatchObject({
|
|
kind: 'state-change',
|
|
});
|
|
expect(finalEvent.status?.state).toBe('input-required');
|
|
expect(finalEvent.final).toBe(true);
|
|
|
|
// There is only one event with final and its the last
|
|
expect(
|
|
events.filter((e) => (e.result as TaskStatusUpdateEvent).final).length,
|
|
).toBe(1);
|
|
expect(
|
|
events.findIndex((e) => (e.result as TaskStatusUpdateEvent).final),
|
|
).toBe(events.length - 1);
|
|
}
|
|
|
|
export function assertTaskCreationAndWorkingStatus(
|
|
events: SendStreamingMessageSuccessResponse[],
|
|
) {
|
|
// Initial task creation event
|
|
const taskEvent = events[0].result as SDKTask;
|
|
expect(taskEvent.kind).toBe('task');
|
|
expect(taskEvent.status.state).toBe('submitted');
|
|
|
|
// Status update: working
|
|
const workingEvent = events[1].result as TaskStatusUpdateEvent;
|
|
expect(workingEvent.kind).toBe('status-update');
|
|
expect(workingEvent.status.state).toBe('working');
|
|
}
|