Skip to contents

Setup hyperparameters for TabNet training.

Usage

setup_TabNet(
  batch_size = 1024^2,
  penalty = 0.001,
  clip_value = NULL,
  loss = "auto",
  epochs = 50L,
  drop_last = FALSE,
  decision_width = NULL,
  attention_width = NULL,
  num_steps = 3L,
  feature_reusage = 1.3,
  mask_type = "sparsemax",
  virtual_batch_size = 256^2,
  valid_split = 0,
  learn_rate = 0.02,
  optimizer = "adam",
  lr_scheduler = NULL,
  lr_decay = 0.1,
  step_size = 30,
  checkpoint_epochs = 10L,
  cat_emb_dim = 1L,
  num_independent = 2L,
  num_shared = 2L,
  num_independent_decoder = 1L,
  num_shared_decoder = 1L,
  momentum = 0.02,
  pretraining_ratio = 0.5,
  device = "auto",
  importance_sample_size = NULL,
  early_stopping_monitor = "auto",
  early_stopping_tolerance = 0,
  early_stopping_patience = 0,
  num_workers = 0L,
  skip_importance = FALSE,
  ifw = FALSE
)

Arguments

batch_size

(Tunable) Positive integer: Batch size.

penalty

(Tunable) Numeric: Regularization penalty.

clip_value

Numeric: Clip value.

loss

Character: Loss function.

epochs

(Tunable) Positive integer: Number of epochs.

drop_last

Logical: If TRUE, drop last batch.

decision_width

(Tunable) Positive integer: Decision width.

attention_width

(Tunable) Positive integer: Attention width.

num_steps

(Tunable) Positive integer: Number of steps.

feature_reusage

(Tunable) Numeric: Feature reusage.

mask_type

Character: Mask type.

virtual_batch_size

(Tunable) Positive integer: Virtual batch size.

valid_split

Numeric: Validation split.

learn_rate

(Tunable) Numeric: Learning rate.

optimizer

Character or torch function: Optimizer.

lr_scheduler

Character or torch function: "step", "reduce_on_plateau".

lr_decay

Numeric: Learning rate decay.

step_size

Positive integer: Step size.

checkpoint_epochs

(Tunable) Positive integer: Checkpoint epochs.

cat_emb_dim

(Tunable) Positive integer: Categorical embedding dimension.

num_independent

(Tunable) Positive integer: Number of independent Gated Linear Units (GLU) at each step of the encoder.

num_shared

(Tunable) Positive integer: Number of shared Gated Linear Units (GLU) at each step of the encoder.

num_independent_decoder

(Tunable) Positive integer: Number of independent GLU layers for pretraining.

num_shared_decoder

(Tunable) Positive integer: Number of shared GLU layers for pretraining.

momentum

(Tunable) Numeric: Momentum.

pretraining_ratio

(Tunable) Numeric: Pretraining ratio.

device

Character: Device "cpu" or "cuda".

importance_sample_size

Positive integer: Importance sample size.

early_stopping_monitor

Character: Early stopping monitor. "valid_loss", "train_loss", "auto".

early_stopping_tolerance

Numeric: Minimum relative improvement to reset the patience counter.

early_stopping_patience

Positive integer: Number of epochs without improving before stopping.

num_workers

Positive integer: Number of subprocesses for data loacding.

skip_importance

Logical: If TRUE, skip importance calculation.

ifw

Logical: If TRUE, use Inverse Frequency Weighting in classification.

Value

TabNetHyperparameters object.

Author

EDG