r"""
Training plugins
"""
import pathlib
import shutil
from typing import Iterable, Optional
import ignite
import ignite.contrib.handlers.tensorboard_logger as tb
import parse
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from ..utils import config, logged
from .base import Trainer, TrainingPlugin
EPOCH_COMPLETED = ignite.engine.Events.EPOCH_COMPLETED
TERMINATE = ignite.engine.Events.TERMINATE
COMPLETED = ignite.engine.Events.COMPLETED
[文档]class Tensorboard(TrainingPlugin):
r"""
Training logging via tensorboard
"""
[文档] def attach(
self, net: torch.nn.Module, trainer: Trainer,
train_engine: ignite.engine.Engine,
val_engine: ignite.engine.Engine,
train_loader: Iterable,
val_loader: Optional[Iterable],
directory: pathlib.Path
) -> None:
tb_directory = directory / "tensorboard"
if tb_directory.exists():
shutil.rmtree(tb_directory)
tb_logger = tb.TensorboardLogger(
log_dir=tb_directory,
flush_secs=config.TENSORBOARD_FLUSH_SECS
)
tb_logger.attach(
train_engine,
log_handler=tb.OutputHandler(
tag="train", metric_names=trainer.required_losses
), event_name=EPOCH_COMPLETED
)
if val_engine:
tb_logger.attach(
val_engine,
log_handler=tb.OutputHandler(
tag="val", metric_names=trainer.required_losses
), event_name=EPOCH_COMPLETED
)
train_engine.add_event_handler(COMPLETED, tb_logger.close)
[文档]@logged
class EarlyStopping(TrainingPlugin):
r"""
Early stop model training when loss no longer decreases
Parameters
----------
monitor
Loss to monitor
patience
Patience to stop early
burnin
Burn-in epochs to skip before initializing early stopping
wait_n_lrs
Wait n learning rate scheduling events before starting early stopping
"""
def __init__(
self, monitor: str, patience: int,
burnin: int = 0, wait_n_lrs: int = 0
) -> None:
super().__init__()
self.monitor = monitor
self.patience = patience
self.burnin = burnin
self.wait_n_lrs = wait_n_lrs
[文档] def attach(
self, net: torch.nn.Module, trainer: Trainer,
train_engine: ignite.engine.Engine,
val_engine: ignite.engine.Engine,
train_loader: Iterable,
val_loader: Optional[Iterable],
directory: pathlib.Path
) -> None:
for item in directory.glob("checkpoint_*.pt"):
item.unlink()
score_engine = val_engine if val_engine else train_engine
score_function = lambda engine: -score_engine.state.metrics[self.monitor]
event_filter = (
lambda engine, event: event > self.burnin and engine.state.n_lrs >= self.wait_n_lrs
) if self.wait_n_lrs else (
lambda engine, event: event > self.burnin
)
event = EPOCH_COMPLETED(event_filter=event_filter) # pylint: disable=not-callable
train_engine.add_event_handler(
event, ignite.handlers.Checkpoint(
{"net": net, "trainer": trainer},
ignite.handlers.DiskSaver(
directory, atomic=True, create_dir=True, require_empty=False
), score_function=score_function,
filename_pattern="checkpoint_{global_step}.pt",
n_saved=config.CHECKPOINT_SAVE_NUMBERS,
global_step_transform=ignite.handlers.global_step_from_engine(train_engine)
)
)
train_engine.add_event_handler(
event, ignite.handlers.EarlyStopping(
patience=self.patience,
score_function=score_function,
trainer=train_engine
)
)
@train_engine.on(COMPLETED | TERMINATE)
def _(engine):
nan_flag = any(
not bool(torch.isfinite(item).all())
for item in (engine.state.output or {}).values()
)
ckpts = sorted([
parse.parse("checkpoint_{epoch:d}.pt", item.name).named["epoch"]
for item in directory.glob("checkpoint_*.pt")
], reverse=True)
if ckpts and nan_flag and train_engine.state.epoch == ckpts[0]:
self.logger.warning(
"The most recent checkpoint \"%d\" can be corrupted by NaNs, "
"will thus be discarded.", ckpts[0]
)
ckpts = ckpts[1:]
if ckpts:
self.logger.info("Restoring checkpoint \"%d\"...", ckpts[0])
loaded = torch.load(directory / f"checkpoint_{ckpts[0]}.pt")
net.load_state_dict(loaded["net"])
trainer.load_state_dict(loaded["trainer"])
else:
self.logger.info(
"No usable checkpoint found. "
"Skipping checkpoint restoration."
)
[文档]@logged
class LRScheduler(TrainingPlugin):
r"""
Reduce learning rate on loss plateau
Parameters
----------
*optims
Optimizers
monitor
Loss to monitor
patience
Patience to reduce learning rate
burnin
Burn-in epochs to skip before initializing learning rate scheduling
"""
def __init__(
self, *optims: torch.optim.Optimizer, monitor: str = None,
patience: int = None, burnin: int = 0
) -> None:
super().__init__()
if monitor is None:
raise ValueError("`monitor` must be specified!")
self.monitor = monitor
if patience is None:
raise ValueError("`patience` must be specified!")
self.schedulers = [
ReduceLROnPlateau(optim, patience=patience, verbose=True)
for optim in optims
]
self.burnin = burnin
[文档] def attach(
self, net: torch.nn.Module, trainer: Trainer,
train_engine: ignite.engine.Engine,
val_engine: ignite.engine.Engine,
train_loader: Iterable,
val_loader: Optional[Iterable],
directory: pathlib.Path
) -> None:
score_engine = val_engine if val_engine else train_engine
event_filter = lambda engine, event: event > self.burnin
for scheduler in self.schedulers:
scheduler.last_epoch = self.burnin
train_engine.state.n_lrs = 0
@train_engine.on(EPOCH_COMPLETED(event_filter=event_filter)) # pylint: disable=not-callable
def _():
update_flags = set()
for scheduler in self.schedulers:
old_lr = scheduler.optimizer.param_groups[0]["lr"]
scheduler.step(score_engine.state.metrics[self.monitor])
new_lr = scheduler.optimizer.param_groups[0]["lr"]
update_flags.add(new_lr != old_lr)
if len(update_flags) != 1:
raise RuntimeError("Learning rates are out of sync!")
if update_flags.pop():
train_engine.state.n_lrs += 1
self.logger.info("Learning rate reduction: step %d", train_engine.state.n_lrs)