Skip to content

Commit 03d6165

Browse files
authored
Update tutorial of device mesh to use fsdp2 (#3472)
1 parent b78fc75 commit 03d6165

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

recipes_source/distributed_device_mesh.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ users would not need to manually create and manage shard group and replicate gro
121121
import torch.nn as nn
122122
123123
from torch.distributed.device_mesh import init_device_mesh
124-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
124+
from torch.distributed.fsdp import fully_shard as FSDP
125125
126126
127127
class ToyModel(nn.Module):
@@ -136,9 +136,9 @@ users would not need to manually create and manage shard group and replicate gro
136136
137137
138138
# HSDP: MeshShape(2, 4)
139-
mesh_2d = init_device_mesh("cuda", (2, 4))
139+
mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("dp_replicate", "dp_shard"))
140140
model = FSDP(
141-
ToyModel(), device_mesh=mesh_2d, sharding_strategy=ShardingStrategy.HYBRID_SHARD
141+
ToyModel(), device_mesh=mesh_2d
142142
)
143143
144144
Let's create a file named ``hsdp.py``.

0 commit comments

Comments
 (0)