Skip to content

Commit d165754

Browse files
[#5153] Update the build flag metadata when loading SPIRV kernel. (#5402)
This is workaround for feature requirement in #5153. The IGC build flag is updated when the large GRF mode is used in loading SPIRV kernel when register spill size > 1000. Signed-off-by: etaf <[email protected]> Co-authored-by: Anatoly Myachev <[email protected]>
1 parent 31386a6 commit d165754

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

python/triton/compiler/compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,12 @@ def raise_(err):
468468
self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary(
469469
self.name, self.kernel, self.metadata.shared, self.metadata.build_flags,
470470
not self.metadata.generate_native_code, device)
471+
# PyTorch could use the updated build flags in load binary.
472+
if hasattr(driver.active.utils, "get_last_selected_build_flags"):
473+
new_build_flags = driver.active.utils.get_last_selected_build_flags()
474+
if new_build_flags != self.metadata.build_flags:
475+
self.metadata = self.metadata._replace(build_flags=new_build_flags)
476+
471477
if hasattr(self.metadata, "threads_per_warp"):
472478
warp_size = self.metadata.threads_per_warp
473479
else:

third_party/intel/backend/driver.c

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ sycl::context get_default_context(const sycl::device &sycl_device) {
192192
#endif
193193
}
194194

195+
static BuildFlags last_build_flag("");
196+
197+
extern "C" EXPORT_FUNC PyObject *get_last_selected_build_flags() {
198+
return Py_BuildValue("s", last_build_flag().data());
199+
}
200+
195201
extern "C" EXPORT_FUNC PyObject *load_binary(PyObject *args) {
196202
const char *name, *build_flags_ptr;
197203
int shared;
@@ -309,7 +315,7 @@ extern "C" EXPORT_FUNC PyObject *load_binary(PyObject *args) {
309315
PyCapsule_New(reinterpret_cast<void *>(fun), "kernel", freeKernel);
310316
auto kernel_bundle_py = PyCapsule_New(reinterpret_cast<void *>(mod),
311317
"kernel_bundle", freeKernelBundle);
312-
318+
last_build_flag = build_flags;
313319
return Py_BuildValue("(OOiii)", kernel_bundle_py, kernel_py, n_regs,
314320
n_spills, n_max_threads);
315321

third_party/intel/backend/driver.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,11 @@ def __init__(self, cache_path: str):
199199
self.shared_library.get_device_properties.argtypes = (ctypes.c_int, )
200200
self.shared_library.has_opencl_extension.restype = ctypes.py_object
201201
self.shared_library.has_opencl_extension.argtypes = (ctypes.c_int, ctypes.c_char_p)
202+
self.shared_library.get_last_selected_build_flags.restype = ctypes.py_object
202203

203204
def __getattribute__(self, name):
204-
if name in ("get_device_properties", "init_devices", "wait_on_sycl_queue", "has_opencl_extension"):
205+
if name in ("get_device_properties", "init_devices", "wait_on_sycl_queue", "has_opencl_extension",
206+
"get_last_selected_build_flags"):
205207
shared_library = super().__getattribute__("shared_library")
206208
return getattr(shared_library, name)
207209

@@ -318,6 +320,7 @@ def __init__(self):
318320
self.device_count = mod.init_devices(self.get_sycl_queue())
319321
self.wait_on_sycl_queue = mod.wait_on_sycl_queue
320322
self.has_opencl_extension = mod.has_opencl_extension
323+
self.get_last_selected_build_flags = mod.get_last_selected_build_flags
321324

322325
def get_current_device(self):
323326
import torch

0 commit comments

Comments
 (0)