vignettes/shrink_intro.Rmd
shrink_intro.Rmd
Given \(n\) observations \(x_i\) with known standard deviations \(s_i > 0\), \(i = 1, \dots, n\), the normal means model (Robbins 1951; Efron and Morris 1972; Stephens 2017; Bhadra et al. 2019; Johnstone 2019; Sun 2020) has \[\begin{equation} x_i \overset{\text{ind.}}{\sim} \mathcal{N}(\theta_i, s_i^2), \end{equation}\] where the unknown (“true”) means \(\theta_i\) are the quantities to be estimated. Here and throughout, we use \(\mathcal{N}(\mu, \sigma^2)\) to denote the normal distribution with mean \(\mu\) and variance \(\sigma^2\).
The empirical Bayes (EB) approach to inferring \(\theta_i\) attempts to improve upon the maximum-likelihood estimate \(\hat{\theta}_i = x_i\) by “borrowing information” across observations, exploiting the fact that each observation contains information not only about its respective mean, but also about how the means are collectively distributed (Robbins 1956; Morris 1983; Efron 2010; Stephens 2017). Specifically, the empirical Bayes normal means (EBNM) approach assumes that \[\begin{equation} \theta_i \overset{\text{ind.}}{\sim} g \in \mathcal{G}, \end{equation}\] where \(\mathcal{G}\) is some family of distributions that is specified in advance and \(g \in \mathcal{G}\) is estimated using the data.
The EBNM model is fit by first using all of the observations to estimate the prior \(g \in \mathcal{G}\), and then using the estimated distribution \(\hat{g}\) to compute posteriors and/or posterior summaries for the “true” means \(\theta_i\). Commonly, \(g\) is estimated via maximum-likelihood and posterior means are used as point estimates for the unknown means. The ebnm package provides a unified interface for efficiently carrying out both steps, with a wide range of available options for the prior family \(\mathcal{G}\).
For a detailed introduction, see our ebnm paper. For further background, see for example John Storey’s book.
Our example data set consists of 400 data points simulated from a normal means model in which the true prior \(g\) is a mixture of (a) a normal distribution centered at 2 and (b) a point-mass also centered at 2:
\[ \theta_i \sim 0.8\delta_2 + 0.2 N(2,1) \]
First, we simulate the “true” means \(\theta_i\) from this prior:
Next, we simulate the observed means \(x_i\) as “noisy” estimates of the true means (in this example, the noise is homoskedastic):
\[ x_i \sim N(\theta_i,s_i), \quad s_i = 1/3, \]
Although we know what the true means are in this example, we’ll treat them as quantities we cannot observe.
The maximum-likelihood estimates (MLEs) of the true means are simply \(\hat{u}_i = x_i\):
par(mar = c(4, 4, 2, 2))
lims <- c(-0.55, 5.05)
plot(u, x, pch = 4, cex = 0.75, xlim = lims, ylim = lims,
xlab = "true value", ylab = "estimate", main = "MLE")
abline(a = 0, b = 1, col = "magenta", lty = "dotted")
We can do better than the MLE — and in fact some theory tells us we are guaranteed to do better — by learning a prior using all the observations, then “shrinking” the estimates toward this prior.
Let’s illustrate this idea with a simple normal prior in which the mean and variance of the normal prior are learned from the data. (Note that the normal prior is the wrong prior for this data set! Recall we that simulated data using a mixture of a normal and a point-mass.)
First, we fit the prior:
Next we estimate the true means using posterior means \(\hat{u}_i = E[\theta_i \,|\, x_i,
\hat{g}]\). We extract these posterior means using the
coef()
method:
y <- coef(fit_normal)
par(mar = c(4, 4, 2, 2))
plot(u, y, pch = 4, cex = 0.75, xlim = lims, ylim = lims,
xlab = "true value", ylab = "estimate", main = "normal prior")
abline(a = 0, b = 1, col = "magenta", lty = "dotted")
These “shrunken” estimates are better when true means \(\theta_i\) are near 2, but worse when they are far from 2. Still, they substantially improve the overall estimation error (the “root mean-squared error” or RMSE):
err_mle <- (x - u)^2
err_shrink_normal <- (y - u)^2
print(round(digits = 4,
x = c(mle = sqrt(mean(err_mle)),
shrink_normal = sqrt(mean(err_shrink_normal)))))
# mle shrink_normal
# 0.3599 0.2868
Here’s a more detailed comparison of the estimation error:
par(mar = c(4, 4, 2, 2))
plot(err_mle, err_shrink_normal, pch = 4, cex = 0.75,
xlim = c(0, 1.2), ylim = c(0, 1.2))
abline(a = 0, b = 1, col = "magenta", lty = "dotted")
Indeed, the error increases in a few of the estimates and decreases in many of the other estimates, resulting in a lower RMSE over the 400 data points.
Let’s now see what happens when we use a family of priors that is
better suited to this data set — specifically, the “point-normal”
family. Notice that the only change we make in our call to
ebnm()
is in the prior_family
argument:
fit_pn <- ebnm(x, s, prior_family = "point_normal", mode = "estimate")
Now we extract the posterior mean estimates and compare to the true values:
par(mar = c(4, 4, 2, 2))
y <- coef(fit_pn)
plot(u, y, pch = 4, cex = 0.75, xlim = lims, ylim = lims,
xlab = "true value", ylab = "estimate", main = "point-normal prior")
abline(a = 0, b = 1, col = "magenta", lty = "dotted")
The added flexibility of the point-normal prior improves the accuracy of estimates for means near 2, while estimates for means far from 2 are no worse than the MLEs. The result is that the overall RMSE again sees a substantial improvement:
The following R version and packages were used to generate this vignette:
sessionInfo()
# R version 4.3.2 (2023-10-31)
# Platform: aarch64-apple-darwin20 (64-bit)
# Running under: macOS Monterey 12.7.4
#
# Matrix products: default
# BLAS: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRblas.0.dylib
# LAPACK: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.11.0
#
# locale:
# [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#
# time zone: America/New_York
# tzcode source: internal
#
# attached base packages:
# [1] stats graphics grDevices utils datasets methods base
#
# other attached packages:
# [1] ebnm_1.1-34
#
# loaded via a namespace (and not attached):
# [1] sass_0.4.8 utf8_1.2.4 generics_0.1.3 ashr_2.2-63
# [5] stringi_1.8.3 lattice_0.21-9 digest_0.6.34 magrittr_2.0.3
# [9] evaluate_0.23 grid_4.3.2 fastmap_1.1.1 jsonlite_1.8.8
# [13] Matrix_1.6-1.1 mixsqp_0.3-54 purrr_1.0.2 fansi_1.0.6
# [17] scales_1.3.0 truncnorm_1.0-9 invgamma_1.1 textshaping_0.3.7
# [21] jquerylib_0.1.4 cli_3.6.2 rlang_1.1.3 deconvolveR_1.2-1
# [25] munsell_0.5.0 splines_4.3.2 cachem_1.0.8 yaml_2.3.8
# [29] tools_4.3.2 SQUAREM_2021.1 memoise_2.0.1 dplyr_1.1.4
# [33] colorspace_2.1-0 ggplot2_3.5.0 vctrs_0.6.5 R6_2.5.1
# [37] lifecycle_1.0.4 stringr_1.5.1 fs_1.6.3 trust_0.1-8
# [41] ragg_1.2.7 irlba_2.3.5.1 pkgconfig_2.0.3 desc_1.4.3
# [45] pkgdown_2.0.7 bslib_0.6.1 pillar_1.9.0 gtable_0.3.4
# [49] glue_1.7.0 Rcpp_1.0.12 systemfonts_1.0.5 xfun_0.41
# [53] tibble_3.2.1 tidyselect_1.2.1 highr_0.10 rstudioapi_0.15.0
# [57] knitr_1.45 htmltools_0.5.7 rmarkdown_2.25 compiler_4.3.2
# [61] horseshoe_0.2.0