Skip to content

Commit 27ae79e

Browse files
committed
typecheck input buffer
1 parent c59781f commit 27ae79e

File tree

1 file changed

+39
-19
lines changed

1 file changed

+39
-19
lines changed

src/exo/LoopIR_interpreter.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,7 @@ def __init__(self, proc, kwargs, use_randomization=False):
7979
)
8080
self.env[a.name] = kwargs[str(a.name)]
8181
else:
82-
assert a.type.is_numeric(), "arg {a.name} is not a numeric type"
83-
assert a.type.basetype() != T.R, "arg basetype {a.name} should not be T.R"
84-
self.simple_typecheck_buffer(a, kwargs)
82+
self.typecheck_input_buffer(a, kwargs)
8583
self.env[a.name] = kwargs[str(a.name)]
8684

8785
# evaluate preconditions
@@ -102,31 +100,54 @@ def _new_scope(self):
102100
def _del_scope(self):
103101
self.env = self.env.parents
104102

105-
# input buffers should be numpy arrays with floating-point values
106-
def simple_typecheck_buffer(self, fnarg, kwargs):
107-
typ = fnarg.type
108-
buf = kwargs[str(fnarg.name)]
109-
nm = fnarg.name
103+
def typecheck_input_buffer(self, proc_arg, kwargs):
104+
nm = proc_arg.name
105+
if not proc_arg.type.is_numeric():
106+
raise TypeError(f"arg {nm} is expected to be numeric")
107+
108+
basetype = proc_arg.type.basetype()
109+
buf = kwargs[str(proc_arg.name)]
110110

111-
# check data type
112111
pre = f"bad argument '{nm}'"
113112
if not isinstance(buf, np.ndarray):
114113
raise TypeError(f"{pre}: expected numpy.ndarray")
115-
elif buf.dtype != float and buf.dtype != np.float32 and buf.dtype != np.float16:
116-
raise TypeError(
117-
f"{pre}: expected buffer of floating-point values; "
118-
f"had '{buf.dtype}' values"
119-
)
120-
121-
# check shape
122-
if typ.is_real_scalar():
114+
115+
if isinstance(basetype, T.F32):
116+
if buf.dtype != np.float32:
117+
raise TypeError(f"{pre}: received {buf.dtype} values")
118+
119+
if isinstance(basetype, T.F16):
120+
if buf.dtype != np.float16:
121+
raise TypeError(f"{pre}: received {buf.dtype} values")
122+
123+
if isinstance(basetype, (T.F64, T.Num)):
124+
if buf.dtype != np.float64:
125+
raise TypeError(f"{pre}: received {buf.dtype} values")
126+
127+
if isinstance(basetype, T.INT8):
128+
if buf.dtype != np.int8:
129+
raise TypeError(f"{pre}: received {buf.dtype} values")
130+
131+
if isinstance(basetype, T.INT32):
132+
if buf.dtype != np.int32:
133+
raise TypeError(f"{pre}: received {buf.dtype} values")
134+
135+
if isinstance(basetype, T.UINT8):
136+
if buf.dtype != np.uint8:
137+
raise TypeError(f"{pre}: received {buf.dtype} values")
138+
139+
if isinstance(basetype, T.UINT16):
140+
if buf.dtype != np.uint16:
141+
raise TypeError(f"{pre}: received {buf.dtype} values")
142+
143+
if proc_arg.type.is_real_scalar():
123144
if tuple(buf.shape) != (1,):
124145
raise TypeError(
125146
f"{pre}: expected buffer of shape (1,), "
126147
f"but got shape {tuple(buf.shape)}"
127148
)
128149
else:
129-
shape = self.eval_shape(typ)
150+
shape = self.eval_shape(proc_arg.type)
130151
if shape != tuple(buf.shape):
131152
raise TypeError(
132153
f"{pre}: expected buffer of shape {shape}, "
@@ -196,7 +217,6 @@ def eval_s(self, s):
196217
# TODO: Maybe randomize?
197218
self.env[s.name] = np.empty(size)
198219

199-
# TODO (andrew) figure out a way to test this, no explicit frees that I can find
200220
elif isinstance(s, LoopIR.Free):
201221
# use extension to chain map from python docs
202222
del self.env[s.name]

0 commit comments

Comments
 (0)