library(rtemis)
.:rtemis 0.99.94 🌊 aarch64-apple-darwin20
library(data.table)
library(rtemis)
.:rtemis 0.99.94 🌊 aarch64-apple-darwin20
library(data.table)
For this example, we shall use the BreastCancer
dataset from the mlbench
package:
data(BreastCancer, package = "mlbench")
In rtemis, the last column is the outcome variable.
We optionally convert the dataset to a data.table
:
train()
supports data.frame
, data.table
, or tibble
inputs.
<- as.data.table(BreastCancer)
dat dat
Id Cl.thickness Cell.size Cell.shape Marg.adhesion Epith.c.size
<char> <ord> <ord> <ord> <ord> <ord>
1: 1000025 5 1 1 1 2
2: 1002945 5 4 4 5 7
3: 1015425 3 1 1 1 2
4: 1016277 6 8 8 1 3
5: 1017023 4 1 1 3 2
---
695: 776715 3 1 1 1 3
696: 841769 2 1 1 1 2
697: 888820 5 10 10 3 7
698: 897471 4 8 6 4 3
699: 897471 4 8 8 5 4
Bare.nuclei Bl.cromatin Normal.nucleoli Mitoses Class
<fctr> <fctr> <fctr> <fctr> <fctr>
1: 1 3 1 1 benign
2: 10 3 2 1 benign
3: 2 3 1 1 benign
4: 4 3 7 1 benign
5: 1 3 1 1 benign
---
695: 2 1 1 1 benign
696: 1 1 1 1 benign
697: 3 8 10 2 malignant
698: 4 10 6 1 malignant
699: 5 10 4 1 malignant
Also optionally, we clean the dataset, in this case to replace periods with underscores in column names:
dt_set_clean_all(dat)
dat
dt_*
functions operate on data.table
objects. dt_set_*
functions modify their input in-place.
Class
is already the last column, otherwise we could use set_outcome()
to move it.
For classification, the outcome variable must be a factor. For binary classification, the second factor level is considered the positive case.
check_data(dat)
dat: A data.table with 699 rows and 11 columns.
Data types
* 0 numeric features
* 0 integer features
* 10 factors, of which 5 are ordered
* 1 character feature
* 0 date features
Issues
* 0 constant features
* 8 duplicate cases
* 1 feature includes 'NA' values; 16 'NA' values total
* 1 factor
Recommendations
* Consider converting character features to factors or excluding them.
* Consider removing the duplicate cases.
* Consider imputing missing values or using algorithms that can handle missingness.
<- resample(dat, setup_Resampler(1L, "StratSub")) res
res
StratSub Resampler
resamples:
Subsample_1: <int> 1, 2, 4, 5...
parameters:
StratSub ResamplerParameters
n: <int> 1
train_p: <nmr> 0.75
stratify_var: <NUL> NULL
strat_n_bins: <int> 2
id_strat: <NUL> NULL
seed: <NUL> NULL
<- dat[res$Subsample_1, ]
dat_training <- dat[-res$Subsample_1, ]
dat_test size(dat_training)
523 x 11
size(dat_test)
176 x 11
Using LightRF as an example to train a random forest model:
<- train(
mod_lightrf
dat_training,dat_test = dat_test,
algorithm = "LightRF"
)
Input data summary:
Training set: 523 cases x 10 features.
Test set: 176 cases x 10 features.
.:Classification Model
LightRF (LightGBM Random Forest)
Training Classification Metrics
Predicted
Reference malignant benign
malignant 164 16
benign 10 333
Overall
Sensitivity 0.911
Specificity 0.971
Balanced_Accuracy 0.941
PPV 0.943
NPV 0.954
F1 0.927
Accuracy 0.950
AUC 0.987
Brier_Score 0.063
Positive Class malignant
Test Classification Metrics
Predicted
Reference malignant benign
malignant 50 11
benign 3 112
Overall
Sensitivity 0.820
Specificity 0.974
Balanced_Accuracy 0.897
PPV 0.943
NPV 0.911
F1 0.877
Accuracy 0.920
AUC 0.986
Brier_Score 0.076
Positive Class malignant
describe(mod_lightrf)
LightGBM Random Forest was used for classification.
Balanced accuracy was 0.94 on the training set and 0.90 in the test set.
plot(mod_lightrf)
plot_roc(mod_lightrf)
present()
combines describe()
and plot()
or plot_roc()
(default):
present(mod_lightrf)
LightGBM Random Forest was used for classification.
Balanced accuracy was 0.94 on the training set and 0.90 in the test set.
type
defaults to "ROC"
, but can be set to "confusion"
to show training and test confusion matrices side by side:
present(mod_lightrf, type = "confusion")
LightGBM Random Forest was used for classification.
Balanced accuracy was 0.94 on the training set and 0.90 in the test set.
plot_varimp(mod_lightrf)
To train on multiple resamples, we use the outer_resampling
argument:
<- train(
resmod_lightrf
dat_training,algorithm = "LightRF",
outer_resampling = setup_Resampler(n_resamples = 10L, type = "KFold")
)
Input data summary:
Training set: 523 cases x 10 features.
.:Resampled Classification Model
LightRF (LightGBM Random Forest)
⟳ Tested using 10-fold crossvalidation.
Resampled Classification Training Metrics
Showing mean (sd) across resamples.
Sensitivity: 0.914 (0.008)
Specificity: 0.971 (4e-03)
Balanced_Accuracy: 0.943 (0.006)
PPV: 0.943 (0.008)
NPV: 0.956 (4.2e-03)
F1: 0.928 (0.007)
Accuracy: 0.951 (4.8e-03)
AUC: 0.987 (1.9e-03)
Brier_Score: 0.065 (1.8e-03)
Resampled Classification Test Metrics
Showing mean (sd) across resamples.
Sensitivity: 0.911 (0.054)
Specificity: 0.971 (0.036)
Balanced_Accuracy: 0.941 (0.031)
PPV: 0.947 (0.061)
NPV: 0.955 (0.026)
F1: 0.927 (0.041)
Accuracy: 0.950 (0.028)
AUC: 0.984 (0.019)
Brier_Score: 0.068 (0.012)
Now, train()
produced a ClassificationRes
object:
class(resmod_lightrf)
[1] "rtemis::ClassificationRes" "rtemis::SupervisedRes"
[3] "S7_object"
describe(resmod_lightrf)
LightGBM Random Forest was used for classification. Mean Balanced accuracy was 0.94 in the training set and 0.94 in the test set across 10 independent folds.
The plot()
method for ClassificationRes
objects plots boxplots of the training and test set metrics:
plot(resmod_lightrf)
The present()
method for ClassificationRes
objects combines the describe()
and plot()
methods:
present(resmod_lightrf)
LightGBM Random Forest was used for classification. Mean Balanced accuracy was 0.94 in the training set and 0.94 in the test set across 10 independent folds.