from abc import ABCMeta, abstractmethod
from torch import Tensor
from torch.optim import Optimizer
from torch.nn import Module
import numpy as np
from torch.utils.data import DataLoader
try:
from IPython import get_ipython
ip = get_ipython()
if ip is not None:
from tqdm import tqdm_notebook as tqdm
else:
from tqdm import tqdm
except ImportError:
from tqdm import tqdm
from neural_pipeline.data_producer.data_producer import DataProducer
from neural_pipeline.data_processor.data_processor import TrainDataProcessor
__all__ = ['TrainConfig', 'TrainStage', 'ValidationStage', 'AbstractMetric', 'MetricsGroup', 'MetricsProcessor', 'AbstractStage',
'StandardStage']
[docs]class AbstractMetric(metaclass=ABCMeta):
"""
Abstract class for metrics. When it works in neural_pipeline, it store metric value for every call of :meth:`calc`
:param name: name of metric. Name wil be used in monitors, so be careful in use unsupported characters
"""
def __init__(self, name: str):
self._name = name
self._values = np.array([])
[docs] @abstractmethod
def calc(self, output: Tensor, target: Tensor) -> np.ndarray or float:
"""
Calculate metric by output from model and target
:param output: output from model
:param target: ground truth
"""
def _calc(self, output: Tensor, target: Tensor):
"""
Calculate metric by output from model and target. Method for internal use
:param output: output from model
:param target: ground truth
"""
self._values = np.append(self._values, self.calc(output, target))
[docs] def name(self) -> str:
"""
Get name of metric
:return: metric name
"""
return self._name
[docs] def get_values(self) -> np.ndarray:
"""
Get array of metric values
:return: array of values
"""
return self._values
[docs] def reset(self) -> None:
"""
Reset array of metric values
"""
self._values = np.array([])
[docs] @staticmethod
def min_val() -> float:
"""
Get minimum value of metric. This used for correct histogram visualisation in some monitors
:return: minimum value
"""
return 0
[docs] @staticmethod
def max_val() -> float:
"""
Get maximum value of metric. This used for correct histogram visualisation in some monitors
:return: maximum value
"""
return 1
[docs]class MetricsGroup:
"""
Class for unite metrics or another :class:`MetricsGroup`'s in one namespace.
Note: MetricsGroup may contain only 2 level of :class:`MetricsGroup`'s. So ``MetricsGroup().add(MetricsGroup().add(MetricsGroup()))``
will raises :class:`MGException`
:param name: group name. Name wil be used in monitors, so be careful in use unsupported characters
"""
[docs] class MGException(Exception):
"""
Exception for MetricsGroup
"""
def __init__(self, msg: str):
self.__msg = msg
def __str__(self):
return self.__msg
def __init__(self, name: str):
self.__name = name
self.__metrics = []
self.__metrics_groups = []
self.__lvl = 1
[docs] def add(self, item: AbstractMetric or 'MetricsGroup') -> 'MetricsGroup':
"""
Add :class:`AbstractMetric` or :class:`MetricsGroup`
:param item: object to add
:return: self object
:rtype: :class:`MetricsGroup`
"""
if isinstance(item, type(self)):
item._set_level(self.__lvl + 1)
self.__metrics_groups.append(item)
else:
self.__metrics.append(item)
return self
[docs] def metrics(self) -> [AbstractMetric]:
"""
Get list of metrics
:return: list of metrics
"""
return self.__metrics
[docs] def groups(self) -> ['MetricsGroup']:
"""
Get list of metrics groups
:return: list of metrics groups
"""
return self.__metrics_groups
[docs] def name(self) -> str:
"""
Get group name
:return: name
"""
return self.__name
[docs] def have_groups(self) -> bool:
"""
Is this group contains another metrics groups
:return: True if contains, otherwise - False
"""
return len(self.__metrics_groups) > 0
def _set_level(self, level: int) -> None:
"""
Internal method for set metrics group level
TODO: if metrics group contains in two groups with different levels - this is undefined case
:param level: parent group level
"""
if level > 2:
raise self.MGException("The metric group {} have {} level. There must be no more than 2 levels".format(self.__name, self.__lvl))
self.__lvl = level
for group in self.__metrics_groups:
group._set_level(self.__lvl + 1)
[docs] def calc(self, output: Tensor, target: Tensor) -> None:
"""
Recursive calculate all metrics in this group and all nested group
:param output: predict value
:param target: target value
"""
for metric in self.__metrics:
metric._calc(output, target)
for group in self.__metrics_groups:
group.calc(output, target)
[docs] def reset(self) -> None:
"""
Recursive reset all metrics in this group and all nested group
"""
for metric in self.__metrics:
metric.reset()
for group in self.__metrics_groups:
group.reset()
[docs]class MetricsProcessor:
"""
Collection for all :class:`AbstractMetric`'s and :class:`MetricsGroup`'s
"""
def __init__(self):
self._metrics = []
self._metrics_groups = []
[docs] def add_metric(self, metric: AbstractMetric) -> AbstractMetric:
"""
Add :class:`AbstractMetric` object
:param metric: metric to add
:return: metric object
:rtype: :class:`AbstractMetric`
"""
self._metrics.append(metric)
return metric
[docs] def add_metrics_group(self, group: MetricsGroup) -> MetricsGroup:
"""
Add :class:`MetricsGroup` object
:param group: metrics group to add
:return: metrics group object
:rtype: :class:`MetricsGroup`
"""
self._metrics_groups.append(group)
return group
[docs] def calc_metrics(self, output, target) -> None:
"""
Recursive calculate all metrics
:param output: predict value
:param target: target value
"""
for metric in self._metrics:
metric.calc(output, target)
for group in self._metrics_groups:
group.calc(output, target)
[docs] def reset_metrics(self) -> None:
"""
Recursive reset all metrics values
"""
for metric in self._metrics:
metric.reset()
for group in self._metrics_groups:
group.reset()
[docs] def get_metrics(self) -> {}:
"""
Get metrics and groups as dict
:return: dict of metrics and groups with keys [metrics, groups]
"""
return {'metrics': self._metrics, 'groups': self._metrics_groups}
[docs]class AbstractStage(metaclass=ABCMeta):
"""
Stage of training process. For example there may be 2 stages: train and validation.
Every epochs in train loop is iteration by stages.
:param name: name of stage
"""
def __init__(self, name: str):
self._name = name
[docs] def name(self) -> str:
"""
Get name of stage
:return: name
"""
return self._name
[docs] def metrics_processor(self) -> MetricsProcessor or None:
"""
Get metrics processor
:return: :class:'MetricsProcessor` object or None
"""
return None
[docs] @abstractmethod
def run(self, data_processor: TrainDataProcessor) -> None:
"""
Run stage
"""
[docs] def get_losses(self) -> np.ndarray or None:
"""
Get losses from this stage
:return: array of losses or None if this stage doesn't need losses
"""
return None
[docs] def on_epoch_end(self) -> None:
"""
Callback for train epoch end
"""
pass
[docs]class StandardStage(AbstractStage):
"""
Standard stage for train process.
When call :meth:`run` it's iterate :meth:`process_batch` of data processor by data loader
After stop iteration ValidationStage accumulate losses from :class:`DataProcessor`.
:param data_producer: :class:`DataProducer` object
:param metrics_processor: :class:`MetricsProcessor`
"""
def __init__(self, stage_name: str, is_train: bool, data_producer: DataProducer, metrics_processor: MetricsProcessor = None):
super().__init__(name=stage_name)
self.data_loader = None
self.data_producer = data_producer
self._metrics_processor = metrics_processor
self._losses = None
self._is_train = is_train
[docs] def run(self, data_processor: TrainDataProcessor) -> None:
"""
Run stage. This iterate by DataProducer and show progress in stdout
:param data_processor: :class:`DataProcessor` object
"""
if self.data_loader is None:
self.data_loader = self.data_producer.get_loader()
self._run(self.data_loader, self.name(), data_processor)
def _run(self, data_loader: DataLoader, name: str, data_processor: TrainDataProcessor):
with tqdm(data_loader, desc=name, leave=False) as t:
self._losses = None
for batch in t:
self._process_batch(batch, data_processor)
t.set_postfix({'loss': '[{:4f}]'.format(np.mean(self._losses))})
def _process_batch(self, batch, data_processor: TrainDataProcessor):
cur_loss = data_processor.process_batch(batch, metrics_processor=self.metrics_processor(), is_train=self._is_train)
if self._losses is None:
self._losses = cur_loss
else:
self._losses = np.append(self._losses, cur_loss)
[docs] def metrics_processor(self) -> MetricsProcessor or None:
"""
Get merics processor of this stage
:return: :class:`MetricsProcessor` if specified otherwise None
"""
return self._metrics_processor
[docs] def get_losses(self) -> np.ndarray:
"""
Get losses from this stage
:return: array of losses
"""
return self._losses
[docs] def on_epoch_end(self) -> None:
"""
Method, that calls after every epoch
"""
self._losses = None
metrics_processor = self.metrics_processor()
if metrics_processor is not None:
metrics_processor.reset_metrics()
[docs]class TrainStage(StandardStage):
"""
Standard training stage
When call :meth:`run` it's iterate :meth:`process_batch` of data processor by data loader with ``is_tran=True`` flag.
After stop iteration ValidationStage accumulate losses from :class:`DataProcessor`.
:param data_producer: :class:`DataProducer` object
:param metrics_processor: :class:`MetricsProcessor`
:param name: name of stage. By default 'train'
"""
class _HardNegativesTrainStage(StandardStage):
def __init__(self, stage_name: str, data_producer: DataProducer, part: float):
super().__init__(stage_name, True, data_producer)
self._part = part
def exec(self, data_processor: TrainDataProcessor, losses: np.ndarray, indices: []) -> None:
num_losses = int(losses.size * self._part)
idxs = np.argpartition(losses, -num_losses)[-num_losses:]
self._run(self.data_producer.get_loader([indices[i] for i in idxs]), self.name(), data_processor)
def __init__(self, data_producer: DataProducer, metrics_processor: MetricsProcessor = None, name: str = 'train'):
super().__init__(name, True, data_producer, metrics_processor)
self.hnm = None
self.hn_indices = []
self._dp_pass_indices_earlier = False
[docs] def enable_hard_negative_mining(self, part: float) -> 'TrainStage':
"""
Enable hard negative mining. Hard negatives was taken by losses values
:param part: part of data that repeat after train stage
:return: self object
"""
if not 0 < part < 1:
raise ValueError('Value of part for hard negative mining is out of range (0, 1)')
self.hnm = self._HardNegativesTrainStage(self.name() + '_hnm', self.data_producer, part)
self._dp_pass_indices_earlier = self.data_producer._is_passed_indices()
self.data_producer.pass_indices(True)
return self
[docs] def disable_hard_negative_mining(self) -> 'TrainStage':
"""
Enable hard negative mining.
:return: self object
"""
self.hnm = None
if not self._dp_pass_indices_earlier:
self.data_producer.pass_indices(False)
return self
[docs] def run(self, data_processor: TrainDataProcessor) -> None:
"""
Run stage
:param data_processor: :class:`TrainDataProcessor` object
"""
super().run(data_processor)
if self.hnm is not None:
self.hnm.exec(data_processor, self._losses, self.hn_indices)
self.hn_indices = []
def _process_batch(self, batch, data_processor: TrainDataProcessor) -> None:
"""
Internal method for process one bathc
:param batch: batch
:param data_processor: :class:`TrainDataProcessor` instance
"""
if self.hnm is not None:
self.hn_indices.append(batch['data_idx'])
super()._process_batch(batch, data_processor)
[docs] def on_epoch_end(self):
"""
Method, that calls after every epoch
"""
super().on_epoch_end()
if self.hnm is not None:
self.hnm.on_epoch_end()
[docs]class ValidationStage(StandardStage):
"""
Standard validation stage.
When call :meth:`run` it's iterate :meth:`process_batch` of data processor by data loader with ``is_tran=False`` flag.
After stop iteration ValidationStage accumulate losses from :class:`DataProcessor`.
:param data_producer: :class:`DataProducer` object
:param metrics_processor: :class:`MetricsProcessor`
:param name: name of stage. By default 'validation'
"""
def __init__(self, data_producer: DataProducer, metrics_processor: MetricsProcessor = None, name: str = 'validation'):
super().__init__(name, False, data_producer, metrics_processor)
[docs]class TrainConfig:
"""
Train process setting storage
:param train_stages: list of stages for train loop
:param loss: loss criterion
:param optimizer: optimizer object
"""
def __init__(self, train_stages: [], loss: Module, optimizer: Optimizer):
self._train_stages = train_stages
self.__loss = loss
self.__optimizer = optimizer
[docs] def loss(self) -> Module:
"""
Get loss object
:return: loss object
"""
return self.__loss
[docs] def optimizer(self) -> Optimizer:
"""
Get optimizer object
:return: optimizer object
"""
return self.__optimizer
[docs] def stages(self) -> [AbstractStage]:
"""
Get list of stages
:return: list of stages
"""
return self._train_stages