Skip to content

Commit b526748

Browse files
Add IntelDPASLayout to Gluon (#5273)
This PR fixes #5159
1 parent c2a39f4 commit b526748

File tree

5 files changed

+115
-0
lines changed

5 files changed

+115
-0
lines changed

python/src/gluon_ir.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "mlir/IR/BuiltinTypes.h"
66
#include "mlir/IR/Types.h"
77
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
8+
#include "third_party/intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
89
#include "triton/Analysis/Utility.h"
910
#include "triton/Dialect/Gluon/IR/Dialect.h"
1011
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
@@ -102,12 +103,15 @@ struct GluonLayouts {
102103
py::handle AMDMFMALayout;
103104
py::handle AMDWMMALayout;
104105
py::handle PaddedSharedLayout;
106+
py::handle IntelDPASLayout;
105107

106108
GluonLayouts() {
107109
auto layouts =
108110
py::module::import("triton.experimental.gluon.language._layouts");
109111
auto amdLayouts =
110112
py::module::import("triton.experimental.gluon.language.amd._layouts");
113+
auto intelLayouts =
114+
py::module::import("triton.experimental.gluon.language.intel._layouts");
111115
AutoLayout = py::object(layouts.attr("AutoLayout")).release();
112116
BlockedLayout = py::object(layouts.attr("BlockedLayout")).release();
113117
SliceLayout = py::object(layouts.attr("SliceLayout")).release();
@@ -125,6 +129,8 @@ struct GluonLayouts {
125129
AMDWMMALayout = py::object(amdLayouts.attr("AMDWMMALayout")).release();
126130
PaddedSharedLayout =
127131
py::object(layouts.attr("PaddedSharedLayout")).release();
132+
IntelDPASLayout =
133+
py::object(intelLayouts.attr("IntelDPASLayout")).release();
128134

129135
auto core = py::module::import("triton.language.core");
130136
}
@@ -247,6 +253,12 @@ py::object layoutToGluon(Attribute layout) {
247253
return layouts.PaddedSharedLayout(intervalPaddingPairs,
248254
ll.getBases().lookup(kOffset),
249255
ll.getBases().lookup(kBlock), shape);
256+
} else if (auto intelDpas = dyn_cast<ttg::intel::DpasEncodingAttr>(layout)) {
257+
return layouts.IntelDPASLayout(
258+
intelDpas.getRepeatCount(), intelDpas.getSystolicDepth(),
259+
intelDpas.getExecutionSize(), intelDpas.getOpsPerChannel(),
260+
toStdVector(intelDpas.getWarpsPerCTA()),
261+
toStdVector(intelDpas.getRepCluster()), intelDpas.getThreadsPerWarp());
250262
}
251263

252264
throw py::value_error("Unhandled encoding encountered");
@@ -385,6 +397,17 @@ void init_gluon_ir(py::module &&m) {
385397
return ttg::AMDWmmaEncodingAttr::get(
386398
ctx, version, transposed, warpsPerCta, ctaLayout, instrShape);
387399
})
400+
.def("get_intel_dpas_layout",
401+
[](GluonOpBuilder &self, unsigned repeatCount,
402+
unsigned systolicDepth, unsigned executionSize,
403+
unsigned opsPerChannel, std::vector<unsigned> &warpsPerCTA,
404+
std::vector<unsigned> &repCluster,
405+
unsigned threadsPerWarp) -> Attribute {
406+
auto ctx = self.getContext();
407+
return ttg::intel::DpasEncodingAttr::get(
408+
ctx, repeatCount, systolicDepth, executionSize, opsPerChannel,
409+
warpsPerCTA, repCluster, threadsPerWarp);
410+
})
388411
.def("get_padded_shared_layout",
389412
[](GluonOpBuilder &self, std::vector<unsigned> &intervals,
390413
std::vector<unsigned> &paddings,

python/test/gluon/test_lowerings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ def _reduce_layouts():
125125
ttgl.amd.AMDMFMALayout(version=4, instr_shape=[32, 32, 16], transposed=True, warps_per_cta=[1, 4]),
126126
ttgl.amd.AMDWMMALayout(version=1, transposed=True, warps_per_cta=[1, 4]),
127127
ttgl.amd.AMDWMMALayout(version=2, transposed=True, warps_per_cta=[1, 4]),
128+
ttgl.intel.IntelDPASLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1,
129+
warps_per_cta=[4, 1], rep_cluster=[1, 1], threads_per_warp=32),
128130
ttgl.DotOperandLayout(
129131
parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 4], ctas_per_cga=[1, 1],
130132
cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]),

python/triton/experimental/gluon/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,5 @@
125125

126126
from . import nvidia
127127
from . import amd
128+
from . import intel
128129
from . import extra
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ._layouts import IntelDPASLayout
2+
3+
__all__ = ["IntelDPASLayout"]
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import List, Optional
5+
from triton.language.core import _unwrap_if_constexpr
6+
7+
from triton.experimental.gluon.language._layouts import DistributedLayout
8+
9+
__all__ = [
10+
"IntelDPASLayout",
11+
]
12+
13+
14+
@dataclass(frozen=True)
15+
class IntelDPASLayout(DistributedLayout):
16+
"""
17+
Represents a layout for Intel DPAS (dot product accumulator) operations.
18+
19+
Args:
20+
repeatCount (int): Number of repeats for the operation.
21+
systolic_depth (int): Systolic array depth.
22+
execution_size (int): Execution size.
23+
ops_per_chan (int): Operations per channel.
24+
warps_per_cta (List[int]): Warp layout in the block.
25+
rep_cluster (List[int]): Cluster repetition configuration.
26+
threads_per_warp (int): Number of threads per warp.
27+
"""
28+
29+
repeatCount: int
30+
systolic_depth: int
31+
execution_size: int
32+
ops_per_chan: int
33+
warps_per_cta: List[int]
34+
rep_cluster: List[int]
35+
threads_per_warp: int
36+
cta_order: Optional[List[int]] = None
37+
38+
def __post_init__(self):
39+
super().__setattr__("repeatCount", _unwrap_if_constexpr(self.repeatCount))
40+
super().__setattr__("systolic_depth", _unwrap_if_constexpr(self.systolic_depth))
41+
super().__setattr__("execution_size", _unwrap_if_constexpr(self.execution_size))
42+
super().__setattr__("ops_per_chan", _unwrap_if_constexpr(self.ops_per_chan))
43+
super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
44+
super().__setattr__("rep_cluster", _unwrap_if_constexpr(self.rep_cluster))
45+
super().__setattr__("threads_per_warp", _unwrap_if_constexpr(self.threads_per_warp))
46+
# Compute cta_order as reversed range of warps_per_cta length, if not provided
47+
super().__setattr__("cta_order", list(reversed(range(len(self.warps_per_cta)))))
48+
49+
self.verify()
50+
51+
def _to_ir(self, builder):
52+
# TODO: Replace with actual Intel DPAS IR builder method
53+
return builder.get_intel_dpas_layout(
54+
self.repeatCount,
55+
self.systolic_depth,
56+
self.execution_size,
57+
self.ops_per_chan,
58+
self.warps_per_cta,
59+
self.rep_cluster,
60+
self.threads_per_warp,
61+
)
62+
63+
def mangle(self) -> str:
64+
65+
def stringify(x):
66+
if x is None:
67+
return ""
68+
return "_".join(map(str, x))
69+
70+
return f"IntelDPAS_{self.repeatCount}_{self.systolic_depth}_{self.execution_size}_{self.ops_per_chan}_{stringify(self.warps_per_cta)}_{stringify(self.rep_cluster)}_{self.threads_per_warp}_IntelDPAS"
71+
72+
def verify(self):
73+
# TODO Do we need verify?
74+
return
75+
76+
def __hash__(self):
77+
return hash((
78+
self.repeatCount,
79+
self.systolic_depth,
80+
self.execution_size,
81+
self.ops_per_chan,
82+
tuple(self.warps_per_cta),
83+
tuple(self.rep_cluster),
84+
self.threads_per_warp,
85+
tuple(self.cta_order),
86+
))

0 commit comments

Comments
 (0)