@@ -79,9 +79,7 @@ def __init__(self, proc, kwargs, use_randomization=False):
7979 )
8080 self .env [a .name ] = kwargs [str (a .name )]
8181 else :
82- assert a .type .is_numeric (), "arg {a.name} is not a numeric type"
83- assert a .type .basetype () != T .R , "arg basetype {a.name} should not be T.R"
84- self .simple_typecheck_buffer (a , kwargs )
82+ self .typecheck_input_buffer (a , kwargs )
8583 self .env [a .name ] = kwargs [str (a .name )]
8684
8785 # evaluate preconditions
@@ -102,31 +100,54 @@ def _new_scope(self):
102100 def _del_scope (self ):
103101 self .env = self .env .parents
104102
105- # input buffers should be numpy arrays with floating-point values
106- def simple_typecheck_buffer (self , fnarg , kwargs ):
107- typ = fnarg .type
108- buf = kwargs [str (fnarg .name )]
109- nm = fnarg .name
103+ def typecheck_input_buffer (self , proc_arg , kwargs ):
104+ nm = proc_arg .name
105+ if not proc_arg .type .is_numeric ():
106+ raise TypeError (f"arg { nm } is expected to be numeric" )
107+
108+ basetype = proc_arg .type .basetype ()
109+ buf = kwargs [str (proc_arg .name )]
110110
111- # check data type
112111 pre = f"bad argument '{ nm } '"
113112 if not isinstance (buf , np .ndarray ):
114113 raise TypeError (f"{ pre } : expected numpy.ndarray" )
115- elif buf .dtype != float and buf .dtype != np .float32 and buf .dtype != np .float16 :
116- raise TypeError (
117- f"{ pre } : expected buffer of floating-point values; "
118- f"had '{ buf .dtype } ' values"
119- )
120-
121- # check shape
122- if typ .is_real_scalar ():
114+
115+ if isinstance (basetype , T .F32 ):
116+ if buf .dtype != np .float32 :
117+ raise TypeError (f"{ pre } : received { buf .dtype } values" )
118+
119+ if isinstance (basetype , T .F16 ):
120+ if buf .dtype != np .float16 :
121+ raise TypeError (f"{ pre } : received { buf .dtype } values" )
122+
123+ if isinstance (basetype , (T .F64 , T .Num )):
124+ if buf .dtype != np .float64 :
125+ raise TypeError (f"{ pre } : received { buf .dtype } values" )
126+
127+ if isinstance (basetype , T .INT8 ):
128+ if buf .dtype != np .int8 :
129+ raise TypeError (f"{ pre } : received { buf .dtype } values" )
130+
131+ if isinstance (basetype , T .INT32 ):
132+ if buf .dtype != np .int32 :
133+ raise TypeError (f"{ pre } : received { buf .dtype } values" )
134+
135+ if isinstance (basetype , T .UINT8 ):
136+ if buf .dtype != np .uint8 :
137+ raise TypeError (f"{ pre } : received { buf .dtype } values" )
138+
139+ if isinstance (basetype , T .UINT16 ):
140+ if buf .dtype != np .uint16 :
141+ raise TypeError (f"{ pre } : received { buf .dtype } values" )
142+
143+ if proc_arg .type .is_real_scalar ():
123144 if tuple (buf .shape ) != (1 ,):
124145 raise TypeError (
125146 f"{ pre } : expected buffer of shape (1,), "
126147 f"but got shape { tuple (buf .shape )} "
127148 )
128149 else :
129- shape = self .eval_shape (typ )
150+ shape = self .eval_shape (proc_arg . type )
130151 if shape != tuple (buf .shape ):
131152 raise TypeError (
132153 f"{ pre } : expected buffer of shape { shape } , "
@@ -196,7 +217,6 @@ def eval_s(self, s):
196217 # TODO: Maybe randomize?
197218 self .env [s .name ] = np .empty (size )
198219
199- # TODO (andrew) figure out a way to test this, no explicit frees that I can find
200220 elif isinstance (s , LoopIR .Free ):
201221 # use extension to chain map from python docs
202222 del self .env [s .name ]
0 commit comments