Skip to content

Commit

Permalink
Merge pull request #31 from bips-hb/optional_rounding
Browse files Browse the repository at this point in the history
add optional rounding for expct and forge
  • Loading branch information
mnwright authored Jun 11, 2024
2 parents c88612c + f60f464 commit 1faedb1
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 25 deletions.
8 changes: 4 additions & 4 deletions R/expct.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#' @param evidence_row_mode Interpretation of rows in multi-row evidence. If \code{'separate'},
#' each row in \code{evidence} is a separate conditioning event for which \code{n_synth} synthetic samples
#' are generated. If \code{'or'}, the rows are combined with a logical or; see Examples.
#' @param round Round continuous variables to their respective maximum precision in the real data set?
#' @param stepsize Stepsize defining number of evidence rows handled in one for each step.
#' Defaults to nrow(evidence)/num_registered_workers for \code{parallel == TRUE}.
#' @param parallel Compute in parallel? Must register backend beforehand, e.g.
Expand Down Expand Up @@ -83,6 +84,7 @@ expct <- function(
query = NULL,
evidence = NULL,
evidence_row_mode = c("separate", "or"),
round = FALSE,
stepsize = 0,
parallel = TRUE) {

Expand Down Expand Up @@ -215,7 +217,7 @@ expct <- function(

# Create dataset with expectations
x_synth <- cbind(synth_cnt, synth_cat)
x_synth <- post_x(x_synth, params)
x_synth <- post_x(x_synth, params, round)

x_synth
}
Expand All @@ -226,6 +228,4 @@ expct <- function(
}

return(x_synth_)
}


}
8 changes: 5 additions & 3 deletions R/forge.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#' @param evidence_row_mode Interpretation of rows in multi-row evidence. If \code{'separate'},
#' each row in \code{evidence} is a separate conditioning event for which \code{n_synth} synthetic samples
#' are generated. If \code{'or'}, the rows are combined with a logical or; see Examples.
#' @param round Round continuous variables to their respective maximum precision in the real data set?
#' @param sample_NAs Sample NAs respecting the probability for missing values in the original data.
#' @param stepsize Stepsize defining number of evidence rows handled in one for each step.
#' Defaults to nrow(evidence)/num_registered_workers for \code{parallel == TRUE}.
Expand Down Expand Up @@ -99,6 +100,7 @@ forge <- function(
n_synth,
evidence = NULL,
evidence_row_mode = c("separate", "or"),
round = TRUE,
sample_NAs = FALSE,
stepsize = 0,
parallel = TRUE) {
Expand Down Expand Up @@ -254,15 +256,15 @@ forge <- function(
}

# Clean up, export
x_synth <- post_x(x_synth, params)
x_synth <- post_x(x_synth, params, round)

if (sample_NAs) {
setDT(x_synth)
NA_share <- rbind(NA_share_cnt, NA_share_cat)
setorder(NA_share[,variable := factor(variable, levels = params$meta[,variable])], variable, idx)
NA_share[,dat := rbinom(.N, 1, prob = NA_share)]
x_synth[dcast(NA_share,formula = idx ~ variable, value.var = "dat")[,-"idx"] == 1] <- NA
x_synth <- post_x(x_synth, params)
x_synth <- post_x(x_synth, params, round)
}

if (evidence_row_mode == "separate" & any(omega[, is.na(f_idx)])) {
Expand All @@ -279,7 +281,7 @@ forge <- function(
x_synth[, idx := rep(indices_sampled, each = n_synth)]
x_synth <- rbind(x_synth, rows_na, fill = T)
setorder(x_synth, idx)[, idx := NULL]
x_synth <- post_x(x_synth, params)
x_synth <- post_x(x_synth, params, round)
}
x_synth
}
Expand Down
25 changes: 15 additions & 10 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ col_rename <- function(cn, old_name) {
#' @keywords internal

col_rename_all <- function(cn) {

if ('y' %in% cn) {
cn[which(cn == 'y')] <- col_rename(cn, 'y')
}
Expand Down Expand Up @@ -130,11 +130,12 @@ prep_x <- function(x) {
#'
#' @param x Input data.frame.
#' @param params Circuit parameters learned via \code{\link{forde}}.
#' @param round Round continuous variables to their respective maximum precision in the real data set?
#'
#' @import data.table
#' @keywords internal

post_x <- function(x, params) {
post_x <- function(x, params, round = TRUE) {

# To avoid data.table check issues
variable <- val <- NULL
Expand All @@ -150,7 +151,7 @@ post_x <- function(x, params) {
idx_integer <- meta_tmp[, which(class == 'integer')]

# Recode
if (sum(idx_numeric) > 0L) {
if (sum(idx_numeric) > 0L & round) {
x[, idx_numeric] <- lapply(idx_numeric, function(j) {
round(as.numeric(x[[j]]), meta_tmp$decimals[j])
})
Expand All @@ -171,7 +172,11 @@ post_x <- function(x, params) {
if (sum(idx_integer) > 0L) {
x[, idx_integer] <- lapply(idx_integer, function(j) {
if (is.numeric(x[[j]])) {
as.integer(round(x[[j]]))
if (round) {
as.integer(round(x[[j]]))
} else {
x[[j]]
}
} else {
as.integer(as.character(x[[j]]))
}
Expand Down Expand Up @@ -385,7 +390,7 @@ cforde <- function(params, evidence, row_mode = c("separate", "or"), stepsize =

# Calculate updates for cat params matching cat conditions
cat_new <- merge(merge(relevant_leaves, cat_conds, by = "c_idx", allow.cartesian = T), cat, by = c("f_idx","variable", "val"))

# Ensure probabilities sum to 1
cat_new[, cvg_factor := sum(prob), by = .(f_idx, c_idx, variable)]
cat_new[, prob := prob/cvg_factor]
Expand All @@ -407,7 +412,7 @@ cforde <- function(params, evidence, row_mode = c("separate", "or"), stepsize =
relevant_leaves <- updates_relevant_leaves$relevant_leaves[,`:=` (f_idx = .I, f_idx_uncond = f_idx)][]
cnt_new <- setcolorder(merge(relevant_leaves, updates_relevant_leaves$cnt_new, by.x = c("c_idx", "f_idx_uncond"), by.y = c("c_idx", "f_idx"), sort = F), c("f_idx","c_idx","variable","min","max","val","cvg_factor"))[]
cat_new <- setcolorder(merge(relevant_leaves, updates_relevant_leaves$cat_new, by.x = c("c_idx", "f_idx_uncond"), by.y = c("c_idx", "f_idx"), sort = F), c("f_idx","c_idx","variable","val","prob","cvg_factor"))[]

# Check for conditions with no matching leaves and handle this according to row_mode
if (relevant_leaves[,uniqueN(c_idx)] < nconds_conditioned) {
if (relevant_leaves[,uniqueN(c_idx)] == 0 & row_mode == "or") {
Expand All @@ -424,8 +429,8 @@ cforde <- function(params, evidence, row_mode = c("separate", "or"), stepsize =
setnames(forest_new, "cvg", "cvg_arf")

cvg_new <- unique(rbind(cat_new[, .(f_idx, c_idx, variable, cvg_factor)],
cnt_new[, .(f_idx, c_idx, variable, cvg_factor)]),
by = c("f_idx", "variable"))[,-"variable"]
cnt_new[, .(f_idx, c_idx, variable, cvg_factor)]),
by = c("f_idx", "variable"))[,-"variable"]

if (nrow(cvg_new) > 0) {
# Use log transformation to avoid overflow
Expand Down Expand Up @@ -480,7 +485,7 @@ cforde <- function(params, evidence, row_mode = c("separate", "or"), stepsize =
forest_new_unconditioned[, `:=` (c_idx = rep(conds_unconditioned,each = nrow(forest)), f_idx_uncond = f_idx, cvg_arf = cvg)]
forest_new <- rbind(forest_new, forest_new_unconditioned)
}

setorder(setcolorder(forest_new,c("f_idx","c_idx","f_idx_uncond","tree","leaf","cvg_arf","cvg")), c_idx, f_idx, f_idx_uncond, tree, leaf)

list(evidence_input = evidence, evidence_prepped = condition_long, cnt = cnt_new, cat = cat_new, forest = forest_new)
Expand Down Expand Up @@ -553,7 +558,7 @@ prep_cond <- function(evidence, params, row_mode) {

# Interval syntax, e.g. (X,Inf)
condition_long[(variable %in% cnt_cols) & str_detect(val, "\\("),
c("val", "min", "max") := cbind(c(NA_real_, transpose(strsplit(substr(val, 2, nchar(val) - 1), split = ","))))]
c("val", "min", "max") := cbind(c(NA_real_, transpose(strsplit(substr(val, 2, nchar(val) - 1), split = ","))))]

# >, < syntax
condition_long[(variable %in% cnt_cols) & str_detect(val, "<"),
Expand Down
6 changes: 3 additions & 3 deletions man/expct.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/forge.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/post_x.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 28 additions & 2 deletions tests/testthat/test-return_types.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,19 @@ test_that("FORGE returns matrix when called with matrix", {
expect_true(is.matrix(x_synth))
})

test_that("FORGE returns same column types", {
test_that("FORGE returns correct column types", {
n <- 50
dat <- data.frame(numeric = rnorm(n),
integer = sample(1L:5L, n, replace = TRUE),
integer_factor = sample(1L:5L, n, replace = TRUE),
integer_numeric = sample(1L:50L, n, replace = FALSE),
character = sample(letters[1:5], n, replace = TRUE),
factor = factor(sample(letters[1:5], n, replace = TRUE)),
logical = (sample(0:1, n, replace = TRUE) == 1))

expect_warning(arf <- adversarial_rf(dat, num_trees = 2, verbose = FALSE, parallel = FALSE))
psi <- forde(arf, dat, parallel = FALSE)

# with round = TRUE
x_synth <- forge(psi, n_synth = 20, parallel = FALSE)

# No NAs
Expand All @@ -78,8 +81,31 @@ test_that("FORGE returns same column types", {
classes <- sapply(dat, class)
classes_synth <- sapply(x_synth, class)
expect_equal(classes, classes_synth)

# with round = FALSE
x_synth <- forge(psi, n_synth = 20, round = FALSE, parallel = FALSE)

# Keep non-integer_numeric column types
classes <- sapply(dat, class)
classes_synth <- sapply(x_synth, class)
expect_equal(classes[-3], classes_synth[-3])
# Output integer_numeric as numeric
expect_true(classes_synth[3] == "numeric")
})

test_that("FORGE does not round to real data set precision if 'round == FALSE'", {
arf <- adversarial_rf(iris, num_trees = 2, verbose = FALSE, parallel = FALSE)
psi <- forde(arf, iris, parallel = FALSE)
x_synth <- forge(psi, n_synth = 20, round = FALSE, parallel = FALSE)
x_synth_rounded <- arf:::post_x(x_synth, psi, round = TRUE)

# Check if continuous variables were not rounded
expect_false(all(x_synth[,1:4] == x_synth_rounded[,1:4]))
expect_equal(data.frame(lapply(x_synth[,1:4], round, 1)), x_synth_rounded[,1:4])
})



# test_that("MAP returns proper column types", {
# n <- 50
# dat <- data.frame(numeric = rnorm(n),
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_expct.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ test_that("expct returns correct values", {
# For all classes
res <- expct(psi, query = "Sepal.Length", parallel = FALSE)
expect_equal(colnames(res), "Sepal.Length")
expect_equal(res$Sepal.Length, round(mean(iris$Sepal.Length), 1), tolerance = .2)
expect_equal(res$Sepal.Length, mean(iris$Sepal.Length))

# Only for setosa
res <- expct(psi, query = "Sepal.Length", evidence = data.frame(Species = "setosa"), parallel = FALSE)
expect_equal(res$Sepal.Length, round(mean(iris[iris$Species == "setosa", "Sepal.Length"]), 1))
expect_equal(res$Sepal.Length, mean(iris[iris$Species == "setosa", "Sepal.Length"]), tolerance = .1)
})

test_that("expct works for vectorized evidence", {
Expand Down

0 comments on commit 1faedb1

Please sign in to comment.