Skip to content
6 changes: 5 additions & 1 deletion src/kirin/analysis/const/prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ def try_eval_const_pure(
_frame.set_values(stmt.args, tuple(x.data for x in values))
method = self._interp.lookup_registry(frame, stmt)
if method is not None:
value = method(self._interp, _frame, stmt)
try:
value = method(self._interp, _frame, stmt)
except NotImplementedError:
# the concrete interpreter doesn't have the implementation so we cannot evaluate it
return tuple(Unknown() for _ in stmt.results)
Comment on lines +102 to +106
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is redundant? if the method is not None that indicates there is an implementation for the concrete interpretation.

else:
return tuple(Unknown() for _ in stmt.results)
match value:
Expand Down
8 changes: 6 additions & 2 deletions src/kirin/dialects/ilist/constprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def one_args(
# 1. if the function is a constant method, and the method is pure, then the map is pure
if isinstance(fn, const.Value) and isinstance(method := fn.data, ir.Method):
self.detect_purity(interp_, frame, stmt, method.code, (fn, const.Unknown()))
if isinstance(collection, const.Value):
if isinstance(collection, const.Value) and stmt in frame.should_be_pure:
return interp_.try_eval_const_pure(frame, stmt, (fn, collection))
elif isinstance(fn, const.PartialLambda):
self.detect_purity(interp_, frame, stmt, fn.code, (fn, const.Unknown()))
Expand All @@ -57,7 +57,11 @@ def two_args(self, interp_: const.Propagate, frame: const.Frame, stmt: Foldl):
method.code,
(fn, const.Unknown(), const.Unknown()),
)
if isinstance(collection, const.Value) and isinstance(init, const.Value):
if (
isinstance(collection, const.Value)
and isinstance(init, const.Value)
and stmt in frame.should_be_pure
):
return interp_.try_eval_const_pure(frame, stmt, (fn, collection, init))
elif isinstance(fn, const.PartialLambda):
self.detect_purity(
Expand Down
52 changes: 52 additions & 0 deletions test/analysis/dataflow/constprop/test_missing_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from kirin import ir, types, passes, lowering
from kirin.decl import info, statement
from kirin.prelude import basic_no_opt
from kirin.analysis import const
from kirin.dialects import ilist

new_dialect = ir.Dialect("test")


@statement(dialect=new_dialect)
class DefaultInit(ir.Statement):
name = "test"

traits = frozenset({lowering.FromPythonCall(), ir.Pure()})

result: ir.ResultValue = info.result(types.Int)


dialect_group = basic_no_opt.add(new_dialect)


def test_missing_impl_try_eval_const_pure():
# this test is trying to trigger the code path in propagate.py
# where a statement has no concrete implementation but is pure
# in this case, the ilist will attempt to evaluate the closure
# which contains a call to DefaultInit, which has no implementation
# in the concrete interpreter. In this case we should still be able
# to mark the result as Unknown, rather than failing the analysis.
# In other words, if a statement has no implementation, but is pure,
# the function `try_eval_const_pure` will catch the exception and
# return Unknown for the result.
@dialect_group
def test():
n = 10

def _inner(val: int) -> int:
return DefaultInit() * val # type: ignore

return ilist.map(_inner, ilist.range(n))

passes.HintConst(dialect_group)(test)

for i in range(5):
stmt = test.callable_region.blocks[0].stmts.at(i)
assert all(
isinstance(result.hints.get("const"), const.Value)
for result in stmt.results
)

call_stmt = test.callable_region.blocks[0].stmts.at(5)
assert isinstance(call_stmt, ilist.Map)
assert isinstance(call_stmt.result.hints.get("const"), const.Unknown)
27 changes: 27 additions & 0 deletions test/dialects/test_ilist.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any, Literal

from kirin import ir, types, rewrite
from kirin.decl import info, statement
from kirin.passes import aggressive
from kirin.prelude import basic_no_opt, python_basic
from kirin.analysis import const
from kirin.dialects import py, func, ilist, lowering
from kirin.lowering import FromPythonCall
from kirin.passes.typeinfer import TypeInfer


Expand Down Expand Up @@ -386,6 +388,31 @@ def main2():
assert target.data == (6, 6)


def test_ilist_constprop_non_pure():

new_dialect = ir.Dialect("test")

@statement(dialect=new_dialect)
class DefaultInit(ir.Statement):
name = "test"
traits = frozenset({FromPythonCall()})
result: ir.ResultValue = info.result(types.Float)

dialect_group = basic_no_opt.add(new_dialect)

@dialect_group
def test():

def inner(_: int):
return DefaultInit()

return ilist.map(inner, ilist.range(10))

_, res = const.Propagate(dialect_group).run(test)

assert isinstance(res, const.Unknown)


rule = rewrite.Fixpoint(rewrite.Walk(ilist.rewrite.Unroll()))
xs = ilist.IList([1, 2, 3])

Expand Down