Bayesian Clustering

Prof. Sam Berchuck (developed with Braden Scherting)

Apr 08, 2025

Learning Objectives

  1. We will introduce the basic mixture modeling framework as a mechanism for model-based clustering and describe computational and inferential challenges.
  2. Variations of the popular finite Gaussian mixture model (GMM) will be introduced to cluster patients according to ED length-of-stay.
  3. We present an implementation of mixture modeling in Stan and discuss challenges therein.
  4. Finally, various posterior summaries will be explored.

Finding subtypes

Revisiting data on patients admitted to the emergency department (ED) from the MIMIC-IV-ED demo.

Can we identify subgroups within this population?

The usual setup

Most models introduced in this course are of the form:

\[f\left(Y_i\mid X_i\right) = f\left(Y_i\mid \boldsymbol{\theta}_i\left(X_i\right)\right).\]

  • \(f(\cdot)\) is the density or distribution function of an assumed family (e.g., Gaussian, binomial),

  • \(\boldsymbol{\theta}_i\) is a parameter (or parameters) that may depend on individual covariates \(X_i\).

The usual setup

\[f\left(Y_i\mid X_i\right) = f\left(Y_i\mid \boldsymbol{\theta}_i\left(X_i\right)\right)\]

Linear regression:

  • \(f\) is the Gaussian density function, and \(\boldsymbol{\theta}_i(X_i)=(X_i\beta, \sigma^2)^\top\)

Binary classification:

  • \(f\) is the Bernoulli mass function, and \(\boldsymbol{\theta}_i(X_i)=\text{logit}(X_i\beta)^{-1}\)

Limitations of the usual setup

Suppose patients \(i=1,\dots,n\) are administered a diagnostic test. Their outcome \(Y_i\) depends only on whether or not they have previously received treatment: \(X_i=1\) if yes and \(X_i=0\) otherwise. Suppose the diagnostic test has Gaussian-distributed measurement error, so \[Y_i\mid X_i \sim N(\alpha + \beta X_i, \sigma^2).\] Now, suppose past treatment information is not included in patients’ record—we cannot condition on \(X_i\). Marginalizing, \[\begin{align*} f(Y_i) &= P(X_i=1)\times N(Y_i\mid \alpha + \beta, \sigma^2) \\ & +P(X_i=0)\times N(Y_i\mid \alpha, \sigma^2). \end{align*}\]

Limitations of the usual setup

n <- 500; mu <- c(1,4.5); s2 <- 1
x <- sample(1:2, n, T); y <- rnorm(n, mu[x], sqrt(s2))
ggplot(data.frame(y = y), aes(x = y)) + 
  geom_histogram() + 
  labs(x = "Y", y = "Count")

Limitations of the usual setup

fit <- lm(y ~ 1)
ggplot(data.frame(residuals = fit$residuals), aes(x = residuals)) + 
  geom_histogram() + 
  labs(x = "Residuals", y = "Count")

Normality of residuals?

Mixture Model

Motivation for using a mixture model: Standard distributional families are not sufficiently expressive.

  • The inflexibility of the model may be due to unobserved heterogeneity (e.g., unrecorded treatment history).

Generically, \[f(Y_i) = \sum_{h = 1}^k \pi_h \times f_h(Y_i).\]

Uses of mixture models:

  1. Modeling weird densities/distributions (e.g., bimodal).
  2. Learning latent groups/clusters.

Mixture Model

\[f(Y_i) = \sum_{h=1}^k \pi_h\times f_h(Y_i)\]

  • This mixture is comprised of \(k\) components indexed by \(h=1,\dots,k\). For each component, we have a probability density (or mass) function \(f_h\) and a mixture weight \(\pi_h\), where \(\sum_{h=1}^k \pi_k=1\).

  • When \(k\) is finite, we call this a finite mixture model for \(Y_i\).

  • It is common to let, \[f_h(Y_i) = f(Y_i\mid \boldsymbol{\theta}_h).\]

    • The component densities share a functional form and differ in their parameters.

Gaussian Mixture Model

Letting \(f_h(Y_i) = N(Y_i\mid \mu_h, \sigma^2_h)\) for \(h=1,\dots,k\), yields the Gaussian mixture model. For \(Y_i\in\mathbb{R}\), \[f(Y_i) = \sum_{h=1}^k \pi_h N\left({Y}_i\mid \mu_h, \sigma^2_h\right)\]

For multivariate outcomes \(\mathbf{Y}_i\in\mathbb{R}^p\),

\[ f(\mathbf{Y}_i) = \sum_{h=1}^k \pi_h N_p\left(\mathbf{Y}_i\mid\boldsymbol{\mu}_h, \boldsymbol{\Sigma}_h\right).\]

Gaussian Mixture Model

Consider a mixture model with 3 groups:

  • Mixture 1: \(\mu_1 = -1.5, \sigma_1 = 1\).
  • Mixture 2: \(\mu_2 = 0, \sigma_2 = 1.5\).
  • Mixture 3: \(\mu_3 = 2, \sigma_3 = 0.6\).

Notice, both means \(\mu_h\) and variances \(\sigma^2_h\) vary across clusters.

Generative perspective on GMM

To simulate from a \(k\)-component Gaussian mixture with means \(\mu_1,\dots,\mu_k\), variances \(\sigma_1^2,\dots,\sigma^2_k\), and weights \(\pi_1,\dots,\pi_k\):

  1. Sample the component indicator \(z_i\in \{1, \dots,k\}\) with probabilities: \[P(z_i=h) = \pi_h \iff z_i \sim \text{Categorical}(k, \{\pi_1,\ldots,\pi_k\}).\]
  2. Given \(z_i\), sample \(Y_i\) from the appropriate component: \[\left(Y_i\mid z_i =h\right) \sim N\left(\mu_h, \sigma^2_h\right).\]

Generative perspective on GMM

n <- 500 
mu <- c(1, 4.5)
s2 <- 1 
# implicit: pi = c(0.5, 0.5)
z <- sample(1:2, n, TRUE)
y <- rnorm(n, mu[z], sqrt(s2))

This is essentially the code used to simulate the missing treatment history example.

Marginalizing Component Indicators

The label \(z_i\) indicates which component \(Y_i\) is drawn from—think of this as the cluster label: \(f\left(Y_i\mid z_i=h\right) = N\left(Y_i\mid \mu_h,\sigma^2_h \right).\)

But \(z_i\) is unknown, so we marginalize to obtain:

\[\begin{align*} f(Y_i) &= \int_\mathcal{Z}f\left(Y_i\mid z\right) f(z)dz \\ &= \sum_{h=1}^k f\left(Y_i\mid z=h\right) P(z=h) \\ &= \sum_{h=1}^k N\left(Y_i\mid \mu_h,\sigma^2_h \right) \times \pi_h. \end{align*}\]

This is key to implementing in Stan.

Gaussian mixture in Stan

Component indicators \(z_i\) are discrete parameters, which cannot be estimated in Stan. As before, suppose \(f(Y_i) = \sum_{h=1}^k \pi_h N\left(Y_i\mid \mu_h,\sigma^2_h \right)\).

The log-likelihood is:

\[\begin{align*} \log f(Y_i) &= \log \sum_{h=1}^k \exp \left(\log\left[\pi_h N\left(Y_i\mid \mu_h,\sigma^2_h \right) \right]\right)\\ &= \verb|log_sum_exp| \left[\log\pi_1 + \log N\left(Y_i\mid \mu_1,\sigma^2_1 \right),\right. \\ &\quad\quad\quad\quad\quad\quad\quad\dots, \\ &\quad\quad\quad\quad\quad\quad\quad \left.\log\pi_k + \log N\left(Y_i\mid \mu_k,\sigma^2_k \right) \right], \end{align*}\]

log_sum_exp is a Stan function.

Gaussian mixture in Stan

// saved in mixture1.stan
data {
  int<lower = 1> k;          // number of mixture components
  int<lower = 1> n;          // number of data points
  array[n] real Y;           // observations
}
parameters {
  simplex[k] pi; // mixing proportions
  ordered[k] mu; // means of the mixture components
  vector<lower=0>[k] sigma; // sds of the mixture components
}
model {
  target += normal_lpdf(mu |0.0, 10.0);
  target += exponential_lpdf(sigma | 1.0);
  vector[k] log_probs = log(pi);
  for (i in 1:n){
    vector[k] lps = log_probs;
    for (h in 1:k){
      lps[h] += normal_lpdf(Y[i] | mu[h], sigma[h]);
    }
    target += log_sum_exp(lps);
  }
}

Of note: simplex and ordered types.

First fit

ed <- read.csv("exam1_data.csv")
dat <- list(Y = (ed$los - mean(ed$los)),
            n = length(ed$los),
            k = 2)
mod1 <- stan_model("mixture1.stan")
fit1 <- sampling(mod1, data=dat, chains=4, iter=5000, control=list("adapt_delta"=0.99))
print(fit1, pars = c("pi", "mu", "sigma"), probs = c(0.025, 0.975))
Inference for Stan model: anon_model.
4 chains, each with iter=5000; warmup=2500; thin=1; 
post-warmup draws per chain=2500, total post-warmup draws=10000.

          mean se_mean   sd  2.5% 97.5% n_eff Rhat
pi[1]     0.37    0.17 0.24  0.16  0.83     2 6.35
pi[2]     0.63    0.17 0.24  0.17  0.84     2 6.35
mu[1]    -3.17    0.09 0.51 -4.49 -2.49    33 1.05
mu[2]     0.45    3.96 5.69 -3.17 13.08     2 5.14
sigma[1] 13.25    4.48 6.46  2.09 20.11     2 4.97
sigma[2]  5.03    3.28 4.68  2.03 14.78     2 7.56

Samples were drawn using NUTS(diag_e) at Sat Mar 22 13:54:38 2025.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

What is going on?

pairs(fit1, pars = c("mu", "sigma"))

Bimodal posterior

  • In one mode, \(\sigma^2_1 \ll \sigma^2_2\) and in the other, \(\sigma^2_1\gg\sigma^2_2\)

Bimodal posterior

The Gaussian clusters have light tails, so outlying values of \(Y\) force large values of \(\sigma^2_h\). When \(\sigma^2_h\) is large, small changes to \(\mu_h\) have little impact on the log-likelihood, and the ordering constraint is not sufficient to identify the clusters.

Things to consider when your mixture model is mixed up

Mixture modeling, especially when clusters are of interest, can be fickle.

  1. Different mixtures can give similar fit to data, leading to multimodal posteriors that are difficult to sample from (previous slides).
  2. Clusters will depend on your choice of \(f_h\)—a Gaussian mixture model can only find Gaussian-shaped clusters.
  3. Increasing \(k\) often improves fit, but may muddle cluster interpretation.

Things to consider when your mixture model is mixed up

  1. Employ informative priors.
  2. Vary the number of clusters.
  3. Change the form of the kernel.

Updated model

// saved in mixture2.stan
data {
  int<lower = 1> k;          // number of mixture components
  int<lower = 1> n;          // number of data points
  array[n] real Y;         // observations
}
parameters {
  simplex[k] pi; // mixing proportions
  ordered[k] mu; // means of the mixture components
  vector<lower = 0>[k] sigma; // sds of the mixture components
  vector<lower = 1>[k] nu;
}
model {
  target += normal_lpdf(mu | 0.0, 10.0);
  target += normal_lpdf(sigma | 2.0, 0.5);
  target += gamma_lpdf(nu | 5.0, 0.5);
  vector[k] log_probs = log(pi);
  for (i in 1:n){
    vector[k] lps = log_probs;
    for (h in 1:k){
      lps[h] += student_t_lpdf(Y[i] | nu[h], mu[h], sigma[h]);
    }
    target += log_sum_exp(lps);
  }
}
  • Informative prior on \(\sigma^2_h\).
  • Mixture of Student-t.

Updated model fit

mod2 <- stan_model("mixture2.stan")
fit2 <- sampling(mod2, data=dat, chains=4, iter=5000, control=list("adapt_delta"=0.99))
print(fit2, pars=c("pi", "mu", "sigma", "nu"))
Inference for Stan model: anon_model.
4 chains, each with iter=5000; warmup=2500; thin=1; 
post-warmup draws per chain=2500, total post-warmup draws=10000.

          mean se_mean   sd  2.5%   25%   50%   75% 97.5% n_eff Rhat
pi[1]     0.81    0.00 0.04  0.73  0.79  0.81  0.83  0.87  5567    1
pi[2]     0.19    0.00 0.04  0.13  0.17  0.19  0.21  0.27  5567    1
mu[1]    -2.96    0.00 0.22 -3.39 -3.10 -2.96 -2.81 -2.53  7318    1
mu[2]     8.30    0.02 1.10  6.03  7.66  8.35  9.01 10.28  4306    1
sigma[1]  2.20    0.00 0.17  1.88  2.09  2.20  2.31  2.55  5987    1
sigma[2]  2.82    0.00 0.40  2.06  2.55  2.81  3.09  3.64  6745    1
nu[1]    11.98    0.04 4.44  5.14  8.75 11.38 14.57 22.22  9860    1
nu[2]     1.54    0.00 0.43  1.03  1.25  1.46  1.74  2.50  9286    1

Samples were drawn using NUTS(diag_e) at Sat Mar 29 12:52:38 2025.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

Updated model results

From marginal mixture model to clusters

Stan cannot directly infer categorical component indicators \(z_i\). Instead, for each individual, we compute

\[\begin{align*} P\left(z_i = h \mid Y_i, \boldsymbol{\mu},\boldsymbol{\sigma},\boldsymbol{\pi} \right) &= \frac{f(Y_i\mid z_i = h, \mu_h,\sigma_h)P(z_i=h\mid \pi_h)}{\sum_{h'=1}^k f(Y_i\mid z_i = h', \mu_{h'},\sigma_{h'})P(z_i=h'\mid \pi_{h'})}\\ &= \frac{N(Y_i | \mu_{h},\sigma_{h})\pi_{h}}{\sum_{h' = 1}^k N(Y_i | \mu_{h'},\sigma_{h'})\pi_{h'}} = p_{ih}. \end{align*}\]

Given these cluster membership probabilities, we can recover cluster indicators through simulation: \[(z_i\mid -) \sim \text{Categorical}\left(k, \left\{ p_{i1},\dots,p_{ik} \right\}\right).\]

From marginal mixture model to clusters

...

generated quantities {
  matrix[n,k] lPrZik;
  int<lower=1, upper=k> z[n];
  for (i in 1:n){
    for (h in 1:k){
      lPrZik[i,h] = log(pi[h]) + student_t_lpdf(Y[i] | nu[h], mu[h], sigma[h]);
    }
    lPrZik[i] -= log(sum(exp(lPrZik[i])));
    z[i] = categorical_rng(exp(lPrZik[i]'));
  }
}

Co-clustering probabilities

Recovering \(z_i\) allows us to make the following pairwise comparison: what is the probability that unit \(i\) and unit \(j\) are in the same cluster? This is the “co-clustering probability”.

It is common to arrange these probabilities in a co-clustering matrix \(\mathbf{C}\), where the \(i,j\) entry is given by, \[C_{ij}=P\left( z_i=z_j\mid- \right)\approx \frac{1}{S}\sum_{s=1}^S \mathbb{1}\left[z_i^{(s)}=z_j^{(s)}\right].\]

Co-clustering probabilities

How do our results change when we use more components?

\(k=3\)

How do our results change when we use more components?

\(k=4\)

Co-clusterings across \(k\)

The same general pattern persists when more clusters are used, indicating that \(k=2\) is a reasonable choice.

Characterizing the Clusters

Characterizing the Clusters

Prepare for next class

  1. Reminder: On Thursday, we will have a in-class live-coding exercise.

  2. Begin working on Exam 02, which is due for feedback on April 15.