From dc09de6d9c11e43a27e9f8c855e4d2eb7a5371a7 Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Mon, 8 Jan 2024 06:49:22 +0100 Subject: [PATCH] save max.depth in ranger object --- R/ranger.R | 5 +++++ tests/testthat/test_classification.R | 4 ++-- tests/testthat/test_print.R | 2 +- tests/testthat/test_regression.R | 4 ++-- tests/testthat/test_survival.R | 4 ++-- 5 files changed, 12 insertions(+), 7 deletions(-) diff --git a/R/ranger.R b/R/ranger.R index 6d56d4d4..fc78e4d0 100644 --- a/R/ranger.R +++ b/R/ranger.R @@ -1037,6 +1037,11 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, result$dependent.variable.name <- dependent.variable.name result$status.variable.name <- status.variable.name + ## Save max.depth + if (!is.null(max.depth)) { + result$max.depth <- max.depth + } + class(result) <- "ranger" ## Prepare quantile prediction diff --git a/tests/testthat/test_classification.R b/tests/testthat/test_classification.R index 0d015c9d..4690afe5 100644 --- a/tests/testthat/test_classification.R +++ b/tests/testthat/test_classification.R @@ -10,9 +10,9 @@ rg.class <- ranger(Species ~ ., data = iris) rg.mat <- ranger(dependent.variable.name = "Species", data = dat, classification = TRUE) ## Basic tests (for all random forests equal) -test_that("classification result is of class ranger with 15 elements", { +test_that("classification result is of class ranger with 16 elements", { expect_is(rg.class, "ranger") - expect_equal(length(rg.class), 15) + expect_equal(length(rg.class), 16) }) test_that("classification prediction returns factor", { diff --git a/tests/testthat/test_print.R b/tests/testthat/test_print.R index 3ca91b4a..8563b1e3 100644 --- a/tests/testthat/test_print.R +++ b/tests/testthat/test_print.R @@ -16,7 +16,7 @@ expect_that(print(rf$forest), prints_text("Ranger forest object")) expect_that(print(predict(rf, iris)), prints_text("Ranger prediction")) ## Test str ranger function -expect_that(str(rf), prints_text("List of 15")) +expect_that(str(rf), prints_text("List of 16")) ## Test str forest function expect_that(str(rf$forest), prints_text("List of 9")) diff --git a/tests/testthat/test_regression.R b/tests/testthat/test_regression.R index dd3bdd4e..8949f82d 100644 --- a/tests/testthat/test_regression.R +++ b/tests/testthat/test_regression.R @@ -7,9 +7,9 @@ context("ranger_reg") rg.reg <- ranger(Sepal.Length ~ ., data = iris) ## Basic tests (for all random forests equal) -test_that("regression result is of class ranger with 15 elements", { +test_that("regression result is of class ranger with 16 elements", { expect_is(rg.reg, "ranger") - expect_equal(length(rg.reg), 15) + expect_equal(length(rg.reg), 16) }) test_that("regression prediction returns numeric vector", { diff --git a/tests/testthat/test_survival.R b/tests/testthat/test_survival.R index 6226eb6f..358a4096 100644 --- a/tests/testthat/test_survival.R +++ b/tests/testthat/test_survival.R @@ -8,9 +8,9 @@ context("ranger_surv") rg.surv <- ranger(Surv(time, status) ~ ., data = veteran, num.trees = 10) ## Basic tests (for all random forests equal) -test_that("survival result is of class ranger with 17 elements", { +test_that("survival result is of class ranger with 18 elements", { expect_is(rg.surv, "ranger") - expect_equal(length(rg.surv), 17) + expect_equal(length(rg.surv), 18) }) test_that("results have right number of trees", {