Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 70 additions & 5 deletions langchain/src/chains/api/api_chain.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { z, ZodTypeAny } from "zod";
import { BaseChain, ChainInputs } from "../base.js";
import { SerializedAPIChain } from "../serde.js";
import { LLMChain } from "../llm_chain.js";
Expand All @@ -9,13 +10,17 @@ import {
API_RESPONSE_PROMPT_TEMPLATE,
} from "./prompts.js";
import { BasePromptTemplate } from "../../index.js";
import { StructuredOutputParser } from "../../output_parsers/structured.js";
import { OutputFixingParser } from "../../output_parsers/fix.js";

export interface APIChainInput extends Omit<ChainInputs, "memory"> {
llm: BaseLanguageModel;
apiAnswerChain: LLMChain;
apiRequestChain: LLMChain;
apiDocs: string;
inputKey?: string;
headers?: Record<string, string>;
allowedMethods?: string[];
/** Key to use for output, defaults to `output` */
outputKey?: string;
}
Expand All @@ -27,6 +32,8 @@ export type APIChainOptions = {
};

export class APIChain extends BaseChain implements APIChainInput {
llm: BaseLanguageModel;

apiAnswerChain: LLMChain;

apiRequestChain: LLMChain;
Expand All @@ -39,6 +46,8 @@ export class APIChain extends BaseChain implements APIChainInput {

outputKey = "output";

allowedMethods = ["GET", "POST"];

get inputKeys() {
return [this.inputKey];
}
Expand All @@ -55,6 +64,7 @@ export class APIChain extends BaseChain implements APIChainInput {
this.inputKey = fields.inputKey ?? this.inputKey;
this.outputKey = fields.outputKey ?? this.outputKey;
this.headers = fields.headers ?? this.headers;
this.allowedMethods = fields.allowedMethods ?? this.allowedMethods;
}

/** @ignore */
Expand All @@ -64,16 +74,47 @@ export class APIChain extends BaseChain implements APIChainInput {
): Promise<ChainValues> {
const question: string = values[this.inputKey];

const api_url = await this.apiRequestChain.predict(
const api_json = await this.apiRequestChain.predict(
{ question, api_docs: this.apiDocs },
runManager?.getChild()
);

const res = await fetch(api_url, { headers: this.headers });
const fixParser = OutputFixingParser.fromLLM(
this.llm,
APIChain.getApiParser()
);
const api_options = await fixParser.parse(api_json);

if (!this.allowedMethods.includes(api_options.api_method)) {
throw new Error(
`${api_options.api_method} is not part of allowedMethods`
);
}

const request_options =
api_options.api_method === "GET" ||
api_options.api_method === "HEAD" ||
api_options.api_method === "DELETE"
? {
method: api_options.api_method,
headers: this.headers,
}
: {
method: api_options.api_method,
headers: this.headers,
body: JSON.stringify(api_options.api_body),
};

const res = await fetch(api_options.api_url, request_options);
const api_response = await res.text();

const answer = await this.apiAnswerChain.predict(
{ question, api_docs: this.apiDocs, api_url, api_response },
{
question,
api_docs: this.apiDocs,
api_url: api_options.api_url,
api_response,
},
runManager?.getChild()
);

Expand All @@ -85,7 +126,7 @@ export class APIChain extends BaseChain implements APIChainInput {
}

static async deserialize(data: SerializedAPIChain) {
const { api_request_chain, api_answer_chain, api_docs } = data;
const { api_request_chain, api_answer_chain, api_docs, llm } = data;

if (!api_request_chain) {
throw new Error("LLMChain must have api_request_chain");
Expand All @@ -99,6 +140,7 @@ export class APIChain extends BaseChain implements APIChainInput {
}

return new APIChain({
llm: await BaseLanguageModel.deserialize(llm),
apiAnswerChain: await LLMChain.deserialize(api_answer_chain),
apiRequestChain: await LLMChain.deserialize(api_request_chain),
apiDocs: api_docs,
Expand All @@ -108,17 +150,39 @@ export class APIChain extends BaseChain implements APIChainInput {
serialize(): SerializedAPIChain {
return {
_type: this._chainType(),
llm: this.llm.serialize(),
api_answer_chain: this.apiAnswerChain.serialize(),
api_request_chain: this.apiRequestChain.serialize(),
api_docs: this.apiDocs,
};
}

static getApiParserSchema(): ZodTypeAny {
return z.object({
api_url: z
.string()
.describe(
"the formatted url in case of GET API call otherwise just the url"
),
api_body: z
.any()
.describe("formatted key value pair for making API call"),
api_method: z.string().describe("API method from documentation"),
});
}

static getApiParser(): StructuredOutputParser<ZodTypeAny> {
return StructuredOutputParser.fromZodSchema(this.getApiParserSchema());
}

static fromLLMAndAPIDocs(
llm: BaseLanguageModel,
apiDocs: string,
options: APIChainOptions &
Omit<APIChainInput, "apiAnswerChain" | "apiRequestChain" | "apiDocs"> = {}
Omit<
APIChainInput,
"apiAnswerChain" | "apiRequestChain" | "apiDocs" | "llm"
> = {}
): APIChain {
const {
apiUrlPrompt = API_URL_PROMPT_TEMPLATE,
Expand All @@ -127,6 +191,7 @@ export class APIChain extends BaseChain implements APIChainInput {
const apiRequestChain = new LLMChain({ prompt: apiUrlPrompt, llm });
const apiAnswerChain = new LLMChain({ prompt: apiResponsePrompt, llm });
return new this({
llm,
apiAnswerChain,
apiRequestChain,
apiDocs,
Expand Down
20 changes: 16 additions & 4 deletions langchain/src/chains/api/prompts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,23 @@ import { PromptTemplate } from "../../prompts/prompt.js";

export const API_URL_RAW_PROMPT_TEMPLATE = `You are given the below API Documentation:
{api_docs}
Using this documentation, generate the full API url to call for answering the user question.
You should build the API url in order to get a response that is as short as possible, while still getting the necessary information to answer the question. Pay attention to deliberately exclude any unnecessary pieces of data in the API call.

Question:{question}
API url:`;
<< FORMATTING >>
Return a JSON string with a JSON object formatted to look like:
{{
"api_url": string \\ the formatted url in case of GET API call otherwise just the url
"api_body": key value \\ formatted key value pair for making API call
"api_method": string \\ API method from documentation
}}

REMEMBER: "api_url" Must be a valid url and should be along with parameters if "api_method" is GET
REMEMBER: "api_body" Should be a valid JSON in case of POST or PUT api call

You should build the json string in order to get a response that is as short as possible, while still getting the necessary information
to answer the question. Pay attention to deliberately exclude any unnecessary pieces of data in the API call.
Question: {question}
json:
`;

export const API_URL_PROMPT_TEMPLATE = /* #__PURE__ */ new PromptTemplate({
inputVariables: ["api_docs", "question"],
Expand Down
1 change: 1 addition & 0 deletions langchain/src/chains/serde.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export type SerializedVectorDBQAChain = {

export type SerializedAPIChain = {
_type: "api_chain";
llm: SerializedLLM;
api_request_chain: SerializedLLMChain;
api_answer_chain: SerializedLLMChain;
api_docs: string;
Expand Down
62 changes: 60 additions & 2 deletions langchain/src/chains/tests/api_chain.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import { LLMChain } from "../llm_chain.js";
import { loadChain } from "../load.js";
import { APIChain, APIChainInput } from "../api/api_chain.js";
import {
API_URL_PROMPT_TEMPLATE,
API_RESPONSE_PROMPT_TEMPLATE,
API_URL_PROMPT_TEMPLATE,
} from "../api/prompts.js";
import { OPEN_METEO_DOCS } from "./example_data/open_meteo_docs.js";
import { POST_API_DOCS } from "./example_data/post_api_docs.js";
import { DELETE_API_DOCS } from "./example_data/delete_api_docs.js";

const test_api_docs = `
This API endpoint will search the notes for a user.
Expand All @@ -19,10 +21,26 @@ Query parameters:
q | string | The search term for notes
`;

const post_test_api_docs = `
API documentation:
Endpoint: https://httpbin.org

This API is for sending Postman message

POST /post

POST body table:
message | string | Message to send | required

Response schema (string):
result | string
`;

const testApiData = {
api_docs: test_api_docs,
question: "Search for notes containing langchain",
api_url: "https://httpbin.com/api/notes?q=langchain",
api_json:
'{"api_url":"https://httpbin.com/api/notes?q=langchain","api_method":"GET","api_data":{}}',
api_response: JSON.stringify({
success: true,
results: [{ id: 1, content: "Langchain is awesome!" }],
Expand All @@ -42,6 +60,7 @@ test("Test APIChain", async () => {
});

const apiChainInput: APIChainInput = {
llm: model,
apiAnswerChain,
apiRequestChain,
apiDocs: testApiData.api_docs,
Expand All @@ -65,6 +84,45 @@ test("Test APIChain fromLLMAndApiDocs", async () => {
console.log({ res });
});

test("Test POST APIChain fromLLMAndApiDocs", async () => {
const model = new OpenAI({ modelName: "text-davinci-003" });
const chain = APIChain.fromLLMAndAPIDocs(model, post_test_api_docs);
const res = await chain.call({
question: "send a message hi langchain",
});
console.log({ res });
});

test("Test POST 2 APIChain fromLLMAndApiDocs", async () => {
const model = new OpenAI({ modelName: "text-davinci-003" });
const chain = APIChain.fromLLMAndAPIDocs(model, POST_API_DOCS);
const res = await chain.call({
question: "send a message hi langchain to channel3 with token5",
});
console.log({ res });
});

test("Test DELETE APIChain fromLLMAndApiDocs if not in allowedMethods", async () => {
const model = new OpenAI({ modelName: "text-davinci-003" });
const chain = APIChain.fromLLMAndAPIDocs(model, DELETE_API_DOCS);
await expect(() =>
chain.call({
question: "delete a message with id 15",
})
).rejects.toThrow();
});

test("Test DELETE APIChain fromLLMAndApiDocs if allowed in allowedMethods", async () => {
const model = new OpenAI({ modelName: "text-davinci-003" });
const chain = APIChain.fromLLMAndAPIDocs(model, DELETE_API_DOCS, {
allowedMethods: ["GET", "POST", "DELETE"],
});
const res = await chain.call({
question: "delete a message with id 15",
});
console.log({ res });
});

test("Load APIChain from hub", async () => {
const chain = await loadChain("lc://chains/api/meteo/chain.json");
const res = await chain.call({
Expand Down
11 changes: 11 additions & 0 deletions langchain/src/chains/tests/example_data/delete_api_docs.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
export const DELETE_API_DOCS = `
Endpoint: https://httpbin.org

This API deletes a message
Method: DELETE

DELETE /api/message/{id} HTTP/1.1

Parameter Format Required Default Description
id String Yes message id to be deleted
`;
13 changes: 13 additions & 0 deletions langchain/src/chains/tests/example_data/post_api_docs.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
export const POST_API_DOCS = `
Endpoint: https://httpbin.org

This API Sends a message to a channel.
Method: POST

POST /post

POST Body:
token | string | Authentication token bearing required scopes| required
channel | string | Channel, private group, or IM channel to send message to. Can be an encoded ID, or a name. See below for more details. | required
text | string | text is the message you want to send in a channel
`;