diff --git a/docs/docs/concepts/low_level.md b/docs/docs/concepts/low_level.md index cb24e4178..d17b6a0af 100644 --- a/docs/docs/concepts/low_level.md +++ b/docs/docs/concepts/low_level.md @@ -255,12 +255,18 @@ const StateWithDocuments = Annotation.Root({ Just like `MessagesAnnotation`, there is a prebuilt Zod schema called `MessagesZodSchema` that provides the same functionality, but uses Zod for defining the state instead of the `Annotation` API. +You can use `MessagesZodSchema` in place of `MessagesAnnotation`, and use Zod to tack on additional state beyond just messages. + ```typescript import { MessagesZodSchema, StateGraph } from "@langchain/langgraph"; import { z } from "zod"; -const graph = new StateGraph(MessagesZodSchema) +const StateWithMessages = MessagesZodSchema.and( + z.object({ counter: z.array(z.string()) }) +); + +const graph = new StateGraph(StateWithMessages) .addNode(...) ... ``` diff --git a/libs/langgraph/src/graph/messages_annotation.ts b/libs/langgraph/src/graph/messages_annotation.ts index 25635f05c..ae61454e7 100644 --- a/libs/langgraph/src/graph/messages_annotation.ts +++ b/libs/langgraph/src/graph/messages_annotation.ts @@ -86,6 +86,19 @@ export const MessagesAnnotation = Annotation.Root({ * default: () => [], * }), * }); + * + * You can also expand this schema to include other fields and retain the core messages field using native zod methods like `z.intersection()` or `.and()` + * @example + * ```ts + * import { MessagesZodState, StateGraph } from "@langchain/langgraph"; + * + * const schema = MessagesZodState.and( + * z.object({ count: z.number() }), + * ); + * + * const graph = new StateGraph(schema) + * .addNode(...) + * ... * ``` */ export const MessagesZodState = z.object({ diff --git a/libs/langgraph/src/graph/zod/state.ts b/libs/langgraph/src/graph/zod/state.ts index 8fb7eeb9a..884c50774 100644 --- a/libs/langgraph/src/graph/zod/state.ts +++ b/libs/langgraph/src/graph/zod/state.ts @@ -20,9 +20,13 @@ export interface Meta { default?: () => ValueType; } -export type AnyZodObject = z.ZodObject; +type RawZodObject = z.ZodObject; -function isZodType(value: unknown): value is z.ZodType { +export type AnyZodObject = + | RawZodObject + | z.ZodIntersection; + +export function isZodType(value: unknown): value is z.ZodType { return ( typeof value === "object" && value != null && @@ -48,11 +52,46 @@ export function isZodDefault( * @internal */ export function isAnyZodObject(value: unknown): value is AnyZodObject { - return ( - isZodType(value) && - "partial" in value && - typeof value.partial === "function" - ); + if (isZodObject(value)) { + return true; + } + if (isZodObjectIntersection(value)) { + return true; + } + return false; +} + +/** + * @internal + */ +export function isZodObject( + value: unknown +): value is z.ZodObject { + if (!isZodType(value)) return false; + if ("partial" in value && typeof value.partial === "function") { + return true; + } + return true; +} + +/** + * @internal + */ +export function isZodObjectIntersection( + value: unknown +): value is z.ZodIntersection { + if (!isZodType(value)) return false; + const maybeDef = (value as { _def?: unknown })._def; + if ( + !maybeDef || + typeof maybeDef !== "object" || + !("left" in maybeDef) || + !("right" in maybeDef) + ) { + return false; + } + const { left, right } = maybeDef as { left: unknown; right: unknown }; + return isAnyZodObject(left) && isAnyZodObject(right); } export function withLangGraph( @@ -92,33 +131,52 @@ export function extendMeta( META_MAP.set(schema, newMeta); } -export type ZodToStateDefinition = { - [key in keyof T["shape"]]: T["shape"][key] extends z.ZodType< - infer V, - z.ZodTypeDef, - infer U - > - ? BaseChannel +export type ZodToStateDefinition = + // Handle ZodObject + T extends z.ZodObject + ? { + [K in keyof Shape]: Shape[K] extends z.ZodType< + infer V, + z.ZodTypeDef, + infer U + > + ? BaseChannel + : never; + } + : // Handle ZodIntersection of two ZodObjects + T extends z.ZodIntersection + ? ZodToStateDefinition & ZodToStateDefinition : never; -}; - -export function getChannelsFromZod( - schema: z.ZodObject -): ZodToStateDefinition> { - const channels = {} as Record; - for (const key in schema.shape) { - if (Object.prototype.hasOwnProperty.call(schema.shape, key)) { - const keySchema = schema.shape[key]; - const meta = getMeta(keySchema); - if (meta?.reducer) { - channels[key] = new BinaryOperatorAggregate>( - meta.reducer.fn, - meta.default - ); - } else { - channels[key] = new LastValue(); + +export function getChannelsFromZod( + schema: T +): ZodToStateDefinition { + // Handle ZodObject + if (isZodObject(schema)) { + const channels = {} as Record; + for (const key in schema.shape) { + if (Object.prototype.hasOwnProperty.call(schema.shape, key)) { + const keySchema = schema.shape[key]; + const meta = getMeta(keySchema); + if (meta?.reducer) { + type ValueType = z.infer[typeof key]; + channels[key] = new BinaryOperatorAggregate( + meta.reducer.fn, + meta.default + ); + } else { + channels[key] = new LastValue(); + } } } + return channels as ZodToStateDefinition; + } + // Handle ZodIntersection of two ZodObjects + if (isZodObjectIntersection(schema)) { + // Recursively extract channels from both sides and merge + const left = getChannelsFromZod(schema._def.left as AnyZodObject); + const right = getChannelsFromZod(schema._def.right as AnyZodObject); + return { ...left, ...right } as ZodToStateDefinition; } - return channels as ZodToStateDefinition>; + return {} as ZodToStateDefinition; } diff --git a/libs/langgraph/src/tests/prebuilt.test.ts b/libs/langgraph/src/tests/prebuilt.test.ts index 8e4428e97..1fdb053b1 100644 --- a/libs/langgraph/src/tests/prebuilt.test.ts +++ b/libs/langgraph/src/tests/prebuilt.test.ts @@ -7,6 +7,7 @@ import { StructuredTool, tool } from "@langchain/core/tools"; import { AIMessage, BaseMessage, + BaseMessageLike, HumanMessage, RemoveMessage, SystemMessage, @@ -44,6 +45,7 @@ import { MessagesAnnotation, MessagesZodState, } from "../graph/messages_annotation.js"; +import { withLangGraph } from "../graph/zod/state.js"; // Tracing slows down the tests beforeAll(() => { @@ -1293,6 +1295,41 @@ describe("MessagesZodState", () => { expect(result.messages[0].content).toEqual("updated"); expect(result.messages[1].content).toEqual("message 2"); }); + + it("should handle intersection with additional fields", async () => { + const schema = z.object({ + messages: withLangGraph(z.custom(), { + reducer: { + schema: z.union([ + z.custom(), + z.array(z.custom()), + ]), + fn: messagesStateReducer, + }, + default: () => [], + }), + count: z.number(), + }); + + type State = z.infer; + + const graph = new StateGraph(schema) + .addNode("process", ({ messages, count }: State) => ({ + messages: [...messages, new HumanMessage(`count: ${count}`)], + count: count + 1, + })) + .addEdge("__start__", "process") + .compile(); + + const result = await graph.invoke({ + messages: [new HumanMessage("start")], + count: 0, + }); + + expect(result.messages.length).toEqual(2); + expect(result.messages[1].content).toEqual("count: 0"); + expect(result.count).toEqual(1); + }); }); describe("messagesStateReducer", () => { diff --git a/libs/langgraph/src/tests/zod.test.ts b/libs/langgraph/src/tests/zod.test.ts new file mode 100644 index 000000000..9ee6a50d7 --- /dev/null +++ b/libs/langgraph/src/tests/zod.test.ts @@ -0,0 +1,143 @@ +import { z } from "zod"; +import { + isZodType, + isZodDefault, + isAnyZodObject, + isZodObject, + isZodObjectIntersection, + withLangGraph, + getMeta, + extendMeta, + getChannelsFromZod, + type Meta, +} from "../graph/zod/state.js"; +import { BinaryOperatorAggregate } from "../channels/binop.js"; +import { LastValue } from "../channels/last_value.js"; + +describe("Zod State Functions", () => { + describe("Type Checking Functions", () => { + test("isZodType", () => { + expect(isZodType(z.string())).toBe(true); + expect(isZodType(z.number())).toBe(true); + expect(isZodType({})).toBe(false); + expect(isZodType(null)).toBe(false); + expect(isZodType(undefined)).toBe(false); + }); + + test("isZodDefault", () => { + expect(isZodDefault(z.string().default("test"))).toBe(true); + expect(isZodDefault(z.string())).toBe(false); + expect(isZodDefault({})).toBe(false); + }); + + test("isZodObject", () => { + const schema = z.object({ name: z.string() }); + expect(isZodObject(schema)).toBe(true); + expect(isZodObject(z.string())).toBe(false); + expect(isZodObject({})).toBe(false); + }); + + test("isZodObjectIntersection", () => { + const schema1 = z.object({ name: z.string() }); + const schema2 = z.object({ age: z.number() }); + const intersection = schema1.and(schema2); + + expect(isZodObjectIntersection(intersection)).toBe(true); + expect(isZodObjectIntersection(schema1)).toBe(false); + expect(isZodObjectIntersection({})).toBe(false); + }); + + test("isAnyZodObject", () => { + const schema = z.object({ name: z.string() }); + const schema1 = z.object({ name: z.string() }); + const schema2 = z.object({ age: z.number() }); + const intersection = schema1.and(schema2); + + expect(isAnyZodObject(schema)).toBe(true); + expect(isAnyZodObject(intersection)).toBe(true); + expect(isAnyZodObject(z.string())).toBe(false); + expect(isAnyZodObject({})).toBe(false); + }); + }); + + describe("Meta Functions", () => { + test("withLangGraph and getMeta", () => { + const schema = z.string(); + const meta: Meta = { + jsonSchemaExtra: { + langgraph_type: "prompt", + }, + reducer: { + fn: (a: string, b: string) => a + b, + }, + default: () => "default", + }; + + const enhancedSchema = withLangGraph(schema, meta); + const retrievedMeta = getMeta(enhancedSchema); + + expect(retrievedMeta).toEqual(meta); + }); + + test("extendMeta", () => { + const schema = z.string(); + const initialMeta: Meta = { + jsonSchemaExtra: { + langgraph_type: "prompt", + }, + }; + + withLangGraph(schema, initialMeta); + + extendMeta(schema, (existingMeta: Meta | undefined) => ({ + ...existingMeta, + reducer: { + fn: (a: string, b: string) => a + b, + }, + default: () => "default", + })); + + const updatedMeta = getMeta(schema); + expect(updatedMeta?.reducer).toBeDefined(); + expect(updatedMeta?.default).toBeDefined(); + }); + }); + + describe("getChannelsFromZod", () => { + test("simple object schema", () => { + const schema = z.object({ + name: z.string(), + count: z.number().default(0), + }); + + const channels = getChannelsFromZod(schema); + expect(channels.name).toBeInstanceOf(LastValue); + expect(channels.count).toBeInstanceOf(LastValue); + }); + + test("schema with reducer", () => { + const schema = z.object({ + messages: withLangGraph(z.array(z.string()), { + reducer: { + fn: (a: string[], b: string[]) => [...a, ...b], + schema: z.array(z.string()), + }, + default: () => [], + }), + }); + + const channels = getChannelsFromZod(schema); + expect(channels.messages).toBeInstanceOf(BinaryOperatorAggregate); + }); + + test("intersection schema", () => { + const schema1 = z.object({ name: z.string() }); + const schema2 = z.object({ age: z.number() }); + const intersection = schema1.and(schema2); + + const channels = getChannelsFromZod(intersection); + expect(channels.name).toBeInstanceOf(LastValue); + expect(channels.age).toBeInstanceOf(LastValue); + }); + }); +});