diff --git a/data_generator.py b/data_generator.py index d7ccc1e..ce664fc 100644 --- a/data_generator.py +++ b/data_generator.py @@ -1,90 +1,473 @@ -import os +import argparse +import asyncio +import json +import re + +from dataclasses import dataclass +from itertools import islice +from typing import Callable, Dict, Iterator, Optional, Tuple, Awaitable import torch -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, -) +import torch.nn.functional as F +import tritonclient.grpc.aio as grpcclient from datasets import load_dataset -from prompts import retrieval_prompt -from data_generation.retrieval import RetrievalPostprocessing -from data_generation.calendar import CalendarPostprocessing -from data_generation.calculator import CalculatorPostprocessing -from data_generation.api_checker import check_apis_available -import json -import time -import argparse +from transformers import AutoTokenizer +from tqdm import tqdm +from tools import Calculator, Tool, WikiSearch -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='do some continuations') - parser.add_argument('--device_id', type=int, default=0) - parser.add_argument("--num_devices", type=int, default=8) - args = parser.parse_args() - gpt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") - prompt_tokens = gpt_tokenizer(retrieval_prompt, return_tensors="pt")["input_ids"] - start_tokens = [ - gpt_tokenizer("[")["input_ids"][0], - gpt_tokenizer(" [")["input_ids"][0], +""" +Tool use parsing regex +[Calculator(1 / 2) -> 0.5] into ("Calculator", "400 / 1400", "0.5") +[Date()] into ("Date", None, None) + [WikiSearch('abcdef')] into ("WikiSearch", "'abcdef'", None) +""" +TOOL_REGEX = "\s?\[([A-Za-z]+)\((.*)\)(\s->\s.*)?\]" + +TOOLFORMER_API_START = "" +TOOLFORMER_API_RESPONSE = "" +TOOLFORMER_API_END = "" + + +@dataclass +class ToolUse: + tool: Tool + args: str + output: str + tokens: torch.Tensor + prompt: str + insert_index: int + + def __str__(self): + return f"{self.tool.name}({self.args}) -> {self.output}" + + def render(self, tokenizer: AutoTokenizer) -> str: + """ + Renders the tool use as a string. + """ + prefix = tokenizer.decode(self.tokens[0, : self.insert_index]) + suffix = tokenizer.decode(self.tokens[0, self.insert_index :]) + return f"{prefix}{TOOLFORMER_API_START}{self.tool.name}({self.args}){TOOLFORMER_API_RESPONSE}{self.output}{TOOLFORMER_API_END}{suffix}" + + +def interpret_tools(input_text: str, tools: Dict[str, Callable[[str], str]]) -> str: + """ + Interprets the use of tools in text and replaces them with their output. + """ + output_text = input_text + for match in re.finditer(TOOL_REGEX, input_text): + tool_name = match.group(1) + tool_args = match.group(2) # empty string if no args + tool_output: Optional[str] = match.group(3) + if tool_name in tools and tool_output is None: + args = [tool_args] if tool_args else [] + tool_output = tools[tool_name](*args) + output = f"[{tool_name}({tool_args}) -> {tool_output}]" + output_text = output_text.replace(match.group(0), output) + elif tool_name not in tools: + print(f"Unknown tool: {tool_name}") + + return output_text + + +""" +To sample API calls, we write a prompt that encourages a LM to annotate text with API calls. +Then, we find the top k locations with the highest probability of tokens. +Then, we sample m API calls from the top k locations, giving the prompt as a prefix and as end of sentence token. + +Then we execute all found API calls and keep their results. +Then we filter API calls by measuring the cross entropy loss between the original text with API call and results prefixed to it, +the original text with no call, and the original text with the API call args but without outputs. +""" + + +async def sample_api_calls( + tool: Tool, + model: Callable[[torch.Tensor], Awaitable[Tuple[torch.Tensor, torch.Tensor]]], + tokenizer: AutoTokenizer, + prompt: str, + k: int, + m: int, + start_tokens: int, + end_token: int, + api_call_threshold: float = 0.05, + max_length: int = 1024, + new_tokens: int = 100, +): + # Build annotator prompt with as placeholder for the prompt + # like "Input: example1\nOutput: annotated example1\nInput: \nOutput: " + annotate_prompt = tool.prompt.replace("", prompt) + # Concate prompt to `...\nOutput: ` + unannotated_output = annotate_prompt + prompt + + prompted_tokens = tokenizer( + [unannotated_output], + return_tensors="pt", + truncation=True, + max_length=max_length - new_tokens, + ).input_ids + prefix_tokens = tokenizer( + [annotate_prompt], + return_tensors="pt", + truncation=True, + max_length=max_length - new_tokens, + ).input_ids + input_tokens = prompted_tokens[:, :-1] + labels = prompted_tokens[:, 1:] + logits, _ = await model(input_tokens) + probs = F.softmax(logits.float(), dim=-1) + + # Find top k locations with highest probability of tokens + + # Make sure we don't keep any tokens that are supposed to be [ or try to + # insert tokens info the prefix + remove_tokens = ~torch.any( + torch.stack([labels == start_token for start_token in start_tokens]), + dim=0, + ) + remove_tokens[:, : prefix_tokens.shape[1]] = False + + probs_for_api = probs[0, :, start_tokens].max(dim=-1).values + probs_for_api = probs_for_api * remove_tokens.float() + + top_k_probs, top_k_indices = torch.topk(probs_for_api, min(len(probs_for_api), k)) + + for idx, prob in zip(top_k_indices, top_k_probs): + if prob.item() < api_call_threshold: + break + + insert_api_at = idx.item() + + selected_start_tokens = probs[:, insert_api_at, start_tokens].argmax().item() + + for i in range(m): + _, api_calls = await model( + torch.cat( + [ + input_tokens[:, :insert_api_at], + torch.full( + [1, 1], + start_tokens[selected_start_tokens], + dtype=torch.long, + ), + ], + dim=1, + ), + new_tokens=new_tokens, + ) + + api_call_str = tokenizer.decode(api_calls[0][insert_api_at:]) + match = re.match(TOOL_REGEX, api_call_str) + + if match is None: + continue + + tool_name = match.group(1) + tool_args = match.group(2) + + if tool_name == tool.name: + loop = asyncio.get_running_loop() + + api_output = await loop.run_in_executor(None, tool, tool_args) + + few_shot_prompt_start = prefix_tokens.shape[1] + prompt_tokens = input_tokens[:, few_shot_prompt_start:] + yield ToolUse( + tool=tool, + args=tool_args, + output=api_output, + insert_index=insert_api_at - few_shot_prompt_start, + prompt=prompt, + tokens=prompt_tokens, + ) + + +async def api_loss_reduction( + model: Callable[[torch.Tensor], Awaitable[torch.Tensor]], + tokenizer: AutoTokenizer, + tool_use: ToolUse, + max_length: int = 1024, +): + api_use = tokenizer( + f" [{tool_use.tool.name}({tool_use.args}) -> {tool_use.output}]", + return_tensors="pt", + ) + + api_args = tokenizer( + f"[{tool_use.tool.name}({tool_use.args})]", + return_tensors="pt", + ) + + input_list = [ + torch.cat([api_use.input_ids[0], tool_use.tokens[0]], dim=0)[:max_length], + torch.cat([api_args.input_ids[0], tool_use.tokens[0]], dim=0)[:max_length], + tool_use.tokens[0], ] - end_tokens = [ - gpt_tokenizer("]")["input_ids"][0], - gpt_tokenizer(" ]")["input_ids"][0], - ] # TODO: keep second? - api_handler = RetrievalPostprocessing(start_tokens, end_tokens) - model = AutoModelForCausalLM.from_pretrained( - "EleutherAI/gpt-j-6B", - revision="float16", - torch_dtype=torch.float16, - low_cpu_mem_usage=True, - ).cuda() - dataset = load_dataset("c4", "en", split="train", streaming=True) - iter_data = iter(dataset) - test = False - counter = 0 - file_counter = 0 - found_examples = 0 - output_dataset = list() - start_time = time.process_time() - num_examples = int(25000.0/float(args.num_devices)) - start_count = -1 - if os.path.isfile(f"retrieval_data_{args.device_id}.json"): - with open(f"retrieval_data_{args.device_id}.json") as f: - output_dataset = json.load(f) - start_count = output_dataset[-1]['file_index'] - for item in output_dataset: - num_examples -= len(item['retrieval_outputs']) - while found_examples < num_examples: - data = next(iter_data) - if file_counter < start_count: - file_counter += 1 - continue - if file_counter % args.num_devices != args.device_id: - file_counter += 1 - continue - available = check_apis_available(data, gpt_tokenizer) - test = available.retrieval - if test: - data_outputs = api_handler.parse_article(data, model, gpt_tokenizer) - output_dataset.append( - { - "file_index": file_counter, - "text": data["text"], - "retrieval_outputs": data_outputs - } + + inputs = torch.nn.utils.rnn.pad_sequence( + input_list, batch_first=True, padding_value=tokenizer.pad_token_id + )[:, :max_length] + + input_tokens = inputs[:, :-1] + label_tokens = inputs[:, 1:] + input_lengths = torch.tensor([x.shape[0] for x in input_list]) + prompt_length = input_lengths[2].item() - 1 + suffix_length = prompt_length - tool_use.insert_index + + logits, _ = await model(input_tokens) + + def weighted_cross_entropy(logits, labels, length): + un_weighted_xent = F.cross_entropy( + logits[length - suffix_length : length], + labels[length - suffix_length : length], + reduction="none", + ) + + weights = 1.0 - 0.2 * torch.arange(un_weighted_xent.shape[0]) + weights = torch.maximum(weights, torch.zeros_like(weights)) + return (un_weighted_xent * weights).sum() + + L_plus = weighted_cross_entropy(logits[0], label_tokens[0], input_lengths[0]) + + L_minus_with_api_args = weighted_cross_entropy( + logits[1], label_tokens[1], input_lengths[1] + ) + + L_minus_without_api = weighted_cross_entropy( + logits[2], label_tokens[2], input_lengths[2] + ) + + L_minus = min(L_minus_with_api_args, L_minus_without_api) + + return L_minus - L_plus + + +def prepare_inference_inputs( + inputs_ids: torch.IntTensor, new_tokens: int = 1, temperature: float = 1.0 +): + batch_size = inputs_ids.shape[0] + + input_ids_input = grpcclient.InferInput("input_ids", inputs_ids.shape, "INT32") + input_ids_input.set_data_from_numpy(inputs_ids.int().cpu().numpy()) + + new_tokens_input = grpcclient.InferInput( + "tensor_of_seq_len", [batch_size, new_tokens], "INT32" + ) + new_tokens_input.set_data_from_numpy( + torch.zeros(batch_size, new_tokens, dtype=torch.int32).cpu().numpy() + ) + + temperature_input = grpcclient.InferInput("temperature", [batch_size, 1], "FP32") + temperature_input.set_data_from_numpy( + torch.full([batch_size, 1], temperature, dtype=torch.float32).cpu().numpy() + ) + + inputs = [input_ids_input, new_tokens_input, temperature_input] + outputs = [ + grpcclient.InferRequestedOutput("logits"), + grpcclient.InferRequestedOutput("output_ids"), + ] + return inputs, outputs + + +async def infer( + triton_client, model_name, input_ids, new_tokens: int = 1, temperature: float = 1.0 +): + inputs, outputs = prepare_inference_inputs(input_ids, new_tokens, temperature) + + triton_model_name = model_name.replace("/", "--") + + result = await triton_client.infer( + model_name=triton_model_name, inputs=inputs, outputs=outputs + ) + + logits = torch.tensor(result.as_numpy("logits").copy(), requires_grad=False) + output_ids = torch.tensor(result.as_numpy("output_ids").copy(), requires_grad=False) + + return logits, output_ids + + +async def main( + model_name, + url, + output_file, + tau=0.5, + max_concurrent=32, + max_samples=10, + max_datapoints=None, + max_length=1024, +): + async with grpcclient.InferenceServerClient( + url=url, + ) as triton_client: + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.truncation_side = "left" + + print( + "Example output:", + ToolUse( + tool=Calculator(), + args="1+1", + output="2", + tokens=tokenizer( + "the result of 1+1 is 2", return_tensors="pt" + ).input_ids, + prompt="the result of 1+1 is 2", + insert_index=6, + ).render(tokenizer=tokenizer), + ) + + async def infer_model(input_ids, new_tokens: int = 1, temperature: float = 1.0): + return await infer( + triton_client, model_name, input_ids, new_tokens, temperature ) - prev_found = found_examples - found_examples += len(output_dataset[-1]["retrieval_outputs"]) - eta_s = (num_examples - found_examples) * (time.process_time()-start_time) / max(1, found_examples) - eta_m = eta_s // 60 - eta_h = eta_m // 60 - eta_m = eta_m - (eta_h*60) - eta_s = eta_s - ((eta_m*60) + (eta_h*60*60)) - print(f"Found: {found_examples}/{num_examples}, ETA: {eta_h}H:{eta_m}M:{eta_s}s") - if found_examples//100 > prev_found//100: - with open(f"retrieval_data_{args.device_id}.json", 'w') as f: - json.dump(output_dataset, f, indent=2) - counter += 1 - file_counter += 1 - with open(f"retrieval_data_{args.device_id}.json", 'w') as f: - json.dump(output_dataset, f, indent=2) \ No newline at end of file + + start_tokens = [ + tokenizer("[")["input_ids"][0], + tokenizer(" [")["input_ids"][0], + ] + end_token = tokenizer("]")["input_ids"][0] + + tools = [Calculator(), WikiSearch()] + + dataset = load_dataset("c4", "en", split="train", streaming=True) + iter_data = iter(dataset) + + if max_datapoints is not None: + iter_data = islice(iter_data, max_datapoints) + + async def sample_and_filter_api_calls(tool, text, top_k, n_gen): + async for tool_use in sample_api_calls( + tool=tool, + model=infer_model, + tokenizer=tokenizer, + prompt=text, + k=top_k, + m=n_gen, + start_tokens=start_tokens, + end_token=end_token, + max_length=max_length, + api_call_threshold=0.05, + ): + lm_loss_diff = await api_loss_reduction( + infer_model, tokenizer, tool_use, max_length=max_length + ) + if lm_loss_diff.item() > tau: + return tool_use + + pbar = tqdm(total=max_datapoints) + pbar.set_description("Datapoints processed") + + tooled_pbar = tqdm(total=max_samples) + tooled_pbar.set_description("Tool uses sampled") + with open(output_file, "w") as f: + counter = 0 + + while True: + data_samples = [] + for _ in range(max_concurrent): + try: + data = next(iter_data) + for tool in tools: + if tool.heuristic(data): + data_samples.append((tool, data)) + except StopIteration: + break + tasks = [ + asyncio.create_task( + sample_and_filter_api_calls( + tool, data["text"], top_k=5, n_gen=1 + ) + ) + for tool, data in data_samples + ] + + for sampled_tool_use in asyncio.as_completed(tasks): + tool_use = await sampled_tool_use + pbar.update(1) + if tool_use is not None: + counter += 1 + tooled_pbar.update(1) + print(tool_use) + + f.write( + json.dumps(dict(text=tool_use.render(tokenizer))) + "\n" + ) + f.flush() + + if counter > max_samples: + return + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + default="EleutherAI/gpt-j-6B", + help="Name of the model to use", + ) + + parser.add_argument( + "--url", + type=str, + default="localhost:8001", + help="URL to the GRPCInferenceService of Triton Inference Server", + ) + + parser.add_argument( + "--output_file", + type=str, + default="output.jsonl", + help="Path to the output file", + ) + + parser.add_argument( + "--tau", + type=float, + default=0.5, + help="Threshold for LM loss reduction the API call to be considered useful", + ) + + parser.add_argument( + "--max-concurrent", + type=int, + default=32, + help="Maximum number of samples to process concurrently", + ) + + parser.add_argument( + "--max-samples", + type=int, + default=10, + help="Maximum number of tool annotated samples to generate", + ) + + parser.add_argument( + "--max-datapoints", + type=int, + default=None, + help="Maximum number of datapoints to sample from", + ) + + parser.add_argument( + "--max-length", + type=int, + default=2048, + help="Maximum length in tokens of the generated text", + ) + + args = parser.parse_args() + asyncio.run( + main( + args.model_name, + args.url, + args.output_file, + args.tau, + args.max_concurrent, + args.max_samples, + args.max_datapoints, + args.max_length, + ) + ) diff --git a/inference/.gitignore b/inference/.gitignore new file mode 100644 index 0000000..44c4a86 --- /dev/null +++ b/inference/.gitignore @@ -0,0 +1 @@ +model_store/** \ No newline at end of file diff --git a/inference/README.md b/inference/README.md new file mode 100644 index 0000000..0225ba3 --- /dev/null +++ b/inference/README.md @@ -0,0 +1,3 @@ +`python convert_to_triton.py --model EleutherAI/gpt-j-6B` + +`./start_server.sh` \ No newline at end of file diff --git a/inference/__init__.py b/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/inference/convert_to_triton.py b/inference/convert_to_triton.py new file mode 100644 index 0000000..e100025 --- /dev/null +++ b/inference/convert_to_triton.py @@ -0,0 +1,113 @@ +# Adapted from https://github.com/CarperAI/trlx/blob/93c90cbdc3c6b463f565b09340ca1f74271285c5/examples/hh/triton_config.pbtxt + +import argparse +import os +from string import Template + +import torch +from torch import nn +from transformers import AutoModelForCausalLM, AutoTokenizer + +parser = argparse.ArgumentParser() + +parser.add_argument( + "--model", type=str, required=True, help="Path to HF checkpoint with the base model" +) + +parser.add_argument( + "--max-batch-size", type=int, default=4, help="Maximum batch size for inference" +) + +parser.add_argument( + "--revision", + type=str, + required=False, + help="Optional branch/commit of the HF checkpoint", +) + +parser.add_argument("--device", type=int, default=0) +args = parser.parse_args() + +device = torch.device(args.device) + + +class ModelLogits(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + @torch.inference_mode() + def forward(self, input_ids: torch.Tensor): + return self.model(input_ids).logits + + +class InferModel(nn.Module): + def __init__(self, traced_model): + super().__init__() + self.traced_model = traced_model + + @torch.inference_mode() + def forward( + self, + input_ids: torch.Tensor, + tensor_of_seq_len: torch.Tensor, + temperature: torch.Tensor, + ): + for _ in range(tensor_of_seq_len.shape[1] - 1): + logits = self.traced_model(input_ids) + next_token = torch.multinomial( + torch.softmax(logits[:, -1, :] / temperature, dim=-1), 1 + ).squeeze(1) + input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1) + + # in TorchScript, the above logits var lifetime doesn't escape the loop's scope + logits = self.traced_model(input_ids).float() + next_token = torch.multinomial( + torch.softmax(logits[:, -1, :] / temperature, dim=-1), 1 + ).squeeze(1) + input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1) + + return input_ids.int(), logits + + +print(f"Converting {args.model} to TorchScript...") +tokenizer = AutoTokenizer.from_pretrained(args.model) +model = ModelLogits(AutoModelForCausalLM.from_pretrained(args.model)) +model.eval() +model.requires_grad_(False) +model = model.half().to(device) + +input = tokenizer("annotator model's hash is 0x", return_tensors="pt").to(device) +print(f"{model(input.input_ids)=}") + +traced_script_module = torch.jit.trace(model, input.input_ids) + +print(f"{traced_script_module(input.input_ids)=}") + +print("Scripting generation wrapper...") + +scripted_generator_model = torch.jit.script(InferModel(traced_script_module)) + +print(f"{input.input_ids=}") +x = input.input_ids, torch.empty(1, 5).cuda(), torch.full([1, 1], 1.0).cuda() +print(f"{(scripted_generator_model(*x))=}") +print(f"{tokenizer.decode(scripted_generator_model(*x)[0][0])=}") + +sanitized_name = args.model.replace("/", "--") +print("Model renamed to ", sanitized_name) + +print("Saving TorchScript model...") + +os.makedirs(f"model_store/{sanitized_name}/1", exist_ok=True) +scripted_generator_model.save(f"model_store/{sanitized_name}/1/traced-model.pt") + +config_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "triton_config.pbtxt" +) +with open(config_path) as f: + template = Template(f.read()) +config = template.substitute( + {"model_name": sanitized_name, "max_batch_size": args.max_batch_size} +) +with open(f"model_store/{sanitized_name}/config.pbtxt", "w") as f: + f.write(config) diff --git a/inference/start_server.sh b/inference/start_server.sh new file mode 100755 index 0000000..30ac433 --- /dev/null +++ b/inference/start_server.sh @@ -0,0 +1 @@ +docker run --gpus=1 --rm --net=host -v ${PWD}/model_store:/model_store nvcr.io/nvidia/tritonserver:23.01-py3 tritonserver --model-repository=/model_store \ No newline at end of file diff --git a/inference/triton_config.pbtxt b/inference/triton_config.pbtxt new file mode 100644 index 0000000..07c5fb2 --- /dev/null +++ b/inference/triton_config.pbtxt @@ -0,0 +1,66 @@ +name: "${model_name}" +backend: "pytorch" +default_model_filename: "traced-model.pt" +max_batch_size: ${max_batch_size} + +parameters { + key: "model_name" + value: { + string_value: "${model_name}" + } +} + +instance_group [ + { + count: 1 + kind: KIND_GPU + gpus: [0] + } +] + +input [ + { + name: "input_ids" + data_type: TYPE_INT32 + dims: [-1] + }, + { + name: "tensor_of_seq_len" + data_type: TYPE_INT32 + dims: [-1] + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [-1] + } +] + +output [ + { + name: "output_ids" + data_type: TYPE_INT32 + dims: [-1] + }, + { + name: "logits" + data_type: TYPE_FP32 + dims: [-1] + } +] + +parameters { + key: "data_type" + value: { + string_value: "fp16" + } +} + +parameters: { + key: "INFERENCE_MODE" + value: { + string_value: "true" + } +} + +version_policy: {specific: {versions: [1]}} diff --git a/prompts.py b/prompts.py index cb96f1c..718cec9 100644 --- a/prompts.py +++ b/prompts.py @@ -13,9 +13,8 @@ Output: I went to Paris in 1994 and stayed there until 2011, so in total, it was [Calculator(2011 - 1994)] 17 years. Input: From this, we have 4 * 30 minutes = 120 minutes. Output: From this, we have 4 * 30 minutes = [Calculator(4 * 30)] 120 minutes. -Input: x -Output: -""" +Input: +Output: """ retrieval_prompt = """ Your task is to complete a given piece of text. @@ -29,8 +28,7 @@ Input: Metformin is the first-line drug for patients with type 2 diabetes and obesity. Output: Metformin is the first-line drug for [Retrieval("illness, diabetes, obesity")] patients with type 2 diabetes and obesity. Input: -Output: -""" +Output: """ llmchain_prompt = """ Your task is to complete a given piece of text. @@ -58,9 +56,8 @@ Output: But what are the risks during production of nanomaterials? [WikiSearch("nanomaterial production risks")] Some nanomaterials may give rise to various kinds of lung damage. Input: Metformin is the first-line drug for patients with type 2 diabetes and obesity. Output: Metformin is the first-line drug for [WikiSearch("Metformin first-line drug")] patients with type 2 diabetes and obesity. -Input: x -Output: -""" +Input: +Output: """ machine_translation_prompt = """ Your task is to complete a given piece of text by using a Machine Translation API. @@ -72,9 +69,8 @@ Output: In Morris de Jonge’s Jeschuah, der klassische jüdische Mann [MT(der klassische jüdische Mann)], there is a description of a Jewish writer Input: 南 京 高 淳 县 住 房 和 城 乡 建 设 局 城 市 新 区 设 计 a plane of reference Gaochun is one of seven districts of the provincial capital Nanjing Output: [MT(南京高淳县住房和城乡建设局 城市新 区 设 计)] a plane of reference Gaochun is one of seven districts of the provincial capital Nanjing -Input: x -Output: -""" +Input: +Output: """ calendar_prompt = """ Your task is to add calls to a Calendar API to a piece of text. @@ -91,6 +87,5 @@ Output: The number of days from now until Christmas is [Calendar()] 30. Input: The store is never open on the weekend, so today it is closed. Output: The store is never open on the weekend, so today [Calendar()] it is closed. -Input: x -Output: -""" +Input: +Output: """ diff --git a/requirements.txt b/requirements.txt index 194f7d3..2de045c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,6 @@ wolframalpha transformers openai langchain -cohere \ No newline at end of file +cohere +tritonclient==2.31.0 +tqdm \ No newline at end of file diff --git a/tools.py b/tools.py index 9f0cdb5..e4ddd84 100644 --- a/tools.py +++ b/tools.py @@ -1,4 +1,7 @@ import copy +import re +import random + import requests import calendar import json @@ -15,6 +18,13 @@ ) from typing import List from operator import truediv, mul, add, sub +from prompts import ( + calculator_prompt, + wikipedia_search_prompt, + machine_translation_prompt, + calendar_prompt, + retrieval_prompt, +) from langchain.chains import LLMChain from langchain import Cohere, PromptTemplate @@ -22,6 +32,20 @@ from googleapiclient.discovery import build +class Tool: + prompt: str = "" + + def __init_subclass__(cls) -> None: + """Give instances of this class a name attribute corresponding to the subclass name.""" + cls.name = cls.__name__ + + def heuristic(self, input: dict) -> bool: + return True + + def __call__(self, text: str) -> str: + return "" + + """ Calendar @@ -33,8 +57,14 @@ """ -def Calendar(date=datetime.datetime.now()): - return f"Today is {calendar.day_name[date.weekday()]}, {calendar.month_name[date.month]} {date.day}, {date.year}." +class Calendar(Tool): + prompt = calendar_prompt + + def __init__(self, date=datetime.datetime.now()): + self.date = date + + def __call__(self, text: str) -> str: + return f"Today is {calendar.day_name[self.date.weekday()]}, {calendar.month_name[self.date.month]} {self.date.day}, {self.date.year}." """ @@ -139,13 +169,16 @@ def colbertv2_get_request(url: str, query: str, k: int): return topk -def WikiSearch(input_query: str): - k = 10 - retrieval_model = ColBERTv2( - "http://ec2-44-228-128-229.us-west-2.compute.amazonaws.com:8893/api/search" - ) - output = retrieval_model(input_query, k) - return output +class WikiSearch(Tool): + prompt = wikipedia_search_prompt + + def __call__(self, input_query: str): + k = 10 + retrieval_model = ColBERTv2( + "http://ec2-44-228-128-229.us-west-2.compute.amazonaws.com:8893/api/search" + ) + output = retrieval_model(input_query, k) + return output """ @@ -159,17 +192,22 @@ def WikiSearch(input_query: str): """ -def MT(input_query: str): - model_name = "facebook/nllb-200-distilled-600M" - tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForSeq2SeqLM.from_pretrained(model_name) - input_ids = tokenizer(input_query, return_tensors="pt") - outputs = model.generate( - **input_ids, - forced_bos_token_id=tokenizer.lang_code_to_id["eng_Latn"], - ) - output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] - return output +class MT(Tool): + prompt = machine_translation_prompt + + def __init__(self): + self.model_name = "facebook/nllb-200-distilled-600M" + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) + + def __call__(self, text: str) -> str: + input_ids = self.tokenizer(text, return_tensors="pt") + outputs = self.model.generate( + **input_ids, + forced_bos_token_id=self.tokenizer.lang_code_to_id["eng_Latn"], + ) + output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] + return output """ @@ -185,14 +223,36 @@ def MT(input_query: str): """ -def Calculator(input_query: str): - operators = {"+": add, "-": sub, "*": mul, "/": truediv} - if input_query.isdigit(): - return float(input_query) - for c in operators.keys(): - left, operator, right = input_query.partition(c) - if operator in operators: - return round(operators[operator](Calculator(left), Calculator(right)), 2) +class Calculator(Tool): + prompt = calculator_prompt + + def heuristic(self, input: dict): + text = input["text"] + calc_pattern = re.compile("^(\d+[\+\-\*\/]{1})+\d+$") + operators = bool(re.search(calc_pattern, text)) + equals = any( + ["=" in text, "equal to" in text, "total of" in text, "average of" in text] + ) + if not (operators and equals): + text = text.replace("\n", " ") + text = text.split(" ") + text = [item for item in text if item.replace(".", "", 1).isnumeric()] + if len(text) >= 3: + if random.randint(0, 99) == 0: + return True + else: + return True + + return False + + def __call__(self, input_query: str): + operators = {"+": add, "-": sub, "*": mul, "/": truediv} + if input_query.isdigit(): + return float(input_query) + for c in operators.keys(): + left, operator, right = input_query.partition(c) + if operator in operators: + return round(operators[operator](self(left), self(right)), 2) # Other Optional Tools @@ -206,6 +266,8 @@ def Calculator(input_query: str): Requires that you set your COHERE_API_KEY environment variable before starting. """ + + def langchain_llmchain(input_question): # TODO: Check succinct if it's good once we don't have rate limited APIs template = """Please be succinct in your answer to this question. @@ -231,21 +293,22 @@ def langchain_llmchain(input_question): """ -def HuggingfaceAPI(input_query: str): - model_id = "gpt-neox-20b" - API_TOKEN = "YOUR_API_TOKEN" - API_URL = "https://api-inference.huggingface.co/models/{model_id}".format( - model_id=model_id - ) - headers = {"Authorization": f"Bearer {API_TOKEN}".format(API_TOKEN=API_TOKEN)} +class HuggingFaceAPI(Tool): + def __call__(input_query: str): + model_id = "gpt-neox-20b" + API_TOKEN = "YOUR_API_TOKEN" + API_URL = "https://api-inference.huggingface.co/models/{model_id}".format( + model_id=model_id + ) + headers = {"Authorization": f"Bearer {API_TOKEN}".format(API_TOKEN=API_TOKEN)} - def query(payload): - data = json.dumps(payload) - response = requests.request("POST", API_URL, headers=headers, data=data) - return json.loads(response.content.decode("utf-8")) + def query(payload): + data = json.dumps(payload) + response = requests.request("POST", API_URL, headers=headers, data=data) + return json.loads(response.content.decode("utf-8")) - data = query(input_query) - return data[0]["generated_text"] + data = query(input_query) + return data[0]["generated_text"] """