Skip to content

Commit bbc6593

Browse files
committed
good progress
1 parent e8ee02e commit bbc6593

File tree

4 files changed

+88
-144
lines changed

4 files changed

+88
-144
lines changed

src/mitim_tools/opt_tools/BOTORCHtools.py

Lines changed: 24 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,14 @@
1515
# ----------------------------------------------------------------------------------------------------------------------------
1616

1717
from botorch.models.transforms.input import InputTransform
18-
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
18+
from botorch.models.transforms.outcome import OutcomeTransform
1919
from botorch.models.utils import validate_input_scaling
20-
from botorch.utils.types import DEFAULT
21-
from gpytorch.models.exact_gp import ExactGP
2220
from torch import Tensor
23-
2421
from linear_operator.operators import CholLinearOperator, DiagLinearOperator
25-
2622
from typing import Iterable
2723
from torch.nn import ModuleDict
2824
from botorch.posteriors.gpytorch import GPyTorchPosterior
2925
from botorch.posteriors.posterior import Posterior
30-
from gpytorch.distributions import MultitaskMultivariateNormal
3126
from linear_operator.operators import BlockDiagLinearOperator
3227

3328

@@ -59,16 +54,14 @@ def __init__(
5954
f"\t\t\t- FixedNoise: {FixedNoise} (extra noise: {learn_additional_noise}), TypeMean: {TypeMean}, TypeKernel: {TypeKernel}, ConstrainNoise: {ConstrainNoise:.1e}"
6055
)
6156

62-
self.store_training(
63-
train_X,
64-
train_X_added,
65-
train_Y,
66-
train_Y_added,
67-
train_Yvar,
68-
train_Yvar_added,
69-
input_transform,
70-
outcome_transform,
71-
)
57+
# ** Store training data
58+
59+
# x, y are raw untransformed, and I want raw transformed. xa, ya are raw transformed
60+
#x_tr = input_transform["tf1"](train_X)if input_transform is not None else train_X
61+
#y_tr, yv_tr = outcome_transform["tf1"](train_X, train_Y, train_Yvar) if outcome_transform is not None else train_Y, train_Yvar
62+
#self.train_X_usedToTrain = torch.cat((train_X_added, x_tr), axis=-2)
63+
#self.train_Y_usedToTrain = torch.cat((train_Y_added, y_tr), axis=-2)
64+
#self.train_Yvar_usedToTrain = torch.cat((train_Yvar_added, yv_tr), axis=-2)
7265

7366
# Grab num_outputs
7467
self._num_outputs = train_Y.shape[-1]
@@ -91,40 +84,29 @@ def __init__(
9184
# Added points are raw transformed, so I need to normalize them
9285
if train_X_added.shape[0] > 0:
9386
train_X_added = input_transform["tf2"](train_X_added)
94-
train_Y_added, train_Yvar_added = outcome_transform["tf2"](
95-
train_Y_added, train_Yvar_added
96-
)
97-
# -----
98-
99-
train_X_usedToTrain = torch.cat((transformed_X, train_X_added), axis=0)
100-
train_Y_usedToTrain = torch.cat((train_Y, train_Y_added), axis=0)
101-
train_Yvar_usedToTrain = torch.cat((train_Yvar, train_Yvar_added), axis=0)
102-
103-
self._input_batch_shape, self._aug_batch_shape = self.get_batch_dimensions(
104-
train_X=train_X_usedToTrain, train_Y=train_Y_usedToTrain
105-
)
106-
107-
train_Y_usedToTrain = train_Y_usedToTrain.squeeze(-1)
108-
train_Yvar_usedToTrain = train_Yvar_usedToTrain.squeeze(-1)
109-
110-
#self._aug_batch_shape = train_Y.shape[:-2] #<----- New
87+
train_Y_added = outcome_transform["tf3"].untransform(train_Y_added)[0]
88+
train_Yvar_added = outcome_transform["tf3"].untransform(train_Yvar_added)[0]
89+
train_Y_added, train_Yvar_added = outcome_transform["tf3"](*outcome_transform["tf2"](train_Y_added, train_Yvar_added))
11190

91+
# -----
11292

93+
train_X_usedToTrain = torch.cat((transformed_X, train_X_added), axis=-2)
94+
train_Y_usedToTrain = torch.cat((train_Y, train_Y_added), axis=-2)
95+
train_Yvar_usedToTrain = torch.cat((train_Yvar, train_Yvar_added), axis=-2)
11396

11497
# Validate again after applying the transforms
115-
self._validate_tensor_args(X=transformed_X, Y=train_Y, Yvar=train_Yvar)
98+
self._validate_tensor_args(X=train_X_usedToTrain, Y=train_Y_usedToTrain, Yvar=train_Yvar_usedToTrain)
11699
ignore_X_dims = getattr(self, "_ignore_X_dims_scaling_check", None)
117100
validate_input_scaling(
118-
train_X=transformed_X,
119-
train_Y=train_Y,
120-
train_Yvar=train_Yvar,
101+
train_X=train_X_usedToTrain,
102+
train_Y=train_Y_usedToTrain,
103+
train_Yvar=train_Yvar_usedToTrain,
121104
ignore_X_dims=ignore_X_dims,
122105
)
123-
self._set_dimensions(train_X=train_X, train_Y=train_Y)
106+
self._set_dimensions(train_X=train_X_usedToTrain, train_Y=train_Y_usedToTrain)
124107

125-
126-
train_X, train_Y, train_Yvar = self._transform_tensor_args(
127-
X=train_X, Y=train_Y, Yvar=train_Yvar
108+
train_X_usedToTrain, train_Y_usedToTrain, train_Yvar_usedToTrain = self._transform_tensor_args(
109+
X=train_X_usedToTrain, Y=train_Y_usedToTrain, Yvar=train_Yvar_usedToTrain
128110
)
129111

130112
"""
@@ -271,26 +253,6 @@ def __init__(
271253
self.input_transform = input_transform
272254
self.to(train_X)
273255

274-
def store_training(self, x, xa, y, ya, yv, yva, input_transform, outcome_transform):
275-
276-
# x, y are raw untransformed, and I want raw transformed
277-
if input_transform is not None:
278-
x_tr = input_transform["tf1"](x)
279-
else:
280-
x_tr = x
281-
if outcome_transform is not None:
282-
y_tr, yv_tr = outcome_transform["tf1"](x, y, yv)
283-
else:
284-
y_tr, yv_tr = y, yv
285-
286-
# xa, ya are raw transformed
287-
xa_tr = xa
288-
ya_tr, yva_tr = ya, yva
289-
290-
self.train_X_usedToTrain = torch.cat((xa_tr, x_tr), axis=0)
291-
self.train_Y_usedToTrain = torch.cat((ya_tr, y_tr), axis=0)
292-
self.train_Yvar_usedToTrain = torch.cat((yva_tr, yv_tr), axis=0)
293-
294256
# Modify posterior call from BatchedMultiOutputGPyTorchModel to call posterior untransform with "X"
295257
def posterior(
296258
self,
@@ -559,6 +521,7 @@ def untransform_posterior(self, X, posterior):
559521
if i == len(self.values())-1
560522
else tf.untransform_posterior(posterior)
561523
) # Only physics transformation (tf1) takes X
524+
562525

563526
return posterior
564527

src/mitim_tools/opt_tools/STEPtools.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ def fit_step(self, avoidPoints=None, fit_output_contains=None):
101101

102102
time1 = datetime.datetime.now()
103103

104-
#self._fit_multioutput_model()
105-
self._fit_individual_models(fit_output_contains=fit_output_contains)
106-
104+
self._fit_multioutput_model(); self.GP["combined_model"] = self.GP["mo_model"]
105+
#self._fit_individual_models(fit_output_contains=fit_output_contains)
106+
107107
txt_time = IOtools.getTimeDifference(time1)
108108
print(f"--> Fitting of all models took {txt_time}")
109109
if self.fileOutputs is not None:

src/mitim_tools/opt_tools/SURROGATEtools.py

Lines changed: 51 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,17 @@ def __init__(
163163
dimTransformedDV_x, dimTransformedDV_y = self._define_MITIM_transformations()
164164
# ------------------------------------------------------------------------------------------------------------
165165

166-
self.train_X_added_full = torch.empty((0, dimTransformedDV_x_full)).to(self.dfT)
167-
self.train_X_added = torch.empty((0, dimTransformedDV_x)).to(self.dfT)
168-
self.train_Y_added = torch.empty((0, dimTransformedDV_y)).to(self.dfT)
169-
self.train_Yvar_added = torch.empty((0, dimTransformedDV_y)).to(self.dfT)
166+
x_transformed = input_transform_physics(self.train_X)
167+
shape = list(x_transformed.shape)
168+
shape[-2] = 0
169+
shape[-1] = dimTransformedDV_x_full
170+
171+
self.train_X_added_full = torch.empty(*shape).to(self.dfT)
172+
shape[-1] = dimTransformedDV_x
173+
self.train_X_added = torch.empty(*shape).to(self.dfT)
174+
shape[-1] = 1
175+
self.train_Y_added = torch.empty(*shape).to(self.dfT)
176+
self.train_Yvar_added = torch.empty(*shape).to(self.dfT)
170177

171178
# --------------------------------------------------------------------------------------
172179
# Make sure that very small variations are not captured
@@ -202,17 +209,7 @@ def __init__(
202209
if (self.fileTraining is not None) and (self.train_X.shape[0] + self.train_X_added.shape[0] > 0):
203210
self.write_datafile(input_transform_physics, outcome_transform_physics)
204211

205-
# -------------------------------------------------------------------------------------
206-
# Obtain normalization constants now (although during training this is messed up, so needed later too)
207-
# -------------------------------------------------------------------------------------
208212

209-
self.normalization_pass(
210-
input_transform_physics,
211-
input_transform_normalization,
212-
outcome_transform_physics,
213-
output_transformed_standardization,
214-
)
215-
216213
# ------------------------------------------------------------------------------------
217214
# Combine transformations in chain of PHYSICS + NORMALIZATION
218215
# ------------------------------------------------------------------------------------
@@ -222,9 +219,15 @@ def __init__(
222219
).to(self.dfT)
223220

224221
outcome_transform = BOTORCHtools.ChainedOutcomeTransform(
225-
tf1=outcome_transform_physics, tf2=output_transformed_standardization #, tf3=BOTORCHtools.OutcomeToBatchDimension()
222+
tf1=outcome_transform_physics, tf2=output_transformed_standardization, tf3=BOTORCHtools.OutcomeToBatchDimension()
226223
).to(self.dfT)
227224

225+
# -------------------------------------------------------------------------------------
226+
# Obtain normalization constants now (although during training this is messed up, so needed later too)
227+
# -------------------------------------------------------------------------------------
228+
229+
self.normalization_pass(input_transform, outcome_transform)
230+
228231
self.variables = (
229232
self.surrogate_transformation_variables[self.outputs[0]]
230233
if (
@@ -305,7 +308,7 @@ def _define_MITIM_transformations(self):
305308
# Broadcast the input transformation to all outputs
306309
# ------------------------------------------------------------------------------------
307310

308-
input_transformation_physics = input_transformations_physics[0] #BOTORCHtools.BatchBroadcastedInputTransform(input_transformations_physics)
311+
input_transformation_physics = BOTORCHtools.BatchBroadcastedInputTransform(input_transformations_physics)
309312

310313
transformed_X = input_transformation_physics(self.train_X)
311314

@@ -331,41 +334,41 @@ def _define_MITIM_transformations(self):
331334
output_transformed_standardization, \
332335
dimTransformedDV_x, dimTransformedDV_y
333336

334-
def normalization_pass(
335-
self,
336-
input_transform_physics,
337-
input_transform_normalization,
338-
outcome_transform_physics,
339-
outcome_transform_normalization,
340-
):
341-
input_transform_normalization.training = True
342-
outcome_transform_normalization.training = True
343-
outcome_transform_normalization._is_trained = torch.tensor(False)
337+
def normalization_pass(self,input_transform, outcome_transform):
338+
'''
339+
The goal of this is to capture NOW the normalization and standardization constants,
340+
by account for both the actual data and the added data from file
341+
'''
344342

345-
train_X_transformed = input_transform_physics(self.train_X)
346-
train_Y_transformed, train_Yvar_transformed = outcome_transform_physics(self.train_X, self.train_Y, self.train_Yvar)
343+
# Get input normalization and outcome standardization in training mode
344+
input_transform['tf2'].training = True
345+
outcome_transform['tf2'].training = True
346+
outcome_transform['tf2']._is_trained = torch.tensor(False)
347347

348-
train_X_transformed = torch.cat(
349-
(input_transform_physics(self.train_X), self.train_X_added), axis=0
350-
)
351-
y, yvar = outcome_transform_physics(self.train_X, self.train_Y, self.train_Yvar)
352-
train_Y_transformed = torch.cat((y, self.train_Y_added), axis=0)
353-
train_Yvar_transformed = torch.cat((yvar, self.train_Yvar_added), axis=0)
348+
# Get the input normalization constants by physics-transforming the train_x and adding the data from file
349+
train_X_transformed = input_transform['tf1'](self.train_X)
350+
train_X_transformed = torch.cat((train_X_transformed, self.train_X_added), axis=-2)
351+
_ = input_transform['tf2'](train_X_transformed)
354352

355-
train_X_transformed_norm = input_transform_normalization(train_X_transformed)
356-
(
357-
train_Y_transformed_norm,
358-
train_Yvar_transformed_norm,
359-
) = outcome_transform_normalization(train_Y_transformed, train_Yvar_transformed)
353+
# Get the outcome standardization constants by physics-transforming the train_y and adding the data from file
354+
# With the caveat that the added points have to not be batched
355+
train_Y_transformed, train_Yvar_transformed = outcome_transform['tf1'](self.train_X, self.train_Y, self.train_Yvar)
356+
y, yvar = outcome_transform['tf1'](self.train_X, self.train_Y, self.train_Yvar)
357+
358+
train_Y_transformed = torch.cat((y, outcome_transform['tf3'].untransform(self.train_Y_added)[0]), axis=-2)
359+
train_Yvar_transformed = torch.cat((yvar, outcome_transform['tf3'].untransform(self.train_Yvar_added)[0]), axis=0)
360+
361+
train_Y_transformed_norm, train_Yvar_transformed_norm = outcome_transform['tf2'](train_Y_transformed, train_Yvar_transformed)
360362

361363
# Make sure they are not on training mode
362-
input_transform_normalization.training = False
363-
outcome_transform_normalization.training = False
364-
outcome_transform_normalization._is_trained = torch.tensor(True)
364+
input_transform['tf2'].training = False
365+
outcome_transform['tf2'].training = False
366+
outcome_transform['tf2']._is_trained = torch.tensor(True)
367+
365368

366369
def fit(self):
367370
print(
368-
f"\t- Fitting model to {self.train_X.shape[0]+self.train_X_added.shape[0]} points"
371+
f"\t- Fitting model to {self.train_X.shape[-2]+self.train_X_added.shape[-2]} points"
369372
)
370373

371374
# ---------------------------------------------------------------------------------------------------
@@ -398,8 +401,6 @@ def fit(self):
398401
with fundamental_model_context(self):
399402
track_fval = self.perform_model_fit(mll)
400403

401-
embed()
402-
403404
# ---------------------------------------------------------------------------------------------------
404405
# Asses optimization
405406
# ---------------------------------------------------------------------------------------------------
@@ -409,12 +410,7 @@ def fit(self):
409410
# Go back to definining the right normalizations, because the optimizer has to work on training mode...
410411
# ---------------------------------------------------------------------------------------------------
411412

412-
self.normalization_pass(
413-
self.gpmodel.input_transform["tf1"],
414-
self.gpmodel.input_transform["tf2"],
415-
self.gpmodel.outcome_transform["tf1"],
416-
self.gpmodel.outcome_transform["tf2"],
417-
)
413+
self.normalization_pass(self.gpmodel.input_transform, self.gpmodel.outcome_transform)
418414

419415
def perform_model_fit(self, mll):
420416
self.gpmodel.train()
@@ -903,29 +899,17 @@ def __init__(self, surrogate_model):
903899

904900
def __enter__(self):
905901
# Works for individual models, not ModelList
906-
self.surrogate_model.gpmodel.input_transform.tf1.flag_to_evaluate = False
902+
for i in range(len(self.surrogate_model.gpmodel.input_transform.tf1.transforms)):
903+
self.surrogate_model.gpmodel.input_transform.tf1.transforms[i].flag_to_evaluate = False
907904
self.surrogate_model.gpmodel.outcome_transform.tf1.flag_to_evaluate = False
908905

909906
return self.surrogate_model
910907

911908
def __exit__(self, *args):
912-
self.surrogate_model.gpmodel.input_transform.tf1.flag_to_evaluate = True
909+
for i in range(len(self.surrogate_model.gpmodel.input_transform.tf1.transforms)):
910+
self.surrogate_model.gpmodel.input_transform.tf1.transforms[i].flag_to_evaluate = True
913911
self.surrogate_model.gpmodel.outcome_transform.tf1.flag_to_evaluate = True
914912

915-
# def __enter__(self):
916-
# # Works for individual models, not ModelList
917-
# embed()
918-
# for i in range(len(self.surrogate_model.gpmodel.input_transform.tf1.transforms)):
919-
# self.surrogate_model.gpmodel.input_transform.tf1.transforms[i].flag_to_evaluate = False
920-
# self.surrogate_model.gpmodel.outcome_transform.tf1.flag_to_evaluate = False
921-
922-
# return self.surrogate_model
923-
924-
# def __exit__(self, *args):
925-
# for i in range(len(self.surrogate_model.gpmodel.input_transform.tf1.transforms)):
926-
# self.surrogate_model.gpmodel.input_transform.tf1.transforms[i].flag_to_evaluate = True
927-
# self.surrogate_model.gpmodel.outcome_transform.tf1.flag_to_evaluate = True
928-
929913
def create_df_portals(x, y, yvar, x_names, output, max_x = 20):
930914

931915
new_data = []

src/mitim_tools/opt_tools/optimizers/BOTORCHoptim.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def findOptima(fun, optimization_params = {}, writeTrajectory=False):
3636
"sample_around_best": True,
3737
"disp": 50 if read_verbose_level() == 5 else False,
3838
"seed": fun.seed,
39-
"maxiter": 100,
4039
}
4140

4241
"""
@@ -64,18 +63,16 @@ def __call__(self, x, *args, **kwargs):
6463
seq_message = f'({"sequential" if sequential_q else "joint"}) ' if q>1 else ''
6564
print(f"\t\t- Optimizing using optimize_acqf: {q = } {seq_message}, {num_restarts = }, {raw_samples = }")
6665

67-
68-
#with IOtools.timer(name = "\n\t- Optimization", name_timer = '\t\t- Time: '):
69-
#with IOtools.speeder("/Users/pablorf/PROJECTS/project_2024_PORTALSdevelopment/speed/profiler_opt.prof") as s:
70-
x_opt, _ = botorch.optim.optimize_acqf(
71-
acq_function=fun_opt,
72-
bounds=fun.bounds_mod,
73-
raw_samples=raw_samples,
74-
q=q,
75-
sequential=sequential_q,
76-
num_restarts=num_restarts,
77-
options=options,
78-
)
66+
with IOtools.timer(name = "\n\t- Optimization", name_timer = '\t\t- Time: '):
67+
x_opt, _ = botorch.optim.optimize_acqf(
68+
acq_function=fun_opt,
69+
bounds=fun.bounds_mod,
70+
raw_samples=raw_samples,
71+
q=q,
72+
sequential=sequential_q,
73+
num_restarts=num_restarts,
74+
options=options,
75+
)
7976
embed()
8077

8178
acq_evaluated = torch.Tensor(acq_evaluated)

0 commit comments

Comments
 (0)