Skip to content

Commit

Permalink
Merge pull request #172 from slds-lmu/mi_measures
Browse files Browse the repository at this point in the history
add correlation and mi to correlation plots
  • Loading branch information
chriskolb authored Dec 20, 2023
2 parents de34918 + 58be677 commit f74764c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
Binary file modified slides/information-theory/figure/correlation_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 24 additions & 2 deletions slides/information-theory/rsrc/make_correlation_plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
library(ggplot2)
library(gridExtra)
library(mvtnorm)
library(infotheo)

set.seed(123)

# DATA -------------------------------------------------------------------------
n <- 400
Expand Down Expand Up @@ -36,16 +39,35 @@ df6 <- data.frame(x = xy6[,1], y = xy6[,2])
# PLOTS ------------------------------------------------------------------------

make_plot <- function(df, xlimit = NULL) {

df <- na.omit(df)
# Calculate Pearson correlation
corr <- cor(df$x, df$y, method = "pearson")

# Discretize for mutual information calculation
df_discrete <- df
num_bins <- ceiling(sqrt(nrow(df_discrete))) # Using the square root heuristic
df_discrete$x <- cut(df_discrete$x, breaks = num_bins, labels = FALSE)
df_discrete$y <- cut(df_discrete$y, breaks = num_bins, labels = FALSE)
mi <- mutinformation(df_discrete$x, df_discrete$y)

# Create the plot
p <- ggplot(df, aes(x = x, y = y)) +
geom_point(shape = 1, size = 2, stroke = 1) +
theme_bw() +
theme(axis.title = element_blank())
ggtitle(paste("Corr: ", round(corr, 2),
", MI: ", round(mi, 2))) +
theme(axis.title = element_blank(), plot.title = element_text(size = 10))

# Set x limits if specified
if (!is.null(xlimit)) {
p <- p + xlim(xlimit[1], xlimit[2])
}

p
}


p1 <- make_plot(df1)
p2 <- make_plot(df2, xlimit = c(-1, 1))
p3 <- make_plot(df3)
Expand All @@ -54,4 +76,4 @@ p5 <- make_plot(df5)
p6 <- make_plot(df6)

p <- grid.arrange(grobs = list(p1, p2, p3, p4, p5, p6), nrow = 1, ncol = 6)
ggsave(filename = "../figure/correlation_plot.png", plot = p, width = 12, height = 2)
ggsave(filename = "../figure/correlation_plot.png", plot = p, width = 12, height = 2)

0 comments on commit f74764c

Please sign in to comment.