From 40d58190d7d7f7cce99a6c2526bef3e93bca901f Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 23 Jul 2025 18:16:32 -0700 Subject: [PATCH] fix DTensor placements for table wise sharding (#3215) Summary: TSIA Reviewed By: XilunWu Differential Revision: D78594015 --- torchrec/distributed/sharding/tw_sharding.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchrec/distributed/sharding/tw_sharding.py b/torchrec/distributed/sharding/tw_sharding.py index d4c1ca48a..e7c5a1687 100644 --- a/torchrec/distributed/sharding/tw_sharding.py +++ b/torchrec/distributed/sharding/tw_sharding.py @@ -135,12 +135,9 @@ def _shard( dtensor_metadata = None if self._env.output_dtensor: dtensor_metadata = DTensorMetadata( - mesh=( - self._env.device_mesh["replicate"] # pyre-ignore[16] - if self._is_2D_parallel - else self._env.device_mesh - ), - placements=(Replicate(),), + mesh=self._env.device_mesh, + placements=(Replicate(),) + * (self._env.device_mesh.ndim), # pyre-ignore[16] size=( info.embedding_config.num_embeddings, info.embedding_config.embedding_dim,