Skip to content

Commit

Permalink
add optional rounding for expct and forge
Browse files Browse the repository at this point in the history
  • Loading branch information
jkapar committed Jun 10, 2024
1 parent c88612c commit 1cdd872
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 31 deletions.
23 changes: 11 additions & 12 deletions R/expct.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#' @param evidence_row_mode Interpretation of rows in multi-row evidence. If \code{'separate'},
#' each row in \code{evidence} is a separate conditioning event for which \code{n_synth} synthetic samples
#' are generated. If \code{'or'}, the rows are combined with a logical or; see Examples.
#' @param round Round continuous variables to their respective maximum precision in the real data set?
#' @param stepsize Stepsize defining number of evidence rows handled in one for each step.
#' Defaults to nrow(evidence)/num_registered_workers for \code{parallel == TRUE}.
#' @param parallel Compute in parallel? Must register backend beforehand, e.g.
Expand All @@ -33,9 +34,6 @@
#' distribution over leaves, with columns \code{f_idx} and \code{wt}. This may
#' be preferable for complex constraints. See Examples.
#'
#' Please note that results for continuous features which are both included in \code{query} and in
#' \code{evidence} with an interval condition are currently inconsistent.
#'
#' @return
#' A one row data frame with values for all query variables.
#'
Expand Down Expand Up @@ -83,6 +81,7 @@ expct <- function(
query = NULL,
evidence = NULL,
evidence_row_mode = c("separate", "or"),
round = FALSE,
stepsize = 0,
parallel = TRUE) {

Expand Down Expand Up @@ -176,18 +175,17 @@ expct <- function(
} else {
psi_cond <- merge(omega, cparams$cnt[variable %in% query, -c("cvg_factor", "f_idx_uncond")], by = c('c_idx', 'f_idx'),
sort = FALSE, allow.cartesian = TRUE)[prob > 0,]
# calculate absolute weights for sub-leaf areas (resulting from within-row or-conditions)
# draw sub-leaf areas (resulting from within-row or-conditions)
if(any(psi_cond[,prob != 1])) {
psi_cond[, wt := wt*prob]
psi_cond[, I := seq_len(.N), by = .(variable, idx)]
} else {
psi_cond[, I := 1]
psi_cond[, I := .I]
psi_cond <- psi_cond[sort(c(psi_cond[prob == 1, I],
psi_cond[prob > 0 & prob < 1, fifelse(.N > 1, resample(I, 1, prob = prob), 0), by = .(variable, idx)][,V1])), -"I"]
}
psi_cond[, prob := NULL]
}
psi <- unique(rbind(psi_cond,
merge(omega, params$cnt[variable %in% query, ], by.x = 'f_idx_uncond', by.y = 'f_idx',
sort = FALSE, allow.cartesian = TRUE)[,`:=` (val = NA_real_, I = 1)]), by = c("c_idx", "f_idx", "variable", "I"))[, I := NULL]
sort = FALSE, allow.cartesian = TRUE)[,val := NA_real_]), by = c("c_idx", "f_idx", "variable"))
psi[NA_share == 1, wt := 0]
cnt <- psi[is.na(val), val := sum(wt * mu)/sum(wt), by = .(c_idx, variable)]
cnt <- unique(cnt[, .(c_idx, variable, val)])
Expand All @@ -208,14 +206,15 @@ expct <- function(
psi <- rbind(psi_cond, psi_uncond_relevant)
}
psi[NA_share == 1, wt := 0]
cat <- psi[, sum(wt * prob), by = .(c_idx, variable, val)]
cat <- setDT(cat)[, .SD[which.max.random(V1)], by = .(c_idx, variable)]
psi[prob < 1, prob := sum(wt * prob)/sum(wt), by = .(c_idx, variable, val)]
cat <- setDT(psi)[, .SD[which.max.random(prob)], by = .(c_idx, variable)]
cat <- unique(cat[, .(c_idx, variable, val)])
synth_cat <- dcast(cat, c_idx ~ variable, value.var = 'val')[, c_idx := NULL]
}

# Create dataset with expectations
x_synth <- cbind(synth_cnt, synth_cat)
x_synth <- post_x(x_synth, params)
x_synth <- post_x(x_synth, params, round)

x_synth
}
Expand Down
8 changes: 5 additions & 3 deletions R/forge.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#' @param evidence_row_mode Interpretation of rows in multi-row evidence. If \code{'separate'},
#' each row in \code{evidence} is a separate conditioning event for which \code{n_synth} synthetic samples
#' are generated. If \code{'or'}, the rows are combined with a logical or; see Examples.
#' @param round Round continuous variables to their respective maximum precision in the real data set?
#' @param sample_NAs Sample NAs respecting the probability for missing values in the original data.
#' @param stepsize Stepsize defining number of evidence rows handled in one for each step.
#' Defaults to nrow(evidence)/num_registered_workers for \code{parallel == TRUE}.
Expand Down Expand Up @@ -99,6 +100,7 @@ forge <- function(
n_synth,
evidence = NULL,
evidence_row_mode = c("separate", "or"),
round = TRUE,
sample_NAs = FALSE,
stepsize = 0,
parallel = TRUE) {
Expand Down Expand Up @@ -254,15 +256,15 @@ forge <- function(
}

# Clean up, export
x_synth <- post_x(x_synth, params)
x_synth <- post_x(x_synth, params, round)

if (sample_NAs) {
setDT(x_synth)
NA_share <- rbind(NA_share_cnt, NA_share_cat)
setorder(NA_share[,variable := factor(variable, levels = params$meta[,variable])], variable, idx)
NA_share[,dat := rbinom(.N, 1, prob = NA_share)]
x_synth[dcast(NA_share,formula = idx ~ variable, value.var = "dat")[,-"idx"] == 1] <- NA
x_synth <- post_x(x_synth, params)
x_synth <- post_x(x_synth, params, round)
}

if (evidence_row_mode == "separate" & any(omega[, is.na(f_idx)])) {
Expand All @@ -279,7 +281,7 @@ forge <- function(
x_synth[, idx := rep(indices_sampled, each = n_synth)]
x_synth <- rbind(x_synth, rows_na, fill = T)
setorder(x_synth, idx)[, idx := NULL]
x_synth <- post_x(x_synth, params)
x_synth <- post_x(x_synth, params, round)
}
x_synth
}
Expand Down
21 changes: 11 additions & 10 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ col_rename <- function(cn, old_name) {
#' @keywords internal

col_rename_all <- function(cn) {

if ('y' %in% cn) {
cn[which(cn == 'y')] <- col_rename(cn, 'y')
}
Expand Down Expand Up @@ -130,11 +130,12 @@ prep_x <- function(x) {
#'
#' @param x Input data.frame.
#' @param params Circuit parameters learned via \code{\link{forde}}.
#' @param round Round continuous variables to their respective maximum precision in the real data set?
#'
#' @import data.table
#' @keywords internal

post_x <- function(x, params) {
post_x <- function(x, params, round = TRUE) {

# To avoid data.table check issues
variable <- val <- NULL
Expand All @@ -150,7 +151,7 @@ post_x <- function(x, params) {
idx_integer <- meta_tmp[, which(class == 'integer')]

# Recode
if (sum(idx_numeric) > 0L) {
if (sum(idx_numeric) > 0L & round) {
x[, idx_numeric] <- lapply(idx_numeric, function(j) {
round(as.numeric(x[[j]]), meta_tmp$decimals[j])
})
Expand All @@ -170,7 +171,7 @@ post_x <- function(x, params) {
}
if (sum(idx_integer) > 0L) {
x[, idx_integer] <- lapply(idx_integer, function(j) {
if (is.numeric(x[[j]])) {
if (is.numeric(x[[j]]) & round) {
as.integer(round(x[[j]]))
} else {
as.integer(as.character(x[[j]]))
Expand Down Expand Up @@ -385,7 +386,7 @@ cforde <- function(params, evidence, row_mode = c("separate", "or"), stepsize =

# Calculate updates for cat params matching cat conditions
cat_new <- merge(merge(relevant_leaves, cat_conds, by = "c_idx", allow.cartesian = T), cat, by = c("f_idx","variable", "val"))

# Ensure probabilities sum to 1
cat_new[, cvg_factor := sum(prob), by = .(f_idx, c_idx, variable)]
cat_new[, prob := prob/cvg_factor]
Expand All @@ -407,7 +408,7 @@ cforde <- function(params, evidence, row_mode = c("separate", "or"), stepsize =
relevant_leaves <- updates_relevant_leaves$relevant_leaves[,`:=` (f_idx = .I, f_idx_uncond = f_idx)][]
cnt_new <- setcolorder(merge(relevant_leaves, updates_relevant_leaves$cnt_new, by.x = c("c_idx", "f_idx_uncond"), by.y = c("c_idx", "f_idx"), sort = F), c("f_idx","c_idx","variable","min","max","val","cvg_factor"))[]
cat_new <- setcolorder(merge(relevant_leaves, updates_relevant_leaves$cat_new, by.x = c("c_idx", "f_idx_uncond"), by.y = c("c_idx", "f_idx"), sort = F), c("f_idx","c_idx","variable","val","prob","cvg_factor"))[]

# Check for conditions with no matching leaves and handle this according to row_mode
if (relevant_leaves[,uniqueN(c_idx)] < nconds_conditioned) {
if (relevant_leaves[,uniqueN(c_idx)] == 0 & row_mode == "or") {
Expand All @@ -424,8 +425,8 @@ cforde <- function(params, evidence, row_mode = c("separate", "or"), stepsize =
setnames(forest_new, "cvg", "cvg_arf")

cvg_new <- unique(rbind(cat_new[, .(f_idx, c_idx, variable, cvg_factor)],
cnt_new[, .(f_idx, c_idx, variable, cvg_factor)]),
by = c("f_idx", "variable"))[,-"variable"]
cnt_new[, .(f_idx, c_idx, variable, cvg_factor)]),
by = c("f_idx", "variable"))[,-"variable"]

if (nrow(cvg_new) > 0) {
# Use log transformation to avoid overflow
Expand Down Expand Up @@ -480,7 +481,7 @@ cforde <- function(params, evidence, row_mode = c("separate", "or"), stepsize =
forest_new_unconditioned[, `:=` (c_idx = rep(conds_unconditioned,each = nrow(forest)), f_idx_uncond = f_idx, cvg_arf = cvg)]
forest_new <- rbind(forest_new, forest_new_unconditioned)
}

setorder(setcolorder(forest_new,c("f_idx","c_idx","f_idx_uncond","tree","leaf","cvg_arf","cvg")), c_idx, f_idx, f_idx_uncond, tree, leaf)

list(evidence_input = evidence, evidence_prepped = condition_long, cnt = cnt_new, cat = cat_new, forest = forest_new)
Expand Down Expand Up @@ -553,7 +554,7 @@ prep_cond <- function(evidence, params, row_mode) {

# Interval syntax, e.g. (X,Inf)
condition_long[(variable %in% cnt_cols) & str_detect(val, "\\("),
c("val", "min", "max") := cbind(c(NA_real_, transpose(strsplit(substr(val, 2, nchar(val) - 1), split = ","))))]
c("val", "min", "max") := cbind(c(NA_real_, transpose(strsplit(substr(val, 2, nchar(val) - 1), split = ","))))]

# >, < syntax
condition_long[(variable %in% cnt_cols) & str_detect(val, "<"),
Expand Down
6 changes: 3 additions & 3 deletions man/expct.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/forge.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/post_x.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions tests/testthat/test-return_types.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ test_that("FORGE returns same column types", {
expect_equal(classes, classes_synth)
})

test_that("FORGE does not round to real data set precision if 'round == FALSE'", {
arf <- adversarial_rf(iris, num_trees = 2, verbose = FALSE, parallel = FALSE)
psi <- forde(arf, iris, parallel = FALSE)
x_synth <- forge(psi, n_synth = 20, round = FALSE, parallel = FALSE)
x_synth_rounded <- arf:::post_x(x_synth, psi, round = TRUE)

# Check if continuous variables were not rounded
expect_false(all(x_synth[,1:4] == x_synth_rounded[,1:4]))
expect_equal(data.frame(lapply(x_synth[,1:4], round, 1)), x_synth_rounded[,1:4])
})



# test_that("MAP returns proper column types", {
# n <- 50
# dat <- data.frame(numeric = rnorm(n),
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_expct.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ test_that("expct returns correct values", {
# For all classes
res <- expct(psi, query = "Sepal.Length", parallel = FALSE)
expect_equal(colnames(res), "Sepal.Length")
expect_equal(res$Sepal.Length, round(mean(iris$Sepal.Length), 1), tolerance = .2)
expect_equal(res$Sepal.Length, mean(iris$Sepal.Length))

# Only for setosa
res <- expct(psi, query = "Sepal.Length", evidence = data.frame(Species = "setosa"), parallel = FALSE)
expect_equal(res$Sepal.Length, round(mean(iris[iris$Species == "setosa", "Sepal.Length"]), 1))
expect_equal(res$Sepal.Length, mean(iris[iris$Species == "setosa", "Sepal.Length"]), tolerance = .1)
})

test_that("expct works for vectorized evidence", {
Expand Down

0 comments on commit 1cdd872

Please sign in to comment.