Skip to content

Commit 14f8981

Browse files
committed
init llama infer
1 parent 9721837 commit 14f8981

File tree

1 file changed

+70
-0
lines changed
  • torchprime/experimental/torchax_models/inference

1 file changed

+70
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torchax.interop
2+
from torchprime.experimental.torchax_models.inference.llama_run import model
3+
import torch
4+
import torchax
5+
import torchax.config
6+
import jax
7+
import time
8+
9+
env = torchax.default_env()
10+
torch.manual_seed(42)
11+
torch.set_default_dtype(torch.bfloat16)
12+
torchax.enable_performance_mode()
13+
14+
max_seq_len = 512 # 8192
15+
vocab_size = 128 # 32000
16+
n_layer = 1
17+
n_heads = 4
18+
dim = 8
19+
block_size = 16 # 2048
20+
batch_size = 1
21+
22+
23+
def fake_dataloader(size, vocab_size, seqlen, batch_size):
24+
for _ in range(size):
25+
x = torch.randint(0, vocab_size, (batch_size, seqlen), device="cpu")
26+
yield x
27+
28+
29+
if __name__ == "__main__":
30+
with torch.no_grad():
31+
input = torch.randint(0, vocab_size, (1, max_seq_len))
32+
model_args = model.ModelArgs(
33+
block_size=block_size,
34+
vocab_size=vocab_size,
35+
n_layer=n_layer,
36+
n_heads=n_heads,
37+
dim=dim,
38+
max_seq_len=max_seq_len,
39+
)
40+
freqs_cis = model.precompute_freqs_cis(
41+
model_args.dim // model_args.n_heads,
42+
model_args.max_seq_len,
43+
model_args.rope_theta,
44+
model_args.use_scaled_rope,
45+
).to(torch.bfloat16)
46+
m = model.Transformer(model_args)
47+
m.to(torch.bfloat16)
48+
49+
# TODO: move weight as arguemts of forward
50+
def forward(input, freqs_cis, mask):
51+
return m(input, 0, freqs_cis=freqs_cis, mask=mask)
52+
53+
jitted_forward = torchax.interop.jax_jit(forward)
54+
55+
data_iter = fake_dataloader(5, vocab_size, max_seq_len, batch_size)
56+
with env:
57+
m.to("jax")
58+
freqs_cis = freqs_cis.to("jax")
59+
for i, input in enumerate(data_iter):
60+
input = input.to("jax")
61+
mask = torch.ones_like(input)
62+
step_start = time.perf_counter()
63+
output = jitted_forward(input, freqs_cis, mask)
64+
jax.block_until_ready(torchax.tensor.t2j(output))
65+
step_end = time.perf_counter()
66+
print(
67+
i,
68+
"step latency: ",
69+
step_end - step_start,
70+
)

0 commit comments

Comments
 (0)