@@ -31,13 +31,21 @@ def __init__(self, config):
3131 for key , value in self .pv_mapping .items ():
3232 self .__validate_formulas (value ['formula' ])
3333 self .latest_input = {symbol : None for symbol in self .input_list }
34+ self .latest_input_struct = {symbol : None for symbol in self .input_list }
3435 self .latest_transformed = {key : 0 for key in self .pv_mapping .keys ()}
3536 self .updated = False
3637 self .handler_time = None
3738 self .formulas = {}
3839 self .lambdified_formulas = {}
40+ self .direct_formula_inputs = {}
41+ self .renamed_symbol_lookup = {
42+ symbol .replace (':' , '_' ): symbol for symbol in self .input_list
43+ }
3944 for key , value in self .pv_mapping .items ():
4045 self .formulas [key ] = sp .sympify (value ['formula' ].replace (':' , '_' ))
46+ self .direct_formula_inputs [key ] = self .__get_direct_input_symbol (
47+ self .formulas [key ]
48+ )
4149 input_list_renamed = [
4250 symbol .replace (':' , '_' ) for symbol in self .input_list
4351 ]
@@ -53,11 +61,17 @@ def __validate_formulas(self, formula: str):
5361 except Exception as e :
5462 raise Exception (f'Invalid formula: { formula } : { e } ' )
5563
64+ def __get_direct_input_symbol (self , formula_expr ):
65+ if isinstance (formula_expr , sp .Symbol ):
66+ return self .renamed_symbol_lookup .get (str (formula_expr ))
67+ return None
68+
5669 def handler (self , pv_name , value ):
5770 # logger.debug(f"SimpleTransformer handler for {pv_name} with value {value}")
5871
5972 # chek if pv_name is in sel.input_list
6073 if pv_name in self .input_list :
74+ self .latest_input_struct [pv_name ] = dict (value )
6175 # assert value is float
6276 try :
6377 if isinstance (value ['value' ], (float , int , np .float32 )):
@@ -161,6 +175,24 @@ def transform(self):
161175 raise e
162176
163177 for key , value in transformed .items ():
178+ direct_input = self .direct_formula_inputs .get (key )
179+ input_struct = (
180+ self .latest_input_struct .get (direct_input )
181+ if direct_input is not None
182+ else None
183+ )
184+ if isinstance (input_struct , dict ):
185+ passthrough_fields = {
186+ field_name : field_value
187+ for field_name , field_value in input_struct .items ()
188+ if field_name != 'value'
189+ }
190+ if passthrough_fields :
191+ self .latest_transformed [key ] = {
192+ 'value' : value ,
193+ ** passthrough_fields ,
194+ }
195+ continue
164196 self .latest_transformed [key ] = value
165197 self .updated = True
166198
@@ -239,12 +271,14 @@ def __init__(self, config):
239271 # config is a dictionary of output:intput pairs
240272 pv_mapping = config ['variables' ]
241273 self .latest_input = {}
274+ self .latest_input_struct = {}
242275 self .latest_transformed = {}
243276 self .updated = False
244277 self .input_list = list (pv_mapping .values ())
245278
246279 for key , value in pv_mapping .items ():
247280 self .latest_input [value ] = None
281+ self .latest_input_struct [value ] = None
248282 self .latest_transformed [key ] = None
249283 self .pv_mapping = pv_mapping
250284
@@ -253,6 +287,7 @@ def __init__(self, config):
253287 def handler (self , pv_name , value ):
254288 time_start = time .time ()
255289 logger .debug (f'PassThroughTransformer handler for { pv_name } ' )
290+ self .latest_input_struct [pv_name ] = dict (value )
256291 self .latest_input [pv_name ] = value ['value' ]
257292 if all ([value is not None for value in self .latest_input .values ()]):
258293 self .transform ()
@@ -263,10 +298,31 @@ def handler(self, pv_name, value):
263298 def transform (self ):
264299 logger .debug ('Transforming' )
265300 for key , value in self .pv_mapping .items ():
266- self .latest_transformed [key ] = self .latest_input [value ]
301+ input_value = self .latest_input [value ]
302+ input_struct = self .latest_input_struct .get (value )
303+ if isinstance (input_struct , dict ):
304+ passthrough_fields = {
305+ field_name : field_value
306+ for field_name , field_value in input_struct .items ()
307+ if field_name != 'value'
308+ }
309+ if passthrough_fields :
310+ self .latest_transformed [key ] = {
311+ 'value' : input_value ,
312+ ** passthrough_fields ,
313+ }
314+ else :
315+ self .latest_transformed [key ] = input_value
316+ else :
317+ self .latest_transformed [key ] = input_value
267318
268- if isinstance (self .latest_input [value ], np .ndarray ):
269- if self .latest_input [value ].shape != self .latest_transformed [key ].shape :
319+ transformed_value = (
320+ self .latest_transformed [key ]['value' ]
321+ if isinstance (self .latest_transformed [key ], dict )
322+ else self .latest_transformed [key ]
323+ )
324+ if isinstance (input_value , np .ndarray ):
325+ if input_value .shape != transformed_value .shape :
270326 logger .error (f'Shape mismatch between input and output for { key } ' )
271327 self .updated = True
272328
0 commit comments