Skip to content

Commit 204b93d

Browse files
XuehaiPanpobin6
authored andcommitted
[dynamo] support operator.methodcaller (pytorch#141137)
Pull Request resolved: pytorch#141137 Approved by: https://github.com/jansel ghstack dependencies: pytorch#141122
1 parent 5fe1ad4 commit 204b93d

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

test/dynamo/test_functions.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3109,6 +3109,25 @@ def fn(x, y):
31093109
y = torch.randn(3, 4)
31103110
self.assertEqual(opt_fn(x, y), fn(x, y))
31113111

3112+
def test_methodcaller(self):
3113+
for name, args, kwargs in (
3114+
("size", (), {}),
3115+
("size", (0,), {}),
3116+
("add", (torch.randn(3, 4),), {}),
3117+
("add", (torch.randn(3, 4),), {"alpha": 2.0}),
3118+
):
3119+
with self.subTest(name=name, args=args, kwargs=kwargs):
3120+
3121+
def fn(x, y):
3122+
caller = operator.methodcaller(name, *args, **kwargs)
3123+
return caller(x), caller(y)
3124+
3125+
opt_fn = torch.compile(fullgraph=True)(fn)
3126+
3127+
x = torch.randn(3, 4)
3128+
y = torch.randn(3, 4)
3129+
self.assertEqual(opt_fn(x, y), fn(x, y))
3130+
31123131
def gen_random_range_args(self):
31133132
args_count = random.randint(1, 3)
31143133
args = [random.randint(-10, 10) for _ in range(args_count)]

torch/_dynamo/polyfills/operator.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
# Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`)
15-
__all__ = ["attrgetter", "itemgetter"]
15+
__all__ = ["attrgetter", "itemgetter", "methodcaller"]
1616

1717

1818
_T = TypeVar("_T")
@@ -95,3 +95,15 @@ def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc]
9595
return tuple(obj[item] for item in items)
9696

9797
return getter
98+
99+
100+
# Reference: https://docs.python.org/3/library/operator.html#operator.methodcaller
101+
@substitute_in_graph(operator.methodcaller, is_embedded_type=True) # type: ignore[arg-type]
102+
def methodcaller(name: str, /, *args: Any, **kwargs: Any) -> Callable[[Any], Any]:
103+
if not isinstance(name, str):
104+
raise TypeError("method name must be a string")
105+
106+
def caller(obj: Any) -> Any:
107+
return getattr(obj, name)(*args, **kwargs)
108+
109+
return caller

0 commit comments

Comments
 (0)