Select models with lowest RMSE

In the following example, for each group i.e. Type, how do I keep the model with lowest RMSE ? The goal is to have mable with the selected models.

suppressWarnings(suppressMessages(library(fpp3)))

# Data
toy_data <- PBS %>% 
  filter(ATC1 == "A", ATC2 == "A01", Concession == "General") %>% 
  select(- Scripts)

train_data <- toy_data %>% 
  filter_index(~ "2005 Dec") 

# Model: Forecast
forecasts <- train_data %>% 
  model(
    # Model 1
    `STL + ARIMA` = decomposition_model(
      STL(Cost ~ trend(window = 21) + season(window=13), robust = TRUE),
      ARIMA(season_adjust)),
    # Model 2
    ARIMA = ARIMA(Cost)
  ) %>% 
  forecast(h = 36)

# RMSE
accuracy(forecasts, toy_data) 
#> Warning: The future dataset is incomplete, incomplete out-of-sample data will be treated as missing. 
#> 6 observations are missing between 2008 Jul and 2008 Dec
#> # A tibble: 4 x 13
#>   .model Concession Type  ATC1  ATC2  .type     ME   RMSE    MAE   MPE  MAPE
#>   <chr>  <chr>      <chr> <chr> <chr> <chr>  <dbl>  <dbl>  <dbl> <dbl> <dbl>
#> 1 ARIMA  General    Co-p~ A     A01   Test   -43.2   55.9   51.4 -Inf   Inf 
#> 2 ARIMA  General    Safe~ A     A01   Test   340.   694.   578.   131.  139.
#> 3 STL +~ General    Co-p~ A     A01   Test   -44.6   59.4   51.8 -Inf   Inf 
#> 4 STL +~ General    Safe~ A     A01   Test  2293.  2403.  2293.   497.  497.
#> # ... with 2 more variables: MASE <dbl>, ACF1 <dbl>

Created on 2020-10-27 by the reprex package (v0.3.0)

Thanks for providing a reproducible example. It makes answering much easier. Here is some code to do what you want.

suppressWarnings(suppressMessages(library(fpp3)))

# Data
toy_data <- PBS %>% 
  filter(ATC1 == "A", ATC2 == "A01", Concession == "General") %>% 
  select(- Scripts)

train_data <- toy_data %>% 
  filter_index(~ "2005 Dec") 

# Fit all models
fit <- train_data %>% 
  model(
    # Model 1
    `STL + ARIMA` = decomposition_model(
      STL(Cost ~ trend(window = 21) + season(window=13), robust = TRUE),
      ARIMA(season_adjust)),
    # Model 2
    ARIMA = ARIMA(Cost)
  )

# Forecasts from all models
forecasts <- fit %>% 
  forecast(h = 36)

# Find best models using RMSE
bestrmse <- accuracy(forecasts, toy_data) %>%
  group_by(Concession, Type, ATC1, ATC2) %>%
  filter(RMSE == min(RMSE)) %>%
  select(.model:ATC2)

# Keep best forecasts 
bestfc <- forecasts %>%
  right_join(bestrmse)
#> Joining, by = c("Concession", "Type", "ATC1", "ATC2", ".model")

# Modify mable to only keep the best models
bestfits <- fit %>%
  pivot_longer(cols=`STL + ARIMA`:ARIMA, names_to = ".model", values_to = "fit") %>%
  right_join(bestrmse) %>%
  mutate(.model = "best") %>%
  pivot_wider(Concession:ATC2, names_from = ".model", values_from = "fit") %>%
  as_mable(key = c(Concession, Type, ATC1, ATC2), model=best)
#> Joining, by = c("Concession", "Type", "ATC1", "ATC2", ".model")

Created on 2020-10-28 by the reprex package (v0.3.0)

3 Likes

Thanks a lot! Just what I wanted.

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.