Skip to content

Commit c679690

Browse files
committed
enable TP
1 parent fd57737 commit c679690

20 files changed

+620
-608
lines changed

test/distributed/_composable/fsdp/test_fully_shard_init.py

Lines changed: 80 additions & 80 deletions
Large diffs are not rendered by default.

test/distributed/_composable/fsdp/test_fully_shard_logging.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch._dynamo.test_case import run_tests
88
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
99
from torch.testing._internal.inductor_utils import HAS_CUDA
10+
from torch.testing._internal.common_utils import TEST_XPU
1011
from torch.testing._internal.logging_utils import LoggingTestCase
1112

1213

test/distributed/_composable/fsdp/test_fully_shard_memory.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
class TestFullyShardMemory(FSDPTest):
1919
@property
2020
def world_size(self) -> int:
21-
return min(2, torch.cuda.device_count())
21+
return min(2, torch.xpu.device_count())
2222

2323
@skip_if_lt_x_gpu(2)
2424
def test_fully_shard_training_memory(self):
@@ -56,10 +56,10 @@ def _test_fully_shard_training_memory(
5656
# Pre-run a linear forward (gemm and bias) and backward (gemm) to
5757
# allocate the cuBLAS workspaces before measuring the memory usage
5858
# since the workspace size can differ between hardwares
59-
lin = torch.nn.Linear(768, 768, device="cuda")
60-
inp = torch.randn(1, 768, device="cuda")
59+
lin = torch.nn.Linear(768, 768, device="xpu")
60+
inp = torch.randn(1, 768, device="xpu")
6161
lin(inp).sum().backward()
62-
torch.cuda.empty_cache()
62+
torch.xpu.empty_cache()
6363
base_mem_mb = self._get_peak_active_memory_mb()
6464
vocab_size = 32
6565
model_args = ModelArgs(
@@ -108,7 +108,7 @@ def _test_fully_shard_training_memory(
108108
self.assertLessEqual(curr_mem_mb - base_mem_mb, init_mem_mb)
109109

110110
# Use a small input to minimize activation memory usage
111-
inp = torch.randint(0, vocab_size, (1, 4), device="cuda")
111+
inp = torch.randint(0, vocab_size, (1, 4), device="xpu")
112112

113113
# Forward:
114114
loss = model(inp)
@@ -166,7 +166,7 @@ def _test_fully_shard_training_memory(
166166
) * 4 / 1e6 + buffer_mb
167167
self.assertLessEqual(mem_mb - base_mem_mb, expected_mem_mb)
168168
del loss
169-
torch.cuda.reset_peak_memory_stats()
169+
torch.xpu.reset_peak_memory_stats()
170170

171171
# Optimizer step: unsharded parameters/gradients freed
172172
if not run_optim_in_backward:
@@ -184,7 +184,7 @@ def _test_fully_shard_training_memory(
184184
# Zero grad: sharded gradients freed
185185
if not run_optim_in_backward:
186186
optim.zero_grad()
187-
torch.cuda.reset_peak_memory_stats() # reset after freeing
187+
torch.xpu.reset_peak_memory_stats() # reset after freeing
188188
mem_mb = self._get_peak_active_memory_mb()
189189
expected_mem_mb = 0
190190
if not use_cpu_offload:
@@ -225,11 +225,11 @@ def test_fully_shard_del_memory(self):
225225
self.assertEqual(mem_mb, base_mem_mb)
226226

227227
def _get_peak_active_memory_mb(self) -> int:
228-
mem_stats = torch.cuda.memory_stats()
228+
mem_stats = torch.xpu.memory_stats()
229229
return round(mem_stats["active_bytes.all.peak"] / 1e6)
230230

231231
def _get_curr_active_memory_mb(self) -> int:
232-
mem_stats = torch.cuda.memory_stats()
232+
mem_stats = torch.xpu.memory_stats()
233233
return round(mem_stats["active_bytes.all.current"] / 1e6)
234234

235235
def _register_optim_in_backward(

test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
class TestFullyShardMixedPrecisionTraining(FSDPTest):
3333
@property
3434
def world_size(self) -> int:
35-
return min(4, torch.cuda.device_count())
35+
return min(4, torch.xpu.device_count())
3636

3737
def _init_models_and_optims(
3838
self,
@@ -43,7 +43,7 @@ def _init_models_and_optims(
4343
):
4444
torch.manual_seed(42)
4545
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
46-
ref_model = copy.deepcopy(model).cuda()
46+
ref_model = copy.deepcopy(model).xpu()
4747
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
4848

4949
def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
@@ -122,7 +122,7 @@ def assert_fn(output: torch.Tensor):
122122
)
123123

124124
torch.manual_seed(42 + self.rank + 1)
125-
inp = torch.randn((4, 16), device="cuda", dtype=param_dtype)
125+
inp = torch.randn((4, 16), device="xpu", dtype=param_dtype)
126126
for iter_idx in range(10):
127127
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
128128
fsdp_loss = model(inp).sum()
@@ -207,7 +207,7 @@ def assert_fn(output: torch.Tensor):
207207
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
208208
)
209209
torch.manual_seed(42 + self.rank + 1)
210-
inp = torch.randn((4, 16), device="cuda", dtype=param_dtype)
210+
inp = torch.randn((4, 16), device="xpu", dtype=param_dtype)
211211
for iter_idx in range(10):
212212
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
213213
fsdp_loss = model(inp).sum()
@@ -256,7 +256,7 @@ def assert_fn(output: torch.Tensor):
256256
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
257257
)
258258
torch.manual_seed(42 + self.rank + 1)
259-
inp = torch.randn((4, 16), device="cuda", dtype=param_dtype)
259+
inp = torch.randn((4, 16), device="xpu", dtype=param_dtype)
260260
for iter_idx in range(10):
261261
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
262262
fsdp_loss = model(inp).sum()
@@ -307,7 +307,7 @@ def _test_grad_acc_with_reduce_dtype(self, reshard_after_forward: bool):
307307
# To emulate the mixed precision implementation where forward/backward
308308
# compute use bf16 and optimizer uses fp32, we maintain both an fp32
309309
# and a bf16 copy of the reference model
310-
ref_model = copy.deepcopy(model).cuda()
310+
ref_model = copy.deepcopy(model).xpu()
311311
ref_model_compute = copy.deepcopy(ref_model).to(param_dtype)
312312
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
313313
for mlp in model:
@@ -327,7 +327,7 @@ def assert_fn(output: torch.Tensor):
327327
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
328328
)
329329
torch.manual_seed(42 + self.rank + 1)
330-
device = torch.device("cuda")
330+
device = torch.device("xpu")
331331
# Train on the same input to avoid loss explosion
332332
num_microbatches = 4
333333
inp = torch.randn((2 * num_microbatches, 16), device=device, dtype=param_dtype)
@@ -387,15 +387,15 @@ def world_size(self) -> int:
387387

388388
@skip_if_lt_x_gpu(1)
389389
def test_float16_on_one_submodule(self):
390-
x = torch.zeros(2, 100, device="cuda")
390+
x = torch.zeros(2, 100, device="xpu")
391391

392392
# Subtest 1: use fp16 on the second child submodule -- does not require
393393
# any additional casting logic
394394
forward_inputs: dict[str, nn.Module] = {}
395395
model = SaveForwardInputsModel(
396396
forward_inputs,
397397
cast_forward_inputs=False,
398-
).cuda()
398+
).xpu()
399399
fully_shard(model.c2, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
400400
fully_shard(model)
401401
model(x).sum().backward()
@@ -408,7 +408,7 @@ def test_float16_on_one_submodule(self):
408408
forward_inputs: dict[nn.Module, torch.Tensor] = {}
409409
model = SaveForwardInputsModel(
410410
forward_inputs=forward_inputs, cast_forward_inputs=True
411-
).cuda()
411+
).xpu()
412412
fully_shard(
413413
model.c2,
414414
mp_policy=MixedPrecisionPolicy(
@@ -426,7 +426,7 @@ def test_float16_on_one_submodule(self):
426426
forward_inputs: dict[nn.Module, torch.Tensor] = {}
427427
model = SaveForwardInputsModel(
428428
forward_inputs=forward_inputs, cast_forward_inputs=False
429-
).cuda()
429+
).xpu()
430430
fully_shard(
431431
model.c1,
432432
mp_policy=MixedPrecisionPolicy(
@@ -468,13 +468,13 @@ def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
468468
def forward(self, x: torch.Tensor) -> torch.Tensor:
469469
self.forward_inputs["model_input_x"] = x
470470
y = torch.ones(
471-
2, 100, device="cuda", dtype=torch.float32
471+
2, 100, device="xpu", dtype=torch.float32
472472
) # external input
473473
return self.l2(self.l1(x), y)
474474

475475
forward_inputs: dict[str, torch.Tensor] = {}
476-
model = ToyModel(forward_inputs).cuda()
477-
x = torch.zeros(2, 100, device="cuda", dtype=torch.float32)
476+
model = ToyModel(forward_inputs).xpu()
477+
x = torch.zeros(2, 100, device="xpu", dtype=torch.float32)
478478
fully_shard(
479479
model.l2,
480480
mp_policy=MixedPrecisionPolicy(
@@ -577,7 +577,7 @@ def assert_fn(output: torch.Tensor):
577577
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
578578
)
579579
with patch_reduce_scatter(reduce_scatter):
580-
inp = torch.randn((4, 32), device="cuda")
580+
inp = torch.randn((4, 32), device="xpu")
581581
loss = model(inp).sum()
582582
loss.backward()
583583

test/distributed/_composable/fsdp/test_fully_shard_overlap.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class TestFullyShardOverlap(FSDPTest):
3535

3636
@property
3737
def world_size(self) -> int:
38-
return min(2, torch.cuda.device_count())
38+
return min(2, torch.xpu.device_count())
3939

4040
@skip_if_lt_x_gpu(2)
4141
def test_fully_shard_training_overlap(self):
@@ -46,23 +46,23 @@ def test_fully_shard_training_overlap(self):
4646
model = nn.Sequential(
4747
*[LinearWithSleep(dim, compute_sleep_ms) for _ in range(num_linears)]
4848
)
49-
ref_model = copy.deepcopy(model).cuda()
49+
ref_model = copy.deepcopy(model).xpu()
5050
for lin in model:
5151
assert len(list(lin.parameters())) == 1, "Expects only one weight"
5252
fully_shard(lin, reshard_after_forward=True)
5353
fully_shard(model, reshard_after_forward=True)
5454

5555
orig_all_gather_into_tensor = dist.all_gather_into_tensor
5656
orig_reduce_scatter_tensor = dist.reduce_scatter_tensor
57-
comm_stream = torch.cuda.Stream()
57+
comm_stream = torch.xpu.Stream()
5858

5959
def delay_collective():
6060
# Share a stream so that all-gather and reduce-scatter block each
6161
# other like in `ProcessGroupNCCL`
62-
comm_stream.wait_stream(torch.cuda.current_stream())
63-
with torch.cuda.stream(comm_stream):
64-
torch.cuda._sleep(int(comm_sleep_ms * get_cycles_per_ms()))
65-
torch.cuda.current_stream().wait_stream(comm_stream)
62+
comm_stream.wait_stream(torch.xpu.current_stream())
63+
with torch.xpu.stream(comm_stream):
64+
torch.xpu._sleep(int(comm_sleep_ms * get_cycles_per_ms()))
65+
torch.xpu.current_stream().wait_stream(comm_stream)
6666

6767
def delayed_all_gather(*args, **kwargs):
6868
delay_collective()
@@ -72,7 +72,7 @@ def delayed_reduce_scatter(*args, **kwargs):
7272
delay_collective()
7373
return orig_reduce_scatter_tensor(*args, **kwargs)
7474

75-
inp = torch.randn((2, dim), device="cuda")
75+
inp = torch.randn((2, dim), device="xpu")
7676
loss = model(inp).sum() # warmup CUDA and allocator
7777
loss.backward()
7878

@@ -153,17 +153,17 @@ def test_fully_shard_post_optim_event_overlap(self):
153153
# low-compute linear, where only the low-compute linear uses FSDP
154154
model = nn.Sequential(
155155
LinearWithSleep(dim, compute_sleep_ms), nn.Linear(dim, dim)
156-
).cuda()
156+
).xpu()
157157
fully_shard(model[1], reshard_after_forward=False)
158158
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
159159

160160
orig_all_gather_into_tensor = dist.all_gather_into_tensor
161161

162162
def delayed_all_gather(*args, **kwargs):
163-
torch.cuda._sleep(int(comm_sleep_ms * get_cycles_per_ms()))
163+
torch.xpu._sleep(int(comm_sleep_ms * get_cycles_per_ms()))
164164
return orig_all_gather_into_tensor(*args, **kwargs)
165165

166-
inp = torch.randn((2, dim), device="cuda")
166+
inp = torch.randn((2, dim), device="xpu")
167167

168168
def run_train_steps(num_iters: int, use_post_optim_event: bool):
169169
for _ in range(num_iters):
@@ -174,7 +174,7 @@ def run_train_steps(num_iters: int, use_post_optim_event: bool):
174174
with implicit_replication():
175175
optim.step()
176176
if use_post_optim_event:
177-
post_optim_event = torch.cuda.current_stream().record_event()
177+
post_optim_event = torch.xpu.current_stream().record_event()
178178
model[1].set_post_optim_event(post_optim_event)
179179

180180
run_train_steps(1, False) # warmup CUDA and allocator
@@ -205,14 +205,14 @@ def run_train_steps(num_iters: int, use_post_optim_event: bool):
205205
self.assertGreater(baseline_time, test_time)
206206

207207
def _time_fn(self, fn: Callable):
208-
start_event = torch.cuda.Event(enable_timing=True)
209-
end_event = torch.cuda.Event(enable_timing=True)
208+
start_event = torch.xpu.Event(enable_timing=True)
209+
end_event = torch.xpu.Event(enable_timing=True)
210210
dist.barrier()
211-
torch.cuda.synchronize()
211+
torch.xpu.synchronize()
212212
start_event.record()
213213
fn()
214214
end_event.record()
215-
torch.cuda.synchronize()
215+
torch.xpu.synchronize()
216216
elapsed_time = start_event.elapsed_time(end_event)
217217
return elapsed_time
218218

@@ -223,13 +223,13 @@ class Matmul(torch.autograd.Function):
223223
def forward(ctx, input: torch.Tensor, weight: torch.Tensor, sleep_ms: int):
224224
ctx.save_for_backward(input, weight)
225225
ctx.sleep_ms = sleep_ms
226-
torch.cuda._sleep(int(sleep_ms * get_cycles_per_ms()))
226+
torch.xpu._sleep(int(sleep_ms * get_cycles_per_ms()))
227227
return input @ weight
228228

229229
@staticmethod
230230
def backward(ctx, grad_output: torch.Tensor):
231231
(input, weight) = ctx.saved_tensors
232-
torch.cuda._sleep(int(2 * ctx.sleep_ms * get_cycles_per_ms()))
232+
torch.xpu._sleep(int(2 * ctx.sleep_ms * get_cycles_per_ms()))
233233
grad_input = grad_output @ weight.T
234234
grad_weight = input.T @ grad_output
235235
return grad_input, grad_weight, None

test/distributed/_composable/fsdp/test_fully_shard_state.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77
from torch.distributed.fsdp import FSDPModule, fully_shard
88
from torch.testing._internal.common_cuda import TEST_CUDA
99
from torch.testing._internal.common_fsdp import FSDPTestMultiThread, MLP
10-
from torch.testing._internal.common_utils import run_tests
10+
from torch.testing._internal.common_utils import run_tests,TEST_XPU
1111

1212

1313
class TestFullyShardState(FSDPTestMultiThread):
1414
@property
1515
def world_size(self) -> int:
1616
return 1
1717

18-
@unittest.skipIf(not TEST_CUDA, "no cuda")
18+
@unittest.skipIf(not TEST_XPU, "no xpu")
1919
def test_fully_shard_state(self):
2020
"""
2121
Tests the ability to get the state object from a fully sharded module.
@@ -31,7 +31,7 @@ def test_fully_shard_state(self):
3131
# Check that each `fully_shard` call constructs a distinct state object
3232
self.assertEqual(len(set(all_states)), num_mlps + 1)
3333

34-
@unittest.skipIf(not TEST_CUDA, "no cuda")
34+
@unittest.skipIf(not TEST_XPU, "no xpu")
3535
def test_fully_shard_reapply(self):
3636
model = MLP(8)
3737
fully_shard(model)
@@ -41,7 +41,7 @@ def test_fully_shard_reapply(self):
4141
):
4242
fully_shard(model)
4343

44-
@unittest.skipIf(not TEST_CUDA, "no cuda")
44+
@unittest.skipIf(not TEST_XPU, "no xpu")
4545
def test_fully_shard_cls(self):
4646
# Check that we only swap class for the module passed to `fully_shard`
4747
model = MLP(8)
@@ -64,7 +64,7 @@ def test_fully_shard_cls(self):
6464
self.assertTrue(isinstance(sliced_model, nn.Sequential))
6565
self.assertFalse(isinstance(sliced_model, FSDPModule))
6666

67-
@unittest.skipIf(not TEST_CUDA, "no cuda")
67+
@unittest.skipIf(not TEST_XPU, "no xpu")
6868
def test_fully_shard_unsupported_module_cls(self):
6969
regex = (
7070
r"fully\_shard does not support containers that do not implement forward"
@@ -76,7 +76,7 @@ def test_fully_shard_unsupported_module_cls(self):
7676
with self.assertRaisesRegex(ValueError, regex):
7777
fully_shard(model)
7878

79-
@unittest.skipIf(not TEST_CUDA, "no cuda")
79+
@unittest.skipIf(not TEST_XPU, "no xpu")
8080
def test_fully_shard_deepcopy(self):
8181
model = MLP(8)
8282
fully_shard(model)

0 commit comments

Comments
 (0)