Using repair_call() in a custom function

I am using library(partykit) to sample from nodes from an rpart model created with library(tidymodels). I need to use repair_call() to use as.party() with output from parsnip::fit(). This breaks when I move the code inside of a custom function:

library(tidymodels)
library(partykit)

cart_model <- parsnip::decision_tree() %>%
  parsnip::set_engine("rpart") %>%
  parsnip::set_mode("regression")

parsnip_model <- fit(cart_model, mpg ~ ., data = mtcars)

predict_sample_rpart <- function(object, old_data, new_data) {
  
  repaired_model <- repair_call(object, data = old_data)
  
  node_ecdf <- predict(as.party(repaired_model$fit), newdata = new_data, type = "prob")
  
  sample(environment(node_ecdf[["1"]])[["x"]], 1)
  
}

predict_sample_rpart(parsnip_model, old_data = mtcars, new_data = mtcars)
#>  Error in is.data.frame(data) : object 'old_data' not found 

repair_call() assigns the data in repaired_model as old_data instead of mtcars and then predict() does not work.

  1. Any help fixing this would be greatly appreciated.
  2. Any suggestions for better ways to sample from conditional distributions created by lm(), rpart(), ranger() would be doubly appreciated.

This shows the issue with using the call object in computations. It assumes that the data used to create the model are in the same scope/environment as the one that uses the call object.

Inside of your function, it can't find the right reference for the data to put inside of the call. Outside of the function it works fine (I think you meant type = "node"):

library(tidymodels)
#> ── Attaching packages ────────────────────────────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom     0.7.0      ✓ recipes   0.1.13
#> ✓ dials     0.0.9      ✓ rsample   0.0.8 
#> ✓ dplyr     1.0.2      ✓ tibble    3.0.3 
#> ✓ ggplot2   3.3.2      ✓ tidyr     1.1.2 
#> ✓ infer     0.5.2      ✓ tune      0.1.1 
#> ✓ modeldata 0.0.2      ✓ workflows 0.2.0 
#> ✓ parsnip   0.1.3      ✓ yardstick 0.0.7 
#> ✓ purrr     0.3.4
#> ── Conflicts ───────────────────────────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()
library(partykit)
#> Loading required package: grid
#> Loading required package: libcoin
#> Loading required package: mvtnorm

cart_model <- parsnip::decision_tree() %>%
  parsnip::set_engine("rpart") %>%
  parsnip::set_mode("regression")

parsnip_model <- fit(cart_model, mpg ~ ., data = mtcars)

repaired_model <- repair_call(parsnip_model, data = mtcars)

node_ecdf <- predict(as.party(repaired_model$fit), newdata = head(mtcars), type = "node")
node_ecdf
#>         Mazda RX4     Mazda RX4 Wag        Datsun 710    Hornet 4 Drive 
#>                 4                 4                 5                 4 
#> Hornet Sportabout           Valiant 
#>                 4                 4

Created on 2020-10-07 by the reprex package (v0.3.0)

I'm not sure what the solution is for using it inside of a function so I'd try to not use it that way. That's probably unsatisfying but it is a problem baked into how they use the call object.

1 Like

This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.