diff --git a/test/distributed/test_c10d_ops_xccl.py b/test/distributed/test_c10d_ops_xccl.py new file mode 100644 index 0000000000000..5d041058ead41 --- /dev/null +++ b/test/distributed/test_c10d_ops_xccl.py @@ -0,0 +1,852 @@ +# Owner(s): ["oncall: distributed"] +# This test file contains positive tests for c10d with XCCL backend. +# During the test, it is expected that ProcessGroup will not be aborted, destroyed or incur fatal error. +# Please be mindful of this when adding tests here. +# If you need to add tests for group creation, abort or destroy, please add tests in test_c10d_xccl.py. + +# There are two ways to launch tests in this file: +# 1. Run this file directly with `python test_c10d_ops_xccl.py` +# 2. Use multi-process launcher, e.g. `torchrun --standalone --nproc-per-node 2 test_c10d_ops_xccl.py` + +import math +import os +import sys +import tempfile + +import torch +import torch.distributed as c10d + + +if not c10d.is_available() or not c10d.is_xccl_available(): + print("c10d XCCL not available, skipping tests", file=sys.stderr) + sys.exit(0) + + +import torch.distributed as dist +from torch.testing._internal.common_distributed import ( + init_multigpu_helper, + MultiProcContinousTest, + requires_xccl, +) +from torch.testing._internal.common_utils import ( + skip_but_pass_in_sandcastle_if, + skipIfRocm, + TEST_WITH_DEV_DBG_ASAN, + TEST_XPU, +) + + +if TEST_WITH_DEV_DBG_ASAN: + print( + "Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr + ) + sys.exit(0) + +TEST_MULTIGPU = TEST_XPU and torch.xpu.device_count() >= 2 + +class ProcessGroupXCCLOpTest(MultiProcContinousTest): + @classmethod + def backend_str(cls) -> str: + return "xccl" + + # @classmethod + # def opts(cls): + # opts = c10d.ProcessGroupXCCL.Options() + # return opts + + @property + def rank_to_GPU(self): + # return rank to GPU map + return init_multigpu_helper(self.world_size, "xccl") + + # TODO: wait reduce + # @requires_xccl() + # @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + # def test_empty_tensors(self): + # pg = self.pg + # local_device_idx = self.rank_to_GPU[self.rank][0] + + # xs = [torch.FloatTensor([]).xpu(local_device_idx)] + # pg.broadcast(xs).wait() + # self.assertEqual(0, xs[0].numel()) + + # pg.allreduce(xs).wait() + # self.assertEqual(0, xs[0].numel()) + + # pg.reduce(xs).wait() + # self.assertEqual(0, xs[0].numel()) + + # ys = [ + # [ + # torch.FloatTensor([]).xpu(local_device_idx) + # for _ in range(self.world_size) + # ] + # ] + # pg.allgather(ys, xs).wait() + # for y in ys[0]: + # self.assertEqual(0, y.numel()) + + # ys = [torch.FloatTensor([]).xpu(local_device_idx)] + # xs = [ + # [ + # torch.FloatTensor([]).xpu(local_device_idx) + # for _ in range(self.world_size) + # ] + # ] + # pg.reduce_scatter(ys, xs).wait() + # self.assertEqual(0, ys[0].numel()) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_broadcast_ops(self): + pg = self.pg + + def broadcast(xs, rootRank, rootTensor): + opts = c10d.BroadcastOptions() + opts.rootRank = rootRank + opts.rootTensor = rootTensor + work = pg.broadcast(xs, opts) + work.wait() + return xs + + # Every rank is root once + for i in range(self.world_size): + # Run with 1 input tensor + x = torch.tensor([self.rank]).xpu(self.rank_to_GPU[self.rank][0]) + output = broadcast([x], i, 0) + self.assertEqual(torch.tensor([i]), output[0]) + + expected_tensor = torch.empty([i + 1, i + 1]).fill_(i + 1) + xs = [ + torch.empty([i + 1, i + 1]).fill_(-1).xpu(device=device_idx) + for device_idx in self.rank_to_GPU[self.rank] + ] + + # test with multiple input tensors (multiple gpu in one rank) + for j in range(len(xs)): + if self.rank == i: + xs[j] = expected_tensor.xpu(device=self.rank_to_GPU[self.rank][j]) + + broadcast(xs, i, j) + + for tensor in xs: + self.assertEqual(tensor, expected_tensor) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_allreduce_ops(self): + device_count = torch.xpu.device_count() + pg = self.pg + local_device_id = self.rank_to_GPU[self.rank][0] + + def allreduce(tensors, op): + opts = c10d.AllreduceOptions() + opts.reduceOp = op + work = pg.allreduce(tensors, opts) + work.wait() + + # Sum + tensors = [torch.tensor([self.rank + 1]).xpu(local_device_id)] + + allreduce(tensors, c10d.ReduceOp.SUM) + + ndev = self.world_size + self.assertEqual( + torch.tensor([ndev * (ndev + 1) // 2]), + tensors[0], + ) + + # Product + tensors = [torch.tensor([self.rank + 1]).xpu(local_device_id)] + + allreduce(tensors, c10d.ReduceOp.PRODUCT) + self.assertEqual(torch.tensor([math.factorial(self.world_size)]), tensors[0]) + + # Min + tensors = [torch.tensor([self.rank + 1]).xpu(local_device_id)] + + allreduce(tensors, c10d.ReduceOp.MIN) + self.assertEqual(torch.tensor([1]), tensors[0]) + + # Max + tensors = [torch.tensor([self.rank + 1]).xpu(local_device_id)] + + allreduce(tensors, c10d.ReduceOp.MAX) + self.assertEqual(torch.tensor([self.world_size]), tensors[0]) + + for op, err in zip( + (c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR), + ("ReduceOp.BAND", "ReduceOp.BOR", "ReduceOp.BXOR"), + ): + with self.assertRaisesRegex(ValueError, "Cannot use " + err + " with XCCL"): + allreduce(tensors, op) + + # TODO: wait all2all + # @requires_xccl() + # @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + # def test_alltoall_ops_with_xpufree_race(self): + # pg = self.pg + # opts = c10d.AllToAllOptions() + # local_device = f"xpu:{self.rank_to_GPU[self.rank][0]}" + # torch.xpu.set_device(local_device) + # input = torch.rand(1000, 1000, device=local_device) + # output = torch.rand(1000, 1000, device=local_device) + # race_tensors = [] + # # create some tensors to race with alltoall collective + # for _ in range(10): + # tmp = [] + # for i in range(5): + # tmp.append(torch.rand(10 ** (3 + i), device=local_device)) + # race_tensors.append(tmp) + + # for i in range(10): + # race_tensors.pop() + # work = pg.alltoall_base(output, input, [], [], opts) + # # this triggers xpuFree + # torch.xpu.empty_cache() + # work.wait() + # torch.xpu.synchronize(device=local_device) + + # TODO: wait reduce + # @requires_xccl() + # @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + # def test_reduce_ops(self): + # pg = self.pg + # local_device_id = self.rank_to_GPU[self.rank][0] + + # def reduce(xs, rootRank, rootTensor, op=None): + # opts = c10d.ReduceOptions() + # opts.rootRank = rootRank + # opts.rootTensor = rootTensor + # if op: + # opts.reduceOp = op + # work = pg.reduce(xs, opts) + # work.wait() + + # # for every root tensor + # for rt in range(self.world_size): + # tensors = [torch.tensor([self.rank + 1]).xpu(local_device_id)] + + # reduce(tensors, rt, 0) + + # if self.rank == rt: + # self.assertEqual( + # torch.tensor([self.world_size * (self.world_size + 1) // 2]), + # tensors[0], + # ) + # else: + # self.assertEqual( + # torch.tensor([self.rank + 1]), + # tensors[0], + # ) + + # for op, err in zip( + # (c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR), + # ("ReduceOp.BAND", "ReduceOp.BOR", "ReduceOp.BXOR"), + # ): + # with self.assertRaisesRegex( + # ValueError, "Cannot use " + err + " with XCCL" + # ): + # reduce(tensors, self.rank, rt, op) + + # # Premul sum + # if torch.xpu.xccl.version() >= (2, 11, 1): + # for factor in (3.0, torch.tensor([5.0], device=local_device_id)): + # if isinstance(factor, torch.Tensor): + # factor_ref = factor.cpu().item() + # else: + # factor_ref = factor + # float_tensors = [ + # torch.tensor( + # [self.rank + 1.0], device=f"xpu:{local_device_id}" + # ) + # ] + # float_tensors_ref = [ + # torch.tensor( + # [(self.rank + 1.0) * factor_ref], + # device=f"xpu:{local_device_id}", + # ) + # ] + + # reduce(float_tensors_ref, rt, 0) + # reduce(float_tensors, rt, 0, c10d._make_xccl_premul_sum(factor)) + # if self.rank == rt: + # self.assertEqual(float_tensors_ref[0], float_tensors[0]) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_allgather_ops(self): + pg = self.pg + local_device_ids = self.rank_to_GPU[self.rank] + + def allgather(output_ts, input_ts): + work = pg.allgather(output_ts, input_ts) + return work.wait() + + tensors = [torch.empty(2, 2).fill_(2).xpu(device=i) for i in local_device_ids] + output_tensors = [] + expected_output = [] + + output_per_gpu = ( + [torch.empty(2, 2).fill_(-1)] * len(local_device_ids) * self.world_size + ) + expected_per_gpu = ( + [torch.empty(2, 2).fill_(2)] * len(local_device_ids) * self.world_size + ) + + for gpu in local_device_ids: + output_tensors.append([t.xpu(device=gpu) for t in output_per_gpu]) + expected_output.append([t.xpu(device=gpu) for t in expected_per_gpu]) + + result = allgather(output_tensors, tensors) + + # Verification + self.assertEqual(output_tensors, expected_output) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_allgather_base_ops(self): + pg = self.pg + local_device_id = self.rank_to_GPU[self.rank][0] + + def allgather_base(output_t, input_t): + work = pg._allgather_base(output_t, input_t) + work.wait() + + # allgather_base is GPU number agnostic. + # Each rank contribute one tensor regardless of GPU counts + tensor = torch.tensor([self.rank]).xpu(local_device_id) + output_t = torch.empty((self.world_size), dtype=tensor.dtype).xpu( + local_device_id + ) + + allgather_base(output_t, tensor) + + # Verification + self.assertEqual(torch.arange(self.world_size), output_t) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_allgather_base_basics(self): + pg = self.pg + local_device_id = self.rank_to_GPU[self.rank][0] + + def allgather_base(output_t, input_t): + work = pg._allgather_base(output_t, input_t) + work.wait() + + # anticipate an error + with self.assertRaisesRegex( + ValueError, + "output tensor size must be equal to world_size times input tensor size", + ): + tensor = torch.tensor([self.rank]).xpu(local_device_id) + output_t = torch.empty((self.world_size + 1), dtype=tensor.dtype).xpu( + local_device_id + ) + # fails the check because output_t is not correctly sized + allgather_base(output_t, tensor) + + # anticipate an error + with self.assertRaisesRegex( + TypeError, "output tensor must have the same type as input tensor" + ): + tensor = torch.tensor([self.rank], dtype=torch.float).xpu(local_device_id) + output_t = torch.empty((self.world_size + 1), dtype=torch.long).xpu( + local_device_id + ) + # fails the check because the dtype is different + allgather_base(output_t, tensor) + + # TODO: wait gather + # @requires_xccl() + # @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + # def test_gather_ops(self): + # pg = self.pg + # local_device_ids = self.rank_to_GPU[self.rank] + # num_gpus = len(local_device_ids) + + # def gather(output_t, input_t, rootRank): + # opts = c10d.GatherOptions() + # opts.rootRank = rootRank + # if rootRank == self.rank: + # work = pg.gather(output_t, input_t, opts) + # else: + # work = pg.gather([], input_t, opts) + # work.wait() + + # # init input + # tensors = [] + # for device_id in local_device_ids: + # tensors.append(torch.tensor([self.rank]).xpu(device_id)) + + # # init output + # output_ts = [] + # for idx in range(num_gpus): + # gpu_idx = local_device_ids[idx] + # output_ts.append([]) + # for rank in range(self.world_size): + # output_ts[idx].append(torch.tensor([-1]).xpu(gpu_idx)) + + # expected = [[torch.tensor([rank]) for rank in range(self.world_size)]] + # for rank in range(self.world_size): + # gather(output_ts, tensors, rank) + # if rank == self.rank: + # self.assertEqual(expected, output_ts) + + # @requires_xccl() + # @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + # def test_gather_stress(self): + # pg = self.pg + # local_device_ids = self.rank_to_GPU[self.rank] + # num_gpus = len(local_device_ids) + + # def gather(output_t, input_t, rootRank): + # opts = c10d.GatherOptions() + # opts.rootRank = rootRank + # if rootRank == self.rank: + # work = pg.gather(output_t, input_t, opts) + # else: + # work = pg.gather([], input_t, opts) + # work.wait() + + # stress_length = 1000 + + # # init input + # tensors = [] + # for i in range(stress_length): + # tensors.append([]) + # for device_id in local_device_ids: + # tensors[i].append(torch.tensor([self.rank]).xpu(device_id)) + + # # init output + # output_ts = [] + # for i in range(stress_length): + # output_ts.append([[] for _ in range(num_gpus)]) + # for idx, ls in enumerate(output_ts[i]): + # gpu_idx = local_device_ids[idx] + # for _ in range(self.world_size): + # ls.append(torch.tensor([-1]).xpu(gpu_idx)) + + # expected = [[torch.tensor([rank]) for rank in range(self.world_size)]] + # for i in range(stress_length): + # for rank in range(self.world_size): + # gather(output_ts[i], tensors[i], rank) + # # Verification + # if rank == self.rank: + # self.assertEqual(output_ts[i], expected) + + # @requires_xccl() + # @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + # def test_gather_checks(self): + # pg = self.pg + # device_id = self.rank_to_GPU[self.rank][0] + + # # init input + # tensor = torch.tensor([self.rank]).xpu(device_id) + + # # init output + # output_ts = [] + # for rank in range(self.world_size): + # output_ts.append(torch.tensor([-1]).xpu(device_id)) + + # with self.assertRaisesRegex(ValueError, "invalid root rank"): + # opts = c10d.GatherOptions() + # opts.rootRank = -1 + # pg.gather([output_ts], [tensor], opts) + + # with self.assertRaisesRegex(TypeError, "incompatible function arguments"): + # pg.gather([output_ts], [tensor], 0) + + # with self.assertRaisesRegex(ValueError, "invalid root rank"): + # opts = c10d.GatherOptions() + # opts.rootRank = self.world_size + # pg.gather([output_ts], [tensor], opts) + + # with self.assertRaisesRegex( + # # throws error message from dispatcher + # RuntimeError, + # "There were no tensor arguments to this function", + # ): + # opts = c10d.GatherOptions() + # opts.rootRank = 0 + # pg.gather([output_ts], [], opts) + + # TODO: wait scatter + # @requires_xccl() + # @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + # def test_scatter_ops(self): + # pg = self.pg + # local_device_ids = self.rank_to_GPU[self.rank] + # num_gpus = len(local_device_ids) + + # def scatter(output_t, input_t, rootRank): + # opts = c10d.ScatterOptions() + # opts.rootRank = rootRank + # if rootRank == self.rank: + # work = pg.scatter(output_t, input_t, opts) + # else: + # work = pg.scatter(output_t, [], opts) + # work.wait() + + # # init output + # tensors = [] + # for device_id in local_device_ids: + # tensors.append(torch.tensor([-1]).xpu(device_id)) + + # # init input + # scatter_list = [] + # for idx in range(num_gpus): + # gpu_idx = local_device_ids[idx] + # scatter_list.append([]) + # for rank in range(self.world_size): + # scatter_list[idx].append(torch.tensor([rank]).xpu(gpu_idx)) + + # # test each rank to scatter + # expected = [torch.tensor([self.rank])] + # for rank in range(self.world_size): + # scatter(tensors, scatter_list, rank) + # self.assertEqual(expected, tensors) + + # @requires_xccl() + # @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + # def test_scatter_stress(self): + # pg = self.pg + # local_device_ids = self.rank_to_GPU[self.rank] + # num_gpus = len(local_device_ids) + + # def scatter(output_t, input_t, rootRank): + # opts = c10d.ScatterOptions() + # opts.rootRank = rootRank + # if rootRank == self.rank: + # work = pg.scatter(output_t, input_t, opts) + # else: + # work = pg.scatter(output_t, [], opts) + # work.wait() + + # stress_length = 1000 + + # # init output + # tensors = [] + # for i in range(stress_length): + # tensors.append([]) + # for device_id in local_device_ids: + # tensors[i].append(torch.tensor([-1]).xpu(device_id)) + + # # init input + # scatter_list = [] + # for i in range(stress_length): + # scatter_list.append([[] for _ in range(num_gpus)]) + # for idx, ls in enumerate(scatter_list[i]): + # gpu_idx = local_device_ids[idx] + # for rank in range(self.world_size): + # ls.append(torch.tensor([rank]).xpu(gpu_idx)) + + # # test each rank to scatter + # expected = [torch.tensor([self.rank])] + # for i in range(stress_length): + # for rank in range(self.world_size): + # scatter(tensors[i], scatter_list[i], rank) + # # Verification + # self.assertEqual(tensors[i], expected) + + # @requires_xccl() + # @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + # def test_scatter_checks(self): + # pg = self.pg + # local_device_ids = self.rank_to_GPU[self.rank] + # num_gpus = len(local_device_ids) + + # # init output + # tensors = [] + # for device_id in local_device_ids: + # tensors.append(torch.tensor([-1]).xpu(device_id)) + + # # init input + # scatter_list = [] + # for idx in range(num_gpus): + # gpu_idx = local_device_ids[idx] + # scatter_list.append([]) + # for rank in range(self.world_size): + # scatter_list[idx].append(torch.tensor([rank]).xpu(gpu_idx)) + + # with self.assertRaisesRegex(ValueError, "invalid root rank"): + # opts = c10d.ScatterOptions() + # opts.rootRank = -1 + # pg.scatter(tensors, scatter_list, opts) + + # with self.assertRaisesRegex(TypeError, "incompatible function arguments"): + # pg.scatter(tensors, scatter_list, 0) + + # with self.assertRaisesRegex(ValueError, "invalid root rank"): + # opts = c10d.ScatterOptions() + # opts.rootRank = self.world_size + # pg.scatter(tensors, scatter_list, opts) + + # with self.assertRaisesRegex( + # # throws error message from dispatcher + # RuntimeError, + # "There were no tensor arguments to this function", + # ): + # opts = c10d.ScatterOptions() + # opts.rootRank = 0 + # pg.scatter([], scatter_list, opts) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_reduce_scatter_base_basics(self): + pg = self.pg + local_device_id = self.rank_to_GPU[self.rank][0] + + def reduce_scatter_base(output_t, input_t): + work = pg._reduce_scatter_base(output_t, input_t) + work.wait() + + # anticipate an error + with self.assertRaisesRegex( + ValueError, + "input tensor must be the same size as output size times world size", + ): + input_t = torch.tensor([self.rank]).xpu(local_device_id) + output_t = torch.empty((self.world_size + 1), dtype=input_t.dtype).xpu( + local_device_id + ) + # fails the check because output_t is not correctly sized + reduce_scatter_base(output_t, input_t) + + # anticipate an error + with self.assertRaisesRegex( + TypeError, "input tensor must be the same type as the output tensor." + ): + tensor = torch.tensor([self.rank], dtype=torch.float).xpu(local_device_id) + output_t = torch.empty((self.world_size + 1), dtype=torch.long).xpu( + local_device_id + ) + # fails the check because the dtype is different + reduce_scatter_base(output_t, tensor) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_reduce_scatter_ops(self): + pg = self.pg + local_device_ids = self.rank_to_GPU[self.rank] + num_gpus = len(local_device_ids) + + def reduce_scatter(outputs, input_lists, op): + opts = c10d.ReduceScatterOptions() + opts.reduceOp = op + work = pg.reduce_scatter(outputs, input_lists, opts) + work.wait() + + output = [torch.tensor([0]).xpu(i) for i in local_device_ids] + + # GPU/rank + # 0 [1], [2], [3], [4] + # 1 [2], [3], [4], [5] + # 2 [3], [4], [5], [6] + # 3 [4], [5], [6], [7] + + # Sum + tensor_lists = [] + input_per_gpu = [] + + for i in range(self.world_size): + input_per_gpu.append(torch.tensor([self.rank + i + 1])) + + for gpu in local_device_ids: + tensor_lists.append([t.xpu(device=gpu) for t in input_per_gpu]) + + reduce_scatter(output, tensor_lists, c10d.ReduceOp.SUM) + + for i in range(num_gpus): + expected = torch.tensor( + [ + (1 + self.world_size) * self.world_size // 2 + + self.world_size * self.rank + ] + ) + + self.assertEqual(expected, output[i]) + + # Min + reduce_scatter(output, tensor_lists, c10d.ReduceOp.MIN) + + for i in range(num_gpus): + expected = torch.tensor([self.rank + 1 + i]) + self.assertEqual(expected, output[i]) + + # Max + reduce_scatter(output, tensor_lists, c10d.ReduceOp.MAX) + + for i in range(num_gpus): + expected = torch.tensor([self.rank + self.world_size + i]) + self.assertEqual(expected, output[i]) + + # Product + reduce_scatter(output, tensor_lists, c10d.ReduceOp.PRODUCT) + + # math package don't have math.perm until python 3.8, so + # we implement a naive version here. + def perm(n, k): + prod_val = n + for val in range(n - k + 1, n): + prod_val *= val + return prod_val + + for i in range(num_gpus): + prod_val = perm(self.rank + self.world_size, self.world_size) + + expected = torch.tensor([prod_val]) + self.assertEqual(expected, output[i]) + + # Test the input params overridden scenarios, aka, when the input is + # a list and output is just one tensor. + # Sum + output_tensor = torch.empty_like(input_per_gpu[0][0]).xpu(self.rank) + input_list = [tensor[0].xpu(self.rank) for tensor in input_per_gpu] + pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.SUM).wait() + expected = torch.tensor( + (1 + self.world_size) * self.world_size // 2 + self.world_size * self.rank + ) + self.assertEqual(expected, output_tensor) + + # Min + pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MIN).wait() + expected = torch.tensor(self.rank + 1) + self.assertEqual(expected, output_tensor) + + # Max + pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MAX).wait() + expected = torch.tensor(self.rank + self.world_size) + self.assertEqual(expected, output_tensor) + + # Product + pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.PRODUCT).wait() + prod_val = self.rank + 1 + for k in range(1, self.world_size): + prod_val = prod_val * (self.rank + 1 + k) + expected = torch.tensor(prod_val) + self.assertEqual(expected, output_tensor) + + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_reduce_scatter_base_ops(self): + pg = self.pg + local_device_id = self.rank_to_GPU[self.rank][0] + + def reduce_scatter_base(output_t, input_t): + work = pg._reduce_scatter_base(output_t, input_t) + work.wait() + + # reduce_scatter_base is GPU number agnostic. + # Each rank contribute one tensor regardless of GPU counts + output_t = torch.empty([1]).xpu(local_device_id) + tensor = torch.arange(self.world_size, dtype=output_t.dtype).xpu( + local_device_id + ) + + reduce_scatter_base(output_t, tensor) + + # Verification + self.assertEqual(output_t[0], self.rank * self.world_size) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_barrier(self): + pg = self.pg + local_device_ids = self.rank_to_GPU[self.rank] + + def allreduce(tensors): + opts = c10d.AllreduceOptions() + work = pg.allreduce(tensors, opts) + return work + + # Making the collective to operate on + # 1, 2, 3, 4, .... len(local_device_ids) GPUs + tensors_list = [[] for _ in range(len(local_device_ids))] + + for i in range(1, len(local_device_ids) + 1): + for j in range(i): + tensors_list[i - 1].append( + torch.tensor([j + 1]).xpu(local_device_ids[j]) + ) + + works = [] + for tensors in tensors_list: + work = allreduce(tensors) + works.append(work) + + # Barrier will ensure that all previous work is completed + pg.barrier().wait() + + for i in range(1, len(local_device_ids) + 1): + for j in range(i): + self.assertEqual( + torch.tensor([(j + 1) * self.world_size]), tensors_list[i - 1][j] + ) + + # TODO: wait send/recv + # @requires_xccl() + # @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + # def test_send_recv(self): + # pg = self.pg + # device = self.rank_to_GPU[self.rank][0] + + # # Generate the same random tensor + # torch.manual_seed(0) + # send_tensor = torch.rand(10, 10, device=device) + # if self.rank == 0: + # dist.send(send_tensor, 1) + # if self.rank == 1: + # recv_tensor = torch.rand(10, 10, device=device) + # dist.recv(recv_tensor, 0) + # self.assertEqual(send_tensor, recv_tensor) + + # @requires_xccl() + # @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + # def test_send_recv_complex(self): + # pg = self.pg + # device = self.rank_to_GPU[self.rank][0] + + # # Generate the same random tensor + # torch.manual_seed(0) + # send_tensor = torch.rand(10, 10, dtype=torch.cfloat, device=device) + # if self.rank == 0: + # dist.send(send_tensor, 1) + # if self.rank == 1: + # recv_tensor = torch.rand(10, 10, dtype=torch.cfloat, device=device) + # dist.recv(recv_tensor, 0) + # self.assertEqual(send_tensor, recv_tensor) + + # @requires_xccl() + # @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + # def test_send_recv_object_list(self): + # device = self.rank_to_GPU[self.rank][0] + + # val = 99 if self.rank == 0 else None + # object_list = [val] * self.world_size + # if self.rank == 0: + # dist.send_object_list(object_list, 1, device=device) + # if self.rank == 1: + # dist.recv_object_list(object_list, 0, device=device) + # self.assertEqual(object_list[0], 99) + + +if __name__ == "__main__": + rank = int(os.getenv("RANK", -1)) + world_size = int(os.getenv("WORLD_SIZE", 2)) + + if rank != -1: + # Launched with torchrun or other multi-proc launchers. Directly run the test. + ProcessGroupXCCLOpTest.run_rank(rank, world_size) + else: + # Launched as a single process. Spawn subprocess to run the tests. + # Also need a rendezvous file for `init_process_group` purpose. + rdvz_file = tempfile.NamedTemporaryFile(delete=False).name + torch.multiprocessing.spawn( + ProcessGroupXCCLOpTest.run_rank, + nprocs=world_size, + args=(world_size, rdvz_file), + ) + diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 83d2729fc43d4..c7f9609bcf0cd 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -610,6 +610,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { tensor = at::empty( {1}, at::TensorOptions().device(at::DeviceType::CUDA).dtype(at::kByte)); + } else if (backendType_ == c10d::ProcessGroup::BackendType::XCCL) { + // set xpu tensor for override cpu dispatch + tensor = at::empty( + {1}, + at::TensorOptions().device(at::DeviceType::XPU).dtype(at::kByte)); } else { // Default to using cpu implementation tensor = at::empty( diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index d26d25ae03e39..d2473b3c95004 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -2,6 +2,8 @@ #include #include +#include +#include #include #include #include @@ -16,6 +18,16 @@ namespace c10d { namespace { + +// wait nonblocking implement +AutoXcclGroup::AutoXcclGroup() { + ccl::group_start(); +} + +AutoXcclGroup::~AutoXcclGroup() noexcept(false) { + ccl::group_end(); +} + std::map xcclOps = { {ReduceOp::MIN, ccl::reduction::min}, {ReduceOp::MAX, ccl::reduction::max}, @@ -33,8 +45,22 @@ std::map xcclDatatypes = { {at::kDouble, ccl::datatype::float64}, {at::kBFloat16, ccl::datatype::bfloat16}, {at::kBool, ccl::datatype::uint8}, + // use for allgather + {at::kFloat8_e5m2, ccl::datatype::uint8}, + {at::kFloat8_e4m3fn, ccl::datatype::uint8}, + {at::kFloat8_e4m3fnuz, ccl::datatype::uint8}, + {at::kFloat8_e5m2fnuz, ccl::datatype::uint8}, }; +bool check_same_size(const std::vector& input_tensors) { + for (const auto& input_tensor : input_tensors) { + if (!input_tensors[0].is_same_size(input_tensor)) { + return false; + } + } + return true; +} + void check_xpu_single_tensor(const at::Tensor& tensor) { if (!tensor.is_xpu() || tensor.is_sparse()) { C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); @@ -44,6 +70,34 @@ void check_xpu_single_tensor(const at::Tensor& tensor) { } } +int64_t check_xpu_tensors_same_device(const std::vector& tensors) { + if (tensors.size() == 0) { + C10_THROW_ERROR(ValueError, "Tensor list must be nonempty"); + } + + const auto& first = tensors.front(); + + int64_t total_numel = 0; + for (const auto& t : tensors) { + if (!t.is_xpu() || t.is_sparse()) { + C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); + } + if (t.scalar_type() != first.scalar_type()) { + C10_THROW_ERROR(TypeError, "Tensors must have identical type"); + } + if (!t.is_non_overlapping_and_dense()) { + C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense"); + } + TORCH_CHECK_WITH( + ValueError, + t.get_device() == tensors[0].get_device(), + "Expected list of tensors on the same device"); + total_numel += t.numel(); + } + + return total_numel; +} + ccl::datatype getXcclDataType(at::ScalarType type) { auto it = xcclDatatypes.find(type); TORCH_CHECK_WITH( @@ -78,6 +132,10 @@ static std::mutex xcclCommDevIdxMapMutex; static std::unordered_map, int> xcclCommDevIdxMap; constexpr int64_t kSynchronizeBusyWaitMillis = 10; +// Before implementing send/recv, the xcclActiveGroupCounter_ variable has no +// effect. +thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0; + ProcessGroupXCCL::WorkXCCL::WorkXCCL( at::Device& device, int rank, @@ -131,6 +189,10 @@ void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); } } + if (barrierTensor_.defined()) { + auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); + currentStream.synchronize(); + } } bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { @@ -138,6 +200,9 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { return true; } +constexpr const char* MULTI_DEVICE_ERROR_MSG = + "Expecting one tensor only but got multiple"; + ProcessGroupXCCL::ProcessGroupXCCL( const c10::intrusive_ptr& store, int rank, @@ -184,6 +249,8 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( "the devices are empty "); } + usedDeviceIdxs_.insert(device.index()); + { std::lock_guard lock(mutex_); if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { @@ -234,10 +301,63 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( return it->second; } +void ProcessGroupXCCL::groupStart() { + ccl::group_start(); + ++xcclActiveGroupCounter_; +} + +void ProcessGroupXCCL::groupEnd() { + ccl::group_end(); + --xcclActiveGroupCounter_; +} + +// TODO: wait p2p enable +static constexpr int CoalActive = 0x01, CoalColl = 0x02; +void ProcessGroupXCCL::startCoalescing() { + coalescedDevice_.set_index(-1); + coalescedComm_ = nullptr; + coalescing_state_ |= CoalActive; + groupStart(); +} + +c10::intrusive_ptr ProcessGroupXCCL::endCoalescing(OpType optype) { + if (coalescedComm_ == nullptr) { + // There is no actual work being coalesced, return here + groupEnd(); + coalescing_state_ = 0; + return nullptr; + } + TORCH_CHECK( + coalescedDevice_.index() >= 0, + "Somthing went wrong. Did you call end_coalescing before start_coalescing?"); + + auto comm = coalescedComm_; + auto device = coalescedDevice_; + + const auto key = std::to_string(device.index()); + auto stream = xcclStreams_.at(key); + + auto work = initWork(device, rank_, optype); + work->blockingWait_ = blockingWait_; + + groupEnd(); + + work->xcclEndEvent_->record(stream); + + coalescing_state_ = 0; + coalescedComm_ = nullptr; + return work; +} + +c10::intrusive_ptr ProcessGroupXCCL::endCoalescing() { + // Default OpType to COALESCED if not specified + return endCoalescing(OpType::COALESCED); +} + template c10::intrusive_ptr ProcessGroupXCCL::collective( - at::Tensor& input, - at::Tensor& output, + std::vector& inputs, + std::vector& outputs, Fn fn, PreProcess pre, PostProcess post, @@ -246,27 +366,49 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( using attr_t = typename traits::template arg<2>::type; attr_t attr = ccl::create_operation_attr(); - auto device = input.device(); + auto device = inputs[0].device(); const auto key = std::to_string(device.index()); auto comm = getXCCLComm(key, device); + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalColl; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; + } else { + TORCH_CHECK( + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = comm; + } else { + TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG); + } + } + auto stream = xcclStreams_.at(key); - std::vector outputs{output}; c10::intrusive_ptr work; work = initWork(device, rank_, opType); - work->outputs_ = - std::make_shared>(std::move(outputs)); - c10::xpu::XPUCachingAllocator::recordStream( - input.storage().data_ptr(), stream); + work->outputs_ = std::make_shared>(outputs); - auto ccl_stream = ccl::create_stream(stream.queue()); + at::xpu::OptionalXPUGuard gpuGuard(device); - fn(input, output, attr, *comm, ccl_stream); + pre(stream, work); - work->xcclEndEvent_->record(stream); + for (const auto& input : inputs) { + c10::xpu::XPUCachingAllocator::recordStream( + input.storage().data_ptr(), stream); + } + + fn(inputs[0], outputs[0], attr, *comm, stream); + + post(stream, work); + + if (!coalescing_state_) { + work->xcclEndEvent_->record(stream); + } std::vector streams = {stream.unwrap()}; c10::MultiStreamGuard streamGuard(streams); @@ -279,6 +421,19 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( return work; } +template +c10::intrusive_ptr ProcessGroupXCCL::collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective(inputs, outputs, fn, pre, post, opType); +} + template c10::intrusive_ptr ProcessGroupXCCL::collective( at::Tensor& input, @@ -289,43 +444,571 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( input, output, fn, - [](at::xpu::XPUStream&, - c10::intrusive_ptr& work) {}, - [](at::xpu::XPUStream&, - c10::intrusive_ptr& work) {}, + [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + }, + [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + }, opType); } +template +c10::intrusive_ptr ProcessGroupXCCL::collectiveCoalesced( + std::vector& inputs, + std::vector& outputs, + Fn fn, + OpType opType) { + using traits = function_traits; + using attr_t = typename traits::template arg<2>::type; + attr_t attr = ccl::create_operation_attr(); + + auto device = inputs[0].device(); + const auto key = std::to_string(device.index()); + auto comm = getXCCLComm(key, device); + + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalColl; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; + } else { + TORCH_CHECK( + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = comm; + } else { + TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG); + } + } + + auto stream = xcclStreams_.at(key); + + c10::intrusive_ptr work; + + work = initWork(device, rank_, opType); + + work->outputs_ = std::make_shared>(outputs); + + at::xpu::OptionalXPUGuard gpuGuard(device); + + { + AutoXcclGroup xccl_group_guard; + for (const auto i : c10::irange(inputs.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + inputs[i].storage().data_ptr(), stream); + fn(inputs[i], outputs[i], attr, *comm, stream); + } + } + + work->xcclEndEvent_->record(stream); + + std::vector streams = {stream.unwrap()}; + c10::MultiStreamGuard streamGuard(streams); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); + work->blockingWait_ = blockingWait_; + + return work; +} + +c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( + at::Tensor& tensor, + const AllreduceOptions& opts) { + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + ccl::allreduce_attr attr, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::event ret_evt; + ret_evt = ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue()), + attr); + return ret_evt; + }, + OpType::ALLREDUCE); +} + c10::intrusive_ptr ProcessGroupXCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + check_xpu_single_tensor(tensor); + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for XCCL reductions"); + + return allreduce_impl(tensor, opts); +} + +c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts) { + check_xpu_tensors_same_device(tensors); TORCH_CHECK( - tensors.size() == 1, "Expecting one tensor only but got multiple"); + !isFloat8Type(tensors.back().scalar_type()), + "Float8 dtypes are not currenlty supported for XCCL reductions"); + + return collectiveCoalesced( + tensors, + tensors, + [&](at::Tensor& input, + at::Tensor& output, + ccl::allreduce_attr attr, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::event ret_evt; + ret_evt = ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue()), + attr); + return ret_evt; + }, + OpType::COALESCED); +} + +c10::intrusive_ptr ProcessGroupXCCL::broadcast( + std::vector& tensors, + const BroadcastOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); auto tensor = tensors.back(); + if (tensor.is_complex()) { + tensor = at::view_as_real(tensor); + } check_xpu_single_tensor(tensor); + + const auto root = opts.rootRank + opts.rootTensor; + return collective( tensor, tensor, [&](at::Tensor& input, at::Tensor& output, - ccl::allreduce_attr attr, + ccl::broadcast_attr attr, xcclComm_t& comm, - ccl::stream& stream) { + at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); ccl::event ret_evt; - ret_evt = ccl::allreduce( + ret_evt = ccl::broadcast( + input.data_ptr(), + (size_t)input.numel(), + xcclDataType, + root, + comm, + ccl::create_stream(stream.queue()), + attr); + return ret_evt; + }, + OpType::BROADCAST); +} + +c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const BroadcastOptions& opts) { + if (outputTensor.numel() != inputTensor.numel()) { + C10_THROW_ERROR( + ValueError, + "Tensor input and output of _broadcast_oop must have the same number of elements "); + } + const auto root = opts.rootRank + opts.rootTensor; + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ccl::broadcast_attr attr, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::event ret_evt; + ret_evt = ccl::broadcast( + input.data_ptr(), + (size_t)input.numel(), + xcclDataType, + root, + comm, + ccl::create_stream(stream.queue()), + attr); + return ret_evt; + }, + OpType::BROADCAST); +} + +c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceOptions& opts) { + if (outputTensor.numel() != inputTensor.numel()) { + C10_THROW_ERROR( + ValueError, + "Tensor input and output of _reduce_oop must have the same number of elements "); + } + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ccl::reduce_attr attr, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + const int root = opts.rootRank + opts.rootTensor; + const auto xcclDataType = getXcclDataType(input.scalar_type()); + const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::event ret_evt; + ret_evt = ccl::reduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), xcclDataType, xcclReduceOp, + root, comm, - stream, + ccl::create_stream(stream.queue())); + return ret_evt; + }, + OpType::REDUCE); +} + +c10::intrusive_ptr ProcessGroupXCCL::allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts) { + TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto inputTensor = inputTensors.back(); + check_xpu_single_tensor(inputTensor); + // @lint-ignore CLANGTIDY + std::vector& outputTensors_ = outputTensors.back(); + + bool same_size = check_same_size(outputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor outputFlattened = newLikeFlat(outputTensors_); + + return collective( + inputTensor, + outputFlattened, + [&](at::Tensor& input, + at::Tensor& output, + ccl::allgather_attr attr, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::event ret_evt; + + ret_evt = ccl::allgather( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + comm, + ccl::create_stream(stream.queue()), + attr); + return ret_evt; + }, + [](at::xpu::XPUStream&, + c10::intrusive_ptr& work) {}, + [&](at::xpu::XPUStream& Stream, + c10::intrusive_ptr& work) { + // Copy the flattened output tensors to the outputs. + c10::StreamGuard guard(Stream); + for (const auto j : c10::irange(outputTensors_.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + outputTensors_[j].storage().data_ptr(), Stream); + outputTensors_[j].copy_(outputFlattened[j], true); + } + }, + OpType::ALLGATHER); + } else { + const auto num_reduces = outputTensors_.size(); + startCoalescing(); + for (const int i : c10::irange(num_reduces)) { + auto& output = outputTensors_[i]; + auto& input = (i == rank_) ? inputTensor : output; + auto broadcastOpts = BroadcastOptions{ + static_cast(i), static_cast(0), opts.timeout}; + _broadcast_oop(output, input, broadcastOpts); + } + auto work = endCoalescing(OpType::ALLGATHER); + return work; + } +} + +c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const AllgatherOptions& opts) { + check_xpu_single_tensor(input_tensor); + check_xpu_single_tensor(output_tensor); + + if (input_tensor.dtype() != output_tensor.dtype()) { + C10_THROW_ERROR( + TypeError, "output tensor must have the same type as input tensor"); + } + + if (input_tensor.numel() * size_ != output_tensor.numel()) { + C10_THROW_ERROR( + ValueError, + "output tensor size must be equal to world_size times input tensor size"); + } + + return collective( + input_tensor, + output_tensor, + [&](at::Tensor& input, + at::Tensor& output, + ccl::allgather_attr attr, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::event ret_evt; + ret_evt = ccl::allgather( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + comm, + ccl::create_stream(stream.queue()), attr); return ret_evt; }, - OpType::ALLREDUCE); + OpType::_ALLGATHER_BASE); +} + +c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts) { + return collectiveCoalesced( + inputs, + outputs, + [&](at::Tensor& input, + at::Tensor& output, + ccl::allgather_attr attr, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::event ret_evt; + ret_evt = ccl::allgather( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + comm, + ccl::create_stream(stream.queue()), + attr); + return ret_evt; + }, + OpType::COALESCED); +} + +c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts) { + TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto outputTensor = outputTensors.back(); + check_xpu_single_tensor(outputTensor); + // @lint-ignore CLANGTIDY + auto inputTensors_ = inputTensors.back(); + TORCH_CHECK( + !isFloat8Type(outputTensor.scalar_type()), + "Float8 dtypes are not currenlty supported for XCCL reductions"); + + bool same_size = check_same_size(inputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor inputFlattened = newLikeFlat(inputTensors_); + return collective( + inputFlattened, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ccl::reduce_attr attr, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::event ret_evt; + ret_evt = ccl::reduce_scatter( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + return ret_evt; + }, + [&](at::xpu::XPUStream& Stream, + c10::intrusive_ptr& work) { + // Copy the input tensors to the flattened inputs. + c10::StreamGuard guard(Stream); + for (const auto j : c10::irange(inputTensors_.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + inputTensors_[j].storage().data_ptr(), Stream); + inputFlattened[j].copy_(inputTensors_[j], true); + } + }, + [&](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + OpType::REDUCE_SCATTER); + } else { + const auto num_reduces = inputTensors_.size(); + startCoalescing(); + for (const int i : c10::irange(num_reduces)) { + auto& input = inputTensors_[i]; + auto& output = (i == rank_) ? outputTensor : input; + auto reduceOpts = ReduceOptions{ + opts.reduceOp, + static_cast(i), + static_cast(0), + opts.timeout}; + _reduce_oop(output, input, reduceOpts); + } + auto work = endCoalescing(OpType::REDUCE_SCATTER); + return work; + } +} + +c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts) { + if (inputTensor.dtype() != outputTensor.dtype()) { + C10_THROW_ERROR( + TypeError, "input tensor must be the same type as the output tensor."); + } + + if (inputTensor.numel() != outputTensor.numel() * size_) { + C10_THROW_ERROR( + ValueError, + "input tensor must be the same size as output size times world size"); + } + + // @lint-ignore CLANGTIDY + const auto& tensor = outputTensor; + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()), + "Float8 dtypes are not currenlty supported for XCCL reductions"); + + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ccl::reduce_attr attr, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::event ret_evt; + ret_evt = ccl::reduce_scatter( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + return ret_evt; + }, + OpType::_REDUCE_SCATTER_BASE); +} + +c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const ReduceScatterOptions& opts) { + TORCH_CHECK( + !isFloat8Type(inputs.back().scalar_type()), + "Float8 dtypes are not currenlty supported for XCCL reductions"); + return collectiveCoalesced( + inputs, + outputs, + [&](at::Tensor& input, + at::Tensor& output, + ccl::reduce_attr attr, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::event ret_evt; + ret_evt = ccl::reduce_scatter( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + return ret_evt; + }, + OpType::COALESCED); +} + +c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { + // Device to use for barrier + int barDevIdx = -1; + + // See nccl barrier comments + if (!opts.device_ids.empty()) { + barDevIdx = opts.device_ids[0]; + } else if (getBoundDeviceId()) { + barDevIdx = (*getBoundDeviceId()).index(); + } else if (!usedDeviceIdxs_.empty()) { + barDevIdx = *usedDeviceIdxs_.begin(); + } else { + barDevIdx = + static_cast(rank_ % at::detail::getXPUHooks().getNumGPUs()); + } + + TORCH_CHECK_WITH( + ValueError, + barDevIdx >= 0, + "Failed to infer a GPU device id to perform barrier. "); + auto barDevice = at::Device(at::DeviceType::XPU, barDevIdx); + + at::Tensor barrierTensor = + at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); + + auto work = allreduce_impl(barrierTensor); + + auto xcclWork = dynamic_cast(work.get()); + TORCH_CHECK(xcclWork); + xcclWork->barrierTensor_ = std::move(barrierTensor); + return work; } } // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 99b815f2138b4..790b6df99e91f 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -30,6 +30,13 @@ #include namespace c10d { +namespace { +struct AutoXcclGroup { + AutoXcclGroup(); + ~AutoXcclGroup() noexcept(false); +}; +} // namespace + static std::vector TORCH_XCCL_BLOCKING_WAIT = { "TORCH_XCCL_BLOCKING_WAIT", "XCCL_BLOCKING_WAIT"}; @@ -70,12 +77,13 @@ class TORCH_API ProcessGroupXCCL : public Backend { } std::vector result() override { - TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::result not implemented"); + return *outputs_; } protected: at::Device device_; std::shared_ptr xcclEndEvent_; + at::Tensor barrierTensor_; bool blockingWait_ = false; std::chrono::time_point workStartTime_; @@ -101,6 +109,12 @@ class TORCH_API ProcessGroupXCCL : public Backend { return std::string(XCCL_BACKEND_NAME); } + void startCoalescing() override; + + c10::intrusive_ptr endCoalescing() override; + + c10::intrusive_ptr endCoalescing(OpType optype); + std::shared_ptr getXCCLComm( const std::string& deviceKey, at::Device& device); @@ -128,6 +142,26 @@ class TORCH_API ProcessGroupXCCL : public Backend { PostProcess post, OpType opType); + template + c10::intrusive_ptr collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType); + + template + c10::intrusive_ptr collectiveCoalesced( + std::vector& input, + std::vector& output, + Fn fn, + OpType opType); + + c10::intrusive_ptr allreduce_impl( + at::Tensor& tensor, + const AllreduceOptions& opts = AllreduceOptions()); + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; @@ -135,9 +169,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = - AllreduceCoalescedOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::allreduce_coalesced not implemented"); - } + AllreduceCoalescedOptions()) override; c10::intrusive_ptr reduce( std::vector& tensors, @@ -145,25 +177,29 @@ class TORCH_API ProcessGroupXCCL : public Backend { TORCH_CHECK(false, "ProcessGroupXCCL::reduce not implemented"); } + c10::intrusive_ptr _reduce_oop( + at::Tensor& outputTensors, + at::Tensor& inputTensors, + const ReduceOptions& opts = ReduceOptions()); + c10::intrusive_ptr broadcast( std::vector& tensors, - const BroadcastOptions& opts = BroadcastOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::broadcast not implemented"); - } + const BroadcastOptions& opts = BroadcastOptions()) override; + + c10::intrusive_ptr _broadcast_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const BroadcastOptions& opts); c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, - const AllgatherOptions& opts = AllgatherOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::allgather not implemented"); - } + const AllgatherOptions& opts = AllgatherOptions()) override; c10::intrusive_ptr _allgather_base( at::Tensor& outputbuffer, at::Tensor& inputbuffer, - const AllgatherOptions& opts = AllgatherOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::_allgather_base not implemented"); - } + const AllgatherOptions& opts = AllgatherOptions()) override; c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, @@ -175,40 +211,25 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr allgather_into_tensor_coalesced( std::vector& outputs, std::vector& inputs, - const AllgatherOptions& opts = AllgatherOptions()) override { - TORCH_CHECK( - false, - "ProcessGroupXCCL::allgather_into_tensor_coalesced not implemented"); - } + const AllgatherOptions& opts = AllgatherOptions()) override; c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::reduce_scatter not implemented"); - } + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; c10::intrusive_ptr _reduce_scatter_base( at::Tensor& outputTensor, at::Tensor& inputTensor, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override { - TORCH_CHECK( - false, "ProcessGroupXCCL::_reduce_scatter_base not implemented"); - } + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; c10::intrusive_ptr reduce_scatter_tensor_coalesced( std::vector& outputs, std::vector& inputs, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override { - TORCH_CHECK( - false, - "ProcessGroupXCCL::reduce_scatter_tensor_coalesced not implemented"); - } + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; c10::intrusive_ptr barrier( - const BarrierOptions& opts = BarrierOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::barrier not implemented"); - } + const BarrierOptions& opts = BarrierOptions()) override; c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, @@ -240,6 +261,10 @@ class TORCH_API ProcessGroupXCCL : public Backend { TORCH_CHECK(false, "ProcessGroupXCCL::recv not implemented"); } + void groupStart(); + + void groupEnd(); + c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, @@ -261,8 +286,12 @@ class TORCH_API ProcessGroupXCCL : public Backend { std::unordered_map> devXCCLCommMap_; c10::intrusive_ptr store_; std::mutex mutex_; + std::set usedDeviceIdxs_; + int coalescing_state_ = 0; + at::Device coalescedDevice_ = at::Device("xpu"); + std::shared_ptr coalescedComm_ = nullptr; bool blockingWait_ = false; - + static thread_local uint64_t xcclActiveGroupCounter_; private: XCCL_KVS kvs; std::mutex kvs_mutex;