@@ -35,64 +35,16 @@ def run_interpreter(proc, kwargs):
3535 Interpreter (proc , kwargs )
3636
3737
38- # context is global
39- ctxt = defaultdict (dict )
40-
4138class Interpreter :
4239 def __init__ (self , proc , kwargs , use_randomization = False ):
43- assert isinstance (proc , LoopIR .proc )
44-
45- proc = ParallelAnalysis ().run (proc )
46- proc = PrecisionAnalysis ().run (proc ) # TODO: need this?
47- proc = WindowAnalysis ().apply_proc (proc )
48- proc = MemoryAnalysis ().run (proc ) # TODO: need this?
40+ if not isinstance (proc , LoopIR .proc ):
41+ raise TypeError (f"Expected { proc .name } to be of type proc" )
4942
50- self .proc = proc
5143 self .env = ChainMap ()
5244 self .use_randomization = use_randomization
45+ self .ctxt = defaultdict (dict )
5346
54- # type check args
55- for a in proc .args :
56- if not str (a .name ) in kwargs :
57- raise TypeError (f"expected argument '{ a .name } ' to be supplied" )
58-
59- if a .type is T .size :
60- if not is_pos_int (kwargs [str (a .name )]):
61- raise TypeError (
62- f"expected size '{ a .name } ' to have positive integer value"
63- )
64- self .env [a .name ] = kwargs [str (a .name )]
65- elif a .type is T .index :
66- if type (kwargs [str (a .name )]) is not int :
67- raise TypeError (
68- f"expected index variable '{ a .name } ' to be an integer"
69- )
70- self .env [a .name ] = kwargs [str (a .name )]
71- elif a .type is T .bool :
72- if type (kwargs [str (a .name )]) is not bool :
73- raise TypeError (f"expected bool variable '{ a .name } ' to be a bool" )
74- self .env [a .name ] = kwargs [str (a .name )]
75- elif a .type is T .stride :
76- if type (kwargs [str (a .name )]) is not int :
77- raise TypeError (
78- f"expected stride variable '{ a .name } ' to be an integer"
79- )
80- self .env [a .name ] = kwargs [str (a .name )]
81- else :
82- self .typecheck_input_buffer (a , kwargs )
83- self .env [a .name ] = kwargs [str (a .name )]
84-
85- # evaluate preconditions
86- for pred in proc .preds :
87- if isinstance (pred , LoopIR .Const ):
88- continue
89- else :
90- assert self .eval_e (pred ), "precondition not satisfied"
91-
92- # eval statements
93- self .env = self .env .new_child ()
94- self .eval_stmts (proc .body )
95- self .env = self .env .parents
47+ self .eval_proc (proc , kwargs )
9648
9749 def _new_scope (self ):
9850 self .env = self .env .new_child ()
@@ -154,14 +106,60 @@ def typecheck_input_buffer(self, proc_arg, kwargs):
154106 f"but got shape { tuple (buf .shape )} "
155107 )
156108
109+ def eval_proc (self , proc , kwargs ):
110+ proc = ParallelAnalysis ().run (proc )
111+ proc = PrecisionAnalysis ().run (proc ) # TODO: need this?
112+ proc = WindowAnalysis ().apply_proc (proc )
113+ proc = MemoryAnalysis ().run (proc ) # TODO: need this?
114+
115+ for a in proc .args :
116+ if not str (a .name ) in kwargs :
117+ raise TypeError (f"expected argument '{ a .name } ' to be supplied" )
118+
119+ if a .type is T .size :
120+ if not is_pos_int (kwargs [str (a .name )]):
121+ raise TypeError (
122+ f"expected size '{ a .name } ' to have positive integer value"
123+ )
124+ self .env [a .name ] = kwargs [str (a .name )]
125+ elif a .type is T .index :
126+ if type (kwargs [str (a .name )]) is not int :
127+ raise TypeError (
128+ f"expected index variable '{ a .name } ' to be an integer"
129+ )
130+ self .env [a .name ] = kwargs [str (a .name )]
131+ elif a .type is T .bool :
132+ if type (kwargs [str (a .name )]) is not bool :
133+ raise TypeError (f"expected bool variable '{ a .name } ' to be a bool" )
134+ self .env [a .name ] = kwargs [str (a .name )]
135+ elif a .type is T .stride :
136+ if type (kwargs [str (a .name )]) is not int :
137+ raise TypeError (
138+ f"expected stride variable '{ a .name } ' to be an integer"
139+ )
140+ self .env [a .name ] = kwargs [str (a .name )]
141+ else :
142+ self .typecheck_input_buffer (a , kwargs )
143+ self .env [a .name ] = kwargs [str (a .name )]
144+
145+ # evaluate preconditions
146+ for pred in proc .preds :
147+ if isinstance (pred , LoopIR .Const ):
148+ continue
149+ else :
150+ assert self .eval_e (pred ), "precondition not satisfied"
151+
152+ # eval statements
153+ self .eval_stmts (proc .body )
154+
157155 def eval_stmts (self , stmts ):
158156 for s in stmts :
159157 self .eval_s (s )
160158
161159 def eval_s (self , s ):
162160 if isinstance (s , LoopIR .Pass ):
163161 pass
164-
162+
165163 elif isinstance (s , (LoopIR .Assign , LoopIR .Reduce )):
166164 lbuf = self .env [s .name ]
167165 if len (s .idx ) == 0 :
@@ -179,12 +177,14 @@ def eval_s(self, s):
179177 elif isinstance (s , LoopIR .WriteConfig ):
180178 nm = s .config .name ()
181179 rhs = self .eval_e (s .rhs )
182- ctxt [nm ][s .field ] = rhs
180+ self . ctxt [nm ][s .field ] = rhs
183181
184182 elif isinstance (s , LoopIR .WindowStmt ):
185183 # nm = rbuf[...]
186184 assert s .name not in self .env , "WindowStmt should be a fresh assignment"
187- assert isinstance (s .rhs , LoopIR .WindowExpr ), "WindowStmt rhs should be WindowExpr"
185+ assert isinstance (
186+ s .rhs , LoopIR .WindowExpr
187+ ), "WindowStmt rhs should be WindowExpr"
188188 self .env [s .name ] = self .eval_e (s .rhs )
189189
190190 elif isinstance (s , LoopIR .If ):
@@ -225,7 +225,9 @@ def eval_s(self, s):
225225 argvals = [self .eval_e (a , call_arg = True ) for a in s .args ]
226226 argnames = [str (a .name ) for a in s .f .args ]
227227 kwargs = {nm : val for nm , val in zip (argnames , argvals )}
228- Interpreter (s .f , kwargs , use_randomization = self .use_randomization )
228+ self ._new_scope ()
229+ self .eval_proc (s .f , kwargs )
230+ self ._del_scope ()
229231
230232 else :
231233 assert False , "bad statement case"
@@ -253,10 +255,14 @@ def stringify_w_access(a):
253255 assert False , "bad w_access case"
254256
255257 # hack to handle interval indexes: LoopIR.Interval returns a string representing the interval
256- idx = ("0" ,) if len (e .idx ) == 0 else tuple (stringify_w_access (a ) for a in e .idx )
258+ idx = (
259+ ("0" ,)
260+ if len (e .idx ) == 0
261+ else tuple (stringify_w_access (a ) for a in e .idx )
262+ )
257263 res = eval (f"buf[{ ',' .join (idx )} ]" )
258264 return res
259-
265+
260266 elif isinstance (e , LoopIR .Const ):
261267 return e .val
262268
@@ -268,9 +274,12 @@ def stringify_w_access(a):
268274 return lhs - rhs
269275 elif e .op == "*" :
270276 return lhs * rhs
271- elif e .op == "/" : # is this right?
272- if isinstance (lhs , int ):
273- return (lhs + rhs - 1 ) // rhs
277+ elif e .op == "/" :
278+ if isinstance (lhs , int ) and isinstance (rhs , int ):
279+ # this is what was here before and without the rhs check
280+ # counter example of why this is wrong -3 / 2 == -1 in C and 0 in this impl
281+ # return (lhs + rhs - 1) // rhs
282+ return int (lhs / rhs )
274283 else :
275284 return lhs / rhs
276285 elif e .op == "%" :
@@ -293,9 +302,12 @@ def stringify_w_access(a):
293302 elif isinstance (e , LoopIR .USub ):
294303 return - self .eval_e (e .arg )
295304
305+ # BuiltIns don't go to the interpreter, they are just called (via call) like a proc
306+ # TODO Discuss to make sure
296307 elif isinstance (e , LoopIR .BuiltIn ):
297- args = [self .eval_e (a ) for a in e .args ]
298- return e .f .interpret (args )
308+ assert False , "Not implemented"
309+ # args = [self.eval_e(a) for a in e.args]
310+ # return e.f.interpret(args)
299311
300312 elif isinstance (e , LoopIR .StrideExpr ):
301313 buf = self .env [e .name ]
@@ -305,7 +317,7 @@ def stringify_w_access(a):
305317
306318 elif isinstance (e , LoopIR .ReadConfig ):
307319 nm = e .config .name ()
308- return ctxt [nm ][e .field ]
320+ return self . ctxt [nm ][e .field ]
309321
310322 else :
311323 print (e )
0 commit comments