mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-06-12 20:37:08 -07:00
feat(core): add sendBtwStream for tool-less side inquiries
This commit is contained in:
committed by
Mahima Shanware
parent
34b4f1c6e4
commit
0bd797a2be
@@ -29,7 +29,7 @@ import { type AgentLoopContext } from '../config/agent-loop-context.js';
|
||||
import { getCoreSystemPrompt } from './prompts.js';
|
||||
import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js';
|
||||
import { reportError } from '../utils/errorReporting.js';
|
||||
import { GeminiChat } from './geminiChat.js';
|
||||
import { GeminiChat, StreamEventType } from './geminiChat.js';
|
||||
import {
|
||||
retryWithBackoff,
|
||||
type RetryAvailabilityContext,
|
||||
@@ -57,7 +57,7 @@ import type {
|
||||
import {
|
||||
ContentRetryFailureEvent,
|
||||
NextSpeakerCheckEvent,
|
||||
type LlmRole,
|
||||
LlmRole,
|
||||
} from '../telemetry/types.js';
|
||||
import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
|
||||
import type { IdeContext, File } from '../ide/types.js';
|
||||
@@ -71,8 +71,13 @@ import {
|
||||
applyModelSelection,
|
||||
createAvailabilityContextProvider,
|
||||
} from '../availability/policyHelpers.js';
|
||||
import { getDisplayString, resolveModel } from '../config/models.js';
|
||||
import { partToString } from '../utils/partUtils.js';
|
||||
import {
|
||||
getDisplayString,
|
||||
resolveModel,
|
||||
isGemini2Model,
|
||||
} from '../config/models.js';
|
||||
import { getResponseText, partToString } from '../utils/partUtils.js';
|
||||
import { parseThought } from '../utils/thoughtUtils.js';
|
||||
import { coreEvents, CoreEvent } from '../utils/events.js';
|
||||
|
||||
const MAX_TURNS = 100;
|
||||
@@ -1227,6 +1232,81 @@ export class GeminiClient {
|
||||
return info;
|
||||
}
|
||||
|
||||
async *sendBtwStream(
|
||||
request: PartListUnion,
|
||||
signal: AbortSignal,
|
||||
prompt_id: string,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
const turn = new Turn(this.getChat(), prompt_id);
|
||||
|
||||
// Availability/Routing logic simplified for BTW
|
||||
const modelToUse = this._getActiveModelForCurrentTurn();
|
||||
const modelConfigKey: ModelConfigKey = {
|
||||
model: modelToUse,
|
||||
isChatModel: true,
|
||||
};
|
||||
|
||||
yield { type: GeminiEventType.ModelInfo, value: modelToUse };
|
||||
|
||||
// Use a custom role for BTW to avoid side-effects in telemetry if needed,
|
||||
// but for now LlmRole.MAIN is fine as it's the primary model talking.
|
||||
const btwStream = this.getChat().sendBtwStream(
|
||||
modelConfigKey,
|
||||
request,
|
||||
prompt_id,
|
||||
signal,
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
|
||||
for await (const streamEvent of btwStream) {
|
||||
if (signal?.aborted) {
|
||||
yield { type: GeminiEventType.UserCancelled };
|
||||
return turn;
|
||||
}
|
||||
|
||||
if (streamEvent.type === 'retry') {
|
||||
yield { type: GeminiEventType.Retry };
|
||||
continue;
|
||||
}
|
||||
|
||||
if (streamEvent.type === StreamEventType.CHUNK) {
|
||||
const resp = streamEvent.value;
|
||||
if (!resp) continue;
|
||||
|
||||
const traceId = resp.responseId;
|
||||
const parts = resp.candidates?.[0]?.content?.parts ?? [];
|
||||
for (const part of parts) {
|
||||
if (part.thought) {
|
||||
const thought = parseThought(part.text ?? '');
|
||||
yield {
|
||||
type: GeminiEventType.Thought,
|
||||
value: thought,
|
||||
traceId,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const text = getResponseText(resp);
|
||||
if (text) {
|
||||
yield { type: GeminiEventType.Content, value: text, traceId };
|
||||
}
|
||||
|
||||
const finishReason = resp.candidates?.[0]?.finishReason;
|
||||
if (finishReason) {
|
||||
yield {
|
||||
type: GeminiEventType.Finished,
|
||||
value: {
|
||||
reason: finishReason,
|
||||
usageMetadata: resp.usageMetadata,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return turn;
|
||||
}
|
||||
|
||||
/**
|
||||
* Masks bulky tool outputs to save context window space.
|
||||
*/
|
||||
|
||||
@@ -2617,4 +2617,138 @@ describe('GeminiChat', () => {
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('sendBtwStream', () => {
|
||||
it('should reuse history but not update it', async () => {
|
||||
// 1. Setup initial history
|
||||
const initialHistory: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'Main question' }] },
|
||||
{ role: 'model', parts: [{ text: 'Main answer' }] },
|
||||
];
|
||||
chat.setHistory(initialHistory);
|
||||
|
||||
// 2. Mock API response for BTW
|
||||
const btwResponse = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Side answer' }], role: 'model' },
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||
btwResponse,
|
||||
);
|
||||
|
||||
// 3. Call sendBtwStream
|
||||
const stream = chat.sendBtwStream(
|
||||
{ model: 'test-model' },
|
||||
'Side question',
|
||||
'btw-prompt-id',
|
||||
new AbortController().signal,
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
for await (const _ of stream) {
|
||||
/* consume */
|
||||
}
|
||||
|
||||
// 4. Verify API was called with current history + side question
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
contents: [
|
||||
...initialHistory,
|
||||
{ role: 'user', parts: [{ text: 'Side question' }] },
|
||||
],
|
||||
config: expect.objectContaining({
|
||||
tools: [], // Should be empty for BTW
|
||||
}),
|
||||
}),
|
||||
'btw-prompt-id',
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
|
||||
// 5. CRITICAL: Verify persistent history was NOT updated
|
||||
const persistentHistory = chat.getHistory();
|
||||
expect(persistentHistory.length).toBe(2);
|
||||
expect(persistentHistory).toEqual(initialHistory);
|
||||
});
|
||||
|
||||
it('should not block or be blocked by a main sendMessageStream call if called concurrently', async () => {
|
||||
// This is a simplified test for concurrency logic within GeminiChat
|
||||
// Since sendBtwStream does not await this.sendPromise (unlike sendMessageStream),
|
||||
// it should be able to start even if a sendMessageStream call is in progress.
|
||||
|
||||
let resolveMain: (value: unknown) => void;
|
||||
const mainPromise = new Promise((resolve) => {
|
||||
resolveMain = resolve;
|
||||
});
|
||||
|
||||
// Mock main stream to hang
|
||||
vi.mocked(mockContentGenerator.generateContentStream)
|
||||
.mockImplementationOnce(async () =>
|
||||
(async function* () {
|
||||
await mainPromise;
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Main done' }] },
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as GenerateContentResponse;
|
||||
})(),
|
||||
)
|
||||
.mockImplementationOnce(async () =>
|
||||
(async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'BTW done' }] },
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as GenerateContentResponse;
|
||||
})(),
|
||||
);
|
||||
|
||||
// Start main stream (will hang during iteration)
|
||||
const mainStreamGen = await chat.sendMessageStream(
|
||||
{ model: 'test-model' },
|
||||
'Main prompt',
|
||||
'main-id',
|
||||
new AbortController().signal,
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
const mainStreamNextPromise = mainStreamGen.next();
|
||||
|
||||
// Attempt BTW stream immediately - it should NOT block on mainPromise
|
||||
const btwStream = chat.sendBtwStream(
|
||||
{ model: 'test-model' },
|
||||
'BTW prompt',
|
||||
'btw-id',
|
||||
new AbortController().signal,
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
|
||||
const btwEvents = [];
|
||||
for await (const event of btwStream) {
|
||||
btwEvents.push(event);
|
||||
}
|
||||
|
||||
// Assert BTW finished while Main is still hanging
|
||||
expect(btwEvents.length).toBeGreaterThan(0);
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(
|
||||
2,
|
||||
);
|
||||
|
||||
// Clean up
|
||||
resolveMain!(null);
|
||||
await mainStreamNextPromise; // Finish the one we started
|
||||
while (!(await mainStreamGen.next()).done) {
|
||||
// drain rest
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -484,12 +484,79 @@ export class GeminiChat {
|
||||
return streamWithRetries.call(this);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a side inquiry (BTW) that reuses current context but doesn't affect history.
|
||||
*/
|
||||
sendBtwStream(
|
||||
modelConfigKey: ModelConfigKey,
|
||||
message: PartListUnion,
|
||||
prompt_id: string,
|
||||
signal: AbortSignal,
|
||||
role: LlmRole,
|
||||
): AsyncGenerator<StreamEvent> {
|
||||
let streamDoneResolver: () => void;
|
||||
// We don't await this.sendPromise because /btw is safe to run concurrently.
|
||||
void new Promise<void>((resolve) => {
|
||||
streamDoneResolver = resolve;
|
||||
});
|
||||
|
||||
const requestContents = [
|
||||
...this.getHistory(true),
|
||||
createUserContent(message),
|
||||
];
|
||||
|
||||
const streamWithRetries = async function* (
|
||||
this: GeminiChat,
|
||||
): AsyncGenerator<StreamEvent, void, void> {
|
||||
try {
|
||||
const stream = await this.makeApiCallAndProcessStream(
|
||||
modelConfigKey,
|
||||
requestContents,
|
||||
prompt_id,
|
||||
signal,
|
||||
role,
|
||||
[], // No tools for side inquiries
|
||||
true, // isBtw flag
|
||||
);
|
||||
for await (const chunk of stream) {
|
||||
yield { type: StreamEventType.CHUNK, value: chunk };
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof AgentExecutionStoppedError) {
|
||||
yield {
|
||||
type: StreamEventType.AGENT_EXECUTION_STOPPED,
|
||||
reason: error.reason,
|
||||
};
|
||||
} else if (error instanceof AgentExecutionBlockedError) {
|
||||
yield {
|
||||
type: StreamEventType.AGENT_EXECUTION_BLOCKED,
|
||||
reason: error.reason,
|
||||
};
|
||||
if (error.syntheticResponse) {
|
||||
yield {
|
||||
type: StreamEventType.CHUNK,
|
||||
value: error.syntheticResponse,
|
||||
};
|
||||
}
|
||||
} else {
|
||||
throw error;
|
||||
}
|
||||
} finally {
|
||||
streamDoneResolver!();
|
||||
}
|
||||
};
|
||||
|
||||
return streamWithRetries.call(this);
|
||||
}
|
||||
|
||||
private async makeApiCallAndProcessStream(
|
||||
modelConfigKey: ModelConfigKey,
|
||||
requestContents: readonly Content[],
|
||||
prompt_id: string,
|
||||
abortSignal: AbortSignal,
|
||||
role: LlmRole,
|
||||
toolsOverride?: Tool[],
|
||||
isBtw?: boolean,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
const contentsForPreviewModel =
|
||||
this.ensureActiveLoopHasThoughtSignatures(requestContents);
|
||||
@@ -560,7 +627,7 @@ export class GeminiChat {
|
||||
// TODO(12622): Ensure we don't overrwrite these when they are
|
||||
// passed via config.
|
||||
systemInstruction: this.systemInstruction,
|
||||
tools: this.tools,
|
||||
tools: toolsOverride !== undefined ? toolsOverride : this.tools,
|
||||
abortSignal,
|
||||
};
|
||||
|
||||
@@ -569,7 +636,7 @@ export class GeminiChat {
|
||||
: [...requestContents];
|
||||
|
||||
const hookSystem = this.context.config.getHookSystem();
|
||||
if (hookSystem) {
|
||||
if (hookSystem && !isBtw) {
|
||||
const beforeModelResult = await hookSystem.fireBeforeModelEvent({
|
||||
model: modelToUse,
|
||||
config,
|
||||
@@ -642,7 +709,7 @@ export class GeminiChat {
|
||||
}
|
||||
}
|
||||
|
||||
if (this.onModelChanged) {
|
||||
if (this.onModelChanged && !isBtw) {
|
||||
this.tools = await this.onModelChanged(modelToUse);
|
||||
}
|
||||
|
||||
@@ -714,6 +781,7 @@ export class GeminiChat {
|
||||
lastModelToUse,
|
||||
streamResponse,
|
||||
originalRequest,
|
||||
isBtw,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -873,6 +941,7 @@ export class GeminiChat {
|
||||
model: string,
|
||||
streamResponse: AsyncGenerator<GenerateContentResponse>,
|
||||
originalRequest: GenerateContentParameters,
|
||||
isBtw?: boolean,
|
||||
): AsyncGenerator<GenerateContentResponse> {
|
||||
const modelResponseParts: Part[] = [];
|
||||
|
||||
@@ -895,7 +964,9 @@ export class GeminiChat {
|
||||
if (content.parts.some((part) => part.thought)) {
|
||||
// Record thoughts
|
||||
hasThoughts = true;
|
||||
this.recordThoughtFromContent(content);
|
||||
if (!isBtw) {
|
||||
this.recordThoughtFromContent(content);
|
||||
}
|
||||
}
|
||||
if (content.parts.some((part) => part.functionCall)) {
|
||||
hasToolCall = true;
|
||||
@@ -908,7 +979,7 @@ export class GeminiChat {
|
||||
}
|
||||
|
||||
// Record token usage if this chunk has usageMetadata
|
||||
if (chunk.usageMetadata) {
|
||||
if (chunk.usageMetadata && !isBtw) {
|
||||
this.chatRecordingService.recordMessageTokens(chunk.usageMetadata);
|
||||
if (chunk.usageMetadata.promptTokenCount !== undefined) {
|
||||
this.lastPromptTokenCount = chunk.usageMetadata.promptTokenCount;
|
||||
@@ -916,7 +987,7 @@ export class GeminiChat {
|
||||
}
|
||||
|
||||
const hookSystem = this.context.config.getHookSystem();
|
||||
if (originalRequest && chunk && hookSystem) {
|
||||
if (originalRequest && chunk && hookSystem && !isBtw) {
|
||||
const hookResult = await hookSystem.fireAfterModelEvent(
|
||||
originalRequest,
|
||||
chunk,
|
||||
@@ -965,7 +1036,7 @@ export class GeminiChat {
|
||||
// Record model response text from the collected parts.
|
||||
// Also flush when there are thoughts or a tool call (even with no text)
|
||||
// so that BeforeTool hooks always see the latest transcript state.
|
||||
if (responseText || hasThoughts || hasToolCall) {
|
||||
if (!isBtw && (responseText || hasThoughts || hasToolCall)) {
|
||||
this.chatRecordingService.recordMessage({
|
||||
model,
|
||||
type: 'gemini',
|
||||
@@ -1008,7 +1079,9 @@ export class GeminiChat {
|
||||
}
|
||||
}
|
||||
|
||||
this.history.push({ role: 'model', parts: consolidatedParts });
|
||||
if (!isBtw) {
|
||||
this.history.push({ role: 'model', parts: consolidatedParts });
|
||||
}
|
||||
}
|
||||
|
||||
getLastPromptTokenCount(): number {
|
||||
|
||||
Reference in New Issue
Block a user