Skip to content

Commit

Permalink
Merge pull request #14 from bips-hb/x_tilde_list
Browse files Browse the repository at this point in the history
Allow lists of knockoff matrices to improve stability
  • Loading branch information
mnwright authored Nov 5, 2024
2 parents e43d710 + 7efef29 commit d54f9be
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 22 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Description: A general test for conditional independence in supervised learning
License: GPL (>= 3)
Encoding: UTF-8
LazyData: true
RoxygenNote: 7.1.2
RoxygenNote: 7.3.2
URL: https://github.com/bips-hb/cpi,
https://bips-hb.github.io/cpi/
BugReports: https://github.com/bips-hb/cpi/issues
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

# cpi 0.1.5
* Allow a list of knockoff matrices to improve stability
* Remove BEST dependency (removed from CRAN)

# cpi 0.1.4
Expand Down
67 changes: 47 additions & 20 deletions R/cpi.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
#' @param B Number of permutations for Fisher permutation test.
#' @param alpha Significance level for confidence intervals.
#' @param x_tilde Knockoff matrix or data.frame. If not given (the default), it will be
#' created with the function given in \code{knockoff_fun}.
#' created with the function given in \code{knockoff_fun}.
#' Also accepts a list of matrices or data.frames.
#' @param aggr_fun Aggregation function over replicates.
#' @param knockoff_fun Function to generate knockoffs. Default:
#' \code{knockoff::\link{create.second_order}} with matrix argument.
#' @param groups (Named) list with groups. Set to \code{NULL} (default) for no
Expand Down Expand Up @@ -144,6 +146,7 @@ cpi <- function(task, learner,
B = 1999,
alpha = 0.05,
x_tilde = NULL,
aggr_fun = mean,
knockoff_fun = function(x) knockoff::create.second_order(as.matrix(x)),
groups = NULL,
verbose = FALSE) {
Expand Down Expand Up @@ -229,8 +232,9 @@ cpi <- function(task, learner,
if (is.null(test_data)) {
x_tilde <- knockoff_fun(task$data(cols = task$feature_names))
} else {
test_data_x_tilde <- knockoff_fun(test_data[, task$feature_names])
x_tilde <- knockoff_fun(test_data[, task$feature_names])
}
x_tilde <- list(x_tilde)
} else if (is.matrix(x_tilde) | is.data.frame(x_tilde)) {
if (is.null(test_data)) {
if (any(dim(x_tilde) != dim(task$data(cols = task$feature_names)))) {
Expand All @@ -240,34 +244,57 @@ cpi <- function(task, learner,
if (any(dim(x_tilde) != dim(test_data[, task$feature_names]))) {
stop("Size of 'x_tilde' must match dimensions of data.")
}
test_data_x_tilde <- x_tilde
}
x_tilde <- list(x_tilde)
} else if (is.list(x_tilde)) {
if (length(x_tilde) < 1) {
stop("If 'x_tilde' is a list, it cannot be empty.")
}
if (is.null(test_data)) {
#FIXME: Check all dims
if (any(dim(x_tilde[[1]]) != dim(task$data(cols = task$feature_names)))) {
stop("Size of 'x_tilde' must match dimensions of data.")
}
} else {
#FIXME: Check all dims
if (any(dim(x_tilde[[1]]) != dim(test_data[, task$feature_names]))) {
stop("Size of 'x_tilde' must match dimensions of data.")
}
}
} else {
stop("Argument 'x_tilde' must be a matrix, data.frame or NULL.")
}

# For each feature, fit reduced model and return difference in error
cpi_fun <- function(i) {
if (is.null(test_data)) {
reduced_test_data <- NULL
reduced_data <- as.data.frame(task$data())
reduced_data[, task$feature_names[i]] <- x_tilde[, task$feature_names[i]]
if (task$task_type == "regr") {
reduced_task <- as_task_regr(reduced_data, target = task$target_names)
} else if (task$task_type == "classif") {
reduced_task <- as_task_classif(reduced_data, target = task$target_names)
err_reduced <- sapply(x_tilde, function(x_tilde_i) {
if (is.null(test_data)) {
reduced_test_data <- NULL
reduced_data <- as.data.frame(task$data())
reduced_data[, task$feature_names[i]] <- x_tilde_i[, task$feature_names[i]]
if (task$task_type == "regr") {
reduced_task <- as_task_regr(reduced_data, target = task$target_names)
} else if (task$task_type == "classif") {
reduced_task <- as_task_classif(reduced_data, target = task$target_names)
} else {
stop("Unknown task type.")
}
} else {
stop("Unknown task type.")
reduced_task <- NULL
reduced_test_data <- test_data
reduced_test_data[, task$feature_names[i]] <- x_tilde_i[, task$feature_names[i]]
}
} else {
reduced_task <- NULL
reduced_test_data <- test_data
reduced_test_data[, task$feature_names[i]] <- test_data_x_tilde[, task$feature_names[i]]
}

# Predict with knockoff data
pred_reduced <- predict_learner(fit_full, reduced_task, resampling = resampling, test_data = reduced_test_data)
err_reduced <- compute_loss(pred_reduced, measure)

err_reduced
})

# Average over results with different knockoffs
err_reduced <- apply(err_reduced, 1, aggr_fun)

# Predict with knockoff data
pred_reduced <- predict_learner(fit_full, reduced_task, resampling = resampling, test_data = reduced_test_data)
err_reduced <- compute_loss(pred_reduced, measure)
if (log) {
dif <- log(err_reduced / err_full)
} else {
Expand Down
6 changes: 5 additions & 1 deletion man/cpi.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit d54f9be

Please sign in to comment.