scglue.models.scglue.PairedSCGLUETrainer
- class scglue.models.scglue.PairedSCGLUETrainer(net, lam_data=None, lam_kl=None, lam_graph=None, lam_align=None, lam_sup=None, lam_joint_cross=None, lam_real_cross=None, lam_cos=None, normalize_u=None, modality_weight=None, optim=None, lr=None, **kwargs)[源代码]
-
Paired trainer for
SCGLUE
- 参数:
net (
scglue.models.scglue.SCGLUE
) –SCGLUE
network to be trainedlam_data (
typing.Optional
[float
]) – Data weightlam_kl (
typing.Optional
[float
]) – KL weightlam_graph (
typing.Optional
[float
]) – Graph weightlam_align (
typing.Optional
[float
]) – Adversarial alignment weightlam_sup (
typing.Optional
[float
]) – Cell type supervision weightlam_joint_cross (
typing.Optional
[float
]) – Joint cross-prediction weightlam_real_cross (
typing.Optional
[float
]) – Real cross-prediction weightlam_cos (
typing.Optional
[float
]) – Cosine similarity weightnormalize_u (
typing.Optional
[bool
]) – Whether to L2 normalize cell embeddings before decodermodality_weight (
typing.Optional
[typing.Mapping
[str
,float
]]) – Relative modality weight (indexed by modality name)optim (
typing.Optional
[str
]) – Optimizerlr (
typing.Optional
[float
]) – Learning rate**kwargs – Additional keyword arguments are passed to the optimizer constructor
方法
Compute loss functions
Format data tensors :rtype:
typing.Tuple
[typing.Mapping
[str
,torch.Tensor
],typing.Mapping
[str
,torch.Tensor
],typing.Mapping
[str
,torch.Tensor
],typing.Mapping
[str
,torch.Tensor
],typing.Mapping
[str
,torch.Tensor
],typing.Mapping
[str
,torch.Tensor
],torch.Tensor
,torch.Tensor
,torch.Tensor
]属性
logger