import numpy as np
import torch
from neural_pipeline.utils import dict_recursive_bypass
from neural_pipeline.utils import CheckpointsManager
from torch.nn import Module
from neural_pipeline.data_processor.model import Model
__all__ = ['DataProcessor', 'TrainDataProcessor']
[docs]class DataProcessor:
"""
DataProcessor manage: model, data processing, device choosing
Args:
model (Module): model, that will be used for process data
device (torch.device): what device pass data for processing
"""
def __init__(self, model: Module, device: torch.device = None):
self._checkpoints_manager = None
self._model = Model(model)
self._device = device
self._pick_model_input = lambda data: data['data']
def set_checkpoints_manager(self, checkpoint_manager: CheckpointsManager) -> 'DataProcessor':
self._checkpoints_manager = checkpoint_manager
self._model.set_checkpoints_manager(checkpoint_manager)
return self
[docs] def model(self) -> Module:
"""
Get current module
"""
return self._model.model()
[docs] def predict(self, data: torch.Tensor or dict) -> object:
"""
Make predict by data
:param data: data as :class:`torch.Tensor` or dict with key ``data``
:return: processed output
:rtype: the model output type
"""
self.model().eval()
with torch.no_grad():
output = self._model(self._pick_model_input(data))
return output
[docs] def load(self) -> None:
"""
Load model weights from checkpoint
"""
self._model.load_weights()
[docs] def save_state(self) -> None:
"""
Save state of optimizer and perform epochs number
"""
self._model.save_weights()
[docs]class TrainDataProcessor(DataProcessor):
"""
TrainDataProcessor is make all of DataProcessor but produce training process.
:param train_config: train config
"""
[docs] class TDPException(Exception):
def __init__(self, msg):
self._msg = msg
def __str__(self):
return self._msg
def __init__(self, train_config: 'TrainConfig', device: torch.device = None):
super().__init__(train_config.model(), device)
self._data_preprocess = (lambda data: data) if device is None else self._pass_data_to_device
self._pick_target = lambda data: data['target']
self._loss_input_preproc = lambda data: data
self.__criterion = train_config.loss()
self.__optimizer = train_config.optimizer()
[docs] def predict(self, data, is_train=False) -> torch.Tensor or dict:
"""
Make predict by data. If ``is_train`` is ``True`` - this operation will compute gradients. If
``is_train`` is ``False`` - this will work with ``model.eval()`` and ``torch.no_grad``
:param data: data in dict
:param is_train: is data processor need train on data or just predict
:return: processed output
:rtype: model return type
"""
if is_train:
self.model().train()
output = self._model(self._pick_model_input(data))
else:
output = super().predict(data)
return output
[docs] def process_batch(self, batch: {}, is_train: bool, metrics_processor: 'AbstractMetricsProcessor' = None) -> np.ndarray:
"""
Process one batch of data
:param batch: dict, contains 'data' and 'target' keys. The values for key must be instance of torch.Tensor or dict
:param is_train: is batch process for train
:param metrics_processor: metrics processor for collect metrics after batch is processed
:return: array of losses with shape (N, ...) where N is batch size
"""
internal_batch = self._data_preprocess(batch)
if is_train:
self.__optimizer.zero_grad()
res = self.predict(internal_batch, is_train)
loss = self.__criterion(res, self._pick_target(internal_batch))
if is_train:
loss.backward()
self.__optimizer.step()
with torch.no_grad():
if metrics_processor is not None:
metrics_processor.calc_metrics(res, self._pick_target(internal_batch))
return loss.data.cpu().numpy()
[docs] def update_lr(self, lr: float) -> None:
"""
Update learning rate straight to optimizer
:param lr: target learning rate
"""
for param_group in self.__optimizer.param_groups:
param_group['lr'] = lr
[docs] def get_lr(self) -> float:
"""
Get learning rate from optimizer
"""
for param_group in self.__optimizer.param_groups:
return param_group['lr']
[docs] def get_state(self) -> {}:
"""
Get model and optimizer state dicts
:return: dict with keys [weights, optimizer]
"""
return {'weights': self._model.model().state_dict(), 'optimizer': self.__optimizer.state_dict()}
def _get_checkpoints_manager(self) -> CheckpointsManager:
if self._checkpoints_manager is None:
raise self.TDPException("Checkpoints manager doesn't specified. Use 'set_checkpoints_manager()'")
return self._checkpoints_manager
[docs] def load(self) -> None:
"""
Load state of model, optimizer and TrainDataProcessor from checkpoint
"""
super().load()
cp_manager = self._get_checkpoints_manager()
print("Optimizer inited by file:", cp_manager.optimizer_state_file(), end='; ')
state = torch.load(cp_manager.optimizer_state_file())
print('state dict len before:', len(state), end='; ')
state = {k: v for k, v in state.items() if k in self.__optimizer.state_dict()}
print('state dict len after:', len(state), end='; ')
self.__optimizer.load_state_dict(state)
print('done')
[docs] def save_state(self) -> None:
"""
Save state of optimizer and perform epochs number
"""
super().save_state()
torch.save(self.__optimizer.state_dict(), self._get_checkpoints_manager().optimizer_state_file())
[docs] def set_pick_target(self, pick_target: callable) -> 'DataProcessor':
"""
Set callback, that will get output from :mod:`DataLoader` and return target.
Default mode:
.. highlight:: python
.. code-block:: python
lambda data: data['target']
Args:
pick_target (callable): pick target callable. This callback need to get one parameter: dataset output
Returns:
self object
Examples:
.. highlight:: python
.. code-block:: python
data_processor.set_pick_target(lambda data: data['target'])
data_processor.set_pick_target(lambda data: data[1])
"""
self._pick_target = pick_target
return self
[docs] def set_data_preprocess(self, data_preprocess: callable) -> 'DataProcessor':
"""
Set callback, that will get output from :mod:`DataLoader` and return preprocessed data.
For example may be used for pass data to device.
Default mode:
.. highlight:: python
.. code-block:: python
:meth:`_pass_data_to_device`
Args:
data_preprocess (callable): preprocess callable. This callback need to get one parameter: dataset output
Returns:
self object
Examples:
.. highlight:: python
.. code-block:: python
from neural_pipeline.utils import dict_recursive_bypass
data_processor.set_data_preprocess(lambda data: dict_recursive_bypass(data, lambda v: v.cuda()))
"""
self._data_preprocess = data_preprocess
return self
def _pass_data_to_device(self, data) -> torch.Tensor or dict:
"""
Internal method, that pass data to specified device
:param data: data as any object type. If will passed to device if it's instance of :class:`torch.Tensor` or dict with key
``data``. Otherwise data will be doesn't changed
:return: processed on target device
"""
if isinstance(data, dict):
return dict_recursive_bypass(data, lambda v: v.to(self._device))
elif isinstance(data, torch.Tensor):
return data.to(self._device)
else:
return data