Skip to content

Commit b3c969d

Browse files
thucpnleehuwuj
andauthored
feat: image generator tool (#135)
--------- Co-authored-by: leehuwuj <[email protected]>
1 parent 628e16d commit b3c969d

File tree

13 files changed

+272
-14
lines changed

13 files changed

+272
-14
lines changed

.changeset/fifty-mugs-suffer.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"create-llama": patch
3+
---
4+
5+
Add image generator tool

helpers/tools.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,24 @@ For better results, you can specify the region parameter to get results from a s
170170
},
171171
],
172172
},
173+
{
174+
display: "Image Generator",
175+
name: "img_gen",
176+
supportedFrameworks: ["fastapi", "express", "nextjs"],
177+
type: ToolType.LOCAL,
178+
envVars: [
179+
{
180+
name: "STABILITY_API_KEY",
181+
description:
182+
"STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys",
183+
},
184+
{
185+
name: TOOL_SYSTEM_PROMPT_ENV_VAR,
186+
description: "System prompt for image generator tool.",
187+
value: `You are an image generator agent. You help users to generate images using the Stability API.`,
188+
},
189+
],
190+
},
173191
];
174192

175193
export const getTool = (toolName: string): Tool | undefined => {

templates/components/engines/python/agent/tools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionToo
3131
return tool_spec.to_tool_list()
3232
else:
3333
module = importlib.import_module(f"{source_package}.{tool_name}")
34-
tools = module.get_tools()
34+
tools = module.get_tools(**config)
3535
if not all(isinstance(tool, FunctionTool) for tool in tools):
3636
raise ValueError(
3737
f"The module {module} does not contain valid tools"

templates/components/engines/python/agent/tools/duckduckgo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,5 @@ def duckduckgo_search(
3232
return results
3333

3434

35-
def get_tools():
35+
def get_tools(**kwargs):
3636
return [FunctionTool.from_defaults(duckduckgo_search)]
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import os
2+
import uuid
3+
import logging
4+
import requests
5+
from typing import Optional
6+
from pydantic import BaseModel, Field
7+
from llama_index.core.tools import FunctionTool
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class ImageGeneratorToolOutput(BaseModel):
13+
is_success: bool = Field(
14+
...,
15+
description="Whether the image generation was successful.",
16+
)
17+
image_url: Optional[str] = Field(
18+
None,
19+
description="The URL of the generated image.",
20+
)
21+
error_message: Optional[str] = Field(
22+
None,
23+
description="The error message if the image generation failed.",
24+
)
25+
26+
27+
class ImageGeneratorTool:
28+
_IMG_OUTPUT_FORMAT = "webp"
29+
_IMG_OUTPUT_DIR = "tool-output"
30+
_IMG_GEN_API = "https://api.stability.ai/v2beta/stable-image/generate/core"
31+
32+
def __init__(self, api_key: str = None):
33+
if not api_key:
34+
api_key = os.getenv("STABILITY_API_KEY")
35+
self._api_key = api_key
36+
self.fileserver_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
37+
if self._api_key is None:
38+
raise ValueError(
39+
"STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys"
40+
)
41+
if self.fileserver_url_prefix is None:
42+
raise ValueError("FILESERVER_URL_PREFIX is required.")
43+
44+
def _prepare_output_dir(self):
45+
"""
46+
Create the output directory if it doesn't exist
47+
"""
48+
if not os.path.exists(self._IMG_OUTPUT_DIR):
49+
os.makedirs(self._IMG_OUTPUT_DIR, exist_ok=True)
50+
51+
def _save_image(self, image_data: bytes):
52+
self._prepare_output_dir()
53+
filename = f"{uuid.uuid4()}.{self._IMG_OUTPUT_FORMAT}"
54+
output_path = os.path.join(self._IMG_OUTPUT_DIR, filename)
55+
with open(output_path, "wb") as f:
56+
f.write(image_data)
57+
url = f"{os.getenv('FILESERVER_URL_PREFIX')}/{self._IMG_OUTPUT_DIR}/{filename}"
58+
logger.info(f"Saved image to {output_path}.\nURL: {url}")
59+
return url
60+
61+
def _call_stability_api(self, prompt: str):
62+
headers = {
63+
"authorization": f"Bearer {self._api_key}",
64+
"accept": "image/*",
65+
}
66+
data = {
67+
"prompt": prompt,
68+
"output_format": self._IMG_OUTPUT_FORMAT,
69+
}
70+
71+
response = requests.post(
72+
self._IMG_GEN_API,
73+
headers=headers,
74+
files={"none": ""},
75+
data=data,
76+
)
77+
response.raise_for_status()
78+
79+
return response
80+
81+
def generate_image(self, prompt: str) -> ImageGeneratorToolOutput:
82+
"""
83+
Use this tool to generate an image based on the prompt.
84+
Args:
85+
prompt (str): The prompt to generate the image from.
86+
"""
87+
88+
try:
89+
# Call the Stability API
90+
response = self._call_stability_api(prompt)
91+
92+
# Save the image and get the URL
93+
image_url = self._save_image(response.content)
94+
95+
return ImageGeneratorToolOutput(
96+
is_success=True,
97+
image_url=image_url,
98+
)
99+
except Exception as e:
100+
logger.exception(e, exc_info=True)
101+
return ImageGeneratorToolOutput(
102+
is_success=False,
103+
error_message=str(e),
104+
)
105+
106+
107+
def get_tools(**kwargs):
108+
return [FunctionTool.from_defaults(ImageGeneratorTool(**kwargs).generate_image)]

templates/components/engines/python/agent/tools/interpreter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ class E2BCodeInterpreter:
2929

3030
output_dir = "tool-output"
3131

32-
def __init__(self):
33-
api_key = os.getenv("E2B_API_KEY")
32+
def __init__(self, api_key: str = None):
33+
if api_key is None:
34+
api_key = os.getenv("E2B_API_KEY")
3435
filesever_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
3536
if not api_key:
3637
raise ValueError(
@@ -138,5 +139,5 @@ def interpret(self, code: str) -> E2BToolOutput:
138139
return output
139140

140141

141-
def get_tools():
142-
return [FunctionTool.from_defaults(E2BCodeInterpreter().interpret)]
142+
def get_tools(**kwargs):
143+
return [FunctionTool.from_defaults(E2BCodeInterpreter(**kwargs).interpret)]

templates/components/engines/python/agent/tools/weather.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,5 @@ def get_weather_information(cls, location: str) -> dict:
6969
return response.json()
7070

7171

72-
def get_tools():
72+
def get_tools(**kwargs):
7373
return [FunctionTool.from_defaults(OpenMeteoWeather.get_weather_information)]
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import type { JSONSchemaType } from "ajv";
2+
import { FormData } from "formdata-node";
3+
import fs from "fs";
4+
import got from "got";
5+
import { BaseTool, ToolMetadata } from "llamaindex";
6+
import path from "node:path";
7+
import { Readable } from "stream";
8+
9+
export type ImgGeneratorParameter = {
10+
prompt: string;
11+
};
12+
13+
export type ImgGeneratorToolParams = {
14+
metadata?: ToolMetadata<JSONSchemaType<ImgGeneratorParameter>>;
15+
};
16+
17+
export type ImgGeneratorToolOutput = {
18+
isSuccess: boolean;
19+
imageUrl?: string;
20+
errorMessage?: string;
21+
};
22+
23+
const DEFAULT_META_DATA: ToolMetadata<JSONSchemaType<ImgGeneratorParameter>> = {
24+
name: "image_generator",
25+
description: `Use this function to generate an image based on the prompt.`,
26+
parameters: {
27+
type: "object",
28+
properties: {
29+
prompt: {
30+
type: "string",
31+
description: "The prompt to generate the image",
32+
},
33+
},
34+
required: ["prompt"],
35+
},
36+
};
37+
38+
export class ImgGeneratorTool implements BaseTool<ImgGeneratorParameter> {
39+
readonly IMG_OUTPUT_FORMAT = "webp";
40+
readonly IMG_OUTPUT_DIR = "tool-output";
41+
readonly IMG_GEN_API =
42+
"https://api.stability.ai/v2beta/stable-image/generate/core";
43+
44+
metadata: ToolMetadata<JSONSchemaType<ImgGeneratorParameter>>;
45+
46+
constructor(params?: ImgGeneratorToolParams) {
47+
this.checkRequiredEnvVars();
48+
this.metadata = params?.metadata || DEFAULT_META_DATA;
49+
}
50+
51+
async call(input: ImgGeneratorParameter): Promise<ImgGeneratorToolOutput> {
52+
return await this.generateImage(input.prompt);
53+
}
54+
55+
private generateImage = async (
56+
prompt: string,
57+
): Promise<ImgGeneratorToolOutput> => {
58+
try {
59+
const buffer = await this.promptToImgBuffer(prompt);
60+
const imageUrl = this.saveImage(buffer);
61+
return { isSuccess: true, imageUrl };
62+
} catch (error) {
63+
console.error(error);
64+
return {
65+
isSuccess: false,
66+
errorMessage: "Failed to generate image. Please try again.",
67+
};
68+
}
69+
};
70+
71+
private promptToImgBuffer = async (prompt: string) => {
72+
const form = new FormData();
73+
form.append("prompt", prompt);
74+
form.append("output_format", this.IMG_OUTPUT_FORMAT);
75+
const buffer = await got
76+
.post(this.IMG_GEN_API, {
77+
// Not sure why it shows an type error when passing form to body
78+
// Although I follow document: https://github.com/sindresorhus/got/blob/main/documentation/2-options.md#body
79+
// Tt still works fine, so I make casting to unknown to avoid the typescript warning
80+
// Found a similar issue: https://github.com/sindresorhus/got/discussions/1877
81+
body: form as unknown as Buffer | Readable | string,
82+
headers: {
83+
Authorization: `Bearer ${process.env.STABILITY_API_KEY}`,
84+
Accept: "image/*",
85+
},
86+
})
87+
.buffer();
88+
return buffer;
89+
};
90+
91+
private saveImage = (buffer: Buffer) => {
92+
const filename = `${crypto.randomUUID()}.${this.IMG_OUTPUT_FORMAT}`;
93+
const outputPath = path.join(this.IMG_OUTPUT_DIR, filename);
94+
fs.writeFileSync(outputPath, buffer);
95+
const url = `${process.env.FILESERVER_URL_PREFIX}/${this.IMG_OUTPUT_DIR}/${filename}`;
96+
console.log(`Saved image to ${outputPath}.\nURL: ${url}`);
97+
return url;
98+
};
99+
100+
private checkRequiredEnvVars = () => {
101+
if (!process.env.STABILITY_API_KEY) {
102+
throw new Error(
103+
"STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys",
104+
);
105+
}
106+
if (!process.env.FILESERVER_URL_PREFIX) {
107+
throw new Error(
108+
"FILESERVER_URL_PREFIX is required to display file output after generation",
109+
);
110+
}
111+
};
112+
}

templates/components/engines/typescript/agent/tools/index.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { BaseToolWithCall } from "llamaindex";
22
import { ToolsFactory } from "llamaindex/tools/ToolsFactory";
33
import { DuckDuckGoSearchTool, DuckDuckGoToolParams } from "./duckduckgo";
4+
import { ImgGeneratorTool, ImgGeneratorToolParams } from "./img-gen";
45
import { InterpreterTool, InterpreterToolParams } from "./interpreter";
56
import { OpenAPIActionTool } from "./openapi-action";
67
import { WeatherTool, WeatherToolParams } from "./weather";
@@ -39,6 +40,9 @@ const toolFactory: Record<string, ToolCreator> = {
3940
duckduckgo: async (config: unknown) => {
4041
return [new DuckDuckGoSearchTool(config as DuckDuckGoToolParams)];
4142
},
43+
img_gen: async (config: unknown) => {
44+
return [new ImgGeneratorTool(config as ImgGeneratorToolParams)];
45+
},
4246
};
4347

4448
async function createLocalTools(

templates/types/streaming/express/eslintrc.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,8 @@
33
"rules": {
44
"max-params": ["error", 4],
55
"prefer-const": "error"
6+
},
7+
"parserOptions": {
8+
"sourceType": "module"
69
}
710
}

0 commit comments

Comments
 (0)