Skip to content

Commit a6d6f62

Browse files
faran928facebook-github-bot
authored andcommitted
Swap IntNBit TBE Kernel with SSD Embedding DB TBE Kernel for SSD Infernece Enablement (#3134)
Summary: Pull Request resolved: #3134 For SSD inference, we have added EmbeddingDB as a custom in house storage not exposed to OSS. We leverage TGIF stack to rewrite IntNBit TBE Kernel with SSD EmbeddingDB TBE kernel as SSD TBE embedding kernel can't be exposed within TorchRec code base. Additionally, for SSD we only provide in di_sharding_pass and SSD can be enabled without having additional DI shards. In that case, for the tables that assigned to CPU host we can just do tw sharding of those tables. Added the TW sharding logic accordingly. The diff also includes an option to manually enable TW sharding using DI / Universal Sharding Pass logic that can be used to override automated tw universal sharding behavior in case that's less efficient (as that's still being improved). Reviewed By: gyllstromk Differential Revision: D76953960 fbshipit-source-id: b2cef7eb118dd5f94242b85c2a931a4d365d58d4
1 parent a580e35 commit a6d6f62

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

torchrec/distributed/quant_embeddingbag.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def __init__(
224224
self._is_weighted: bool = module.is_weighted()
225225
self._lookups: List[nn.Module] = []
226226
self._create_lookups(fused_params, device)
227+
self._fused_params = fused_params
227228

228229
# Ensure output dist is set for post processing from an inference runtime (ie. setting device from runtime).
229230
self._output_dists: torch.nn.ModuleList = torch.nn.ModuleList()

0 commit comments

Comments
 (0)