diff --git a/llama/generation.py b/llama/generation.py index 96be4b29..a5a7a4e2 100644 --- a/llama/generation.py +++ b/llama/generation.py @@ -19,6 +19,13 @@ from llama.model import ModelArgs, Transformer from llama.tokenizer import ChatFormat, Dialog, Message, Tokenizer +if torch.backends.mps.is_available(): + device = torch.device('mps') + torch.set_default_device('mps') +elif torch.cuda.is_available(): + device = torch.device('cuda') +else: + device = torch.device('cpu') class CompletionPrediction(TypedDict, total=False): generation: str @@ -67,16 +74,20 @@ def build( assert 1 <= max_seq_len <= 8192, f"max_seq_len must be between 1 and 8192, got {max_seq_len}." assert os.path.isdir(ckpt_dir), f"Checkpoint directory '{ckpt_dir}' does not exist." assert os.path.isfile(tokenizer_path), f"Tokenizer file '{tokenizer_path}' does not exist." - + if not torch.distributed.is_initialized(): - torch.distributed.init_process_group("nccl") + if device == "cuda": + torch.distributed.init_process_group("nccl") + else: + torch.distributed.init_process_group("gloo") if not model_parallel_is_initialized(): if model_parallel_size is None: model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) initialize_model_parallel(model_parallel_size) local_rank = int(os.environ.get("LOCAL_RANK", 0)) - torch.cuda.set_device(local_rank) + if device == "cuda": + torch.cuda.set_device(local_rank) # seed must be the same in all processes torch.manual_seed(seed) @@ -102,10 +113,14 @@ def build( ) tokenizer = Tokenizer(model_path=tokenizer_path) assert model_args.vocab_size == tokenizer.n_words - if torch.cuda.is_bf16_supported(): - torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + if device == "cuda": + if torch.cuda.is_bf16_supported(): + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + else: + torch.set_default_tensor_type(torch.cuda.HalfTensor) else: - torch.set_default_tensor_type(torch.cuda.HalfTensor) + torch.set_default_tensor_type(torch.HalfTensor) + model = Transformer(model_args) model.load_state_dict(checkpoint, strict=False) print(f"Loaded in {time.time() - start_time:.2f} seconds") @@ -156,14 +171,14 @@ def generate( total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) pad_id = self.tokenizer.pad_id - tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device) for k, t in enumerate(prompt_tokens): - tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device) if logprobs: token_logprobs = torch.zeros_like(tokens, dtype=torch.float) prev_pos = 0 - eos_reached = torch.tensor([False] * bsz, device="cuda") + eos_reached = torch.tensor([False] * bsz, device=device) input_text_mask = tokens != pad_id if min_prompt_len == total_len: logits = self.model.forward(tokens, prev_pos) diff --git a/llama/model.py b/llama/model.py index e388c038..529b32a2 100644 --- a/llama/model.py +++ b/llama/model.py @@ -15,6 +15,13 @@ ) from torch import nn +if torch.backends.mps.is_available(): + device = torch.device('mps') + torch.set_default_device('mps') +elif torch.cuda.is_available(): + device = torch.device('cuda') +else: + device = torch.device('cpu') @dataclass class ModelArgs: @@ -48,6 +55,7 @@ def forward(self, x): def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + freqs = freqs.to(device) t = torch.arange(end, device=freqs.device, dtype=torch.float32) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 @@ -103,28 +111,28 @@ def __init__(self, args: ModelArgs): bias=False, gather_output=False, init_method=lambda x: x, - ) + ).to(device) self.wk = ColumnParallelLinear( args.dim, self.n_kv_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x, - ) + ).to(device) self.wv = ColumnParallelLinear( args.dim, self.n_kv_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x, - ) + ).to(device) self.wo = RowParallelLinear( args.n_heads * self.head_dim, args.dim, bias=False, input_is_parallel=True, init_method=lambda x: x, - ) + ).to(device) self.cache_k = torch.zeros( ( @@ -133,7 +141,7 @@ def __init__(self, args: ModelArgs): self.n_local_kv_heads, self.head_dim, ) - ).cuda() + ).to(device) self.cache_v = torch.zeros( ( args.max_batch_size, @@ -141,7 +149,7 @@ def __init__(self, args: ModelArgs): self.n_local_kv_heads, self.head_dim, ) - ).cuda() + ).to(device) def forward( self, @@ -257,16 +265,16 @@ def __init__(self, params: ModelArgs): self.tok_embeddings = VocabParallelEmbedding( params.vocab_size, params.dim, init_method=lambda x: x - ) + ).to(device) self.layers = torch.nn.ModuleList() for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock(layer_id, params)) + self.layers.append(TransformerBlock(layer_id, params).to(device)) self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = ColumnParallelLinear( params.dim, params.vocab_size, bias=False, init_method=lambda x: x - ) + ).to(device) self.freqs_cis = precompute_freqs_cis( params.dim // params.n_heads, @@ -276,6 +284,7 @@ def __init__(self, params: ModelArgs): @torch.inference_mode() def forward(self, tokens: torch.Tensor, start_pos: int): + tokens = tokens.to(device) _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) self.freqs_cis = self.freqs_cis.to(h.device)