mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-12 21:03:05 -07:00
feat(core): Stop context window overflow when sending chat (#10459)
This commit is contained in:
@@ -1854,6 +1854,109 @@ describe('useGeminiStream', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should add info message for ContextWindowWillOverflow event', async () => {
|
||||
// Setup mock to return a stream with ContextWindowWillOverflow event
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
type: ServerGeminiEventType.ContextWindowWillOverflow,
|
||||
value: {
|
||||
estimatedRequestTokenCount: 100,
|
||||
remainingTokenCount: 50,
|
||||
},
|
||||
};
|
||||
})(),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockLoadedSettings,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
() => {},
|
||||
80,
|
||||
24,
|
||||
),
|
||||
);
|
||||
|
||||
// Submit a query
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('Test overflow');
|
||||
});
|
||||
|
||||
// Check that the info message was added
|
||||
await waitFor(() => {
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
{
|
||||
type: 'info',
|
||||
text: `Sending this message (100 tokens) might exceed the remaining context window limit (50 tokens). Please try reducing the size of your message or use the \`/compress\` command to compress the chat history.`,
|
||||
},
|
||||
expect.any(Number),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should call onCancelSubmit when ContextWindowWillOverflow event is received', async () => {
|
||||
const onCancelSubmitSpy = vi.fn();
|
||||
// Setup mock to return a stream with ContextWindowWillOverflow event
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
type: ServerGeminiEventType.ContextWindowWillOverflow,
|
||||
value: {
|
||||
estimatedRequestTokenCount: 100,
|
||||
remainingTokenCount: 50,
|
||||
},
|
||||
};
|
||||
})(),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockLoadedSettings,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
() => {},
|
||||
onCancelSubmitSpy,
|
||||
() => {},
|
||||
80,
|
||||
24,
|
||||
),
|
||||
);
|
||||
|
||||
// Submit a query
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('Test overflow');
|
||||
});
|
||||
|
||||
// Check that onCancelSubmit was called
|
||||
await waitFor(() => {
|
||||
expect(onCancelSubmitSpy).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it('should not add message for STOP finish reason', async () => {
|
||||
// Setup mock to return a stream with STOP finish reason
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
|
||||
@@ -637,6 +637,21 @@ export const useGeminiStream = (
|
||||
[addItem, config],
|
||||
);
|
||||
|
||||
const handleContextWindowWillOverflowEvent = useCallback(
|
||||
(estimatedRequestTokenCount: number, remainingTokenCount: number) => {
|
||||
onCancelSubmit();
|
||||
|
||||
addItem(
|
||||
{
|
||||
type: 'info',
|
||||
text: `Sending this message (${estimatedRequestTokenCount} tokens) might exceed the remaining context window limit (${remainingTokenCount} tokens). Please try reducing the size of your message or use the \`/compress\` command to compress the chat history.`,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
},
|
||||
[addItem, onCancelSubmit],
|
||||
);
|
||||
|
||||
const handleLoopDetectionConfirmation = useCallback(
|
||||
(result: { userSelection: 'disable' | 'keep' }) => {
|
||||
setLoopDetectionConfirmationRequest(null);
|
||||
@@ -709,6 +724,12 @@ export const useGeminiStream = (
|
||||
case ServerGeminiEventType.MaxSessionTurns:
|
||||
handleMaxSessionTurnsEvent();
|
||||
break;
|
||||
case ServerGeminiEventType.ContextWindowWillOverflow:
|
||||
handleContextWindowWillOverflowEvent(
|
||||
event.value.estimatedRequestTokenCount,
|
||||
event.value.remainingTokenCount,
|
||||
);
|
||||
break;
|
||||
case ServerGeminiEventType.Finished:
|
||||
handleFinishedEvent(
|
||||
event as ServerGeminiFinishedEvent,
|
||||
@@ -746,6 +767,7 @@ export const useGeminiStream = (
|
||||
handleChatCompressionEvent,
|
||||
handleFinishedEvent,
|
||||
handleMaxSessionTurnsEvent,
|
||||
handleContextWindowWillOverflowEvent,
|
||||
handleCitationEvent,
|
||||
],
|
||||
);
|
||||
|
||||
@@ -1505,6 +1505,113 @@ ${JSON.stringify(
|
||||
);
|
||||
});
|
||||
|
||||
it('should yield ContextWindowWillOverflow when the context window is about to overflow', async () => {
|
||||
// Arrange
|
||||
const MOCKED_TOKEN_LIMIT = 1000;
|
||||
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
|
||||
|
||||
// Set last prompt token count
|
||||
const lastPromptTokenCount = 900;
|
||||
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(
|
||||
lastPromptTokenCount,
|
||||
);
|
||||
|
||||
// Remaining = 100. Threshold (95%) = 95.
|
||||
// We need a request > 95 tokens.
|
||||
// A string of length 400 is roughly 100 tokens.
|
||||
const longText = 'a'.repeat(400);
|
||||
const request: Part[] = [{ text: longText }];
|
||||
const estimatedRequestTokenCount = Math.floor(
|
||||
JSON.stringify(request).length / 4,
|
||||
);
|
||||
const remainingTokenCount = MOCKED_TOKEN_LIMIT - lastPromptTokenCount;
|
||||
|
||||
// Mock tryCompressChat to not compress
|
||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValue({
|
||||
originalTokenCount: lastPromptTokenCount,
|
||||
newTokenCount: lastPromptTokenCount,
|
||||
compressionStatus: CompressionStatus.NOOP,
|
||||
});
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(
|
||||
request,
|
||||
new AbortController().signal,
|
||||
'prompt-id-overflow',
|
||||
);
|
||||
|
||||
const events = await fromAsync(stream);
|
||||
|
||||
// Assert
|
||||
expect(events).toContainEqual({
|
||||
type: GeminiEventType.ContextWindowWillOverflow,
|
||||
value: {
|
||||
estimatedRequestTokenCount,
|
||||
remainingTokenCount,
|
||||
},
|
||||
});
|
||||
// Ensure turn.run is not called
|
||||
expect(mockTurnRunFn).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should use the sticky model's token limit for the overflow check", async () => {
|
||||
// Arrange
|
||||
const STICKY_MODEL = 'gemini-1.5-flash';
|
||||
const STICKY_MODEL_LIMIT = 1000;
|
||||
const CONFIG_MODEL_LIMIT = 2000;
|
||||
|
||||
// Set up token limits
|
||||
vi.mocked(tokenLimit).mockImplementation((model) => {
|
||||
if (model === STICKY_MODEL) return STICKY_MODEL_LIMIT;
|
||||
return CONFIG_MODEL_LIMIT;
|
||||
});
|
||||
|
||||
// Set the sticky model
|
||||
client['currentSequenceModel'] = STICKY_MODEL;
|
||||
|
||||
// Set token count
|
||||
const lastPromptTokenCount = 900;
|
||||
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(
|
||||
lastPromptTokenCount,
|
||||
);
|
||||
|
||||
// Remaining (sticky) = 100. Threshold (95%) = 95.
|
||||
// We need a request > 95 tokens.
|
||||
const longText = 'a'.repeat(400);
|
||||
const request: Part[] = [{ text: longText }];
|
||||
const estimatedRequestTokenCount = Math.floor(
|
||||
JSON.stringify(request).length / 4,
|
||||
);
|
||||
const remainingTokenCount = STICKY_MODEL_LIMIT - lastPromptTokenCount;
|
||||
|
||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValue({
|
||||
originalTokenCount: lastPromptTokenCount,
|
||||
newTokenCount: lastPromptTokenCount,
|
||||
compressionStatus: CompressionStatus.NOOP,
|
||||
});
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(
|
||||
request,
|
||||
new AbortController().signal,
|
||||
'test-session-id', // Use the same ID as the session to keep stickiness
|
||||
);
|
||||
|
||||
const events = await fromAsync(stream);
|
||||
|
||||
// Assert
|
||||
// Should overflow based on the sticky model's limit
|
||||
expect(events).toContainEqual({
|
||||
type: GeminiEventType.ContextWindowWillOverflow,
|
||||
value: {
|
||||
estimatedRequestTokenCount,
|
||||
remainingTokenCount,
|
||||
},
|
||||
});
|
||||
expect(tokenLimit).toHaveBeenCalledWith(STICKY_MODEL);
|
||||
expect(mockTurnRunFn).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
describe('Model Routing', () => {
|
||||
let mockRouterService: { route: Mock };
|
||||
|
||||
|
||||
@@ -458,6 +458,19 @@ My setup is complete. I will provide my first command in the next turn.
|
||||
}
|
||||
}
|
||||
|
||||
private _getEffectiveModelForCurrentTurn(): string {
|
||||
if (this.currentSequenceModel) {
|
||||
return this.currentSequenceModel;
|
||||
}
|
||||
|
||||
const configModel = this.config.getModel();
|
||||
const model: string =
|
||||
configModel === DEFAULT_GEMINI_MODEL_AUTO
|
||||
? DEFAULT_GEMINI_MODEL
|
||||
: configModel;
|
||||
return getEffectiveModel(this.config.isInFallbackMode(), model);
|
||||
}
|
||||
|
||||
async *sendMessageStream(
|
||||
request: PartListUnion,
|
||||
signal: AbortSignal,
|
||||
@@ -483,6 +496,25 @@ My setup is complete. I will provide my first command in the next turn.
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
}
|
||||
|
||||
// Check for context window overflow
|
||||
const modelForLimitCheck = this._getEffectiveModelForCurrentTurn();
|
||||
|
||||
const estimatedRequestTokenCount = Math.floor(
|
||||
JSON.stringify(request).length / 4,
|
||||
);
|
||||
|
||||
const remainingTokenCount =
|
||||
tokenLimit(modelForLimitCheck) -
|
||||
uiTelemetryService.getLastPromptTokenCount();
|
||||
|
||||
if (estimatedRequestTokenCount > remainingTokenCount * 0.95) {
|
||||
yield {
|
||||
type: GeminiEventType.ContextWindowWillOverflow,
|
||||
value: { estimatedRequestTokenCount, remainingTokenCount },
|
||||
};
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
}
|
||||
|
||||
const compressed = await this.tryCompressChat(prompt_id, false);
|
||||
|
||||
if (compressed.compressionStatus === CompressionStatus.COMPRESSED) {
|
||||
@@ -674,14 +706,7 @@ My setup is complete. I will provide my first command in the next turn.
|
||||
// If the model is 'auto', we will use a placeholder model to check.
|
||||
// Compression occurs before we choose a model, so calling `count_tokens`
|
||||
// before the model is chosen would result in an error.
|
||||
const configModel = this.config.getModel();
|
||||
let model: string =
|
||||
configModel === DEFAULT_GEMINI_MODEL_AUTO
|
||||
? DEFAULT_GEMINI_MODEL
|
||||
: configModel;
|
||||
|
||||
// Check if the model needs to be a fallback
|
||||
model = getEffectiveModel(this.config.isInFallbackMode(), model);
|
||||
const model = this._getEffectiveModelForCurrentTurn();
|
||||
|
||||
const curatedHistory = this.getChat().getHistory(true);
|
||||
|
||||
|
||||
@@ -59,12 +59,21 @@ export enum GeminiEventType {
|
||||
LoopDetected = 'loop_detected',
|
||||
Citation = 'citation',
|
||||
Retry = 'retry',
|
||||
ContextWindowWillOverflow = 'context_window_will_overflow',
|
||||
}
|
||||
|
||||
export type ServerGeminiRetryEvent = {
|
||||
type: GeminiEventType.Retry;
|
||||
};
|
||||
|
||||
export type ServerGeminiContextWindowWillOverflowEvent = {
|
||||
type: GeminiEventType.ContextWindowWillOverflow;
|
||||
value: {
|
||||
estimatedRequestTokenCount: number;
|
||||
remainingTokenCount: number;
|
||||
};
|
||||
};
|
||||
|
||||
export interface StructuredError {
|
||||
message: string;
|
||||
status?: number;
|
||||
@@ -193,7 +202,8 @@ export type ServerGeminiStreamEvent =
|
||||
| ServerGeminiToolCallRequestEvent
|
||||
| ServerGeminiToolCallResponseEvent
|
||||
| ServerGeminiUserCancelledEvent
|
||||
| ServerGeminiRetryEvent;
|
||||
| ServerGeminiRetryEvent
|
||||
| ServerGeminiContextWindowWillOverflowEvent;
|
||||
|
||||
// A turn manages the agentic loop turn within the server context.
|
||||
export class Turn {
|
||||
|
||||
Reference in New Issue
Block a user