r"""
Base classes for model definition and training
"""
import os
import pathlib
import tempfile
from abc import abstractmethod
from typing import Any, Iterable, List, Mapping, Optional
import dill
import ignite
import torch
from ..utils import DelayedKeyboardInterrupt, config, logged
EPOCH_STARTED = ignite.engine.Events.EPOCH_STARTED
EPOCH_COMPLETED = ignite.engine.Events.EPOCH_COMPLETED
ITERATION_COMPLETED = ignite.engine.Events.ITERATION_COMPLETED
EXCEPTION_RAISED = ignite.engine.Events.EXCEPTION_RAISED
COMPLETED = ignite.engine.Events.COMPLETED
[文档]@logged
class Trainer:
r"""
Abstract trainer class
Parameters
----------
net
Network module to be trained
Note
----
Subclasses should populate ``required_losses``, and additionally
define optimizers here.
"""
def __init__(self, net: torch.nn.Module) -> None:
self.net = net
self.required_losses: List[str] = []
[文档] @abstractmethod
def train_step(
self, engine: ignite.engine.Engine, data: List[torch.Tensor]
) -> Mapping[str, torch.Tensor]:
r"""
A single training step
Parameters
----------
engine
Training engine
data
Data of the training step
Returns
-------
loss_dict
Dict containing training loss values
"""
raise NotImplementedError # pragma: no cover
[文档] @abstractmethod
def val_step(
self, engine: ignite.engine.Engine, data: List[torch.Tensor]
) -> Mapping[str, torch.Tensor]:
r"""
A single validation step
Parameters
----------
engine
Validation engine
data
Data of the validation step
Returns
-------
loss_dict
Dict containing validation loss values
"""
raise NotImplementedError # pragma: no cover
[文档] def report_metrics(
self, train_state: ignite.engine.State,
val_state: Optional[ignite.engine.State]
) -> None:
r"""
Report loss values during training
Parameters
----------
train_state
Training engine state
val_state
Validation engine state
"""
if train_state.epoch % config.PRINT_LOSS_INTERVAL:
return
train_metrics = {
key: float(f"{val:.3f}")
for key, val in train_state.metrics.items()
}
val_metrics = {
key: float(f"{val:.3f}")
for key, val in val_state.metrics.items()
} if val_state else None
self.logger.info(
"[Epoch %d] train=%s, val=%s, %.1fs elapsed",
train_state.epoch, train_metrics, val_metrics,
train_state.times["EPOCH_COMPLETED"] # Also includes validator time
)
[文档] def fit(
self, train_loader: Iterable, val_loader: Optional[Iterable] = None,
max_epochs: int = 100, random_seed: int = 0,
directory: Optional[os.PathLike] = None,
plugins: Optional[List["TrainingPlugin"]] = None
) -> None:
r"""
Fit network
Parameters
----------
train_loader
Training data loader
val_loader
Validation data loader
max_epochs
Maximal number of epochs
random_seed
Random seed
directory
Training directory
plugins
Optional list of training plugins
"""
interrupt_delayer = DelayedKeyboardInterrupt()
directory = pathlib.Path(directory or tempfile.mkdtemp(prefix=config.TMP_PREFIX))
self.logger.info("Using training directory: \"%s\"", directory)
# Construct engines
train_engine = ignite.engine.Engine(self.train_step)
val_engine = ignite.engine.Engine(self.val_step) if val_loader else None
delay_interrupt = interrupt_delayer.__enter__
train_engine.add_event_handler(EPOCH_STARTED, delay_interrupt)
train_engine.add_event_handler(COMPLETED, delay_interrupt)
# Exception handling
train_engine.add_event_handler(ITERATION_COMPLETED, ignite.handlers.TerminateOnNan())
@train_engine.on(EXCEPTION_RAISED)
def _handle_exception(engine, e):
if isinstance(e, KeyboardInterrupt) and config.ALLOW_TRAINING_INTERRUPTION:
self.logger.info("Stopping training due to user interrupt...")
engine.terminate()
else:
raise e
# Compute metrics
for item in self.required_losses:
ignite.metrics.Average(
output_transform=lambda output, item=item: output[item]
).attach(train_engine, item)
if val_engine:
ignite.metrics.Average(
output_transform=lambda output, item=item: output[item]
).attach(val_engine, item)
if val_engine:
@train_engine.on(EPOCH_COMPLETED)
def _validate(engine):
val_engine.run(
val_loader, max_epochs=engine.state.epoch
) # Bumps max_epochs by 1 per training epoch, so validator resumes for 1 epoch
@train_engine.on(EPOCH_COMPLETED)
def _report_metrics(engine):
self.report_metrics(engine.state, val_engine.state if val_engine else None)
for plugin in plugins or []:
plugin.attach(
net=self.net, trainer=self,
train_engine=train_engine, val_engine=val_engine,
train_loader=train_loader, val_loader=val_loader,
directory=directory
)
restore_interrupt = lambda: interrupt_delayer.__exit__(None, None, None)
train_engine.add_event_handler(EPOCH_COMPLETED, restore_interrupt)
train_engine.add_event_handler(COMPLETED, restore_interrupt)
# Start engines
torch.manual_seed(random_seed)
train_engine.run(train_loader, max_epochs=max_epochs)
torch.cuda.empty_cache() # Works even if GPU is unavailable
[文档] def get_losses(self, loader: Iterable) -> Mapping[str, float]:
r"""
Get loss values for given data
Parameters
----------
loader
Data loader
Returns
-------
loss_dict
Dict containing loss values
"""
engine = ignite.engine.Engine(self.val_step)
for item in self.required_losses:
ignite.metrics.Average(
output_transform=lambda output, item=item: output[item]
).attach(engine, item)
engine.run(loader, max_epochs=1)
torch.cuda.empty_cache() # Works even if GPU is unavailable
return engine.state.metrics
[文档] def state_dict(self) -> Mapping[str, Any]:
r"""
State dict
Returns
-------
state_dict
State dict
"""
return {}
[文档] def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
r"""
Load state from a state dict
Parameters
----------
state_dict
State dict
"""
[文档]@logged
class Model:
r"""
Abstract model class
Parameters
----------
net
Network type
*args
Positional arguments are passed to the network constructor
**kwargs
Keyword arguments are passed to the network constructor
Note
----
Subclasses may override arguments for API definition.
"""
NET_TYPE = torch.nn.Module
TRAINER_TYPE = Trainer
def __init__(self, *args, **kwargs) -> None:
self._net = self.NET_TYPE(*args, **kwargs)
self._trainer: Optional[Trainer] = None # Constructed upon compile
@property
def net(self) -> torch.nn.Module:
r"""
Neural network module in the model (read-only)
"""
return self._net
@property
def trainer(self) -> Trainer:
r"""
Trainer of the neural network module (read-only)
"""
if self._trainer is None:
raise RuntimeError(
"No trainer has been registered! "
"Please call `.compile()` first."
)
return self._trainer
[文档] def compile(self, *args, **kwargs) -> None:
r"""
Prepare model for training
Parameters
----------
trainer
Trainer type
*args
Positional arguments are passed to the trainer constructor
**kwargs
Keyword arguments are passed to the trainer constructor
Note
----
Subclasses may override arguments for API definition.
"""
if self._trainer:
self.logger.warning(
"`compile` has already been called. "
"Previous trainer will be overwritten!"
)
self._trainer = self.TRAINER_TYPE(self.net, *args, **kwargs)
[文档] def fit(self, *args, **kwargs) -> None:
r"""
Alias of ``.trainer.fit``.
Parameters
----------
*args
Positional arguments are passed to the ``.trainer.fit`` method
**kwargs
Keyword arguments are passed to the ``.trainer.fit`` method
Note
----
Subclasses may override arguments for API definition.
"""
self.trainer.fit(*args, **kwargs)
[文档] def get_losses(self, *args, **kwargs) -> Mapping[str, float]:
r"""
Alias of ``.trainer.get_losses``.
Parameters
----------
*args
Positional arguments are passed to the ``.trainer.get_losses`` method
**kwargs
Keyword arguments are passed to the ``.trainer.get_losses`` method
Returns
-------
loss_dict
Dict containing loss values
"""
return self.trainer.get_losses(*args, **kwargs)
[文档] def save(self, fname: os.PathLike) -> None:
r"""
Save model to file
Parameters
----------
file
Specifies path to the file
Note
----
Only the network is saved but not the trainer
"""
fname = pathlib.Path(fname)
trainer_backup, self._trainer = self._trainer, None
device_backup, self.net.device = self.net.device, torch.device("cpu")
with fname.open("wb") as f:
dill.dump(self, f, protocol=4, byref=False, recurse=True)
self.net.device = device_backup
self._trainer = trainer_backup
[文档] def upgrade(self) -> None:
r"""
Upgrade the model if generated by older versions
"""
[文档]@logged
class TrainingPlugin:
r"""
Plugin used to extend the training process with certain functions
"""
[文档] @abstractmethod
def attach(
self, net: torch.nn.Module, trainer: Trainer,
train_engine: ignite.engine.Engine,
val_engine: ignite.engine.Engine,
train_loader: Iterable,
val_loader: Optional[Iterable],
directory: pathlib.Path
) -> None:
r"""
Attach custom handlers to training or validation engine
Parameters
----------
net
Network module
trainer
Trainer object
train_engine
Training engine
val_engine
Validation engine
train_loader
Training data loader
val_loader
Validation data loader
directory
Training directory
"""
raise NotImplementedError # pragma: no cover