Shrink coeffs to mean based on sample size of group (hierarchical models)

Below is a reprex where I use the gapminder data to fit a simple model by continent.

I am curious how people could solve the problem of "shinking" coefficients to one or more group level means. I have heard that brms and lme4 are good libraries for this kind of analysis. Also, what is the difference between some of these shrinkage algorithms and simply doing some kind of weighted average?

library(gapminder)
library(tidyverse)
library(ggrepel)
#> Warning: package 'ggrepel' was built under R version 3.5.2

# simple lifeExp  model
lifeExp_model <- function(df) {
  lm(lifeExp ~ year, data = df) %>% 
    broom::tidy() %>% 
    select(term, estimate) 
}

# change factors to char
gapminder$country <- as.character(gapminder$country)
gapminder$continent <- as.character(gapminder$continent)

# continent model (grouped)
continent_fit <- gapminder %>% 
  group_by(continent) %>% 
  nest() %>% 
  mutate(model = map(data, lifeExp_model)) %>% 
  unnest(model) %>% 
  rename(continent_term = term, continent_estimate = estimate)

# pooled
pooled_fit <- gapminder %>% 
  nest() %>% 
  mutate(model = map(data, lifeExp_model)) %>% 
  unnest(model) %>% 
  rename(pooled_term = term, pooled_estimate = estimate)

# number of observation per continent
n_obs <- gapminder %>% 
  group_by(continent) %>% 
  summarise(n_country = n_distinct(country))

# gouped and pooled fit
gnp_fit <- continent_fit %>% 
  inner_join(n_obs) %>% 
  inner_join(pooled_fit, by = c( "continent_term" = "pooled_term"))
#> Joining, by = "continent"

gnp_fit %>% 
  filter(continent_term == "year") %>% 
  ggplot(aes(x = n_country, y = continent_estimate )) +
  geom_hline(aes(yintercept = pooled_estimate)) +
  geom_point() +
  ggrepel::geom_label_repel(aes(label = continent))+
  theme_classic() 

Created on 2019-02-02 by the reprex package (v0.2.1)

My partner wrote a great note on partial-pooling in vtreat, that has some good references to Gelman and others.

1 Like

This is not an answer so much as a comment that a full answer to this question has an approximately textbook length response.

I recommend picking up Statistical Rethinking by Richard McElreath for a general intro to multilevel models.

2 Likes

After some searching and tutorials this is one way to look at the problem:

suppressMessages(library(gapminder))
suppressMessages(library(tidyverse))
suppressMessages(library(ggrepel))
suppressMessages(library(recipes))
suppressMessages(library(brms))
suppressMessages(library(tidybayes))

# scale the outcome 
gapminder_rec <- recipe(lifeExp ~ year + country + continent, data = gapminder) %>%
  step_center(all_outcomes()) %>%
  step_scale(all_outcomes()) %>% 
  prep(., training = gapminder, retain = TRUE)
g <- juice(gapminder_rec)

# fit a hierarchical model where a "year" coefficient is provided for each 
# continent. I am assuming this code shrinks a grouped (aka. many models) year
# coefficient to a pooled estimate. 
fit_brms2 <- brms::brm(lifeExp ~ 1 + (year | continent), data = g,
                       control = list(adapt_delta = .99),
                       cores = getOption("mc.cores",2L),
                       chains = 2L,
                       prior = c(prior(student_t(3, 0, 1), class = sigma)))

# clean the brms estimates for joining to "many models" estimate later
h_bayes_estiamtes <- tidybayes::gather_draws(fit_brms2, `r_.*`, regex = TRUE)  %>%
  group_by(.variable) %>% 
  summarise(h_bayes = mean(.value)) %>% 
  mutate(continent = stringr::str_extract(.variable, regex("(?<=\\[)(.+)(?=,)"))) %>% 
  mutate(continent_term  = stringr::str_extract(.variable, regex("(?<=,)(.+)(?=\\])"))) %>% 
  mutate(continent_term = ifelse(continent_term == "Intercept", "(Intercept)", continent_term)) %>% 
  select(-.variable)

# Many Models lifeExp_model
lifeExp_model <- function(df) {
  lm(lifeExp ~ year, data = df) %>% 
    broom::tidy() %>% 
    select(term, estimate) 
}

# change factors to char
gapminder$country <- as.character(gapminder$country)
gapminder$continent <- as.character(gapminder$continent)

# continent model (grouped)
continent_fit <- g %>% 
  group_by(continent) %>% 
  nest() %>% 
  mutate(model = map(data, lifeExp_model)) %>% 
  unnest(model) %>% 
  rename(continent_term = term, continent_estimate = estimate)

# pooled model for an overall mean estiamte for year
pooled_fit <- g %>% 
  nest() %>% 
  mutate(model = map(data, lifeExp_model)) %>% 
  unnest(model) %>% 
  rename(pooled_term = term, pooled_estimate = estimate)

# number of observation per continent
n_obs <- g %>% 
  group_by(continent) %>% 
  summarise(n_country = n_distinct(country))

# gouped, pooled, and hierarchical fit
gnp_fit <- continent_fit %>% 
  inner_join(n_obs) %>% 
  inner_join(pooled_fit, by = c( "continent_term" = "pooled_term")) %>% 
  inner_join(h_bayes_estiamtes)

# plot gouped, pooled, and hierarchical fit
# hope to see hierarchical pulled more to
# pooled estimate if the group has a smaller
# samples size 
gnp_fit %>% 
  filter(continent_term == "year") %>% 
  gather(key = estimate_type, value = year_beta, -n_country, -pooled_estimate, -continent_term , -continent) %>% 
  ggplot(aes(x = n_country, y = year_beta, shape = estimate_type, color = continent)) +
  geom_hline(aes(yintercept = pooled_estimate)) +
  geom_point(size = 2) +
  theme_classic() +
  scale_shape_manual(values=c(1, 16))

I am not sure this is correct. In fact, something seems off because the h_bayes estimates don't always shrink in the correct direction.

The embed package let's you do this inside a recipe.

The embed documentation was very eye opening and useful. I recently made the switch to parsnip and recipes. Tidymodels is very convenient to use. Thanks!

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.