Skip to content

Commit d6091c8

Browse files
laithsakkapytorchmergebot
authored andcommitted
Add compile time instruction count metric (#133834)
PYTHONPATH=$(pwd) python benchmarks/update_hint_benchmark.py out as of this diff, compile_time_instruction_count counts the number of instruction from within convert_frame.compile_inner ``` update_hint_regression,compile_time_instruction_count,10522459165 ``` will add result from CI once populated. Pull Request resolved: #133834 Approved by: https://github.com/aorenste
1 parent ef0f591 commit d6091c8

File tree

6 files changed

+122
-20
lines changed

6 files changed

+122
-20
lines changed

benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from fbscribelogger import make_scribe_logger
55

66
import torch._C._instruction_counter as i_counter
7+
import torch._dynamo.config as config
8+
from torch._dynamo.utils import CompileTimeInstructionCounter
79

810

911
scribe_log_torch_benchmark_compile_time = make_scribe_logger(
@@ -51,10 +53,19 @@
5153

5254

5355
class BenchmarkBase(ABC):
54-
_instruction_count = False
56+
# measure total number of instruction spent in _work.
57+
_enable_instruction_count = False
58+
59+
# measure total number of instruction spent in convert_frame.compile_inner
60+
# TODO is there other parts we need to add ?
61+
_enable_compile_time_instruction_count = False
5562

5663
def enable_instruction_count(self):
57-
self._instruction_count = True
64+
self._enable_instruction_count = True
65+
return self
66+
67+
def enable_compile_time_instruction_count(self):
68+
self._enable_compile_time_instruction_count = True
5869
return self
5970

6071
def name(self):
@@ -64,29 +75,44 @@ def description(self):
6475
return ""
6576

6677
@abstractmethod
67-
def prepare(self):
78+
def _prepare(self):
6879
pass
6980

7081
@abstractmethod
71-
def work(self):
82+
def _work(self):
7283
pass
7384

74-
def prepare_once(self): # noqa: B027
85+
def _prepare_once(self): # noqa: B027
7586
pass
7687

77-
def count_instructions(self):
88+
def _count_instructions(self):
7889
print(f"collecting instruction count for {self.name()}")
79-
self.prepare_once()
80-
8190
results = []
8291
for i in range(10):
83-
self.prepare()
92+
self._prepare()
8493
id = i_counter.start()
85-
self.work()
94+
self._work()
8695
count = i_counter.end(id)
8796
print(f"instruction count for iteration {i} is {count}")
88-
if i != 0:
89-
results.append(count)
97+
results.append(count)
98+
return min(results)
99+
100+
def _count_compile_time_instructions(self):
101+
print(f"collecting compile time instruction count for {self.name()}")
102+
config.record_compile_time_instruction_count = True
103+
104+
results = []
105+
for i in range(10):
106+
self._prepare()
107+
# CompileTimeInstructionCounter.record is only called on convert_frame._compile_inner
108+
# hence this will only count instruction count spent in compile_inner.
109+
CompileTimeInstructionCounter.clear()
110+
self._work()
111+
count = CompileTimeInstructionCounter.value()
112+
print(f"compile time instruction count for iteration {i} is {count}")
113+
results.append(count)
114+
115+
config.record_compile_time_instruction_count = False
90116
return min(results)
91117

92118
def append_results(self, path):
@@ -102,12 +128,36 @@ def print(self):
102128
print(f"{entry[0]},{entry[1]},{entry[2]}")
103129

104130
def collect_all(self):
131+
self._prepare_once()
105132
self.results = []
106-
if self._instruction_count:
107-
r = self.count_instructions()
133+
if (
134+
self._enable_instruction_count
135+
and self._enable_compile_time_instruction_count
136+
):
137+
raise RuntimeError(
138+
"not supported until we update the logger, both logs to the same field now"
139+
)
140+
141+
if self._enable_instruction_count:
142+
r = self._count_instructions()
108143
self.results.append((self.name(), "instruction_count", r))
109144
scribe_log_torch_benchmark_compile_time(
110145
name=self.name(),
111146
instruction_count=r,
112147
)
148+
if self._enable_compile_time_instruction_count:
149+
r = self._count_compile_time_instructions()
150+
151+
self.results.append(
152+
(
153+
self.name(),
154+
"compile_time_instruction_count",
155+
r,
156+
)
157+
)
158+
# TODO add a new field compile_time_instruction_count to the logger.
159+
scribe_log_torch_benchmark_compile_time(
160+
name=self.name(),
161+
instruction_count=r,
162+
)
113163
return self

benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@ def name(self):
1515
def description(self):
1616
return "information at https://github.com/pytorch/pytorch/pull/129893"
1717

18-
def prepare_once(self):
18+
def _prepare_once(self):
1919
torch._dynamo.config.capture_scalar_outputs = True
2020
random.seed(42)
2121
self.splits = torch.randint(10, (self.N,))
2222
sz = self.splits.sum().item()
2323
self.input = torch.randn(sz)
2424

25-
def prepare(self):
25+
def _prepare(self):
2626
torch._dynamo.reset()
2727

28-
def work(self):
28+
def _work(self):
2929
@torch.compile(fullgraph=True)
3030
def f(a, b):
3131
xs = b.tolist()
@@ -34,12 +34,15 @@ def f(a, b):
3434
torch._check(x <= self.N)
3535
return a.split(xs)
3636

37-
f(self.input, self.splits)
37+
for i in range(1000):
38+
f(self.input, self.splits)
3839

3940

4041
def main():
4142
result_path = sys.argv[1]
42-
Benchmark().enable_instruction_count().collect_all().append_results(result_path)
43+
Benchmark().enable_compile_time_instruction_count().collect_all().append_results(
44+
result_path
45+
)
4346

4447

4548
if __name__ == "__main__":

torch/_C/_instruction_counter.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Defined in torch/csrc/instruction_counter/Module.cpp
2+
3+
def start() -> int: ...
4+
def end(id: int) -> int: ...

torch/_dynamo/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,10 @@ def _get_optimize_ddp_mode():
374374
# Inline inbuilt nn modules
375375
inline_inbuilt_nn_modules = not is_fbcode()
376376

377+
# When set, total compile time instruction count is recorded using
378+
# torch._dynamo.utilsCompileTimeInstructionCounter.
379+
record_compile_time_instruction_count = False
380+
377381

378382
def default_debug_dir_root():
379383
# [@compile_ignored: debug]

torch/_dynamo/convert_frame.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import torch._logging
2828
from torch._C._dynamo.guards import GlobalStateGuard
2929
from torch._dynamo.distributed import get_compile_pg
30+
from torch._dynamo.utils import CompileTimeInstructionCounter
3031
from torch._guards import compile_context, CompileContext, CompileId, tracing
3132
from torch._logging import structured
3233
from torch._utils_internal import (
@@ -652,7 +653,8 @@ def compile_inner(
652653
transform: Callable[[List[Instruction], Dict[str, Any]], Any],
653654
) -> Optional[GuardedCode]:
654655
with dynamo_timed("_compile.compile_inner", phase_name="entire_frame_compile"):
655-
return _compile_inner(code, one_graph, hooks, transform)
656+
with CompileTimeInstructionCounter.record():
657+
return _compile_inner(code, one_graph, hooks, transform)
656658

657659
@compile_time_strobelight_meta(phase_name="compile_inner")
658660
@maybe_cprofile

torch/_dynamo/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
from torch import fx
6565
from torch._C import (
6666
_get_function_stack_at,
67+
_instruction_counter,
6768
_len_torch_function_stack,
6869
_pop_torch_function_stack,
6970
_push_on_torch_function_stack,
@@ -3203,3 +3204,41 @@ def get_user_object_from_id(obj_id):
32033204
def store_user_object_weakref(obj):
32043205
obj_id = id(obj)
32053206
user_obj_id_to_weakref[obj_id] = weakref.ref(obj)
3207+
3208+
3209+
class CompileTimeInstructionCounter:
3210+
_counter: int = 0
3211+
_id: int = -1
3212+
_depth = 0
3213+
3214+
@classmethod
3215+
def start(cls) -> None:
3216+
cls._depth = cls._depth + 1
3217+
if cls._depth == 1:
3218+
cls._id = _instruction_counter.start()
3219+
3220+
@classmethod
3221+
def end(cls) -> None:
3222+
cls._depth = cls._depth - 1
3223+
if cls._depth == 0:
3224+
cls._counter += _instruction_counter.end(cls._id)
3225+
cls._id = -1
3226+
3227+
@classmethod
3228+
def clear(cls) -> None:
3229+
cls._counter = 0
3230+
3231+
@classmethod
3232+
def value(cls) -> int:
3233+
return cls._counter
3234+
3235+
@classmethod
3236+
@contextmanager
3237+
def record(cls):
3238+
try:
3239+
if config.record_compile_time_instruction_count:
3240+
cls.start()
3241+
yield
3242+
finally:
3243+
if config.record_compile_time_instruction_count:
3244+
cls.end()

0 commit comments

Comments
 (0)