Use GetOperation to poll for OnboardUser completion (#15827)

Co-authored-by: Vedant Mahajan <vedant.04.mahajan@gmail.com>
Co-authored-by: Tommaso Sciortino <sciortino@gmail.com>
This commit is contained in:
Ishaan Gupta
2026-01-07 00:38:59 +05:30
committed by GitHub
parent 9172e28315
commit cce4574143
4 changed files with 133 additions and 9 deletions

View File

@@ -432,6 +432,31 @@ describe('CodeAssistServer', () => {
expect(response.name).toBe('operations/123');
});
it('should call the getOperation endpoint', async () => {
const { server } = createTestServer();
const mockResponse = {
name: 'operations/123',
done: true,
response: {
cloudaicompanionProject: {
id: 'test-project',
name: 'projects/test-project',
},
},
};
vi.spyOn(server, 'requestGetOperation').mockResolvedValue(mockResponse);
const response = await server.getOperation('operations/123');
expect(server.requestGetOperation).toHaveBeenCalledWith('operations/123');
expect(response.name).toBe('operations/123');
expect(response.response?.cloudaicompanionProject?.id).toBe('test-project');
expect(response.response?.cloudaicompanionProject?.name).toBe(
'projects/test-project',
);
});
it('should call the loadCodeAssist endpoint', async () => {
const { server } = createTestServer();
const mockResponse = {

View File

@@ -51,7 +51,6 @@ import {
recordConversationOffered,
} from './telemetry.js';
import { getClientMetadata } from './experiments/client_metadata.js';
/** HTTP options to be used in each of the requests. */
export interface HttpOptions {
/** Additional HTTP headers to be sent with the request. */
@@ -160,6 +159,10 @@ export class CodeAssistServer implements ContentGenerator {
return this.requestPost<LongRunningOperationResponse>('onboardUser', req);
}
async getOperation(name: string): Promise<LongRunningOperationResponse> {
return this.requestGetOperation<LongRunningOperationResponse>(name);
}
async loadCodeAssist(
req: LoadCodeAssistRequest,
): Promise<LoadCodeAssistResponse> {
@@ -289,9 +292,12 @@ export class CodeAssistServer implements ContentGenerator {
return res.data as T;
}
async requestGet<T>(method: string, signal?: AbortSignal): Promise<T> {
private async makeGetRequest<T>(
url: string,
signal?: AbortSignal,
): Promise<T> {
const res = await this.client.request({
url: this.getMethodUrl(method),
url,
method: 'GET',
headers: {
'Content-Type': 'application/json',
@@ -303,6 +309,14 @@ export class CodeAssistServer implements ContentGenerator {
return res.data as T;
}
async requestGet<T>(method: string, signal?: AbortSignal): Promise<T> {
return this.makeGetRequest<T>(this.getMethodUrl(method), signal);
}
async requestGetOperation<T>(name: string, signal?: AbortSignal): Promise<T> {
return this.makeGetRequest<T>(this.getOperationUrl(name), signal);
}
async requestStreamingPost<T>(
method: string,
req: object,
@@ -345,10 +359,18 @@ export class CodeAssistServer implements ContentGenerator {
})();
}
getMethodUrl(method: string): string {
private getBaseUrl(): string {
const endpoint =
process.env['CODE_ASSIST_ENDPOINT'] ?? CODE_ASSIST_ENDPOINT;
return `${endpoint}/${CODE_ASSIST_API_VERSION}:${method}`;
return `${endpoint}/${CODE_ASSIST_API_VERSION}`;
}
getMethodUrl(method: string): string {
return `${this.getBaseUrl()}:${method}`;
}
getOperationUrl(name: string): string {
return `${this.getBaseUrl()}/${name}`;
}
}

View File

@@ -106,9 +106,11 @@ describe('setupUser for existing user', () => {
describe('setupUser for new user', () => {
let mockLoad: ReturnType<typeof vi.fn>;
let mockOnboardUser: ReturnType<typeof vi.fn>;
let mockGetOperation: ReturnType<typeof vi.fn>;
beforeEach(() => {
vi.resetAllMocks();
vi.useFakeTimers();
mockLoad = vi.fn();
mockOnboardUser = vi.fn().mockResolvedValue({
done: true,
@@ -118,16 +120,19 @@ describe('setupUser for new user', () => {
},
},
});
mockGetOperation = vi.fn();
vi.mocked(CodeAssistServer).mockImplementation(
() =>
({
loadCodeAssist: mockLoad,
onboardUser: mockOnboardUser,
getOperation: mockGetOperation,
}) as unknown as CodeAssistServer,
);
});
afterEach(() => {
vi.useRealTimers();
vi.unstubAllEnvs();
});
@@ -221,4 +226,74 @@ describe('setupUser for new user', () => {
ProjectIdRequiredError,
);
});
it('should poll getOperation when onboardUser returns done=false', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
const operationName = 'operations/123';
mockOnboardUser.mockResolvedValueOnce({
name: operationName,
done: false,
});
mockGetOperation
.mockResolvedValueOnce({
name: operationName,
done: false,
})
.mockResolvedValueOnce({
name: operationName,
done: true,
response: {
cloudaicompanionProject: {
id: 'server-project',
},
},
});
const setupPromise = setupUser({} as OAuth2Client);
await vi.advanceTimersByTimeAsync(5000);
await vi.advanceTimersByTimeAsync(5000);
const userData = await setupPromise;
expect(mockOnboardUser).toHaveBeenCalledTimes(1);
expect(mockGetOperation).toHaveBeenCalledTimes(2);
expect(mockGetOperation).toHaveBeenCalledWith(operationName);
expect(userData).toEqual({
projectId: 'server-project',
userTier: 'standard-tier',
});
});
it('should not poll getOperation when onboardUser returns done=true immediately', async () => {
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'test-project');
mockLoad.mockResolvedValue({
allowedTiers: [mockPaidTier],
});
mockOnboardUser.mockResolvedValueOnce({
name: 'operations/123',
done: true,
response: {
cloudaicompanionProject: {
id: 'server-project',
},
},
});
const userData = await setupUser({} as OAuth2Client);
expect(mockOnboardUser).toHaveBeenCalledTimes(1);
expect(mockGetOperation).not.toHaveBeenCalled();
expect(userData).toEqual({
projectId: 'server-project',
userTier: 'standard-tier',
});
});
});

View File

@@ -89,11 +89,13 @@ export async function setupUser(client: AuthClient): Promise<UserData> {
};
}
// Poll onboardUser until long running operation is complete.
let lroRes = await caServer.onboardUser(onboardReq);
while (!lroRes.done) {
await new Promise((f) => setTimeout(f, 5000));
lroRes = await caServer.onboardUser(onboardReq);
if (!lroRes.done && lroRes.name) {
const operationName = lroRes.name;
while (!lroRes.done) {
await new Promise((f) => setTimeout(f, 5000));
lroRes = await caServer.getOperation(operationName);
}
}
if (!lroRes.response?.cloudaicompanionProject?.id) {