Merge remote-tracking branch 'origin/main' into mk-packing

This commit is contained in:
mkorwel
2025-07-01 18:57:38 -05:00
22 changed files with 577 additions and 159 deletions
+1 -1
View File
@@ -1,6 +1,6 @@
{
"name": "@google/gemini-cli-core",
"version": "0.1.8",
"version": "0.1.9",
"description": "Gemini CLI Server",
"repository": "google-gemini/gemini-cli",
"type": "module",
+2 -1
View File
@@ -12,11 +12,12 @@ import { CodeAssistServer, HttpOptions } from './server.js';
export async function createCodeAssistContentGenerator(
httpOptions: HttpOptions,
authType: AuthType,
sessionId?: string,
): Promise<ContentGenerator> {
if (authType === AuthType.LOGIN_WITH_GOOGLE) {
const authClient = await getOauthClient();
const projectId = await setupUser(authClient);
return new CodeAssistServer(authClient, projectId, httpOptions);
return new CodeAssistServer(authClient, projectId, httpOptions, sessionId);
}
throw new Error(`Unsupported authType: ${authType}`);
@@ -37,6 +37,7 @@ describe('converter', () => {
labels: undefined,
safetySettings: undefined,
generationConfig: undefined,
session_id: undefined,
},
});
});
@@ -59,6 +60,34 @@ describe('converter', () => {
labels: undefined,
safetySettings: undefined,
generationConfig: undefined,
session_id: undefined,
},
});
});
it('should convert a request with sessionId', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-project',
'session-123',
);
expect(codeAssistReq).toEqual({
model: 'gemini-pro',
project: 'my-project',
request: {
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
systemInstruction: undefined,
cachedContent: undefined,
tools: undefined,
toolConfig: undefined,
labels: undefined,
safetySettings: undefined,
generationConfig: undefined,
session_id: 'session-123',
},
});
});
+5 -1
View File
@@ -44,6 +44,7 @@ interface VertexGenerateContentRequest {
labels?: Record<string, string>;
safetySettings?: SafetySetting[];
generationConfig?: VertexGenerationConfig;
session_id?: string;
}
interface VertexGenerationConfig {
@@ -114,11 +115,12 @@ export function fromCountTokenResponse(
export function toGenerateContentRequest(
req: GenerateContentParameters,
project?: string,
sessionId?: string,
): CAGenerateContentRequest {
return {
model: req.model,
project,
request: toVertexGenerateContentRequest(req),
request: toVertexGenerateContentRequest(req, sessionId),
};
}
@@ -136,6 +138,7 @@ export function fromGenerateContentResponse(
function toVertexGenerateContentRequest(
req: GenerateContentParameters,
sessionId?: string,
): VertexGenerateContentRequest {
return {
contents: toContents(req.contents),
@@ -146,6 +149,7 @@ function toVertexGenerateContentRequest(
labels: req.config?.labels,
safetySettings: req.config?.safetySettings,
generationConfig: toVertexGenerationConfig(req.config),
session_id: sessionId,
};
}
+3 -2
View File
@@ -48,6 +48,7 @@ export class CodeAssistServer implements ContentGenerator {
readonly client: OAuth2Client,
readonly projectId?: string,
readonly httpOptions: HttpOptions = {},
readonly sessionId?: string,
) {}
async generateContentStream(
@@ -55,7 +56,7 @@ export class CodeAssistServer implements ContentGenerator {
): Promise<AsyncGenerator<GenerateContentResponse>> {
const resps = await this.requestStreamingPost<CaGenerateContentResponse>(
'streamGenerateContent',
toGenerateContentRequest(req, this.projectId),
toGenerateContentRequest(req, this.projectId, this.sessionId),
req.config?.abortSignal,
);
return (async function* (): AsyncGenerator<GenerateContentResponse> {
@@ -70,7 +71,7 @@ export class CodeAssistServer implements ContentGenerator {
): Promise<GenerateContentResponse> {
const resp = await this.requestPost<CaGenerateContentResponse>(
'generateContent',
toGenerateContentRequest(req, this.projectId),
toGenerateContentRequest(req, this.projectId, this.sessionId),
req.config?.abortSignal,
);
return fromGenerateContentResponse(resp);
+178
View File
@@ -402,5 +402,183 @@ describe('Gemini Client (client.ts)', () => {
// Assert
expect(finalResult).toBeInstanceOf(Turn);
});
it('should stop infinite loop after MAX_TURNS when nextSpeaker always returns model', async () => {
// Get the mocked checkNextSpeaker function and configure it to trigger infinite loop
const { checkNextSpeaker } = await import(
'../utils/nextSpeakerChecker.js'
);
const mockCheckNextSpeaker = vi.mocked(checkNextSpeaker);
mockCheckNextSpeaker.mockResolvedValue({
next_speaker: 'model',
reasoning: 'Test case - always continue',
});
// Mock Turn to have no pending tool calls (which would allow nextSpeaker check)
const mockStream = (async function* () {
yield { type: 'content', value: 'Continue...' };
})();
mockTurnRunFn.mockReturnValue(mockStream);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
// Use a signal that never gets aborted
const abortController = new AbortController();
const signal = abortController.signal;
// Act - Start the stream that should loop
const stream = client.sendMessageStream(
[{ text: 'Start conversation' }],
signal,
);
// Count how many stream events we get
let eventCount = 0;
let finalResult: Turn | undefined;
// Consume the stream and count iterations
while (true) {
const result = await stream.next();
if (result.done) {
finalResult = result.value;
break;
}
eventCount++;
// Safety check to prevent actual infinite loop in test
if (eventCount > 200) {
abortController.abort();
throw new Error(
'Test exceeded expected event limit - possible actual infinite loop',
);
}
}
// Assert
expect(finalResult).toBeInstanceOf(Turn);
// Debug: Check how many times checkNextSpeaker was called
const callCount = mockCheckNextSpeaker.mock.calls.length;
// If infinite loop protection is working, checkNextSpeaker should be called many times
// but stop at MAX_TURNS (100). Since each recursive call should trigger checkNextSpeaker,
// we expect it to be called multiple times before hitting the limit
expect(mockCheckNextSpeaker).toHaveBeenCalled();
// The test should demonstrate that the infinite loop protection works:
// - If checkNextSpeaker is called many times (close to MAX_TURNS), it shows the loop was happening
// - If it's only called once, the recursive behavior might not be triggered
if (callCount === 0) {
throw new Error(
'checkNextSpeaker was never called - the recursive condition was not met',
);
} else if (callCount === 1) {
// This might be expected behavior if the turn has pending tool calls or other conditions prevent recursion
console.log(
'checkNextSpeaker called only once - no infinite loop occurred',
);
} else {
console.log(
`checkNextSpeaker called ${callCount} times - infinite loop protection worked`,
);
// If called multiple times, we expect it to be stopped before MAX_TURNS
expect(callCount).toBeLessThanOrEqual(100); // Should not exceed MAX_TURNS
}
// The stream should produce events and eventually terminate
expect(eventCount).toBeGreaterThanOrEqual(1);
expect(eventCount).toBeLessThan(200); // Should not exceed our safety limit
});
it('should respect MAX_TURNS limit even when turns parameter is set to a large value', async () => {
// This test verifies that the infinite loop protection works even when
// someone tries to bypass it by calling with a very large turns value
// Get the mocked checkNextSpeaker function and configure it to trigger infinite loop
const { checkNextSpeaker } = await import(
'../utils/nextSpeakerChecker.js'
);
const mockCheckNextSpeaker = vi.mocked(checkNextSpeaker);
mockCheckNextSpeaker.mockResolvedValue({
next_speaker: 'model',
reasoning: 'Test case - always continue',
});
// Mock Turn to have no pending tool calls (which would allow nextSpeaker check)
const mockStream = (async function* () {
yield { type: 'content', value: 'Continue...' };
})();
mockTurnRunFn.mockReturnValue(mockStream);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
// Use a signal that never gets aborted
const abortController = new AbortController();
const signal = abortController.signal;
// Act - Start the stream with an extremely high turns value
// This simulates a case where the turns protection is bypassed
const stream = client.sendMessageStream(
[{ text: 'Start conversation' }],
signal,
Number.MAX_SAFE_INTEGER, // Bypass the MAX_TURNS protection
);
// Count how many stream events we get
let eventCount = 0;
const maxTestIterations = 1000; // Higher limit to show the loop continues
// Consume the stream and count iterations
try {
while (true) {
const result = await stream.next();
if (result.done) {
break;
}
eventCount++;
// This test should hit this limit, demonstrating the infinite loop
if (eventCount > maxTestIterations) {
abortController.abort();
// This is the expected behavior - we hit the infinite loop
break;
}
}
} catch (error) {
// If the test framework times out, that also demonstrates the infinite loop
console.error('Test timed out or errored:', error);
}
// Assert that the fix works - the loop should stop at MAX_TURNS
const callCount = mockCheckNextSpeaker.mock.calls.length;
// With the fix: even when turns is set to a very high value,
// the loop should stop at MAX_TURNS (100)
expect(callCount).toBeLessThanOrEqual(100); // Should not exceed MAX_TURNS
expect(eventCount).toBeLessThanOrEqual(200); // Should have reasonable number of events
console.log(
`Infinite loop protection working: checkNextSpeaker called ${callCount} times, ` +
`${eventCount} events generated (properly bounded by MAX_TURNS)`,
);
});
});
});
+5 -2
View File
@@ -68,6 +68,7 @@ export class GeminiClient {
async initialize(contentGeneratorConfig: ContentGeneratorConfig) {
this.contentGenerator = await createContentGenerator(
contentGeneratorConfig,
this.config.getSessionId(),
);
this.chat = await this.startChat();
}
@@ -219,7 +220,9 @@ export class GeminiClient {
signal: AbortSignal,
turns: number = this.MAX_TURNS,
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
if (!turns) {
// Ensure turns never exceeds MAX_TURNS to prevent infinite loops
const boundedTurns = Math.min(turns, this.MAX_TURNS);
if (!boundedTurns) {
return new Turn(this.getChat());
}
@@ -242,7 +245,7 @@ export class GeminiClient {
const nextRequest = [{ text: 'Please continue.' }];
// This recursive call's events will be yielded out, but the final
// turn object will be from the top-level call.
yield* this.sendMessageStream(nextRequest, signal, turns - 1);
yield* this.sendMessageStream(nextRequest, signal, boundedTurns - 1);
}
}
return turn;
+6 -1
View File
@@ -101,6 +101,7 @@ export async function createContentGeneratorConfig(
export async function createContentGenerator(
config: ContentGeneratorConfig,
sessionId?: string,
): Promise<ContentGenerator> {
const version = process.env.CLI_VERSION || process.version;
const httpOptions = {
@@ -109,7 +110,11 @@ export async function createContentGenerator(
},
};
if (config.authType === AuthType.LOGIN_WITH_GOOGLE) {
return createCodeAssistContentGenerator(httpOptions, config.authType);
return createCodeAssistContentGenerator(
httpOptions,
config.authType,
sessionId,
);
}
if (
@@ -196,6 +196,11 @@ describe('fileUtils', () => {
vi.restoreAllMocks(); // Restore spies on actualNodeFs
});
it('should detect typescript type by extension (ts)', () => {
expect(detectFileType('file.ts')).toBe('text');
expect(detectFileType('file.test.ts')).toBe('text');
});
it('should detect image type by extension (png)', () => {
mockMimeLookup.mockReturnValueOnce('image/png');
expect(detectFileType('file.png')).toBe('image');
+7 -1
View File
@@ -100,8 +100,14 @@ export function detectFileType(
filePath: string,
): 'text' | 'image' | 'pdf' | 'audio' | 'video' | 'binary' {
const ext = path.extname(filePath).toLowerCase();
const lookedUpMimeType = mime.lookup(filePath); // Returns false if not found, or the mime type string
// The mimetype for "ts" is MPEG transport stream (a video format) but we want
// to assume these are typescript files instead.
if (ext === '.ts') {
return 'text';
}
const lookedUpMimeType = mime.lookup(filePath); // Returns false if not found, or the mime type string
if (lookedUpMimeType) {
if (lookedUpMimeType.startsWith('image/')) {
return 'image';