diff --git a/libs/langchain-community/src/chat_models/ibm.ts b/libs/langchain-community/src/chat_models/ibm.ts index f9d203604b68..c02b45e4e4ea 100644 --- a/libs/langchain-community/src/chat_models/ibm.ts +++ b/libs/langchain-community/src/chat_models/ibm.ts @@ -668,6 +668,7 @@ export class ChatWatsonx< { ...scopeId, messages: watsonxMessages, + signal: options?.signal, }, watsonxCallbacks ) @@ -676,6 +677,7 @@ export class ChatWatsonx< ...params, ...scopeId, messages: watsonxMessages, + signal: options?.signal, }, watsonxCallbacks ); @@ -695,9 +697,6 @@ export class ChatWatsonx< } generations.push(generation); } - if (options.signal?.aborted) { - throw new Error("AbortError"); - } return { generations, @@ -727,6 +726,7 @@ export class ChatWatsonx< ...scopeId, messages: watsonxMessages, returnObject: true, + signal: options?.signal, }, watsonxCallbacks ) @@ -736,6 +736,7 @@ export class ChatWatsonx< ...scopeId, messages: watsonxMessages, returnObject: true, + signal: options?.signal, }, watsonxCallbacks ); @@ -745,9 +746,6 @@ export class ChatWatsonx< let usage: TextChatUsage | undefined; let currentCompletion = 0; for await (const chunk of stream) { - if (options.signal?.aborted) { - throw new Error("AbortError"); - } if (chunk?.data?.usage) usage = chunk.data.usage; const { data } = chunk; const choice = data.choices[0] as TextChatResultChoice & diff --git a/libs/langchain-community/src/chat_models/tests/ibm.test.ts b/libs/langchain-community/src/chat_models/tests/ibm.test.ts index bba376530bc4..45d03165f09a 100644 --- a/libs/langchain-community/src/chat_models/tests/ibm.test.ts +++ b/libs/langchain-community/src/chat_models/tests/ibm.test.ts @@ -289,4 +289,126 @@ describe("LLM unit tests", () => { testProperties(instance, testProps, notExTestProps); }); }); + + describe("AbortSignal parameter passing", () => { + test("Signal passed to textChat() with projectId", async () => { + const testProps = { + model: "ibm/granite-3-8b-instruct", + version: "2025-01-17", + serviceUrl: "https://test.watsonx.ai", + projectId: "test-project-id", + }; + const instance = new ChatWatsonx({ ...testProps, ...fakeAuthProp }); + + const mockResponse = { + choices: [{ message: { role: "assistant", content: "" } }], + }; + const spy = jest + .spyOn(instance.service, "textChat") + .mockResolvedValue({ result: mockResponse } as any); + + const controller = new AbortController(); + await instance.invoke("test", { signal: controller.signal }); + + expect(spy).toHaveBeenCalledWith( + expect.objectContaining({ signal: controller.signal }), + undefined + ); + + spy.mockRestore(); + }); + + test("Signal passed to deploymentsTextChat() with idOrName", async () => { + const testProps = { + version: "2025-01-17", + serviceUrl: "https://test.watsonx.ai", + idOrName: "test-deployment", + }; + const instance = new ChatWatsonx({ ...testProps, ...fakeAuthProp }); + + const mockResponse = { + choices: [{ message: { role: "assistant", content: "" } }], + }; + const spy = jest + .spyOn(instance.service, "deploymentsTextChat") + .mockResolvedValue({ result: mockResponse } as any); + + const controller = new AbortController(); + await instance.invoke("test", { signal: controller.signal }); + + expect(spy).toHaveBeenCalledWith( + expect.objectContaining({ signal: controller.signal }), + undefined + ); + + spy.mockRestore(); + }); + + test("Signal passed to textChatStream() with projectId", async () => { + const testProps = { + model: "ibm/granite-3-8b-instruct", + version: "2025-01-17", + serviceUrl: "https://test.watsonx.ai", + projectId: "test-project-id", + }; + const instance = new ChatWatsonx({ ...testProps, ...fakeAuthProp }); + + async function* mockStream() { + yield { data: { choices: [{ delta: {} }] } }; + } + + const spy = jest + .spyOn(instance.service, "textChatStream") + .mockResolvedValue(mockStream() as any); + + const controller = new AbortController(); + const stream = await instance.stream("test", { + signal: controller.signal, + }); + + for await (const _chunk of stream) { + /* consume stream */ + } + + expect(spy).toHaveBeenCalledWith( + expect.objectContaining({ signal: controller.signal }), + undefined + ); + + spy.mockRestore(); + }); + + test("Signal passed to deploymentsTextChatStream() with idOrName", async () => { + const testProps = { + version: "2025-01-17", + serviceUrl: "https://test.watsonx.ai", + idOrName: "test-deployment", + }; + const instance = new ChatWatsonx({ ...testProps, ...fakeAuthProp }); + + async function* mockStream() { + yield { data: { choices: [{ delta: {} }] } }; + } + + const spy = jest + .spyOn(instance.service, "deploymentsTextChatStream") + .mockResolvedValue(mockStream() as any); + + const controller = new AbortController(); + const stream = await instance.stream("test", { + signal: controller.signal, + }); + + for await (const _chunk of stream) { + /* consume stream */ + } + + expect(spy).toHaveBeenCalledWith( + expect.objectContaining({ signal: controller.signal }), + undefined + ); + + spy.mockRestore(); + }); + }); });