diff --git a/torchprime/sharding/shard_model.py b/torchprime/sharding/shard_model.py index bf2030e6..737c64a1 100644 --- a/torchprime/sharding/shard_model.py +++ b/torchprime/sharding/shard_model.py @@ -315,3 +315,13 @@ def __init__(self, mod, mark_sharding, spec): def forward(self, *args, **kwargs): return self.mark_sharding(self._orig_mod(*args, **kwargs), self.spec) + + @property + def module(self): + return self._orig_mod + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.module, name)