mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-26 21:14:35 -07:00
feat(core): Stop context window overflow when sending chat (#10459)
This commit is contained in:
@@ -1505,6 +1505,113 @@ ${JSON.stringify(
|
||||
);
|
||||
});
|
||||
|
||||
it('should yield ContextWindowWillOverflow when the context window is about to overflow', async () => {
|
||||
// Arrange
|
||||
const MOCKED_TOKEN_LIMIT = 1000;
|
||||
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
|
||||
|
||||
// Set last prompt token count
|
||||
const lastPromptTokenCount = 900;
|
||||
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(
|
||||
lastPromptTokenCount,
|
||||
);
|
||||
|
||||
// Remaining = 100. Threshold (95%) = 95.
|
||||
// We need a request > 95 tokens.
|
||||
// A string of length 400 is roughly 100 tokens.
|
||||
const longText = 'a'.repeat(400);
|
||||
const request: Part[] = [{ text: longText }];
|
||||
const estimatedRequestTokenCount = Math.floor(
|
||||
JSON.stringify(request).length / 4,
|
||||
);
|
||||
const remainingTokenCount = MOCKED_TOKEN_LIMIT - lastPromptTokenCount;
|
||||
|
||||
// Mock tryCompressChat to not compress
|
||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValue({
|
||||
originalTokenCount: lastPromptTokenCount,
|
||||
newTokenCount: lastPromptTokenCount,
|
||||
compressionStatus: CompressionStatus.NOOP,
|
||||
});
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(
|
||||
request,
|
||||
new AbortController().signal,
|
||||
'prompt-id-overflow',
|
||||
);
|
||||
|
||||
const events = await fromAsync(stream);
|
||||
|
||||
// Assert
|
||||
expect(events).toContainEqual({
|
||||
type: GeminiEventType.ContextWindowWillOverflow,
|
||||
value: {
|
||||
estimatedRequestTokenCount,
|
||||
remainingTokenCount,
|
||||
},
|
||||
});
|
||||
// Ensure turn.run is not called
|
||||
expect(mockTurnRunFn).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should use the sticky model's token limit for the overflow check", async () => {
|
||||
// Arrange
|
||||
const STICKY_MODEL = 'gemini-1.5-flash';
|
||||
const STICKY_MODEL_LIMIT = 1000;
|
||||
const CONFIG_MODEL_LIMIT = 2000;
|
||||
|
||||
// Set up token limits
|
||||
vi.mocked(tokenLimit).mockImplementation((model) => {
|
||||
if (model === STICKY_MODEL) return STICKY_MODEL_LIMIT;
|
||||
return CONFIG_MODEL_LIMIT;
|
||||
});
|
||||
|
||||
// Set the sticky model
|
||||
client['currentSequenceModel'] = STICKY_MODEL;
|
||||
|
||||
// Set token count
|
||||
const lastPromptTokenCount = 900;
|
||||
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(
|
||||
lastPromptTokenCount,
|
||||
);
|
||||
|
||||
// Remaining (sticky) = 100. Threshold (95%) = 95.
|
||||
// We need a request > 95 tokens.
|
||||
const longText = 'a'.repeat(400);
|
||||
const request: Part[] = [{ text: longText }];
|
||||
const estimatedRequestTokenCount = Math.floor(
|
||||
JSON.stringify(request).length / 4,
|
||||
);
|
||||
const remainingTokenCount = STICKY_MODEL_LIMIT - lastPromptTokenCount;
|
||||
|
||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValue({
|
||||
originalTokenCount: lastPromptTokenCount,
|
||||
newTokenCount: lastPromptTokenCount,
|
||||
compressionStatus: CompressionStatus.NOOP,
|
||||
});
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(
|
||||
request,
|
||||
new AbortController().signal,
|
||||
'test-session-id', // Use the same ID as the session to keep stickiness
|
||||
);
|
||||
|
||||
const events = await fromAsync(stream);
|
||||
|
||||
// Assert
|
||||
// Should overflow based on the sticky model's limit
|
||||
expect(events).toContainEqual({
|
||||
type: GeminiEventType.ContextWindowWillOverflow,
|
||||
value: {
|
||||
estimatedRequestTokenCount,
|
||||
remainingTokenCount,
|
||||
},
|
||||
});
|
||||
expect(tokenLimit).toHaveBeenCalledWith(STICKY_MODEL);
|
||||
expect(mockTurnRunFn).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
describe('Model Routing', () => {
|
||||
let mockRouterService: { route: Mock };
|
||||
|
||||
|
||||
@@ -458,6 +458,19 @@ My setup is complete. I will provide my first command in the next turn.
|
||||
}
|
||||
}
|
||||
|
||||
private _getEffectiveModelForCurrentTurn(): string {
|
||||
if (this.currentSequenceModel) {
|
||||
return this.currentSequenceModel;
|
||||
}
|
||||
|
||||
const configModel = this.config.getModel();
|
||||
const model: string =
|
||||
configModel === DEFAULT_GEMINI_MODEL_AUTO
|
||||
? DEFAULT_GEMINI_MODEL
|
||||
: configModel;
|
||||
return getEffectiveModel(this.config.isInFallbackMode(), model);
|
||||
}
|
||||
|
||||
async *sendMessageStream(
|
||||
request: PartListUnion,
|
||||
signal: AbortSignal,
|
||||
@@ -483,6 +496,25 @@ My setup is complete. I will provide my first command in the next turn.
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
}
|
||||
|
||||
// Check for context window overflow
|
||||
const modelForLimitCheck = this._getEffectiveModelForCurrentTurn();
|
||||
|
||||
const estimatedRequestTokenCount = Math.floor(
|
||||
JSON.stringify(request).length / 4,
|
||||
);
|
||||
|
||||
const remainingTokenCount =
|
||||
tokenLimit(modelForLimitCheck) -
|
||||
uiTelemetryService.getLastPromptTokenCount();
|
||||
|
||||
if (estimatedRequestTokenCount > remainingTokenCount * 0.95) {
|
||||
yield {
|
||||
type: GeminiEventType.ContextWindowWillOverflow,
|
||||
value: { estimatedRequestTokenCount, remainingTokenCount },
|
||||
};
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
}
|
||||
|
||||
const compressed = await this.tryCompressChat(prompt_id, false);
|
||||
|
||||
if (compressed.compressionStatus === CompressionStatus.COMPRESSED) {
|
||||
@@ -674,14 +706,7 @@ My setup is complete. I will provide my first command in the next turn.
|
||||
// If the model is 'auto', we will use a placeholder model to check.
|
||||
// Compression occurs before we choose a model, so calling `count_tokens`
|
||||
// before the model is chosen would result in an error.
|
||||
const configModel = this.config.getModel();
|
||||
let model: string =
|
||||
configModel === DEFAULT_GEMINI_MODEL_AUTO
|
||||
? DEFAULT_GEMINI_MODEL
|
||||
: configModel;
|
||||
|
||||
// Check if the model needs to be a fallback
|
||||
model = getEffectiveModel(this.config.isInFallbackMode(), model);
|
||||
const model = this._getEffectiveModelForCurrentTurn();
|
||||
|
||||
const curatedHistory = this.getChat().getHistory(true);
|
||||
|
||||
|
||||
@@ -59,12 +59,21 @@ export enum GeminiEventType {
|
||||
LoopDetected = 'loop_detected',
|
||||
Citation = 'citation',
|
||||
Retry = 'retry',
|
||||
ContextWindowWillOverflow = 'context_window_will_overflow',
|
||||
}
|
||||
|
||||
export type ServerGeminiRetryEvent = {
|
||||
type: GeminiEventType.Retry;
|
||||
};
|
||||
|
||||
export type ServerGeminiContextWindowWillOverflowEvent = {
|
||||
type: GeminiEventType.ContextWindowWillOverflow;
|
||||
value: {
|
||||
estimatedRequestTokenCount: number;
|
||||
remainingTokenCount: number;
|
||||
};
|
||||
};
|
||||
|
||||
export interface StructuredError {
|
||||
message: string;
|
||||
status?: number;
|
||||
@@ -193,7 +202,8 @@ export type ServerGeminiStreamEvent =
|
||||
| ServerGeminiToolCallRequestEvent
|
||||
| ServerGeminiToolCallResponseEvent
|
||||
| ServerGeminiUserCancelledEvent
|
||||
| ServerGeminiRetryEvent;
|
||||
| ServerGeminiRetryEvent
|
||||
| ServerGeminiContextWindowWillOverflowEvent;
|
||||
|
||||
// A turn manages the agentic loop turn within the server context.
|
||||
export class Turn {
|
||||
|
||||
Reference in New Issue
Block a user