11  Classification

11.1 Setup

11.1.1 Packages

library(rtemis)
  .:rtemis 0.99.94 🌊 aarch64-apple-darwin20
library(data.table)

11.1.2 Data

For this example, we shall use the BreastCancer dataset from the mlbench package:

data(BreastCancer, package = "mlbench")
Important

In rtemis, the last column is the outcome variable.

We optionally convert the dataset to a data.table:

Note

train() supports data.frame, data.table, or tibble inputs.

dat <- as.data.table(BreastCancer)
dat
          Id Cl.thickness Cell.size Cell.shape Marg.adhesion Epith.c.size
      <char>        <ord>     <ord>      <ord>         <ord>        <ord>
  1: 1000025            5         1          1             1            2
  2: 1002945            5         4          4             5            7
  3: 1015425            3         1          1             1            2
  4: 1016277            6         8          8             1            3
  5: 1017023            4         1          1             3            2
 ---                                                                     
695:  776715            3         1          1             1            3
696:  841769            2         1          1             1            2
697:  888820            5        10         10             3            7
698:  897471            4         8          6             4            3
699:  897471            4         8          8             5            4
     Bare.nuclei Bl.cromatin Normal.nucleoli Mitoses     Class
          <fctr>      <fctr>          <fctr>  <fctr>    <fctr>
  1:           1           3               1       1    benign
  2:          10           3               2       1    benign
  3:           2           3               1       1    benign
  4:           4           3               7       1    benign
  5:           1           3               1       1    benign
 ---                                                          
695:           2           1               1       1    benign
696:           1           1               1       1    benign
697:           3           8              10       2 malignant
698:           4          10               6       1 malignant
699:           5          10               4       1 malignant

Also optionally, we clean the dataset, in this case to replace periods with underscores in column names:

dt_set_clean_all(dat)
dat
Tip

dt_* functions operate on data.table objects. dt_set_* functions modify their input in-place.

Class is already the last column, otherwise we could use set_outcome() to move it.

Important

For classification, the outcome variable must be a factor. For binary classification, the second factor level is considered the positive case.

11.2 Check data

check_data(dat)
  dat: A data.table with 699 rows and 11 columns.

  Data types
  * 0 numeric features
  * 0 integer features
  * 10 factors, of which 5 are ordered
  * 1 character feature
  * 0 date features

  Issues
  * 0 constant features
  * 8 duplicate cases
  * 1 feature includes 'NA' values; 16 'NA' values total
    * 1 factor

  Recommendations
  * Consider converting character features to factors or excluding them.
  * Consider removing the duplicate cases.
  * Consider imputing missing values or using algorithms that can handle missingness. 

11.3 Train a single model

11.3.1 Resample

res <- resample(dat, setup_Resampler(1L, "StratSub"))
2025-06-06 17:43:19 Input contains more than one column; will stratify on last. [resample]
2025-06-06 17:43:19 Using max n bins possible = 2 [strat_sub]
2025-06-06 17:43:19 Updated strat_n_bins from 4 to 2 in ResamplerParameters object. [resample]
res
StratSub Resampler
   resamples: 
              Subsample_1: <int> 1, 2, 4, 5...
  parameters:  
              StratSub ResamplerParameters
                           n: <int> 1
                     train_p: <nmr> 0.75
                stratify_var: <NUL> NULL
                strat_n_bins: <int> 2
                    id_strat: <NUL> NULL
                        seed: <NUL> NULL
dat_training <- dat[res$Subsample_1, ]
dat_test <- dat[-res$Subsample_1, ]
size(dat_training)
523 x 11 
size(dat_test)
176 x 11 

11.3.2 Train model

Using LightRF as an example to train a random forest model:

mod_lightrf <- train(
  dat_training,
  dat_test = dat_test,
  algorithm = "LightRF"
)
2025-06-06 17:43:19 👽Hello. [train]
  Input data summary:
  Training set: 523 cases x 10 features.
   Test set: 176 cases x 10 features.
2025-06-06 17:43:19 Training LightRF Classification... [train]
2025-06-06 17:43:19 Checking data is ready for training... [check_supervised_data]

.:Classification Model
  LightRF (LightGBM Random Forest)

  Training Classification Metrics

                   Predicted 
        Reference  malignant  benign  
        malignant        164      16
           benign         10     333

                   Overall  
      Sensitivity  0.911  
      Specificity  0.971  
Balanced_Accuracy  0.941  
              PPV  0.943  
              NPV  0.954  
               F1  0.927  
         Accuracy  0.950  
              AUC  0.987  
      Brier_Score  0.063  

   Positive Class  malignant 

  Test Classification Metrics

                   Predicted 
        Reference  malignant  benign  
        malignant         50      11
           benign          3     112

                   Overall  
      Sensitivity  0.820  
      Specificity  0.974  
Balanced_Accuracy  0.897  
              PPV  0.943  
              NPV  0.911  
               F1  0.877  
         Accuracy  0.920  
              AUC  0.986  
      Brier_Score  0.076  

   Positive Class  malignant 
2025-06-06 17:43:20 Done in 0.01 minutes (Real: 0.77; User: 0.70; System: 0.07). [train]

11.3.3 Describe model

describe(mod_lightrf)
LightGBM Random Forest was used for classification.
Balanced accuracy was 0.94 on the training set and 0.90 in the test set.

11.3.4 Plot Confusion Matrix

plot(mod_lightrf)

11.3.5 Plot ROC Curve

plot_roc(mod_lightrf)

11.3.6 Present model

present() combines describe() and plot() or plot_roc() (default):

present(mod_lightrf)
LightGBM Random Forest was used for classification.
Balanced accuracy was 0.94 on the training set and 0.90 in the test set.

type defaults to "ROC", but can be set to "confusion" to show training and test confusion matrices side by side:

present(mod_lightrf, type = "confusion")
LightGBM Random Forest was used for classification.
Balanced accuracy was 0.94 on the training set and 0.90 in the test set.

11.3.7 Plot Variable Importance

plot_varimp(mod_lightrf)

11.4 Train on multiple training/test resamples

To train on multiple resamples, we use the outer_resampling argument:

resmod_lightrf <- train(
  dat_training,
  algorithm = "LightRF",
  outer_resampling = setup_Resampler(n_resamples = 10L, type = "KFold")
)
2025-06-06 17:43:21 👽Hello. [train]
  Input data summary:
  Training set: 523 cases x 10 features.
2025-06-06 17:43:21 Training LightRF Classification by cross-validation... [train]
2025-06-06 17:43:21 Input contains more than one column; will stratify on last. [resample]
2025-06-06 17:43:21 Using max n bins possible = 2. [kfold]
2025-06-06 17:43:25 Crossvalidation done. [train]
.:Resampled Classification Model
  LightRF (LightGBM Random Forest)
   Tested using 10-fold crossvalidation.

  Resampled Classification Training Metrics
  Showing mean (sd) across resamples.

        Sensitivity: 0.914 (0.008)
        Specificity: 0.971 (4e-03)
  Balanced_Accuracy: 0.943 (0.006)
                PPV: 0.943 (0.008)
                NPV: 0.956 (4.2e-03)
                 F1: 0.928 (0.007)
           Accuracy: 0.951 (4.8e-03)
                AUC: 0.987 (1.9e-03)
        Brier_Score: 0.065 (1.8e-03)

  Resampled Classification Test Metrics
  Showing mean (sd) across resamples.

        Sensitivity: 0.911 (0.054)
        Specificity: 0.971 (0.036)
  Balanced_Accuracy: 0.941 (0.031)
                PPV: 0.947 (0.061)
                NPV: 0.955 (0.026)
                 F1: 0.927 (0.041)
           Accuracy: 0.950 (0.028)
                AUC: 0.984 (0.019)
        Brier_Score: 0.068 (0.012)
2025-06-06 17:43:25 Done in 0.07 minutes (Real: 4.31; User: 10.20; System: 1.17). [train]

Now, train() produced a ClassificationRes object:

class(resmod_lightrf)
[1] "rtemis::ClassificationRes" "rtemis::SupervisedRes"    
[3] "S7_object"                

11.4.1 Describe

describe(resmod_lightrf)
LightGBM Random Forest was used for classification. Mean Balanced accuracy was 0.94 in the training set and 0.94 in the test set across 10 independent folds. 

11.4.2 Plot

The plot() method for ClassificationRes objects plots boxplots of the training and test set metrics:

plot(resmod_lightrf)

11.4.3 Present

The present() method for ClassificationRes objects combines the describe() and plot() methods:

present(resmod_lightrf)
LightGBM Random Forest was used for classification. Mean Balanced accuracy was 0.94 in the training set and 0.94 in the test set across 10 independent folds. 
© 2025 E.D. Gennatas