Skip to content

Commit de34f58

Browse files
Revert "Made FlexAttention error on subgraph lowering failure (#140331)"
This reverts commit e68bc76. Reverted pytorch/pytorch#140331 on behalf of https://github.com/malfet due to Looks like it regressed trunk, see https://hud.pytorch.org/hud/pytorch/pytorch/55f1959fc148240e55a91d7625c750736ea5f2e4/1?per_page=50&name_filter=linux-focal-cuda12.1-py3.10-gcc9-sm86&mergeLF=true ([comment](pytorch/pytorch#140331 (comment)))
1 parent 55f1959 commit de34f58

File tree

7 files changed

+81
-123
lines changed

7 files changed

+81
-123
lines changed

test/inductor/test_flex_attention.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2965,61 +2965,6 @@ def test_mixed_device_error_message(self):
29652965
with self.assertRaisesRegex(ValueError, expected_error_message):
29662966
flex_attention(query, key, value)
29672967

2968-
@supported_platform
2969-
def test_captured_wrong_device_error_message(self):
2970-
means = torch.randn(64, 3).cuda()
2971-
length_scales = torch.logspace(0.001, 0.1, 8)
2972-
2973-
def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx):
2974-
q_pos = means[q_idx]
2975-
k_pos = means[k_idx]
2976-
dist = (q_pos - k_pos).pow(2).sum(-1).sqrt()
2977-
scale = length_scales[h]
2978-
inv_dist = torch.exp(-dist / scale)
2979-
return inv_dist * score
2980-
2981-
expected_error_message = "Buffers cannot be created"
2982-
2983-
q, k, v = (torch.randn(1, 8, 64, 64, device="cuda") for _ in range(3))
2984-
with self.assertRaisesRegex(RuntimeError, expected_error_message):
2985-
torch.compile(flex_attention)(q, k, v, score_mod=euclidean_dist_pos_embed)
2986-
2987-
@supported_platform
2988-
def test_cant_lower_error_message(self):
2989-
# We can't lower a 256-element reduction inside a pointwise reduction
2990-
means = torch.randn(64, 256).cuda()
2991-
length_scales = torch.logspace(0.001, 0.1, 8).cuda()
2992-
2993-
def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx):
2994-
q_pos = means[q_idx]
2995-
k_pos = means[k_idx]
2996-
dist = (q_pos - k_pos).pow(2).sum(-1).sqrt()
2997-
scale = length_scales[h]
2998-
inv_dist = torch.exp(-dist / scale)
2999-
return inv_dist * score
3000-
3001-
expected_error_message = "Buffers cannot be created"
3002-
3003-
q, k, v = (torch.randn(1, 8, 64, 64, device="cuda") for _ in range(3))
3004-
with self.assertRaisesRegex(RuntimeError, expected_error_message):
3005-
torch.compile(flex_attention)(q, k, v, score_mod=euclidean_dist_pos_embed)
3006-
3007-
@supported_platform
3008-
def test_reduction_unrolled(self):
3009-
# We can't lower a 256-element reduction inside a pointwise reduction
3010-
means = torch.randn(S, 3).cuda()
3011-
length_scales = torch.logspace(0.001, 0.1, H).cuda()
3012-
3013-
def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx):
3014-
q_pos = means[q_idx]
3015-
k_pos = means[k_idx]
3016-
dist = (q_pos - k_pos).pow(2).sum(-1).sqrt()
3017-
scale = length_scales[h]
3018-
inv_dist = torch.exp(-dist / scale)
3019-
return inv_dist * score
3020-
3021-
self.run_test(euclidean_dist_pos_embed, torch.bfloat16)
3022-
30232968
@supported_platform
30242969
def test_invalid_block_size(self):
30252970
# Create tensors on different devices

test/inductor/test_torchinductor_opinfo.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,10 +443,6 @@ def wrapper_noop_set_seed(op, *args, **kwargs):
443443
"atol": 3e-4,
444444
"rtol": 0.002,
445445
},
446-
("nn.functional.triplet_margin_with_distance_loss", f16): {
447-
"atol": 3e-4,
448-
"rtol": 0.003,
449-
},
450446
("softmax", f16): {"atol": 1e-4, "rtol": 0.02},
451447
("polygamma.polygamma_n_0", f32): {"atol": 1e-3, "rtol": 1e-4},
452448
("polygamma.polygamma_n_1", f32): {"atol": 1e-3, "rtol": 1e-4},

torch/_higher_order_ops/flex_attention.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,8 @@ def flex_attention_functionalize(
410410
assert isinstance(mask_mod_other_buffers_unwrapped, tuple)
411411

412412
example_vals = (
413-
[query_unwrapped.new_zeros(())]
414-
+ [query_unwrapped.new_zeros((), dtype=torch.int) for _ in range(4)]
413+
[torch.zeros((), dtype=query.dtype)]
414+
+ [torch.zeros((), dtype=torch.int) for _ in range(4)]
415415
+ list(score_mod_other_buffers_unwrapped)
416416
)
417417
with ctx.redispatch_to_next() as m:
@@ -710,11 +710,11 @@ def flex_attention_autograd(
710710
input_requires_grad = any(t.requires_grad for t in (query, key, value))
711711
if torch.is_grad_enabled() and input_requires_grad:
712712
example_vals = (
713-
query.new_zeros((), requires_grad=input_requires_grad),
714-
query.new_zeros((), dtype=torch.int),
715-
query.new_zeros((), dtype=torch.int),
716-
query.new_zeros((), dtype=torch.int),
717-
query.new_zeros((), dtype=torch.int),
713+
torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad),
714+
torch.zeros((), dtype=torch.int),
715+
torch.zeros((), dtype=torch.int),
716+
torch.zeros((), dtype=torch.int),
717+
torch.zeros((), dtype=torch.int),
718718
)
719719
fw_graph, bw_graph = create_fw_bw_graph(
720720
score_mod, example_vals, score_mod_other_buffers
@@ -930,11 +930,11 @@ def trace_flex_attention_backward(
930930
mask_mod_other_buffers,
931931
)
932932

933-
fw_example_vals = [query.new_zeros((), requires_grad=query.requires_grad)] + [
934-
query.new_zeros((), dtype=torch.int) for _ in range(4)
935-
]
936-
bw_example_vals = fw_example_vals + [query.new_zeros(())]
937-
mask_example_vals = [query.new_zeros((), dtype=torch.int) for _ in range(4)]
933+
fw_example_vals = [
934+
torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad)
935+
] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
936+
bw_example_vals = fw_example_vals + [torch.zeros((), dtype=query.dtype)]
937+
mask_example_vals = [torch.zeros((), dtype=torch.int) for _ in range(4)]
938938
mask_graph = block_mask[-1]
939939
with TransformGetItemToIndex():
940940
fw_graph = reenter_make_fx(fw_graph)(*fw_example_vals, *score_mod_other_buffers)

torch/_inductor/ir.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,10 +1254,8 @@ def fn(index: int) -> OpsValue:
12541254
isinstance(reduction_numel, Integer)
12551255
and V.graph.sizevars.size_hint(reduction_numel)
12561256
< config.unroll_reductions_threshold
1257-
and (sympy_product(ranges) != 1 or device.type == "cuda")
1257+
and sympy_product(ranges) != 1
12581258
):
1259-
# NB: This works around https://github.com/pytorch/pytorch/issues/140457
1260-
# since turning reductions into pointwise ops can exacerbate this problem
12611259
return Pointwise.create(
12621260
device=device,
12631261
dtype=dst_dtype,

torch/_inductor/kernel/flex_attention.py

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -104,40 +104,59 @@ def build_subgraph_buffer(
104104
args: The args that are passed into the subgraph. Contains both fixed and lifted inputs.
105105
subgraph: The Subgraph ir for which to produce the output node
106106
"""
107-
from ..subgraph_lowering import PointwiseSubgraphLowering
108-
109-
pw_subgraph = PointwiseSubgraphLowering(
110-
subgraph.graph_module, root_graph_lowering=V.graph
111-
)
112-
with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type]
113-
pw_subgraph.run(*args)
114-
115-
def convert_output_node_to_buffer(output):
116-
if output is None:
117-
return None
118-
output_buffer = output
119-
assert isinstance(output_buffer, TensorBox), (
120-
"The output node for flex attention's subgraph must be a TensorBox, but got: ",
121-
type(output_buffer),
122-
)
123-
assert isinstance(output_buffer.data, StorageBox), (
124-
"The output node for the flex attention subgraph must be a StorageBox, but got: ",
125-
type(output_buffer),
126-
)
127-
subgraph_buffer = ComputedBuffer(
128-
name=None,
129-
layout=FlexibleLayout(
130-
device=output_buffer.data.get_device(),
131-
dtype=output_buffer.data.get_dtype(),
132-
size=output_buffer.data.get_size(),
133-
),
134-
data=output_buffer.data.data, # type: ignore[arg-type]
135-
)
136-
return subgraph_buffer
137-
138-
# node.args[0] is either a single element or a list of elements
139-
# representing all outputs of the function.
140-
return tree_map(convert_output_node_to_buffer, pw_subgraph.graph_outputs)
107+
cnt = 0
108+
env = {}
109+
for node in subgraph.graph_module.graph.nodes:
110+
# There are two classes of placeholder inpts that we need
111+
# to handle differently. For the first n_scalar_inps inputs
112+
# we expect that these placeholders were generated by the make_fx call
113+
# in the flex Attention HOP. So we need to create a new placeholder
114+
# TensorBox for each of these inputs. For the rest of the inputs we
115+
# expect that these are lifted inputs that fill up the '*other_buffers'
116+
# tuple and already have corresponding TensorBoxes passed in as args.
117+
with V.graph.set_current_node(node):
118+
if node.op == "placeholder":
119+
env[node] = args[cnt]
120+
cnt += 1
121+
elif node.op == "call_function":
122+
# For call_function we use the default lowerings and pass in the
123+
# already created TensorBoxes as args
124+
125+
args, kwargs = tree_map(
126+
lambda x: env[x] if x in env else x, (node.args, node.kwargs)
127+
)
128+
env[node] = lowerings[node.target](*args, **kwargs)
129+
elif node.op == "output":
130+
131+
def convert_output_node_to_buffer(output):
132+
if output is None:
133+
return None
134+
output_node = output
135+
output_buffer = env[output_node]
136+
assert isinstance(output_buffer, TensorBox), (
137+
"The output node for flex attention's subgraph must be a TensorBox, but got: ",
138+
type(output_buffer),
139+
)
140+
assert isinstance(output_buffer.data, StorageBox), (
141+
"The output node for the flex attention subgraph must be a StorageBox, but got: ",
142+
type(output_buffer),
143+
)
144+
subgraph_buffer = ComputedBuffer(
145+
name=None,
146+
layout=FlexibleLayout(
147+
device=output_buffer.data.get_device(),
148+
dtype=output_buffer.data.get_dtype(),
149+
size=output_buffer.data.get_size(),
150+
),
151+
data=output_buffer.data.data, # type: ignore[arg-type]
152+
)
153+
return subgraph_buffer
154+
155+
# node.args[0] is either a single element or a list of elements
156+
# representing all outputs of the function.
157+
return tree_map(convert_output_node_to_buffer, node.args[0])
158+
159+
raise ValueError("FlexAttention was passed a subgraph with no output node!")
141160

142161

143162
# Inner Triton functions shared by flex_attention & split-k decoding kernels.
@@ -503,7 +522,7 @@ def forward_block_mn(
503522
) | indent_except_first(2) }}
504523
505524
if CHECK_BLOCK_BOUNDARY:
506-
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
525+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, float("-inf"))
507526
# apply mask for partially unmasked blocks
508527
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
509528
@@ -1739,8 +1758,6 @@ def flex_attention_backward(*args, **kwargs):
17391758
joint_placeholder_inps = fwd_placeholder_inps + [
17401759
create_placeholder("grad_score_mod", dtype, device)
17411760
]
1742-
# Sometimes we have weird unused nodes here
1743-
joint_graph.graph_module.graph.eliminate_dead_code()
17441761
joint_subgraph_buffer, *_ = build_subgraph_buffer(
17451762
joint_placeholder_inps + list(score_mod_other_buffers), joint_graph
17461763
)

torch/_inductor/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474

7575

7676
log = logging.getLogger(__name__)
77-
lowerings: Dict[Callable[..., Any], Callable[..., Any]] = {}
77+
lowerings: Dict[torch._ops.OpOverload, Callable[..., Any]] = {}
7878
# Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints
7979
_maybe_layout_constraints: Dict[
8080
torch._ops.OpOverload, Optional[Callable[..., Any]]

torch/_inductor/subgraph_lowering.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from . import ir
1414
from .exc import SubgraphLoweringException
1515
from .ops_handler import SimpleCSEHandler
16+
from .sizevars import SizeVarAllocator
1617
from .virtualized import ops, V, WrapperHandler
1718

1819

@@ -21,11 +22,6 @@
2122

2223

2324
class PointwiseSubgraphLowering(torch.fx.Interpreter):
24-
"""
25-
Lowers a pointwise subgraph to a single set of buffers with a separate
26-
lowering object. Errors if buffers are created unexpectedly
27-
"""
28-
2925
graph_outputs: Optional[List[ir.IRNode]]
3026

3127
def __init__(
@@ -37,19 +33,18 @@ def __init__(
3733
self.graph_outputs = None
3834
self.root_graph = root_graph_lowering
3935

36+
@property
37+
def sizevars(self) -> SizeVarAllocator:
38+
return self.root_graph.sizevars
39+
4040
def mark_buffer_mutated(self, name: str) -> None:
4141
raise SubgraphLoweringException("Mutations are not supported in this context")
4242

4343
def register_buffer(self, buffer: ir.Buffer) -> str:
4444
raise SubgraphLoweringException(
45-
"Buffers cannot be created while lowering a pointwise subgraph. "
46-
"This could be for a good reason (e.g. you're calling an op we can't codegen as a pointwise op), "
47-
"but it could also be a bug. Please file a bug report if you think this should be supportable."
45+
"Buffer creation is not supported in this context"
4846
)
4947

50-
def __getattr__(self, name: str) -> Any:
51-
return getattr(self.root_graph, name)
52-
5348
def call_function(
5449
self,
5550
target: Callable[[Any], Any], # type: ignore[override]
@@ -61,11 +56,18 @@ def call_function(
6156
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
6257
return super().call_function(target, args, kwargs)
6358

59+
assert isinstance(target, torch._ops.OpOverload)
60+
6461
if target not in lowerings:
6562
raise SubgraphLoweringException(
6663
f"{target} not supported in subgraph, (missing lowering)"
6764
)
6865

66+
if torch.Tag.pointwise not in target.tags:
67+
raise SubgraphLoweringException(
68+
f"Only pointwise operators are supported in this context, but got {target}"
69+
)
70+
6971
return lowerings[target](*args, **kwargs)
7072

7173
def output(self, target: str, args: Tuple[Any], kwargs: Dict[str, Any]) -> None: # type: ignore[override]

0 commit comments

Comments
 (0)