Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add optional rounding for expct and forge #31

Merged
merged 6 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions R/expct.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#' @param evidence_row_mode Interpretation of rows in multi-row evidence. If \code{'separate'},
#' each row in \code{evidence} is a separate conditioning event for which \code{n_synth} synthetic samples
#' are generated. If \code{'or'}, the rows are combined with a logical or; see Examples.
#' @param round Round continuous variables to their respective maximum precision in the real data set?
#' @param stepsize Stepsize defining number of evidence rows handled in one for each step.
#' Defaults to nrow(evidence)/num_registered_workers for \code{parallel == TRUE}.
#' @param parallel Compute in parallel? Must register backend beforehand, e.g.
Expand Down Expand Up @@ -83,6 +84,7 @@ expct <- function(
query = NULL,
evidence = NULL,
evidence_row_mode = c("separate", "or"),
round = FALSE,
stepsize = 0,
parallel = TRUE) {

Expand Down Expand Up @@ -215,7 +217,7 @@ expct <- function(

# Create dataset with expectations
x_synth <- cbind(synth_cnt, synth_cat)
x_synth <- post_x(x_synth, params)
x_synth <- post_x(x_synth, params, round)

x_synth
}
Expand All @@ -226,6 +228,4 @@ expct <- function(
}

return(x_synth_)
}


}
8 changes: 5 additions & 3 deletions R/forge.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#' @param evidence_row_mode Interpretation of rows in multi-row evidence. If \code{'separate'},
#' each row in \code{evidence} is a separate conditioning event for which \code{n_synth} synthetic samples
#' are generated. If \code{'or'}, the rows are combined with a logical or; see Examples.
#' @param round Round continuous variables to their respective maximum precision in the real data set?
#' @param sample_NAs Sample NAs respecting the probability for missing values in the original data.
#' @param stepsize Stepsize defining number of evidence rows handled in one for each step.
#' Defaults to nrow(evidence)/num_registered_workers for \code{parallel == TRUE}.
Expand Down Expand Up @@ -99,6 +100,7 @@ forge <- function(
n_synth,
evidence = NULL,
evidence_row_mode = c("separate", "or"),
round = TRUE,
sample_NAs = FALSE,
stepsize = 0,
parallel = TRUE) {
Expand Down Expand Up @@ -254,15 +256,15 @@ forge <- function(
}

# Clean up, export
x_synth <- post_x(x_synth, params)
x_synth <- post_x(x_synth, params, round)

if (sample_NAs) {
setDT(x_synth)
NA_share <- rbind(NA_share_cnt, NA_share_cat)
setorder(NA_share[,variable := factor(variable, levels = params$meta[,variable])], variable, idx)
NA_share[,dat := rbinom(.N, 1, prob = NA_share)]
x_synth[dcast(NA_share,formula = idx ~ variable, value.var = "dat")[,-"idx"] == 1] <- NA
x_synth <- post_x(x_synth, params)
x_synth <- post_x(x_synth, params, round)
}

if (evidence_row_mode == "separate" & any(omega[, is.na(f_idx)])) {
Expand All @@ -279,7 +281,7 @@ forge <- function(
x_synth[, idx := rep(indices_sampled, each = n_synth)]
x_synth <- rbind(x_synth, rows_na, fill = T)
setorder(x_synth, idx)[, idx := NULL]
x_synth <- post_x(x_synth, params)
x_synth <- post_x(x_synth, params, round)
}
x_synth
}
Expand Down
25 changes: 15 additions & 10 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ col_rename <- function(cn, old_name) {
#' @keywords internal

col_rename_all <- function(cn) {

if ('y' %in% cn) {
cn[which(cn == 'y')] <- col_rename(cn, 'y')
}
Expand Down Expand Up @@ -130,11 +130,12 @@ prep_x <- function(x) {
#'
#' @param x Input data.frame.
#' @param params Circuit parameters learned via \code{\link{forde}}.
#' @param round Round continuous variables to their respective maximum precision in the real data set?
#'
#' @import data.table
#' @keywords internal

post_x <- function(x, params) {
post_x <- function(x, params, round = TRUE) {

# To avoid data.table check issues
variable <- val <- NULL
Expand All @@ -150,7 +151,7 @@ post_x <- function(x, params) {
idx_integer <- meta_tmp[, which(class == 'integer')]

# Recode
if (sum(idx_numeric) > 0L) {
if (sum(idx_numeric) > 0L & round) {
x[, idx_numeric] <- lapply(idx_numeric, function(j) {
round(as.numeric(x[[j]]), meta_tmp$decimals[j])
})
Expand All @@ -171,7 +172,11 @@ post_x <- function(x, params) {
if (sum(idx_integer) > 0L) {
x[, idx_integer] <- lapply(idx_integer, function(j) {
if (is.numeric(x[[j]])) {
as.integer(round(x[[j]]))
if (round) {
as.integer(round(x[[j]]))
} else {
x[[j]]
}
} else {
as.integer(as.character(x[[j]]))
}
Expand Down Expand Up @@ -385,7 +390,7 @@ cforde <- function(params, evidence, row_mode = c("separate", "or"), stepsize =

# Calculate updates for cat params matching cat conditions
cat_new <- merge(merge(relevant_leaves, cat_conds, by = "c_idx", allow.cartesian = T), cat, by = c("f_idx","variable", "val"))

# Ensure probabilities sum to 1
cat_new[, cvg_factor := sum(prob), by = .(f_idx, c_idx, variable)]
cat_new[, prob := prob/cvg_factor]
Expand All @@ -407,7 +412,7 @@ cforde <- function(params, evidence, row_mode = c("separate", "or"), stepsize =
relevant_leaves <- updates_relevant_leaves$relevant_leaves[,`:=` (f_idx = .I, f_idx_uncond = f_idx)][]
cnt_new <- setcolorder(merge(relevant_leaves, updates_relevant_leaves$cnt_new, by.x = c("c_idx", "f_idx_uncond"), by.y = c("c_idx", "f_idx"), sort = F), c("f_idx","c_idx","variable","min","max","val","cvg_factor"))[]
cat_new <- setcolorder(merge(relevant_leaves, updates_relevant_leaves$cat_new, by.x = c("c_idx", "f_idx_uncond"), by.y = c("c_idx", "f_idx"), sort = F), c("f_idx","c_idx","variable","val","prob","cvg_factor"))[]

# Check for conditions with no matching leaves and handle this according to row_mode
if (relevant_leaves[,uniqueN(c_idx)] < nconds_conditioned) {
if (relevant_leaves[,uniqueN(c_idx)] == 0 & row_mode == "or") {
Expand All @@ -424,8 +429,8 @@ cforde <- function(params, evidence, row_mode = c("separate", "or"), stepsize =
setnames(forest_new, "cvg", "cvg_arf")

cvg_new <- unique(rbind(cat_new[, .(f_idx, c_idx, variable, cvg_factor)],
cnt_new[, .(f_idx, c_idx, variable, cvg_factor)]),
by = c("f_idx", "variable"))[,-"variable"]
cnt_new[, .(f_idx, c_idx, variable, cvg_factor)]),
by = c("f_idx", "variable"))[,-"variable"]

if (nrow(cvg_new) > 0) {
# Use log transformation to avoid overflow
Expand Down Expand Up @@ -480,7 +485,7 @@ cforde <- function(params, evidence, row_mode = c("separate", "or"), stepsize =
forest_new_unconditioned[, `:=` (c_idx = rep(conds_unconditioned,each = nrow(forest)), f_idx_uncond = f_idx, cvg_arf = cvg)]
forest_new <- rbind(forest_new, forest_new_unconditioned)
}

setorder(setcolorder(forest_new,c("f_idx","c_idx","f_idx_uncond","tree","leaf","cvg_arf","cvg")), c_idx, f_idx, f_idx_uncond, tree, leaf)

list(evidence_input = evidence, evidence_prepped = condition_long, cnt = cnt_new, cat = cat_new, forest = forest_new)
Expand Down Expand Up @@ -553,7 +558,7 @@ prep_cond <- function(evidence, params, row_mode) {

# Interval syntax, e.g. (X,Inf)
condition_long[(variable %in% cnt_cols) & str_detect(val, "\\("),
c("val", "min", "max") := cbind(c(NA_real_, transpose(strsplit(substr(val, 2, nchar(val) - 1), split = ","))))]
c("val", "min", "max") := cbind(c(NA_real_, transpose(strsplit(substr(val, 2, nchar(val) - 1), split = ","))))]

# >, < syntax
condition_long[(variable %in% cnt_cols) & str_detect(val, "<"),
Expand Down
6 changes: 3 additions & 3 deletions man/expct.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/forge.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/post_x.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 28 additions & 2 deletions tests/testthat/test-return_types.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,19 @@ test_that("FORGE returns matrix when called with matrix", {
expect_true(is.matrix(x_synth))
})

test_that("FORGE returns same column types", {
test_that("FORGE returns correct column types", {
n <- 50
dat <- data.frame(numeric = rnorm(n),
integer = sample(1L:5L, n, replace = TRUE),
integer_factor = sample(1L:5L, n, replace = TRUE),
integer_numeric = sample(1L:50L, n, replace = FALSE),
character = sample(letters[1:5], n, replace = TRUE),
factor = factor(sample(letters[1:5], n, replace = TRUE)),
logical = (sample(0:1, n, replace = TRUE) == 1))

expect_warning(arf <- adversarial_rf(dat, num_trees = 2, verbose = FALSE, parallel = FALSE))
psi <- forde(arf, dat, parallel = FALSE)

# with round = TRUE
x_synth <- forge(psi, n_synth = 20, parallel = FALSE)

# No NAs
Expand All @@ -78,8 +81,31 @@ test_that("FORGE returns same column types", {
classes <- sapply(dat, class)
classes_synth <- sapply(x_synth, class)
expect_equal(classes, classes_synth)

# with round = FALSE
x_synth <- forge(psi, n_synth = 20, round = FALSE, parallel = FALSE)

# Keep non-integer_numeric column types
classes <- sapply(dat, class)
classes_synth <- sapply(x_synth, class)
expect_equal(classes[-3], classes_synth[-3])
# Output integer_numeric as numeric
expect_true(classes_synth[3] == "numeric")
})

test_that("FORGE does not round to real data set precision if 'round == FALSE'", {
arf <- adversarial_rf(iris, num_trees = 2, verbose = FALSE, parallel = FALSE)
psi <- forde(arf, iris, parallel = FALSE)
x_synth <- forge(psi, n_synth = 20, round = FALSE, parallel = FALSE)
x_synth_rounded <- arf:::post_x(x_synth, psi, round = TRUE)

# Check if continuous variables were not rounded
expect_false(all(x_synth[,1:4] == x_synth_rounded[,1:4]))
expect_equal(data.frame(lapply(x_synth[,1:4], round, 1)), x_synth_rounded[,1:4])
})



# test_that("MAP returns proper column types", {
# n <- 50
# dat <- data.frame(numeric = rnorm(n),
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_expct.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ test_that("expct returns correct values", {
# For all classes
res <- expct(psi, query = "Sepal.Length", parallel = FALSE)
expect_equal(colnames(res), "Sepal.Length")
expect_equal(res$Sepal.Length, round(mean(iris$Sepal.Length), 1), tolerance = .2)
expect_equal(res$Sepal.Length, mean(iris$Sepal.Length))

# Only for setosa
res <- expct(psi, query = "Sepal.Length", evidence = data.frame(Species = "setosa"), parallel = FALSE)
expect_equal(res$Sepal.Length, round(mean(iris[iris$Species == "setosa", "Sepal.Length"]), 1))
expect_equal(res$Sepal.Length, mean(iris[iris$Species == "setosa", "Sepal.Length"]), tolerance = .1)
})

test_that("expct works for vectorized evidence", {
Expand Down
Loading