Examples

Data

Heart Failure Clinical Records

Heart Failure Clinical Records. archive.ics.uci.edu/dataset/519/heart+failure+clinical+records

Machine learning can predict survival of patients with heart failure from serum creatinine and ejection fraction alone. Davide Chicco, Giuseppe Jurman. BMC Medical Informatics and Decision Making. 2020. doi.org/10.1186/s12911-020-1023-5.

Download CSV: heart_failure.csv

Examples

This section includes code for the examples shown. These may differ slightly from the examples shown in the live demonstration.

Example 1: Pre-processing

See example
# Load R packages ---------------------------------------------------------

library(tidyverse)
library(tidymodels)
tidymodels_prefer()


# Load data ---------------------------------------------------------------

heart_failure <- read_csv("data/heart_failure.csv")
heart_failure
# A tibble: 299 × 12
     age sex   smoking anaemia diabetes high_blood_pressure serum_creatinine
   <dbl> <chr>   <dbl>   <dbl>    <dbl>               <dbl>            <dbl>
 1    75 M           0       0        0                   1              1.9
 2    55 M           0       0        0                   0              1.1
 3    65 M           1       0        0                   0              1.3
 4    50 M           0       1        0                   0              1.9
 5    65 F           0       1        1                   0              2.7
 6    90 M           1       1        0                   1              2.1
 7    75 M           0       1        0                   0              1.2
 8    60 M           1       1        1                   0              1.1
 9    65 F           0       0        0                   0              1.5
10    80 M           1       1        0                   1              9.4
# ℹ 289 more rows
# ℹ 5 more variables: creatinine_phosphokinase <dbl>, platelets <dbl>,
#   ejection_fraction <dbl>, time <dbl>, death <dbl>
heart_failure <- heart_failure |> 
  mutate(death = factor(death))

# You can also use `View()`!

# Inspect variables -------------------------------------------------------

barplot(table(heart_failure$death))

barplot(table(heart_failure$sex))

hist(heart_failure$age)

# Split into training and testing -----------------------------------------

set.seed(1234)
hf_split <- initial_split(heart_failure)
hf_train <- training(hf_split)
hf_test <- testing(hf_split)

# choose a different split proportion?
set.seed(1234)
hf_split <- initial_split(heart_failure, prop = 0.8)
hf_train <- training(hf_split)
hf_test <- testing(hf_split)

# Create cross validation folds
hf_folds <- vfold_cv(hf_train, v = 10)

# Build a recipe ----------------------------------------------------------

hf_recipe <- recipe(death ~ ., data = hf_train) |> 
  step_dummy(sex) |> 
  step_normalize(age, serum_creatinine:time)

wf <- workflow() |> 
  add_recipe(hf_recipe)

Example 2: Lasso regression

See example
# Specify the model -------------------------------------------------------

tune_spec_lasso <- logistic_reg(penalty = tune(), mixture = 1) |>
  set_engine("glmnet")


# Tune the model ----------------------------------------------------------

# Fit lots of values
lasso_grid <- tune_grid(
  add_model(wf, tune_spec_lasso),
  resamples = hf_folds,
  grid = grid_regular(penalty(), levels = 50)
)

# Choose the best value
highest_roc_auc_lasso <- lasso_grid |>
  select_best(metric = "roc_auc")


# Fit the final model -----------------------------------------------------

final_lasso <- finalize_workflow(
  add_model(wf, tune_spec_lasso),
  highest_roc_auc_lasso
)


# Model evaluation --------------------------------------------------------

last_fit(final_lasso, hf_split) |>
  collect_metrics()
# A tibble: 3 × 4
  .metric     .estimator .estimate .config             
  <chr>       <chr>          <dbl> <chr>               
1 accuracy    binary         0.85  Preprocessor1_Model1
2 roc_auc     binary         0.899 Preprocessor1_Model1
3 brier_class binary         0.116 Preprocessor1_Model1
# which variables were most important?
final_lasso |>
  fit(hf_train) |>
  extract_fit_parsnip() |>
  vip::vi(lambda = highest_roc_auc_lasso$penalty) |>
  mutate(
    Importance = abs(Importance),
    Variable = fct_reorder(Variable, Importance)
  ) |>
  ggplot(mapping = aes(x = Importance, y = Variable, fill = Sign)) +
  geom_col()

Example 3: Random forests

See example
# Specify model -----------------------------------------------------------

tune_spec_rf <- rand_forest(
  mtry = tune(),
  trees = 20,
  min_n = tune()
) |>
  set_mode("classification") |>
  set_engine("ranger")

# Tune hyperparameters ----------------------------------------------------

rf_grid <- tune_grid(
  add_model(wf, tune_spec_rf),
  resamples = hf_folds,
  grid = grid_regular(
    mtry(range = c(5, 8)),
    min_n(), #default c(2, 40)
    levels = 5)
)

# Fit model ---------------------------------------------------------------

highest_roc_auc_rf <- rf_grid |>
  select_best(metric = "roc_auc")

final_rf <- finalize_workflow(
  add_model(wf, tune_spec_rf),
  highest_roc_auc_rf
)

# Evaluate ----------------------------------------------------------------

last_fit(final_rf, hf_split) |>
  collect_metrics()
# A tibble: 3 × 4
  .metric     .estimator .estimate .config             
  <chr>       <chr>          <dbl> <chr>               
1 accuracy    binary         0.817 Preprocessor1_Model1
2 roc_auc     binary         0.855 Preprocessor1_Model1
3 brier_class binary         0.130 Preprocessor1_Model1
# create a confusion matrix
last_fit(final_rf, hf_split) |> 
  collect_predictions() |> 
  conf_mat(death, .pred_class) |> 
  autoplot()

Example 4: Support vector machines

See example
# Specify model -----------------------------------------------------------

tune_spec_svm <- svm_rbf(cost = tune()) |> 
  set_engine("kernlab") |> 
  set_mode("classification")


# Tune hyperparameters ----------------------------------------------------

# Fit lots of values
svm_grid <- tune_grid(
  add_model(wf, tune_spec_svm),
  resamples = hf_folds,
  grid = grid_regular(cost(), levels = 20)
)

# Fit model ---------------------------------------------------------------

highest_roc_auc_svm <- svm_grid |>
  select_best(metric = "roc_auc")

final_svm <- finalize_workflow(
  add_model(wf, tune_spec_svm),
  highest_roc_auc_svm
)


# Evaluate ----------------------------------------------------------------

last_fit(final_svm, hf_split,
         metrics = metric_set(roc_auc, accuracy, f_meas)) |>
  collect_metrics()
# A tibble: 3 × 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.717 Preprocessor1_Model1
2 f_meas   binary         0.835 Preprocessor1_Model1
3 roc_auc  binary         0.858 Preprocessor1_Model1
# create a confusion matrix
last_fit(final_svm, hf_split) |> 
  collect_predictions() |> 
  conf_mat(death, .pred_class) |> 
  autoplot()