diff --git a/QEfficient/cloud/finetune.py b/QEfficient/cloud/finetune.py index d8de58951..b8138e8b0 100644 --- a/QEfficient/cloud/finetune.py +++ b/QEfficient/cloud/finetune.py @@ -80,9 +80,8 @@ def setup_distributed_training(train_config: TrainConfig) -> None: assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}" dist_backend_map = {"cpu": "gloo", "qaic": "qccl", "cuda": "gloo"} dist.init_process_group(backend=dist_backend_map[torch_device.type]) - if not train_config.enable_pp: - # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank - getattr(torch, torch_device.type).set_device(dist.get_rank()) + # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank + getattr(torch, torch_device.type).set_device(dist.get_rank() * train_config.num_pp_stages) def setup_seeds(seed: int) -> None: