Skip to content

Commit 0331b5d

Browse files
committed
Fix bug where augmented assignment kept on taping
1 parent 2325204 commit 0331b5d

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

firedrake/adjoint/function.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,11 @@ def wrapper(self, other, *args, **kwargs):
149149
def _ad_annotate_iadd(__iadd__):
150150
@wraps(__iadd__)
151151
def wrapper(self, other, **kwargs):
152+
with stop_annotating():
153+
func = __iadd__(self, other, **kwargs)
154+
152155
ad_block_tag = kwargs.pop("ad_block_tag", None)
153156
annotate = annotate_tape(kwargs)
154-
func = __iadd__(self, other, **kwargs)
155-
156157
if annotate:
157158
block = FunctionAssignBlock(func, self + other, ad_block_tag=ad_block_tag)
158159
tape = get_working_tape()
@@ -167,10 +168,11 @@ def wrapper(self, other, **kwargs):
167168
def _ad_annotate_isub(__isub__):
168169
@wraps(__isub__)
169170
def wrapper(self, other, **kwargs):
171+
with stop_annotating():
172+
func = __isub__(self, other, **kwargs)
173+
170174
ad_block_tag = kwargs.pop("ad_block_tag", None)
171175
annotate = annotate_tape(kwargs)
172-
func = __isub__(self, other, **kwargs)
173-
174176
if annotate:
175177
block = FunctionAssignBlock(func, self - other, ad_block_tag=ad_block_tag)
176178
tape = get_working_tape()
@@ -185,10 +187,11 @@ def wrapper(self, other, **kwargs):
185187
def _ad_annotate_imul(__imul__):
186188
@wraps(__imul__)
187189
def wrapper(self, other, **kwargs):
190+
with stop_annotating():
191+
func = __imul__(self, other, **kwargs)
192+
188193
ad_block_tag = kwargs.pop("ad_block_tag", None)
189194
annotate = annotate_tape(kwargs)
190-
func = __imul__(self, other, **kwargs)
191-
192195
if annotate:
193196
block = FunctionAssignBlock(func, self*other, ad_block_tag=ad_block_tag)
194197
tape = get_working_tape()
@@ -203,10 +206,11 @@ def wrapper(self, other, **kwargs):
203206
def _ad_annotate_idiv(__idiv__):
204207
@wraps(__idiv__)
205208
def wrapper(self, other, **kwargs):
209+
with stop_annotating():
210+
func = __idiv__(self, other, **kwargs)
211+
206212
ad_block_tag = kwargs.pop("ad_block_tag", None)
207213
annotate = annotate_tape(kwargs)
208-
func = __idiv__(self, other, **kwargs)
209-
210214
if annotate:
211215
block = FunctionAssignBlock(func, self/other, ad_block_tag=ad_block_tag)
212216
tape = get_working_tape()

0 commit comments

Comments
 (0)