scglue.models.plugins 源代码

r"""
Training plugins
"""

import pathlib
import shutil
from typing import Iterable, Optional

import ignite
import ignite.contrib.handlers.tensorboard_logger as tb
import parse
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau

from ..utils import config, logged
from .base import Trainer, TrainingPlugin

EPOCH_COMPLETED = ignite.engine.Events.EPOCH_COMPLETED
TERMINATE = ignite.engine.Events.TERMINATE
COMPLETED = ignite.engine.Events.COMPLETED


[文档]class Tensorboard(TrainingPlugin): r""" Training logging via tensorboard """
[文档] def attach( self, net: torch.nn.Module, trainer: Trainer, train_engine: ignite.engine.Engine, val_engine: ignite.engine.Engine, train_loader: Iterable, val_loader: Optional[Iterable], directory: pathlib.Path ) -> None: tb_directory = directory / "tensorboard" if tb_directory.exists(): shutil.rmtree(tb_directory) tb_logger = tb.TensorboardLogger( log_dir=tb_directory, flush_secs=config.TENSORBOARD_FLUSH_SECS ) tb_logger.attach( train_engine, log_handler=tb.OutputHandler( tag="train", metric_names=trainer.required_losses ), event_name=EPOCH_COMPLETED ) if val_engine: tb_logger.attach( val_engine, log_handler=tb.OutputHandler( tag="val", metric_names=trainer.required_losses ), event_name=EPOCH_COMPLETED ) train_engine.add_event_handler(COMPLETED, tb_logger.close)
[文档]@logged class EarlyStopping(TrainingPlugin): r""" Early stop model training when loss no longer decreases Parameters ---------- monitor Loss to monitor patience Patience to stop early burnin Burn-in epochs to skip before initializing early stopping wait_n_lrs Wait n learning rate scheduling events before starting early stopping """ def __init__( self, monitor: str, patience: int, burnin: int = 0, wait_n_lrs: int = 0 ) -> None: super().__init__() self.monitor = monitor self.patience = patience self.burnin = burnin self.wait_n_lrs = wait_n_lrs
[文档] def attach( self, net: torch.nn.Module, trainer: Trainer, train_engine: ignite.engine.Engine, val_engine: ignite.engine.Engine, train_loader: Iterable, val_loader: Optional[Iterable], directory: pathlib.Path ) -> None: for item in directory.glob("checkpoint_*.pt"): item.unlink() score_engine = val_engine if val_engine else train_engine score_function = lambda engine: -score_engine.state.metrics[self.monitor] event_filter = ( lambda engine, event: event > self.burnin and engine.state.n_lrs >= self.wait_n_lrs ) if self.wait_n_lrs else ( lambda engine, event: event > self.burnin ) event = EPOCH_COMPLETED(event_filter=event_filter) # pylint: disable=not-callable train_engine.add_event_handler( event, ignite.handlers.Checkpoint( {"net": net, "trainer": trainer}, ignite.handlers.DiskSaver( directory, atomic=True, create_dir=True, require_empty=False ), score_function=score_function, filename_pattern="checkpoint_{global_step}.pt", n_saved=config.CHECKPOINT_SAVE_NUMBERS, global_step_transform=ignite.handlers.global_step_from_engine(train_engine) ) ) train_engine.add_event_handler( event, ignite.handlers.EarlyStopping( patience=self.patience, score_function=score_function, trainer=train_engine ) ) @train_engine.on(COMPLETED | TERMINATE) def _(engine): nan_flag = any( not bool(torch.isfinite(item).all()) for item in (engine.state.output or {}).values() ) ckpts = sorted([ parse.parse("checkpoint_{epoch:d}.pt", item.name).named["epoch"] for item in directory.glob("checkpoint_*.pt") ], reverse=True) if ckpts and nan_flag and train_engine.state.epoch == ckpts[0]: self.logger.warning( "The most recent checkpoint \"%d\" can be corrupted by NaNs, " "will thus be discarded.", ckpts[0] ) ckpts = ckpts[1:] if ckpts: self.logger.info("Restoring checkpoint \"%d\"...", ckpts[0]) loaded = torch.load(directory / f"checkpoint_{ckpts[0]}.pt") net.load_state_dict(loaded["net"]) trainer.load_state_dict(loaded["trainer"]) else: self.logger.info( "No usable checkpoint found. " "Skipping checkpoint restoration." )
[文档]@logged class LRScheduler(TrainingPlugin): r""" Reduce learning rate on loss plateau Parameters ---------- *optims Optimizers monitor Loss to monitor patience Patience to reduce learning rate burnin Burn-in epochs to skip before initializing learning rate scheduling """ def __init__( self, *optims: torch.optim.Optimizer, monitor: str = None, patience: int = None, burnin: int = 0 ) -> None: super().__init__() if monitor is None: raise ValueError("`monitor` must be specified!") self.monitor = monitor if patience is None: raise ValueError("`patience` must be specified!") self.schedulers = [ ReduceLROnPlateau(optim, patience=patience, verbose=True) for optim in optims ] self.burnin = burnin
[文档] def attach( self, net: torch.nn.Module, trainer: Trainer, train_engine: ignite.engine.Engine, val_engine: ignite.engine.Engine, train_loader: Iterable, val_loader: Optional[Iterable], directory: pathlib.Path ) -> None: score_engine = val_engine if val_engine else train_engine event_filter = lambda engine, event: event > self.burnin for scheduler in self.schedulers: scheduler.last_epoch = self.burnin train_engine.state.n_lrs = 0 @train_engine.on(EPOCH_COMPLETED(event_filter=event_filter)) # pylint: disable=not-callable def _(): update_flags = set() for scheduler in self.schedulers: old_lr = scheduler.optimizer.param_groups[0]["lr"] scheduler.step(score_engine.state.metrics[self.monitor]) new_lr = scheduler.optimizer.param_groups[0]["lr"] update_flags.add(new_lr != old_lr) if len(update_flags) != 1: raise RuntimeError("Learning rates are out of sync!") if update_flags.pop(): train_engine.state.n_lrs += 1 self.logger.info("Learning rate reduction: step %d", train_engine.state.n_lrs)