Trainer¶
- class daart.train.Trainer(learning_rate: float = 0.0001, l2_reg: float = 0.0, min_epochs: int = 10, max_epochs: int = 200, val_check_interval: int = 10, rng_seed_train: int = 0, save_last_model: bool = False, callbacks: list = [], **kwargs)[source]¶
Bases:
objectMethods Summary
fit(model, data_generator, save_path)Fit pytorch models with stochastic gradient descent and early stopping.
Methods Documentation
- fit(model, data_generator, save_path)[source]¶
Fit pytorch models with stochastic gradient descent and early stopping.
Training parameters such as min/max epochs are specified in the class constructor.
Training progess is monitored by calculating the model loss on both training data and validation data. The training loss is calculated each epoch, and the validation loss is calculated according to the hparams key ‘val_check_interval’. For example, if val_check_interval=5 then the validation loss is calculated every 5 epochs. If val_check_interval=0.5 then the validation loss is calculated twice per epoch - after the first half of the batches have been processed, then again after all batches have been processed.
Monitored metrics are saved in a csv file in the model directory. This logging is handled by the class
Logger.- Parameters:
model (Segmenter object) – daart model to train
data_generator (ConcatSessionsGenerator object) – data generator to serve data batches
save_path (str, optional) – absolute path to store model and training results