Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions R/evalTargetFun.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ evalTargetFun.OptState = function(opt.state, xs, extras) {
if (!isYValid(y2))
stopf("Y-Imputation failed. Must return a numeric of length: %i, but we got: %s",
ny, convertToShortString(y2))
if (hasAttributes(y2, "extras")) {
user.extras = attr(y2, "extras")
y2 = setAttribute(y2, "extras", NULL)
}
}
}

Expand Down
81 changes: 81 additions & 0 deletions tests/testthat/test_impute_y.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,84 @@ test_that("impute y parego", {
expect_data_frame(op, nrow = 10 + 5)
expect_true(all(is.na(op$error.message) | op$y1 == 100))
})

test_that("imputation works with extras", {

# silent mode
options(parallelMap.suppress.local.errors = TRUE)

par.set = makeNumericParamSet(len = 2L, lower = 0, upper = 3)

f1 = smoof::makeSingleObjectiveFunction(
fn = function(x) {
y = sum(x^2)
if (y < 5)
return(NA)
if (y > 15)
stop("simulated error")
attr(y, "extras") = list(.dotextra = list("pass"), x1 = x[[1]], x2 = x[[2]])
return(y)
},
par.set = par.set
)

learner = makeLearner("regr.randomForest", predict.type = "se", se.method = "sd")

n.focus.points = 100L
design1 = data.frame(x1 = rep(seq(2, 3, length.out = 5), 2), x2 = rep(c(2, 3), each = 5))
design2 = data.frame(x1 = rep(seq(1, 2, length.out = 5), 2), x2 = rep(c(1, 2), each = 5))

ctrl = makeMBOControl()
ctrl = setMBOControlTermination(ctrl, iters = 20L)
ctrl = setMBOControlInfill(ctrl, opt.focussearch.points = n.focus.points)
expect_error(mbo(f1, design1, learner, ctrl), "simulated error")
expect_error(mbo(f1, design2, learner, ctrl), "must be a numeric of length 1")

ctrl = makeMBOControl(impute.y.fun = function(x, y, opt.path) 0)
ctrl = setMBOControlTermination(ctrl, iters = 20L)
ctrl = setMBOControlInfill(ctrl, opt.focussearch.points = n.focus.points)

# valid points added first, then imputation is missing extras
expect_error(mbo(f1, design1, learner, ctrl), "Trying to add extras but missing.*x1.*x2")

# invalid points evaluated first, which are imputed w/o extras. Valid points then have unexpected extras
expect_error(mbo(f1, design2, learner, ctrl), "Trying to add unknown extra.*x1.*x2")


ctrl = makeMBOControl(impute.y.fun = function(x, y, opt.path) {
structure(20, extras = list(.dotextra = list("fail"), x1 = -x$x[[1]], x2 = -x$x[[2]]))
})

ctrl = setMBOControlTermination(ctrl, iters = 20L)
ctrl = setMBOControlInfill(ctrl, opt.focussearch.points = n.focus.points)

res = mbo(f1, design1, learner, ctrl)
expect_is(res, "MBOResult")

# Check for correct error messages
na.inds = which(getOptPathY(res$opt.path) == 20)
opex = getOptPathX(res$opt.path)
expected.error.above = which(rowSums(opex^2) > 15)
expected.error.below = which(rowSums(opex^2) < 5)
expect_set_equal(na.inds, c(expected.error.above, expected.error.below))
for (ind in 1:getOptPathLength(res$opt.path)) {
opel = getOptPathEl(res$opt.path, ind)
if (ind %in% na.inds) {
expect_equal(unname(opel$y), 20)
expect_equal(opel$extra[c(".dotextra", "x1", "x2")], list(.dotextra = list("fail"), x1 = -opex[ind, 1], x2 = -opex[ind, 2]))
} else {
expect_equal(unname(opel$y), unname(sum(opex[ind, ]^2)))
expect_equal(opel$extra[c(".dotextra", "x1", "x2")], list(.dotextra = list("pass"), x1 = opex[ind, 1], x2 = opex[ind, 2]))
}
if (ind %in% expected.error.above) {
expect_string(getOptPathErrorMessages(res$opt.path)[ind], fixed = "simulated error")
} else if (ind %in% expected.error.below) {
expect_string(getOptPathErrorMessages(res$opt.path)[ind], fixed = "mlrMBO:")
} else {
expect_equal(NA_character_, getOptPathErrorMessages(res$opt.path)[ind])
}
}

# turn off silent mode
options(parallelMap.suppress.local.errors = FALSE)
})
7 changes: 5 additions & 2 deletions tests/testthat/test_smoof_wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ context("smoof wrappers")
test_that("wrapped smoof function work", {
f = makeSphereFunction(2L)
fc = addCountingWrapper(f)
fl = addLoggingWrapper(f)
fcl = addCountingWrapper(addLoggingWrapper(f))
fl = addLoggingWrapper(f, logg.x = TRUE)
environment(fl)$logg.x = FALSE # fix for https://github.com/jakobbossek/smoof/issues/143
fcl = addLoggingWrapper(f, logg.x = TRUE)
environment(fcl)$logg.x = FALSE # ditto
fcl = addCountingWrapper(fcl)

learner = makeLearner("regr.rpart")

Expand Down