diff --git a/R/forde.R b/R/forde.R index e063b11..d53973e 100644 --- a/R/forde.R +++ b/R/forde.R @@ -152,8 +152,8 @@ forde <- function( classes <- sapply(x, class) x <- suppressWarnings(prep_x(x)) factor_cols <- sapply(x, is.factor) - lvls <- arf$forest$covariate.levels[factor_cols] - if (!is.null(lvls)) { + if (any(factor_cols)) { + lvls <- lapply(x[, factor_cols, drop = FALSE], levels) names(lvls) <- colnames_x[factor_cols] lvl_df <- rbindlist(lapply(seq_along(lvls), function(j) { melt(as.data.table(lvls[j]), measure.vars = names(lvls)[j], diff --git a/tests/testthat/test-return_types.R b/tests/testthat/test-return_types.R index 31a95d6..a60b34e 100644 --- a/tests/testthat/test-return_types.R +++ b/tests/testthat/test-return_types.R @@ -80,6 +80,20 @@ test_that("FORGE returns same column types", { expect_equal(classes, classes_synth) }) +test_that("FORGE returns factors with same levels (and order of levels)", { + arf <- adversarial_rf(iris, num_trees = 2, verbose = FALSE, parallel = FALSE) + psi <- forde(arf, iris, parallel = FALSE) + x_synth <- forge(psi, n_synth = 10, parallel = FALSE) + expect_equal(levels(x_synth$Species), levels(iris$Species)) +}) + +test_that("EXPCT returns factors with same levels (and order of levels)", { + arf <- adversarial_rf(iris, num_trees = 2, verbose = FALSE, parallel = FALSE) + psi <- forde(arf, iris, parallel = FALSE) + x_synth <- expct(psi, parallel = FALSE) + expect_equal(levels(x_synth$Species), levels(iris$Species)) +}) + # test_that("MAP returns proper column types", { # n <- 50 # dat <- data.frame(numeric = rnorm(n),