Trainer¶
The main module for training process
-
class
neural_pipeline.train.
Trainer
(model: torch.nn.modules.module.Module, train_config: neural_pipeline.train_config.train_config.TrainConfig, fsm: neural_pipeline.utils.file_structure_manager.FileStructManager, device: torch.device = None)[source]¶ Class, that run drive process.
Trainer get list of training stages and every epoch loop over it.
Training process looks like:
for epoch in epochs_num: for stage in training_stages: stage.run() monitor_hub.update_metrics(stage.metrics_processor().get_metrics()) save_state() on_epoch_end_callback()
Parameters: - model – model for training
- train_config –
TrainConfig
object - fsm –
FileStructManager
object - device – device for training process
-
add_on_epoch_end_callback
(callback: callable) → neural_pipeline.train.Trainer[source]¶ Add callback, that will be called after every epoch end
Parameters: callback – method, that will be called. This method may not get any parameters Returns: self object
-
data_processor
() → neural_pipeline.data_processor.data_processor.TrainDataProcessor[source]¶ Get data processor object
Returns: data processor
-
disable_best_states_saving
() → neural_pipeline.train.Trainer[source]¶ Enable best states saving
Returns: self object
-
enable_best_states_saving
(rule: callable) → neural_pipeline.train.Trainer[source]¶ Enable best states saving
Best stages will save when return of rule update minimum
Parameters: rule – callback which returns the value that is used for define when need store best metric Returns: self object
-
enable_lr_decaying
(coeff: float, patience: int, target_val_clbk: callable) → neural_pipeline.train.Trainer[source]¶ Enable rearing rate decaying. Learning rate decay when target_val_clbk returns doesn’t update minimum for patience steps
Parameters: - coeff – lr decay coefficient
- patience – number of steps
- target_val_clbk – callback which returns the value that is used for lr decaying
Returns: self object
-
resume
(from_best_checkpoint: bool) → neural_pipeline.train.Trainer[source]¶ Resume train from last checkpoint
Parameters: from_best_checkpoint – is need to continue from best checkpoint Returns: self object