diff --git a/R/adversarial_rf.R b/R/adversarial_rf.R index 79bae3ce..15e2595e 100644 --- a/R/adversarial_rf.R +++ b/R/adversarial_rf.R @@ -128,9 +128,12 @@ adversarial_rf <- function( nodeIDs <- stats::predict(rf0, x_real, type = 'terminalNodes')$predictions tmp <- data.table('tree' = rep(seq_len(num_trees), each = n), 'leaf' = as.vector(nodeIDs)) - x_real_dt <- rbindlist(lapply(seq_len(num_trees), function(b) { + x_real_dt <- do.call(rbind, lapply(seq_len(num_trees), function(b) { cbind(x_real, tmp[tree == b]) })) + x_real_dt[, factor_cols] <- setDF( + lapply(x_real_dt[, factor_cols, drop = FALSE], as.numeric) + ) tmp <- tmp[sample(.N, n, replace = TRUE)] tmp <- unique(tmp[, cnt := .N, by = .(tree, leaf)]) draw_from <- merge(tmp, x_real_dt, by = c('tree', 'leaf'),