Skip to content

Commit 44ac88d

Browse files
authored
Agent parse conversation history (#36)
* Agent parse conversation history * Add doc strings. Extract shared logic * Add null guard * Address Copilot comments * Reuse content types
1 parent 54e5806 commit 44ac88d

File tree

3 files changed

+267
-20
lines changed

3 files changed

+267
-20
lines changed

src/__tests__/unit/agents.test.ts

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,90 @@ describe('GuardrailAgent', () => {
365365
expect(typeof result.tripwireTriggered).toBe('boolean');
366366
});
367367

368+
it('passes the latest user message text to guardrails for conversation inputs', async () => {
369+
process.env.OPENAI_API_KEY = 'test';
370+
const config = {
371+
version: 1,
372+
input: {
373+
version: 1,
374+
guardrails: [{ name: 'Moderation', config: {} }],
375+
},
376+
};
377+
378+
const { instantiateGuardrails } = await import('../../runtime');
379+
const runSpy = vi.fn().mockResolvedValue({
380+
tripwireTriggered: false,
381+
info: { guardrail_name: 'Moderation' },
382+
});
383+
384+
vi.mocked(instantiateGuardrails).mockImplementationOnce(() =>
385+
Promise.resolve([
386+
{
387+
definition: {
388+
name: 'Moderation',
389+
description: 'Moderation guardrail',
390+
mediaType: 'text/plain',
391+
configSchema: z.object({}),
392+
checkFn: vi.fn(),
393+
metadata: {},
394+
ctxRequirements: z.object({}),
395+
schema: () => ({}),
396+
instantiate: vi.fn(),
397+
},
398+
config: {},
399+
run: runSpy,
400+
} as unknown as Parameters<typeof instantiateGuardrails>[0] extends Promise<infer T>
401+
? T extends readonly (infer U)[]
402+
? U
403+
: never
404+
: never,
405+
])
406+
);
407+
408+
const agent = (await GuardrailAgent.create(
409+
config,
410+
'Conversation Agent',
411+
'Handle multi-turn conversations'
412+
)) as MockAgent;
413+
414+
const guardrail = agent.inputGuardrails[0] as unknown as {
415+
execute: (args: { input: unknown; context?: unknown }) => Promise<{
416+
outputInfo: Record<string, unknown>;
417+
tripwireTriggered: boolean;
418+
}>;
419+
};
420+
421+
const conversation = [
422+
{ role: 'system', content: 'You are helpful.' },
423+
{ role: 'user', content: [{ type: 'input_text', text: 'First question?' }] },
424+
{ role: 'assistant', content: [{ type: 'output_text', text: 'An answer.' }] },
425+
{
426+
role: 'user',
427+
content: [
428+
{ type: 'input_text', text: 'Latest user message' },
429+
{ type: 'input_text', text: 'with additional context.' },
430+
],
431+
},
432+
];
433+
434+
const result = await guardrail.execute({ input: conversation, context: {} });
435+
436+
expect(runSpy).toHaveBeenCalledTimes(1);
437+
const [ctxArgRaw, dataArg] = runSpy.mock.calls[0] as [unknown, string];
438+
const ctxArg = ctxArgRaw as { getConversationHistory?: () => unknown[] };
439+
expect(dataArg).toBe('Latest user message with additional context.');
440+
expect(typeof ctxArg.getConversationHistory).toBe('function');
441+
442+
const history = ctxArg.getConversationHistory?.() as Array<{ content?: unknown }> | undefined;
443+
expect(Array.isArray(history)).toBe(true);
444+
expect(history && history[history.length - 1]?.content).toBe(
445+
'Latest user message with additional context.'
446+
);
447+
448+
expect(result.tripwireTriggered).toBe(false);
449+
expect(result.outputInfo.input).toBe('Latest user message with additional context.');
450+
});
451+
368452
it('should handle guardrail execution errors based on raiseGuardrailErrors setting', async () => {
369453
process.env.OPENAI_API_KEY = 'test';
370454
const config = {

src/agents.ts

Lines changed: 179 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ import type {
1313
InputGuardrailFunctionArgs,
1414
OutputGuardrailFunctionArgs,
1515
} from '@openai/agents-core';
16-
import { GuardrailLLMContext, GuardrailResult, TextOnlyContent, ContentPart } from './types';
17-
import { ContentUtils } from './utils/content';
16+
import { GuardrailLLMContext, GuardrailResult, TextOnlyContent } from './types';
17+
import { TEXT_CONTENT_TYPES } from './utils/content';
1818
import {
1919
loadPipelineBundles,
2020
instantiateGuardrails,
@@ -250,6 +250,180 @@ function ensureGuardrailContext(
250250
} as GuardrailLLMContext;
251251
}
252252

253+
const TEXTUAL_CONTENT_TYPES = new Set<string>(TEXT_CONTENT_TYPES);
254+
const MAX_CONTENT_EXTRACTION_DEPTH = 10;
255+
256+
/**
257+
* Extract text from any nested content value with optional type filtering.
258+
*
259+
* @param value Arbitrary content value (string, array, or object) to inspect.
260+
* @param depth Current recursion depth, used to guard against circular structures.
261+
* @param filterByType When true, only content parts with recognized text types are returned.
262+
* @returns The extracted text, or an empty string when no text is found.
263+
*/
264+
function extractTextFromValue(value: unknown, depth: number, filterByType: boolean): string {
265+
if (depth > MAX_CONTENT_EXTRACTION_DEPTH) {
266+
return '';
267+
}
268+
269+
if (typeof value === 'string') {
270+
return value.trim();
271+
}
272+
273+
if (Array.isArray(value)) {
274+
const parts: string[] = [];
275+
for (const item of value) {
276+
const text = extractTextFromValue(item, depth + 1, filterByType);
277+
if (text) {
278+
parts.push(text);
279+
}
280+
}
281+
return parts.join(' ').trim();
282+
}
283+
284+
if (value && typeof value === 'object') {
285+
const record = value as Record<string, unknown>;
286+
const typeValue = typeof record.type === 'string' ? record.type : null;
287+
const isRecognizedTextType = typeValue ? TEXTUAL_CONTENT_TYPES.has(typeValue) : false;
288+
289+
if (typeof record.text === 'string') {
290+
if (!filterByType || isRecognizedTextType || typeValue === null) {
291+
return record.text.trim();
292+
}
293+
}
294+
295+
const contentValue = record.content;
296+
// If a direct text field was skipped due to type filtering, fall back to nested content.
297+
if (contentValue != null) {
298+
const nested = extractTextFromValue(contentValue, depth + 1, filterByType);
299+
if (nested) {
300+
return nested;
301+
}
302+
}
303+
}
304+
305+
return '';
306+
}
307+
308+
/**
309+
* Extract text from structured content parts (e.g., the `content` field on a message).
310+
*
311+
* Only textual content-part types enumerated in TEXTUAL_CONTENT_TYPES are considered so
312+
* that non-text modalities (images, tools, etc.) remain ignored.
313+
*/
314+
function extractTextFromContentParts(content: unknown, depth = 0): string {
315+
return extractTextFromValue(content, depth, true);
316+
}
317+
318+
/**
319+
* Extract text from a single message entry.
320+
*
321+
* Handles strings, arrays of content parts, or message-like objects that contain a
322+
* `content` collection or a plain `text` field.
323+
*/
324+
function extractTextFromMessageEntry(entry: unknown, depth = 0): string {
325+
if (depth > MAX_CONTENT_EXTRACTION_DEPTH) {
326+
return '';
327+
}
328+
329+
if (entry == null) {
330+
return '';
331+
}
332+
333+
if (typeof entry === 'string') {
334+
return entry.trim();
335+
}
336+
337+
if (Array.isArray(entry)) {
338+
return extractTextFromContentParts(entry, depth + 1);
339+
}
340+
341+
if (typeof entry === 'object') {
342+
const record = entry as Record<string, unknown>;
343+
344+
if (record.content !== undefined) {
345+
const contentText = extractTextFromContentParts(record.content, depth + 1);
346+
if (contentText) {
347+
return contentText;
348+
}
349+
}
350+
351+
if (typeof record.text === 'string') {
352+
return record.text.trim();
353+
}
354+
}
355+
356+
return extractTextFromValue(entry, depth + 1, false /* allow all types when falling back */);
357+
}
358+
359+
/**
360+
* Extract the latest user-authored text from raw agent input.
361+
*
362+
* Accepts strings, message objects, or arrays of mixed items. Arrays are scanned
363+
* from newest to oldest, returning the first user-role message with textual content.
364+
*/
365+
function extractTextFromAgentInput(input: unknown): string {
366+
if (input == null) {
367+
return '';
368+
}
369+
370+
if (typeof input === 'string') {
371+
return input.trim();
372+
}
373+
374+
if (Array.isArray(input)) {
375+
for (let idx = input.length - 1; idx >= 0; idx -= 1) {
376+
const candidate = input[idx];
377+
if (candidate && typeof candidate === 'object') {
378+
const record = candidate as Record<string, unknown>;
379+
if (record.role === 'user') {
380+
const text = extractTextFromMessageEntry(candidate);
381+
if (text) {
382+
return text;
383+
}
384+
}
385+
} else if (typeof candidate === 'string') {
386+
const text = candidate.trim();
387+
if (text) {
388+
return text;
389+
}
390+
}
391+
}
392+
return '';
393+
}
394+
395+
if (input && typeof input === 'object') {
396+
const record = input as Record<string, unknown>;
397+
if (record.role === 'user') {
398+
const text = extractTextFromMessageEntry(record);
399+
if (text) {
400+
return text;
401+
}
402+
}
403+
404+
if (record.content != null) {
405+
const contentText = extractTextFromContentParts(record.content);
406+
if (contentText) {
407+
return contentText;
408+
}
409+
}
410+
411+
if (typeof record.text === 'string') {
412+
return record.text.trim();
413+
}
414+
}
415+
416+
if (
417+
typeof input === 'number' ||
418+
typeof input === 'boolean' ||
419+
typeof input === 'bigint'
420+
) {
421+
return String(input);
422+
}
423+
424+
return '';
425+
}
426+
253427
function extractLatestUserText(history: NormalizedConversationEntry[]): string {
254428
for (let i = history.length - 1; i >= 0; i -= 1) {
255429
const entry = history[i];
@@ -261,20 +435,9 @@ function extractLatestUserText(history: NormalizedConversationEntry[]): string {
261435
}
262436

263437
function resolveInputText(input: unknown, history: NormalizedConversationEntry[]): string {
264-
if (typeof input === 'string') {
265-
return input;
266-
}
267-
268-
if (input && typeof input === 'object' && 'content' in (input as Record<string, unknown>)) {
269-
const content = (input as { content: string | ContentPart[] }).content;
270-
const message = {
271-
role: 'user',
272-
content,
273-
};
274-
const extracted = ContentUtils.extractTextFromMessage(message);
275-
if (extracted) {
276-
return extracted;
277-
}
438+
const directText = extractTextFromAgentInput(input);
439+
if (directText) {
440+
return directText;
278441
}
279442

280443
return extractLatestUserText(history);

src/utils/content.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77

88
import { Message, ContentPart, TextContentPart, TextOnlyMessageArray } from '../types';
99

10+
export const TEXT_CONTENT_TYPES = ['input_text', 'text', 'output_text', 'summary_text'] as const;
11+
const TEXT_CONTENT_TYPES_SET = new Set<string>(TEXT_CONTENT_TYPES);
12+
1013
export class ContentUtils {
11-
// Clear: what types are considered text
12-
private static readonly TEXT_TYPES = ['input_text', 'text', 'output_text', 'summary_text'] as const;
13-
1414
/**
1515
* Check if a content part is text-based.
1616
*/
1717
static isText(part: ContentPart): boolean {
18-
return this.TEXT_TYPES.includes(part.type as typeof this.TEXT_TYPES[number]);
18+
return typeof part.type === 'string' && TEXT_CONTENT_TYPES_SET.has(part.type);
1919
}
2020

2121
/**

0 commit comments

Comments
 (0)