|
8 | 8 |
|
9 | 9 | import sympy
|
10 | 10 | import torch
|
| 11 | +from torch._inductor.utils import triton_type |
11 | 12 | import triton
|
12 | 13 |
|
13 | 14 | from .. import exc
|
|
20 | 21 | from .variable_origin import BlockSizeOrigin
|
21 | 22 |
|
22 | 23 | if TYPE_CHECKING:
|
| 24 | + from collections.abc import Sequence |
| 25 | + |
23 | 26 | from ..runtime.config import Config
|
24 | 27 | from .device_function import TensorDescriptorArg
|
25 | 28 | from .inductor_lowering import CodegenState
|
26 | 29 |
|
| 30 | + SymIntLike = torch.SymInt | int |
| 31 | + ShapeLike = Sequence[SymIntLike] |
| 32 | + |
27 | 33 |
|
28 | 34 | class IndexingStrategy:
|
29 | 35 | def codegen_load(
|
@@ -289,6 +295,134 @@ def codegen_store(
|
289 | 295 | )
|
290 | 296 |
|
291 | 297 |
|
| 298 | +class MulticastIndexingStrategy: |
| 299 | + @staticmethod |
| 300 | + def get_broadcast_str( |
| 301 | + multicast_shape: ShapeLike, |
| 302 | + subscript_shape: ShapeLike, |
| 303 | + ) -> tuple[str, str]: |
| 304 | + multicast_broadcast_keys = [":" for _ in multicast_shape] + [ |
| 305 | + "None" for _ in subscript_shape |
| 306 | + ] |
| 307 | + multicast_broadcast = f"[{', '.join(multicast_broadcast_keys)}]" |
| 308 | + tensor_broadcast_keys = ["None" for _ in multicast_shape] + [ |
| 309 | + ":" for _ in subscript_shape |
| 310 | + ] |
| 311 | + tensor_broadcast = f"[{', '.join(tensor_broadcast_keys)}]" |
| 312 | + |
| 313 | + return multicast_broadcast, tensor_broadcast |
| 314 | + |
| 315 | + @staticmethod |
| 316 | + def get_mask_expr( |
| 317 | + state: CodegenState, |
| 318 | + indexing: SubscriptIndexing, |
| 319 | + multicast_shape: ShapeLike, |
| 320 | + subscript_shape: ShapeLike, |
| 321 | + ) -> ast.AST | None: |
| 322 | + multicast_broadcast, tensor_broadcast = ( |
| 323 | + MulticastIndexingStrategy.get_broadcast_str( |
| 324 | + multicast_shape, subscript_shape |
| 325 | + ) |
| 326 | + ) |
| 327 | + |
| 328 | + mask_exprs = [] |
| 329 | + dev_ptr_mask_exprs = [] |
| 330 | + # Generate Mask |
| 331 | + |
| 332 | + for dim, size in enumerate(multicast_shape): |
| 333 | + if ( |
| 334 | + index := CompileEnvironment.current().get_block_id(size) |
| 335 | + ) is not None and (mask_var := state.codegen.mask_var(index)) is not None: |
| 336 | + expand = state.tile_strategy.expand_str(multicast_shape, dim) |
| 337 | + dev_ptr_mask_exprs.append(f"({mask_var}{expand})") |
| 338 | + |
| 339 | + if dev_ptr_mask_exprs: |
| 340 | + dev_ptr_mask_expr = f"({'&'.join(dev_ptr_mask_exprs)})" |
| 341 | + if len(dev_ptr_mask_exprs) < len(multicast_shape): |
| 342 | + dev_ptr_mask_expr = f"tl.broadcast_to({dev_ptr_mask_expr}, {state.tile_strategy.shape_str(multicast_shape)})" |
| 343 | + dev_ptr_mask_expr = f"({dev_ptr_mask_expr}){multicast_broadcast}" |
| 344 | + mask_exprs.append(dev_ptr_mask_expr) |
| 345 | + |
| 346 | + if indexing.has_mask(): |
| 347 | + mask_exprs.append(f"(tensor_mask){tensor_broadcast}") |
| 348 | + return expr_from_string( |
| 349 | + "&".join(mask_exprs), tensor_mask=indexing.mask_expr |
| 350 | + ) |
| 351 | + if mask_exprs: |
| 352 | + return expr_from_string("&".join(mask_exprs)) |
| 353 | + return None |
| 354 | + |
| 355 | + @staticmethod |
| 356 | + def codegen_load( |
| 357 | + state: CodegenState, |
| 358 | + tensors: tuple[torch.Tensor, torch.Tensor], |
| 359 | + dev_ptrs_ast: ast.AST, |
| 360 | + subscript: list[object], |
| 361 | + extra_mask: ast.AST | None, |
| 362 | + ) -> ast.AST: |
| 363 | + tensor_like, dev_ptrs = tensors |
| 364 | + indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask) |
| 365 | + subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript) |
| 366 | + multicast_shape = [*dev_ptrs.size()] |
| 367 | + |
| 368 | + mask_expr = MulticastIndexingStrategy.get_mask_expr( |
| 369 | + state, indexing, multicast_shape, subscripts_shape |
| 370 | + ) |
| 371 | + extra = ", other=0" |
| 372 | + if mask_expr is None: |
| 373 | + mask_expr = expr_from_string("None") |
| 374 | + extra = "" |
| 375 | + |
| 376 | + multicast_broadcast, tensor_broadcast = ( |
| 377 | + MulticastIndexingStrategy.get_broadcast_str( |
| 378 | + multicast_shape, subscripts_shape |
| 379 | + ) |
| 380 | + ) |
| 381 | + |
| 382 | + dtype = triton_type(tensor_like.dtype) |
| 383 | + return expr_from_string( |
| 384 | + f"tl.load((base.to(tl.pointer_type({dtype}))){multicast_broadcast} + (offset){tensor_broadcast}, mask{extra})", |
| 385 | + base=dev_ptrs_ast, |
| 386 | + offset=indexing.index_expr, |
| 387 | + mask=mask_expr, |
| 388 | + ) |
| 389 | + |
| 390 | + @staticmethod |
| 391 | + def codegen_store( |
| 392 | + state: CodegenState, |
| 393 | + tensors: tuple[torch.Tensor, torch.Tensor], |
| 394 | + dev_ptrs_ast: ast.AST, |
| 395 | + subscript: list[object], |
| 396 | + value: ast.AST, |
| 397 | + extra_mask: ast.AST | None, |
| 398 | + ) -> ast.AST: |
| 399 | + tensor_like, dev_ptrs = tensors |
| 400 | + indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask) |
| 401 | + subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript) |
| 402 | + multicast_shape = [*dev_ptrs.size()] |
| 403 | + |
| 404 | + mask_expr = MulticastIndexingStrategy.get_mask_expr( |
| 405 | + state, indexing, multicast_shape, subscripts_shape |
| 406 | + ) |
| 407 | + if mask_expr is None: |
| 408 | + mask_expr = expr_from_string("None") |
| 409 | + |
| 410 | + multicast_broadcast, tensor_broadcast = ( |
| 411 | + MulticastIndexingStrategy.get_broadcast_str( |
| 412 | + multicast_shape, subscripts_shape |
| 413 | + ) |
| 414 | + ) |
| 415 | + |
| 416 | + dtype = triton_type(tensor_like.dtype) |
| 417 | + return expr_from_string( |
| 418 | + f"tl.store(base.to(tl.pointer_type({dtype})){multicast_broadcast} + (offset){tensor_broadcast}, value, mask)", |
| 419 | + base=dev_ptrs_ast, |
| 420 | + value=value, |
| 421 | + offset=indexing.index_expr, |
| 422 | + mask=mask_expr, |
| 423 | + ) |
| 424 | + |
| 425 | + |
292 | 426 | class SubscriptIndexing(NamedTuple):
|
293 | 427 | index_expr: ast.AST
|
294 | 428 | mask_expr: ast.AST
|
|
0 commit comments