library('rgl') library('mvtnorm') library('dplyr') set.seed(42) ## simulate data from a three component, bivariate normal mixture mu <- list(c(0,0),c(0,1),c(1,0)) sigma <- (diag(2) + 0.25)/30 dat <- as.data.frame(do.call(rbind, lapply(mu, function(m) rmvnorm(10, m, sigma)))) names(dat) <- c("x1", "x2") dat$class <- rep(1:3, rep(10,3)) cols <- c("#887733", "#882255", "#332288") plot(x=dat$x1, y=dat$x2, xlab=expression(x[1]), ylab=expression(x[2]), col=cols[dat$class], pch=19) ## define ranges and a grid of x (input) values xlm1 <- c(-0.75, 1.75) xlm2 <- c(-0.75, 1.75) xrg1 <- seq(xlm1[1],xlm1[2],length.out=10) xrg2 <- seq(xlm2[1],xlm2[2],length.out=10) zlm <- c(0, 1.1) ## estimate class means mu_hat <- lapply(unique(dat$class), function(cl) { dat %>% filter(class == cl) %>% select(x1, x2) %>% colMeans }) ## estimate pooled variance sigma_hat <- lapply(unique(dat$class), function(cl) { sig <- dat %>% filter(class == cl) %>% select(x1, x2) %>% cov sig * (20-1)/(60-3) }) %>% Reduce('+', .) ## estimate pi (prior) pi_hat <- as.list(rep(1/3,3)) ## discriminant functions delta <- function(x, k, sigmah=sigma_hat, muh=mu_hat, pih=pi_hat) { pik <- pih[[k]] muk <- muh[[k]] isg <- solve(sigmah) dmvnorm(x, mean=muk, sigma=sigmah, log=TRUE) + log(pik) ## could also do this using ## t(x) %*% isg %*% muk - 0.5 * t(muk) %*% isg %*% muk + log(pik) } ## compute value of three discriminant functions on a grid z1 <- apply(expand.grid(xrg1, xrg2), 1, delta, k=1) z2 <- apply(expand.grid(xrg1, xrg2), 1, delta, k=2) z3 <- apply(expand.grid(xrg1, xrg2), 1, delta, k=3) ## normalize those vlaues for plotting purposes zm <- range(c(z1,z2,z3)) z1 <- (z1-zm[1])/(zm[2]-zm[1]) z2 <- (z2-zm[1])/(zm[2]-zm[1]) z3 <- (z3-zm[1])/(zm[2]-zm[1]) ## create 3d graphic par3d(FOV=1,userMatrix=diag(4)) ## plot observed data; color according to class with(dat, plot3d(x1, x2, 1, zlim=zlm, xlim=xlm1, ylim=xlm2, zlab="", type="s", radius=0.075, axes=FALSE, col=cols[class], lwd=0)) lines3d(xlm1[c(1,2,2,1,1)], xlm2[c(1,1,2,2,1)], 1) lines3d(xlm1[c(1,2,2,1,1)], xlm2[c(1,1,2,2,1)], 0) ## draw descriminant functions for(cls in 1:3) quads3d(x=c(-0.75,1.75,1.75,-0.75), y=c(-0.75,-0.75,1.75,1.75), z=c((delta(c(-0.75,-0.75),cls)-zm[1])/(zm[2]-zm[1]), (delta(c(1.75,-0.75),cls)-zm[1])/(zm[2]-zm[1]), (delta(c(1.75,1.75),cls)-zm[1])/(zm[2]-zm[1]), (delta(c(-0.75,1.75),cls)-zm[1])/(zm[2]-zm[1])), color=cols[cls], emission=cols[cls], specular=cols[cls])