diff --git a/distributed/ddp-tutorial-series/multigpu.py b/distributed/ddp-tutorial-series/multigpu.py index 7e11633305..9f573e88e1 100644 --- a/distributed/ddp-tutorial-series/multigpu.py +++ b/distributed/ddp-tutorial-series/multigpu.py @@ -17,9 +17,24 @@ def ddp_setup(rank, world_size): world_size: Total number of processes """ os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - torch.cuda.set_device(rank) - init_process_group(backend="nccl", rank=rank, world_size=world_size) + os.environ["MASTER_PORT"] = "12453" + + + if torch.accelerator.is_available(): + device_type = torch.accelerator.current_accelerator() + torch.accelerator.set_device_idx(rank) + device: torch.device = torch.device(f"{device_type}:{rank}") + torch.accelerator.device_index(rank) + print(f"Running on rank {rank} on device {device}") + backend = torch.distributed.get_default_backend_for_device(device) + torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, device_id=device) + else: + device = torch.device("cpu") + print(f"Running on device {device}") + torch.distributed.init_process_group(backend="gloo", device_id=device) + + # torch.cuda.set_device(rank) + # init_process_group(backend="xccl", rank=rank, world_size=world_size) class Trainer: def __init__( @@ -100,5 +115,6 @@ def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_s parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') args = parser.parse_args() - world_size = torch.cuda.device_count() + world_size = torch.accelerator.device_count() + print(world_size) mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size) diff --git a/distributed/ddp-tutorial-series/multigpu_torchrun.py b/distributed/ddp-tutorial-series/multigpu_torchrun.py index 32d6254d2d..5a0118112b 100644 --- a/distributed/ddp-tutorial-series/multigpu_torchrun.py +++ b/distributed/ddp-tutorial-series/multigpu_torchrun.py @@ -11,8 +11,21 @@ def ddp_setup(): - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - init_process_group(backend="nccl") + rank = int(os.environ["LOCAL_RANK"]) + if torch.accelerator.is_available(): + device_type = torch.accelerator.current_accelerator() + device: torch.device = torch.device(f"{device_type}:{rank}") + torch.accelerator.device_index(rank) + print(f"Running on rank {rank} on device {device}") + backend = torch.distributed.get_default_backend_for_device(device) + torch.distributed.init_process_group(backend=backend) + return device_type + else: + device = torch.device("cpu") + print(f"Running on device {device}") + torch.distributed.init_process_group(backend="gloo") + return device + class Trainer: def __init__( @@ -22,6 +35,7 @@ def __init__( optimizer: torch.optim.Optimizer, save_every: int, snapshot_path: str, + device ) -> None: self.gpu_id = int(os.environ["LOCAL_RANK"]) self.model = model.to(self.gpu_id) @@ -30,6 +44,7 @@ def __init__( self.save_every = save_every self.epochs_run = 0 self.snapshot_path = snapshot_path + self.device = device if os.path.exists(snapshot_path): print("Loading snapshot") self._load_snapshot(snapshot_path) @@ -37,7 +52,7 @@ def __init__( self.model = DDP(self.model, device_ids=[self.gpu_id]) def _load_snapshot(self, snapshot_path): - loc = f"cuda:{self.gpu_id}" + loc = str(self.device) snapshot = torch.load(snapshot_path, map_location=loc) self.model.load_state_dict(snapshot["MODEL_STATE"]) self.epochs_run = snapshot["EPOCHS_RUN"] @@ -92,10 +107,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int): def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"): - ddp_setup() + device = ddp_setup() dataset, model, optimizer = load_train_objs() train_data = prepare_dataloader(dataset, batch_size) - trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path) + trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, device) trainer.train(total_epochs) destroy_process_group() diff --git a/distributed/ddp-tutorial-series/multinode.py b/distributed/ddp-tutorial-series/multinode.py index 2cbae84b56..838056a42c 100644 --- a/distributed/ddp-tutorial-series/multinode.py +++ b/distributed/ddp-tutorial-series/multinode.py @@ -11,8 +11,20 @@ def ddp_setup(): - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - init_process_group(backend="nccl") + rank = int(os.environ["LOCAL_RANK"]) + if torch.accelerator.is_available(): + device_type = torch.accelerator.current_accelerator() + device: torch.device = torch.device(f"{device_type}:{rank}") + torch.accelerator.device_index(rank) + print(f"Running on rank {rank} on device {device}") + backend = torch.distributed.get_default_backend_for_device(device) + torch.distributed.init_process_group(backend=backend) + return device_type + else: + device = torch.device("cpu") + print(f"Running on device {device}") + torch.distributed.init_process_group(backend="gloo") + return device class Trainer: def __init__( @@ -22,6 +34,7 @@ def __init__( optimizer: torch.optim.Optimizer, save_every: int, snapshot_path: str, + device ) -> None: self.local_rank = int(os.environ["LOCAL_RANK"]) self.global_rank = int(os.environ["RANK"]) @@ -31,6 +44,7 @@ def __init__( self.save_every = save_every self.epochs_run = 0 self.snapshot_path = snapshot_path + self.device = device if os.path.exists(snapshot_path): print("Loading snapshot") self._load_snapshot(snapshot_path) @@ -38,7 +52,7 @@ def __init__( self.model = DDP(self.model, device_ids=[self.local_rank]) def _load_snapshot(self, snapshot_path): - loc = f"cuda:{self.local_rank}" + loc = str(self.device) snapshot = torch.load(snapshot_path, map_location=loc) self.model.load_state_dict(snapshot["MODEL_STATE"]) self.epochs_run = snapshot["EPOCHS_RUN"] @@ -93,10 +107,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int): def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"): - ddp_setup() + device = ddp_setup() dataset, model, optimizer = load_train_objs() train_data = prepare_dataloader(dataset, batch_size) - trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path) + trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, device) trainer.train(total_epochs) destroy_process_group() diff --git a/distributed/ddp-tutorial-series/requirements.txt b/distributed/ddp-tutorial-series/requirements.txt index 9270a1d6ee..d5656cb6b2 100644 --- a/distributed/ddp-tutorial-series/requirements.txt +++ b/distributed/ddp-tutorial-series/requirements.txt @@ -1 +1 @@ -torch>=1.11.0 \ No newline at end of file +torch>=2.7 \ No newline at end of file diff --git a/distributed/ddp-tutorial-series/run_example.sh b/distributed/ddp-tutorial-series/run_example.sh new file mode 100644 index 0000000000..9320951532 --- /dev/null +++ b/distributed/ddp-tutorial-series/run_example.sh @@ -0,0 +1,10 @@ +# /bin/bash +# bash run_example.sh {file_to_run.py} {num_gpus} +# where file_to_run = example to run. Default = 'example.py' +# num_gpus = num local gpus to use (must be at least 2). Default = 2 + +# samples to run include: +# example.py + +echo "Launching ${1:-example.py} with ${2:-2} gpus" +torchrun --nnodes=1 --nproc_per_node=${2:-2} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-example.py} \ No newline at end of file diff --git a/run_distributed_examples.sh b/run_distributed_examples.sh index e1f579c072..a7b03e489b 100755 --- a/run_distributed_examples.sh +++ b/run_distributed_examples.sh @@ -50,6 +50,12 @@ function distributed_tensor_parallelism() { uv run bash run_example.sh fsdp_tp_example.py || error "2D parallel example failed" } +function distributed_ddp-tutorial-series() { + uv run bash run_example.sh multigpu.py || error "ddp tutorial series multigpu example failed" + uv run bash run_example.sh multigpu_torchrun.py || error "ddp tutorial series multigpu torchrun example failed" + uv run bash run_example.sh multinode.py || error "ddp tutorial series multinode example failed" +} + function distributed_ddp() { uv run main.py || error "ddp example failed" }