Skip to contents
library(EM.examples)
library(ggplot2)
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(pander)

Example from McLachlan & Krishnan:

rm(list = ls())
n = 50
mu = c(0,2)
sigma = 1
pi = c(.8,.2)
set.seed(1)
Z = rbinom(n = n, size = 1, p = pi[2]) |> factor(labels = c("Z = 0", "Z = 1"))
Y = rnorm(n = n, mean = ifelse(Z == "Z = 0", mu[1], mu[2]), sd = sigma)
hist(Y, breaks = 10)

data1 = tibble(Z, Y)

Here’s the observed data:

ggplot(data1, aes(x = Y)) +
  geom_histogram() +
  theme(axis.text=element_text(size=14),
    axis.title=element_text(size=18,face="bold"))
#> `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

And here’s the (latent) complete data:

ggplot(data1, aes(x = Y)) +
  geom_histogram() +
  facet_wrap(~Z, ncol = 1) +
  theme(
    strip.text.x = element_text(size = 18),
    axis.text=element_text(size=14),
    axis.title=element_text(size=18,face="bold"))
#> `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.


`p(Y=y|Z=0)` = dnorm(Y, mu[1], sd = sigma)
`p(Y=y|Z=1)` = dnorm(Y, mu[2], sd = sigma)

likelihood = function(pi1)
{
  sapply(pi1, 
    function(x)
    {
      (x * `p(Y=y|Z=0)` + (1 - x) * `p(Y=y|Z=1)`) |> prod()
    })
  
}

loglik = function(pi1) log(likelihood(pi1))

par(mfrow = c(2,1))
plot(likelihood, xlim = c(0,1), xlab = "pi_1")
plot(loglik, xlim = c(0,1), xlab = "pi_1", ylab = "log(likelihood)")

EM Algorithm


`p(Z=0)`  = .5 # initial guess for `pi-hat`
diff = Inf
tolerance = .00001
progress = tibble(
  Iteration = 0,
  `p(Z=0)` = `p(Z=0)`,
  loglik =  loglik(`p(Z=0)`)
)
max_iterations = 1000

par(mfrow = c(1,1))

plot(loglik, xlim = c(0,1), xlab = "p(Z=0)", ylab = "log(likelihood)", lwd = 2)

for(i in 1:max_iterations)
{
  # E step:
  Estep = tibble(
    Y, # observed data
    `p(Y=y|Z=0)`,
    `p(Y=y|Z=1)`,
    `p(Y=y,Z=0)` = `p(Y=y|Z=0)`*`p(Z=0)`, # `p(Z=0)` = "pi-hat"
    `p(Y=y,Z=1)` = `p(Y=y|Z=1)`*(1 - `p(Z=0)`),
    `p(Y=y)` = `p(Y=y,Z=0)` + `p(Y=y,Z=1)`,
    `p(Z=0|Y=y)` = `p(Y=y,Z=0)`/`p(Y=y)`,
    `p(Z=1|Y=y)` = 1 - `p(Z=0|Y=y)` # == `p(Y=y,Z=1)`/`p(Y=y)`)
  )
  
  # M step
  `pi-hat-prev` = `p(Z=0)` # save the previous pi-hat estimate so we can graph our progress
  `p(Z=0)` = mean(Estep$`p(Z=0|Y=y)`) # here's the new pi-hat estimate
  
  Q = function(pi1)
  {
    sapply(pi1, 
      FUN = function(x)
      {
        with(Estep, sum((log(`p(Y=y|Z=0)`) + log(x))*`p(Z=0|Y=y)` +
            (log(`p(Y=y|Z=1)`) + log(1-x))*`p(Z=1|Y=y)`))
      })
  }
  
  # plot(loglik, xlim = c(0,1), xlab = "p(Z=0)", ylab = "log(likelihood)", lwd = 2)
  points(x = `pi-hat-prev`, y = loglik(`pi-hat-prev`), col = "red", pch = 16)
  plot(Q, add = TRUE, col = 'blue', lwd = 2)
  legend(x = "topleft", col = c('black', 'blue'), lty = 1, 
    lwd = 2,
    legend = c("log-likelihood", expression(Q(pi,hat(pi)))))
  points(x = `p(Z=0)`, y = loglik(`p(Z=0)`), pch = 16, col = 'orange')
  points(x = `p(Z=0)`, y = Q(`p(Z=0)`), col = "green", pch = 16)
  
  
  
  diff = `p(Z=0)` - `pi-hat-prev` # this is wrong; should be diff of logliks
  new_results = tibble(
    Iteration = i,
    `p(Z=0)` = `p(Z=0)`,
    loglik = loglik(`p(Z=0)`),
    `diff(loglik)` = diff)
  
  progress = 
    bind_rows(progress, new_results)
  
  if(diff < tolerance) break;  
}


pander(progress)
Iteration p(Z=0) loglik diff(loglik)
0 0.5 -82.76 NA
1 0.6721 -78.3 0.1721
2 0.7499 -77.23 0.07779
3 0.7858 -76.96 0.03589
4 0.8035 -76.89 0.01763
5 0.8125 -76.87 0.009062
6 0.8173 -76.87 0.004788
7 0.8199 -76.86 0.00257
8 0.8213 -76.86 0.001392
9 0.822 -76.86 0.0007572
10 0.8224 -76.86 0.0004131
11 0.8227 -76.86 0.0002257
12 0.8228 -76.86 0.0001234
13 0.8229 -76.86 6.75e-05
14 0.8229 -76.86 3.694e-05
15 0.8229 -76.86 2.021e-05
16 0.8229 -76.86 1.106e-05
17 0.8229 -76.86 6.054e-06