Skip to content

Commit

Permalink
Merge pull request #715 from imbs-hl/save_max_depth
Browse files Browse the repository at this point in the history
Save max.depth in ranger object
  • Loading branch information
mnwright authored May 16, 2024
2 parents 24a24bf + dc09de6 commit 858bfda
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 7 deletions.
5 changes: 5 additions & 0 deletions R/ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,11 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
result$dependent.variable.name <- dependent.variable.name
result$status.variable.name <- status.variable.name

## Save max.depth
if (!is.null(max.depth)) {
result$max.depth <- max.depth
}

class(result) <- "ranger"

## Prepare quantile prediction
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_classification.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ rg.class <- ranger(Species ~ ., data = iris)
rg.mat <- ranger(dependent.variable.name = "Species", data = dat, classification = TRUE)

## Basic tests (for all random forests equal)
test_that("classification result is of class ranger with 15 elements", {
test_that("classification result is of class ranger with 16 elements", {
expect_is(rg.class, "ranger")
expect_equal(length(rg.class), 15)
expect_equal(length(rg.class), 16)
})

test_that("classification prediction returns factor", {
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_print.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ expect_that(print(rf$forest), prints_text("Ranger forest object"))
expect_that(print(predict(rf, iris)), prints_text("Ranger prediction"))

## Test str ranger function
expect_that(str(rf), prints_text("List of 15"))
expect_that(str(rf), prints_text("List of 16"))

## Test str forest function
expect_that(str(rf$forest), prints_text("List of 9"))
4 changes: 2 additions & 2 deletions tests/testthat/test_regression.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ context("ranger_reg")
rg.reg <- ranger(Sepal.Length ~ ., data = iris)

## Basic tests (for all random forests equal)
test_that("regression result is of class ranger with 15 elements", {
test_that("regression result is of class ranger with 16 elements", {
expect_is(rg.reg, "ranger")
expect_equal(length(rg.reg), 15)
expect_equal(length(rg.reg), 16)
})

test_that("regression prediction returns numeric vector", {
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ context("ranger_surv")
rg.surv <- ranger(Surv(time, status) ~ ., data = veteran, num.trees = 10)

## Basic tests (for all random forests equal)
test_that("survival result is of class ranger with 17 elements", {
test_that("survival result is of class ranger with 18 elements", {
expect_is(rg.surv, "ranger")
expect_equal(length(rg.surv), 17)
expect_equal(length(rg.surv), 18)
})

test_that("results have right number of trees", {
Expand Down

0 comments on commit 858bfda

Please sign in to comment.