Skip to content
This repository was archived by the owner on Jun 21, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
549 changes: 466 additions & 83 deletions data_generator.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions inference/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
model_store/**
3 changes: 3 additions & 0 deletions inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
`python convert_to_triton.py --model EleutherAI/gpt-j-6B`

`./start_server.sh`
Empty file added inference/__init__.py
Empty file.
113 changes: 113 additions & 0 deletions inference/convert_to_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Adapted from https://github.com/CarperAI/trlx/blob/93c90cbdc3c6b463f565b09340ca1f74271285c5/examples/hh/triton_config.pbtxt

import argparse
import os
from string import Template

import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer

parser = argparse.ArgumentParser()

parser.add_argument(
"--model", type=str, required=True, help="Path to HF checkpoint with the base model"
)

parser.add_argument(
"--max-batch-size", type=int, default=4, help="Maximum batch size for inference"
)

parser.add_argument(
"--revision",
type=str,
required=False,
help="Optional branch/commit of the HF checkpoint",
)

parser.add_argument("--device", type=int, default=0)
args = parser.parse_args()

device = torch.device(args.device)


class ModelLogits(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

@torch.inference_mode()
def forward(self, input_ids: torch.Tensor):
return self.model(input_ids).logits


class InferModel(nn.Module):
def __init__(self, traced_model):
super().__init__()
self.traced_model = traced_model

@torch.inference_mode()
def forward(
self,
input_ids: torch.Tensor,
tensor_of_seq_len: torch.Tensor,
temperature: torch.Tensor,
):
for _ in range(tensor_of_seq_len.shape[1] - 1):
logits = self.traced_model(input_ids)
next_token = torch.multinomial(
torch.softmax(logits[:, -1, :] / temperature, dim=-1), 1
).squeeze(1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1)

# in TorchScript, the above logits var lifetime doesn't escape the loop's scope
logits = self.traced_model(input_ids).float()
next_token = torch.multinomial(
torch.softmax(logits[:, -1, :] / temperature, dim=-1), 1
).squeeze(1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1)

return input_ids.int(), logits


print(f"Converting {args.model} to TorchScript...")
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = ModelLogits(AutoModelForCausalLM.from_pretrained(args.model))
model.eval()
model.requires_grad_(False)
model = model.half().to(device)

input = tokenizer("annotator model's hash is 0x", return_tensors="pt").to(device)
print(f"{model(input.input_ids)=}")

traced_script_module = torch.jit.trace(model, input.input_ids)

print(f"{traced_script_module(input.input_ids)=}")

print("Scripting generation wrapper...")

scripted_generator_model = torch.jit.script(InferModel(traced_script_module))

print(f"{input.input_ids=}")
x = input.input_ids, torch.empty(1, 5).cuda(), torch.full([1, 1], 1.0).cuda()
print(f"{(scripted_generator_model(*x))=}")
print(f"{tokenizer.decode(scripted_generator_model(*x)[0][0])=}")

sanitized_name = args.model.replace("/", "--")
print("Model renamed to ", sanitized_name)

print("Saving TorchScript model...")

os.makedirs(f"model_store/{sanitized_name}/1", exist_ok=True)
scripted_generator_model.save(f"model_store/{sanitized_name}/1/traced-model.pt")

config_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "triton_config.pbtxt"
)
with open(config_path) as f:
template = Template(f.read())
config = template.substitute(
{"model_name": sanitized_name, "max_batch_size": args.max_batch_size}
)
with open(f"model_store/{sanitized_name}/config.pbtxt", "w") as f:
f.write(config)
1 change: 1 addition & 0 deletions inference/start_server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
docker run --gpus=1 --rm --net=host -v ${PWD}/model_store:/model_store nvcr.io/nvidia/tritonserver:23.01-py3 tritonserver --model-repository=/model_store
66 changes: 66 additions & 0 deletions inference/triton_config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
name: "${model_name}"
backend: "pytorch"
default_model_filename: "traced-model.pt"
max_batch_size: ${max_batch_size}

parameters {
key: "model_name"
value: {
string_value: "${model_name}"
}
}

instance_group [
{
count: 1
kind: KIND_GPU
gpus: [0]
}
]

input [
{
name: "input_ids"
data_type: TYPE_INT32
dims: [-1]
},
{
name: "tensor_of_seq_len"
data_type: TYPE_INT32
dims: [-1]
},
{
name: "temperature"
data_type: TYPE_FP32
dims: [-1]
}
]

output [
{
name: "output_ids"
data_type: TYPE_INT32
dims: [-1]
},
{
name: "logits"
data_type: TYPE_FP32
dims: [-1]
}
]

parameters {
key: "data_type"
value: {
string_value: "fp16"
}
}

parameters: {
key: "INFERENCE_MODE"
value: {
string_value: "true"
}
}

version_policy: {specific: {versions: [1]}}
23 changes: 9 additions & 14 deletions prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
Output: I went to Paris in 1994 and stayed there until 2011, so in total, it was [Calculator(2011 - 1994)] 17 years.
Input: From this, we have 4 * 30 minutes = 120 minutes.
Output: From this, we have 4 * 30 minutes = [Calculator(4 * 30)] 120 minutes.
Input: x
Output:
"""
Input: <REPLACEGPT>
Output: """

retrieval_prompt = """
Your task is to complete a given piece of text.
Expand All @@ -29,8 +28,7 @@
Input: Metformin is the first-line drug for patients with type 2 diabetes and obesity.
Output: Metformin is the first-line drug for [Retrieval("illness, diabetes, obesity")] patients with type 2 diabetes and obesity.
Input: <REPLACEGPT>
Output:
"""
Output: """

llmchain_prompt = """
Your task is to complete a given piece of text.
Expand Down Expand Up @@ -58,9 +56,8 @@
Output: But what are the risks during production of nanomaterials? [WikiSearch("nanomaterial production risks")] Some nanomaterials may give rise to various kinds of lung damage.
Input: Metformin is the first-line drug for patients with type 2 diabetes and obesity.
Output: Metformin is the first-line drug for [WikiSearch("Metformin first-line drug")] patients with type 2 diabetes and obesity.
Input: x
Output:
"""
Input: <REPLACEGPT>
Output: """

machine_translation_prompt = """
Your task is to complete a given piece of text by using a Machine Translation API.
Expand All @@ -72,9 +69,8 @@
Output: In Morris de Jonge’s Jeschuah, der klassische jüdische Mann [MT(der klassische jüdische Mann)], there is a description of a Jewish writer
Input: 南 京 高 淳 县 住 房 和 城 乡 建 设 局 城 市 新 区 设 计 a plane of reference Gaochun is one of seven districts of the provincial capital Nanjing
Output: [MT(南京高淳县住房和城乡建设局 城市新 区 设 计)] a plane of reference Gaochun is one of seven districts of the provincial capital Nanjing
Input: x
Output:
"""
Input: <REPLACEGPT>
Output: """

calendar_prompt = """
Your task is to add calls to a Calendar API to a piece of text.
Expand All @@ -91,6 +87,5 @@
Output: The number of days from now until Christmas is [Calendar()] 30.
Input: The store is never open on the weekend, so today it is closed.
Output: The store is never open on the weekend, so today [Calendar()] it is closed.
Input: x
Output:
"""
Input: <REPLACEGPT>
Output: """
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ wolframalpha
transformers
openai
langchain
cohere
cohere
tritonclient==2.31.0
tqdm
Loading