Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/wet-taxis-heal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@langchain/aws": minor
---

feat(aws): allow bedrock Application Inference Profile
16 changes: 14 additions & 2 deletions libs/providers/langchain-aws/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ export interface ChatBedrockConverseInput
*/
model?: string;

/**
* Application Inference Profile ARN to use for the model.
* For example, "arn:aws:bedrock:eu-west-1:123456789102:application-inference-profile/fm16bt65tzgx", will override this.model in final /invoke URL call.
* Must still provide `model` as normal modelId to benefit from all the metadata.
* See the below link for more details on creating and using application inference profiles.
* @link https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-create.html
*/
applicationInferenceProfile?: string;

/**
* The AWS region e.g. `us-west-2`.
* Fallback to AWS_DEFAULT_REGION env variable or region specified in ~/.aws/config
Expand Down Expand Up @@ -664,6 +673,8 @@ export class ChatBedrockConverse

model = "anthropic.claude-3-haiku-20240307-v1:0";

applicationInferenceProfile?: string | undefined = undefined;

streaming = false;

region: string;
Expand Down Expand Up @@ -745,6 +756,7 @@ export class ChatBedrockConverse

this.region = region;
this.model = rest?.model ?? this.model;
this.applicationInferenceProfile = rest?.applicationInferenceProfile;
this.streaming = rest?.streaming ?? this.streaming;
this.temperature = rest?.temperature;
this.maxTokens = rest?.maxTokens;
Expand Down Expand Up @@ -866,7 +878,7 @@ export class ChatBedrockConverse
const params = this.invocationParams(options);

const command = new ConverseCommand({
modelId: this.model,
modelId: this.applicationInferenceProfile ?? this.model,
messages: converseMessages,
system: converseSystem,
requestMetadata: options.requestMetadata,
Expand Down Expand Up @@ -907,7 +919,7 @@ export class ChatBedrockConverse
streamUsage = options.streamUsage;
}
const command = new ConverseStreamCommand({
modelId: this.model,
modelId: this.applicationInferenceProfile ?? this.model,
messages: converseMessages,
system: converseSystem,
requestMetadata: options.requestMetadata,
Expand Down
203 changes: 202 additions & 1 deletion libs/providers/langchain-aws/src/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import {
import { concat } from "@langchain/core/utils/stream";
import {
ConversationRole as BedrockConversationRole,
BedrockRuntimeClient,
type Message as BedrockMessage,
type SystemContentBlock as BedrockSystemContentBlock,
} from "@aws-sdk/client-bedrock-runtime";
import { z } from "zod/v3";
import { describe, expect, test, it } from "vitest";
import { describe, expect, test, it, vi } from "vitest";
import { convertToConverseMessages } from "../utils/message_inputs.js";
import { handleConverseStreamContentBlockDelta } from "../utils/message_outputs.js";
import { ChatBedrockConverse } from "../chat_models.js";
Expand Down Expand Up @@ -451,6 +452,206 @@ test("Streaming supports empty string chunks", async () => {
expect(finalChunk.content).toBe("Hello world!");
});

describe("applicationInferenceProfile parameter", () => {
const baseConstructorArgs = {
region: "us-east-1",
credentials: {
secretAccessKey: "test-secret-key",
accessKeyId: "test-access-key",
},
};

it("should initialize applicationInferenceProfile from constructor", () => {
const testArn =
"arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile";
const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
applicationInferenceProfile: testArn,
});
expect(model.model).toBe("anthropic.claude-3-haiku-20240307-v1:0");
expect(model.applicationInferenceProfile).toBe(testArn);
});

it("should be undefined when not provided in constructor", () => {
const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
});

expect(model.model).toBe("anthropic.claude-3-haiku-20240307-v1:0");
expect(model.applicationInferenceProfile).toBeUndefined();
});

it("should send applicationInferenceProfile as modelId in ConverseCommand when provided", async () => {
const testArn =
"arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile";
const mockSend = vi.fn().mockResolvedValue({
output: {
message: {
role: "assistant",
content: [{ text: "Test response" }],
},
},
stopReason: "end_turn",
usage: {
inputTokens: 10,
outputTokens: 5,
totalTokens: 15,
},
});

const mockClient = {
send: mockSend,
} as unknown as BedrockRuntimeClient;

const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
applicationInferenceProfile: testArn,
client: mockClient,
});

await model.invoke([new HumanMessage("Hello")]);

// Verify that send was called
expect(mockSend).toHaveBeenCalledTimes(1);

// Verify that the command was created with applicationInferenceProfile as modelId
const commandArg = mockSend.mock.calls[0][0];
expect(commandArg.input.modelId).toBe(testArn);
expect(commandArg.input.modelId).not.toBe(
"anthropic.claude-3-haiku-20240307-v1:0"
);
});

it("should send model as modelId in ConverseCommand when applicationInferenceProfile is not provided", async () => {
const mockSend = vi.fn().mockResolvedValue({
output: {
message: {
role: "assistant",
content: [{ text: "Test response" }],
},
},
stopReason: "end_turn",
usage: {
inputTokens: 10,
outputTokens: 5,
totalTokens: 15,
},
});

const mockClient = {
send: mockSend,
} as unknown as BedrockRuntimeClient;

const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
client: mockClient,
});

await model.invoke([new HumanMessage("Hello")]);

// Verify that send was called
expect(mockSend).toHaveBeenCalledTimes(1);

// Verify that the command was created with model as modelId
const commandArg = mockSend.mock.calls[0][0];
expect(commandArg.input.modelId).toBe(
"anthropic.claude-3-haiku-20240307-v1:0"
);
});

it("should send applicationInferenceProfile as modelId in ConverseStreamCommand when provided", async () => {
const testArn =
"arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile/test-profile";
const mockSend = vi.fn().mockResolvedValue({
stream: (async function* () {
yield {
contentBlockDelta: {
contentBlockIndex: 0,
delta: { text: "Test" },
},
};
yield {
metadata: {
usage: {
inputTokens: 10,
outputTokens: 5,
totalTokens: 15,
},
},
};
})(),
});

const mockClient = {
send: mockSend,
} as unknown as BedrockRuntimeClient;

const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
applicationInferenceProfile: testArn,
streaming: true,
client: mockClient,
});

await model.invoke([new HumanMessage("Hello")]);

expect(mockSend).toHaveBeenCalledTimes(1);

const commandArg = mockSend.mock.calls[0][0];
expect(commandArg.input.modelId).toBe(testArn);
expect(commandArg.input.modelId).not.toBe(
"anthropic.claude-3-haiku-20240307-v1:0"
);
});

it("should send model as modelId in ConverseStreamCommand when applicationInferenceProfile is not provided", async () => {
const mockSend = vi.fn().mockResolvedValue({
stream: (async function* () {
yield {
contentBlockDelta: {
contentBlockIndex: 0,
delta: { text: "Test" },
},
};
yield {
metadata: {
usage: {
inputTokens: 10,
outputTokens: 5,
totalTokens: 15,
},
},
};
})(),
});

const mockClient = {
send: mockSend,
} as unknown as BedrockRuntimeClient;

const model = new ChatBedrockConverse({
...baseConstructorArgs,
model: "anthropic.claude-3-haiku-20240307-v1:0",
streaming: true,
client: mockClient,
});

await model.invoke([new HumanMessage("Hello")]);

expect(mockSend).toHaveBeenCalledTimes(1);

const commandArg = mockSend.mock.calls[0][0];
expect(commandArg.input.modelId).toBe(
"anthropic.claude-3-haiku-20240307-v1:0"
);
});
});

describe("tool_choice works for supported models", () => {
const tool = {
name: "weather",
Expand Down