mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-14 08:01:02 -07:00
feat: GenerateImage tool for built-in image generation
This commit is contained in:
@@ -832,6 +832,7 @@ export async function loadCliConfig(
|
||||
skillsSupport: settings.skills?.enabled ?? true,
|
||||
disabledSkills: settings.skills?.disabled,
|
||||
experimentalJitContext: settings.experimental?.jitContext,
|
||||
imageGeneration: settings.experimental?.imageGeneration,
|
||||
modelSteering: settings.experimental?.modelSteering,
|
||||
toolOutputMasking: settings.experimental?.toolOutputMasking,
|
||||
noBrowser: !!process.env['NO_BROWSER'],
|
||||
|
||||
@@ -1811,6 +1811,16 @@ const SETTINGS_SCHEMA = {
|
||||
description: 'Enable planning features (Plan Mode and tools).',
|
||||
showInDialog: true,
|
||||
},
|
||||
imageGeneration: {
|
||||
type: 'boolean',
|
||||
label: 'Image Generation',
|
||||
category: 'Experimental',
|
||||
requiresRestart: true,
|
||||
default: false,
|
||||
description:
|
||||
'Enable generating images with Nano Banana (experimental).',
|
||||
showInDialog: true,
|
||||
},
|
||||
modelSteering: {
|
||||
type: 'boolean',
|
||||
label: 'Model Steering',
|
||||
|
||||
@@ -58,6 +58,7 @@ import { shellsCommand } from '../ui/commands/shellsCommand.js';
|
||||
import { vimCommand } from '../ui/commands/vimCommand.js';
|
||||
import { setupGithubCommand } from '../ui/commands/setupGithubCommand.js';
|
||||
import { terminalSetupCommand } from '../ui/commands/terminalSetupCommand.js';
|
||||
import { imageCommand } from '../ui/commands/imageCommand.js';
|
||||
|
||||
/**
|
||||
* Loads the core, hard-coded slash commands that are an integral part
|
||||
@@ -119,6 +120,7 @@ export class BuiltinCommandLoader implements ICommandLoader {
|
||||
]
|
||||
: [extensionsCommand(this.config?.getEnableExtensionReloading())]),
|
||||
helpCommand,
|
||||
imageCommand,
|
||||
shortcutsCommand,
|
||||
...(this.config?.getEnableHooksUI() ? [hooksCommand] : []),
|
||||
rewindCommand,
|
||||
|
||||
132
packages/cli/src/ui/commands/imageCommand.test.ts
Normal file
132
packages/cli/src/ui/commands/imageCommand.test.ts
Normal file
@@ -0,0 +1,132 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { parseImageArgs, imageCommand } from './imageCommand.js';
|
||||
|
||||
describe('parseImageArgs', () => {
|
||||
it('should parse a simple prompt with no flags', () => {
|
||||
const result = parseImageArgs('a sunset over the ocean');
|
||||
expect(result.prompt).toBe('a sunset over the ocean');
|
||||
expect(result.flags).toEqual({});
|
||||
});
|
||||
|
||||
it('should parse prompt with a space-separated flag value', () => {
|
||||
const result = parseImageArgs('a sunset --ratio 16:9');
|
||||
expect(result.prompt).toBe('a sunset');
|
||||
expect(result.flags['ratio']).toBe('16:9');
|
||||
});
|
||||
|
||||
it('should parse --return as boolean flag', () => {
|
||||
const result = parseImageArgs('a cat --return');
|
||||
expect(result.prompt).toBe('a cat');
|
||||
expect(result.flags['return']).toBe(true);
|
||||
});
|
||||
|
||||
it('should parse inline flag values with =', () => {
|
||||
const result = parseImageArgs('a cat --ratio=16:9 --size=4K');
|
||||
expect(result.prompt).toBe('a cat');
|
||||
expect(result.flags['ratio']).toBe('16:9');
|
||||
expect(result.flags['size']).toBe('4K');
|
||||
});
|
||||
|
||||
it('should handle multiple flags', () => {
|
||||
const result = parseImageArgs(
|
||||
'abstract wallpaper --ratio 21:9 --size 4K --count 2 --return',
|
||||
);
|
||||
expect(result.prompt).toBe('abstract wallpaper');
|
||||
expect(result.flags['ratio']).toBe('21:9');
|
||||
expect(result.flags['size']).toBe('4K');
|
||||
expect(result.flags['count']).toBe('2');
|
||||
expect(result.flags['return']).toBe(true);
|
||||
});
|
||||
|
||||
it('should return empty prompt when input starts with flags', () => {
|
||||
const result = parseImageArgs('--ratio 16:9');
|
||||
expect(result.prompt).toBe('');
|
||||
expect(result.flags['ratio']).toBe('16:9');
|
||||
});
|
||||
|
||||
it('should handle empty input', () => {
|
||||
const result = parseImageArgs('');
|
||||
expect(result.prompt).toBe('');
|
||||
expect(result.flags).toEqual({});
|
||||
});
|
||||
});
|
||||
|
||||
describe('imageCommand', () => {
|
||||
const mockContext = {} as Parameters<
|
||||
NonNullable<typeof imageCommand.action>
|
||||
>[0];
|
||||
|
||||
it('should return error for empty args', () => {
|
||||
const result = imageCommand.action!(mockContext, '');
|
||||
expect(result).toEqual(
|
||||
expect.objectContaining({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should return error when prompt is empty (only flags)', () => {
|
||||
const result = imageCommand.action!(mockContext, '--ratio 16:9');
|
||||
expect(result).toEqual(
|
||||
expect.objectContaining({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: expect.stringContaining('No prompt provided'),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should return tool action for valid prompt', () => {
|
||||
const result = imageCommand.action!(mockContext, 'a sunset over the ocean');
|
||||
expect(result).toEqual({
|
||||
type: 'tool',
|
||||
toolName: 'generate_image',
|
||||
toolArgs: { prompt: 'a sunset over the ocean' },
|
||||
});
|
||||
});
|
||||
|
||||
it('should map all flags to tool args correctly', () => {
|
||||
const result = imageCommand.action!(
|
||||
mockContext,
|
||||
'a cat --ratio 16:9 --size 2K --count 3 --model gemini-3-pro-image-preview --edit ./img.png --output ./out --return',
|
||||
);
|
||||
expect(result).toEqual({
|
||||
type: 'tool',
|
||||
toolName: 'generate_image',
|
||||
toolArgs: {
|
||||
prompt: 'a cat',
|
||||
aspect_ratio: '16:9',
|
||||
size: '2K',
|
||||
count: 3,
|
||||
model: 'gemini-3-pro-image-preview',
|
||||
input_image: './img.png',
|
||||
output_path: './out',
|
||||
return_to_context: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should have correct metadata', () => {
|
||||
expect(imageCommand.name).toBe('image');
|
||||
expect(imageCommand.altNames).toContain('img');
|
||||
expect(imageCommand.kind).toBe('built-in');
|
||||
expect(imageCommand.autoExecute).toBe(false);
|
||||
});
|
||||
|
||||
it('should provide flag completions', () => {
|
||||
const completions = imageCommand.completion!(mockContext, '--ra');
|
||||
expect(completions).toContain('--ratio');
|
||||
});
|
||||
|
||||
it('should return empty completions for non-flag input', () => {
|
||||
const completions = imageCommand.completion!(mockContext, 'some');
|
||||
expect(completions).toEqual([]);
|
||||
});
|
||||
});
|
||||
110
packages/cli/src/ui/commands/imageCommand.ts
Normal file
110
packages/cli/src/ui/commands/imageCommand.ts
Normal file
@@ -0,0 +1,110 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { GENERATE_IMAGE_TOOL_NAME } from '@google/gemini-cli-core';
|
||||
import type { SlashCommand, SlashCommandActionReturn } from './types.js';
|
||||
import { CommandKind } from './types.js';
|
||||
|
||||
interface ParsedImageArgs {
|
||||
prompt: string;
|
||||
flags: Record<string, string | boolean>;
|
||||
}
|
||||
|
||||
export function parseImageArgs(input: string): ParsedImageArgs {
|
||||
const flags: Record<string, string | boolean> = {};
|
||||
const parts = input.split(/\s+/);
|
||||
const promptParts: string[] = [];
|
||||
let i = 0;
|
||||
|
||||
// Collect prompt text (everything before first --flag)
|
||||
while (i < parts.length && !parts[i].startsWith('--')) {
|
||||
promptParts.push(parts[i]);
|
||||
i++;
|
||||
}
|
||||
|
||||
// Parse flags
|
||||
while (i < parts.length) {
|
||||
const part = parts[i];
|
||||
if (part.startsWith('--')) {
|
||||
const flagName = part.slice(2).split('=')[0];
|
||||
const inlineValue = part.includes('=') ? part.split('=')[1] : undefined;
|
||||
|
||||
if (inlineValue !== undefined) {
|
||||
flags[flagName] = inlineValue;
|
||||
} else if (flagName === 'return') {
|
||||
flags[flagName] = true;
|
||||
} else if (i + 1 < parts.length && !parts[i + 1].startsWith('--')) {
|
||||
flags[flagName] = parts[i + 1];
|
||||
i++;
|
||||
}
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
return { prompt: promptParts.join(' '), flags };
|
||||
}
|
||||
|
||||
export const imageCommand: SlashCommand = {
|
||||
name: 'image',
|
||||
altNames: ['img'],
|
||||
description: 'Generate or edit images using Nano Banana',
|
||||
kind: CommandKind.BUILT_IN,
|
||||
autoExecute: false,
|
||||
|
||||
action: (_context, args): SlashCommandActionReturn | void => {
|
||||
if (!args.trim()) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content:
|
||||
'Usage: /image <prompt> [--ratio 16:9] [--size 2K] [--count 3] [--edit path/to/image.png]',
|
||||
};
|
||||
}
|
||||
|
||||
const { prompt, flags } = parseImageArgs(args);
|
||||
|
||||
if (!prompt) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content:
|
||||
'Error: No prompt provided. The prompt must come before any --flags.',
|
||||
};
|
||||
}
|
||||
|
||||
const toolArgs: Record<string, unknown> = { prompt };
|
||||
if (flags['ratio']) toolArgs['aspect_ratio'] = flags['ratio'];
|
||||
if (flags['size']) toolArgs['size'] = flags['size'];
|
||||
if (flags['count'])
|
||||
toolArgs['count'] = parseInt(String(flags['count']), 10);
|
||||
if (flags['model']) toolArgs['model'] = flags['model'];
|
||||
if (flags['edit']) toolArgs['input_image'] = flags['edit'];
|
||||
if (flags['output']) toolArgs['output_path'] = flags['output'];
|
||||
if (flags['return']) toolArgs['return_to_context'] = true;
|
||||
|
||||
return {
|
||||
type: 'tool',
|
||||
toolName: GENERATE_IMAGE_TOOL_NAME,
|
||||
toolArgs,
|
||||
};
|
||||
},
|
||||
|
||||
completion: (_context, partialArg) => {
|
||||
const flagOptions = [
|
||||
'--ratio',
|
||||
'--size',
|
||||
'--count',
|
||||
'--model',
|
||||
'--edit',
|
||||
'--output',
|
||||
'--return',
|
||||
];
|
||||
if (partialArg.startsWith('--')) {
|
||||
return flagOptions.filter((f) => f.startsWith(partialArg));
|
||||
}
|
||||
return [];
|
||||
},
|
||||
};
|
||||
@@ -34,6 +34,7 @@ import { WriteFileTool } from '../tools/write-file.js';
|
||||
import { WebFetchTool } from '../tools/web-fetch.js';
|
||||
import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js';
|
||||
import { WebSearchTool } from '../tools/web-search.js';
|
||||
import { GenerateImageTool } from '../tools/generate-image.js';
|
||||
import { AskUserTool } from '../tools/ask-user.js';
|
||||
import { ExitPlanModeTool } from '../tools/exit-plan-mode.js';
|
||||
import { EnterPlanModeTool } from '../tools/enter-plan-mode.js';
|
||||
@@ -373,6 +374,7 @@ import { getErrorMessage } from '../utils/errors.js';
|
||||
import {
|
||||
ENTER_PLAN_MODE_TOOL_NAME,
|
||||
EXIT_PLAN_MODE_TOOL_NAME,
|
||||
GENERATE_IMAGE_TOOL_NAME,
|
||||
} from '../tools/tool-names.js';
|
||||
|
||||
export type { FileFilteringOptions };
|
||||
@@ -561,6 +563,7 @@ export interface ConfigParameters {
|
||||
disabledSkills?: string[];
|
||||
adminSkillsEnabled?: boolean;
|
||||
experimentalJitContext?: boolean;
|
||||
imageGeneration?: boolean;
|
||||
toolOutputMasking?: Partial<ToolOutputMaskingConfig>;
|
||||
disableLLMCorrection?: boolean;
|
||||
plan?: boolean;
|
||||
@@ -772,6 +775,7 @@ export class Config implements McpContext {
|
||||
private readonly adminSkillsEnabled: boolean;
|
||||
|
||||
private readonly experimentalJitContext: boolean;
|
||||
private readonly imageGeneration: boolean;
|
||||
private readonly disableLLMCorrection: boolean;
|
||||
private readonly planEnabled: boolean;
|
||||
private readonly planModeRoutingEnabled: boolean;
|
||||
@@ -871,6 +875,7 @@ export class Config implements McpContext {
|
||||
this.adminSkillsEnabled = params.adminSkillsEnabled ?? true;
|
||||
this.modelAvailabilityService = new ModelAvailabilityService();
|
||||
this.experimentalJitContext = params.experimentalJitContext ?? false;
|
||||
this.imageGeneration = params.imageGeneration ?? false;
|
||||
this.modelSteering = params.modelSteering ?? false;
|
||||
this.userHintService = new UserHintService(() =>
|
||||
this.isModelSteeringEnabled(),
|
||||
@@ -1260,6 +1265,36 @@ export class Config implements McpContext {
|
||||
},
|
||||
);
|
||||
this.setRemoteAdminSettings(adminControls);
|
||||
|
||||
// Re-evaluate image generation tool registration after auth change
|
||||
this.updateImageGenToolRegistration();
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers or unregisters the GenerateImageTool based on current auth type.
|
||||
* Called after auth completes or is refreshed.
|
||||
*/
|
||||
private updateImageGenToolRegistration(): void {
|
||||
if (!this.imageGeneration) {
|
||||
this.toolRegistry?.unregisterTool(GENERATE_IMAGE_TOOL_NAME);
|
||||
return;
|
||||
}
|
||||
|
||||
const currentAuthType = this.getContentGeneratorConfig()?.authType;
|
||||
const supportsImageGen =
|
||||
currentAuthType === AuthType.USE_GEMINI ||
|
||||
currentAuthType === AuthType.USE_VERTEX_AI;
|
||||
|
||||
if (
|
||||
supportsImageGen &&
|
||||
!this.toolRegistry?.getTool(GENERATE_IMAGE_TOOL_NAME)
|
||||
) {
|
||||
this.toolRegistry?.registerTool(
|
||||
new GenerateImageTool(this, this.messageBus),
|
||||
);
|
||||
} else if (!supportsImageGen) {
|
||||
this.toolRegistry?.unregisterTool(GENERATE_IMAGE_TOOL_NAME);
|
||||
}
|
||||
}
|
||||
|
||||
async getExperimentsAsync(): Promise<Experiments | undefined> {
|
||||
@@ -2234,6 +2269,10 @@ export class Config implements McpContext {
|
||||
return this.experimentalZedIntegration;
|
||||
}
|
||||
|
||||
isImageGenerationEnabled(): boolean {
|
||||
return this.imageGeneration;
|
||||
}
|
||||
|
||||
getListExtensions(): boolean {
|
||||
return this.listExtensions;
|
||||
}
|
||||
@@ -2841,6 +2880,19 @@ export class Config implements McpContext {
|
||||
);
|
||||
}
|
||||
|
||||
// Register image generation tool if enabled and auth supports it
|
||||
if (this.imageGeneration) {
|
||||
const authType = this.getContentGeneratorConfig()?.authType;
|
||||
if (
|
||||
authType === AuthType.USE_GEMINI ||
|
||||
authType === AuthType.USE_VERTEX_AI
|
||||
) {
|
||||
maybeRegister(GenerateImageTool, () =>
|
||||
registry.registerTool(new GenerateImageTool(this, this.messageBus)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Register Subagents as Tools
|
||||
this.registerSubAgentTools(registry);
|
||||
|
||||
|
||||
@@ -169,6 +169,7 @@ export * from './tools/mcp-tool.js';
|
||||
export * from './tools/write-todos.js';
|
||||
export * from './tools/activate-skill.js';
|
||||
export * from './tools/ask-user.js';
|
||||
export * from './tools/generate-image.js';
|
||||
|
||||
// MCP OAuth
|
||||
export { MCPOAuthProvider } from './mcp/oauth-provider.js';
|
||||
|
||||
@@ -32,3 +32,4 @@ export const ACTIVATE_SKILL_TOOL_NAME = 'activate_skill';
|
||||
export const ASK_USER_TOOL_NAME = 'ask_user';
|
||||
export const EXIT_PLAN_MODE_TOOL_NAME = 'exit_plan_mode';
|
||||
export const ENTER_PLAN_MODE_TOOL_NAME = 'enter_plan_mode';
|
||||
export const GENERATE_IMAGE_TOOL_NAME = 'generate_image';
|
||||
|
||||
@@ -38,6 +38,7 @@ export {
|
||||
ASK_USER_TOOL_NAME,
|
||||
EXIT_PLAN_MODE_TOOL_NAME,
|
||||
ENTER_PLAN_MODE_TOOL_NAME,
|
||||
GENERATE_IMAGE_TOOL_NAME,
|
||||
} from './base-declarations.js';
|
||||
|
||||
// Re-export sets for compatibility
|
||||
|
||||
420
packages/core/src/tools/generate-image.test.ts
Normal file
420
packages/core/src/tools/generate-image.test.ts
Normal file
@@ -0,0 +1,420 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import * as fs from 'node:fs';
|
||||
|
||||
import type { GenerateImageParams } from './generate-image.js';
|
||||
import { GenerateImageTool ,
|
||||
promptToFilename,
|
||||
getUniqueFilename,
|
||||
validateOutputPath,
|
||||
} from './generate-image.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
|
||||
// Mock the @google/genai module
|
||||
const mockGenerateContent = vi.fn();
|
||||
vi.mock('@google/genai', () => ({
|
||||
GoogleGenAI: vi.fn().mockImplementation(() => ({
|
||||
models: {
|
||||
generateContent: mockGenerateContent,
|
||||
},
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock node:fs - must handle both default and named exports
|
||||
vi.mock('node:fs', async () => {
|
||||
const actual = await vi.importActual<typeof import('node:fs')>('node:fs');
|
||||
return {
|
||||
...actual,
|
||||
default: {
|
||||
...actual,
|
||||
existsSync: vi.fn(),
|
||||
statSync: vi.fn(),
|
||||
mkdirSync: vi.fn(),
|
||||
promises: {
|
||||
...actual.promises,
|
||||
readFile: vi.fn(),
|
||||
writeFile: vi.fn(),
|
||||
},
|
||||
},
|
||||
existsSync: vi.fn(),
|
||||
statSync: vi.fn(),
|
||||
mkdirSync: vi.fn(),
|
||||
promises: {
|
||||
...actual.promises,
|
||||
readFile: vi.fn(),
|
||||
writeFile: vi.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const FAKE_CWD = '/fake/project';
|
||||
// Valid base64 image data (no padding chars in the middle, >1000 chars)
|
||||
const FAKE_BASE64_IMAGE = 'A'.repeat(2000);
|
||||
|
||||
function createMockConfig(): Config {
|
||||
return {
|
||||
getTargetDir: () => FAKE_CWD,
|
||||
getContentGeneratorConfig: () => ({
|
||||
apiKey: 'test-api-key',
|
||||
authType: 'gemini-api-key',
|
||||
}),
|
||||
} as unknown as Config;
|
||||
}
|
||||
|
||||
function mockSuccessResponse() {
|
||||
return {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [
|
||||
{
|
||||
inlineData: {
|
||||
data: FAKE_BASE64_IMAGE,
|
||||
mimeType: 'image/png',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
describe('GenerateImageTool', () => {
|
||||
let tool: GenerateImageTool;
|
||||
let mockConfig: Config;
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig = createMockConfig();
|
||||
tool = new GenerateImageTool(mockConfig, createMockMessageBus());
|
||||
mockGenerateContent.mockReset();
|
||||
|
||||
// Default: output dir doesn't exist (will be created), files don't exist
|
||||
vi.mocked(fs.existsSync).mockReturnValue(false);
|
||||
vi.mocked(fs.mkdirSync).mockReturnValue(undefined);
|
||||
vi.mocked(fs.promises.writeFile).mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('build / parameter validation', () => {
|
||||
it('should return an invocation for a valid prompt', () => {
|
||||
const params: GenerateImageParams = { prompt: 'a sunset over the ocean' };
|
||||
const invocation = tool.build(params);
|
||||
expect(invocation).toBeDefined();
|
||||
expect(invocation.params).toEqual(params);
|
||||
});
|
||||
|
||||
it('should throw for an empty prompt', () => {
|
||||
expect(() => tool.build({ prompt: '' })).toThrow(
|
||||
"The 'prompt' parameter cannot be empty.",
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw for a whitespace-only prompt', () => {
|
||||
expect(() => tool.build({ prompt: ' ' })).toThrow(
|
||||
"The 'prompt' parameter cannot be empty.",
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw if count is below 1', () => {
|
||||
expect(() => tool.build({ prompt: 'test', count: 0 })).toThrow();
|
||||
});
|
||||
|
||||
it('should throw if count is above 4', () => {
|
||||
expect(() => tool.build({ prompt: 'test', count: 5 })).toThrow();
|
||||
});
|
||||
|
||||
it('should accept valid count values', () => {
|
||||
expect(tool.build({ prompt: 'test', count: 1 })).toBeDefined();
|
||||
expect(tool.build({ prompt: 'test', count: 4 })).toBeDefined();
|
||||
});
|
||||
|
||||
it('should reject output_path outside cwd', () => {
|
||||
expect(() =>
|
||||
tool.build({ prompt: 'test', output_path: '/other/dir' }),
|
||||
).toThrow('Output path must be within the current working directory.');
|
||||
});
|
||||
|
||||
it('should accept output_path within cwd', () => {
|
||||
const invocation = tool.build({
|
||||
prompt: 'test',
|
||||
output_path: 'my-images',
|
||||
});
|
||||
expect(invocation).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getDescription', () => {
|
||||
it('should show generation mode for text-to-image', () => {
|
||||
const invocation = tool.build({ prompt: 'a sunset' });
|
||||
expect(invocation.getDescription()).toContain('Generate image');
|
||||
expect(invocation.getDescription()).toContain('a sunset');
|
||||
});
|
||||
|
||||
it('should show edit mode with source path when input_image is provided', () => {
|
||||
const invocation = tool.build({
|
||||
prompt: 'make it blue',
|
||||
input_image: '/fake/img.png',
|
||||
});
|
||||
const desc = invocation.getDescription();
|
||||
expect(desc).toContain('Edit image');
|
||||
expect(desc).toContain('Source: /fake/img.png');
|
||||
});
|
||||
|
||||
it('should show optional params when provided', () => {
|
||||
const invocation = tool.build({
|
||||
prompt: 'test',
|
||||
count: 3,
|
||||
aspect_ratio: '16:9',
|
||||
size: '2K',
|
||||
});
|
||||
const desc = invocation.getDescription();
|
||||
expect(desc).toContain('Count: 3');
|
||||
expect(desc).toContain('Ratio: 16:9');
|
||||
expect(desc).toContain('Size: 2K');
|
||||
});
|
||||
|
||||
it('should show default 1:1 aspect ratio for text-to-image', () => {
|
||||
const invocation = tool.build({ prompt: 'test' });
|
||||
const desc = invocation.getDescription();
|
||||
expect(desc).toContain('Ratio: 1:1');
|
||||
});
|
||||
|
||||
it('should omit aspect ratio for edit mode when not provided', () => {
|
||||
const invocation = tool.build({
|
||||
prompt: 'make it blue',
|
||||
input_image: '/fake/img.png',
|
||||
});
|
||||
const desc = invocation.getDescription();
|
||||
expect(desc).not.toContain('Ratio:');
|
||||
});
|
||||
|
||||
it('should omit count and size when not provided', () => {
|
||||
const invocation = tool.build({ prompt: 'test' });
|
||||
const desc = invocation.getDescription();
|
||||
expect(desc).not.toContain('Count:');
|
||||
expect(desc).not.toContain('Size:');
|
||||
});
|
||||
|
||||
it('should truncate long prompts in description', () => {
|
||||
const longPrompt = 'a'.repeat(100);
|
||||
const invocation = tool.build({ prompt: longPrompt });
|
||||
const desc = invocation.getDescription();
|
||||
expect(desc).toContain('...');
|
||||
});
|
||||
});
|
||||
|
||||
describe('execute', () => {
|
||||
it('should return cancel result when signal is already aborted', async () => {
|
||||
const controller = new AbortController();
|
||||
controller.abort();
|
||||
const invocation = tool.build({ prompt: 'test' });
|
||||
const result = await invocation.execute(controller.signal);
|
||||
expect(result.llmContent).toBe('Image generation cancelled.');
|
||||
});
|
||||
|
||||
it('should generate an image successfully', async () => {
|
||||
mockGenerateContent.mockResolvedValue(mockSuccessResponse());
|
||||
|
||||
const invocation = tool.build({ prompt: 'a cute cat' });
|
||||
const signal = new AbortController().signal;
|
||||
const result = await invocation.execute(signal);
|
||||
|
||||
expect(result.llmContent).toContain('Successfully generated 1 image(s)');
|
||||
expect(result.returnDisplay).toContain('Generated 1 image(s)');
|
||||
expect(vi.mocked(fs.promises.writeFile)).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return error when API returns no image data', async () => {
|
||||
mockGenerateContent.mockResolvedValue({
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: 'just text, no image' }],
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const invocation = tool.build({ prompt: 'test' });
|
||||
const signal = new AbortController().signal;
|
||||
const result = await invocation.execute(signal);
|
||||
|
||||
expect(result.llmContent).toContain('No image data in API response');
|
||||
});
|
||||
|
||||
it('should handle API auth errors', async () => {
|
||||
mockGenerateContent.mockRejectedValue(new Error('api key not valid'));
|
||||
|
||||
const invocation = tool.build({ prompt: 'test' });
|
||||
const signal = new AbortController().signal;
|
||||
const result = await invocation.execute(signal);
|
||||
|
||||
expect(result.llmContent).toContain('Authentication error');
|
||||
});
|
||||
|
||||
it('should handle safety filter errors', async () => {
|
||||
mockGenerateContent.mockRejectedValue(
|
||||
new Error('Content was blocked by safety filters'),
|
||||
);
|
||||
|
||||
const invocation = tool.build({ prompt: 'test' });
|
||||
const signal = new AbortController().signal;
|
||||
const result = await invocation.execute(signal);
|
||||
|
||||
expect(result.llmContent).toContain('safety filters');
|
||||
});
|
||||
|
||||
it('should validate input image existence during execute', async () => {
|
||||
// existsSync returns false for everything including the input image
|
||||
vi.mocked(fs.existsSync).mockReturnValue(false);
|
||||
|
||||
const invocation = tool.build({
|
||||
prompt: 'edit this',
|
||||
input_image: '/fake/missing.png',
|
||||
});
|
||||
const signal = new AbortController().signal;
|
||||
const result = await invocation.execute(signal);
|
||||
|
||||
expect(result.llmContent).toContain('Input image not found');
|
||||
});
|
||||
|
||||
it('should include inlineData when return_to_context is true', async () => {
|
||||
mockGenerateContent.mockResolvedValue(mockSuccessResponse());
|
||||
|
||||
const invocation = tool.build({
|
||||
prompt: 'a cat',
|
||||
return_to_context: true,
|
||||
});
|
||||
const signal = new AbortController().signal;
|
||||
const result = await invocation.execute(signal);
|
||||
|
||||
expect(Array.isArray(result.llmContent)).toBe(true);
|
||||
const parts = result.llmContent as Array<Record<string, unknown>>;
|
||||
expect(parts).toHaveLength(2);
|
||||
expect(parts[1]).toHaveProperty('inlineData');
|
||||
});
|
||||
|
||||
it('should stream progress for batch generation', async () => {
|
||||
mockGenerateContent.mockResolvedValue(mockSuccessResponse());
|
||||
|
||||
const updateOutput = vi.fn();
|
||||
const invocation = tool.build({ prompt: 'test', count: 2 });
|
||||
const signal = new AbortController().signal;
|
||||
await invocation.execute(signal, updateOutput);
|
||||
|
||||
expect(updateOutput).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Generating image 1 of 2'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle partial success in batch generation', async () => {
|
||||
mockGenerateContent
|
||||
.mockResolvedValueOnce(mockSuccessResponse())
|
||||
.mockRejectedValueOnce(new Error('quota exceeded'));
|
||||
|
||||
const invocation = tool.build({ prompt: 'test', count: 2 });
|
||||
const signal = new AbortController().signal;
|
||||
const result = await invocation.execute(signal);
|
||||
|
||||
expect(result.llmContent).toContain('Successfully generated 1 image(s)');
|
||||
expect(result.llmContent).toContain('Warnings');
|
||||
});
|
||||
|
||||
it('should abort batch immediately on auth error', async () => {
|
||||
mockGenerateContent.mockRejectedValue(new Error('api key not valid'));
|
||||
|
||||
const invocation = tool.build({ prompt: 'test', count: 3 });
|
||||
const signal = new AbortController().signal;
|
||||
const result = await invocation.execute(signal);
|
||||
|
||||
expect(mockGenerateContent).toHaveBeenCalledTimes(1);
|
||||
expect(result.llmContent).toContain('Authentication error');
|
||||
});
|
||||
|
||||
it('should handle text fallback response parsing', async () => {
|
||||
mockGenerateContent.mockResolvedValue({
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: FAKE_BASE64_IMAGE }],
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const invocation = tool.build({ prompt: 'test' });
|
||||
const signal = new AbortController().signal;
|
||||
const result = await invocation.execute(signal);
|
||||
|
||||
expect(result.llmContent).toContain('Successfully generated 1 image(s)');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('promptToFilename', () => {
|
||||
it('should convert prompt to lowercase with underscores', () => {
|
||||
expect(promptToFilename('A Sunset Over The Ocean')).toBe(
|
||||
'a_sunset_over_the_ocean',
|
||||
);
|
||||
});
|
||||
|
||||
it('should remove special characters', () => {
|
||||
expect(promptToFilename('hello! @world #2024')).toBe('hello_world_2024');
|
||||
});
|
||||
|
||||
it('should truncate to 32 characters', () => {
|
||||
const longPrompt = 'a'.repeat(50);
|
||||
expect(promptToFilename(longPrompt).length).toBe(32);
|
||||
});
|
||||
|
||||
it('should return default name for empty result', () => {
|
||||
expect(promptToFilename('!!!')).toBe('generated_image');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getUniqueFilename', () => {
|
||||
it('should return simple filename when no file exists', () => {
|
||||
vi.mocked(fs.existsSync).mockReturnValue(false);
|
||||
expect(getUniqueFilename('/dir', 'test', '.png')).toBe('test.png');
|
||||
});
|
||||
|
||||
it('should add variation index for batch', () => {
|
||||
vi.mocked(fs.existsSync).mockReturnValue(false);
|
||||
expect(getUniqueFilename('/dir', 'test', '.png', 0)).toBe('test.png');
|
||||
expect(getUniqueFilename('/dir', 'test', '.png', 1)).toBe('test_v2.png');
|
||||
expect(getUniqueFilename('/dir', 'test', '.png', 2)).toBe('test_v3.png');
|
||||
});
|
||||
|
||||
it('should auto-increment when file exists', () => {
|
||||
vi.mocked(fs.existsSync)
|
||||
.mockReturnValueOnce(true) // test.png exists
|
||||
.mockReturnValueOnce(false); // test_1.png doesn't exist
|
||||
expect(getUniqueFilename('/dir', 'test', '.png')).toBe('test_1.png');
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateOutputPath', () => {
|
||||
it('should accept paths within cwd', () => {
|
||||
expect(validateOutputPath('images', '/project')).toBeNull();
|
||||
expect(validateOutputPath('./output', '/project')).toBeNull();
|
||||
});
|
||||
|
||||
it('should reject paths outside cwd', () => {
|
||||
expect(validateOutputPath('/other/dir', '/project')).toContain(
|
||||
'within the current working directory',
|
||||
);
|
||||
expect(validateOutputPath('../outside', '/project')).toContain(
|
||||
'within the current working directory',
|
||||
);
|
||||
});
|
||||
});
|
||||
760
packages/core/src/tools/generate-image.ts
Normal file
760
packages/core/src/tools/generate-image.ts
Normal file
@@ -0,0 +1,760 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import * as fs from 'node:fs';
|
||||
import path from 'node:path';
|
||||
import { GoogleGenAI } from '@google/genai';
|
||||
import type { ToolInvocation, ToolLocation, ToolResult } from './tools.js';
|
||||
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
|
||||
import type { AnsiOutput } from '../utils/terminalSerializer.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import {
|
||||
GENERATE_IMAGE_TOOL_NAME,
|
||||
GENERATE_IMAGE_DISPLAY_NAME,
|
||||
} from './tool-names.js';
|
||||
import { getSpecificMimeType } from '../utils/fileUtils.js';
|
||||
|
||||
// ─── Constants ───────────────────────────────────────────────────────────────
|
||||
|
||||
export const DEFAULT_IMAGE_MODEL = 'gemini-3.1-flash-image-preview';
|
||||
export const DEFAULT_OUTPUT_DIR = 'generated-images';
|
||||
const MAX_FILENAME_LENGTH = 32;
|
||||
const MAX_INPUT_IMAGE_SIZE = 20 * 1024 * 1024; // 20MB
|
||||
|
||||
// ─── Types ───────────────────────────────────────────────────────────────────
|
||||
|
||||
export interface GenerateImageParams {
|
||||
prompt: string;
|
||||
input_image?: string;
|
||||
output_path?: string;
|
||||
filename?: string;
|
||||
count?: number;
|
||||
return_to_context?: boolean;
|
||||
aspect_ratio?: string;
|
||||
size?: string;
|
||||
model?: string;
|
||||
}
|
||||
|
||||
export interface ImageGenerationResult {
|
||||
success: boolean;
|
||||
filePaths: string[];
|
||||
mimeType: string;
|
||||
base64Data?: string;
|
||||
errors?: string[];
|
||||
}
|
||||
|
||||
export interface GenerateImageOptions {
|
||||
config: Config;
|
||||
params: GenerateImageParams;
|
||||
cwd: string;
|
||||
signal: AbortSignal;
|
||||
updateOutput?: (output: string) => void;
|
||||
}
|
||||
|
||||
interface GenerateImageApiParams {
|
||||
ai: GoogleGenAI;
|
||||
modelName: string;
|
||||
prompt: string;
|
||||
inputImageBase64?: string;
|
||||
inputImageMimeType?: string;
|
||||
aspectRatio?: string;
|
||||
size?: string;
|
||||
isEditing: boolean;
|
||||
}
|
||||
|
||||
// ─── Image Generation Engine ─────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Creates a GoogleGenAI client for image generation using credentials from Config.
|
||||
* Uses a separate instance from the chat model's client since image generation
|
||||
* requires different models.
|
||||
*/
|
||||
export function createImageGenClient(config: Config): GoogleGenAI {
|
||||
const cgConfig = config.getContentGeneratorConfig();
|
||||
|
||||
if (cgConfig?.vertexai) {
|
||||
return new GoogleGenAI({
|
||||
vertexai: true,
|
||||
project: process.env['GOOGLE_CLOUD_PROJECT'],
|
||||
location: process.env['GOOGLE_CLOUD_LOCATION'],
|
||||
});
|
||||
}
|
||||
|
||||
return new GoogleGenAI({ apiKey: cgConfig?.apiKey });
|
||||
}
|
||||
|
||||
function isValidBase64ImageData(data: string): boolean {
|
||||
if (!data || data.length < 1000) return false;
|
||||
const base64Regex = /^[A-Za-z0-9+/]*={0,2}$/;
|
||||
return base64Regex.test(data);
|
||||
}
|
||||
|
||||
export function promptToFilename(prompt: string): string {
|
||||
let baseName = prompt
|
||||
.toLowerCase()
|
||||
.replace(/[^a-z0-9\s]/g, '')
|
||||
.replace(/\s+/g, '_')
|
||||
.substring(0, MAX_FILENAME_LENGTH);
|
||||
|
||||
if (!baseName) {
|
||||
baseName = 'generated_image';
|
||||
}
|
||||
|
||||
return baseName;
|
||||
}
|
||||
|
||||
function mimeToExtension(mimeType: string): string {
|
||||
switch (mimeType) {
|
||||
case 'image/jpeg':
|
||||
return '.jpg';
|
||||
case 'image/webp':
|
||||
return '.webp';
|
||||
case 'image/gif':
|
||||
return '.gif';
|
||||
case 'image/png':
|
||||
default:
|
||||
return '.png';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a unique filename, appending a variation suffix for batches
|
||||
* and auto-incrementing if a file already exists at the target path.
|
||||
*/
|
||||
export function getUniqueFilename(
|
||||
outputDir: string,
|
||||
baseName: string,
|
||||
extension: string,
|
||||
variationIndex?: number,
|
||||
): string {
|
||||
let filename: string;
|
||||
|
||||
if (variationIndex !== undefined && variationIndex > 0) {
|
||||
filename = `${baseName}_v${variationIndex + 1}${extension}`;
|
||||
} else {
|
||||
filename = `${baseName}${extension}`;
|
||||
}
|
||||
|
||||
if (!fs.existsSync(path.join(outputDir, filename))) {
|
||||
return filename;
|
||||
}
|
||||
|
||||
// Auto-increment
|
||||
let counter = 1;
|
||||
while (true) {
|
||||
const suffix =
|
||||
variationIndex !== undefined && variationIndex > 0
|
||||
? `_v${variationIndex + 1}_${counter}`
|
||||
: `_${counter}`;
|
||||
filename = `${baseName}${suffix}${extension}`;
|
||||
if (!fs.existsSync(path.join(outputDir, filename))) {
|
||||
return filename;
|
||||
}
|
||||
counter++;
|
||||
}
|
||||
}
|
||||
|
||||
export function validateInputImage(
|
||||
inputImagePath: string,
|
||||
cwd: string,
|
||||
): string | null {
|
||||
const resolved = path.isAbsolute(inputImagePath)
|
||||
? inputImagePath
|
||||
: path.resolve(cwd, inputImagePath);
|
||||
|
||||
if (!fs.existsSync(resolved)) {
|
||||
return `Image generation failed: Input image not found at '${inputImagePath}'.`;
|
||||
}
|
||||
|
||||
const stat = fs.statSync(resolved);
|
||||
if (stat.size > MAX_INPUT_IMAGE_SIZE) {
|
||||
return 'Image generation failed: Input image exceeds 20MB size limit.';
|
||||
}
|
||||
|
||||
const mime = getSpecificMimeType(resolved);
|
||||
if (!mime || !mime.startsWith('image/')) {
|
||||
return `Image generation failed: '${inputImagePath}' is not a supported image format. Supports PNG, JPEG, and WebP.`;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
export function validateOutputPath(
|
||||
outputPath: string,
|
||||
cwd: string,
|
||||
): string | null {
|
||||
const resolved = path.resolve(cwd, outputPath);
|
||||
const normalizedCwd = path.resolve(cwd);
|
||||
|
||||
if (
|
||||
!resolved.startsWith(normalizedCwd + path.sep) &&
|
||||
resolved !== normalizedCwd
|
||||
) {
|
||||
return 'Image generation failed: Output path must be within the current working directory.';
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Makes a single API call to generate an image.
|
||||
* Checks inlineData first, then falls back to base64-encoded text.
|
||||
*/
|
||||
async function callImageApi(
|
||||
params: GenerateImageApiParams,
|
||||
): Promise<{ base64Data: string; mimeType: string } | null> {
|
||||
const {
|
||||
ai,
|
||||
modelName,
|
||||
prompt,
|
||||
inputImageBase64,
|
||||
inputImageMimeType,
|
||||
aspectRatio,
|
||||
size,
|
||||
isEditing,
|
||||
} = params;
|
||||
|
||||
const parts: Array<Record<string, unknown>> = [{ text: prompt }];
|
||||
|
||||
if (inputImageBase64 && inputImageMimeType) {
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: inputImageBase64,
|
||||
mimeType: inputImageMimeType,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const imageConfig: Record<string, string> = {};
|
||||
if (size) {
|
||||
imageConfig['imageSize'] = size;
|
||||
}
|
||||
if (aspectRatio && (!isEditing || aspectRatio)) {
|
||||
imageConfig['aspectRatio'] = aspectRatio;
|
||||
}
|
||||
|
||||
const response = await ai.models.generateContent({
|
||||
model: modelName,
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts,
|
||||
},
|
||||
],
|
||||
config: Object.keys(imageConfig).length > 0 ? { imageConfig } : undefined,
|
||||
});
|
||||
|
||||
if (response.candidates?.[0]?.content?.parts) {
|
||||
for (const part of response.candidates[0].content.parts) {
|
||||
if (part.inlineData?.data && part.inlineData?.mimeType) {
|
||||
return {
|
||||
base64Data: part.inlineData.data,
|
||||
mimeType: part.inlineData.mimeType,
|
||||
};
|
||||
}
|
||||
if (part.text && isValidBase64ImageData(part.text)) {
|
||||
return {
|
||||
base64Data: part.text,
|
||||
mimeType: 'image/png',
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Classifies an API error into a user-friendly message.
|
||||
* Auth errors are flagged separately so batch generation can abort immediately.
|
||||
*/
|
||||
function classifyApiError(error: unknown): {
|
||||
message: string;
|
||||
isAuthError: boolean;
|
||||
} {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error).toLowerCase();
|
||||
|
||||
if (
|
||||
errorMessage.includes('api key not valid') ||
|
||||
errorMessage.includes('permission denied') ||
|
||||
errorMessage.includes('403')
|
||||
) {
|
||||
return {
|
||||
message:
|
||||
'Image generation failed: Authentication error. Ensure your API key or Vertex AI credentials have access to image generation models.',
|
||||
isAuthError: true,
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
errorMessage.includes('safety') ||
|
||||
errorMessage.includes('blocked') ||
|
||||
errorMessage.includes('SAFETY')
|
||||
) {
|
||||
return {
|
||||
message:
|
||||
'Image generation failed: The prompt was blocked by safety filters. Please modify your prompt and try again.',
|
||||
isAuthError: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (errorMessage.includes('quota') || errorMessage.includes('429')) {
|
||||
return {
|
||||
message:
|
||||
'Image generation failed: API quota exceeded. Check your Google Cloud console for quota details.',
|
||||
isAuthError: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (errorMessage.includes('500')) {
|
||||
return {
|
||||
message:
|
||||
'Image generation failed: Internal service error. Please try again.',
|
||||
isAuthError: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
errorMessage.includes('ENOTFOUND') ||
|
||||
errorMessage.includes('ECONNREFUSED') ||
|
||||
errorMessage.includes('network')
|
||||
) {
|
||||
return {
|
||||
message:
|
||||
'Image generation failed: Network error. Check your internet connection.',
|
||||
isAuthError: false,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
message: `Image generation failed: ${error instanceof Error ? error.message : String(error)}`,
|
||||
isAuthError: false,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Main entry point for image generation. Generates one or more images
|
||||
* sequentially, saving each to disk. Supports partial success for batches.
|
||||
*/
|
||||
export async function generateImages(
|
||||
options: GenerateImageOptions,
|
||||
): Promise<ImageGenerationResult> {
|
||||
const { config, params, cwd, signal, updateOutput } = options;
|
||||
const count = Math.min(Math.max(params.count || 1, 1), 4);
|
||||
const modelName =
|
||||
params.model || process.env['GEMINI_IMAGE_MODEL'] || DEFAULT_IMAGE_MODEL;
|
||||
const outputDir = path.resolve(cwd, params.output_path || DEFAULT_OUTPUT_DIR);
|
||||
|
||||
if (!fs.existsSync(outputDir)) {
|
||||
fs.mkdirSync(outputDir, { recursive: true });
|
||||
}
|
||||
|
||||
let inputImageBase64: string | undefined;
|
||||
let inputImageMimeType: string | undefined;
|
||||
const isEditing = !!params.input_image;
|
||||
|
||||
if (params.input_image) {
|
||||
const resolvedInput = path.isAbsolute(params.input_image)
|
||||
? params.input_image
|
||||
: path.resolve(cwd, params.input_image);
|
||||
const imageBuffer = await fs.promises.readFile(resolvedInput);
|
||||
inputImageBase64 = imageBuffer.toString('base64');
|
||||
inputImageMimeType = getSpecificMimeType(resolvedInput) || 'image/png';
|
||||
}
|
||||
|
||||
const baseName = params.filename
|
||||
? params.filename.replace(/\.[^.]+$/, '')
|
||||
: promptToFilename(params.prompt);
|
||||
|
||||
const ai = createImageGenClient(config);
|
||||
const generatedFiles: string[] = [];
|
||||
const errors: string[] = [];
|
||||
let firstImageBase64: string | undefined;
|
||||
let firstImageMimeType = 'image/png';
|
||||
|
||||
for (let i = 0; i < count; i++) {
|
||||
if (signal.aborted) {
|
||||
if (generatedFiles.length > 0) {
|
||||
errors.push('User cancelled. Returning partial results.');
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (count > 1) {
|
||||
updateOutput?.(`Generating image ${i + 1} of ${count}...`);
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await callImageApi({
|
||||
ai,
|
||||
modelName,
|
||||
prompt: params.prompt,
|
||||
inputImageBase64,
|
||||
inputImageMimeType,
|
||||
aspectRatio:
|
||||
isEditing && !params.aspect_ratio
|
||||
? undefined
|
||||
: params.aspect_ratio || '1:1',
|
||||
size: params.size || '1K',
|
||||
isEditing,
|
||||
});
|
||||
|
||||
if (signal.aborted) {
|
||||
if (generatedFiles.length > 0) {
|
||||
errors.push('User cancelled. Returning partial results.');
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (!result) {
|
||||
errors.push(`Variation ${i + 1}: No image data in API response.`);
|
||||
continue;
|
||||
}
|
||||
|
||||
const extension = mimeToExtension(result.mimeType);
|
||||
const filename = getUniqueFilename(
|
||||
outputDir,
|
||||
baseName,
|
||||
extension,
|
||||
count > 1 ? i : undefined,
|
||||
);
|
||||
const fullPath = path.join(outputDir, filename);
|
||||
|
||||
const buffer = Buffer.from(result.base64Data, 'base64');
|
||||
await fs.promises.writeFile(fullPath, buffer);
|
||||
|
||||
generatedFiles.push(fullPath);
|
||||
|
||||
if (i === 0 && params.return_to_context) {
|
||||
firstImageBase64 = result.base64Data;
|
||||
firstImageMimeType = result.mimeType;
|
||||
}
|
||||
|
||||
if (count > 1) {
|
||||
updateOutput?.(`Generated image ${i + 1} of ${count}: ${fullPath}`);
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
const classified = classifyApiError(error);
|
||||
errors.push(`Variation ${i + 1}: ${classified.message}`);
|
||||
|
||||
if (classified.isAuthError) {
|
||||
return {
|
||||
success: false,
|
||||
filePaths: generatedFiles,
|
||||
mimeType: firstImageMimeType,
|
||||
errors,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
success: generatedFiles.length > 0,
|
||||
filePaths: generatedFiles,
|
||||
mimeType: firstImageMimeType,
|
||||
base64Data: firstImageBase64,
|
||||
errors: errors.length > 0 ? errors : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
// ─── Tool Schema ─────────────────────────────────────────────────────────────
|
||||
|
||||
const GENERATE_IMAGE_DESCRIPTION =
|
||||
'Generates images from text prompts or edits existing images using Nano Banana image generation models. When an input_image path is provided, the prompt describes the desired edits to that image. Supports generating multiple variations.';
|
||||
|
||||
const GENERATE_IMAGE_SCHEMA = {
|
||||
type: 'object' as const,
|
||||
properties: {
|
||||
prompt: {
|
||||
type: 'string' as const,
|
||||
description:
|
||||
'A detailed text description of the image to generate, or editing instructions when used with input_image. Be specific about style, composition, colors, lighting, and subject matter.',
|
||||
},
|
||||
input_image: {
|
||||
type: 'string' as const,
|
||||
description:
|
||||
'Optional. Absolute path to an existing image file to edit or transform. When provided, the prompt describes the desired edits. Supports PNG, JPEG, and WebP formats.',
|
||||
},
|
||||
output_path: {
|
||||
type: 'string' as const,
|
||||
description:
|
||||
"Optional. Directory path (relative to cwd) where the generated image should be saved. Must be within the current working directory. Defaults to './generated-images/'.",
|
||||
},
|
||||
filename: {
|
||||
type: 'string' as const,
|
||||
description:
|
||||
'Optional. Custom filename for the output image (without extension). If not provided, a filename is auto-generated from the prompt.',
|
||||
},
|
||||
count: {
|
||||
type: 'integer' as const,
|
||||
description:
|
||||
'Number of image variations to generate (1-4). Defaults to 1.',
|
||||
minimum: 1,
|
||||
maximum: 4,
|
||||
default: 1,
|
||||
},
|
||||
return_to_context: {
|
||||
type: 'boolean' as const,
|
||||
description:
|
||||
'Optional. If true, the generated image is returned as inlineData in the tool result so the model can see and iterate on it. Use this when the user wants to refine or iterate on the generated image. Defaults to false.',
|
||||
default: false,
|
||||
},
|
||||
aspect_ratio: {
|
||||
type: 'string' as const,
|
||||
description:
|
||||
"Optional. Aspect ratio of the generated image. Defaults to '1:1' for text-to-image. For image editing, defaults to the input image's original aspect ratio.",
|
||||
enum: [
|
||||
'1:1',
|
||||
'16:9',
|
||||
'9:16',
|
||||
'3:2',
|
||||
'2:3',
|
||||
'4:3',
|
||||
'3:4',
|
||||
'4:5',
|
||||
'5:4',
|
||||
'21:9',
|
||||
'4:1',
|
||||
'1:4',
|
||||
'8:1',
|
||||
'1:8',
|
||||
],
|
||||
default: '1:1',
|
||||
},
|
||||
size: {
|
||||
type: 'string' as const,
|
||||
description:
|
||||
"Optional. Output resolution tier. '1K' is ~1024px on the long edge. Defaults to '1K'.",
|
||||
enum: ['512px', '1K', '2K', '4K'],
|
||||
default: '1K',
|
||||
},
|
||||
model: {
|
||||
type: 'string' as const,
|
||||
description:
|
||||
"Optional. Override the image generation model. Defaults to 'gemini-3.1-flash-image-preview'.",
|
||||
enum: [
|
||||
'gemini-3.1-flash-image-preview',
|
||||
'gemini-3-pro-image-preview',
|
||||
'gemini-2.5-flash-image',
|
||||
],
|
||||
},
|
||||
},
|
||||
required: ['prompt'] as const,
|
||||
};
|
||||
|
||||
// ─── Tool Invocation ─────────────────────────────────────────────────────────
|
||||
|
||||
class GenerateImageInvocation extends BaseToolInvocation<
|
||||
GenerateImageParams,
|
||||
ToolResult
|
||||
> {
|
||||
constructor(
|
||||
params: GenerateImageParams,
|
||||
private readonly config: Config,
|
||||
messageBus: MessageBus,
|
||||
toolName?: string,
|
||||
toolDisplayName?: string,
|
||||
) {
|
||||
super(params, messageBus, toolName, toolDisplayName);
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
const truncatedPrompt =
|
||||
this.params.prompt.length > 80
|
||||
? this.params.prompt.substring(0, 80) + '...'
|
||||
: this.params.prompt;
|
||||
const mode = this.params.input_image ? 'Edit image' : 'Generate image';
|
||||
const model = this.params.model || DEFAULT_IMAGE_MODEL;
|
||||
const outputDir = this.params.output_path || `./${DEFAULT_OUTPUT_DIR}/`;
|
||||
const count = this.params.count || 1;
|
||||
|
||||
const lines = [`${mode}: "${truncatedPrompt}"`];
|
||||
if (this.params.input_image) {
|
||||
lines.push(` Source: ${this.params.input_image}`);
|
||||
}
|
||||
lines.push(` Output: ${outputDir}`);
|
||||
lines.push(` Model: ${model}`);
|
||||
if (count > 1) {
|
||||
lines.push(` Count: ${count}`);
|
||||
}
|
||||
if (!this.params.input_image) {
|
||||
lines.push(` Ratio: ${this.params.aspect_ratio || '1:1'}`);
|
||||
} else if (this.params.aspect_ratio) {
|
||||
lines.push(` Ratio: ${this.params.aspect_ratio}`);
|
||||
}
|
||||
if (this.params.size) {
|
||||
lines.push(` Size: ${this.params.size}`);
|
||||
}
|
||||
return lines.join('\n');
|
||||
}
|
||||
|
||||
override toolLocations(): ToolLocation[] {
|
||||
return [
|
||||
{
|
||||
path: path.resolve(
|
||||
this.config.getTargetDir(),
|
||||
this.params.output_path || 'generated-images',
|
||||
),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
async execute(
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string | AnsiOutput) => void,
|
||||
): Promise<ToolResult> {
|
||||
if (signal.aborted) {
|
||||
return {
|
||||
llmContent: 'Image generation cancelled.',
|
||||
returnDisplay: 'Cancelled.',
|
||||
};
|
||||
}
|
||||
|
||||
const cwd = this.config.getTargetDir();
|
||||
|
||||
// Validate input image if provided
|
||||
if (this.params.input_image) {
|
||||
const inputError = validateInputImage(this.params.input_image, cwd);
|
||||
if (inputError) {
|
||||
return {
|
||||
llmContent: inputError,
|
||||
returnDisplay: inputError,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Validate output path
|
||||
if (this.params.output_path) {
|
||||
const outputError = validateOutputPath(this.params.output_path, cwd);
|
||||
if (outputError) {
|
||||
return {
|
||||
llmContent: outputError,
|
||||
returnDisplay: outputError,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const result = await generateImages({
|
||||
config: this.config,
|
||||
params: this.params,
|
||||
cwd,
|
||||
signal,
|
||||
updateOutput: updateOutput
|
||||
? (msg: string) => updateOutput(msg)
|
||||
: undefined,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
const errorMsg = result.errors?.join('\n') || 'Image generation failed.';
|
||||
return {
|
||||
llmContent: errorMsg,
|
||||
returnDisplay: errorMsg,
|
||||
};
|
||||
}
|
||||
|
||||
const count = result.filePaths.length;
|
||||
const fileList = result.filePaths.map((p) => `- ${p}`).join('\n');
|
||||
const warningText = result.errors
|
||||
? '\n\nWarnings:\n' + result.errors.join('\n')
|
||||
: '';
|
||||
|
||||
if (this.params.return_to_context && result.base64Data) {
|
||||
return {
|
||||
llmContent: [
|
||||
{
|
||||
text: `Successfully generated ${count} image(s):\n${fileList}${warningText}\n\nThe first image is included below for review.`,
|
||||
},
|
||||
{
|
||||
inlineData: {
|
||||
data: result.base64Data,
|
||||
mimeType: result.mimeType,
|
||||
},
|
||||
},
|
||||
],
|
||||
returnDisplay:
|
||||
`Generated ${count} image(s):\n` +
|
||||
result.filePaths.map((p) => ` ${p}`).join('\n'),
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
llmContent: `Successfully generated ${count} image(s):\n${fileList}${warningText}\n\nThe images have been saved to disk. You can reference these file paths in subsequent operations.`,
|
||||
returnDisplay:
|
||||
`Generated ${count} image(s):\n` +
|
||||
result.filePaths.map((p) => ` ${p}`).join('\n'),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Tool Builder ────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Built-in tool for generating and editing images using Nano Banana models.
|
||||
* Gated behind the `imageGeneration` setting and requires Gemini API key
|
||||
* or Vertex AI authentication.
|
||||
*/
|
||||
export class GenerateImageTool extends BaseDeclarativeTool<
|
||||
GenerateImageParams,
|
||||
ToolResult
|
||||
> {
|
||||
static readonly Name = GENERATE_IMAGE_TOOL_NAME;
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
super(
|
||||
GenerateImageTool.Name,
|
||||
GENERATE_IMAGE_DISPLAY_NAME,
|
||||
GENERATE_IMAGE_DESCRIPTION,
|
||||
Kind.Execute,
|
||||
GENERATE_IMAGE_SCHEMA,
|
||||
messageBus,
|
||||
true, // isOutputMarkdown
|
||||
true, // canUpdateOutput (for batch streaming)
|
||||
);
|
||||
}
|
||||
|
||||
protected override validateToolParamValues(
|
||||
params: GenerateImageParams,
|
||||
): string | null {
|
||||
if (!params.prompt || params.prompt.trim() === '') {
|
||||
return "The 'prompt' parameter cannot be empty.";
|
||||
}
|
||||
|
||||
if (params.count !== undefined) {
|
||||
if (params.count < 1 || params.count > 4) {
|
||||
return "The 'count' parameter must be between 1 and 4.";
|
||||
}
|
||||
}
|
||||
|
||||
if (params.output_path) {
|
||||
const error = validateOutputPath(
|
||||
params.output_path,
|
||||
this.config.getTargetDir(),
|
||||
);
|
||||
if (error) return error;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
protected createInvocation(
|
||||
params: GenerateImageParams,
|
||||
messageBus: MessageBus,
|
||||
toolName?: string,
|
||||
toolDisplayName?: string,
|
||||
): ToolInvocation<GenerateImageParams, ToolResult> {
|
||||
return new GenerateImageInvocation(
|
||||
params,
|
||||
this.config,
|
||||
messageBus ?? this.messageBus,
|
||||
toolName,
|
||||
toolDisplayName,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import {
|
||||
ASK_USER_TOOL_NAME,
|
||||
EXIT_PLAN_MODE_TOOL_NAME,
|
||||
ENTER_PLAN_MODE_TOOL_NAME,
|
||||
GENERATE_IMAGE_TOOL_NAME,
|
||||
} from './definitions/coreTools.js';
|
||||
|
||||
export {
|
||||
@@ -42,8 +43,11 @@ export {
|
||||
ASK_USER_TOOL_NAME,
|
||||
EXIT_PLAN_MODE_TOOL_NAME,
|
||||
ENTER_PLAN_MODE_TOOL_NAME,
|
||||
GENERATE_IMAGE_TOOL_NAME,
|
||||
};
|
||||
|
||||
export const GENERATE_IMAGE_DISPLAY_NAME = 'GenerateImage';
|
||||
|
||||
export const LS_TOOL_NAME_LEGACY = 'list_directory'; // Just to be safe if anything used the old exported name directly
|
||||
|
||||
export const EDIT_TOOL_NAMES = new Set([EDIT_TOOL_NAME, WRITE_FILE_TOOL_NAME]);
|
||||
@@ -110,6 +114,7 @@ export const ALL_BUILTIN_TOOL_NAMES = [
|
||||
GET_INTERNAL_DOCS_TOOL_NAME,
|
||||
ENTER_PLAN_MODE_TOOL_NAME,
|
||||
EXIT_PLAN_MODE_TOOL_NAME,
|
||||
GENERATE_IMAGE_TOOL_NAME,
|
||||
] as const;
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user