-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgaussianMixture.R
70 lines (58 loc) · 2.13 KB
/
gaussianMixture.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Load required libraries
library(tidyverse)
library(mclust) # For Gaussian Mixture Model (GMM)
library(factoextra) # For visualization
library(cluster) # For silhouette analysis
library(ggplot2)
library(gridExtra)
# Function to preprocess data
preprocess_data <- function(file_path) {
data <- read.csv(file_path) %>% select(-CustomerID)
data$Gender <- as.numeric(data$Gender == "Male")
scaled_data <- data %>% mutate(across(everything(), scale))
return(list(original = data, scaled = scaled_data))
}
# Function to fit GMM and get optimal clusters
fit_gmm <- function(scaled_data) {
gmm_model <- Mclust(scaled_data)
return(gmm_model)
}
# Function to plot GMM results
plot_gmm_results <- function(gmm_model) {
par(family = "Times New Roman")
plot(gmm_model, what = "BIC", main = "BIC Scores for GMM")
plot(gmm_model, what = "density", main = "Density of Cluster Probabilities")
}
# Function to compute silhouette score
compute_silhouette <- function(data, scaled_data) {
silhouette_score <- silhouette(as.numeric(data$Cluster), dist(scaled_data))
return(mean(silhouette_score[, 3]))
}
# Function to visualize clusters
plot_clusters <- function(scaled_data, clusters) {
fviz_cluster(list(data = scaled_data, cluster = clusters),
geom = "point", ellipse.type = "norm",
main = "Customer Segmentation (GMM)") +
theme(text = element_text(family = "Times New Roman"))
}
# Load and preprocess data
data_info <- preprocess_data("./rawSegdata.csv")
data <- data_info$original
scaled_data <- data_info$scaled
# Fit GMM model
gmm_model <- fit_gmm(scaled_data)
# Optimal number of clusters
optimal_k <- gmm_model$G
print(paste("Optimal number of clusters:", optimal_k))
# Assign cluster labels to the dataset
data$Cluster <- as.factor(predict(gmm_model)$classification)
# Display cluster summary
cluster_summary <- data %>%
group_by(Cluster) %>%
summarise(across(everything(), mean))
# Compute silhouette score
silhouette_mean <- compute_silhouette(data, scaled_data)
print(paste("Average Silhouette Score:", silhouette_mean))
# Plot results
plot_gmm_results(gmm_model)
plot_clusters(scaled_data, data$Cluster)