Skip to content

Commit

Permalink
Merge pull request #41 from SchlossLab/report-roc-curves
Browse files Browse the repository at this point in the history
Improve report & ROC / PRC plots
  • Loading branch information
kelly-sovacool authored Feb 1, 2023
2 parents cf3011e + db55cc7 commit 20323a3
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 34 deletions.
11 changes: 6 additions & 5 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FROM condaforge/mambaforge:latest
LABEL io.github.snakemake.containerized="true"
LABEL io.github.snakemake.conda_env_hash="2698c18a4528bbf37c853c799661b04cfa051524df1e2288d2a1bc376c84a9d9"
LABEL io.github.snakemake.conda_env_hash="6aa289536136aae2d34bac6dce9ce47d037da888ed09e2c8ada989c90ef10658"

# Step 1: Retrieve conda environments

Expand All @@ -17,7 +17,7 @@ COPY workflow/envs/graphviz.yml /conda-envs/b42323b0ffd5d034544511c9db1bdead/env

# Conda environment:
# source: workflow/envs/mikropml.yml
# prefix: /conda-envs/67570867c99c9c3db185b41548ad6071
# prefix: /conda-envs/3f83a46ff5ea715a12fde6ee46136b0b
# name: mikropml
# channels:
# - conda-forge
Expand All @@ -31,13 +31,14 @@ COPY workflow/envs/graphviz.yml /conda-envs/b42323b0ffd5d034544511c9db1bdead/env
# - r-future.apply
# - r-import
# - r-mikropml>=1.5.0
# - r-patchwork
# - r-rmarkdown
# - r-rpart
# - r-purrr
# - r-schtools>=0.4.0
# - r-tidyverse
RUN mkdir -p /conda-envs/67570867c99c9c3db185b41548ad6071
COPY workflow/envs/mikropml.yml /conda-envs/67570867c99c9c3db185b41548ad6071/environment.yaml
RUN mkdir -p /conda-envs/3f83a46ff5ea715a12fde6ee46136b0b
COPY workflow/envs/mikropml.yml /conda-envs/3f83a46ff5ea715a12fde6ee46136b0b/environment.yaml

# Conda environment:
# source: workflow/envs/smk.yml
Expand All @@ -56,6 +57,6 @@ COPY workflow/envs/smk.yml /conda-envs/457b7b75191d44b96e5086432876e333/environm
# Step 2: Generate conda environments

RUN mamba env create --prefix /conda-envs/b42323b0ffd5d034544511c9db1bdead --file /conda-envs/b42323b0ffd5d034544511c9db1bdead/environment.yaml && \
mamba env create --prefix /conda-envs/67570867c99c9c3db185b41548ad6071 --file /conda-envs/67570867c99c9c3db185b41548ad6071/environment.yaml && \
mamba env create --prefix /conda-envs/3f83a46ff5ea715a12fde6ee46136b0b --file /conda-envs/3f83a46ff5ea715a12fde6ee46136b0b/environment.yaml && \
mamba env create --prefix /conda-envs/457b7b75191d44b96e5086432876e333 --file /conda-envs/457b7b75191d44b96e5086432876e333/environment.yaml && \
mamba clean --all -y
Binary file modified figures/example/benchmarks.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/example/hp_performance_glmnet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/example/roc_curves.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
41 changes: 28 additions & 13 deletions report-example.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
ML Results
================
2023-01-20
---
title: "ML Results"
date: "2023-01-31"
output:
html_document:
keep_md: true
self_contained: true
theme: spacelab
---

Machine learning algorithms used were glmnet and rf. Models were trained
with 10 different random partitions of the otu-large dataset into
training and testing sets using 5-fold cross validation. See
`config/config.yaml` for the full configuration.





Machine learning algorithm(s) used: glmnet and rf.
Models were trained with 10 different random
partitions of the otu-large dataset into training and
testing sets using 5-fold cross validation.
See [config/config.yaml](config/config.yaml)
for the full configuration.

## Workflow

Expand All @@ -17,16 +31,17 @@ training and testing sets using 5-fold cross validation. See

<img src="figures/example/roc_curves.png" width="80%" />

## Hyperparameter Performance

<img src="figures/example/hp_performance_glmnet.png" width="80%" /><img src="figures/example/hp_performance_rf.png" width="80%" />

## Feature Importance

<img src="figures/example/feature_importance.png" width="80%" />

## Memory Usage & Runtime

<img src="figures/example/benchmarks.png" width="80%" />

Each model training run was given 8 cores for parallelization.
Each model training run was given 8 cores
for parallelization.

## Hyperparameter Performance

<img src="figures/example/hp_performance_glmnet.png" width="80%" /><img src="figures/example/hp_performance_rf.png" width="80%" />
<img src="figures/example/benchmarks.png" width="80%" />
6 changes: 4 additions & 2 deletions workflow/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ results_types = ["performance", "benchmarks", "sensspec"]
if find_feature_importance:
results_types.append("feature-importance")


include: "rules/learn.smk"
include: "rules/combine.smk"
include: "rules/plot.smk"
Expand Down Expand Up @@ -86,13 +87,14 @@ rule render_report:

rule archive:
input:
expand(rules.render_report.input, dataset = dataset),
expand(rules.render_report.output, dataset = dataset),
expand(rules.render_report.input, dataset=dataset),
expand(rules.render_report.output, dataset=dataset),
expand(
"results/{dataset}/{rtype}_results.csv",
dataset=dataset,
rtype=results_types,
),
config_path,
output:
f"workflow_{dataset}.zip",
log:
Expand Down
1 change: 1 addition & 0 deletions workflow/envs/mikropml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- r-future.apply
- r-import
- r-mikropml>=1.5.0
- r-patchwork
- r-rmarkdown
- r-rpart
- r-purrr
Expand Down
1 change: 1 addition & 0 deletions workflow/scripts/calc_model_sensspec.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ mikropml::calc_model_sensspec(
test_dat,
outcome_colname
) %>%
bind_cols(schtools::get_wildcards_tbl()) %>%
write_csv(snakemake@output[["csv"]])
125 changes: 118 additions & 7 deletions workflow/scripts/plot_roc_curves.R
Original file line number Diff line number Diff line change
@@ -1,16 +1,127 @@
schtools::log_snakemake()
library(cowplot)
library(mikropml)
library(patchwork)
library(tidyverse)


dat <- read_csv(snakemake@input[["csv"]])
p <- plot_grid(
dat %>% calc_mean_roc() %>% plot_mean_roc(),
dat %>% calc_mean_prc() %>% plot_mean_prc()
)

calc_mean_perf <- function(sensspec_dat,
group_var = specificity,
sum_var = sensitivity,
custom_group_vars = NULL) {
specificity <- sensitivity <- sd <- NULL
dat_round <- sensspec_dat %>%
dplyr::mutate({{ group_var }} := round({{ group_var }}, 2))
if (!is.null(custom_group_vars)) {
dat_grouped <- dat_round %>%
dplyr::group_by({{ group_var }}, !!rlang::sym(custom_group_vars))
} else {
dat_grouped <- dat_round %>%
dplyr::group_by({{ group_var }})
}
return(
dat_grouped %>%
dplyr::summarise(
mean = mean({{ sum_var }}),
sd = stats::sd({{ sum_var }})
) %>%
dplyr::mutate(
upper = mean + sd,
lower = mean - sd,
upper = dplyr::case_when(
upper > 1 ~ 1,
TRUE ~ upper
),
lower = dplyr::case_when(
lower < 0 ~ 0,
TRUE ~ lower
)
) %>%
dplyr::rename(
"mean_{{ sum_var }}" := mean,
"sd_{{ sum_var }}" := sd
)
)
}

calc_mean_roc <- function(sensspec_dat, custom_group_vars = NULL) {
specificity <- sensitivity <- NULL
return(calc_mean_perf(sensspec_dat,
group_var = specificity,
sum_var = sensitivity,
custom_group_vars = custom_group_vars
))
}

calc_mean_prc <- function(sensspec_dat, custom_group_vars = NULL) {
sensitivity <- recall <- precision <- NULL
return(calc_mean_perf(
sensspec_dat %>%
dplyr::rename(recall = sensitivity),
group_var = recall,
sum_var = precision,
custom_group_vars = custom_group_vars
))
}

shared_ggprotos <- function(colorvar) {
return(list(
ggplot2::geom_ribbon(aes(fill = {{ colorvar }}), alpha = 0.5),
ggplot2::geom_line(aes(color = {{ colorvar }})),
ggplot2::coord_equal(),
ggplot2::scale_y_continuous(expand = c(0, 0), limits = c(-0.01, 1.01)),
ggplot2::theme_bw(),
ggplot2::theme(legend.title = ggplot2::element_blank())
))
}

plot_mean_roc <- function(dat) {
specificity <- mean_sensitivity <- lower <- upper <- NULL
dat %>%
ggplot2::ggplot(ggplot2::aes(
x = specificity, y = mean_sensitivity,
ymin = lower, ymax = upper
)) +
shared_ggprotos(colorvar = method) +
ggplot2::geom_abline(
intercept = 1, slope = 1,
linetype = "dashed", color = "grey50"
) +
ggplot2::scale_x_reverse(expand = c(0, 0), limits = c(1.01, -0.01)) +
ggplot2::labs(x = "Specificity", y = "Mean Sensitivity")
}

plot_mean_prc <- function(dat, baseline_precision = NULL) {
recall <- mean_precision <- lower <- upper <- NULL
prc_plot <- dat %>%
ggplot2::ggplot(ggplot2::aes(
x = recall, y = mean_precision,
ymin = lower, ymax = upper
)) +
shared_ggprotos(colorvar = method) +
ggplot2::scale_x_continuous(expand = c(0, 0), limits = c(-0.01, 1.01)) +
ggplot2::labs(x = "Recall", y = "Mean Precision")
if (!is.null(baseline_precision)) {
prc_plot <- prc_plot +
ggplot2::geom_hline(
yintercept = baseline_precision,
linetype = "dashed", color = "grey50"
)
}
return(prc_plot)
}
p <- (dat %>%
calc_mean_roc(custom_group_vars = "method") %>%
plot_mean_roc()) +
(dat %>%
calc_mean_prc(custom_group_vars = "method") %>%
plot_mean_prc() +
theme(legend.position = "none"))

ggsave(
filename = snakemake@output[["plot"]],
plot = p,
device = "png"
device = "png",
height = 4,
width = 6
)
15 changes: 8 additions & 7 deletions workflow/scripts/report.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ Machine learning algorithm(s) used: `r snakemake@params[['ml_methods']]`.
Models were trained with `r snakemake@params[['nseeds']]` different random
partitions of the `r snakemake@params[['dataset']]` dataset into training and
testing sets using `r snakemake@params[['kfold']]`-fold cross validation.
See ``r snakemake@params[['config_path']]`` for the full configuration.
See [`r snakemake@params[['config_path']]`](`r snakemake@params[['config_path']]`)
for the full configuration.

## Workflow

Expand All @@ -38,6 +39,12 @@ include_graphics(snakemake@input[['perf_plot']])
include_graphics(snakemake@input[['roc_plot']])
```

## Hyperparameter Performance

```{r hp_plot, out.width='80%'}
include_graphics(snakemake@input[['hp_plot']])
```

```{r feat_imp_header, results='asis'}
if (isTRUE(snakemake@params[['find_feature_importance']])) {
cat("## Feature Importance")
Expand All @@ -56,9 +63,3 @@ for parallelization.
```{r runtime_plot, out.width='80%'}
include_graphics(snakemake@input[['bench_plot']])
```

## Hyperparameter Performance

```{r hp_plot, out.width='80%'}
include_graphics(snakemake@input[['hp_plot']])
```

0 comments on commit 20323a3

Please sign in to comment.