Skip to contents

Setup hyperparameters for TabNet training.

Usage

setup_TabNet(
  batch_size = 1024^2,
  penalty = 0.001,
  clip_value = NULL,
  loss = "auto",
  epochs = 50L,
  drop_last = FALSE,
  decision_width = NULL,
  attention_width = NULL,
  num_steps = 3L,
  feature_reusage = 1.3,
  mask_type = "sparsemax",
  virtual_batch_size = 256^2,
  valid_split = 0,
  learn_rate = 0.02,
  optimizer = "adam",
  lr_scheduler = NULL,
  lr_decay = 0.1,
  step_size = 30,
  checkpoint_epochs = 10L,
  cat_emb_dim = 1L,
  num_independent = 2L,
  num_shared = 2L,
  num_independent_decoder = 1L,
  num_shared_decoder = 1L,
  momentum = 0.02,
  pretraining_ratio = 0.5,
  device = "auto",
  importance_sample_size = NULL,
  early_stopping_monitor = "auto",
  early_stopping_tolerance = 0,
  early_stopping_patience = 0,
  num_workers = 0L,
  skip_importance = FALSE,
  ifw = FALSE
)

Arguments

batch_size

(Tunable) Positive integer: Batch size.

penalty

(Tunable) Numeric: Regularization penalty.

clip_value

Numeric: Clip value.

loss

Character: Loss function.

epochs

(Tunable) Positive integer: Number of epochs.

drop_last

Logical: If TRUE, drop last batch.

decision_width

(Tunable) Positive integer: Decision width.

attention_width

(Tunable) Positive integer: Attention width.

num_steps

(Tunable) Positive integer: Number of steps.

feature_reusage

(Tunable) Numeric: Feature reusage.

mask_type

Character: Mask type.

virtual_batch_size

(Tunable) Positive integer: Virtual batch size.

valid_split

Numeric: Validation split.

learn_rate

(Tunable) Numeric: Learning rate.

optimizer

Character or torch function: Optimizer.

lr_scheduler

Character or torch function: "step", "reduce_on_plateau".

lr_decay

Numeric: Learning rate decay.

step_size

Positive integer: Step size.

checkpoint_epochs

(Tunable) Positive integer: Checkpoint epochs.

cat_emb_dim

(Tunable) Positive integer: Categorical embedding dimension.

num_independent

(Tunable) Positive integer: Number of independent Gated Linear Units (GLU) at each step of the encoder.

num_shared

(Tunable) Positive integer: Number of shared Gated Linear Units (GLU) at each step of the encoder.

num_independent_decoder

(Tunable) Positive integer: Number of independent GLU layers for pretraining.

num_shared_decoder

(Tunable) Positive integer: Number of shared GLU layers for pretraining.

momentum

(Tunable) Numeric: Momentum.

pretraining_ratio

(Tunable) Numeric: Pretraining ratio.

device

Character: Device "cpu" or "cuda".

importance_sample_size

Positive integer: Importance sample size.

early_stopping_monitor

Character: Early stopping monitor. "valid_loss", "train_loss", "auto".

early_stopping_tolerance

Numeric: Minimum relative improvement to reset the patience counter.

early_stopping_patience

Positive integer: Number of epochs without improving before stopping.

num_workers

Positive integer: Number of subprocesses for data loacding.

skip_importance

Logical: If TRUE, skip importance calculation.

ifw

Logical: If TRUE, use Inverse Frequency Weighting in classification.

Value

TabNetHyperparameters object.

Author

EDG

Examples

tabnet_hyperparams <- setup_TabNet(epochs = 100L, learn_rate = 0.01)
tabnet_hyperparams
#> <TabNetHyperparameters>
#>         hyperparameters: 
#>                                        batch_size: <nmr> 1e+06
#>                                           penalty: <nmr> 1e-03
#>                                        clip_value: <NUL> NULL
#>                                              loss: <chr> auto
#>                                            epochs: <int> 100
#>                                         drop_last: <lgc> FALSE
#>                                    decision_width: <NUL> NULL
#>                                   attention_width: <NUL> NULL
#>                                         num_steps: <int> 3
#>                                   feature_reusage: <nmr> 1.30
#>                                         mask_type: <chr> sparsemax
#>                                virtual_batch_size: <nmr> 65536.00
#>                                       valid_split: <nmr> 0.00
#>                                        learn_rate: <nmr> 0.01
#>                                         optimizer: <chr> adam
#>                                      lr_scheduler: <NUL> NULL
#>                                          lr_decay: <nmr> 0.10
#>                                         step_size: <nmr> 30.00
#>                                 checkpoint_epochs: <int> 10
#>                                       cat_emb_dim: <int> 1
#>                                   num_independent: <int> 2
#>                                        num_shared: <int> 2
#>                           num_independent_decoder: <int> 1
#>                                num_shared_decoder: <int> 1
#>                                          momentum: <nmr> 0.02
#>                                 pretraining_ratio: <nmr> 0.50
#>                                            device: <chr> auto
#>                            importance_sample_size: <NUL> NULL
#>                            early_stopping_monitor: <chr> auto
#>                          early_stopping_tolerance: <nmr> 0.00
#>                           early_stopping_patience: <nmr> 0.00
#>                                       num_workers: <int> 0
#>                                   skip_importance: <lgc> FALSE
#>                                               ifw: <lgc> FALSE
#> tunable_hyperparameters: <chr> batch_size, penalty, clip_value, loss, epochs, drop_last, decision_width, attention_width, num_steps, feature_reusage, mask_type, virtual_batch_size, valid_split, learn_rate, optimizer, lr_scheduler, lr_decay, step_size, checkpoint_epochs, cat_emb_dim, num_independent, num_shared, num_independent_decoder, num_shared_decoder, momentum, pretraining_ratio, importance_sample_size, early_stopping_monitor, early_stopping_tolerance, early_stopping_patience, num_workers, skip_importance, early_stopping_patience, ifw
#>   fixed_hyperparameters: <chr> device, num_workers, skip_importance
#>                   tuned: <int> -1
#>               resampled: <int> 0
#>               n_workers: <int> 1
#> 
#>   No search values defined for tunable hyperparameters.