Skip to content

Commit 5342c6d

Browse files
committed
Working version but grad calculation extremely slow
1 parent 53e010c commit 5342c6d

File tree

4 files changed

+15
-28
lines changed

4 files changed

+15
-28
lines changed

src/mitim_tools/opt_tools/BOTORCHtools.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,6 @@ def __init__(
190190
)
191191

192192

193-
194-
195193
# TODO: Allow subsetting of other covar modules
196194
if outcome_transform is not None:
197195
self.outcome_transform = outcome_transform

src/mitim_tools/opt_tools/STEPtools.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,6 @@ def fit_step(self, avoidPoints=None, fitWithTrainingDataIfContains=None):
163163
with open(self.fileOutputs, "a") as f:
164164
f.write(f" (took total of {txt_time})")
165165

166-
embed()
167-
x = torch.rand(10_000, self.train_X.shape[-1]).to(self.dfT)
168-
with IOtools.speeder("/Users/pablorf/PROJECTS/project_2024_PORTALSdevelopment/speed/profiler_gp64.prof") as s:
169-
self.GP["combined_model"].gpmodel.posterior(x)
170-
171166
def _fit_multioutput_model(self):
172167

173168
surrogateOptions = self.surrogateOptions["selectSurrogate"]('AllMITIM', self.surrogateOptions)
@@ -359,7 +354,7 @@ def defineFunctions(self, scalarized_objective):
359354
I create this so that, upon reading a pickle, I re-call it. Otherwise, it is very heavy to store lambdas
360355
"""
361356

362-
self.evaluators = {"GP": self.GP["combined_model"]}
357+
self.evaluators = {"GP": self.GP["mo_model"]}
363358

364359
# **************************************************************************************************
365360
# Objective (multi-objective model -> single objective residual)
@@ -442,14 +437,6 @@ def residual(Y, X = None):
442437
)
443438
)
444439

445-
446-
embed()
447-
x = torch.rand(64, self.train_X.shape[-1]).to(self.dfT)
448-
with IOtools.speeder("/Users/pablorf/PROJECTS/project_2024_PORTALSdevelopment/speed/profiler_acq64.prof") as s:
449-
self.evaluators["acq_function"](x)
450-
451-
452-
453440
# **************************************************************************************************
454441
# Quick function to return components (I need this for ROOT too, since I need the components)
455442
# **************************************************************************************************

src/mitim_tools/opt_tools/SURROGATEtools.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,6 @@ def perform_model_fit(self, mll):
448448
-mll.forward(mll.model(*mll.model.train_inputs), mll.model.train_targets)
449449
.detach()
450450
]
451-
embed()
452451

453452
def callback(x, y, mll=mll):
454453
track_fval.append(y.fval)

src/mitim_tools/opt_tools/optimizers/BOTORCHoptim.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import types
32
import botorch
43
import random
54
from mitim_tools.opt_tools import OPTtools
@@ -37,6 +36,7 @@ def findOptima(fun, optimization_params = {}, writeTrajectory=False):
3736
"sample_around_best": True,
3837
"disp": 50 if read_verbose_level() == 5 else False,
3938
"seed": fun.seed,
39+
"maxiter": 100,
4040
}
4141

4242
"""
@@ -64,16 +64,19 @@ def __call__(self, x, *args, **kwargs):
6464
seq_message = f'({"sequential" if sequential_q else "joint"}) ' if q>1 else ''
6565
print(f"\t\t- Optimizing using optimize_acqf: {q = } {seq_message}, {num_restarts = }, {raw_samples = }")
6666

67-
with IOtools.timer(name = "\n\t- Optimization", name_timer = '\t\t- Time: '):
68-
x_opt, _ = botorch.optim.optimize_acqf(
69-
acq_function=fun_opt,
70-
bounds=fun.bounds_mod,
71-
raw_samples=raw_samples,
72-
q=q,
73-
sequential=sequential_q,
74-
num_restarts=num_restarts,
75-
options=options,
76-
)
67+
68+
#with IOtools.timer(name = "\n\t- Optimization", name_timer = '\t\t- Time: '):
69+
#with IOtools.speeder("/Users/pablorf/PROJECTS/project_2024_PORTALSdevelopment/speed/profiler_opt.prof") as s:
70+
x_opt, _ = botorch.optim.optimize_acqf(
71+
acq_function=fun_opt,
72+
bounds=fun.bounds_mod,
73+
raw_samples=raw_samples,
74+
q=q,
75+
sequential=sequential_q,
76+
num_restarts=num_restarts,
77+
options=options,
78+
)
79+
embed()
7780

7881
acq_evaluated = torch.Tensor(acq_evaluated)
7982

0 commit comments

Comments
 (0)