diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index 2634c3983..54b54d5a1 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -827,6 +827,7 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: "_output_segments_tensor", "_current_iter_tensor", "_scalar_logger._scalar_logger_steps", + "_hash_zch_bucket", ]: continue if name in module._non_persistent_buffers_set: diff --git a/torchrec/modules/hash_mc_modules.py b/torchrec/modules/hash_mc_modules.py index 40ebfd3fe..e3b8f5447 100644 --- a/torchrec/modules/hash_mc_modules.py +++ b/torchrec/modules/hash_mc_modules.py @@ -305,6 +305,7 @@ def __init__( self._max_probe = max_probe self._buckets = total_num_buckets + self.register_buffer("_hash_zch_bucket", torch.tensor(total_num_buckets)) # Do not need to store in buffer since this is created and consumed # at each step https://fburl.com/code/axzimmbx self._evicted_indices = []