Skip to content

Commit 775a976

Browse files
feat(langchain/createAgent): add before/after agent hooks (#9161)
1 parent 54debd8 commit 775a976

21 files changed

+2730
-730
lines changed

libs/langchain/src/agents/ReactAgent.ts

Lines changed: 239 additions & 37 deletions
Large diffs are not rendered by default.

libs/langchain/src/agents/annotation.ts

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,12 @@ export function createAgentAnnotationConditional<
4949
}
5050
}
5151

52-
if (!hasStructuredResponse) {
53-
return z.object(zodSchema);
52+
// Only include structuredResponse when responseFormat is defined
53+
if (hasStructuredResponse) {
54+
zodSchema.structuredResponse = z.any().optional();
5455
}
5556

56-
return z.object({
57-
...zodSchema,
58-
structuredResponse: z.any().optional(),
59-
});
57+
return z.object(zodSchema);
6058
}
6159

6260
export const PreHookAnnotation: AnnotationRoot<{

libs/langchain/src/agents/middleware.ts

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ export function createMiddleware<
7777
* - Undefined
7878
*/
7979
contextSchema?: TContextSchema;
80+
/**
81+
* Explitictly defines which targets are allowed to be jumped to from the `beforeAgent` hook.
82+
*/
83+
beforeAgentJumpTo?: JumpToTarget[];
8084
/**
8185
* Explitictly defines which targets are allowed to be jumped to from the `beforeModel` hook.
8286
*/
@@ -85,6 +89,10 @@ export function createMiddleware<
8589
* Explitictly defines which targets are allowed to be jumped to from the `afterModel` hook.
8690
*/
8791
afterModelJumpTo?: JumpToTarget[];
92+
/**
93+
* Explitictly defines which targets are allowed to be jumped to from the `afterAgent` hook.
94+
*/
95+
afterAgentJumpTo?: JumpToTarget[];
8896
/**
8997
* Additional tools registered by the middleware.
9098
*/
@@ -209,6 +217,43 @@ export function createMiddleware<
209217
>
210218
) => Promise<AIMessage> | AIMessage
211219
) => Promise<AIMessage> | AIMessage;
220+
/**
221+
* The function to run before the agent execution starts. This function is called once at the start of the agent invocation.
222+
* It allows to modify the state of the agent before any model calls or tool executions.
223+
*
224+
* @param state - The middleware state
225+
* @param runtime - The middleware runtime
226+
* @returns The modified middleware state or undefined to pass through
227+
*/
228+
beforeAgent?: (
229+
state: (TSchema extends InteropZodObject
230+
? InferInteropZodInput<TSchema>
231+
: {}) &
232+
AgentBuiltInState,
233+
runtime: Runtime<
234+
TContextSchema extends InteropZodObject
235+
? InferInteropZodOutput<TContextSchema>
236+
: TContextSchema extends InteropZodDefault<any>
237+
? InferInteropZodOutput<TContextSchema>
238+
: TContextSchema extends InteropZodOptional<any>
239+
? Partial<InferInteropZodOutput<TContextSchema>>
240+
: never
241+
>
242+
) =>
243+
| Promise<
244+
MiddlewareResult<
245+
Partial<
246+
TSchema extends InteropZodObject
247+
? InferInteropZodInput<TSchema>
248+
: {}
249+
>
250+
>
251+
>
252+
| MiddlewareResult<
253+
Partial<
254+
TSchema extends InteropZodObject ? InferInteropZodInput<TSchema> : {}
255+
>
256+
>;
212257
/**
213258
* The function to run before the model call. This function is called before the model is invoked and before the `wrapModelRequest` hook.
214259
* It allows to modify the state of the agent.
@@ -283,13 +328,52 @@ export function createMiddleware<
283328
TSchema extends InteropZodObject ? InferInteropZodInput<TSchema> : {}
284329
>
285330
>;
331+
/**
332+
* The function to run after the agent execution completes. This function is called once at the end of the agent invocation.
333+
* It allows to modify the final state of the agent after all model calls and tool executions are complete.
334+
*
335+
* @param state - The middleware state
336+
* @param runtime - The middleware runtime
337+
* @returns The modified middleware state or undefined to pass through
338+
*/
339+
afterAgent?: (
340+
state: (TSchema extends InteropZodObject
341+
? InferInteropZodInput<TSchema>
342+
: {}) &
343+
AgentBuiltInState,
344+
runtime: Runtime<
345+
TContextSchema extends InteropZodObject
346+
? InferInteropZodOutput<TContextSchema>
347+
: TContextSchema extends InteropZodDefault<any>
348+
? InferInteropZodOutput<TContextSchema>
349+
: TContextSchema extends InteropZodOptional<any>
350+
? Partial<InferInteropZodOutput<TContextSchema>>
351+
: never
352+
>
353+
) =>
354+
| Promise<
355+
MiddlewareResult<
356+
Partial<
357+
TSchema extends InteropZodObject
358+
? InferInteropZodInput<TSchema>
359+
: {}
360+
>
361+
>
362+
>
363+
| MiddlewareResult<
364+
Partial<
365+
TSchema extends InteropZodObject ? InferInteropZodInput<TSchema> : {}
366+
>
367+
>;
286368
}): AgentMiddleware<TSchema, TContextSchema, any> {
287369
const middleware: AgentMiddleware<TSchema, TContextSchema, any> = {
288370
name: config.name,
289371
stateSchema: config.stateSchema,
290372
contextSchema: config.contextSchema,
373+
beforeAgentJumpTo: config.beforeAgentJumpTo,
291374
beforeModelJumpTo: config.beforeModelJumpTo,
292375
afterModelJumpTo: config.afterModelJumpTo,
376+
afterAgentJumpTo: config.afterAgentJumpTo,
293377
tools: config.tools ?? [],
294378
};
295379

@@ -303,6 +387,24 @@ export function createMiddleware<
303387
Promise.resolve(config.wrapModelRequest!(request, handler));
304388
}
305389

390+
if (config.beforeAgent) {
391+
middleware.beforeAgent = async (state, runtime) =>
392+
Promise.resolve(
393+
config.beforeAgent!(
394+
state,
395+
runtime as Runtime<
396+
TContextSchema extends InteropZodObject
397+
? InferInteropZodOutput<TContextSchema>
398+
: TContextSchema extends InteropZodDefault<any>
399+
? InferInteropZodOutput<TContextSchema>
400+
: TContextSchema extends InteropZodOptional<any>
401+
? Partial<InferInteropZodOutput<TContextSchema>>
402+
: never
403+
>
404+
)
405+
);
406+
}
407+
306408
if (config.beforeModel) {
307409
middleware.beforeModel = async (state, runtime) =>
308410
Promise.resolve(
@@ -339,5 +441,23 @@ export function createMiddleware<
339441
);
340442
}
341443

444+
if (config.afterAgent) {
445+
middleware.afterAgent = async (state, runtime) =>
446+
Promise.resolve(
447+
config.afterAgent!(
448+
state,
449+
runtime as Runtime<
450+
TContextSchema extends InteropZodObject
451+
? InferInteropZodOutput<TContextSchema>
452+
: TContextSchema extends InteropZodDefault<any>
453+
? InferInteropZodOutput<TContextSchema>
454+
: TContextSchema extends InteropZodOptional<any>
455+
? Partial<InferInteropZodOutput<TContextSchema>>
456+
: never
457+
>
458+
)
459+
);
460+
}
461+
342462
return middleware;
343463
}

libs/langchain/src/agents/middleware/tests/hitl.test.ts

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { z } from "zod/v3";
22
import { describe, it, expect, vi, beforeEach } from "vitest";
33
import { tool } from "@langchain/core/tools";
4-
import { HumanMessage, ToolMessage } from "@langchain/core/messages";
4+
import { AIMessage, HumanMessage, ToolMessage } from "@langchain/core/messages";
55
import { Command } from "@langchain/langgraph";
66
import { MemorySaver } from "@langchain/langgraph-checkpoint";
77

@@ -95,6 +95,7 @@ describe("humanInTheLoopMiddleware", () => {
9595
args: { filename: "greeting.txt", content: "Hello World" },
9696
},
9797
],
98+
[],
9899
],
99100
});
100101

@@ -136,14 +137,35 @@ describe("humanInTheLoopMiddleware", () => {
136137

137138
// Verify response
138139
const mathMessages = mathResult.messages;
139-
expect(mathMessages).toHaveLength(3);
140+
expect(mathMessages).toHaveLength(4);
141+
/**
142+
* 1st message: Human message with prompt
143+
*/
144+
expect(HumanMessage.isInstance(mathMessages[0])).toBe(true);
140145
expect(mathMessages[0]).toEqual(
141146
new _AnyIdHumanMessage("Calculate 42 * 17")
142147
);
148+
/**
149+
* 2nd message: AIMessage calling tool
150+
*/
151+
expect(AIMessage.isInstance(mathMessages[1])).toBe(true);
143152
expect(mathMessages[1].content).toEqual(
144153
expect.stringContaining("You are a helpful assistant.")
145154
);
146-
expect(mathMessages[2].content).toBe("42 * 17 = 714");
155+
/**
156+
* 3rd message: ToolMessage with tool response
157+
*/
158+
expect(ToolMessage.isInstance(mathMessages[2])).toBe(true);
159+
expect(mathMessages[2].content).toEqual(
160+
expect.stringContaining("42 * 17 = 714")
161+
);
162+
/**
163+
* 4th message: AI response
164+
*/
165+
expect(AIMessage.isInstance(mathMessages[3])).toBe(true);
166+
expect(mathMessages[3].content).toEqual(
167+
expect.stringContaining("42 * 17 = 714")
168+
);
147169

148170
// Test 2: Write file tool (requires approval)
149171
model.index = 1;

libs/langchain/src/agents/middleware/types.ts

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,10 @@ export interface AgentMiddleware<
8686
stateSchema?: TSchema;
8787
contextSchema?: TContextSchema;
8888
name: string;
89+
beforeAgentJumpTo?: JumpToTarget[];
8990
beforeModelJumpTo?: JumpToTarget[];
9091
afterModelJumpTo?: JumpToTarget[];
92+
afterAgentJumpTo?: JumpToTarget[];
9193
tools?: (ClientTool | ServerTool)[];
9294
/**
9395
* Wraps tool execution with custom logic. This allows you to:
@@ -200,6 +202,19 @@ export interface AgentMiddleware<
200202
>
201203
) => Promise<AIMessage> | AIMessage
202204
): Promise<AIMessage> | AIMessage;
205+
beforeAgent?(
206+
state: (TSchema extends InteropZodObject
207+
? InferInteropZodInput<TSchema>
208+
: {}) &
209+
AgentBuiltInState,
210+
runtime: Runtime<TFullContext>
211+
): Promise<
212+
MiddlewareResult<
213+
Partial<
214+
TSchema extends InteropZodObject ? InferInteropZodInput<TSchema> : {}
215+
>
216+
>
217+
>;
203218
beforeModel?(
204219
state: (TSchema extends InteropZodObject
205220
? InferInteropZodInput<TSchema>
@@ -226,6 +241,19 @@ export interface AgentMiddleware<
226241
>
227242
>
228243
>;
244+
afterAgent?(
245+
state: (TSchema extends InteropZodObject
246+
? InferInteropZodInput<TSchema>
247+
: {}) &
248+
AgentBuiltInState,
249+
runtime: Runtime<TFullContext>
250+
): Promise<
251+
MiddlewareResult<
252+
Partial<
253+
TSchema extends InteropZodObject ? InferInteropZodInput<TSchema> : {}
254+
>
255+
>
256+
>;
229257
}
230258

231259
/**
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import { z } from "zod/v3";
2+
import { RunnableConfig } from "@langchain/core/runnables";
3+
import { MiddlewareNode, MiddlewareNodeOptions } from "./middleware.js";
4+
import type { AgentMiddleware, MiddlewareResult } from "../middleware/types.js";
5+
import type { AgentBuiltInState, Runtime } from "../runtime.js";
6+
7+
/**
8+
* Node for executing a single middleware's afterAgent hook.
9+
*/
10+
export class AfterAgentNode<
11+
TStateSchema extends Record<string, unknown> = Record<string, unknown>,
12+
TContextSchema extends Record<string, unknown> = Record<string, unknown>
13+
> extends MiddlewareNode<TStateSchema, TContextSchema> {
14+
lc_namespace = ["langchain", "agents", "afterAgentNodes"];
15+
16+
constructor(
17+
public middleware: AgentMiddleware<
18+
z.ZodObject<z.ZodRawShape>,
19+
z.ZodObject<z.ZodRawShape>
20+
>,
21+
options: MiddlewareNodeOptions
22+
) {
23+
super(
24+
{
25+
name: `AfterAgentNode_${middleware.name}`,
26+
func: async (
27+
state: TStateSchema,
28+
config?: RunnableConfig<TContextSchema>
29+
) => this.invokeMiddleware(state, config),
30+
},
31+
options
32+
);
33+
}
34+
35+
runHook(state: TStateSchema, runtime: Runtime<TContextSchema>) {
36+
return this.middleware.afterAgent!(
37+
state as Record<string, unknown> & AgentBuiltInState,
38+
runtime as Runtime<unknown>
39+
) as Promise<MiddlewareResult<TStateSchema>>;
40+
}
41+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import { z } from "zod/v3";
2+
import { RunnableConfig } from "@langchain/core/runnables";
3+
import { MiddlewareNode, type MiddlewareNodeOptions } from "./middleware.js";
4+
import type { AgentMiddleware, MiddlewareResult } from "../middleware/types.js";
5+
import type { AgentBuiltInState, Runtime } from "../runtime.js";
6+
7+
/**
8+
* Node for executing a single middleware's beforeAgent hook.
9+
*/
10+
export class BeforeAgentNode<
11+
TStateSchema extends Record<string, unknown> = Record<string, unknown>,
12+
TContextSchema extends Record<string, unknown> = Record<string, unknown>
13+
> extends MiddlewareNode<TStateSchema, TContextSchema> {
14+
lc_namespace = ["langchain", "agents", "beforeAgentNodes"];
15+
16+
constructor(
17+
public middleware: AgentMiddleware<
18+
z.ZodObject<z.ZodRawShape>,
19+
z.ZodObject<z.ZodRawShape>
20+
>,
21+
options: MiddlewareNodeOptions
22+
) {
23+
super(
24+
{
25+
name: `BeforeAgentNode_${middleware.name}`,
26+
func: async (
27+
state: TStateSchema,
28+
config?: RunnableConfig<TContextSchema>
29+
) => this.invokeMiddleware(state, config),
30+
},
31+
options
32+
);
33+
}
34+
35+
runHook(state: TStateSchema, runtime: Runtime<TContextSchema>) {
36+
return this.middleware.beforeAgent!(
37+
state as Record<string, unknown> & AgentBuiltInState,
38+
runtime as Runtime<unknown>
39+
) as Promise<MiddlewareResult<TStateSchema>>;
40+
}
41+
}

0 commit comments

Comments
 (0)