Source code for neural_pipeline.train

"""
The main module for training process
"""
import json

import torch
from torch.nn import Module

from neural_pipeline.data_processor import TrainDataProcessor
from neural_pipeline.utils import FileStructManager, CheckpointsManager
from neural_pipeline.train_config.train_config import TrainConfig
from neural_pipeline.monitoring import MonitorHub, ConsoleMonitor

__all__ = ['Trainer']


class LearningRate:
    """
    Basic learning rate class
    """

    def __init__(self, value: float):
        self._value = value

    def value(self) -> float:
        """
        Get value of current learning rate

        :return: current value
        """
        return self._value

    def set_value(self, value) -> None:
        """
        Set lr value

        :param value: lr value
        """
        self._value = value


class DecayingLR(LearningRate):
    """
    This class provide lr decaying by defined metric value (by :arg:`target_value_clbk`).
    If metric value doesn't update minimum after defined number of steps (:arg:`patience`) - lr was decaying
    by defined coefficient (:arg:`decay_coefficient`).

    :param start_value: start value
    :param decay_coefficient: coefficient of decaying
    :param patience: steps before decay
    :param target_value_clbk: callable, that return target value for lr decaying
    """

    def __init__(self, start_value: float, decay_coefficient: float, patience: int, target_value_clbk: callable):
        super().__init__(start_value)

        self._decay_coefficient = decay_coefficient
        self._patience = patience
        self._cur_step = 1
        self._target_value_clbk = target_value_clbk
        self._cur_min_target_val = None

    def value(self) -> float:
        """
        Get value of current learning rate

        :return: learning rate value
        """
        metric_val = self._target_value_clbk()
        if metric_val is None:
            return self._value

        if self._cur_min_target_val is None:
            self._cur_min_target_val = metric_val

        if metric_val < self._cur_min_target_val:
            self._cur_step = 1
            self._cur_min_target_val = metric_val

        if self._cur_step > 0 and (self._cur_step % self._patience) == 0:
            self._value *= self._decay_coefficient
            self._cur_min_target_val = None
            self._cur_step = 1
            return self._value

        self._cur_step += 1
        return self._value

    def set_value(self, value):
        self._value = value
        self._cur_step = 0
        self._cur_min_target_val = None


[docs]class Trainer: """ Class, that run drive process. Trainer get list of training stages and every epoch loop over it. Training process looks like: .. highlight:: python .. code-block:: python for epoch in epochs_num: for stage in training_stages: stage.run() monitor_hub.update_metrics(stage.metrics_processor().get_metrics()) save_state() on_epoch_end_callback() :param model: model for training :param train_config: :class:`TrainConfig` object :param fsm: :class:`FileStructManager` object :param device: device for training process """
[docs] class TrainerException(Exception): def __init__(self, msg): super().__init__() self._msg = msg def __str__(self): return self._msg
def __init__(self, model: Module, train_config: TrainConfig, fsm: FileStructManager, device: torch.device = None): self._fsm = fsm self.monitor_hub = MonitorHub() self._checkpoint_manager = CheckpointsManager(self._fsm) self.__epoch_num = 100 self._resume_from = None self._on_epoch_end = [] self._best_state_rule = None self.__train_config = train_config self._device = device self._data_processor = TrainDataProcessor(model, self.__train_config, self._device) \ .set_checkpoints_manager(self._checkpoint_manager) self._lr = LearningRate(self._data_processor.get_lr())
[docs] def set_epoch_num(self, epoch_number: int) -> 'Trainer': """ Define number of epoch for training. One epoch - one iteration over all train stages :param epoch_number: number of training epoch :return: self object """ self.__epoch_num = epoch_number return self
[docs] def resume(self, from_best_checkpoint: bool) -> 'Trainer': """ Resume train from last checkpoint :param from_best_checkpoint: is need to continue from best checkpoint :return: self object """ self._resume_from = 'last' if from_best_checkpoint is False else 'best' return self
[docs] def enable_lr_decaying(self, coeff: float, patience: int, target_val_clbk: callable) -> 'Trainer': """ Enable rearing rate decaying. Learning rate decay when `target_val_clbk` returns doesn't update minimum for `patience` steps :param coeff: lr decay coefficient :param patience: number of steps :param target_val_clbk: callback which returns the value that is used for lr decaying :return: self object """ self._lr = DecayingLR(self._data_processor.get_lr(), coeff, patience, target_val_clbk) return self
[docs] def train(self) -> None: """ Run training process """ if len(self.__train_config.stages()) < 1: raise self.TrainerException("There's no sages for training") best_checkpoints_manager = None cur_best_state = None if self._best_state_rule is not None: best_checkpoints_manager = CheckpointsManager(self._fsm, 'best') start_epoch_idx = 1 if self._resume_from is not None: start_epoch_idx += self._resume() self.monitor_hub.add_monitor(ConsoleMonitor()) with self.monitor_hub: for epoch_idx in range(start_epoch_idx, self.__epoch_num + start_epoch_idx): self.monitor_hub.set_epoch_num(epoch_idx) for stage in self.__train_config.stages(): stage.run(self._data_processor) if stage.metrics_processor() is not None: self.monitor_hub.update_metrics(stage.metrics_processor().get_metrics()) new_best_state = self._save_state(self._checkpoint_manager, best_checkpoints_manager, cur_best_state, epoch_idx) if new_best_state is not None: cur_best_state = new_best_state self._data_processor.update_lr(self._lr.value()) for clbk in self._on_epoch_end: clbk() self._update_losses() self.__iterate_by_stages(lambda s: s.on_epoch_end())
def _resume(self) -> int: if self._resume_from == 'last': ckpts_manager = self._checkpoint_manager elif self._checkpoint_manager == 'best': ckpts_manager = CheckpointsManager(self._fsm, 'best') else: raise NotImplementedError("Resume parameter may be only 'last' or 'best' not {}".format(self._resume_from)) ckpts_manager.unpack() self._data_processor.load() with open(ckpts_manager.trainer_file(), 'r') as file: start_epoch_idx = json.load(file)['last_epoch'] + 1 ckpts_manager.pack() return start_epoch_idx def _save_state(self, ckpts_manager: CheckpointsManager, best_ckpts_manager: CheckpointsManager or None, cur_best_state: float or None, epoch_idx: int) -> float or None: """ Internal method used for save states after epoch end :param ckpts_manager: ordinal checkpoints manager :param best_ckpts_manager: checkpoints manager, used for store best stages :param cur_best_state: current best stage metric value :return: new best stage metric value or None if it not update """ def save_trainer(ckp_manager): with open(ckp_manager.trainer_file(), 'w') as out: json.dump({'last_epoch': epoch_idx}, out) if self._best_state_rule is not None: new_best_state = self._best_state_rule() if cur_best_state is None: self._data_processor.save_state() save_trainer(ckpts_manager) ckpts_manager.pack() return new_best_state else: if new_best_state <= cur_best_state: self._data_processor.set_checkpoints_manager(best_ckpts_manager) self._data_processor.save_state() save_trainer(best_ckpts_manager) best_ckpts_manager.pack() self._data_processor.set_checkpoints_manager(ckpts_manager) return new_best_state self._data_processor.save_state() save_trainer(ckpts_manager) ckpts_manager.pack() return None def _update_losses(self) -> None: """ Update loses procedure """ losses = {} for stage in self.__train_config.stages(): if stage.get_losses() is not None: losses[stage.name()] = stage.get_losses() self.monitor_hub.update_losses(losses)
[docs] def data_processor(self) -> TrainDataProcessor: """ Get data processor object :return: data processor """ return self._data_processor
[docs] def enable_best_states_saving(self, rule: callable) -> 'Trainer': """ Enable best states saving Best stages will save when return of `rule` update minimum :param rule: callback which returns the value that is used for define when need store best metric :return: self object """ self._best_state_rule = rule return self
[docs] def disable_best_states_saving(self) -> 'Trainer': """ Enable best states saving :return: self object """ self._best_state_rule = None return self
[docs] def add_on_epoch_end_callback(self, callback: callable) -> 'Trainer': """ Add callback, that will be called after every epoch end :param callback: method, that will be called. This method may not get any parameters :return: self object """ self._on_epoch_end.append(callback) return self
def __iterate_by_stages(self, func: callable) -> None: """ Internal method, that used for iterate by stages :param func: callback, that calls for every stage """ for stage in self.__train_config.stages(): func(stage)