Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
9972987
Add catalyst/mlir/test/PrintUseNameLocAsPrefix.mlir file.
rturrado Aug 26, 2025
55a84c2
Update PrintUseNameLocAsPrefix.mlir.
rturrado Aug 27, 2025
0bce28e
generate_ir: pass arg_names to lower_jaxpr_to_mlir.
rturrado Sep 3, 2025
c0f02a5
[no ci] bump nightly version
Sep 12, 2025
5e58c62
[no ci] bump nightly version
Sep 13, 2025
f20430c
Merge branch 'main' into issue_1714_add_variable_names_to_the_ir
rturrado Sep 13, 2025
9e0a996
[no ci] bump nightly version
Sep 16, 2025
03d0a2b
[no ci] bump nightly version
Sep 17, 2025
f7f7a7e
[no ci] bump nightly version
Sep 18, 2025
31ddb88
Add argument names as name locations in the frontend, and add use the…
rturrado Sep 18, 2025
17b1cfd
Fix test_option_use_nameloc.py.
rturrado Sep 18, 2025
726f094
Fix CodeFactor checks.
rturrado Sep 18, 2025
0cbe469
Merge branch 'PennyLaneAI:main' into main
rturrado Sep 18, 2025
6ee021e
Merge remote-tracking branch 'origin' into issue_1714_add_variable_na…
rturrado Sep 18, 2025
daf2c99
[no ci] bump nightly version
Sep 19, 2025
9df3bf6
[no ci] bump nightly version
Sep 20, 2025
a7ebdd0
[no ci] bump nightly version
Sep 23, 2025
e139daf
[no ci] bump nightly version
Sep 24, 2025
e6b3d4f
[no ci] bump nightly version
Sep 25, 2025
126e626
[no ci] bump nightly version
Sep 26, 2025
9c72207
[no ci] bump nightly version
Sep 27, 2025
8e3cc1a
[no ci] bump nightly version
Sep 30, 2025
a58843d
[no ci] bump nightly version
Oct 1, 2025
6da9eb1
[no ci] bump nightly version
Oct 2, 2025
58b79b0
Merge remote-tracking branch 'upstream/main'
rturrado Oct 2, 2025
e81eb00
Update to origin/main.
rturrado Oct 2, 2025
67df639
Run black.
rturrado Oct 8, 2025
7748331
Merge branch 'main' into issue_1714_add_variable_names_to_the_ir
rturrado Oct 9, 2025
9634bc9
Fix code coverage warning about uncovered code in mlir_opt
rturrado Oct 9, 2025
8af233e
[no ci] bump nightly version
Oct 10, 2025
646039b
[no ci] bump nightly version
Oct 11, 2025
c0d1e97
[no ci] bump nightly version
Oct 14, 2025
4251be5
[no ci] bump nightly version
Oct 15, 2025
a4519e1
[no ci] bump nightly version
Oct 16, 2025
9218936
[no ci] bump nightly version
Oct 17, 2025
d8c042f
[no ci] bump nightly version
Oct 18, 2025
f740e4c
Merge branch 'main' into issue_1714_add_variable_names_to_the_ir
rturrado Oct 20, 2025
54649ee
Address David Ittah's review comments
rturrado Oct 20, 2025
06984ba
Remove PrintUseNameLocAsPrefix.mlir
rturrado Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/catalyst-cli/catalyst-cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ intermediate files are saved.
Keep intermediate files after each pipeline in the compilation. By default, no intermediate files
are saved. Using ``--keep-intermediate`` is equivalent to using ``--save-ir-after-each=pipeline``.

``--use-nameloc-as-prefix[=<true|false>]``
""""""""""""""""""""""""""""""""""""""""""

Print SSA IDs using their name location, if provided, as prefix. By default, name location information is not used.
Name location, or named source location, is a type of source location information that allows attaching a name to a child location.

``--{passname}``
"""""""""""""""

Expand Down
2 changes: 1 addition & 1 deletion frontend/catalyst/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.14.0-dev3"
__version__ = "0.14.0-dev10"
11 changes: 9 additions & 2 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def _quantum_opt(*args, stdin=None):
return _catalyst(("--tool", "opt"), *args, stdin=stdin)


def canonicalize(*args, stdin=None):
def canonicalize(*args, stdin=None, options: Optional[CompileOptions] = None):
"""Run opt with canonicalization

echo ${stdin} | catalyst --tool=opt \
Expand All @@ -316,7 +316,11 @@ def canonicalize(*args, stdin=None):

Returns stdout string
"""
return _quantum_opt(("--pass-pipeline", "builtin.module(canonicalize)"), *args, stdin=stdin)
opts = ["--pass-pipeline", "builtin.module(canonicalize)"]
if options and options.use_nameloc:
opts.append("--use-nameloc-as-prefix")

return _quantum_opt(*opts, *args, stdin=stdin)


def _options_to_cli_flags(options):
Expand Down Expand Up @@ -349,6 +353,9 @@ def _options_to_cli_flags(options):
extra_args += ["--save-ir-after-each=pass"]
extra_args += ["--dump-module-scope"]

if options.use_nameloc:
extra_args += ["--use-nameloc-as-prefix"]

if options.verbose:
extra_args += ["--verbose"]

Expand Down
8 changes: 6 additions & 2 deletions frontend/catalyst/jax_extras/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,13 @@


@debug_logger
def jaxpr_to_mlir(func_name, jaxpr):
def jaxpr_to_mlir(jaxpr, func_name, arg_names):
"""Lower a Jaxpr into an MLIR module.

Args:
func_name(str): function name
jaxpr(Jaxpr): Jaxpr code to lower
func_name(str): function name
arg_names(list[str]): list of argument names

Returns:
module: the MLIR module corresponding to ``func``
Expand All @@ -81,6 +82,7 @@ def jaxpr_to_mlir(func_name, jaxpr):
platform="cpu",
axis_context=axis_context,
name_stack=name_stack,
arg_names=arg_names,
)

return module, context
Expand All @@ -97,6 +99,7 @@ def custom_lower_jaxpr_to_module(
axis_context: AxisContext,
name_stack,
replicated_args=None,
arg_names=None,
arg_shardings=None,
result_shardings=None,
):
Expand Down Expand Up @@ -149,6 +152,7 @@ def custom_lower_jaxpr_to_module(
effects,
public=True,
replicated_args=replicated_args,
arg_names=arg_names,
arg_shardings=arg_shardings,
result_shardings=result_shardings,
name_stack=name_stack,
Expand Down
5 changes: 3 additions & 2 deletions frontend/catalyst/jax_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,12 +665,13 @@ def trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs, debug_in


@debug_logger
def lower_jaxpr_to_mlir(jaxpr, func_name):
def lower_jaxpr_to_mlir(jaxpr, func_name, arg_names):
"""Lower a JAXPR to MLIR.
Args:
ClosedJaxpr: the JAXPR to lower to MLIR
func_name: a name to use for the MLIR function
arg_names: list of parameter names for the MLIR function
Returns:
ir.Module: the MLIR module coontaining the JAX program
Expand All @@ -680,7 +681,7 @@ def lower_jaxpr_to_mlir(jaxpr, func_name):
MemrefCallable.clearcache()

with transient_jax_config({"jax_dynamic_shapes": True}):
mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
mlir_module, ctx = jaxpr_to_mlir(jaxpr, func_name, arg_names)

return mlir_module, ctx

Expand Down
29 changes: 17 additions & 12 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from catalyst.tracing.type_signatures import (
filter_static_args,
get_abstract_signature,
get_arg_names,
get_type_annotations,
merge_static_argname_into_argnum,
merge_static_args,
Expand Down Expand Up @@ -79,6 +80,7 @@ def qjit(
async_qnodes=False,
target="binary",
keep_intermediate=False,
use_nameloc=False,
verbose=False,
logfile=None,
pipelines=None,
Expand Down Expand Up @@ -121,6 +123,8 @@ def qjit(
- :attr:`~.QJIT.mlir`: MLIR representation after canonicalization
- :attr:`~.QJIT.mlir_opt`: MLIR representation after optimization
- :attr:`~.QJIT.qir`: QIR in LLVM IR form
use_nameloc (bool): If ``True``, function parameter names are added to the IR as name
locations.
verbose (bool): If ``True``, the tools and flags used by Catalyst behind the scenes are
printed out.
logfile (Optional[TextIOWrapper]): File object to write verbose messages to (default -
Expand Down Expand Up @@ -517,7 +521,6 @@ class QJIT(CatalystCallable):
:ivar jaxpr: This attribute stores the Jaxpr compiled from the function as a string.
:ivar mlir: This attribute stores the MLIR compiled from the function as a string.
:ivar qir: This attribute stores the QIR in LLVM IR form compiled from the function as a string.

"""

@debug_logger_init
Expand Down Expand Up @@ -562,20 +565,26 @@ def __init__(self, fn, compile_options):

@property
def mlir(self):
"""obtain the MLIR representation after canonicalization"""
"""Obtain the MLIR representation after canonicalization"""
# Canonicalize the MLIR since there can be a lot of redundancy coming from JAX.
if not self.mlir_module:
return None

return canonicalize(stdin=str(self.mlir_module))
stdin = self.mlir_module.operation.get_asm(
enable_debug_info=self.compile_options.use_nameloc
)
return canonicalize(stdin=stdin, options=self.compile_options)

@property
def mlir_opt(self):
"""obtain the MLIR representation after optimization"""
"""Obtain the MLIR representation after optimization"""
if not self.mlir_module:
return None

return to_mlir_opt(stdin=str(self.mlir_module), options=self.compile_options)
stdin = self.mlir_module.operation.get_asm(
enable_debug_info=self.compile_options.use_nameloc
)
return to_mlir_opt(stdin=stdin, options=self.compile_options)

@debug_logger
def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -604,7 +613,6 @@ def __call__(self, *args, **kwargs):
@debug_logger
def aot_compile(self):
"""Compile Python function on initialization using the type hint signature."""

self.workspace = self._get_workspace()

# TODO: awkward, refactor or redesign the target feature
Expand Down Expand Up @@ -643,7 +651,6 @@ def jit_compile(self, args, **kwargs):
bool: whether the provided arguments will require promotion to be used with the compiled
function
"""

cached_fn, requires_promotion = self.fn_cache.lookup(args)

if cached_fn is None:
Expand Down Expand Up @@ -774,8 +781,9 @@ def generate_ir(self):
Returns:
Tuple[ir.Module, str]: the in-memory MLIR module and its string representation
"""

mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
mlir_module, ctx = lower_jaxpr_to_mlir(
self.jaxpr, self.__name__, get_arg_names(self.jaxpr.in_avals, self.original_function)
)

# Inject Runtime Library-specific functions (e.g. setup/teardown).
inject_functions(mlir_module, ctx, self.compile_options.seed)
Expand All @@ -790,7 +798,6 @@ def compile(self):
Returns:
Tuple[CompiledFunction, str]: the compilation result and LLVMIR
"""

# WARNING: assumption is that the first function is the entry point to the compiled program.
entry_point_func = self.mlir_module.body.operations[0]
restype = entry_point_func.type.results
Expand Down Expand Up @@ -833,7 +840,6 @@ def run(self, args, kwargs):
Returns:
Any: results of the execution arranged into the original function's output PyTrees
"""

results = self.compiled_function(*args, **kwargs)

# TODO: Move this to the compiled function object.
Expand All @@ -853,7 +859,6 @@ def _validate_configuration(self):

def _get_workspace(self):
"""Get or create a workspace to use for compilation."""

workspace_name = self.__name__
preferred_workspace_dir = os.getcwd() if self.use_cwd_for_workspace else None

Expand Down
3 changes: 3 additions & 0 deletions frontend/catalyst/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class CompileOptions:
- ``False`` or ``0`` or ``"none"`` (default): No intermediate files are kept.
- ``True`` or ``1`` or ``"pipeline"``: Intermediate files are saved after each pipeline.
- ``2`` or ``"pass"``: Intermediate files are saved after each pass.
use_nameloc (Optional[bool]): If ``True``, add function parameter names to the IR as name
locations.
pipelines (Optional[List[Tuple[str,List[str]]]]): A list of tuples. The first entry of the
tuple corresponds to the name of a pipeline. The second entry of the tuple corresponds
to a list of MLIR passes.
Expand Down Expand Up @@ -115,6 +117,7 @@ class CompileOptions:
logfile: Optional[TextIOWrapper] = sys.stderr
target: Optional[str] = "binary"
keep_intermediate: Optional[Union[str, int, bool, KeepIntermediateLevel]] = False
use_nameloc: Optional[bool] = False
pipelines: Optional[List[Any]] = None
autograph: Optional[bool] = False
autograph_include: Optional[Iterable[str]] = ()
Expand Down
27 changes: 26 additions & 1 deletion frontend/catalyst/tracing/type_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Callable

import jax
from jax._src.core import shaped_abstractify
from jax._src.core import DShapedArray, shaped_abstractify
from jax._src.interpreters.partial_eval import infer_lambda_input_type
from jax._src.pjit import _flat_axes_specs
from jax.core import AbstractValue
Expand Down Expand Up @@ -324,3 +324,28 @@ def promote_arguments(target_signature, args):
promoted_args.append(promoted_arg)

return tree_unflatten(treedef, promoted_args)


def get_arg_names(qjit_jaxpr_in_avals: tuple[AbstractValue, ...], qjit_original_function: Callable):
"""Construct a list of argument names, with the size of qjit_jaxpr_in_avals, and fill it with
the names of the parameters of the original function signature.
The number of parameters of the original function could be different to the number of
elements in qjit_jaxpr_in_avals. For example, if a function with one parameter is invoked with a
dynamic argument, qjit_jaxpr_in_avals will contain two elements (a dynamically-shaped array, and
its type).

Args:
qjit_jaxpr_in_avals: list of abstract values that represent the inputs to the QJIT's JAXPR
qjit_original_function: QJIT's original function

Returns:
A list of argument names with the same number of elements than qjit_jaxpr_in_avals.
The argument names are assigned from the list of parameters of the original function,
in order, and until that list is empty. Then left to empty strings.
"""
arg_names = [""] * len(qjit_jaxpr_in_avals)
param_values = [p.name for p in inspect.signature(qjit_original_function).parameters.values()]
for in_aval_index, in_aval in enumerate(qjit_jaxpr_in_avals):
if len(param_values) > 0 and type(in_aval) != DShapedArray:
arg_names[in_aval_index] = param_values.pop(0)
return arg_names
54 changes: 54 additions & 0 deletions frontend/test/lit/test_option_use_nameloc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for use name location option."""

# RUN: %PYTHON %s | FileCheck %s

from utils import print_mlir, print_mlir_opt

from catalyst import qjit


# CHECK-LABEL: @jit_f
@qjit(use_nameloc=True)
def f(x: float, y: float):
"""Check that MLIR module contains name location information, and MLIR code uses that name
location information.
"""
# CHECK: %x: tensor<f64>, %y: tensor<f64>
return x * y


assert str(f.mlir_module.body.operations[0].arguments[0].location) == 'loc("x")'
assert str(f.mlir_module.body.operations[0].arguments[1].location) == 'loc("y")'

print_mlir(f, 0.3, 0.4)


# CHECK-LABEL: @jit_f_opt
@qjit(use_nameloc=True)
def f_opt(x: float, y: float):
"""Check that MLIR module contains name location information, and MLIR code uses that name
location information.
Same test as before, but now we exercise mlir_opt property.
"""
# CHECK: %x: tensor<f64>, %y: tensor<f64>
return x * y


assert str(f_opt.mlir_module.body.operations[0].arguments[0].location) == 'loc("x")'
assert str(f_opt.mlir_module.body.operations[0].arguments[1].location) == 'loc("y")'

print_mlir_opt(f_opt, 0.3, 0.4)
5 changes: 5 additions & 0 deletions frontend/test/lit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,8 @@ def print_jaxpr(f, *args, **kwargs):
def print_mlir(f, *args, **kwargs):
"""Print mlir code of a function"""
return print_attr(f, "mlir", *args, **kwargs)


def print_mlir_opt(f, *args, **kwargs):
"""Print mlir code of a function"""
return print_attr(f, "mlir_opt", *args, **kwargs)
7 changes: 7 additions & 0 deletions frontend/test/pytest/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,13 @@ def test_option_dialect_plugin_tuple(self):
assert ("--load-dialect-plugin", path) in flags
assert isinstance(options.dialect_plugins, set)

def test_option_use_nameloc(self):
"""Test use name location option"""

options = CompileOptions(use_nameloc=True)
flags = _options_to_cli_flags(options)
assert "--use-nameloc-as-prefix" in flags

def test_option_not_lower_to_llvm(self):
"""Test not lower to llvm"""
options = CompileOptions(lower_to_llvm=False)
Expand Down
Loading