Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
from llama.model import ModelArgs, Transformer
from llama.tokenizer import ChatFormat, Dialog, Message, Tokenizer

if torch.backends.mps.is_available():
device = torch.device('mps')
torch.set_default_device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')

class CompletionPrediction(TypedDict, total=False):
generation: str
Expand Down Expand Up @@ -67,16 +74,20 @@ def build(
assert 1 <= max_seq_len <= 8192, f"max_seq_len must be between 1 and 8192, got {max_seq_len}."
assert os.path.isdir(ckpt_dir), f"Checkpoint directory '{ckpt_dir}' does not exist."
assert os.path.isfile(tokenizer_path), f"Tokenizer file '{tokenizer_path}' does not exist."

if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if device == "cuda":
torch.distributed.init_process_group("nccl")
else:
torch.distributed.init_process_group("gloo")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
if device == "cuda":
torch.cuda.set_device(local_rank)

# seed must be the same in all processes
torch.manual_seed(seed)
Expand All @@ -102,10 +113,14 @@ def build(
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert model_args.vocab_size == tokenizer.n_words
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
if device == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
torch.set_default_tensor_type(torch.HalfTensor)

model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
Expand Down Expand Up @@ -156,14 +171,14 @@ def generate(
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
eos_reached = torch.tensor([False] * bsz, device=device)
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
logits = self.model.forward(tokens, prev_pos)
Expand Down
27 changes: 18 additions & 9 deletions llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
)
from torch import nn

if torch.backends.mps.is_available():
device = torch.device('mps')
torch.set_default_device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')

@dataclass
class ModelArgs:
Expand Down Expand Up @@ -48,6 +55,7 @@ def forward(self, x):

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
freqs = freqs.to(device)
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
Expand Down Expand Up @@ -103,28 +111,28 @@ def __init__(self, args: ModelArgs):
bias=False,
gather_output=False,
init_method=lambda x: x,
)
).to(device)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
).to(device)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
).to(device)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
).to(device)

self.cache_k = torch.zeros(
(
Expand All @@ -133,15 +141,15 @@ def __init__(self, args: ModelArgs):
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
).to(device)
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
).to(device)

def forward(
self,
Expand Down Expand Up @@ -257,16 +265,16 @@ def __init__(self, params: ModelArgs):

self.tok_embeddings = VocabParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
)
).to(device)

self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.layers.append(TransformerBlock(layer_id, params).to(device))

self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
)
).to(device)

self.freqs_cis = precompute_freqs_cis(
params.dim // params.n_heads,
Expand All @@ -276,6 +284,7 @@ def __init__(self, params: ModelArgs):

@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
tokens = tokens.to(device)
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
Expand Down