scglue.models.base.Trainer

class scglue.models.base.Trainer(net)[源代码]

基类:object

Abstract trainer class

参数:

net (Module) – Network module to be trained

备注

Subclasses should populate required_losses, and additionally define optimizers here.

方法

fit

训练模型

get_losses

Get loss values for given data

load_state_dict

Load state from a state dict

report_metrics

Report loss values during training

state_dict

State dict

train_step

A single training step

val_step

A single validation step

属性

logger