scglue.models.scclue.SCCLUETrainer

class scglue.models.scclue.SCCLUETrainer(net, lam_data=None, lam_kl=None, lam_align=None, lam_sup=None, lam_joint_cross=None, lam_real_cross=None, lam_cos=None, normalize_u=None, modality_weight=None, optim=None, lr=None, **kwargs)[源代码]

基类:Trainer

方法

compute_losses

rtype:

typing.Mapping[str, torch.Tensor]

fit

训练模型

format_data

rtype:

typing.Tuple[typing.Mapping[str, torch.Tensor], typing.Mapping[str, torch.Tensor], typing.Mapping[str, torch.Tensor], typing.Mapping[str, torch.Tensor], typing.Mapping[str, torch.Tensor], typing.Mapping[str, torch.Tensor], torch.Tensor]

load_state_dict

Load state from a state dict

state_dict

State dict

train_step

A single training step

val_step

A single validation step

属性

logger