|
1 | 1 | from functools import cached_property, lru_cache, wraps |
2 | 2 | from itertools import count |
| 3 | +import numbers |
3 | 4 | import operator |
4 | 5 | from .block import Block |
5 | 6 | from .overloaded_type import OverloadedType, register_overloaded_type |
@@ -115,24 +116,28 @@ def annotate_operator(sp_operator): |
115 | 116 | def wrapper(fn): |
116 | 117 | @wraps(fn) |
117 | 118 | def annotated_operator(*args): |
118 | | - for arg in args: |
119 | | - if isinstance(arg, OverloadedType): |
120 | | - cls = type(arg) |
121 | | - break |
122 | | - else: |
123 | | - raise TypeError("OverloadedType required") |
124 | 119 | output = fn(*args) |
125 | | - if output is NotImplemented: |
| 120 | + if not isinstance(output, numbers.Complex): |
| 121 | + # Not annotated |
126 | 122 | return output |
127 | | - args = tuple(arg if isinstance(arg, OverloadedType) else cls(arg) for arg in args) |
128 | | - output = cls(output) |
| 123 | + output = AdjFloat(output) # Error here if not Real |
| 124 | + |
129 | 125 | if annotate_tape(): |
130 | 126 | args = list(args) |
131 | | - bv = set() |
| 127 | + adjfloat_bv = set() |
132 | 128 | for i, arg in enumerate(args): |
133 | | - if arg.block_variable in bv: |
134 | | - args[i] = +arg # copy |
135 | | - bv.add(arg.block_variable) |
| 129 | + if isinstance(arg, AdjFloat): |
| 130 | + if arg.block_variable in adjfloat_bv: |
| 131 | + args[i] = +arg # copy |
| 132 | + adjfloat_bv.add(arg.block_variable) |
| 133 | + elif isinstance(arg, OverloadedType): |
| 134 | + pass |
| 135 | + elif isinstance(arg, numbers.Complex): |
| 136 | + args[i] = AdjFloat(arg) # Error here if not real |
| 137 | + else: |
| 138 | + # Not annotated |
| 139 | + return output |
| 140 | + |
136 | 141 | block = AdjFloatExprBlock(sp_operator, *args) |
137 | 142 | tape = get_working_tape() |
138 | 143 | tape.add_block(block) |
@@ -173,11 +178,13 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): |
173 | 178 | return NotImplemented |
174 | 179 | if len(kwargs) > 0: |
175 | 180 | return NotImplemented |
176 | | - if ufunc not in _ops: |
177 | | - return NotImplemented |
178 | 181 | if len(inputs) == 0: |
179 | 182 | return NotImplemented |
180 | | - return _ops[ufunc](*inputs) |
| 183 | + if ufunc in _ops: |
| 184 | + return _ops[ufunc](*inputs) |
| 185 | + else: |
| 186 | + # Not annotated |
| 187 | + return ufunc(*(float(arg) if isinstance(arg, AdjFloat) else arg for arg in inputs)) |
181 | 188 |
|
182 | 189 | @annotate_operator(Operator(sp.Abs, 1)) |
183 | 190 | def __abs__(self): |
|
0 commit comments