Skip to content

Commit a943a26

Browse files
committed
Add a new stage to generate zebin for XPU.
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 67f901d commit a943a26

File tree

2 files changed

+58
-51
lines changed

2 files changed

+58
-51
lines changed

python/triton/compiler/compiler.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def parse(full_name, ext, context):
138138
return module
139139
if ext == "llir" or ext == "ptx" or ext == "amdgcn":
140140
return Path(full_name).read_text()
141-
if ext == "cubin" or ext == "hsaco":
141+
if ext == "cubin" or ext == "hsaco" or ext == "zebin":
142142
return Path(full_name).read_bytes()
143143
if ext == "spv":
144144
return Path(full_name).read_bytes()
@@ -332,7 +332,7 @@ def compile(src, target=None, options=None, _env_vars=None):
332332
print(f"\nOverriding kernel with file {full_name}")
333333
next_module = parse(full_name, ext, context)
334334
# If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json
335-
if (not store_only_binary) or (ext in ("cubin", "hsaco", "json", "spv")):
335+
if (not store_only_binary) or (ext in ("cubin", "hsaco", "zebin", "json", "spv")):
336336
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
337337
if fn_dump_manager is not None:
338338
fn_dump_manager.put(next_module, ir_filename)
@@ -433,11 +433,15 @@ def __init__(self, src, metadata_group, hash):
433433
self.name = self.metadata.name
434434
# stores the text of each level of IR that was generated during compilation
435435
asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")]
436+
437+
def read_file(path):
438+
try:
439+
return path.read_text()
440+
except UnicodeDecodeError:
441+
return path.read_bytes()
442+
443+
self.asm = AsmDict({file.suffix[1:]: read_file(file) for file in asm_files})
436444
binary_ext = backend.binary_ext
437-
self.asm = AsmDict({
438-
file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text()
439-
for file in asm_files
440-
})
441445
self.metadata_group = metadata_group
442446
self.kernel = self.asm[binary_ext]
443447
# binaries are lazily initialized
@@ -477,8 +481,7 @@ def raise_(err):
477481
knobs.runtime.kernel_load_start_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
478482
# TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
479483
self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary(
480-
self.name, self.kernel, self.metadata.shared, self.metadata.build_flags,
481-
not self.metadata.generate_native_code, device)
484+
self.name, self.kernel, self.metadata.shared, self.metadata.build_flags, False, device)
482485
if hasattr(self.metadata, "threads_per_warp"):
483486
warp_size = self.metadata.threads_per_warp
484487
else:

third_party/intel/backend/compiler.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def __init__(self, target: tuple) -> None:
120120
mod = compile_module_from_src(src=Path(os.path.join(dirname, "arch_parser.c")).read_text(), name="arch_utils")
121121
self.device_arch = knobs.intel.device_arch or mod.parse_device_arch(target.arch.get('architecture', 0))
122122
self.properties = self.parse_target(target.arch)
123-
self.binary_ext = "spv"
123+
self.binary_ext = "zebin"
124124

125125
def get_target_name(self, options) -> str:
126126
return f"xpu:{self.device_arch}"
@@ -374,6 +374,10 @@ def make_llir(src, metadata, options):
374374
def make_spv(src, metadata, options, device_arch):
375375
spirv, name = intel.translate_to_spirv(src)
376376
metadata["name"] = name
377+
return spirv
378+
379+
@staticmethod
380+
def make_zebin(src, metadata, options, device_arch):
377381
if options.grf_mode == 'small':
378382
metadata["build_flags"] = "-cl-intel-128-GRF-per-thread"
379383
elif options.grf_mode == 'large':
@@ -392,50 +396,49 @@ def make_spv(src, metadata, options, device_arch):
392396
if knobs.intel.dump_shader_info:
393397
# The IGC (Intel Graphic Compiler) only parses the options at first time in JIT-ing the binary per process.
394398
# Have to use the `ocloc` to generate the binary in sub-process to work around the limitation.
395-
assert options.generate_native_code, "Only support native code generation with shader dump"
399+
# assert options.generate_native_code, "Only support native code generation with shader dump"
396400
shader_dump_opt = f" -igc_opts ',DumpToCustomDir={metadata['cache_dir']},ShaderDumpEnable=1'"
397401

398-
metadata["generate_native_code"] = options.generate_native_code
399-
400-
if options.generate_native_code:
401-
with tempfile.TemporaryDirectory() as temp_dir:
402-
with tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir, delete=False) as fsrc:
403-
fsrc.write(spirv)
404-
fbin = fsrc.name + '.o'
405-
406-
ocloc_cmd = [
407-
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch,
408-
'-options', metadata["build_flags"] + shader_dump_opt
409-
]
410-
411-
try:
412-
output = subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
413-
if 'spilled' in output and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1:
414-
"""
415-
The exact message is something like:
416-
warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217
417-
is "spilled" enough for now?
418-
"""
419-
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread"
420-
# re-run with new build flags
421-
ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt
422-
subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
423-
except subprocess.CalledProcessError as e:
424-
if e.returncode == 255:
425-
error = 'Internal Triton ZEBIN codegen error'
426-
elif e.returncode == 128 + signal.SIGSEGV:
427-
error = '`ocloc` raised SIGSEGV'
428-
else:
429-
error = f'`ocloc` failed with error code {e.returncode}'
430-
431-
raise RuntimeError(f'{error}\n'
432-
f'`ocloc` stderr:\n{e.output}\n'
433-
f'Repro command: {ocloc_cmd}\n') from e
434-
435-
with open(fbin, 'rb') as f:
436-
zebin = f.read()
437-
return zebin
438-
return spirv
402+
# metadata["generate_native_code"] = options.generate_native_code
403+
404+
# if options.generate_native_code:
405+
with tempfile.TemporaryDirectory() as temp_dir:
406+
with tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir, delete=False) as fsrc:
407+
fsrc.write(src)
408+
fbin = fsrc.name + '.o'
409+
410+
ocloc_cmd = [
411+
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch, '-options',
412+
metadata["build_flags"] + shader_dump_opt
413+
]
414+
415+
try:
416+
output = subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
417+
if 'spilled' in output and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1:
418+
"""
419+
The exact message is something like:
420+
warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217
421+
is "spilled" enough for now?
422+
"""
423+
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread"
424+
# re-run with new build flags
425+
ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt
426+
subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
427+
except subprocess.CalledProcessError as e:
428+
if e.returncode == 255:
429+
error = 'Internal Triton ZEBIN codegen error'
430+
elif e.returncode == 128 + signal.SIGSEGV:
431+
error = '`ocloc` raised SIGSEGV'
432+
else:
433+
error = f'`ocloc` failed with error code {e.returncode}'
434+
435+
raise RuntimeError(f'{error}\n'
436+
f'`ocloc` stderr:\n{e.output}\n'
437+
f'Repro command: {ocloc_cmd}\n') from e
438+
439+
with open(fbin, 'rb') as f:
440+
zebin = f.read()
441+
return zebin
439442

440443
def add_stages(self, stages, options, language):
441444
if language == Language.TRITON:
@@ -445,6 +448,7 @@ def add_stages(self, stages, options, language):
445448
stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options)
446449
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
447450
stages["spv"] = lambda src, metadata: self.make_spv(src, metadata, options, self.device_arch)
451+
stages["zebin"] = lambda src, metadata: self.make_zebin(src, metadata, options, self.device_arch)
448452

449453
@functools.lru_cache()
450454
def hash(self):

0 commit comments

Comments
 (0)