mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-22 01:33:30 -07:00
feat(core): enhance render-prompt with memoization, slot parsing, and dynamic attributes
This commit is contained in:
@@ -4,8 +4,8 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { renderPrompt, p } from './render-prompt.js';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
import { renderPrompt, p, memoize, parseSlots } from './render-prompt.js';
|
||||
import type { PromptContent } from './render-prompt.js';
|
||||
|
||||
type TestContext = { name?: string; shouldRender?: boolean };
|
||||
@@ -338,6 +338,55 @@ const tests: TestCase[] = [
|
||||
context: {},
|
||||
expect: 'Visible start\n\nVisible end',
|
||||
},
|
||||
{
|
||||
desc: 'renders explicit level overriding depth and resetting children',
|
||||
content: {
|
||||
heading: 'Top Level',
|
||||
content: {
|
||||
heading: 'Nested',
|
||||
level: 4,
|
||||
content: {
|
||||
heading: 'Deep',
|
||||
content: 'Text',
|
||||
},
|
||||
},
|
||||
},
|
||||
context: {},
|
||||
expect: '# Top Level\n\n#### Nested\n\n##### Deep\n\nText',
|
||||
},
|
||||
{
|
||||
desc: 'renders level 0 as # and children as level 1 (#)',
|
||||
content: {
|
||||
heading: 'Level 0',
|
||||
level: 0,
|
||||
content: {
|
||||
heading: 'Level 1',
|
||||
content: 'Text',
|
||||
},
|
||||
},
|
||||
context: {},
|
||||
expect: '# Level 0\n\n# Level 1\n\nText',
|
||||
},
|
||||
{
|
||||
desc: 'resolves dynamic attributes synchronously',
|
||||
content: {
|
||||
tag: 'dynamic',
|
||||
attrs: { static: 'val', dyn: (ctx) => `hello-${ctx.name}` },
|
||||
content: 'Inside',
|
||||
},
|
||||
context: { name: 'Alice' },
|
||||
expect: '<dynamic static="val" dyn="hello-Alice">\nInside\n</dynamic>',
|
||||
},
|
||||
{
|
||||
desc: 'resolves dynamic attributes asynchronously',
|
||||
content: {
|
||||
tag: 'async-dynamic',
|
||||
attrs: { dyn: async (ctx) => `async-${ctx.name}` },
|
||||
content: 'Inside',
|
||||
},
|
||||
context: { name: 'Bob' },
|
||||
expect: '<async-dynamic dyn="async-Bob">\nInside\n</async-dynamic>',
|
||||
},
|
||||
];
|
||||
|
||||
describe('renderPrompt', () => {
|
||||
@@ -350,3 +399,56 @@ describe('renderPrompt', () => {
|
||||
expect(result).toBe(test.expect);
|
||||
});
|
||||
});
|
||||
|
||||
describe('memoize', () => {
|
||||
it('should cache result per context instance', () => {
|
||||
const resolver = vi.fn((ctx: TestContext) => ctx.name);
|
||||
const memoized = memoize(resolver);
|
||||
|
||||
const ctx1 = { name: 'Alice' };
|
||||
const ctx2 = { name: 'Bob' };
|
||||
|
||||
expect(memoized(ctx1)).toBe('Alice');
|
||||
expect(memoized(ctx1)).toBe('Alice');
|
||||
expect(resolver).toHaveBeenCalledTimes(1);
|
||||
|
||||
expect(memoized(ctx2)).toBe('Bob');
|
||||
expect(memoized(ctx2)).toBe('Bob');
|
||||
expect(resolver).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should handle async resolvers', async () => {
|
||||
const resolver = vi.fn(async (ctx: TestContext) => ctx.name);
|
||||
const memoized = memoize(resolver);
|
||||
|
||||
const ctx = { name: 'Async' };
|
||||
|
||||
expect(await memoized(ctx)).toBe('Async');
|
||||
expect(await memoized(ctx)).toBe('Async');
|
||||
expect(resolver).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('parseSlots', () => {
|
||||
it('should return empty array for empty string', () => {
|
||||
expect(parseSlots('')).toEqual([]);
|
||||
});
|
||||
|
||||
it('should return string array for no slots', () => {
|
||||
expect(parseSlots('Hello World')).toEqual(['Hello World']);
|
||||
});
|
||||
|
||||
it('should parse a single slot', () => {
|
||||
expect(parseSlots('${slot1}')).toEqual([{ slot: 'slot1' }]);
|
||||
});
|
||||
|
||||
it('should parse slots at the start, middle, and end', () => {
|
||||
expect(parseSlots('${first} middle ${second} end ${third}')).toEqual([
|
||||
{ slot: 'first' },
|
||||
' middle ',
|
||||
{ slot: 'second' },
|
||||
' end ',
|
||||
{ slot: 'third' },
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -13,10 +13,13 @@ export type PromptSlot = { slot: string; content?: never };
|
||||
export type PromptSection<C> = {
|
||||
/** Add a Markdown heading of appropriate level to this section. */
|
||||
heading?: string;
|
||||
/** Explicitly set the markdown heading depth (1 = #, 2 = ##), overriding tree depth.
|
||||
* 0 is valid and will render as # while setting children to level 1. */
|
||||
level?: number;
|
||||
/** If supplied, wrap this section in an XML tag. */
|
||||
tag?: string;
|
||||
/** If supplied, add attributes to the XML section tag. */
|
||||
attrs?: Record<string, string>;
|
||||
attrs?: Record<string, ContextResolver<C, string>>;
|
||||
/** Formatting of the content inside this section. Defaults to 'block'. */
|
||||
format?:
|
||||
| 'inline'
|
||||
@@ -47,7 +50,11 @@ export type PromptContent<C> = ContextResolver<
|
||||
>;
|
||||
|
||||
type BaseContent = string | StaticSection | PromptSlot | BaseContent[];
|
||||
type StaticSection = Omit<PromptSection<unknown>, 'condition' | 'content'> & {
|
||||
type StaticSection = Omit<
|
||||
PromptSection<unknown>,
|
||||
'condition' | 'content' | 'attrs'
|
||||
> & {
|
||||
attrs?: Record<string, string>;
|
||||
content: BaseContent;
|
||||
};
|
||||
|
||||
@@ -124,10 +131,11 @@ const formatBasic = (
|
||||
}
|
||||
|
||||
const section = c;
|
||||
const currentDepth = section.level ?? depth;
|
||||
const sectionFormat = section.format || 'block';
|
||||
const innerContent = formatBasic(
|
||||
section.content,
|
||||
depth + 1,
|
||||
currentDepth + 1,
|
||||
sectionFormat,
|
||||
resolvedContributions,
|
||||
).trim();
|
||||
@@ -141,7 +149,7 @@ const formatBasic = (
|
||||
}
|
||||
|
||||
if (section.heading) {
|
||||
const headingLevel = Math.min(depth, 6);
|
||||
const headingLevel = Math.max(1, Math.min(currentDepth, 6));
|
||||
result = `\n\n${'#'.repeat(headingLevel)} ${section.heading}\n\n${result.trim()}`;
|
||||
}
|
||||
|
||||
@@ -224,10 +232,20 @@ export async function renderPrompt<C = SystemPromptOptions>({
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
let resolvedAttrs: Record<string, string> | undefined = undefined;
|
||||
if (section.attrs) {
|
||||
resolvedAttrs = {};
|
||||
for (const [key, value] of Object.entries(section.attrs)) {
|
||||
resolvedAttrs[key] =
|
||||
typeof value === 'function' ? await value(context) : value;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
heading: section.heading,
|
||||
level: section.level,
|
||||
tag: section.tag,
|
||||
attrs: section.attrs,
|
||||
attrs: resolvedAttrs,
|
||||
format: section.format,
|
||||
content: resolvedInner,
|
||||
};
|
||||
@@ -264,3 +282,51 @@ export function prompt<C = SystemPromptOptions>(
|
||||
): PromptContent<C> {
|
||||
return content.length === 1 ? content[0] : content;
|
||||
}
|
||||
|
||||
type Resolver<C, T> = (ctx: C) => T | Promise<T>;
|
||||
|
||||
/**
|
||||
* Creates a memoized selector that caches its result per context instance.
|
||||
* Ideal for efficiently sharing derived state across a prompt tree.
|
||||
*/
|
||||
export function memoize<C extends object, T>(
|
||||
resolver: Resolver<C, T>,
|
||||
): (ctx: C) => T | Promise<T> {
|
||||
const cache = new WeakMap<C, T | Promise<T>>();
|
||||
|
||||
return (ctx: C) => {
|
||||
if (cache.has(ctx)) {
|
||||
return cache.get(ctx)!;
|
||||
}
|
||||
const result = resolver(ctx);
|
||||
cache.set(ctx, result);
|
||||
return result;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a string containing placeholders like `${slotName}` into a PromptContent array.
|
||||
* Interleaves literal string segments with `{ slot: 'slotName' }` objects.
|
||||
*/
|
||||
export function parseSlots<C>(template: string): Array<PromptContent<C>> {
|
||||
if (!template) return [];
|
||||
|
||||
const regex = /\$\{([^}]+)\}/g;
|
||||
const parts: Array<PromptContent<C>> = [];
|
||||
let lastIndex = 0;
|
||||
let match;
|
||||
|
||||
while ((match = regex.exec(template)) !== null) {
|
||||
if (match.index > lastIndex) {
|
||||
parts.push(template.slice(lastIndex, match.index));
|
||||
}
|
||||
parts.push({ slot: match[1] });
|
||||
lastIndex = regex.lastIndex;
|
||||
}
|
||||
|
||||
if (lastIndex < template.length) {
|
||||
parts.push(template.slice(lastIndex));
|
||||
}
|
||||
|
||||
return parts;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user