1 Introduction

Customer retention is a critical concern for banks, as losing customers can significantly impact revenue and long-term stability. In this project, we explore customer churn prediction for ABC Multistate Bank using a dataset that includes various customer attributes such as credit score, tenure, account balance, number of bank products, and activity status. The goal is to develop a predictive model that can identify customers at risk of leaving the bank, enabling proactive intervention strategies.

Our approach involves conducting an exploratory data analysis (EDA) to understand patterns and trends in customer behavior. We then apply statistical learning algorithms to build and evaluate models for churn prediction. The performance of these models will be assessed based on key evaluation metrics, allowing us to compare their effectiveness. Finally, we will discuss the implications of our findings and propose future improvements to enhance model accuracy and real-world applicability.

Through this project, we aim to demonstrate how data-driven insights can support strategic decision-making in the banking sector, ultimately improving customer retention and business performance.

2 Data Import

df <- read.csv("/Users/ethantsao/PSTAT131/Bank Customer Churn Prediction.csv")
head(df)
##   customer_id credit_score country gender age tenure   balance products_number
## 1    15634602          619  France Female  42      2      0.00               1
## 2    15647311          608   Spain Female  41      1  83807.86               1
## 3    15619304          502  France Female  42      8 159660.80               3
## 4    15701354          699  France Female  39      1      0.00               2
## 5    15737888          850   Spain Female  43      2 125510.82               1
## 6    15574012          645   Spain   Male  44      8 113755.78               2
##   credit_card active_member estimated_salary churn
## 1           1             1        101348.88     1
## 2           0             1        112542.58     0
## 3           1             0        113931.57     1
## 4           0             0         93826.63     0
## 5           1             1         79084.10     0
## 6           1             0        149756.71     1

3 Reading Dataset

str(df)
## 'data.frame':    10000 obs. of  12 variables:
##  $ customer_id     : int  15634602 15647311 15619304 15701354 15737888 15574012 15592531 15656148 15792365 15592389 ...
##  $ credit_score    : int  619 608 502 699 850 645 822 376 501 684 ...
##  $ country         : chr  "France" "Spain" "France" "France" ...
##  $ gender          : chr  "Female" "Female" "Female" "Female" ...
##  $ age             : int  42 41 42 39 43 44 50 29 44 27 ...
##  $ tenure          : int  2 1 8 1 2 8 7 4 4 2 ...
##  $ balance         : num  0 83808 159661 0 125511 ...
##  $ products_number : int  1 1 3 2 1 2 2 4 2 1 ...
##  $ credit_card     : int  1 0 1 0 1 1 1 1 0 1 ...
##  $ active_member   : int  1 1 0 0 1 0 1 0 1 1 ...
##  $ estimated_salary: num  101349 112543 113932 93827 79084 ...
##  $ churn           : int  1 0 1 0 0 1 0 1 0 0 ...
vis_miss(df)

The plot above shows that there are no missing data points in our cleaned dataset, which contains a total of 10,000 records. Since there are no missing values, we can confidently move forward with our exploratory data analysis and modeling efforts.

summary_df <- df %>%
  summarise(
    Column = names(df),
    Count = nrow(df),
    Null = sapply(df, function(x) sum(is.na(x))),
    `Null %` = sapply(df, function(x) mean(is.na(x)) * 100),
    Cardinality = sapply(df, function(x) n_distinct(x))
  )

summary_df
##              Column Count Null Null % Cardinality
## 1       customer_id 10000    0      0       10000
## 2      credit_score 10000    0      0         460
## 3           country 10000    0      0           3
## 4            gender 10000    0      0           2
## 5               age 10000    0      0          70
## 6            tenure 10000    0      0          11
## 7           balance 10000    0      0        6382
## 8   products_number 10000    0      0           4
## 9       credit_card 10000    0      0           2
## 10    active_member 10000    0      0           2
## 11 estimated_salary 10000    0      0        9999
## 12            churn 10000    0      0           2

4 Detecting Outliers

numeric_columns <- names(df)[sapply(df, is.numeric)][-length(names(df))]

detect_outliers_iqr <- function(df, columns, threshold = 1.5) {
  outlier_indices <- list()
  
  for (col in columns) {
    Q1 <- quantile(df[[col]], 0.25, na.rm = TRUE)  # 25th percentile
    Q3 <- quantile(df[[col]], 0.75, na.rm = TRUE)  # 75th percentile
    IQR <- Q3 - Q1  # Interquartile range
    lower_bound <- Q1 - threshold * IQR
    upper_bound <- Q3 + threshold * IQR
    
    outliers <- which(df[[col]] < lower_bound | df[[col]] > upper_bound)
    outlier_indices[[col]] <- outliers
  }
  
  return(outlier_indices)
}

outliers_iqr <- detect_outliers_iqr(df, numeric_columns)

for (col in names(outliers_iqr)) {
  cat(col, ": ", length(outliers_iqr[[col]]), "outliers detected\n")
}
## customer_id :  0 outliers detected
## credit_score :  15 outliers detected
## age :  359 outliers detected
## tenure :  0 outliers detected
## balance :  0 outliers detected
## products_number :  60 outliers detected
## credit_card :  0 outliers detected
## active_member :  0 outliers detected
## estimated_salary :  0 outliers detected
## churn :  2037 outliers detected
remove_outliers <- function(df, columns, threshold = 1.5) {
  for(col in columns) {
    Q1 <- quantile(df[[col]], 0.25, na.rm = TRUE)
    Q3 <- quantile(df[[col]], 0.75, na.rm = TRUE)
    IQR <- Q3 - Q1
    lower_bound <- Q1 - threshold * IQR
    upper_bound <- Q3 + threshold * IQR
    
    df <- df[df[[col]] >= lower_bound & df[[col]] <= upper_bound, ]
  }
  return(df)
}

numeric_columns <- c("credit_score", "age", "tenure", "balance", "products_number", "estimated_salary")

df_cleaned <- remove_outliers(df, numeric_columns)

dim(df_cleaned)
## [1] 9568   12
head(df_cleaned)
##   customer_id credit_score country gender age tenure   balance products_number
## 1    15634602          619  France Female  42      2      0.00               1
## 2    15647311          608   Spain Female  41      1  83807.86               1
## 3    15619304          502  France Female  42      8 159660.80               3
## 4    15701354          699  France Female  39      1      0.00               2
## 5    15737888          850   Spain Female  43      2 125510.82               1
## 6    15574012          645   Spain   Male  44      8 113755.78               2
##   credit_card active_member estimated_salary churn
## 1           1             1        101348.88     1
## 2           0             1        112542.58     0
## 3           1             0        113931.57     1
## 4           0             0         93826.63     0
## 5           1             1         79084.10     0
## 6           1             0        149756.71     1

5 EDA and Data Insights

5.1 Distribution of Credit Score

ggplot(df_cleaned, aes(x = gender, y = credit_score, fill = gender)) +
  geom_boxplot() +
  labs(title = "Credit Score Distribution by Gender",
       x = "Gender",
       y = "Credit Score") +
  scale_fill_manual(values = c("Male" = "lightblue", "Female" = "pink")) +
  theme_minimal() +
  theme(plot.title = element_text(hjust = 0.5))

The box plot illustrates the distribution of credit scores by gender. Both distributions appear relatively similar, with overlapping interquartile ranges (IQRs) and medians. The median credit score for males and females is nearly the same, suggesting no significant difference in central tendency between the two groups. Both genders exhibit a similar range of credit scores, with the lower and upper quartiles closely aligned. Overall, this box plot suggests that credit scores are fairly evenly distributed across genders, with no clear evidence of major differences in median, spread, or extreme values.

5.2 Distribution of Age

ggplot(df_cleaned, aes(x = age, fill = gender)) +
  geom_density(alpha = 0.5, color = NA) +  # Density plot with transparency
  labs(title = "Age Distribution by Gender",
       x = "Age",
       y = "Density") +
  scale_fill_manual(values = c("Male" = "lightblue", "Female" = "pink")) +  # Custom colors
  theme_minimal() +
  theme(plot.title = element_text(hjust = 0.5))  # Center the title

The density plot illustrates the age distribution by gender, comparing males (light blue) and females (pink). Both distributions are unimodal and right-skewed, with the peak around the early 30s. Males have a slightly higher density around the peak, indicating a larger concentration of individuals in this age range compared to females. The distributions follow a similar pattern, suggesting that the age structure is fairly comparable across genders. However, the female distribution appears to extend slightly more toward older ages, with a longer tail beyond 50.

5.3 Distribution of Credit Score by Churn

ggplot(df_cleaned, aes(x = credit_score, fill = factor(churn))) +
  geom_density(alpha = 0.4) +  
  scale_fill_manual(values = c("0" = "magenta", "1" = "navy"), labels = c("No Churn", "Churn")) +
  labs(
    title = "Density Distribution of Credit Score by Churn Status",
    x = "Credit Score", y = "Density",
    fill = "Churn Status"
  ) +
  theme_minimal() +
  theme(plot.title = element_text(hjust = 0.5))

The density distribution of credit scores by churn status reveals that while both churned and non-churned customers have similar overall distributions, there are slight differences. Customers with lower credit scores (below 650) have a slightly higher likelihood of churning, while those with higher scores (680–750) show a higher tendency to stay. However, the overlap suggests that credit score alone may not be a strong predictor of churn.

5.4 Credit Score Distribution by Churn (further)

ggplot(df_cleaned, aes(x = credit_score, fill = factor(churn))) +
  geom_histogram(alpha = 0.6, bins = 20, color = "black") +
  facet_wrap(~ churn, labeller = labeller(churn = c("0" = "No Churn", "1" = "Churn"))) +
  scale_fill_manual(values = c("0" = "blue", "1" = "red")) +
  labs(
    title = "Credit Score Distribution by Churn Status",
    x = "Credit Score", y = "Count",
    fill = "Churn Status"
  ) +
  theme_minimal() +
  theme(plot.title = element_text(hjust = 0.5))

This histogram shows the distribution of credit scores for customers who churned (red) versus those who did not churn (blue). The majority of non-churned customers have credit scores between 600 and 750, forming a normal distribution with a peak around 700. In contrast, the churned customers have a lower overall count and their distribution is slightly skewed left, with most scores falling between 500 and 700. This suggests that customers with lower credit scores are more likely to churn, while those with higher scores tend to stay. However, some churn still occurs at higher credit scores, indicating that other factors beyond credit score influence churn behavior.

5.5 Churn Distribution by Age Group Stacked Bar Chart

# Creating bins for age, salary, and credit score
df$age_bin <- cut(df$age, breaks = c(0, 30, 45, 60, Inf), labels = c("Young-Adults", "Professional", "Middle-Aged", "Senior"))
df$salary_bin <- cut(df$estimated_salary, breaks = c(0, 25000, 75000, 125000, Inf), labels = c("Below", "Low", "Middle", "High"))
df$credit_score_bin <- cut(df$credit_score, breaks = c(349, 583, 651, 717, Inf), labels = c("Poor", "Fair", "Good", "Excellent"))

# Counting churn occurrences within each age group
df_summary <- df %>% 
  count(age_bin, churn) %>%
  group_by(age_bin) %>%
  mutate(pct = n / sum(n) * 100)

# Stacked bar chart showing churn distribution by age group
ggplot(data = df_summary, aes(x = age_bin, y = pct, fill = factor(churn))) +
  geom_bar(stat = "identity", position = "stack") +
  ggtitle("Churn Distribution by Age Group") +
  xlab("Age Group") +
  ylab("Percentage of Churn") +
  scale_fill_manual(values = c("lightblue", "beige"), labels = c("No Churn", "Churn")) +
  theme_minimal()

The stacked bar chart illustrates the distribution of churn across different age groups, providing insights into which demographic segments are more likely to leave. The “Young-Adults” and “Professional” groups exhibit a higher proportion of churn (blue) compared to the “Middle-Aged” and “Senior” groups, indicating that younger customers may be less loyal or have different financial behaviors. The churn rate decreases as age increases and with professionalism, suggesting that older customers with jobs are more stable and less likely to leave. This insight can help businesses tailor retention strategies, such as offering targeted incentives to younger demographics to reduce churn.

6 Data Preprocessing and Preparation

column_to_encode <- df_cleaned %>% 
  select(where(is.character)) %>% colnames()

df_cleaned <- df_cleaned %>%
  mutate(across(all_of(column_to_encode), ~ as.numeric(factor(.))))

head(df_cleaned)
##   customer_id credit_score country gender age tenure   balance products_number
## 1    15634602          619       1      1  42      2      0.00               1
## 2    15647311          608       3      1  41      1  83807.86               1
## 3    15619304          502       1      1  42      8 159660.80               3
## 4    15701354          699       1      1  39      1      0.00               2
## 5    15737888          850       3      1  43      2 125510.82               1
## 6    15574012          645       3      2  44      8 113755.78               2
##   credit_card active_member estimated_salary churn
## 1           1             1        101348.88     1
## 2           0             1        112542.58     0
## 3           1             0        113931.57     1
## 4           0             0         93826.63     0
## 5           1             1         79084.10     0
## 6           1             0        149756.71     1
columns_to_scale <- c("credit_score", "balance", "estimated_salary")

min_max_scale <- function(x) {
  (x - min(x, na.rm = TRUE)) / (max(x, na.rm = TRUE) - min(x, na.rm = TRUE))
}

df_cleaned <- df_cleaned %>%
  mutate(across(all_of(columns_to_scale), min_max_scale))

head(df_cleaned)
##   customer_id credit_score country gender age tenure   balance products_number
## 1    15634602    0.5053533       1      1  42      2 0.0000000               1
## 2    15647311    0.4817987       3      1  41      1 0.3340315               1
## 3    15619304    0.2548180       1      1  42      8 0.6363572               3
## 4    15701354    0.6766595       1      1  39      1 0.0000000               2
## 5    15737888    1.0000000       3      1  43      2 0.5002462               1
## 6    15574012    0.5610278       3      2  44      8 0.4533944               2
##   credit_card active_member estimated_salary churn
## 1           1             1        0.5067349     1
## 2           0             1        0.5627087     0
## 3           1             0        0.5696544     1
## 4           0             0        0.4691201     0
## 5           1             1        0.3954004     0
## 6           1             0        0.7487972     1
correlation_matrix <- df_cleaned %>%
  select(where(is.numeric)) %>%
  cor(use = "complete.obs")

fig <- plot_ly(
  z = correlation_matrix,
  x = colnames(correlation_matrix),
  y = colnames(correlation_matrix),
  type = "heatmap",
  colorscale = "tempo",
  text = round(correlation_matrix, 2),
  texttemplate = "%{text}",
  hoverinfo = "x+y+z",
  width = 1000,
  height = 800
) %>%
  plotly::layout(
    title = "Correlation Heatmap",
    template = "plotly_dark"
  )

fig

The customer_id column was removed from the dataset because it does not provide meaningful information for predictive modeling. As shown in the correlation heatmap, the customer_id variable has negligible correlations with other variables, including the target variable (churn). Its correlation values are close to zero, indicating no significant relationship with any other feature in the dataset.

df_cleaned <- df_cleaned %>%
  select(-customer_id)

6.1 Distribution of Churn

df_cleaned$churn <- as.factor(df_cleaned$churn)

# Count occurrences of each category in Churn
counts_df <- as.data.frame(table(df_cleaned$churn))
colnames(counts_df) <- c("Churn", "Counts")

# Plot the bar chart
ggplot(counts_df, aes(x = Churn, y = Counts)) +
  geom_bar(stat = "identity", fill = "lightblue") +
  ggtitle("Distribution of Churn") +
  xlab("Churn") +
  ylab("Counts") +
  theme_minimal()

The distribution of churn in the dataset is visualized using a bar chart, which shows the count of occurrences for each churn category. The churn variable, converted to a factor, is divided into two categories: churn (1) and non-churn (0). The bar chart indicates a significant imbalance between the two categories, with non-churn counts being substantially higher than churn counts. Specifically, the non-churn category has a much larger number of observations (around 8000), while the churn category has significantly fewer observations (around 2000 or less). This imbalance suggests that the dataset is skewed, with a majority of customers not churning.

6.2 Distribution of Churn after SMOTE

df_cleaned$churn <- as.factor(df_cleaned$churn)

# Define a recipe and apply SMOTE
recipe_obj <- recipe(churn ~ ., data = df_cleaned) %>%
  step_smote(churn)

# Prepare the recipe
smote_data <- prep(recipe_obj) %>% juice()

# Count occurrences of each category after SMOTE
counts_df <- as.data.frame(table(smote_data$churn))
colnames(counts_df) <- c("Churn", "Counts")

# Plot the bar chart
ggplot(counts_df, aes(x = Churn, y = Counts)) +
  geom_bar(stat = "identity", fill = "lightblue") +
  ggtitle("Distribution of Churn After SMOTE") +
  xlab("Churn") +
  ylab("Counts") +
  theme_minimal()

After applying SMOTE, the distribution of churn in the dataset becomes more balanced. The previously underrepresented churn category has been increased to match or closely align with the non-churn category. This balancing effect is evident as both categories now have similar counts, approximately around 8000 each, compared to the original imbalanced distribution where non-churn significantly outweighed churn. By addressing the class imbalance, SMOTE helps improve the model’s ability to learn from both classes equally, reducing the risk of bias toward the majority class.

7 Splitting Data and Cross Validation

set.seed(123)

df_cleaned$churn <- as.factor(df_cleaned$churn)

# Splitting the data (75% train, 25% test)
train_index <- createDataPartition(df_cleaned$churn, p = 0.75, list = FALSE)
train_data <- df_cleaned[train_index, ]
test_data <- df_cleaned[-train_index, ]

# Extracting X (features) and y (target variable)
X_train <- train_data[, -which(names(train_data) == "churn")]
y_train <- train_data$churn
X_test <- test_data[, -which(names(test_data) == "churn")]
y_test <- test_data$churn

dim_table <- data.frame(
  Dataset = c("X_Train", "X_Test", "Y_Train", "Y_Test"),
  Rows = c(dim(X_train)[1], dim(X_test)[1], length(y_train), length(y_test)),
  Columns = c(dim(X_train)[2], dim(X_test)[2], NA, NA)
)

# Display the table with kable
kable(dim_table, col.names = c("Dataset", "Rows", "Columns"), format = "pipe", align = "c")
Dataset Rows Columns
X_Train 7177 10
X_Test 2391 10
Y_Train 7177 NA
Y_Test 2391 NA

Before proceeding with model fitting, we split our dataset into training and testing sets. To ensure the results are reproducible, we used set.seed(123), which fixed the random seed, guaranteeing that the data split would be the same every time the code is executed. The dataset was then divided using the createDataPartition function, with 75% of the data allocated to the training set and 25% to the testing set. This 75/25 split ensures that the model has sufficient data to learn from while retaining a separate portion for testing. Next, we separated the features (X) and the target variable (y) for both the training and testing sets. The features were stored in X_train and X_test, while the target variable, churn, was stored in y_train and y_test. Finally, a table summarizing the dimensions of these datasets was created and displayed using the kable function, providing an overview of the number of rows and columns in each dataset.

churn_recipe <- recipe(churn ~ ., data = df_cleaned) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_center(all_numeric_predictors()) %>%
  step_scale(all_numeric_predictors())

# Check the recipe
churn_recipe_prepped <- prep(churn_recipe)
bake(churn_recipe_prepped, new_data = train_data) %>% head()
## # A tibble: 6 × 11
##   credit_score country gender   age tenure balance products_number credit_card
##          <dbl>   <dbl>  <dbl> <dbl>  <dbl>   <dbl>           <dbl>       <dbl>
## 1      -0.444    1.51  -1.10  0.372 -1.39    0.118          -0.939      -1.55 
## 2      -1.55    -0.902 -1.10  0.486  1.03    1.33            2.70        0.647
## 3       0.503   -0.902 -1.10  0.144 -1.39   -1.22            0.880      -1.55 
## 4       2.07     1.51  -1.10  0.599 -1.04    0.786          -0.939       0.647
## 5      -0.0592   1.51   0.910 0.713  1.03    0.598           0.880       0.647
## 6      -1.56    -0.902  0.910 0.713 -0.352   1.05            0.880      -1.55 
## # ℹ 3 more variables: active_member <dbl>, estimated_salary <dbl>, churn <fct>

In this step, we created a preprocessing recipe for the data using the recipe function, where the target variable is churn and all other columns are considered predictors. The recipe includes transforming categorical predictors into dummy variables using step_dummy, centering numerical predictors by subtracting their mean with step_center, and scaling them by dividing by their standard deviation with step_scale. After defining the recipe, we used the prep function to prepare it, which processes the defined steps. Finally, we applied the recipe to the training data with the bake function and displayed the first few rows of the transformed data, ensuring it is properly preprocessed for modeling.

## #  10-fold cross-validation using stratification 
## # A tibble: 10 × 2
##    splits             id    
##    <list>             <chr> 
##  1 <split [6459/718]> Fold01
##  2 <split [6459/718]> Fold02
##  3 <split [6459/718]> Fold03
##  4 <split [6459/718]> Fold04
##  5 <split [6459/718]> Fold05
##  6 <split [6459/718]> Fold06
##  7 <split [6459/718]> Fold07
##  8 <split [6459/718]> Fold08
##  9 <split [6460/717]> Fold09
## 10 <split [6461/716]> Fold10

We create cross-validation folds using the vfold_cv function, which splits the training data into 10 folds (v = 10) while ensuring that the distribution of the target variable churn is preserved in each fold using stratification. This allows for more reliable performance evaluation of the models. The cross-validation object, churn_folds, contains these splits. Next, we defined a set of metrics to evaluate the performance of the models, including accuracy, precision, recall, F-measure, and ROC AUC, all from the yardstick package. These metrics will be used to assess the model’s ability to predict the target variable, churn, effectively during training and cross-validation.

8 Model Fitting

In the model fitting process, we evaluate the predictive performance of each machine learning model by comparing their performance metrics, including accuracy, precision, recall, F1 score, and ROC AUC. These metrics provide insights into how well the model is performing across different aspects, such as overall correctness, balance between precision and recall, and the ability to distinguish between classes. We use workflows to organize each model, which include the necessary preprocessing steps, and fit these models using cross-validation on the training data. After fitting the models, we collect and print their performance metrics across the different folds.

The models we have selected for fitting are Logistic Regression, Random Forest, K-Nearest Neighbors (KNN), Boosted Trees, and Bagging.

# Logistic Regression
log_reg_model <- logistic_reg() %>%
  set_engine("glm") %>%
  set_mode("classification")

log_reg_wf <- workflow() %>%
  add_recipe(churn_recipe) %>%
  add_model(log_reg_model)
#Random Forest
rf_model <- rand_forest() %>%
  set_engine("ranger") %>%
  set_mode("classification")

rf_wf <- workflow() %>%
  add_recipe(churn_recipe) %>%
  add_model(rf_model)
# K-Nearest Neighbors (KNN)
knn_model <- nearest_neighbor(neighbors = 5) %>%
  set_engine("kknn") %>%
  set_mode("classification")

knn_wf <- workflow() %>%
  add_recipe(churn_recipe) %>%
  add_model(knn_model)
# Boosted Trees
boost_model <- boost_tree() %>%
  set_engine("xgboost") %>%
  set_mode("classification")

boost_wf <- workflow() %>%
  add_recipe(churn_recipe) %>%
  add_model(boost_model)
# Bagging
bagging_model <- bag_tree() %>%
  set_engine("rpart", times = 25) %>%
  set_mode("classification")

bagging_wf <- workflow() %>%
  add_recipe(churn_recipe) %>%
  add_model(bagging_model)

8.1 Fit and evaluate models using cross-validation

log_reg_results <- log_reg_wf %>%
  fit_resamples(
    resamples = churn_folds,
    metrics = churn_metrics,
    control = control_resamples(save_pred = TRUE)
  )
rf_results <- rf_wf %>%
  fit_resamples(
    resamples = churn_folds,
    metrics = churn_metrics,
    control = control_resamples(save_pred = TRUE)
  )
boost_results <- boost_wf %>%
  fit_resamples(
    resamples = churn_folds,
    metrics = churn_metrics,
    control = control_resamples(save_pred = TRUE)
  )
bagging_results <- bagging_wf %>%
  fit_resamples(
    resamples = churn_folds,
    metrics = churn_metrics,
    control = control_resamples(save_pred = TRUE)
  )
knn_results <- knn_wf %>%
  fit_resamples(
    resamples = churn_folds,
    metrics = churn_metrics,
    control = control_resamples(save_pred = TRUE)
  )

8.2 Collecting metrics for each model

log_reg_metrics <- collect_metrics(log_reg_results)
rf_metrics <- collect_metrics(rf_results)
boost_metrics <- collect_metrics(boost_results)
bagging_metrics <- collect_metrics(bagging_results)
knn_metrics <- collect_metrics(knn_results)

# Display the results
log_reg_metrics <- log_reg_metrics %>% mutate(Model = "Logistic Regression")
rf_metrics <- rf_metrics %>% mutate(Model = "Random Forest")
boost_metrics <- boost_metrics %>% mutate(Model = "Boosted Trees")
bagging_metrics <- bagging_metrics %>% mutate(Model = "Bagging")
knn_metrics <- knn_metrics %>% mutate(Model = "K-Nearest Neighbors")

# Combine all metric tables
all_metrics <- bind_rows(log_reg_metrics, rf_metrics, boost_metrics, bagging_metrics, knn_metrics) %>%
  select(Model, .metric, mean, std_err)  # Keep relevant columns

# Create a kable table
kable(all_metrics, caption = "Comparison of Model Performance Metrics", digits = 3)
Comparison of Model Performance Metrics
Model .metric mean std_err
Logistic Regression accuracy 0.828 0.004
Logistic Regression f_meas 0.900 0.002
Logistic Regression precision 0.842 0.002
Logistic Regression recall 0.967 0.003
Logistic Regression roc_auc 0.775 0.008
Random Forest accuracy 0.860 0.003
Random Forest f_meas 0.917 0.002
Random Forest precision 0.872 0.002
Random Forest recall 0.968 0.002
Random Forest roc_auc 0.851 0.004
Boosted Trees accuracy 0.861 0.003
Boosted Trees f_meas 0.918 0.002
Boosted Trees precision 0.876 0.002
Boosted Trees recall 0.964 0.003
Boosted Trees roc_auc 0.852 0.004
Bagging accuracy 0.852 0.001
Bagging f_meas 0.911 0.001
Bagging precision 0.877 0.002
Bagging recall 0.949 0.003
Bagging roc_auc 0.823 0.005
K-Nearest Neighbors accuracy 0.821 0.003
K-Nearest Neighbors f_meas 0.892 0.002
K-Nearest Neighbors precision 0.863 0.002
K-Nearest Neighbors recall 0.922 0.003
K-Nearest Neighbors roc_auc 0.770 0.007

8.3 ROC Curve and Model Comparison

# ROC Curves
log_reg_predictions <- collect_predictions(log_reg_results)
rf_predictions <- collect_predictions(rf_results)
boost_predictions <- collect_predictions(boost_results)
bagging_predictions <- collect_predictions(bagging_results)
knn_predictions <- collect_predictions(knn_results)

bind_rows(
  log_reg_predictions %>% mutate(model = "Logistic Regression"),
  rf_predictions %>% mutate(model = "Random Forest"),
  boost_predictions %>% mutate(model = "Boosted Trees"),
  bagging_predictions %>% mutate(model = "Bagging"),
  knn_predictions %>% mutate(model = "KNN")
) %>%
  group_by(model) %>%
  roc_curve(truth = churn, .pred_1) %>%
  autoplot() +
  labs(
    title = "ROC Curves for Churn Prediction Models",
    subtitle = "Comparing multiple ML models using the same preprocessing recipe"
  )

The ROC curve analysis compares the performance of multiple machine learning models for churn prediction using the same preprocessing recipe. The models evaluated include Bagging, Boosted Trees, KNN, Logistic Regression, and Random Forest. The sensitivity (true positive rate) and specificity (true negative rate) metrics are plotted to assess the models’ ability to correctly classify churn and non-churn cases. Sensitivity ranges from 0.00 to 1.00, with higher values indicating better detection of true positives, while specificity also ranges from 0.00 to 1.00, reflecting the models’ accuracy in identifying true negatives. The curves suggest that models like Random Forest and Boosted Trees likely perform well, as they tend to achieve higher sensitivity and specificity. In contrast, simpler models like Logistic Regression or KNN may show lower performance.

all_metrics <- bind_rows(
  collect_metrics(log_reg_results) %>% mutate(model = "Logistic Regression"),
  collect_metrics(rf_results) %>% mutate(model = "Random Forest"),
  collect_metrics(boost_results) %>% mutate(model = "Boosted Trees"),
  collect_metrics(bagging_results) %>% mutate(model = "Bagging"),
  collect_metrics(knn_results) %>% mutate(model = "KNN")
)

# Plot accuracy across models
all_metrics %>%
  filter(.metric == "accuracy") %>%
  ggplot(aes(x = model, y = mean, fill = model)) +
  geom_col() +
  labs(title = "Model Comparison", y = "Accuracy", x = "Model Type") +
  theme_minimal()

From the comparison, the results indicate that some models such as Random Forest or Boosted Trees achieve higher accuracy (closer to 0.75), making them more effective for the given task. In contrast, simpler models like Logistic Regression or KNN may show lower accuracy (closer to 0.25 or 0.50), suggesting they are less reliable for churn prediction.

8.4 Determining and fitting best model

best_models <- all_metrics %>%
  filter(.metric %in% c("accuracy", "f_meas", "roc_auc")) %>%
  group_by(.metric) %>%
  slice_max(mean, n = 1) %>%
  ungroup()

print(best_models)
## # A tibble: 3 × 7
##   .metric  .estimator  mean     n std_err .config              model        
##   <chr>    <chr>      <dbl> <int>   <dbl> <chr>                <chr>        
## 1 accuracy binary     0.861    10 0.00284 Preprocessor1_Model1 Boosted Trees
## 2 f_meas   binary     0.918    10 0.00172 Preprocessor1_Model1 Boosted Trees
## 3 roc_auc  binary     0.852    10 0.00434 Preprocessor1_Model1 Boosted Trees
# Find the overall best model
best_overall <- best_models %>%
  count(model) %>%
  slice_max(n, n = 1)

best_overall
## # A tibble: 1 × 2
##   model             n
##   <chr>         <int>
## 1 Boosted Trees     3

The Boosted Trees model emerged as the best model overall, as it had the highest count of top performance across the different folds. With a total of 3 occurrences as the best-performing model, it demonstrated better predictive accuracy compared to the other models.

boost_model <- boost_tree() %>%
  set_engine("xgboost") %>%
  set_mode("classification")

# Create the workflow
boost_wf <- workflow() %>%
  add_recipe(churn_recipe) %>%
  add_model(boost_model)

# Fit the model to the training data
boost_fit <- boost_wf %>%
  fit(data = train_data)

# Make predictions on the test set
boost_predictions <- boost_fit %>%
  predict(new_data = test_data) %>%
  bind_cols(test_data)

# Evaluate performance
boost_metrics <- boost_predictions %>%
  metrics(truth = churn, estimate = .pred_class)

# Display evaluation metrics
boost_metrics
## # A tibble: 2 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.858
## 2 kap      binary         0.474

The Boosted Trees model also performed well in terms of key evaluation metrics, with an accuracy of 0.858, indicating that it correctly predicted the outcome approximately 85.8% of the time. Additionally, it achieved a Kappa statistic (kap) of 0.474, suggesting moderate agreement between the predicted and actual classifications. These results reinforce the model’s strong predictive performance and its ability to capture meaningful patterns in the data.

9 Conclusion

In conclusion, our project demonstrated a solid approach to predicting customer churn, with several models tested and evaluated. The boosting model stood out as the best performer, offering strong predictive power and consistently good results across folds. However, there are areas for improvement. First, additional feature engineering and exploration of more advanced models, such as deep learning, could provide further enhancements in prediction accuracy. Further refinement of hyperparameters and exploring other techniques like ensemble learning could improve the model’s overall performance. Despite these opportunities for growth, the overall approach, including proper data preprocessing, model selection, and cross-validation, was executed well, leading to reliable insights and a comprehensive understanding of the churn problem. Moving forward, refining the model and exploring new features could help to elevate the model’s accuracy and robustness even further.