diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index ac8d9f1bd6..b7e85962a5 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -291,6 +291,7 @@ describe('Gemini Client (client.ts)', () => { it('should call chat.addHistory with the provided content', async () => { const mockChat = { addHistory: vi.fn(), + setTools: vi.fn(), } as unknown as GeminiChat; client['chat'] = mockChat; @@ -389,6 +390,7 @@ describe('Gemini Client (client.ts)', () => { getHistory: mockGetHistory, addHistory: vi.fn(), setHistory: vi.fn(), + setTools: vi.fn(), getLastPromptTokenCount: vi.fn(), } as unknown as GeminiChat; }); @@ -805,6 +807,7 @@ describe('Gemini Client (client.ts)', () => { const mockChat = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), } as unknown as GeminiChat; @@ -868,6 +871,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -926,6 +930,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1003,6 +1008,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1119,6 +1125,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1167,6 +1174,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1232,6 +1240,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1289,6 +1298,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1349,6 +1359,7 @@ ${JSON.stringify( const lastPromptTokenCount = 900; const mockChat: Partial = { getLastPromptTokenCount: vi.fn().mockReturnValue(lastPromptTokenCount), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), }; client['chat'] = mockChat as GeminiChat; @@ -1409,6 +1420,7 @@ ${JSON.stringify( const lastPromptTokenCount = 900; const mockChat: Partial = { getLastPromptTokenCount: vi.fn().mockReturnValue(lastPromptTokenCount), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), }; client['chat'] = mockChat as GeminiChat; @@ -1467,6 +1479,7 @@ ${JSON.stringify( .fn() .mockReturnValue([{ role: 'user', parts: [{ text: 'old' }] }]), addHistory: vi.fn(), + setTools: vi.fn(), getChatRecordingService: vi.fn().mockReturnValue({ getConversation: vi.fn(), getConversationFilePath: vi.fn(), @@ -1479,6 +1492,7 @@ ${JSON.stringify( .fn() .mockReturnValue([{ role: 'user', parts: [{ text: 'old' }] }]), addHistory: vi.fn(), + setTools: vi.fn(), getChatRecordingService: vi.fn().mockReturnValue({ getConversation: vi.fn(), getConversationFilePath: vi.fn(), @@ -1616,6 +1630,7 @@ ${JSON.stringify( const lastPromptTokenCount = 10000; const mockChat: Partial = { getLastPromptTokenCount: vi.fn().mockReturnValue(lastPromptTokenCount), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), }; client['chat'] = mockChat as GeminiChat; @@ -1689,6 +1704,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1892,6 +1908,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1947,6 +1964,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1984,6 +2002,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -2028,6 +2047,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), setHistory: vi.fn(), + setTools: vi.fn(), // Assume history is not empty for delta checks getHistory: vi .fn() @@ -2443,6 +2463,7 @@ ${JSON.stringify( addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), // Default empty history setHistory: vi.fn(), + setTools: vi.fn(), getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; @@ -2783,6 +2804,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -2820,6 +2842,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -2857,6 +2880,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -3069,6 +3093,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -3103,6 +3128,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index fdfa1defce..d6afeac4be 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -259,6 +259,10 @@ export class GeminiClient { private lastUsedModelId?: string; async setTools(modelId?: string): Promise { + if (!this.chat) { + return; + } + if (modelId && modelId === this.lastUsedModelId) { return; } @@ -346,6 +350,13 @@ export class GeminiClient { tools, history, resumedSessionData, + async (modelId: string) => { + this.lastUsedModelId = modelId; + const toolRegistry = this.config.getToolRegistry(); + const toolDeclarations = + toolRegistry.getFunctionDeclarations(modelId); + return [{ functionDeclarations: toolDeclarations }]; + }, ); } catch (error) { await reportError( diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 35a8e79855..6bea67dc0e 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -247,6 +247,7 @@ export class GeminiChat { private tools: Tool[] = [], private history: Content[] = [], resumedSessionData?: ResumedSessionData, + private readonly onModelChanged?: (modelId: string) => Promise, ) { validateHistory(history); this.chatRecordingService = new ChatRecordingService(config); @@ -581,7 +582,9 @@ export class GeminiChat { } // Track final request parameters for AfterModel hooks - await this.config.getGeminiClient().setTools(modelToUse); + if (this.onModelChanged) { + this.tools = await this.onModelChanged(modelToUse); + } lastModelToUse = modelToUse; lastConfig = config; lastContentsToUse = contentsToUse;