mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-13 07:30:52 -07:00
Use GetOperation to poll for OnboardUser completion (#15827)
Co-authored-by: Vedant Mahajan <vedant.04.mahajan@gmail.com> Co-authored-by: Tommaso Sciortino <sciortino@gmail.com>
This commit is contained in:
@@ -432,6 +432,31 @@ describe('CodeAssistServer', () => {
|
||||
expect(response.name).toBe('operations/123');
|
||||
});
|
||||
|
||||
it('should call the getOperation endpoint', async () => {
|
||||
const { server } = createTestServer();
|
||||
|
||||
const mockResponse = {
|
||||
name: 'operations/123',
|
||||
done: true,
|
||||
response: {
|
||||
cloudaicompanionProject: {
|
||||
id: 'test-project',
|
||||
name: 'projects/test-project',
|
||||
},
|
||||
},
|
||||
};
|
||||
vi.spyOn(server, 'requestGetOperation').mockResolvedValue(mockResponse);
|
||||
|
||||
const response = await server.getOperation('operations/123');
|
||||
|
||||
expect(server.requestGetOperation).toHaveBeenCalledWith('operations/123');
|
||||
expect(response.name).toBe('operations/123');
|
||||
expect(response.response?.cloudaicompanionProject?.id).toBe('test-project');
|
||||
expect(response.response?.cloudaicompanionProject?.name).toBe(
|
||||
'projects/test-project',
|
||||
);
|
||||
});
|
||||
|
||||
it('should call the loadCodeAssist endpoint', async () => {
|
||||
const { server } = createTestServer();
|
||||
const mockResponse = {
|
||||
|
||||
@@ -51,7 +51,6 @@ import {
|
||||
recordConversationOffered,
|
||||
} from './telemetry.js';
|
||||
import { getClientMetadata } from './experiments/client_metadata.js';
|
||||
|
||||
/** HTTP options to be used in each of the requests. */
|
||||
export interface HttpOptions {
|
||||
/** Additional HTTP headers to be sent with the request. */
|
||||
@@ -160,6 +159,10 @@ export class CodeAssistServer implements ContentGenerator {
|
||||
return this.requestPost<LongRunningOperationResponse>('onboardUser', req);
|
||||
}
|
||||
|
||||
async getOperation(name: string): Promise<LongRunningOperationResponse> {
|
||||
return this.requestGetOperation<LongRunningOperationResponse>(name);
|
||||
}
|
||||
|
||||
async loadCodeAssist(
|
||||
req: LoadCodeAssistRequest,
|
||||
): Promise<LoadCodeAssistResponse> {
|
||||
@@ -289,9 +292,12 @@ export class CodeAssistServer implements ContentGenerator {
|
||||
return res.data as T;
|
||||
}
|
||||
|
||||
async requestGet<T>(method: string, signal?: AbortSignal): Promise<T> {
|
||||
private async makeGetRequest<T>(
|
||||
url: string,
|
||||
signal?: AbortSignal,
|
||||
): Promise<T> {
|
||||
const res = await this.client.request({
|
||||
url: this.getMethodUrl(method),
|
||||
url,
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
@@ -303,6 +309,14 @@ export class CodeAssistServer implements ContentGenerator {
|
||||
return res.data as T;
|
||||
}
|
||||
|
||||
async requestGet<T>(method: string, signal?: AbortSignal): Promise<T> {
|
||||
return this.makeGetRequest<T>(this.getMethodUrl(method), signal);
|
||||
}
|
||||
|
||||
async requestGetOperation<T>(name: string, signal?: AbortSignal): Promise<T> {
|
||||
return this.makeGetRequest<T>(this.getOperationUrl(name), signal);
|
||||
}
|
||||
|
||||
async requestStreamingPost<T>(
|
||||
method: string,
|
||||
req: object,
|
||||
@@ -345,10 +359,18 @@ export class CodeAssistServer implements ContentGenerator {
|
||||
})();
|
||||
}
|
||||
|
||||
getMethodUrl(method: string): string {
|
||||
private getBaseUrl(): string {
|
||||
const endpoint =
|
||||
process.env['CODE_ASSIST_ENDPOINT'] ?? CODE_ASSIST_ENDPOINT;
|
||||
return `${endpoint}/${CODE_ASSIST_API_VERSION}:${method}`;
|
||||
return `${endpoint}/${CODE_ASSIST_API_VERSION}`;
|
||||
}
|
||||
|
||||
getMethodUrl(method: string): string {
|
||||
return `${this.getBaseUrl()}:${method}`;
|
||||
}
|
||||
|
||||
getOperationUrl(name: string): string {
|
||||
return `${this.getBaseUrl()}/${name}`;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -106,9 +106,11 @@ describe('setupUser for existing user', () => {
|
||||
describe('setupUser for new user', () => {
|
||||
let mockLoad: ReturnType<typeof vi.fn>;
|
||||
let mockOnboardUser: ReturnType<typeof vi.fn>;
|
||||
let mockGetOperation: ReturnType<typeof vi.fn>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.useFakeTimers();
|
||||
mockLoad = vi.fn();
|
||||
mockOnboardUser = vi.fn().mockResolvedValue({
|
||||
done: true,
|
||||
@@ -118,16 +120,19 @@ describe('setupUser for new user', () => {
|
||||
},
|
||||
},
|
||||
});
|
||||
mockGetOperation = vi.fn();
|
||||
vi.mocked(CodeAssistServer).mockImplementation(
|
||||
() =>
|
||||
({
|
||||
loadCodeAssist: mockLoad,
|
||||
onboardUser: mockOnboardUser,
|
||||
getOperation: mockGetOperation,
|
||||
}) as unknown as CodeAssistServer,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.unstubAllEnvs();
|
||||
});
|
||||
|
||||
@@ -221,4 +226,74 @@ describe('setupUser for new user', () => {
|
||||
ProjectIdRequiredError,
|
||||
);
|
||||
});
|
||||
|
||||
it('should poll getOperation when onboardUser returns done=false', async () => {
|
||||
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
|
||||
mockLoad.mockResolvedValue({
|
||||
allowedTiers: [mockPaidTier],
|
||||
});
|
||||
|
||||
const operationName = 'operations/123';
|
||||
|
||||
mockOnboardUser.mockResolvedValueOnce({
|
||||
name: operationName,
|
||||
done: false,
|
||||
});
|
||||
|
||||
mockGetOperation
|
||||
.mockResolvedValueOnce({
|
||||
name: operationName,
|
||||
done: false,
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
name: operationName,
|
||||
done: true,
|
||||
response: {
|
||||
cloudaicompanionProject: {
|
||||
id: 'server-project',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const setupPromise = setupUser({} as OAuth2Client);
|
||||
|
||||
await vi.advanceTimersByTimeAsync(5000);
|
||||
await vi.advanceTimersByTimeAsync(5000);
|
||||
|
||||
const userData = await setupPromise;
|
||||
|
||||
expect(mockOnboardUser).toHaveBeenCalledTimes(1);
|
||||
expect(mockGetOperation).toHaveBeenCalledTimes(2);
|
||||
expect(mockGetOperation).toHaveBeenCalledWith(operationName);
|
||||
expect(userData).toEqual({
|
||||
projectId: 'server-project',
|
||||
userTier: 'standard-tier',
|
||||
});
|
||||
});
|
||||
|
||||
it('should not poll getOperation when onboardUser returns done=true immediately', async () => {
|
||||
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
|
||||
mockLoad.mockResolvedValue({
|
||||
allowedTiers: [mockPaidTier],
|
||||
});
|
||||
|
||||
mockOnboardUser.mockResolvedValueOnce({
|
||||
name: 'operations/123',
|
||||
done: true,
|
||||
response: {
|
||||
cloudaicompanionProject: {
|
||||
id: 'server-project',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const userData = await setupUser({} as OAuth2Client);
|
||||
|
||||
expect(mockOnboardUser).toHaveBeenCalledTimes(1);
|
||||
expect(mockGetOperation).not.toHaveBeenCalled();
|
||||
expect(userData).toEqual({
|
||||
projectId: 'server-project',
|
||||
userTier: 'standard-tier',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -89,11 +89,13 @@ export async function setupUser(client: AuthClient): Promise<UserData> {
|
||||
};
|
||||
}
|
||||
|
||||
// Poll onboardUser until long running operation is complete.
|
||||
let lroRes = await caServer.onboardUser(onboardReq);
|
||||
while (!lroRes.done) {
|
||||
await new Promise((f) => setTimeout(f, 5000));
|
||||
lroRes = await caServer.onboardUser(onboardReq);
|
||||
if (!lroRes.done && lroRes.name) {
|
||||
const operationName = lroRes.name;
|
||||
while (!lroRes.done) {
|
||||
await new Promise((f) => setTimeout(f, 5000));
|
||||
lroRes = await caServer.getOperation(operationName);
|
||||
}
|
||||
}
|
||||
|
||||
if (!lroRes.response?.cloudaicompanionProject?.id) {
|
||||
|
||||
Reference in New Issue
Block a user