Skip to content

Commit 4e16c94

Browse files
committed
Add multicast tensor
stack-info: PR: #346, branch: joydddd/stack/17
1 parent 642836c commit 4e16c94

File tree

10 files changed

+1038
-23
lines changed

10 files changed

+1038
-23
lines changed

helion/_compiler/device_ir.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from .type_propagation import LiteralType
5555
from .type_propagation import NumericType
5656
from .type_propagation import SequenceType
57+
from .type_propagation import StackTensorType
5758
from .type_propagation import TensorType
5859
from .type_propagation import TileIndexType
5960
from .type_propagation import TypeInfo
@@ -321,12 +322,14 @@ def build_rolled_reductions(self) -> None:
321322
graph_to_info = {}
322323
allow_loop = False
323324

324-
# First, check if any graph contains matmul with rdim
325+
# First, check if any graph contains matmul or dev_prts stacking with rdim
325326
# If so, we can't roll any graphs in this reduction dimension
326327
can_roll_graphs = True
327328
for graph_info in self.graphs:
328329
roller = ReductionRoller(self, rdim, {})
329-
if roller.has_matmul_with_rdim(graph_info.graph):
330+
if roller.has_matmul_with_rdim(
331+
graph_info.graph
332+
) or roller.has_stack_tensor_with_rdim(graph_info.graph):
330333
can_roll_graphs = False
331334
break
332335

@@ -870,7 +873,9 @@ def visit_Assign(self, node: ast.Assign) -> None:
870873
assert isinstance(target.value, ExtendedAST)
871874
assert target.value._type_info is not None
872875
target_origin = target.value._type_info.origin # pyright: ignore[reportOptionalMemberAccess]
873-
if not target_origin.is_host():
876+
if not target_origin.is_host() and not isinstance(
877+
target.value._type_info, StackTensorType
878+
):
874879
# Get the variable name for the error message
875880
var_name = (
876881
target.value.id
@@ -895,7 +900,9 @@ def _assign_subscript(self, target: ast.Subscript, val: object) -> None:
895900
assert isinstance(target.value, ExtendedAST)
896901
assert target.value._type_info is not None
897902
target_origin = target.value._type_info.origin
898-
assert target_origin.is_host()
903+
assert target_origin.is_host() or isinstance(
904+
target.value._type_info, StackTensorType
905+
)
899906

900907
return hl.store(
901908
self.visit(target.value), # pyright: ignore[reportArgumentType]
@@ -928,6 +935,8 @@ def visit_Subscript(self, node: ast.Subscript) -> object:
928935
if isinstance(node.slice, ast.Constant):
929936
return self.visit(value)[self.visit(node.slice)] # pyright: ignore[reportIndexIssue]
930937
raise exc.InvalidSequenceSubscription(node.slice)
938+
if isinstance(type_info, StackTensorType):
939+
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
931940
if type_info is not None and type_info.origin.is_host():
932941
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
933942
return hl.subscript(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]

helion/_compiler/indexing_strategy.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import sympy
1010
import torch
11+
from torch._inductor.utils import triton_type
1112

1213
from .. import exc
1314
from .._compat import get_tensor_descriptor_fn_name
@@ -19,10 +20,15 @@
1920
from .variable_origin import BlockSizeOrigin
2021

2122
if TYPE_CHECKING:
23+
from collections.abc import Sequence
24+
2225
from ..runtime.config import Config
2326
from .device_function import TensorDescriptorArg
2427
from .inductor_lowering import CodegenState
2528

29+
SymIntLike = torch.SymInt | int
30+
ShapeLike = Sequence[SymIntLike]
31+
2632

2733
class IndexingStrategy:
2834
def codegen_load(
@@ -296,6 +302,147 @@ def codegen_store(
296302
)
297303

298304

305+
class StackIndexingStrategy:
306+
"""
307+
Generate pointer math for stacking load/store to several device memory pointers sharing the same indexing.
308+
309+
offset, mask are calculated for the tensor_like template tensor and then broadcasted to each dev_ptr
310+
, with the results stacked.
311+
312+
e.g. for a 1D offset tensor and a 1D dev_ptr array, the stack offset is:
313+
stack_offset = dev_ptrs[:, None] + offset[None, :]
314+
315+
"""
316+
317+
@staticmethod
318+
def get_broadcast_str(
319+
stack_shape: ShapeLike,
320+
subscript_shape: ShapeLike,
321+
) -> tuple[str, str]:
322+
"""
323+
Args:
324+
stack_shape: shape of the dev_ptr tensor.
325+
subscript_shape: shape of subscription for each individual tensor.
326+
327+
Returns:
328+
the broadcast str for dev_ptrs and individual tensor offset.
329+
"""
330+
stack_broadcast_keys = [":" for _ in stack_shape] + [
331+
"None" for _ in subscript_shape
332+
]
333+
stack_broadcast = f"[{', '.join(stack_broadcast_keys)}]"
334+
tensor_broadcast_keys = ["None" for _ in stack_shape] + [
335+
":" for _ in subscript_shape
336+
]
337+
tensor_broadcast = f"[{', '.join(tensor_broadcast_keys)}]"
338+
339+
return stack_broadcast, tensor_broadcast
340+
341+
@staticmethod
342+
def get_mask_expr(
343+
state: CodegenState,
344+
indexing: SubscriptIndexing,
345+
stack_shape: ShapeLike,
346+
subscript_shape: ShapeLike,
347+
) -> ast.AST | None:
348+
stack_broadcast, tensor_broadcast = StackIndexingStrategy.get_broadcast_str(
349+
stack_shape, subscript_shape
350+
)
351+
352+
mask_exprs = []
353+
dev_ptr_mask_exprs = []
354+
# Generate Mask
355+
356+
for dim, size in enumerate(stack_shape):
357+
if (
358+
index := CompileEnvironment.current().get_block_id(size)
359+
) is not None and (mask_var := state.codegen.mask_var(index)) is not None:
360+
expand = state.tile_strategy.expand_str(stack_shape, dim)
361+
dev_ptr_mask_exprs.append(f"({mask_var}{expand})")
362+
363+
if dev_ptr_mask_exprs:
364+
dev_ptr_mask_expr = f"({'&'.join(dev_ptr_mask_exprs)})"
365+
if len(dev_ptr_mask_exprs) < len(stack_shape):
366+
dev_ptr_mask_expr = f"tl.broadcast_to({dev_ptr_mask_expr}, {state.tile_strategy.shape_str(stack_shape)})"
367+
dev_ptr_mask_expr = f"({dev_ptr_mask_expr}){stack_broadcast}"
368+
mask_exprs.append(dev_ptr_mask_expr)
369+
370+
if indexing.has_mask():
371+
mask_exprs.append(f"(tensor_mask){tensor_broadcast}")
372+
return expr_from_string(
373+
"&".join(mask_exprs), tensor_mask=indexing.mask_expr
374+
)
375+
if mask_exprs:
376+
return expr_from_string("&".join(mask_exprs))
377+
return None
378+
379+
@staticmethod
380+
def codegen_load(
381+
state: CodegenState,
382+
stack_tensor: tuple[torch.Tensor, torch.Tensor],
383+
dev_ptrs_ast: ast.AST,
384+
subscript: list[object],
385+
extra_mask: ast.AST | None,
386+
) -> ast.AST:
387+
tensor_like, dev_ptrs = stack_tensor
388+
indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask)
389+
subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript)
390+
stack_shape = [*dev_ptrs.size()]
391+
392+
mask_expr = StackIndexingStrategy.get_mask_expr(
393+
state, indexing, stack_shape, subscripts_shape
394+
)
395+
extra = ", other=0"
396+
if mask_expr is None:
397+
mask_expr = expr_from_string("None")
398+
extra = ""
399+
400+
stack_broadcast, tensor_broadcast = StackIndexingStrategy.get_broadcast_str(
401+
stack_shape, subscripts_shape
402+
)
403+
404+
dtype = triton_type(tensor_like.dtype)
405+
return expr_from_string(
406+
f"tl.load((base.to(tl.pointer_type({dtype}))){stack_broadcast} + (offset){tensor_broadcast}, mask{extra})",
407+
base=dev_ptrs_ast,
408+
offset=indexing.index_expr,
409+
mask=mask_expr,
410+
)
411+
412+
@staticmethod
413+
def codegen_store(
414+
state: CodegenState,
415+
stack_tensor: tuple[torch.Tensor, torch.Tensor],
416+
dev_ptrs_ast: ast.AST,
417+
subscript: list[object],
418+
value: ast.AST,
419+
extra_mask: ast.AST | None,
420+
) -> ast.AST:
421+
tensor_like, dev_ptrs = stack_tensor
422+
indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask)
423+
subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript)
424+
stack_shape = [*dev_ptrs.size()]
425+
426+
mask_expr = StackIndexingStrategy.get_mask_expr(
427+
state, indexing, stack_shape, subscripts_shape
428+
)
429+
if mask_expr is None:
430+
mask_expr = expr_from_string("None")
431+
432+
stack_broadcast, tensor_broadcast = StackIndexingStrategy.get_broadcast_str(
433+
stack_shape, subscripts_shape
434+
)
435+
436+
dtype = triton_type(tensor_like.dtype)
437+
return expr_from_string(
438+
f"tl.store(base.to(tl.pointer_type({dtype})){stack_broadcast} + (offset){tensor_broadcast}, value, mask)",
439+
base=dev_ptrs_ast,
440+
value=value,
441+
offset=indexing.index_expr,
442+
mask=mask_expr,
443+
)
444+
445+
299446
class SubscriptIndexing(NamedTuple):
300447
index_expr: ast.AST
301448
mask_expr: ast.AST

helion/_compiler/roll_reduction.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from torch.fx import map_arg
88

9+
from ..language import _MEMORY_OPS
910
from ..language._tracing_ops import _for_loop
1011
from ..language._tracing_ops import _get_symnode
1112
from ..language._tracing_ops import _host_tensor
@@ -277,6 +278,35 @@ def is_matmul_with_rdim(node: torch.fx.Node) -> bool:
277278

278279
return any(is_matmul_with_rdim(node) for node in graph.nodes)
279280

281+
def has_stack_tensor_with_rdim(self, graph: torch.fx.Graph) -> bool:
282+
"""Check if a graph contains stack tensors with rdim inputs."""
283+
284+
def is_stack_with_rdim(node: torch.fx.Node) -> bool:
285+
"""Check if a node is a stack dev_ptr with rdim inputs."""
286+
if node.op != "call_function":
287+
return False
288+
289+
if node.target not in _MEMORY_OPS:
290+
return False
291+
292+
host_tensor = node.args[0]
293+
294+
if not isinstance(host_tensor, tuple):
295+
return False
296+
297+
# Check if stack dims have rdim
298+
if len(host_tensor) == 2:
299+
assert isinstance(host_tensor[1], torch.fx.Node)
300+
stack = host_tensor[1].meta.get("val", None)
301+
if isinstance(stack, torch.Tensor):
302+
for size in stack.size():
303+
block_idx = CompileEnvironment.current().get_block_id(size)
304+
if block_idx == self.rdim.block_id:
305+
return True
306+
return False
307+
308+
return any(is_stack_with_rdim(node) for node in graph.nodes)
309+
280310
def process(self, graph: torch.fx.Graph) -> torch.fx.Graph:
281311
for node in graph.nodes:
282312
if self.should_go_in_inner_graph(node):

helion/_compiler/type_propagation.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ..autotuner.config_spec import BlockSizeSpec
2828
from ..language._decorators import get_device_func_replacement
2929
from ..language._decorators import is_api_func
30+
from ..language.stack_tensor import StackTensor
3031
from ..language.tile_proxy import Tile
3132
from ..language.tile_proxy import _CheckForIndexCalls
3233
from .ast_extension import ExtendedAST
@@ -1294,6 +1295,86 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo:
12941295
return self.element_types[attr]
12951296

12961297

1298+
class StackTensorType(ClassType):
1299+
element_types: dict[str, TypeInfo] # pyright: ignore[reportIncompatibleVariableOverride]
1300+
1301+
def proxy(self) -> StackTensor: # pyright: ignore[reportIncompatibleMethodOverride]
1302+
with proxy_tensor.disable_proxy_modes_tracing():
1303+
fake_mode = torch._C._unset_dispatch_mode( # pyright: ignore[reportAttributeAccessIssue]
1304+
torch._C._TorchDispatchModeKey.FAKE # pyright: ignore[reportAttributeAccessIssue]
1305+
)
1306+
try:
1307+
assert isinstance(self.element_types["tensor_like"], TensorType)
1308+
assert isinstance(self.element_types["dev_ptrs"], TensorType)
1309+
return StackTensor(
1310+
self.element_types["tensor_like"].proxy(),
1311+
self.element_types["dev_ptrs"].proxy(),
1312+
)
1313+
finally:
1314+
assert fake_mode is not None
1315+
torch._C._set_dispatch_mode(fake_mode) # pyright: ignore[reportAttributeAccessIssue]
1316+
1317+
def merge(self, other: TypeInfo) -> TypeInfo:
1318+
if isinstance(other, StackTensorType):
1319+
self_elements = self.element_types
1320+
other_elements = other.element_types
1321+
if set(self_elements.keys()) == set(other_elements.keys()):
1322+
return StackTensorType(
1323+
origin=other.origin,
1324+
element_types={
1325+
key: self_elements[key].merge(other_elements[key])
1326+
for key in self_elements
1327+
},
1328+
)
1329+
return super().merge(other)
1330+
1331+
def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
1332+
tensor_like_type = self.element_types["tensor_like"]
1333+
assert isinstance(tensor_like_type, TensorType)
1334+
size_like = tensor_like_type._device_indexing_size(key)
1335+
1336+
dev_ptrs_type = self.element_types["dev_ptrs"]
1337+
assert isinstance(dev_ptrs_type, TensorType)
1338+
stack_size = list(dev_ptrs_type.fake_value.size())
1339+
1340+
return stack_size + size_like
1341+
1342+
def propagate_setitem(
1343+
self, key: TypeInfo, value: TypeInfo, origin: Origin
1344+
) -> TypeInfo:
1345+
if origin.is_host():
1346+
warning(exc.TensorOperationInWrapper)
1347+
else:
1348+
lhs_shape = self._device_indexing_size(key)
1349+
lhs_rank = len(lhs_shape)
1350+
if isinstance(value, TensorType):
1351+
rhs_rank = value.fake_value.ndim
1352+
if lhs_rank != rhs_rank:
1353+
raise exc.RankMismatch(
1354+
lhs_rank,
1355+
rhs_rank,
1356+
f"LHS shape: {tuple(lhs_shape)}, RHS shape: {tuple(value.fake_value.shape)}",
1357+
)
1358+
elif isinstance(value, (NumericType, LiteralType)):
1359+
# Allow scalar assignment to tensor (broadcasts to tensor shape)
1360+
pass
1361+
else:
1362+
raise exc.RequiresTensorInAssignment(value)
1363+
return self
1364+
1365+
def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo:
1366+
if origin.is_host():
1367+
warning(exc.TensorOperationInWrapper)
1368+
1369+
assert isinstance(self.element_types["tensor_like"], TensorType)
1370+
return TensorType(
1371+
origin,
1372+
self.element_types["tensor_like"]
1373+
.proxy()
1374+
.new_empty(self._device_indexing_size(key)),
1375+
)
1376+
1377+
12971378
class SliceType(CollectionType):
12981379
element_types: slice # pyright: ignore[reportIncompatibleVariableOverride]
12991380

@@ -1619,7 +1700,7 @@ def _assign(self, lhs: ast.AST, rhs: TypeInfo) -> None:
16191700
if isinstance(lhs, ast.Subscript):
16201701
# TODO(jansel): test different types of subscript
16211702
lhs_base_type = self.visit(lhs.value)
1622-
if isinstance(lhs_base_type, TensorType):
1703+
if isinstance(lhs_base_type, (TensorType, StackTensorType)):
16231704
self.visit(lhs) # need to populate shape info
16241705
lhs_base_type = lhs_base_type.propagate_setitem(
16251706
self.visit(lhs.slice), rhs, self.origin()

0 commit comments

Comments
 (0)