Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions numba_cuda/numba/cuda/core/inline_closurecall.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,34 +331,40 @@ def check(arg, name):
self.calltypes = calltypes

def inline_ir(
self, caller_ir, block, i, callee_ir, callee_freevars, arg_typs=None
self, caller_ir, block, i, callee_ir, callee_freevars,
arg_typs=None, preserve_ir=True,
):
"""Inlines the callee_ir in the caller_ir at statement index i of block
`block`, callee_freevars are the free variables for the callee_ir. If
the callee_ir is derived from a function `func` then this is
`func.__code__.co_freevars`. If `arg_typs` is given and the InlineWorker
instance was initialized with a typemap and calltypes then they will be
appropriately updated based on the arg_typs.
appropriately updated based on the arg_typs. If `preserve_ir` is
True, the callee_ir object will be copied before mutating, otherwise it
will be mutated in place.
"""

# Always copy the callee IR, it gets mutated
def copy_ir(the_ir):
kernel_copy = the_ir.copy()
kernel_copy.blocks = {}
for block_label, block in the_ir.blocks.items():
new_block = copy.deepcopy(the_ir.blocks[block_label])
kernel_copy.blocks[block_label] = new_block
return kernel_copy

callee_ir = copy_ir(callee_ir)
# Save a reference to the incoming callee_ir
callee_ir_original = callee_ir

# When preserve_ir is True, create a copy of the FunctionIR object
# to mutate. Set preserve_ir to False if callee_ir does not persist
# between calls to inline_ir.
if preserve_ir:
def copy_ir(the_ir):
kernel_copy = the_ir.copy()
kernel_copy.blocks = {}
for block_label, block in the_ir.blocks.items():
new_block = copy.deepcopy(the_ir.blocks[block_label])
kernel_copy.blocks[block_label] = new_block
return kernel_copy

callee_ir = copy_ir(callee_ir)

# check that the contents of the callee IR is something that can be
# inlined if a validator is present
if self.validator is not None:
self.validator(callee_ir)

# save an unmutated copy of the callee_ir to return
callee_ir_original = copy_ir(callee_ir)
scope = block.scope
instr = block.body[i]
call_expr = instr.value
Expand Down Expand Up @@ -468,7 +474,8 @@ def inline_function(self, caller_ir, block, i, function, arg_typs=None):
callee_ir = self.run_untyped_passes(function)
freevars = function.__code__.co_freevars
return self.inline_ir(
caller_ir, block, i, callee_ir, freevars, arg_typs=arg_typs
caller_ir, block, i, callee_ir, freevars,
arg_typs=arg_typs, preserve_ir=False,
)

def run_untyped_passes(self, func, enable_ssa=False):
Expand Down