Bayesian Clustering

Prof. Sam Berchuck (developed with Braden Scherting)

Apr 08, 2025

Learning Objectives

  1. We will introduce the basic mixture modeling framework as a mechanism for model-based clustering and describe computational and inferential challenges.
  2. Variations of the popular finite Gaussian mixture model (GMM) will be introduced to cluster patients according to ED length-of-stay.
  3. We present an implementation of mixture modeling in Stan and discuss challenges therein.
  4. Finally, various posterior summaries will be explored.

Finding subtypes

Revisiting data on patients admitted to the emergency department (ED) from the MIMIC-IV-ED demo.

Can we identify subgroups within this population?

The usual setup

Most models introduced in this course are of the form:

f(Yi∣Xi)=f(Yi∣θi(Xi)).

  • f(⋅) is the density or distribution function of an assumed family (e.g., Gaussian, binomial),

  • θi is a parameter (or parameters) that may depend on individual covariates Xi.

The usual setup

f(Yi∣Xi)=f(Yi∣θi(Xi))

Linear regression:

  • f is the Gaussian density function, and θi(Xi)=(Xiβ,σ2)⊤

Binary classification:

  • f is the Bernoulli mass function, and θi(Xi)=logit(Xiβ)−1

Limitations of the usual setup

Suppose patients i=1,…,n are administered a diagnostic test. Their outcome Yi depends only on whether or not they have previously received treatment: Xi=1 if yes and Xi=0 otherwise. Suppose the diagnostic test has Gaussian-distributed measurement error, so Yi∣Xi∼N(α+βXi,σ2). Now, suppose past treatment information is not included in patients’ record—we cannot condition on Xi. Marginalizing, f(Yi)=P(Xi=1)×N(Yi∣α+β,σ2)+P(Xi=0)×N(Yi∣α,σ2).

Limitations of the usual setup

n <- 500; mu <- c(1,4.5); s2 <- 1
x <- sample(1:2, n, T); y <- rnorm(n, mu[x], sqrt(s2))
ggplot(data.frame(y = y), aes(x = y)) + 
  geom_histogram() + 
  labs(x = "Y", y = "Count")

Limitations of the usual setup

fit <- lm(y ~ 1)
ggplot(data.frame(residuals = fit$residuals), aes(x = residuals)) + 
  geom_histogram() + 
  labs(x = "Residuals", y = "Count")

Normality of residuals?

Mixture Model

Motivation for using a mixture model: Standard distributional families are not sufficiently expressive.

  • The inflexibility of the model may be due to unobserved heterogeneity (e.g., unrecorded treatment history).

Generically, f(Yi)=∑h=1kπh×fh(Yi).

Uses of mixture models:

  1. Modeling weird densities/distributions (e.g., bimodal).
  2. Learning latent groups/clusters.

Mixture Model

f(Yi)=∑h=1kπh×fh(Yi)

  • This mixture is comprised of k components indexed by h=1,…,k. For each component, we have a probability density (or mass) function fh and a mixture weight πh, where ∑h=1kπk=1.

  • When k is finite, we call this a finite mixture model for Yi.

  • It is common to let, fh(Yi)=f(Yi∣θh).

    • The component densities share a functional form and differ in their parameters.

Gaussian Mixture Model

Letting fh(Yi)=N(Yi∣μh,σh2) for h=1,…,k, yields the Gaussian mixture model. For Yi∈R, f(Yi)=∑h=1kπhN(Yi∣μh,σh2)

For multivariate outcomes Yi∈Rp,

f(Yi)=∑h=1kπhNp(Yi∣μh,Σh).

Gaussian Mixture Model

Consider a mixture model with 3 groups:

  • Mixture 1: μ1=−1.5,σ1=1.
  • Mixture 2: μ2=0,σ2=1.5.
  • Mixture 3: μ3=2,σ3=0.6.

Notice, both means μh and variances σh2 vary across clusters.

Generative perspective on GMM

To simulate from a k-component Gaussian mixture with means μ1,…,μk, variances σ12,…,σk2, and weights π1,…,πk:

  1. Sample the component indicator zi∈{1,…,k} with probabilities: P(zi=h)=πh⟺zi∼Categorical(k,{π1,…,πk}).
  2. Given zi, sample Yi from the appropriate component: (Yi∣zi=h)∼N(μh,σh2).

Generative perspective on GMM

n <- 500 
mu <- c(1, 4.5)
s2 <- 1 
# implicit: pi = c(0.5, 0.5)
z <- sample(1:2, n, TRUE)
y <- rnorm(n, mu[z], sqrt(s2))

This is essentially the code used to simulate the missing treatment history example.

Marginalizing Component Indicators

The label zi indicates which component Yi is drawn from—think of this as the cluster label: f(Yi∣zi=h)=N(Yi∣μh,σh2).

But zi is unknown, so we marginalize to obtain:

f(Yi)=∫Zf(Yi∣z)f(z)dz=∑h=1kf(Yi∣z=h)P(z=h)=∑h=1kN(Yi∣μh,σh2)×πh.

This is key to implementing in Stan.

Gaussian mixture in Stan

Component indicators zi are discrete parameters, which cannot be estimated in Stan. As before, suppose f(Yi)=∑h=1kπhN(Yi∣μh,σh2).

The log-likelihood is:

log⁡f(Yi)=log⁡∑h=1kexp⁡(log⁡[πhN(Yi∣μh,σh2)])=log_sum_exp[log⁡π1+log⁡N(Yi∣μ1,σ12),…,log⁡πk+log⁡N(Yi∣μk,σk2)],

log_sum_exp is a Stan function.

Gaussian mixture in Stan

// saved in mixture1.stan
data {
  int<lower = 1> k;          // number of mixture components
  int<lower = 1> n;          // number of data points
  array[n] real Y;           // observations
}
parameters {
  simplex[k] pi; // mixing proportions
  ordered[k] mu; // means of the mixture components
  vector<lower=0>[k] sigma; // sds of the mixture components
}
model {
  target += normal_lpdf(mu |0.0, 10.0);
  target += exponential_lpdf(sigma | 1.0);
  vector[k] log_probs = log(pi);
  for (i in 1:n){
    vector[k] lps = log_probs;
    for (h in 1:k){
      lps[h] += normal_lpdf(Y[i] | mu[h], sigma[h]);
    }
    target += log_sum_exp(lps);
  }
}

Of note: simplex and ordered types.

First fit

ed <- read.csv("exam1_data.csv")
dat <- list(Y = (ed$los - mean(ed$los)),
            n = length(ed$los),
            k = 2)
mod1 <- stan_model("mixture1.stan")
fit1 <- sampling(mod1, data=dat, chains=4, iter=5000, control=list("adapt_delta"=0.99))
print(fit1, pars = c("pi", "mu", "sigma"), probs = c(0.025, 0.975))
Inference for Stan model: anon_model.
4 chains, each with iter=5000; warmup=2500; thin=1; 
post-warmup draws per chain=2500, total post-warmup draws=10000.

          mean se_mean   sd  2.5% 97.5% n_eff Rhat
pi[1]     0.37    0.17 0.24  0.16  0.83     2 6.35
pi[2]     0.63    0.17 0.24  0.17  0.84     2 6.35
mu[1]    -3.17    0.09 0.51 -4.49 -2.49    33 1.05
mu[2]     0.45    3.96 5.69 -3.17 13.08     2 5.14
sigma[1] 13.25    4.48 6.46  2.09 20.11     2 4.97
sigma[2]  5.03    3.28 4.68  2.03 14.78     2 7.56

Samples were drawn using NUTS(diag_e) at Sat Mar 22 13:54:38 2025.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

What is going on?

pairs(fit1, pars = c("mu", "sigma"))

Bimodal posterior

  • In one mode, σ12≪σ22 and in the other, σ12≫σ22

Bimodal posterior

The Gaussian clusters have light tails, so outlying values of Y force large values of σh2. When σh2 is large, small changes to μh have little impact on the log-likelihood, and the ordering constraint is not sufficient to identify the clusters.

Things to consider when your mixture model is mixed up

Mixture modeling, especially when clusters are of interest, can be fickle.

  1. Different mixtures can give similar fit to data, leading to multimodal posteriors that are difficult to sample from (previous slides).
  2. Clusters will depend on your choice of fh—a Gaussian mixture model can only find Gaussian-shaped clusters.
  3. Increasing k often improves fit, but may muddle cluster interpretation.

Things to consider when your mixture model is mixed up

  1. Employ informative priors.
  2. Vary the number of clusters.
  3. Change the form of the kernel.

Updated model

// saved in mixture2.stan
data {
  int<lower = 1> k;          // number of mixture components
  int<lower = 1> n;          // number of data points
  array[n] real Y;         // observations
}
parameters {
  simplex[k] pi; // mixing proportions
  ordered[k] mu; // means of the mixture components
  vector<lower = 0>[k] sigma; // sds of the mixture components
  vector<lower = 1>[k] nu;
}
model {
  target += normal_lpdf(mu | 0.0, 10.0);
  target += normal_lpdf(sigma | 2.0, 0.5);
  target += gamma_lpdf(nu | 5.0, 0.5);
  vector[k] log_probs = log(pi);
  for (i in 1:n){
    vector[k] lps = log_probs;
    for (h in 1:k){
      lps[h] += student_t_lpdf(Y[i] | nu[h], mu[h], sigma[h]);
    }
    target += log_sum_exp(lps);
  }
}
  • Informative prior on σh2.
  • Mixture of Student-t.

Updated model fit

mod2 <- stan_model("mixture2.stan")
fit2 <- sampling(mod2, data=dat, chains=4, iter=5000, control=list("adapt_delta"=0.99))
print(fit2, pars=c("pi", "mu", "sigma", "nu"))
Inference for Stan model: anon_model.
4 chains, each with iter=5000; warmup=2500; thin=1; 
post-warmup draws per chain=2500, total post-warmup draws=10000.

          mean se_mean   sd  2.5%   25%   50%   75% 97.5% n_eff Rhat
pi[1]     0.81    0.00 0.04  0.73  0.79  0.81  0.83  0.87  5567    1
pi[2]     0.19    0.00 0.04  0.13  0.17  0.19  0.21  0.27  5567    1
mu[1]    -2.96    0.00 0.22 -3.39 -3.10 -2.96 -2.81 -2.53  7318    1
mu[2]     8.30    0.02 1.10  6.03  7.66  8.35  9.01 10.28  4306    1
sigma[1]  2.20    0.00 0.17  1.88  2.09  2.20  2.31  2.55  5987    1
sigma[2]  2.82    0.00 0.40  2.06  2.55  2.81  3.09  3.64  6745    1
nu[1]    11.98    0.04 4.44  5.14  8.75 11.38 14.57 22.22  9860    1
nu[2]     1.54    0.00 0.43  1.03  1.25  1.46  1.74  2.50  9286    1

Samples were drawn using NUTS(diag_e) at Sat Mar 29 12:52:38 2025.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

Updated model results

From marginal mixture model to clusters

Stan cannot directly infer categorical component indicators zi. Instead, for each individual, we compute

P(zi=h∣Yi,μ,σ,π)=f(Yi∣zi=h,μh,σh)P(zi=h∣πh)∑h′=1kf(Yi∣zi=h′,μh′,σh′)P(zi=h′∣πh′)=N(Yi|μh,σh)πh∑h′=1kN(Yi|μh′,σh′)πh′=pih.

Given these cluster membership probabilities, we can recover cluster indicators through simulation: (zi∣−)∼Categorical(k,{pi1,…,pik}).

From marginal mixture model to clusters

...

generated quantities {
  matrix[n,k] lPrZik;
  int<lower=1, upper=k> z[n];
  for (i in 1:n){
    for (h in 1:k){
      lPrZik[i,h] = log(pi[h]) + student_t_lpdf(Y[i] | nu[h], mu[h], sigma[h]);
    }
    lPrZik[i] -= log(sum(exp(lPrZik[i])));
    z[i] = categorical_rng(exp(lPrZik[i]'));
  }
}

Co-clustering probabilities

Recovering zi allows us to make the following pairwise comparison: what is the probability that unit i and unit j are in the same cluster? This is the “co-clustering probability”.

It is common to arrange these probabilities in a co-clustering matrix C, where the i,j entry is given by, 𝟙Cij=P(zi=zj∣−)≈1S∑s=1S1[zi(s)=zj(s)].

Co-clustering probabilities

How do our results change when we use more components?

k=3

How do our results change when we use more components?

k=4

Co-clusterings across k

The same general pattern persists when more clusters are used, indicating that k=2 is a reasonable choice.

Characterizing the Clusters

Characterizing the Clusters

Prepare for next class

  1. Reminder: On Thursday, we will have a in-class live-coding exercise.

  2. Begin working on Exam 02, which is due for feedback on April 15.

🔗 BIOSTAT 725 - Spring 2025

1 / 36
Bayesian Clustering Prof. Sam Berchuck (developed with Braden Scherting) Apr 08, 2025

  1. Slides

  2. Tools

  3. Close
  • Bayesian Clustering
  • Learning Objectives
  • Finding subtypes
  • The usual setup
  • The usual setup
  • Limitations of the usual setup
  • Limitations of the usual setup
  • Limitations of the usual setup
  • Mixture Model
  • Mixture Model
  • Gaussian Mixture Model
  • Gaussian Mixture Model
  • Generative perspective on GMM
  • Generative perspective on GMM
  • Marginalizing Component Indicators
  • Gaussian mixture in Stan
  • Gaussian mixture in Stan
  • First fit
  • What is going on?
  • Bimodal posterior
  • Bimodal posterior
  • Things to consider when your mixture model is mixed up
  • Things to consider when your mixture model is mixed up
  • Updated model
  • Updated model fit
  • Updated model results
  • From marginal mixture model to clusters
  • From marginal mixture model to clusters
  • Co-clustering probabilities
  • Co-clustering probabilities
  • How do our results change when we use more components?
  • How do our results change when we use more components?
  • Co-clusterings across k
  • Characterizing the Clusters
  • Characterizing the Clusters
  • Prepare for next class
  • f Fullscreen
  • s Speaker View
  • o Slide Overview
  • e PDF Export Mode
  • b Toggle Chalkboard
  • c Toggle Notes Canvas
  • d Download Drawings
  • ? Keyboard Help