Skip to content

Commit 5c9a778

Browse files
committed
Handle non-annotated operations
1 parent b62615b commit 5c9a778

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

pyadjoint/adjfloat.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from functools import cached_property, lru_cache, wraps
22
from itertools import count
3+
import numbers
34
import operator
45
from .block import Block
56
from .overloaded_type import OverloadedType, register_overloaded_type
@@ -115,24 +116,28 @@ def annotate_operator(sp_operator):
115116
def wrapper(fn):
116117
@wraps(fn)
117118
def annotated_operator(*args):
118-
for arg in args:
119-
if isinstance(arg, OverloadedType):
120-
cls = type(arg)
121-
break
122-
else:
123-
raise TypeError("OverloadedType required")
124119
output = fn(*args)
125-
if output is NotImplemented:
120+
if not isinstance(output, numbers.Complex):
121+
# Not annotated
126122
return output
127-
args = tuple(arg if isinstance(arg, OverloadedType) else cls(arg) for arg in args)
128-
output = cls(output)
123+
output = AdjFloat(output) # Error here if not Real
124+
129125
if annotate_tape():
130126
args = list(args)
131-
bv = set()
127+
adjfloat_bv = set()
132128
for i, arg in enumerate(args):
133-
if arg.block_variable in bv:
134-
args[i] = +arg # copy
135-
bv.add(arg.block_variable)
129+
if isinstance(arg, AdjFloat):
130+
if arg.block_variable in adjfloat_bv:
131+
args[i] = +arg # copy
132+
adjfloat_bv.add(arg.block_variable)
133+
elif isinstance(arg, OverloadedType):
134+
pass
135+
elif isinstance(arg, numbers.Complex):
136+
args[i] = AdjFloat(arg) # Error here if not real
137+
else:
138+
# Not annotated
139+
return output
140+
136141
block = AdjFloatExprBlock(sp_operator, *args)
137142
tape = get_working_tape()
138143
tape.add_block(block)
@@ -173,11 +178,13 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
173178
return NotImplemented
174179
if len(kwargs) > 0:
175180
return NotImplemented
176-
if ufunc not in _ops:
177-
return NotImplemented
178181
if len(inputs) == 0:
179182
return NotImplemented
180-
return _ops[ufunc](*inputs)
183+
if ufunc in _ops:
184+
return _ops[ufunc](*inputs)
185+
else:
186+
# Not annotated
187+
return ufunc(*(float(arg) if isinstance(arg, AdjFloat) else arg for arg in inputs))
181188

182189
@annotate_operator(Operator(sp.Abs, 1))
183190
def __abs__(self):

0 commit comments

Comments
 (0)