diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 34e78752e1..f2bbce6b3f 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -472,7 +472,7 @@ def callback(self, clusters, prefix, seen=None): expr = Eq(self.B, HaloTouch(*points, halo_scheme=hs)) key0 = lambda i: i in prefix[:-1] or i in hs.loc_indices # noqa: B023 - key1 = lambda i: i not in hs.distributed_defined # noqa: B023 + key1 = lambda i: not i._defines & set(hs.distributed_defined) # noqa: B023 key = lambda i: key0(i) and key1(i) # noqa: B023 ispace = c.ispace.project(key) diff --git a/tests/test_subdomains.py b/tests/test_subdomains.py index 2fafd4e02b..c3e9e68a7a 100644 --- a/tests/test_subdomains.py +++ b/tests/test_subdomains.py @@ -6,7 +6,7 @@ from conftest import assert_structure, opts_tiling from devito import ( - Border, ConditionalDimension, Constant, Eq, Function, Grid, Lt, Operator, + Border, Buffer, ConditionalDimension, Constant, Eq, Function, Grid, Lt, Operator, SparseFunction, SparseTimeFunction, SubDomain, SubDomainSet, TensorFunction, TimeFunction, VectorFunction, solve ) @@ -228,6 +228,40 @@ def define(self, dimensions): sregistry=SymbolRegistry())[0] assert str(expr.rhs) == 'ix*f[ix + 1, iy + 1] + iy' + @pytest.mark.parallel(mode=2) + def test_halo_subdomain(self, mode): + """ + Test halo lowering with temporary dimensions and shifted subdomain access. + """ + space_order = 8 + + class Interface(SubDomain): + name = 'interface' + + def define(self, dimensions): + x, y, z = dimensions + return { + x: ('middle', 0, 0), + y: ('middle', 0, 0), + z: ('middle', space_order//2, 0) + } + + grid = Grid(shape=(9, 9, 9), subdomains=(Interface(),)) + x, y, z = grid.dimensions + time = grid.stepping_dim + u = TimeFunction(name='u', grid=grid, time_order=1, + space_order=space_order, save=Buffer(1), + is_transient=False) + + equation = Eq(u[time + 1, x, y, z - space_order//2], + u.dxdy + u.dydz + + u[time + 1, x, y, z - space_order//2], + subdomain=grid.subdomains['interface']) + + op = Operator([equation]) + + assert_structure(op, ['t', 'txyz', 'txyz'], 'txyzyz') + class TestMultiSubDomain: