scglue.models.base.Trainer

class scglue.models.base.Trainer(net)[source]

Bases: object

Abstract trainer class

Parameters:

net (Module) – Network module to be trained

Note

Subclasses should populate required_losses, and additionally define optimizers here.

Methods

fit

Fit network

get_losses

Get loss values for given data

load_state_dict

Load state from a state dict

report_metrics

Report loss values during training

state_dict

State dict

train_step

A single training step

val_step

A single validation step

Attributes

logger