step_dummy, step_zv and predict

I'm relatively new to tidy models and I've been going through the tutorials (which are very helpful!). I'm curious about some potentially strange behavior in the example code from the Preprocess your data with recipes tutorial. In that tutorial, because there were many factors in the dataset, it was possible for there to be novel factors in the test set that did not appear in the training set (e.g. dest == "LEX" appears in the test set and not the training set in the example).
My understanding is that the problem with this is 1) that if we take it out of our training set, we don't know how to make predictions for rows in the test set with this factor but 2) if we leave it in, for unregularized regressions (like glm as in the example), we end up with a low rank design matrix so we can't even train the model.

In order to handle this, the suggestion was to create dummy variables (using a treatment contrast) with step_dummy and then remove the zero-variance variables with step_zv. I can see how this handles 2) but it is unclear from the tutorial how predict(flights_fit, test_data) makes predictions. It seems likely to me to be just making predictions as if there was a 0 in all of the dummy variable columns for dest which means that it makes the same prediction as it would for the baseline factor (which in this case is dest == ABQ). Since the choice of baseline factor in the treatment contrast is essentially arbitrary, this does not seem like optimal behavior. Is this in fact what predict is doing in this setting? It seems like the advice here (creating treatment dummies and then removing zero-variance variables) is quite bad then -- I can see how it would make sense for a regularized regression with one-hot encoded variables but not for this particular setting. Am I missing something here?

One of the first portions in the tutorial converts all characters to factors so that they retain levels even if they are not present in the split data.

So when you take a look under the hood, the factor levels match in the raw, test, train, and recipe data.

A simple reprex is below:

flight_data <- 
  flights %>% 
  mutate(
    # Convert the arrival delay to a factor
    arr_delay = ifelse(arr_delay >= 30, "late", "on_time"),
    arr_delay = factor(arr_delay),
    # We will use the date (not date-time) in the recipe below
    date = as.Date(time_hour)
  ) %>% 
  # Include the weather data
  inner_join(weather, by = c("origin", "time_hour")) %>% 
  # Only retain the specific columns we will use
  select(dep_time, flight, origin, dest, air_time, distance, 
         carrier, date, arr_delay, time_hour) %>% 
  # Exclude missing data
  na.omit() %>% 
  # For creating models, it is better to have qualitative columns
  # encoded as factors (instead of character strings)
  mutate_if(is.character, as.factor)

lapply(
  list(
    flight_data,
    train_data,
    test_data,
    flights_rec$template
  ),
  str
)
#> tibble [325,819 x 10] (S3: tbl_df/tbl/data.frame)
#>  $ dep_time : int [1:325819] 517 533 542 544 554 554 555 557 557 558 ...
#>  $ flight   : int [1:325819] 1545 1714 1141 725 461 1696 507 5708 79 301 ...
#>  $ origin   : Factor w/ 3 levels "EWR","JFK","LGA": 1 3 2 2 3 1 1 3 2 3 ...
#>  $ dest     : Factor w/ 104 levels "ABQ","ACK","ALB",..: 44 44 58 13 5 69 36 43 54 69 ...
#>  $ air_time : num [1:325819] 227 227 160 183 116 150 158 53 140 138 ...
#>  $ distance : num [1:325819] 1400 1416 1089 1576 762 ...
#>  $ carrier  : Factor w/ 16 levels "9E","AA","AS",..: 12 12 2 4 5 12 4 6 4 2 ...
#>  $ date     : Date[1:325819], format: "2013-01-01" "2013-01-01" ...
#>  $ arr_delay: Factor w/ 2 levels "late","on_time": 2 2 1 2 2 2 2 2 2 2 ...
#>  $ time_hour: POSIXct[1:325819], format: "2013-01-01 05:00:00" "2013-01-01 05:00:00" ...
#> tibble [244,365 x 10] (S3: tbl_df/tbl/data.frame)
#>  $ dep_time : int [1:244365] 517 533 544 554 554 555 557 557 558 558 ...
#>  $ flight   : int [1:244365] 1545 1714 725 461 1696 507 5708 79 301 49 ...
#>  $ origin   : Factor w/ 3 levels "EWR","JFK","LGA": 1 3 2 3 1 1 3 2 3 2 ...
#>  $ dest     : Factor w/ 104 levels "ABQ","ACK","ALB",..: 44 44 13 5 69 36 43 54 69 71 ...
#>  $ air_time : num [1:244365] 227 227 183 116 150 158 53 140 138 149 ...
#>  $ distance : num [1:244365] 1400 1416 1576 762 719 ...
#>  $ carrier  : Factor w/ 16 levels "9E","AA","AS",..: 12 12 4 5 12 4 6 4 2 4 ...
#>  $ date     : Date[1:244365], format: "2013-01-01" "2013-01-01" ...
#>  $ arr_delay: Factor w/ 2 levels "late","on_time": 2 2 2 2 2 2 2 2 2 2 ...
#>  $ time_hour: POSIXct[1:244365], format: "2013-01-01 05:00:00" "2013-01-01 05:00:00" ...
#> tibble [81,454 x 10] (S3: tbl_df/tbl/data.frame)
#>  $ dep_time : int [1:81454] 542 606 611 622 624 624 628 629 635 637 ...
#>  $ flight   : int [1:81454] 1141 1743 303 245 4626 4599 1665 4646 711 389 ...
#>  $ origin   : Factor w/ 3 levels "EWR","JFK","LGA": 2 2 2 1 1 3 1 3 3 3 ...
#>  $ dest     : Factor w/ 104 levels "ABQ","ACK","ALB",..: 58 5 90 74 61 61 50 17 31 54 ...
#>  $ air_time : num [1:81454] 160 128 366 342 190 166 366 40 248 144 ...
#>  $ distance : num [1:81454] 1089 760 2586 2133 1008 ...
#>  $ carrier  : Factor w/ 16 levels "9E","AA","AS",..: 2 5 12 13 6 10 12 15 2 4 ...
#>  $ date     : Date[1:81454], format: "2013-01-01" "2013-01-01" ...
#>  $ arr_delay: Factor w/ 2 levels "late","on_time": 1 2 2 2 2 2 2 2 1 2 ...
#>  $ time_hour: POSIXct[1:81454], format: "2013-01-01 05:00:00" "2013-01-01 06:00:00" ...
#> tibble [244,365 x 10] (S3: tbl_df/tbl/data.frame)
#>  $ dep_time : int [1:244365] 517 533 544 554 554 555 557 557 558 558 ...
#>  $ flight   : int [1:244365] 1545 1714 725 461 1696 507 5708 79 301 49 ...
#>  $ origin   : Factor w/ 3 levels "EWR","JFK","LGA": 1 3 2 3 1 1 3 2 3 2 ...
#>  $ dest     : Factor w/ 104 levels "ABQ","ACK","ALB",..: 44 44 13 5 69 36 43 54 69 71 ...
#>  $ air_time : num [1:244365] 227 227 183 116 150 158 53 140 138 149 ...
#>  $ distance : num [1:244365] 1400 1416 1576 762 719 ...
#>  $ carrier  : Factor w/ 16 levels "9E","AA","AS",..: 12 12 4 5 12 4 6 4 2 4 ...
#>  $ date     : Date[1:244365], format: "2013-01-01" "2013-01-01" ...
#>  $ time_hour: POSIXct[1:244365], format: "2013-01-01 05:00:00" "2013-01-01 05:00:00" ...
#>  $ arr_delay: Factor w/ 2 levels "late","on_time": 2 2 2 2 2 2 2 2 2 2 ...
1 Like

Thanks for the response! Sorry if my question wasn't entirely clear, let me try again, adding a reprex. After preprocessing the flight_data and converting to factors, the tutorial adds a recipe:

library(tidymodels)      # for the recipes package, along with the rest of tidy-models
# Helper packages
library(nycflights13)    # for flight data
library(skimr)           # for variable summaries

set.seed(123)

flight_data <- 
  flights %>% 
  mutate(
    # Convert the arrival delay to a factor
    arr_delay = ifelse(arr_delay >= 30, "late", "on_time"),
    arr_delay = factor(arr_delay),
    # We will use the date (not date-time) in the recipe below
    date = as.Date(time_hour)
  ) %>% 
  # Include the weather data
  inner_join(weather, by = c("origin", "time_hour")) %>% 
  # Only retain the specific columns we will use
  select(dep_time, flight, origin, dest, air_time, distance, 
         carrier, date, arr_delay, time_hour) %>% 
  # Exclude missing data
  na.omit() %>% 
  # For creating models, it is better to have qualitative columns
  # encoded as factors (instead of character strings)
  mutate_if(is.character, as.factor)

# Fix the random numbers by setting the seed 
# This enables the analysis to be reproducible when random numbers are used 
set.seed(555)
# Put 3/4 of the data into the training set 
data_split <- initial_split(flight_data, prop = 3/4)

# Create data frames for the two sets:
train_data <- training(data_split)
test_data  <- testing(data_split)

flights_rec <- 
  recipe(arr_delay ~ ., data = train_data) %>% 
  update_role(flight, time_hour, new_role = "ID") %>% 
  step_date(date, features = c("dow", "month")) %>% 
  step_holiday(date, holidays = timeDate::listHolidays("US")) %>% 
  step_rm(date) %>% 
  step_dummy(all_nominal(), -all_outcomes()) %>% 
  step_zv(all_predictors())

trained_flights_rec <- flights_rec %>% prep(training = train_data)
trained_flights_rec
#> Data Recipe
#> 
#> Inputs:
#> 
#>       role #variables
#>         ID          2
#>    outcome          1
#>  predictor          7
#> 
#> Training data contained 244365 data points and no missing data.
#> 
#> Operations:
#> 
#> Date features from date [trained]
#> Holiday features from date [trained]
#> Variables removed date [trained]
#> Dummy variables from origin, dest, carrier, date_dow, date_month [trained]
#> Zero variance filter removed dest_LEX [trained]

test_bake <- trained_flights_rec %>% bake(new_data = test_data)
test_bake %>% select(starts_with("dest_L"))
#> # A tibble: 81,454 x 3
#>    dest_LAS dest_LAX dest_LGB
#>       <dbl>    <dbl>    <dbl>
#>  1        0        0        0
#>  2        0        0        0
#>  3        0        0        0
#>  4        0        0        0
#>  5        1        0        0
#>  6        0        0        0
#>  7        0        0        0
#>  8        0        0        0
#>  9        0        0        0
#> 10        0        0        0
#> # … with 81,444 more rows

As you can see, step_zv removes the column dest_LEX from both training data and test data. This is crucial for training the model (an unregularized logistic regression) because LEX doesn't appear as a destination in the training data so the estimated \beta for dest_LEX is unconstrained. But given that, it is unclear from the tutorial text what recipes does with dest == LEX in the test data.

LEX_test_data <- test_data %>% 
  filter(dest == "LEX")

test_LEX_bake <- trained_flights_rec %>% bake(new_data = LEX_test_data)
test_LEX_bake %>% select_if(~ !is.numeric(.) || sum(.) != 0)
#> # A tibble: 1 x 9
#>   dep_time flight air_time distance time_hour           arr_delay origin_LGA
#>      <int>  <int>    <dbl>    <dbl> <dttm>              <fct>          <dbl>
#> 1     2026   3669       90      604 2013-11-24 20:00:00 on_time            1
#> # … with 2 more variables: date_dow_Mon <dbl>, date_month_Nov <dbl>

We can see here that all of the dest_ columns are 0 when we bake the recipe on test data where dest == LEX.
But because we used step_dummy without one-hot encoding, all of the dest columns will also be 0 when dest == ABQ.

new_ABQ_test_data = LEX_test_data  %>% mutate(dest = factor("ABQ", levels = levels(LEX_test_data$dest)))
test_ABQ_bake <- trained_flights_rec %>% bake(new_data = new_ABQ_test_data)
test_ABQ_bake %>% select_if(~ !is.numeric(.) || sum(.) != 0)
#> # A tibble: 1 x 9
#>   dep_time flight air_time distance time_hour           arr_delay origin_LGA
#>      <int>  <int>    <dbl>    <dbl> <dttm>              <fct>          <dbl>
#> 1     2026   3669       90      604 2013-11-24 20:00:00 on_time            1
#> # … with 2 more variables: date_dow_Mon <dbl>, date_month_Nov <dbl>

Just to show that this is unique to ABQ (the chosen baseline contrast), we can try the same thing with dest == PHL

new_PHL_test_data = LEX_test_data  %>% mutate(dest = factor("PHL", levels = levels(LEX_test_data$dest)))
test_PHL_bake <- trained_flights_rec %>% bake(new_data = new_PHL_test_data)
test_PHL_bake %>% select_if(~ !is.numeric(.) || sum(.) != 0)
#> # A tibble: 1 x 10
#>   dep_time flight air_time distance time_hour           arr_delay origin_LGA
#>      <int>  <int>    <dbl>    <dbl> <dttm>              <fct>          <dbl>
#> 1     2026   3669       90      604 2013-11-24 20:00:00 on_time            1
#> # … with 3 more variables: dest_PHL <dbl>, date_dow_Mon <dbl>,
#> #   date_month_Nov <dbl>

Here there is a nonzero dest_PHL column. So recipes is treating LEX and ABQ the same, even though the choice of ABQ as a baseline was essentially arbitrary (because it came first in the factor). It's possible that when we predict on test data, the model does something smart with the dropped zero-variance columns so we can try that out.

lr_mod <-
  logistic_reg() %>%
  set_engine("glm")

flights_wflow <-
  workflow() %>%
  add_model(lr_mod) %>%
  add_recipe(flights_rec)

flights_fit <-
  flights_wflow %>%
  fit(data = train_data)

predict(flights_fit, LEX_test_data, type = "prob")
#> # A tibble: 1 x 2
#>   .pred_late .pred_on_time
#>        <dbl>         <dbl>
#> 1      0.376         0.624
predict(flights_fit, new_ABQ_test_data, type = "prob")
#> # A tibble: 1 x 2
#>   .pred_late .pred_on_time
#>        <dbl>         <dbl>
#> 1      0.376         0.624
predict(flights_fit, new_PHL_test_data, type = "prob")
#> # A tibble: 1 x 2
#>   .pred_late .pred_on_time
#>        <dbl>         <dbl>
#> 1      0.132         0.868

The model makes the same prediction for LEX and ABQ but not PHL as you might have predicted based on how recipes "baked" the test data above. This seems like potentially bad model behavior to me unless there is something I'm missing?

I appreciate the reprex! And feel like I'm starting to get a better grasp on what you are referring to. However, when I run your code, I run into errors, as it seems the set.seed variables cause the dest == "LEX" to be in my train data.

Despite this, if I'm not mistaken,by default the contrast function leaves one variable unaccounted for, as per https://bookdown.org/max/FES/creating-dummy-variables-for-unordered-categories.html

If this were the case, it may be choosing which variable to exclude based on the number of occurrences, in which LEX is the least occuring dest in the data set with only 1 occurance.

As an example, there should be one less dest in the recipe than there are unique levels in the data. Which seems to be the case.

Would you agree / disagree?

library(tidymodels) # for the recipes package, along with the rest of tidy-models
library(nycflights13) # for flight data
library(skimr) # for variable summaries
library(workflows)

set.seed(123)

flight_data <-
  flights %>%
  mutate(
    arr_delay = ifelse(arr_delay >= 30, "late", "on_time"),
    arr_delay = factor(arr_delay),
    date = as.Date(time_hour)
  ) %>%
  inner_join(weather, by = c("origin", "time_hour")) %>%
  select(
    dep_time, flight, origin, dest, air_time, distance,
    carrier, date, arr_delay, time_hour
  ) %>%
  na.omit() %>%
  mutate_if(is.character, as.factor)

set.seed(555)
data_split <- initial_split(flight_data, prop = 3 / 4)

train_data <- training(data_split)
test_data <- testing(data_split)

flights_rec <-
  recipe(arr_delay ~ ., data = train_data) %>%
  update_role(flight, time_hour, new_role = "ID") %>%
  step_date(date, features = c("dow", "month")) %>%
  step_holiday(date, holidays = timeDate::listHolidays("US")) %>%
  step_rm(date) %>%
  step_dummy(all_nominal(), -all_outcomes()) %>%
  step_zv(all_predictors())

trained_flights_rec <- flights_rec %>% prep(training = train_data)
#> Warning: The `x` argument of `as_tibble.matrix()` must have column names if `.name_repair` is omitted as of tibble 2.0.0.
#> Using compatibility `.name_repair`.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_warnings()` to see where this warning was generated.

test_bake <- trained_flights_rec %>% bake(new_data = test_data)
train_bake <- trained_flights_rec %>% bake(new_data = train_data)
test_full_bake <- trained_flights_rec %>% bake(new_data = flight_data)

test_bake %>% 
  select(contains("dest_")) %>% 
  ncol() == (length(unique(flight_data$dest))-1)
#> [1] TRUE
  
train_bake %>% 
  select(contains("dest_")) %>% 
  ncol() == (length(unique(flight_data$dest))-1)
#> [1] TRUE

test_full_bake %>% 
  select(contains("dest_")) %>% 
  ncol() == (length(unique(flight_data$dest))-1)
#> [1] TRUE

Hmm, I'm not sure what the issue was with the set.seed but in any case, I've copied the whole script below (created with reprex and including session info). You can see that when I run it, I get FALSE for the three checks you suggested. I totally agree that the dummy variable contrast function is working correctly (or at least as specified) and creates a treatment contrast with ABQ as the baseline treatment. I'm pretty sure it is dropping LEX because of step_zv in the recipe -- when LEX doesn't appear in the training set (because it is infrequent), it has no variance in the training set and therefore gets removed from the recipe. This is also what it says in the tutorial if I'm not mistaken? I guess what I'm really asking about is why the tutorial mentions needing to use step_zv to allow it to train without errors but doesn't mention the consequences for prediction. I would think a preferable solution would be to predict the intercept (i.e. mean) for unseen factors, which is what would happen if you used a regularized classifier with one-hot encoding. There may be an alternative way (using a sum coding contrast after dropping unseen factors or something) but I'm not sure if that exists in recipes yet?

library(tidymodels)      # for the recipes package, along with the rest of tidymodels
#> ── Attaching packages ─────────────────────────────────────────────────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom     0.7.0      ✓ recipes   0.1.13
#> ✓ dials     0.0.8      ✓ rsample   0.0.7 
#> ✓ dplyr     1.0.0      ✓ tibble    3.0.3 
#> ✓ ggplot2   3.3.2      ✓ tidyr     1.1.0 
#> ✓ infer     0.5.2      ✓ tune      0.1.1 
#> ✓ modeldata 0.0.2      ✓ workflows 0.1.2 
#> ✓ parsnip   0.1.2      ✓ yardstick 0.0.7 
#> ✓ purrr     0.3.4
#> Warning: package 'dials' was built under R version 4.0.2
#> Warning: package 'modeldata' was built under R version 4.0.1
#> Warning: package 'parsnip' was built under R version 4.0.1
#> Warning: package 'tune' was built under R version 4.0.2
#> Warning: package 'workflows' was built under R version 4.0.2
#> ── Conflicts ────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()

# Helper packages
library(nycflights13)    # for flight data
library(skimr)           # for variable summaries
#> Warning: package 'skimr' was built under R version 4.0.2

set.seed(123)

flight_data <- 
  flights %>% 
  mutate(
    # Convert the arrival delay to a factor
    arr_delay = ifelse(arr_delay >= 30, "late", "on_time"),
    arr_delay = factor(arr_delay),
    # We will use the date (not date-time) in the recipe below
    date = as.Date(time_hour)
  ) %>% 
  # Include the weather data
  inner_join(weather, by = c("origin", "time_hour")) %>% 
  # Only retain the specific columns we will use
  select(dep_time, flight, origin, dest, air_time, distance, 
         carrier, date, arr_delay, time_hour) %>% 
  # Exclude missing data
  na.omit() %>% 
  # For creating models, it is better to have qualitative columns
  # encoded as factors (instead of character strings)
  mutate_if(is.character, as.factor)

# Fix the random numbers by setting the seed 
# This enables the analysis to be reproducible when random numbers are used 
set.seed(555)
# Put 3/4 of the data into the training set 
data_split <- initial_split(flight_data, prop = 3/4)

# Create data frames for the two sets:
train_data <- training(data_split)
test_data  <- testing(data_split)

flights_rec <- 
  recipe(arr_delay ~ ., data = train_data) %>% 
  update_role(flight, time_hour, new_role = "ID") %>% 
  step_date(date, features = c("dow", "month")) %>% 
  step_holiday(date, holidays = timeDate::listHolidays("US")) %>% 
  step_rm(date) %>% 
  step_dummy(all_nominal(), -all_outcomes()) %>% 
  step_zv(all_predictors())

trained_flights_rec <- flights_rec %>% prep(training = train_data)
trained_flights_rec
#> Data Recipe
#> 
#> Inputs:
#> 
#>       role #variables
#>         ID          2
#>    outcome          1
#>  predictor          7
#> 
#> Training data contained 244365 data points and no missing data.
#> 
#> Operations:
#> 
#> Date features from date [trained]
#> Holiday features from date [trained]
#> Variables removed date [trained]
#> Dummy variables from origin, dest, carrier, date_dow, date_month [trained]
#> Zero variance filter removed dest_LEX [trained]

test_bake <- trained_flights_rec %>% bake(new_data = test_data)
train_bake <- trained_flights_rec %>% bake(new_data = train_data)
test_full_bake <- trained_flights_rec %>% bake(new_data = flight_data)

test_bake %>% 
  select(contains("dest_")) %>% 
  ncol() == (length(unique(flight_data$dest))-1)
#> [1] FALSE

train_bake %>% 
  select(contains("dest_")) %>% 
  ncol() == (length(unique(flight_data$dest))-1)
#> [1] FALSE

test_full_bake %>% 
  select(contains("dest_")) %>% 
  ncol() == (length(unique(flight_data$dest))-1)
#> [1] FALSE

test_bake <- trained_flights_rec %>% bake(new_data = test_data)
test_bake %>% select(starts_with("dest_L"))
#> # A tibble: 81,454 x 3
#>    dest_LAS dest_LAX dest_LGB
#>       <dbl>    <dbl>    <dbl>
#>  1        0        0        0
#>  2        0        0        0
#>  3        0        0        0
#>  4        0        0        0
#>  5        1        0        0
#>  6        0        0        0
#>  7        0        0        0
#>  8        0        0        0
#>  9        0        0        0
#> 10        0        0        0
#> # … with 81,444 more rows

LEX_test_data <- test_data %>% 
  filter(dest == "LEX")

test_LEX_bake <- trained_flights_rec %>% bake(new_data = LEX_test_data)
test_LEX_bake %>% select_if(~ !is.numeric(.) || sum(.) != 0)
#> # A tibble: 1 x 9
#>   dep_time flight air_time distance time_hour           arr_delay origin_LGA
#>      <int>  <int>    <dbl>    <dbl> <dttm>              <fct>          <dbl>
#> 1     2026   3669       90      604 2013-11-24 20:00:00 on_time            1
#> # … with 2 more variables: date_dow_Mon <dbl>, date_month_Nov <dbl>

new_ABQ_test_data = LEX_test_data  %>% mutate(dest = factor("ABQ", levels = levels(LEX_test_data$dest)))
test_ABQ_bake <- trained_flights_rec %>% bake(new_data = new_ABQ_test_data)
test_ABQ_bake %>% select_if(~ !is.numeric(.) || sum(.) != 0)
#> # A tibble: 1 x 9
#>   dep_time flight air_time distance time_hour           arr_delay origin_LGA
#>      <int>  <int>    <dbl>    <dbl> <dttm>              <fct>          <dbl>
#> 1     2026   3669       90      604 2013-11-24 20:00:00 on_time            1
#> # … with 2 more variables: date_dow_Mon <dbl>, date_month_Nov <dbl>

new_PHL_test_data = LEX_test_data  %>% mutate(dest = factor("PHL", levels = levels(LEX_test_data$dest)))
test_PHL_bake <- trained_flights_rec %>% bake(new_data = new_PHL_test_data)
test_PHL_bake %>% select_if(~ !is.numeric(.) || sum(.) != 0)
#> # A tibble: 1 x 10
#>   dep_time flight air_time distance time_hour           arr_delay origin_LGA
#>      <int>  <int>    <dbl>    <dbl> <dttm>              <fct>          <dbl>
#> 1     2026   3669       90      604 2013-11-24 20:00:00 on_time            1
#> # … with 3 more variables: dest_PHL <dbl>, date_dow_Mon <dbl>,
#> #   date_month_Nov <dbl>

lr_mod <-
  logistic_reg() %>%
  set_engine("glm")

flights_wflow <-
  workflow() %>%
  add_model(lr_mod) %>%
  add_recipe(flights_rec)

flights_fit <-
  flights_wflow %>%
  fit(data = train_data)

predict(flights_fit, LEX_test_data, type = "prob")
#> # A tibble: 1 x 2
#>   .pred_late .pred_on_time
#>        <dbl>         <dbl>
#> 1      0.376         0.624
predict(flights_fit, new_ABQ_test_data, type = "prob")
#> # A tibble: 1 x 2
#>   .pred_late .pred_on_time
#>        <dbl>         <dbl>
#> 1      0.376         0.624
predict(flights_fit, new_PHL_test_data, type = "prob")
#> # A tibble: 1 x 2
#>   .pred_late .pred_on_time
#>        <dbl>         <dbl>
#> 1      0.132         0.868

Created on 2020-07-15 by the reprex package (v0.3.0)

Session info
devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.0.0 (2020-04-24)
#>  os       macOS Catalina 10.15.5      
#>  system   x86_64, darwin17.0          
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       America/New_York            
#>  date     2020-07-15                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date       lib source        
#>  assertthat     0.2.1      2019-03-21 [1] CRAN (R 4.0.0)
#>  backports      1.1.8      2020-06-17 [1] CRAN (R 4.0.0)
#>  base64enc      0.1-3      2015-07-28 [1] CRAN (R 4.0.0)
#>  broom        * 0.7.0      2020-07-09 [1] CRAN (R 4.0.0)
#>  callr          3.4.3      2020-03-28 [1] CRAN (R 4.0.0)
#>  class          7.3-17     2020-04-26 [1] CRAN (R 4.0.0)
#>  cli            2.0.2      2020-02-28 [1] CRAN (R 4.0.0)
#>  codetools      0.2-16     2018-12-24 [1] CRAN (R 4.0.0)
#>  colorspace     1.4-1      2019-03-18 [1] CRAN (R 4.0.0)
#>  crayon         1.3.4      2017-09-16 [1] CRAN (R 4.0.0)
#>  desc           1.2.0      2018-05-01 [1] CRAN (R 4.0.0)
#>  devtools       2.3.0      2020-04-10 [1] CRAN (R 4.0.0)
#>  dials        * 0.0.8      2020-07-08 [1] CRAN (R 4.0.2)
#>  DiceDesign     1.8-1      2019-07-31 [1] CRAN (R 4.0.0)
#>  digest         0.6.25     2020-02-23 [1] CRAN (R 4.0.0)
#>  dplyr        * 1.0.0      2020-05-29 [1] CRAN (R 4.0.0)
#>  ellipsis       0.3.1      2020-05-15 [1] CRAN (R 4.0.0)
#>  evaluate       0.14       2019-05-28 [1] CRAN (R 4.0.0)
#>  fansi          0.4.1      2020-01-08 [1] CRAN (R 4.0.0)
#>  foreach        1.5.0      2020-03-30 [1] CRAN (R 4.0.0)
#>  fs             1.4.2      2020-06-30 [1] CRAN (R 4.0.1)
#>  furrr          0.1.0      2018-05-16 [1] CRAN (R 4.0.0)
#>  future         1.18.0     2020-07-09 [1] CRAN (R 4.0.0)
#>  generics       0.0.2      2018-11-29 [1] CRAN (R 4.0.0)
#>  ggplot2      * 3.3.2      2020-06-19 [1] CRAN (R 4.0.0)
#>  globals        0.12.5     2019-12-07 [1] CRAN (R 4.0.0)
#>  glue           1.4.1      2020-05-13 [1] CRAN (R 4.0.0)
#>  gower          0.2.2      2020-06-23 [1] CRAN (R 4.0.0)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.0.0)
#>  gtable         0.3.0      2019-03-25 [1] CRAN (R 4.0.0)
#>  hardhat        0.1.4      2020-07-02 [1] CRAN (R 4.0.0)
#>  highr          0.8        2019-03-20 [1] CRAN (R 4.0.0)
#>  htmltools      0.5.0      2020-06-16 [1] CRAN (R 4.0.0)
#>  infer        * 0.5.2      2020-06-14 [1] CRAN (R 4.0.0)
#>  ipred          0.9-9      2019-04-28 [1] CRAN (R 4.0.0)
#>  iterators      1.0.12     2019-07-26 [1] CRAN (R 4.0.0)
#>  jsonlite       1.7.0      2020-06-25 [1] CRAN (R 4.0.0)
#>  knitr          1.29       2020-06-23 [1] CRAN (R 4.0.0)
#>  lattice        0.20-41    2020-04-02 [1] CRAN (R 4.0.0)
#>  lava           1.6.7      2020-03-05 [1] CRAN (R 4.0.0)
#>  lhs            1.0.2      2020-04-13 [1] CRAN (R 4.0.0)
#>  lifecycle      0.2.0      2020-03-06 [1] CRAN (R 4.0.0)
#>  listenv        0.8.0      2019-12-05 [1] CRAN (R 4.0.0)
#>  lubridate      1.7.9      2020-06-08 [1] CRAN (R 4.0.0)
#>  magrittr       1.5        2014-11-22 [1] CRAN (R 4.0.0)
#>  MASS           7.3-51.6   2020-04-26 [1] CRAN (R 4.0.0)
#>  Matrix         1.2-18     2019-11-27 [1] CRAN (R 4.0.0)
#>  memoise        1.1.0      2017-04-21 [1] CRAN (R 4.0.0)
#>  modeldata    * 0.0.2      2020-06-22 [1] CRAN (R 4.0.1)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.0.0)
#>  nnet           7.3-14     2020-04-26 [1] CRAN (R 4.0.0)
#>  nycflights13 * 1.0.1      2019-09-16 [1] CRAN (R 4.0.0)
#>  parsnip      * 0.1.2      2020-07-03 [1] CRAN (R 4.0.1)
#>  pillar         1.4.6      2020-07-10 [1] CRAN (R 4.0.0)
#>  pkgbuild       1.1.0      2020-07-13 [1] CRAN (R 4.0.0)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.0.0)
#>  pkgload        1.1.0      2020-05-29 [1] CRAN (R 4.0.0)
#>  plyr           1.8.6      2020-03-03 [1] CRAN (R 4.0.0)
#>  prettyunits    1.1.1      2020-01-24 [1] CRAN (R 4.0.0)
#>  pROC           1.16.2     2020-03-19 [1] CRAN (R 4.0.0)
#>  processx       3.4.3      2020-07-05 [1] CRAN (R 4.0.1)
#>  prodlim        2019.11.13 2019-11-17 [1] CRAN (R 4.0.0)
#>  ps             1.3.3      2020-05-08 [1] CRAN (R 4.0.0)
#>  purrr        * 0.3.4      2020-04-17 [1] CRAN (R 4.0.0)
#>  R6             2.4.1      2019-11-12 [1] CRAN (R 4.0.0)
#>  Rcpp           1.0.5      2020-07-06 [1] CRAN (R 4.0.0)
#>  recipes      * 0.1.13     2020-06-23 [1] CRAN (R 4.0.0)
#>  remotes        2.1.1      2020-02-15 [1] CRAN (R 4.0.1)
#>  repr           1.1.0      2020-01-28 [1] CRAN (R 4.0.1)
#>  rlang          0.4.7      2020-07-09 [1] CRAN (R 4.0.0)
#>  rmarkdown      2.3        2020-06-18 [1] CRAN (R 4.0.0)
#>  rpart          4.1-15     2019-04-12 [1] CRAN (R 4.0.0)
#>  rprojroot      1.3-2      2018-01-03 [1] CRAN (R 4.0.0)
#>  rsample      * 0.0.7      2020-06-04 [1] CRAN (R 4.0.0)
#>  rstudioapi     0.11       2020-02-07 [1] CRAN (R 4.0.0)
#>  scales       * 1.1.1      2020-05-11 [1] CRAN (R 4.0.0)
#>  sessioninfo    1.1.1      2018-11-05 [1] CRAN (R 4.0.0)
#>  skimr        * 2.1.2      2020-07-06 [1] CRAN (R 4.0.2)
#>  stringi        1.4.6      2020-02-17 [1] CRAN (R 4.0.0)
#>  stringr        1.4.0      2019-02-10 [1] CRAN (R 4.0.0)
#>  survival       3.2-3      2020-06-13 [1] CRAN (R 4.0.0)
#>  testthat       2.3.2      2020-03-02 [1] CRAN (R 4.0.0)
#>  tibble       * 3.0.3      2020-07-10 [1] CRAN (R 4.0.0)
#>  tidymodels   * 0.1.1      2020-07-14 [1] CRAN (R 4.0.0)
#>  tidyr        * 1.1.0      2020-05-20 [1] CRAN (R 4.0.0)
#>  tidyselect     1.1.0      2020-05-11 [1] CRAN (R 4.0.0)
#>  timeDate       3043.102   2018-02-21 [1] CRAN (R 4.0.0)
#>  tune         * 0.1.1      2020-07-08 [1] CRAN (R 4.0.2)
#>  usethis        1.6.1      2020-04-29 [1] CRAN (R 4.0.0)
#>  utf8           1.1.4      2018-05-24 [1] CRAN (R 4.0.0)
#>  vctrs          0.3.1      2020-06-05 [1] CRAN (R 4.0.0)
#>  withr          2.2.0      2020-04-20 [1] CRAN (R 4.0.0)
#>  workflows    * 0.1.2      2020-07-07 [1] CRAN (R 4.0.2)
#>  xfun           0.15       2020-06-21 [1] CRAN (R 4.0.0)
#>  yaml           2.2.1      2020-02-01 [1] CRAN (R 4.0.0)
#>  yardstick    * 0.0.7      2020-07-13 [1] CRAN (R 4.0.0)
#> 
#> [1] /Library/Frameworks/R.framework/Versions/4.0/Resources/library

Based on the vignette https://recipes.tidymodels.org/articles/Dummies.html (see section Novel Levels), all new levels found only in the testing sample will be treated as missing values. I think (I haven't tried) ABQ and LEX will be assigned the same value/class when it comes to prediction.

If a column is removed by step_zv(), it is not included in the processed version of the data (obtained via bake() or juice()). The model has no parameters for that column.

For any data set, even if the data has the level that was removed by step_zv(), it will not be included by bake().

No matter the encoding, you don't have data to make a parameter so there's not much to do.

Thanks so much for the response! I understand why there are no parameters for the column, I guess it just seemed like a bad solution to me and something you wouldn't want to do in practice, since ABQ is a completely arbitrary choice of baseline.
One thing you could do (that would approximate what happens with regularized regression/one-hot encoding) would be to use sum contrast coding for dest so the intercept is average of all of the levels of that variable. That seems at least like a more sensible default to me? I realize that is not yet implemented in recipes though... I guess maybe it seems like it would still be worth pointing out this counterintuitive behavior (predicting that LEX is the same as ABQ) given that there's already a section of the tutorial devoted to it? Or maybe it's only counterintuitive to me?

Being able to use sum coding in recipes would also be pretty useful to me, for some projects -- I'd be somewhat interested in trying to implement it if you have any pointers on how to do so?

step_dummy() offers a one-hot encoding option. Other than that, the help file has

To change the type of contrast being used, change the global contrast option via options .

Thanks! I think that wouldn't do exactly what I'm proposing: if you have unseen predictors in the training set, that will destroy the sum-to-zero properties (which is what gives us the benefit of the model predicting the mean when it sees a novel factor level).

E.g. if we have three factors with one observation for each:

contr.sum(3)
#>   [,1] [,2]
#> 1    1    0
#> 2    0    1
#> 3   -1   -1

and then the unseen level is the second one (which is removed due to zero-variance using step_zv:

contr.sum(3)[c(1, 3), ]
#>   [,1] [,2]
#> 1    1    0
#> 3   -1   -1

The intercept will be a weighted mean where the 3rd level is counted twice as much as the second. You would need some way to remove the unseen factor before creating the contrast I think?

This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.