Skip to content

Commit 53e010c

Browse files
committed
misc testing
1 parent 876775b commit 53e010c

File tree

4 files changed

+217
-225
lines changed

4 files changed

+217
-225
lines changed

src/mitim_modules/portals/PORTALStools.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99

1010
def selectSurrogate(output, surrogateOptions, CGYROrun=False):
1111

12-
print(
13-
f'\t- Selecting surrogate options for "{output}" to be run'
14-
)
12+
print(f'\t- Selecting surrogate options for "{output}" to be run')
1513

1614
if output is not None:
1715
# If it's a target, just linear

src/mitim_tools/opt_tools/BOTORCHtools.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,8 @@
1717
from botorch.models.transforms.input import InputTransform
1818
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
1919
from botorch.models.utils import validate_input_scaling
20-
from botorch.models.utils.gpytorch_modules import (
21-
get_covar_module_with_dim_scaled_prior,
22-
get_gaussian_likelihood_with_lognormal_prior,
23-
)
2420
from botorch.utils.types import DEFAULT
25-
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
26-
from gpytorch.means.constant_mean import ConstantMean
2721
from gpytorch.models.exact_gp import ExactGP
28-
from gpytorch.module import Module
2922
from torch import Tensor
3023

3124
from linear_operator.operators import CholLinearOperator, DiagLinearOperator
@@ -136,7 +129,6 @@ def __init__(
136129
batch_shape=self._aug_batch_shape, variables=variables
137130
)
138131

139-
140132
"""
141133
-----------------------------------------------------------------------
142134
GP Kernel - Covariance
@@ -562,23 +554,15 @@ def __init__(self, *gp_models):
562554
def prepareToGenerateCommons(self):
563555
self.models[0].input_transform.tf1.flag_to_store = True
564556
# Make sure that this ModelListGP evaluation is fresh
565-
if (
566-
"parameters_combined"
567-
in self.models[0].input_transform.tf1.surrogate_parameters
568-
):
569-
del self.models[0].input_transform.tf1.surrogate_parameters[
570-
"parameters_combined"
571-
]
557+
if ("surrogate_parameters" in self.models[0].input_transform.tf1.__dict__) and \
558+
("parameters_combined" in self.models[0].input_transform.tf1.surrogate_parameters):
559+
del self.models[0].input_transform.tf1.surrogate_parameters["parameters_combined"]
572560

573561
def cold_startCommons(self):
574562
self.models[0].input_transform.tf1.flag_to_store = False
575-
if (
576-
"parameters_combined"
577-
in self.models[0].input_transform.tf1.surrogate_parameters
578-
):
579-
del self.models[0].input_transform.tf1.surrogate_parameters[
580-
"parameters_combined"
581-
]
563+
if ("surrogate_parameters" in self.models[0].input_transform.tf1.__dict__) and \
564+
("parameters_combined" in self.models[0].input_transform.tf1.surrogate_parameters):
565+
del self.models[0].input_transform.tf1.surrogate_parameters["parameters_combined"]
582566

583567
def transform_inputs(self, X):
584568
self.prepareToGenerateCommons()

0 commit comments

Comments
 (0)