@@ -163,10 +163,17 @@ def __init__(
163163 dimTransformedDV_x , dimTransformedDV_y = self ._define_MITIM_transformations ()
164164 # ------------------------------------------------------------------------------------------------------------
165165
166- self .train_X_added_full = torch .empty ((0 , dimTransformedDV_x_full )).to (self .dfT )
167- self .train_X_added = torch .empty ((0 , dimTransformedDV_x )).to (self .dfT )
168- self .train_Y_added = torch .empty ((0 , dimTransformedDV_y )).to (self .dfT )
169- self .train_Yvar_added = torch .empty ((0 , dimTransformedDV_y )).to (self .dfT )
166+ x_transformed = input_transform_physics (self .train_X )
167+ shape = list (x_transformed .shape )
168+ shape [- 2 ] = 0
169+ shape [- 1 ] = dimTransformedDV_x_full
170+
171+ self .train_X_added_full = torch .empty (* shape ).to (self .dfT )
172+ shape [- 1 ] = dimTransformedDV_x
173+ self .train_X_added = torch .empty (* shape ).to (self .dfT )
174+ shape [- 1 ] = 1
175+ self .train_Y_added = torch .empty (* shape ).to (self .dfT )
176+ self .train_Yvar_added = torch .empty (* shape ).to (self .dfT )
170177
171178 # --------------------------------------------------------------------------------------
172179 # Make sure that very small variations are not captured
@@ -202,17 +209,7 @@ def __init__(
202209 if (self .fileTraining is not None ) and (self .train_X .shape [0 ] + self .train_X_added .shape [0 ] > 0 ):
203210 self .write_datafile (input_transform_physics , outcome_transform_physics )
204211
205- # -------------------------------------------------------------------------------------
206- # Obtain normalization constants now (although during training this is messed up, so needed later too)
207- # -------------------------------------------------------------------------------------
208212
209- self .normalization_pass (
210- input_transform_physics ,
211- input_transform_normalization ,
212- outcome_transform_physics ,
213- output_transformed_standardization ,
214- )
215-
216213 # ------------------------------------------------------------------------------------
217214 # Combine transformations in chain of PHYSICS + NORMALIZATION
218215 # ------------------------------------------------------------------------------------
@@ -222,9 +219,15 @@ def __init__(
222219 ).to (self .dfT )
223220
224221 outcome_transform = BOTORCHtools .ChainedOutcomeTransform (
225- tf1 = outcome_transform_physics , tf2 = output_transformed_standardization # , tf3=BOTORCHtools.OutcomeToBatchDimension()
222+ tf1 = outcome_transform_physics , tf2 = output_transformed_standardization , tf3 = BOTORCHtools .OutcomeToBatchDimension ()
226223 ).to (self .dfT )
227224
225+ # -------------------------------------------------------------------------------------
226+ # Obtain normalization constants now (although during training this is messed up, so needed later too)
227+ # -------------------------------------------------------------------------------------
228+
229+ self .normalization_pass (input_transform , outcome_transform )
230+
228231 self .variables = (
229232 self .surrogate_transformation_variables [self .outputs [0 ]]
230233 if (
@@ -305,7 +308,7 @@ def _define_MITIM_transformations(self):
305308 # Broadcast the input transformation to all outputs
306309 # ------------------------------------------------------------------------------------
307310
308- input_transformation_physics = input_transformations_physics [ 0 ] # BOTORCHtools.BatchBroadcastedInputTransform(input_transformations_physics)
311+ input_transformation_physics = BOTORCHtools .BatchBroadcastedInputTransform (input_transformations_physics )
309312
310313 transformed_X = input_transformation_physics (self .train_X )
311314
@@ -331,41 +334,41 @@ def _define_MITIM_transformations(self):
331334 output_transformed_standardization , \
332335 dimTransformedDV_x , dimTransformedDV_y
333336
334- def normalization_pass (
335- self ,
336- input_transform_physics ,
337- input_transform_normalization ,
338- outcome_transform_physics ,
339- outcome_transform_normalization ,
340- ):
341- input_transform_normalization .training = True
342- outcome_transform_normalization .training = True
343- outcome_transform_normalization ._is_trained = torch .tensor (False )
337+ def normalization_pass (self ,input_transform , outcome_transform ):
338+ '''
339+ The goal of this is to capture NOW the normalization and standardization constants,
340+ by account for both the actual data and the added data from file
341+ '''
344342
345- train_X_transformed = input_transform_physics (self .train_X )
346- train_Y_transformed , train_Yvar_transformed = outcome_transform_physics (self .train_X , self .train_Y , self .train_Yvar )
343+ # Get input normalization and outcome standardization in training mode
344+ input_transform ['tf2' ].training = True
345+ outcome_transform ['tf2' ].training = True
346+ outcome_transform ['tf2' ]._is_trained = torch .tensor (False )
347347
348- train_X_transformed = torch .cat (
349- (input_transform_physics (self .train_X ), self .train_X_added ), axis = 0
350- )
351- y , yvar = outcome_transform_physics (self .train_X , self .train_Y , self .train_Yvar )
352- train_Y_transformed = torch .cat ((y , self .train_Y_added ), axis = 0 )
353- train_Yvar_transformed = torch .cat ((yvar , self .train_Yvar_added ), axis = 0 )
348+ # Get the input normalization constants by physics-transforming the train_x and adding the data from file
349+ train_X_transformed = input_transform ['tf1' ](self .train_X )
350+ train_X_transformed = torch .cat ((train_X_transformed , self .train_X_added ), axis = - 2 )
351+ _ = input_transform ['tf2' ](train_X_transformed )
354352
355- train_X_transformed_norm = input_transform_normalization (train_X_transformed )
356- (
357- train_Y_transformed_norm ,
358- train_Yvar_transformed_norm ,
359- ) = outcome_transform_normalization (train_Y_transformed , train_Yvar_transformed )
353+ # Get the outcome standardization constants by physics-transforming the train_y and adding the data from file
354+ # With the caveat that the added points have to not be batched
355+ train_Y_transformed , train_Yvar_transformed = outcome_transform ['tf1' ](self .train_X , self .train_Y , self .train_Yvar )
356+ y , yvar = outcome_transform ['tf1' ](self .train_X , self .train_Y , self .train_Yvar )
357+
358+ train_Y_transformed = torch .cat ((y , outcome_transform ['tf3' ].untransform (self .train_Y_added )[0 ]), axis = - 2 )
359+ train_Yvar_transformed = torch .cat ((yvar , outcome_transform ['tf3' ].untransform (self .train_Yvar_added )[0 ]), axis = 0 )
360+
361+ train_Y_transformed_norm , train_Yvar_transformed_norm = outcome_transform ['tf2' ](train_Y_transformed , train_Yvar_transformed )
360362
361363 # Make sure they are not on training mode
362- input_transform_normalization .training = False
363- outcome_transform_normalization .training = False
364- outcome_transform_normalization ._is_trained = torch .tensor (True )
364+ input_transform ['tf2' ].training = False
365+ outcome_transform ['tf2' ].training = False
366+ outcome_transform ['tf2' ]._is_trained = torch .tensor (True )
367+
365368
366369 def fit (self ):
367370 print (
368- f"\t - Fitting model to { self .train_X .shape [0 ]+ self .train_X_added .shape [0 ]} points"
371+ f"\t - Fitting model to { self .train_X .shape [- 2 ]+ self .train_X_added .shape [- 2 ]} points"
369372 )
370373
371374 # ---------------------------------------------------------------------------------------------------
@@ -398,8 +401,6 @@ def fit(self):
398401 with fundamental_model_context (self ):
399402 track_fval = self .perform_model_fit (mll )
400403
401- embed ()
402-
403404 # ---------------------------------------------------------------------------------------------------
404405 # Asses optimization
405406 # ---------------------------------------------------------------------------------------------------
@@ -409,12 +410,7 @@ def fit(self):
409410 # Go back to definining the right normalizations, because the optimizer has to work on training mode...
410411 # ---------------------------------------------------------------------------------------------------
411412
412- self .normalization_pass (
413- self .gpmodel .input_transform ["tf1" ],
414- self .gpmodel .input_transform ["tf2" ],
415- self .gpmodel .outcome_transform ["tf1" ],
416- self .gpmodel .outcome_transform ["tf2" ],
417- )
413+ self .normalization_pass (self .gpmodel .input_transform , self .gpmodel .outcome_transform )
418414
419415 def perform_model_fit (self , mll ):
420416 self .gpmodel .train ()
@@ -903,29 +899,17 @@ def __init__(self, surrogate_model):
903899
904900 def __enter__ (self ):
905901 # Works for individual models, not ModelList
906- self .surrogate_model .gpmodel .input_transform .tf1 .flag_to_evaluate = False
902+ for i in range (len (self .surrogate_model .gpmodel .input_transform .tf1 .transforms )):
903+ self .surrogate_model .gpmodel .input_transform .tf1 .transforms [i ].flag_to_evaluate = False
907904 self .surrogate_model .gpmodel .outcome_transform .tf1 .flag_to_evaluate = False
908905
909906 return self .surrogate_model
910907
911908 def __exit__ (self , * args ):
912- self .surrogate_model .gpmodel .input_transform .tf1 .flag_to_evaluate = True
909+ for i in range (len (self .surrogate_model .gpmodel .input_transform .tf1 .transforms )):
910+ self .surrogate_model .gpmodel .input_transform .tf1 .transforms [i ].flag_to_evaluate = True
913911 self .surrogate_model .gpmodel .outcome_transform .tf1 .flag_to_evaluate = True
914912
915- # def __enter__(self):
916- # # Works for individual models, not ModelList
917- # embed()
918- # for i in range(len(self.surrogate_model.gpmodel.input_transform.tf1.transforms)):
919- # self.surrogate_model.gpmodel.input_transform.tf1.transforms[i].flag_to_evaluate = False
920- # self.surrogate_model.gpmodel.outcome_transform.tf1.flag_to_evaluate = False
921-
922- # return self.surrogate_model
923-
924- # def __exit__(self, *args):
925- # for i in range(len(self.surrogate_model.gpmodel.input_transform.tf1.transforms)):
926- # self.surrogate_model.gpmodel.input_transform.tf1.transforms[i].flag_to_evaluate = True
927- # self.surrogate_model.gpmodel.outcome_transform.tf1.flag_to_evaluate = True
928-
929913def create_df_portals (x , y , yvar , x_names , output , max_x = 20 ):
930914
931915 new_data = []
0 commit comments