Skip to content

Commit 26650ff

Browse files
authored
fix: predict_newdata without coordinates (#102)
* fix: predict_newdata without coordinates * ...
1 parent 9f8a522 commit 26650ff

File tree

9 files changed

+75
-0
lines changed

9 files changed

+75
-0
lines changed

DESCRIPTION

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ Suggests:
4242
rpart,
4343
stars (>= 0.5-5),
4444
testthat (>= 3.0.0)
45+
Remotes:
46+
mlr-org/mlr3
4547
VignetteBuilder:
4648
knitr
4749
Config/testthat/edition: 3

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ export(mask_stack)
4848
export(numeric_layer)
4949
export(predict_spatial)
5050
export(sample_stack)
51+
export(weights_layer)
5152
export(write_raster)
5253
import(checkmate)
5354
import(data.table)

R/data.R

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,24 @@ factor_layer = function(id, levels, in_memory = FALSE) {
4343
list(id = id, type = "factor", levels = levels, in_memory = in_memory)
4444
}
4545

46+
#' @title Weights Layer Generator
47+
#'
48+
#' @description
49+
#' Generates a weights layer when passed to [generate_stack()].
50+
#'
51+
#' @param id (`character(1)`)\cr
52+
#' Layer id.
53+
#' @param in_memory (`logical(1)`)\cr
54+
#' If `FALSE` (default), layer is written to disk.
55+
#'
56+
#' @keywords internal
57+
#' @export
58+
weights_layer = function(id, in_memory = FALSE) {
59+
assert_string(id)
60+
assert_flag(in_memory)
61+
list(id = id, type = "weights", in_memory = in_memory)
62+
}
63+
4664
#' @title Generate Raster Stack
4765
#'
4866
#' @description
@@ -96,6 +114,15 @@ generate_stack = function(layers, layer_size = NULL, dimension = NULL, multi_lay
96114
ras = rast(filename)
97115
}
98116
ras
117+
} else if (layer$type == "weights") {
118+
data = matrix(runif(dimension^2, 0, 1), nrow = dimension)
119+
ras = rast(data)
120+
if (!layer$in_memory && !multi_layer_file) {
121+
filename = tempfile(fileext = ".tif")
122+
writeRaster(ras, filename)
123+
ras = rast(filename)
124+
}
125+
ras
99126
}
100127
})
101128

R/zzz.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@
7171
x$default_measures$classif_st = "classif.ce"
7272
x$default_measures$regr_st = "regr.mse"
7373

74+
x$task_col_roles_optional_newdata$classif_st = c(x$task_col_roles_optional_newdata$classif, c("coordinate", "space", "time"))
75+
x$task_col_roles_optional_newdata$regr_st = c(x$task_col_roles_optional_newdata$regr, c("coordinate", "space", "time"))
76+
7477
# task
7578
x = getFromNamespace("mlr_tasks", ns = "mlr3")
7679
x$add("leipzig", load_task_leipzig)

man/TaskClassifST.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/TaskRegrST.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/weights_layer.Rd

Lines changed: 19 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/helper.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
lapply(list.files(system.file("testthat", package = "mlr3"), pattern = "^helper.*\\.[rR]", full.names = TRUE), source)

tests/testthat/test_LearnerClassifSpatial.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,23 @@ test_that("LearnerClassifSpatial ignores observations with missing values", {
2222
expect_true(all(is.na(pred$response[seq(100)])))
2323
expect_numeric(pred$response, any.missing = TRUE, all.missing = FALSE)
2424
})
25+
26+
test_that("LearnerClassifSpatial predicts newdata without optional column roles", {
27+
stack = generate_stack(list(
28+
numeric_layer("x_1"),
29+
weights_layer("weights"),
30+
factor_layer("y", levels = c("a", "b"))),
31+
dimension = 100)
32+
vector = sample_stack(stack, n = 100)
33+
task_train = as_task_classif_st(vector, id = "test_vector", target = "y")
34+
task_train$set_col_roles("weights", roles = "weights_learner")
35+
36+
learner = lrn("classif.rpart")
37+
learner$train(task_train)
38+
39+
stack$weights = NULL
40+
task_predict = as_task_unsupervised(stack, id = "test")
41+
pred = learner$predict_newdata(task_predict$data())
42+
43+
expect_prediction(pred)
44+
})

0 commit comments

Comments
 (0)