mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-20 10:10:56 -07:00
Merge remote-tracking branch 'origin/main' into mk-packing
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@google/gemini-cli",
|
||||
"version": "0.1.8",
|
||||
"version": "0.1.9",
|
||||
"description": "Gemini CLI",
|
||||
"repository": "google-gemini/gemini-cli",
|
||||
"type": "module",
|
||||
|
||||
@@ -350,3 +350,131 @@ describe('mergeMcpServers', () => {
|
||||
expect(settings).toEqual(originalSettings);
|
||||
});
|
||||
});
|
||||
|
||||
describe('mergeExcludeTools', () => {
|
||||
it('should merge excludeTools from settings and extensions', async () => {
|
||||
const settings: Settings = { excludeTools: ['tool1', 'tool2'] };
|
||||
const extensions: Extension[] = [
|
||||
{
|
||||
config: {
|
||||
name: 'ext1',
|
||||
version: '1.0.0',
|
||||
excludeTools: ['tool3', 'tool4'],
|
||||
},
|
||||
contextFiles: [],
|
||||
},
|
||||
{
|
||||
config: {
|
||||
name: 'ext2',
|
||||
version: '1.0.0',
|
||||
excludeTools: ['tool5'],
|
||||
},
|
||||
contextFiles: [],
|
||||
},
|
||||
];
|
||||
const config = await loadCliConfig(settings, extensions, 'test-session');
|
||||
expect(config.getExcludeTools()).toEqual(
|
||||
expect.arrayContaining(['tool1', 'tool2', 'tool3', 'tool4', 'tool5']),
|
||||
);
|
||||
expect(config.getExcludeTools()).toHaveLength(5);
|
||||
});
|
||||
|
||||
it('should handle overlapping excludeTools between settings and extensions', async () => {
|
||||
const settings: Settings = { excludeTools: ['tool1', 'tool2'] };
|
||||
const extensions: Extension[] = [
|
||||
{
|
||||
config: {
|
||||
name: 'ext1',
|
||||
version: '1.0.0',
|
||||
excludeTools: ['tool2', 'tool3'],
|
||||
},
|
||||
contextFiles: [],
|
||||
},
|
||||
];
|
||||
const config = await loadCliConfig(settings, extensions, 'test-session');
|
||||
expect(config.getExcludeTools()).toEqual(
|
||||
expect.arrayContaining(['tool1', 'tool2', 'tool3']),
|
||||
);
|
||||
expect(config.getExcludeTools()).toHaveLength(3);
|
||||
});
|
||||
|
||||
it('should handle overlapping excludeTools between extensions', async () => {
|
||||
const settings: Settings = { excludeTools: ['tool1'] };
|
||||
const extensions: Extension[] = [
|
||||
{
|
||||
config: {
|
||||
name: 'ext1',
|
||||
version: '1.0.0',
|
||||
excludeTools: ['tool2', 'tool3'],
|
||||
},
|
||||
contextFiles: [],
|
||||
},
|
||||
{
|
||||
config: {
|
||||
name: 'ext2',
|
||||
version: '1.0.0',
|
||||
excludeTools: ['tool3', 'tool4'],
|
||||
},
|
||||
contextFiles: [],
|
||||
},
|
||||
];
|
||||
const config = await loadCliConfig(settings, extensions, 'test-session');
|
||||
expect(config.getExcludeTools()).toEqual(
|
||||
expect.arrayContaining(['tool1', 'tool2', 'tool3', 'tool4']),
|
||||
);
|
||||
expect(config.getExcludeTools()).toHaveLength(4);
|
||||
});
|
||||
|
||||
it('should return an empty array when no excludeTools are specified', async () => {
|
||||
const settings: Settings = {};
|
||||
const extensions: Extension[] = [];
|
||||
const config = await loadCliConfig(settings, extensions, 'test-session');
|
||||
expect(config.getExcludeTools()).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle settings with excludeTools but no extensions', async () => {
|
||||
const settings: Settings = { excludeTools: ['tool1', 'tool2'] };
|
||||
const extensions: Extension[] = [];
|
||||
const config = await loadCliConfig(settings, extensions, 'test-session');
|
||||
expect(config.getExcludeTools()).toEqual(
|
||||
expect.arrayContaining(['tool1', 'tool2']),
|
||||
);
|
||||
expect(config.getExcludeTools()).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('should handle extensions with excludeTools but no settings', async () => {
|
||||
const settings: Settings = {};
|
||||
const extensions: Extension[] = [
|
||||
{
|
||||
config: {
|
||||
name: 'ext1',
|
||||
version: '1.0.0',
|
||||
excludeTools: ['tool1', 'tool2'],
|
||||
},
|
||||
contextFiles: [],
|
||||
},
|
||||
];
|
||||
const config = await loadCliConfig(settings, extensions, 'test-session');
|
||||
expect(config.getExcludeTools()).toEqual(
|
||||
expect.arrayContaining(['tool1', 'tool2']),
|
||||
);
|
||||
expect(config.getExcludeTools()).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('should not modify the original settings object', async () => {
|
||||
const settings: Settings = { excludeTools: ['tool1'] };
|
||||
const extensions: Extension[] = [
|
||||
{
|
||||
config: {
|
||||
name: 'ext1',
|
||||
version: '1.0.0',
|
||||
excludeTools: ['tool2'],
|
||||
},
|
||||
contextFiles: [],
|
||||
},
|
||||
];
|
||||
const originalSettings = JSON.parse(JSON.stringify(settings));
|
||||
await loadCliConfig(settings, extensions, 'test-session');
|
||||
expect(settings).toEqual(originalSettings);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -194,6 +194,7 @@ export async function loadCliConfig(
|
||||
);
|
||||
|
||||
const mcpServers = mergeMcpServers(settings, extensions);
|
||||
const excludeTools = mergeExcludeTools(settings, extensions);
|
||||
|
||||
const sandboxConfig = await loadSandboxConfig(settings, argv);
|
||||
|
||||
@@ -206,7 +207,7 @@ export async function loadCliConfig(
|
||||
question: argv.prompt || '',
|
||||
fullContext: argv.all_files || false,
|
||||
coreTools: settings.coreTools || undefined,
|
||||
excludeTools: settings.excludeTools || undefined,
|
||||
excludeTools,
|
||||
toolDiscoveryCommand: settings.toolDiscoveryCommand,
|
||||
toolCallCommand: settings.toolCallCommand,
|
||||
mcpServerCommand: settings.mcpServerCommand,
|
||||
@@ -265,6 +266,20 @@ function mergeMcpServers(settings: Settings, extensions: Extension[]) {
|
||||
}
|
||||
return mcpServers;
|
||||
}
|
||||
|
||||
function mergeExcludeTools(
|
||||
settings: Settings,
|
||||
extensions: Extension[],
|
||||
): string[] {
|
||||
const allExcludeTools = new Set(settings.excludeTools || []);
|
||||
for (const extension of extensions) {
|
||||
for (const tool of extension.config.excludeTools || []) {
|
||||
allExcludeTools.add(tool);
|
||||
}
|
||||
}
|
||||
return [...allExcludeTools];
|
||||
}
|
||||
|
||||
function findEnvFile(startDir: string): string | null {
|
||||
let currentDir = path.resolve(startDir);
|
||||
while (true) {
|
||||
|
||||
@@ -22,6 +22,7 @@ export interface ExtensionConfig {
|
||||
version: string;
|
||||
mcpServers?: Record<string, MCPServerConfig>;
|
||||
contextFileName?: string | string[];
|
||||
excludeTools?: string[];
|
||||
}
|
||||
|
||||
export function loadExtensions(workspaceDir: string): Extension[] {
|
||||
|
||||
@@ -113,7 +113,7 @@ export const InputPrompt: React.FC<InputPromptProps> = ({
|
||||
return;
|
||||
}
|
||||
const query = buffer.text;
|
||||
const selectedSuggestion = completionSuggestions[indexToUse];
|
||||
const suggestion = completionSuggestions[indexToUse].value;
|
||||
|
||||
if (query.trimStart().startsWith('/')) {
|
||||
const parts = query.trimStart().substring(1).split(' ');
|
||||
@@ -122,11 +122,16 @@ export const InputPrompt: React.FC<InputPromptProps> = ({
|
||||
const base = query.substring(0, slashIndex + 1);
|
||||
|
||||
const command = slashCommands.find((cmd) => cmd.name === commandName);
|
||||
if (command && command.completion) {
|
||||
const newValue = `${base}${commandName} ${selectedSuggestion.value}`;
|
||||
buffer.setText(newValue);
|
||||
// Make sure completion isn't the original command when command.completigion hasn't happened yet.
|
||||
if (command && command.completion && suggestion !== commandName) {
|
||||
const newValue = `${base}${commandName} ${suggestion}`;
|
||||
if (newValue === query) {
|
||||
handleSubmitAndClear(newValue);
|
||||
} else {
|
||||
buffer.setText(newValue);
|
||||
}
|
||||
} else {
|
||||
const newValue = base + selectedSuggestion.value;
|
||||
const newValue = base + suggestion;
|
||||
buffer.setText(newValue);
|
||||
handleSubmitAndClear(newValue);
|
||||
}
|
||||
@@ -142,7 +147,7 @@ export const InputPrompt: React.FC<InputPromptProps> = ({
|
||||
buffer.replaceRangeByOffset(
|
||||
autoCompleteStartIndex,
|
||||
buffer.text.length,
|
||||
selectedSuggestion.value,
|
||||
suggestion,
|
||||
);
|
||||
}
|
||||
resetCompletionState();
|
||||
|
||||
@@ -420,6 +420,7 @@ export function useTextBuffer({
|
||||
const [undoStack, setUndoStack] = useState<UndoHistoryEntry[]>([]);
|
||||
const [redoStack, setRedoStack] = useState<UndoHistoryEntry[]>([]);
|
||||
const historyLimit = 100;
|
||||
const [opQueue, setOpQueue] = useState<UpdateOperation[]>([]);
|
||||
|
||||
const [clipboard, setClipboard] = useState<string | null>(null);
|
||||
const [selectionAnchor, setSelectionAnchor] = useState<
|
||||
@@ -526,148 +527,110 @@ export function useTextBuffer({
|
||||
return _restoreState(state);
|
||||
}, [redoStack, lines, cursorRow, cursorCol, _restoreState]);
|
||||
|
||||
const insertStr = useCallback(
|
||||
(str: string): boolean => {
|
||||
dbg('insertStr', { str, beforeCursor: [cursorRow, cursorCol] });
|
||||
if (str === '') return false;
|
||||
const applyOperations = useCallback((ops: UpdateOperation[]) => {
|
||||
if (ops.length === 0) return;
|
||||
setOpQueue((prev) => [...prev, ...ops]);
|
||||
}, []);
|
||||
|
||||
pushUndo();
|
||||
let normalised = str.replace(/\r\n/g, '\n').replace(/\r/g, '\n');
|
||||
normalised = stripUnsafeCharacters(normalised);
|
||||
useEffect(() => {
|
||||
if (opQueue.length === 0) return;
|
||||
|
||||
const parts = normalised.split('\n');
|
||||
|
||||
const newLines = [...lines];
|
||||
const lineContent = currentLine(cursorRow);
|
||||
const before = cpSlice(lineContent, 0, cursorCol);
|
||||
const after = cpSlice(lineContent, cursorCol);
|
||||
newLines[cursorRow] = before + parts[0];
|
||||
|
||||
if (parts.length > 1) {
|
||||
// Adjusted condition for inserting multiple lines
|
||||
const remainingParts = parts.slice(1);
|
||||
const lastPartOriginal = remainingParts.pop() ?? '';
|
||||
newLines.splice(cursorRow + 1, 0, ...remainingParts);
|
||||
newLines.splice(
|
||||
cursorRow + parts.length - 1,
|
||||
0,
|
||||
lastPartOriginal + after,
|
||||
);
|
||||
setCursorRow(cursorRow + parts.length - 1);
|
||||
setCursorCol(cpLen(lastPartOriginal));
|
||||
} else {
|
||||
setCursorCol(cpLen(before) + cpLen(parts[0]));
|
||||
}
|
||||
setLines(newLines);
|
||||
setPreferredCol(null);
|
||||
return true;
|
||||
},
|
||||
[pushUndo, cursorRow, cursorCol, lines, currentLine, setPreferredCol],
|
||||
);
|
||||
|
||||
const applyOperations = useCallback(
|
||||
(ops: UpdateOperation[]) => {
|
||||
if (ops.length === 0) return;
|
||||
|
||||
const expandedOps: UpdateOperation[] = [];
|
||||
for (const op of ops) {
|
||||
if (op.type === 'insert') {
|
||||
let currentText = '';
|
||||
for (const char of toCodePoints(op.payload)) {
|
||||
if (char.codePointAt(0) === 127) {
|
||||
// \x7f
|
||||
if (currentText.length > 0) {
|
||||
expandedOps.push({ type: 'insert', payload: currentText });
|
||||
currentText = '';
|
||||
}
|
||||
expandedOps.push({ type: 'backspace' });
|
||||
} else {
|
||||
currentText += char;
|
||||
const expandedOps: UpdateOperation[] = [];
|
||||
for (const op of opQueue) {
|
||||
if (op.type === 'insert') {
|
||||
let currentText = '';
|
||||
for (const char of toCodePoints(op.payload)) {
|
||||
if (char.codePointAt(0) === 127) {
|
||||
// \x7f
|
||||
if (currentText.length > 0) {
|
||||
expandedOps.push({ type: 'insert', payload: currentText });
|
||||
currentText = '';
|
||||
}
|
||||
}
|
||||
if (currentText.length > 0) {
|
||||
expandedOps.push({ type: 'insert', payload: currentText });
|
||||
}
|
||||
} else {
|
||||
expandedOps.push(op);
|
||||
}
|
||||
}
|
||||
|
||||
if (expandedOps.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
pushUndo(); // Snapshot before applying batch of updates
|
||||
|
||||
const newLines = [...lines];
|
||||
let newCursorRow = cursorRow;
|
||||
let newCursorCol = cursorCol;
|
||||
|
||||
const currentLine = (r: number) => newLines[r] ?? '';
|
||||
|
||||
for (const op of expandedOps) {
|
||||
if (op.type === 'insert') {
|
||||
const str = stripUnsafeCharacters(
|
||||
op.payload.replace(/\r\n/g, '\n').replace(/\r/g, '\n'),
|
||||
);
|
||||
const parts = str.split('\n');
|
||||
const lineContent = currentLine(newCursorRow);
|
||||
const before = cpSlice(lineContent, 0, newCursorCol);
|
||||
const after = cpSlice(lineContent, newCursorCol);
|
||||
|
||||
if (parts.length > 1) {
|
||||
newLines[newCursorRow] = before + parts[0];
|
||||
const remainingParts = parts.slice(1);
|
||||
const lastPartOriginal = remainingParts.pop() ?? '';
|
||||
newLines.splice(newCursorRow + 1, 0, ...remainingParts);
|
||||
newLines.splice(
|
||||
newCursorRow + parts.length - 1,
|
||||
0,
|
||||
lastPartOriginal + after,
|
||||
);
|
||||
newCursorRow = newCursorRow + parts.length - 1;
|
||||
newCursorCol = cpLen(lastPartOriginal);
|
||||
expandedOps.push({ type: 'backspace' });
|
||||
} else {
|
||||
newLines[newCursorRow] = before + parts[0] + after;
|
||||
|
||||
newCursorCol = cpLen(before) + cpLen(parts[0]);
|
||||
}
|
||||
} else if (op.type === 'backspace') {
|
||||
if (newCursorCol === 0 && newCursorRow === 0) continue;
|
||||
|
||||
if (newCursorCol > 0) {
|
||||
const lineContent = currentLine(newCursorRow);
|
||||
newLines[newCursorRow] =
|
||||
cpSlice(lineContent, 0, newCursorCol - 1) +
|
||||
cpSlice(lineContent, newCursorCol);
|
||||
newCursorCol--;
|
||||
} else if (newCursorRow > 0) {
|
||||
const prevLineContent = currentLine(newCursorRow - 1);
|
||||
const currentLineContentVal = currentLine(newCursorRow);
|
||||
const newCol = cpLen(prevLineContent);
|
||||
newLines[newCursorRow - 1] =
|
||||
prevLineContent + currentLineContentVal;
|
||||
newLines.splice(newCursorRow, 1);
|
||||
newCursorRow--;
|
||||
newCursorCol = newCol;
|
||||
currentText += char;
|
||||
}
|
||||
}
|
||||
if (currentText.length > 0) {
|
||||
expandedOps.push({ type: 'insert', payload: currentText });
|
||||
}
|
||||
} else {
|
||||
expandedOps.push(op);
|
||||
}
|
||||
}
|
||||
|
||||
setLines(newLines);
|
||||
setCursorRow(newCursorRow);
|
||||
setCursorCol(newCursorCol);
|
||||
setPreferredCol(null);
|
||||
},
|
||||
[lines, cursorRow, cursorCol, pushUndo, setPreferredCol],
|
||||
);
|
||||
if (expandedOps.length === 0) {
|
||||
setOpQueue([]); // Clear queue even if ops were no-ops
|
||||
return;
|
||||
}
|
||||
|
||||
pushUndo(); // Snapshot before applying batch of updates
|
||||
|
||||
const newLines = [...lines];
|
||||
let newCursorRow = cursorRow;
|
||||
let newCursorCol = cursorCol;
|
||||
|
||||
const currentLine = (r: number) => newLines[r] ?? '';
|
||||
|
||||
for (const op of expandedOps) {
|
||||
if (op.type === 'insert') {
|
||||
const str = stripUnsafeCharacters(
|
||||
op.payload.replace(/\r\n/g, '\n').replace(/\r/g, '\n'),
|
||||
);
|
||||
const parts = str.split('\n');
|
||||
const lineContent = currentLine(newCursorRow);
|
||||
const before = cpSlice(lineContent, 0, newCursorCol);
|
||||
const after = cpSlice(lineContent, newCursorCol);
|
||||
|
||||
if (parts.length > 1) {
|
||||
newLines[newCursorRow] = before + parts[0];
|
||||
const remainingParts = parts.slice(1);
|
||||
const lastPartOriginal = remainingParts.pop() ?? '';
|
||||
newLines.splice(newCursorRow + 1, 0, ...remainingParts);
|
||||
newLines.splice(
|
||||
newCursorRow + parts.length - 1,
|
||||
0,
|
||||
lastPartOriginal + after,
|
||||
);
|
||||
newCursorRow = newCursorRow + parts.length - 1;
|
||||
newCursorCol = cpLen(lastPartOriginal);
|
||||
} else {
|
||||
newLines[newCursorRow] = before + parts[0] + after;
|
||||
|
||||
newCursorCol = cpLen(before) + cpLen(parts[0]);
|
||||
}
|
||||
} else if (op.type === 'backspace') {
|
||||
if (newCursorCol === 0 && newCursorRow === 0) continue;
|
||||
|
||||
if (newCursorCol > 0) {
|
||||
const lineContent = currentLine(newCursorRow);
|
||||
newLines[newCursorRow] =
|
||||
cpSlice(lineContent, 0, newCursorCol - 1) +
|
||||
cpSlice(lineContent, newCursorCol);
|
||||
newCursorCol--;
|
||||
} else if (newCursorRow > 0) {
|
||||
const prevLineContent = currentLine(newCursorRow - 1);
|
||||
const currentLineContentVal = currentLine(newCursorRow);
|
||||
const newCol = cpLen(prevLineContent);
|
||||
newLines[newCursorRow - 1] = prevLineContent + currentLineContentVal;
|
||||
newLines.splice(newCursorRow, 1);
|
||||
newCursorRow--;
|
||||
newCursorCol = newCol;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setLines(newLines);
|
||||
setCursorRow(newCursorRow);
|
||||
setCursorCol(newCursorCol);
|
||||
setPreferredCol(null);
|
||||
|
||||
// Clear the queue after processing
|
||||
setOpQueue((prev) => prev.slice(opQueue.length));
|
||||
}, [opQueue, lines, cursorRow, cursorCol, pushUndo, setPreferredCol]);
|
||||
|
||||
const insert = useCallback(
|
||||
(ch: string): void => {
|
||||
if (/[\n\r]/.test(ch)) {
|
||||
insertStr(ch);
|
||||
return;
|
||||
}
|
||||
dbg('insert', { ch, beforeCursor: [cursorRow, cursorCol] });
|
||||
|
||||
ch = stripUnsafeCharacters(ch);
|
||||
@@ -694,7 +657,7 @@ export function useTextBuffer({
|
||||
}
|
||||
applyOperations([{ type: 'insert', payload: ch }]);
|
||||
},
|
||||
[applyOperations, cursorRow, cursorCol, isValidPath, insertStr],
|
||||
[applyOperations, cursorRow, cursorCol, isValidPath],
|
||||
);
|
||||
|
||||
const newline = useCallback((): void => {
|
||||
@@ -1397,8 +1360,9 @@ export function useTextBuffer({
|
||||
}, [selectionAnchor, cursorRow, cursorCol, currentLine, setClipboard]),
|
||||
paste: useCallback(() => {
|
||||
if (clipboard === null) return false;
|
||||
return insertStr(clipboard);
|
||||
}, [clipboard, insertStr]),
|
||||
applyOperations([{ type: 'insert', payload: clipboard }]);
|
||||
return true;
|
||||
}, [clipboard, applyOperations]),
|
||||
startSelection: useCallback(
|
||||
() => setSelectionAnchor([cursorRow, cursorCol]),
|
||||
[cursorRow, cursorCol, setSelectionAnchor],
|
||||
|
||||
@@ -135,7 +135,8 @@ export function useCompletion(
|
||||
(cmd) => cmd.name === commandName || cmd.altName === commandName,
|
||||
);
|
||||
|
||||
if (command && command.completion) {
|
||||
// Continue to show command help until user types past command name.
|
||||
if (command && command.completion && parts.length > 1) {
|
||||
const fetchAndSetSuggestions = async () => {
|
||||
setIsLoadingSuggestions(true);
|
||||
if (command.completion) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@google/gemini-cli-core",
|
||||
"version": "0.1.8",
|
||||
"version": "0.1.9",
|
||||
"description": "Gemini CLI Server",
|
||||
"repository": "google-gemini/gemini-cli",
|
||||
"type": "module",
|
||||
|
||||
@@ -12,11 +12,12 @@ import { CodeAssistServer, HttpOptions } from './server.js';
|
||||
export async function createCodeAssistContentGenerator(
|
||||
httpOptions: HttpOptions,
|
||||
authType: AuthType,
|
||||
sessionId?: string,
|
||||
): Promise<ContentGenerator> {
|
||||
if (authType === AuthType.LOGIN_WITH_GOOGLE) {
|
||||
const authClient = await getOauthClient();
|
||||
const projectId = await setupUser(authClient);
|
||||
return new CodeAssistServer(authClient, projectId, httpOptions);
|
||||
return new CodeAssistServer(authClient, projectId, httpOptions, sessionId);
|
||||
}
|
||||
|
||||
throw new Error(`Unsupported authType: ${authType}`);
|
||||
|
||||
@@ -37,6 +37,7 @@ describe('converter', () => {
|
||||
labels: undefined,
|
||||
safetySettings: undefined,
|
||||
generationConfig: undefined,
|
||||
session_id: undefined,
|
||||
},
|
||||
});
|
||||
});
|
||||
@@ -59,6 +60,34 @@ describe('converter', () => {
|
||||
labels: undefined,
|
||||
safetySettings: undefined,
|
||||
generationConfig: undefined,
|
||||
session_id: undefined,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should convert a request with sessionId', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-project',
|
||||
'session-123',
|
||||
);
|
||||
expect(codeAssistReq).toEqual({
|
||||
model: 'gemini-pro',
|
||||
project: 'my-project',
|
||||
request: {
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
systemInstruction: undefined,
|
||||
cachedContent: undefined,
|
||||
tools: undefined,
|
||||
toolConfig: undefined,
|
||||
labels: undefined,
|
||||
safetySettings: undefined,
|
||||
generationConfig: undefined,
|
||||
session_id: 'session-123',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
@@ -44,6 +44,7 @@ interface VertexGenerateContentRequest {
|
||||
labels?: Record<string, string>;
|
||||
safetySettings?: SafetySetting[];
|
||||
generationConfig?: VertexGenerationConfig;
|
||||
session_id?: string;
|
||||
}
|
||||
|
||||
interface VertexGenerationConfig {
|
||||
@@ -114,11 +115,12 @@ export function fromCountTokenResponse(
|
||||
export function toGenerateContentRequest(
|
||||
req: GenerateContentParameters,
|
||||
project?: string,
|
||||
sessionId?: string,
|
||||
): CAGenerateContentRequest {
|
||||
return {
|
||||
model: req.model,
|
||||
project,
|
||||
request: toVertexGenerateContentRequest(req),
|
||||
request: toVertexGenerateContentRequest(req, sessionId),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -136,6 +138,7 @@ export function fromGenerateContentResponse(
|
||||
|
||||
function toVertexGenerateContentRequest(
|
||||
req: GenerateContentParameters,
|
||||
sessionId?: string,
|
||||
): VertexGenerateContentRequest {
|
||||
return {
|
||||
contents: toContents(req.contents),
|
||||
@@ -146,6 +149,7 @@ function toVertexGenerateContentRequest(
|
||||
labels: req.config?.labels,
|
||||
safetySettings: req.config?.safetySettings,
|
||||
generationConfig: toVertexGenerationConfig(req.config),
|
||||
session_id: sessionId,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ export class CodeAssistServer implements ContentGenerator {
|
||||
readonly client: OAuth2Client,
|
||||
readonly projectId?: string,
|
||||
readonly httpOptions: HttpOptions = {},
|
||||
readonly sessionId?: string,
|
||||
) {}
|
||||
|
||||
async generateContentStream(
|
||||
@@ -55,7 +56,7 @@ export class CodeAssistServer implements ContentGenerator {
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
const resps = await this.requestStreamingPost<CaGenerateContentResponse>(
|
||||
'streamGenerateContent',
|
||||
toGenerateContentRequest(req, this.projectId),
|
||||
toGenerateContentRequest(req, this.projectId, this.sessionId),
|
||||
req.config?.abortSignal,
|
||||
);
|
||||
return (async function* (): AsyncGenerator<GenerateContentResponse> {
|
||||
@@ -70,7 +71,7 @@ export class CodeAssistServer implements ContentGenerator {
|
||||
): Promise<GenerateContentResponse> {
|
||||
const resp = await this.requestPost<CaGenerateContentResponse>(
|
||||
'generateContent',
|
||||
toGenerateContentRequest(req, this.projectId),
|
||||
toGenerateContentRequest(req, this.projectId, this.sessionId),
|
||||
req.config?.abortSignal,
|
||||
);
|
||||
return fromGenerateContentResponse(resp);
|
||||
|
||||
@@ -402,5 +402,183 @@ describe('Gemini Client (client.ts)', () => {
|
||||
// Assert
|
||||
expect(finalResult).toBeInstanceOf(Turn);
|
||||
});
|
||||
|
||||
it('should stop infinite loop after MAX_TURNS when nextSpeaker always returns model', async () => {
|
||||
// Get the mocked checkNextSpeaker function and configure it to trigger infinite loop
|
||||
const { checkNextSpeaker } = await import(
|
||||
'../utils/nextSpeakerChecker.js'
|
||||
);
|
||||
const mockCheckNextSpeaker = vi.mocked(checkNextSpeaker);
|
||||
mockCheckNextSpeaker.mockResolvedValue({
|
||||
next_speaker: 'model',
|
||||
reasoning: 'Test case - always continue',
|
||||
});
|
||||
|
||||
// Mock Turn to have no pending tool calls (which would allow nextSpeaker check)
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'Continue...' };
|
||||
})();
|
||||
mockTurnRunFn.mockReturnValue(mockStream);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
// Use a signal that never gets aborted
|
||||
const abortController = new AbortController();
|
||||
const signal = abortController.signal;
|
||||
|
||||
// Act - Start the stream that should loop
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Start conversation' }],
|
||||
signal,
|
||||
);
|
||||
|
||||
// Count how many stream events we get
|
||||
let eventCount = 0;
|
||||
let finalResult: Turn | undefined;
|
||||
|
||||
// Consume the stream and count iterations
|
||||
while (true) {
|
||||
const result = await stream.next();
|
||||
if (result.done) {
|
||||
finalResult = result.value;
|
||||
break;
|
||||
}
|
||||
eventCount++;
|
||||
|
||||
// Safety check to prevent actual infinite loop in test
|
||||
if (eventCount > 200) {
|
||||
abortController.abort();
|
||||
throw new Error(
|
||||
'Test exceeded expected event limit - possible actual infinite loop',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Assert
|
||||
expect(finalResult).toBeInstanceOf(Turn);
|
||||
|
||||
// Debug: Check how many times checkNextSpeaker was called
|
||||
const callCount = mockCheckNextSpeaker.mock.calls.length;
|
||||
|
||||
// If infinite loop protection is working, checkNextSpeaker should be called many times
|
||||
// but stop at MAX_TURNS (100). Since each recursive call should trigger checkNextSpeaker,
|
||||
// we expect it to be called multiple times before hitting the limit
|
||||
expect(mockCheckNextSpeaker).toHaveBeenCalled();
|
||||
|
||||
// The test should demonstrate that the infinite loop protection works:
|
||||
// - If checkNextSpeaker is called many times (close to MAX_TURNS), it shows the loop was happening
|
||||
// - If it's only called once, the recursive behavior might not be triggered
|
||||
if (callCount === 0) {
|
||||
throw new Error(
|
||||
'checkNextSpeaker was never called - the recursive condition was not met',
|
||||
);
|
||||
} else if (callCount === 1) {
|
||||
// This might be expected behavior if the turn has pending tool calls or other conditions prevent recursion
|
||||
console.log(
|
||||
'checkNextSpeaker called only once - no infinite loop occurred',
|
||||
);
|
||||
} else {
|
||||
console.log(
|
||||
`checkNextSpeaker called ${callCount} times - infinite loop protection worked`,
|
||||
);
|
||||
// If called multiple times, we expect it to be stopped before MAX_TURNS
|
||||
expect(callCount).toBeLessThanOrEqual(100); // Should not exceed MAX_TURNS
|
||||
}
|
||||
|
||||
// The stream should produce events and eventually terminate
|
||||
expect(eventCount).toBeGreaterThanOrEqual(1);
|
||||
expect(eventCount).toBeLessThan(200); // Should not exceed our safety limit
|
||||
});
|
||||
|
||||
it('should respect MAX_TURNS limit even when turns parameter is set to a large value', async () => {
|
||||
// This test verifies that the infinite loop protection works even when
|
||||
// someone tries to bypass it by calling with a very large turns value
|
||||
|
||||
// Get the mocked checkNextSpeaker function and configure it to trigger infinite loop
|
||||
const { checkNextSpeaker } = await import(
|
||||
'../utils/nextSpeakerChecker.js'
|
||||
);
|
||||
const mockCheckNextSpeaker = vi.mocked(checkNextSpeaker);
|
||||
mockCheckNextSpeaker.mockResolvedValue({
|
||||
next_speaker: 'model',
|
||||
reasoning: 'Test case - always continue',
|
||||
});
|
||||
|
||||
// Mock Turn to have no pending tool calls (which would allow nextSpeaker check)
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'Continue...' };
|
||||
})();
|
||||
mockTurnRunFn.mockReturnValue(mockStream);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
// Use a signal that never gets aborted
|
||||
const abortController = new AbortController();
|
||||
const signal = abortController.signal;
|
||||
|
||||
// Act - Start the stream with an extremely high turns value
|
||||
// This simulates a case where the turns protection is bypassed
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Start conversation' }],
|
||||
signal,
|
||||
Number.MAX_SAFE_INTEGER, // Bypass the MAX_TURNS protection
|
||||
);
|
||||
|
||||
// Count how many stream events we get
|
||||
let eventCount = 0;
|
||||
const maxTestIterations = 1000; // Higher limit to show the loop continues
|
||||
|
||||
// Consume the stream and count iterations
|
||||
try {
|
||||
while (true) {
|
||||
const result = await stream.next();
|
||||
if (result.done) {
|
||||
break;
|
||||
}
|
||||
eventCount++;
|
||||
|
||||
// This test should hit this limit, demonstrating the infinite loop
|
||||
if (eventCount > maxTestIterations) {
|
||||
abortController.abort();
|
||||
// This is the expected behavior - we hit the infinite loop
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// If the test framework times out, that also demonstrates the infinite loop
|
||||
console.error('Test timed out or errored:', error);
|
||||
}
|
||||
|
||||
// Assert that the fix works - the loop should stop at MAX_TURNS
|
||||
const callCount = mockCheckNextSpeaker.mock.calls.length;
|
||||
|
||||
// With the fix: even when turns is set to a very high value,
|
||||
// the loop should stop at MAX_TURNS (100)
|
||||
expect(callCount).toBeLessThanOrEqual(100); // Should not exceed MAX_TURNS
|
||||
expect(eventCount).toBeLessThanOrEqual(200); // Should have reasonable number of events
|
||||
|
||||
console.log(
|
||||
`Infinite loop protection working: checkNextSpeaker called ${callCount} times, ` +
|
||||
`${eventCount} events generated (properly bounded by MAX_TURNS)`,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -68,6 +68,7 @@ export class GeminiClient {
|
||||
async initialize(contentGeneratorConfig: ContentGeneratorConfig) {
|
||||
this.contentGenerator = await createContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
this.config.getSessionId(),
|
||||
);
|
||||
this.chat = await this.startChat();
|
||||
}
|
||||
@@ -219,7 +220,9 @@ export class GeminiClient {
|
||||
signal: AbortSignal,
|
||||
turns: number = this.MAX_TURNS,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
if (!turns) {
|
||||
// Ensure turns never exceeds MAX_TURNS to prevent infinite loops
|
||||
const boundedTurns = Math.min(turns, this.MAX_TURNS);
|
||||
if (!boundedTurns) {
|
||||
return new Turn(this.getChat());
|
||||
}
|
||||
|
||||
@@ -242,7 +245,7 @@ export class GeminiClient {
|
||||
const nextRequest = [{ text: 'Please continue.' }];
|
||||
// This recursive call's events will be yielded out, but the final
|
||||
// turn object will be from the top-level call.
|
||||
yield* this.sendMessageStream(nextRequest, signal, turns - 1);
|
||||
yield* this.sendMessageStream(nextRequest, signal, boundedTurns - 1);
|
||||
}
|
||||
}
|
||||
return turn;
|
||||
|
||||
@@ -101,6 +101,7 @@ export async function createContentGeneratorConfig(
|
||||
|
||||
export async function createContentGenerator(
|
||||
config: ContentGeneratorConfig,
|
||||
sessionId?: string,
|
||||
): Promise<ContentGenerator> {
|
||||
const version = process.env.CLI_VERSION || process.version;
|
||||
const httpOptions = {
|
||||
@@ -109,7 +110,11 @@ export async function createContentGenerator(
|
||||
},
|
||||
};
|
||||
if (config.authType === AuthType.LOGIN_WITH_GOOGLE) {
|
||||
return createCodeAssistContentGenerator(httpOptions, config.authType);
|
||||
return createCodeAssistContentGenerator(
|
||||
httpOptions,
|
||||
config.authType,
|
||||
sessionId,
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
|
||||
@@ -196,6 +196,11 @@ describe('fileUtils', () => {
|
||||
vi.restoreAllMocks(); // Restore spies on actualNodeFs
|
||||
});
|
||||
|
||||
it('should detect typescript type by extension (ts)', () => {
|
||||
expect(detectFileType('file.ts')).toBe('text');
|
||||
expect(detectFileType('file.test.ts')).toBe('text');
|
||||
});
|
||||
|
||||
it('should detect image type by extension (png)', () => {
|
||||
mockMimeLookup.mockReturnValueOnce('image/png');
|
||||
expect(detectFileType('file.png')).toBe('image');
|
||||
|
||||
@@ -100,8 +100,14 @@ export function detectFileType(
|
||||
filePath: string,
|
||||
): 'text' | 'image' | 'pdf' | 'audio' | 'video' | 'binary' {
|
||||
const ext = path.extname(filePath).toLowerCase();
|
||||
const lookedUpMimeType = mime.lookup(filePath); // Returns false if not found, or the mime type string
|
||||
|
||||
// The mimetype for "ts" is MPEG transport stream (a video format) but we want
|
||||
// to assume these are typescript files instead.
|
||||
if (ext === '.ts') {
|
||||
return 'text';
|
||||
}
|
||||
|
||||
const lookedUpMimeType = mime.lookup(filePath); // Returns false if not found, or the mime type string
|
||||
if (lookedUpMimeType) {
|
||||
if (lookedUpMimeType.startsWith('image/')) {
|
||||
return 'image';
|
||||
|
||||
Reference in New Issue
Block a user