Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,23 +866,46 @@ class VariableBatchPooledEmbeddingsAllToAll(nn.Module):
Example::

kjt_split = [1, 2]

# the kjt_split informs the number of features owned by each rank, here t0 owns f0 and
# t1 owns f1 and f2.

emb_dim_per_rank_per_feature = [[2], [3, 3]]
a2a = VariableBatchPooledEmbeddingsAllToAll(
pg, emb_dim_per_rank_per_feature, device
)

t0 = torch.rand(6) # 2 * (2 + 1)
t1 = torch.rand(24) # 3 * (1 + 3) + 3 * (2 + 2)
t1 = torch.rand(24) # 3 * (1 + 2) + 3 * (3 + 2)

# t0 and t1 are the flattened send buffers of pooled embedding outputs produced on the
# ranks that own the features, computed as embedding_dim * (sum of variable batch sizes
# for that feature across all source ranks), summed over the features owned by that destination rank.

# r0_batch_size r1_batch_size
# f_0: 2 1
-----------------------------------------
# f_1: 1 2
# f_2: 3 2

# batch_size_per_rank_per_feature tensor is specified from the perspective of the sending rank
# outer_index = destination rank, inner vector = features ownwed by the sending rank (in emb_dim_per_rank_per_feature order)

r0_batch_size_per_rank_per_feature = [[2], [1]]
r1_batch_size_per_rank_per_feature = [[1, 3], [2, 2]]

# r0 wants r1 wants
# f0: 2 1
# f1: 1 2
# f2: 3 2
# which informs the per_feature_pre_a2a vectors

r0_batch_size_per_feature_pre_a2a = [2, 1, 3]
r1_batch_size_per_feature_pre_a2a = [1, 2, 2]

# r0 should recieve f0: 2 (from r0), f1: 1 (from r1), f2: 3 (from r1)
# r1 should recieve f0: 1 (from r0), f1: 2 (from r1), f2: 2 (from r1)

rank0_output = a2a(
t0, r0_batch_size_per_rank_per_feature, r0_batch_size_per_feature_pre_a2a
).wait()
Expand Down
27 changes: 14 additions & 13 deletions torchrec/distributed/sharding/tw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,8 @@ def __init__(
else None
)
self._emb_dim_per_rank_per_feature = emb_dim_per_rank_per_feature
self._dist: Optional[
Union[PooledEmbeddingsAllToAll, VariableBatchPooledEmbeddingsAllToAll]
] = None
self._dist: Optional[PooledEmbeddingsAllToAll] = None
self._variable_dist: Optional[VariableBatchPooledEmbeddingsAllToAll] = None

def forward(
self,
Expand All @@ -371,7 +370,10 @@ def forward(
if sharding_ctx is None:
return cast(PooledEmbeddingsAllToAll, self._dist)(local_embs)
elif sharding_ctx.variable_batch_per_feature:
return cast(VariableBatchPooledEmbeddingsAllToAll, self._dist)(
assert (
self._variable_dist is not None
), "variable batch dist is not initialized!"
return self._variable_dist(
local_embs,
batch_size_per_rank_per_feature=sharding_ctx.batch_size_per_rank_per_feature,
batch_size_per_feature_pre_a2a=sharding_ctx.batch_size_per_feature_pre_a2a,
Expand All @@ -386,21 +388,20 @@ def _create_output_dist_module(
self, sharding_ctx: Optional[EmbeddingShardingContext] = None
) -> None:
if sharding_ctx is not None and sharding_ctx.variable_batch_per_feature:
self._dist = VariableBatchPooledEmbeddingsAllToAll(
self._variable_dist = VariableBatchPooledEmbeddingsAllToAll(
pg=self._pg,
emb_dim_per_rank_per_feature=self._emb_dim_per_rank_per_feature,
device=self._device,
callbacks=None,
codecs=self._codecs,
)
else:
self._dist = PooledEmbeddingsAllToAll(
pg=self._pg,
dim_sum_per_rank=self._dim_sum_per_rank,
device=self._device,
callbacks=self._callbacks,
codecs=self._codecs,
)
self._dist = PooledEmbeddingsAllToAll(
pg=self._pg,
dim_sum_per_rank=self._dim_sum_per_rank,
device=self._device,
callbacks=self._callbacks,
codecs=self._codecs,
)


class TwPooledEmbeddingSharding(
Expand Down
49 changes: 23 additions & 26 deletions torchrec/distributed/sharding/twrw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,18 +472,14 @@ def __init__(
if qcomm_codecs_registry
else None
)
self._intra_dist: Optional[
Union[
PooledEmbeddingsReduceScatter,
VariableBatchPooledEmbeddingsReduceScatter,
]
] = None
self._cross_dist: Optional[
Union[
PooledEmbeddingsAllToAll,
VariableBatchPooledEmbeddingsAllToAll,
]
self._intra_dist: Optional[VariableBatchPooledEmbeddingsReduceScatter] = None
self._cross_dist: Optional[PooledEmbeddingsAllToAll] = None
self._variable_intra_dist: Optional[
VariableBatchPooledEmbeddingsReduceScatter
] = None
self._variable_cross_dist: Optional[VariableBatchPooledEmbeddingsAllToAll] = (
None
)

def forward(
self,
Expand Down Expand Up @@ -514,13 +510,15 @@ def forward(
sharding_ctx.batch_size_per_rank_per_feature,
)
rs_result = cast(
VariableBatchPooledEmbeddingsReduceScatter, self._intra_dist
VariableBatchPooledEmbeddingsReduceScatter, self._variable_intra_dist
)(
local_embs,
batch_size_per_rank_per_feature=batch_size_per_feature_sum_by_cross_group,
embedding_dims=self._emb_dim_per_node_per_feature[current_node],
).wait()
return cast(VariableBatchPooledEmbeddingsAllToAll, self._cross_dist)(
return cast(
VariableBatchPooledEmbeddingsAllToAll, self._variable_cross_dist
)(
rs_result,
batch_size_per_rank_per_feature=batch_size_per_rank_per_feature_by_cross_group[
local_rank
Expand Down Expand Up @@ -615,28 +613,27 @@ def _create_output_dist_modules(
self, sharding_ctx: Optional[EmbeddingShardingContext] = None
) -> None:
if sharding_ctx is not None and sharding_ctx.variable_batch_per_feature:
self._intra_dist = VariableBatchPooledEmbeddingsReduceScatter(
self._variable_intra_dist = VariableBatchPooledEmbeddingsReduceScatter(
pg=self._intra_pg,
codecs=self._intra_codecs,
)
self._cross_dist = VariableBatchPooledEmbeddingsAllToAll(
self._variable_cross_dist = VariableBatchPooledEmbeddingsAllToAll(
pg=self._cross_pg,
emb_dim_per_rank_per_feature=self._emb_dim_per_node_per_feature,
device=self._device,
callbacks=None, # don't pass permute callback, handle in LazyAwaitable
codecs=self._cross_codecs,
)
else:
self._intra_dist = PooledEmbeddingsReduceScatter(
pg=self._intra_pg,
codecs=self._intra_codecs,
)
self._cross_dist = PooledEmbeddingsAllToAll(
pg=self._cross_pg,
dim_sum_per_rank=self._dim_sum_per_node,
device=self._device,
codecs=self._cross_codecs,
)
self._intra_dist = PooledEmbeddingsReduceScatter(
pg=self._intra_pg,
codecs=self._intra_codecs,
)
self._cross_dist = PooledEmbeddingsAllToAll(
pg=self._cross_pg,
dim_sum_per_rank=self._dim_sum_per_node,
device=self._device,
codecs=self._cross_codecs,
)


class TwRwPooledEmbeddingSharding(
Expand Down
Loading