From 08e657ee9744dfef970f0ea9d9039480d4dafa5a Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 1 Oct 2025 15:27:56 +0000 Subject: [PATCH] Add ATen bucketing pass as well --- autoparallel/auto_bucketing.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/autoparallel/auto_bucketing.py b/autoparallel/auto_bucketing.py index f36d88be..b9fd5329 100644 --- a/autoparallel/auto_bucketing.py +++ b/autoparallel/auto_bucketing.py @@ -4,6 +4,9 @@ # LICENSE file in the root directory of this source tree. import torch +from torch._inductor.fx_passes.overlap_preserving_bucketer import ( + OverlapPreservingBucketer, +) from torch._inductor.fx_passes.overlap_scheduling import OverlapScheduler from .autobucketing_util import bucket_func, bucket_plan, bucket_utils, reorder @@ -101,14 +104,26 @@ class aten_autobucketing_config: max_in_flight_gb = 2.0 compute_overlap_multipler = 1.0 max_coll_distance = 100 + max_bucket_memory_gb = 1.0 def aten_autobucketing_reordering_pass( - gm: torch.fx.Graph, configs: "aten_autobucketing_config" + graph: torch.fx.Graph, configs: "aten_autobucketing_config" ) -> torch.fx.GraphModule: - return OverlapScheduler( - gm.owning_module, + overlap_scheduler = OverlapScheduler( + graph.owning_module, compute_overlap_multipler=configs.compute_overlap_multipler, max_in_flight_gb=configs.max_in_flight_gb, max_coll_distance=configs.max_coll_distance, - ).run() + ) + + gm = overlap_scheduler.run() + bucketer = OverlapPreservingBucketer( + graph=overlap_scheduler.graph, + collective_info=overlap_scheduler.collective_info, + node_ancestors=overlap_scheduler.node_ancestors, + scheduled=overlap_scheduler.scheduled, + max_bucket_memory_gb=configs.max_bucket_memory_gb, + ) + bucketer.bucket_collectives() + return gm