Skip to content

Adding torch accelerator to ddp-tutorial-series example #1376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions distributed/ddp-tutorial-series/multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,24 @@ def ddp_setup(rank, world_size):
world_size: Total number of processes
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.cuda.set_device(rank)
init_process_group(backend="nccl", rank=rank, world_size=world_size)
os.environ["MASTER_PORT"] = "12453"


if torch.accelerator.is_available():
device_type = torch.accelerator.current_accelerator()
torch.accelerator.set_device_idx(rank)
device: torch.device = torch.device(f"{device_type}:{rank}")
torch.accelerator.device_index(rank)
print(f"Running on rank {rank} on device {device}")
backend = torch.distributed.get_default_backend_for_device(device)
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, device_id=device)
else:
device = torch.device("cpu")
print(f"Running on device {device}")
torch.distributed.init_process_group(backend="gloo", device_id=device)

# torch.cuda.set_device(rank)
# init_process_group(backend="xccl", rank=rank, world_size=world_size)

class Trainer:
def __init__(
Expand Down Expand Up @@ -100,5 +115,6 @@ def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_s
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
args = parser.parse_args()

world_size = torch.cuda.device_count()
world_size = torch.accelerator.device_count()
print(world_size)
mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)
25 changes: 20 additions & 5 deletions distributed/ddp-tutorial-series/multigpu_torchrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,21 @@


def ddp_setup():
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
init_process_group(backend="nccl")
rank = int(os.environ["LOCAL_RANK"])
if torch.accelerator.is_available():
device_type = torch.accelerator.current_accelerator()
device: torch.device = torch.device(f"{device_type}:{rank}")
torch.accelerator.device_index(rank)
print(f"Running on rank {rank} on device {device}")
backend = torch.distributed.get_default_backend_for_device(device)
torch.distributed.init_process_group(backend=backend)
return device_type
else:
device = torch.device("cpu")
print(f"Running on device {device}")
torch.distributed.init_process_group(backend="gloo")
return device


class Trainer:
def __init__(
Expand All @@ -22,6 +35,7 @@ def __init__(
optimizer: torch.optim.Optimizer,
save_every: int,
snapshot_path: str,
device
) -> None:
self.gpu_id = int(os.environ["LOCAL_RANK"])
self.model = model.to(self.gpu_id)
Expand All @@ -30,14 +44,15 @@ def __init__(
self.save_every = save_every
self.epochs_run = 0
self.snapshot_path = snapshot_path
self.device = device
if os.path.exists(snapshot_path):
print("Loading snapshot")
self._load_snapshot(snapshot_path)

self.model = DDP(self.model, device_ids=[self.gpu_id])

def _load_snapshot(self, snapshot_path):
loc = f"cuda:{self.gpu_id}"
loc = str(self.device)
snapshot = torch.load(snapshot_path, map_location=loc)
self.model.load_state_dict(snapshot["MODEL_STATE"])
self.epochs_run = snapshot["EPOCHS_RUN"]
Expand Down Expand Up @@ -92,10 +107,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int):


def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"):
ddp_setup()
device = ddp_setup()
dataset, model, optimizer = load_train_objs()
train_data = prepare_dataloader(dataset, batch_size)
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, device)
trainer.train(total_epochs)
destroy_process_group()

Expand Down
24 changes: 19 additions & 5 deletions distributed/ddp-tutorial-series/multinode.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,20 @@


def ddp_setup():
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
init_process_group(backend="nccl")
rank = int(os.environ["LOCAL_RANK"])
if torch.accelerator.is_available():
device_type = torch.accelerator.current_accelerator()
device: torch.device = torch.device(f"{device_type}:{rank}")
torch.accelerator.device_index(rank)
print(f"Running on rank {rank} on device {device}")
backend = torch.distributed.get_default_backend_for_device(device)
torch.distributed.init_process_group(backend=backend)
return device_type
else:
device = torch.device("cpu")
print(f"Running on device {device}")
torch.distributed.init_process_group(backend="gloo")
return device

class Trainer:
def __init__(
Expand All @@ -22,6 +34,7 @@ def __init__(
optimizer: torch.optim.Optimizer,
save_every: int,
snapshot_path: str,
device
) -> None:
self.local_rank = int(os.environ["LOCAL_RANK"])
self.global_rank = int(os.environ["RANK"])
Expand All @@ -31,14 +44,15 @@ def __init__(
self.save_every = save_every
self.epochs_run = 0
self.snapshot_path = snapshot_path
self.device = device
if os.path.exists(snapshot_path):
print("Loading snapshot")
self._load_snapshot(snapshot_path)

self.model = DDP(self.model, device_ids=[self.local_rank])

def _load_snapshot(self, snapshot_path):
loc = f"cuda:{self.local_rank}"
loc = str(self.device)
snapshot = torch.load(snapshot_path, map_location=loc)
self.model.load_state_dict(snapshot["MODEL_STATE"])
self.epochs_run = snapshot["EPOCHS_RUN"]
Expand Down Expand Up @@ -93,10 +107,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int):


def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"):
ddp_setup()
device = ddp_setup()
dataset, model, optimizer = load_train_objs()
train_data = prepare_dataloader(dataset, batch_size)
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, device)
trainer.train(total_epochs)
destroy_process_group()

Expand Down
2 changes: 1 addition & 1 deletion distributed/ddp-tutorial-series/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
torch>=1.11.0
torch>=2.7
10 changes: 10 additions & 0 deletions distributed/ddp-tutorial-series/run_example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# /bin/bash
# bash run_example.sh {file_to_run.py} {num_gpus}
# where file_to_run = example to run. Default = 'example.py'
# num_gpus = num local gpus to use (must be at least 2). Default = 2

# samples to run include:
# example.py

echo "Launching ${1:-example.py} with ${2:-2} gpus"
torchrun --nnodes=1 --nproc_per_node=${2:-2} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-example.py}
6 changes: 6 additions & 0 deletions run_distributed_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ function distributed_tensor_parallelism() {
uv run bash run_example.sh fsdp_tp_example.py || error "2D parallel example failed"
}

function distributed_ddp-tutorial-series() {
uv run bash run_example.sh multigpu.py || error "ddp tutorial series multigpu example failed"
uv run bash run_example.sh multigpu_torchrun.py || error "ddp tutorial series multigpu torchrun example failed"
uv run bash run_example.sh multinode.py || error "ddp tutorial series multinode example failed"
}

function distributed_ddp() {
uv run main.py || error "ddp example failed"
}
Expand Down