scglue.models.glue.GLUETrainer.fit

GLUETrainer.fit(data, graph, val_split=None, data_batch_size=None, graph_batch_size=None, align_burnin=None, safe_burnin=True, max_epochs=None, patience=None, reduce_lr_patience=None, wait_n_lrs=None, random_seed=None, directory=None, plugins=None)[source]

Fit network

Parameters
  • data (ArrayDataset) – Data dataset

  • graph (GraphDataset) – Graph dataset

  • val_split (Optional[float]) – Validation split

  • data_batch_size (Optional[int]) – Number of samples in each data minibatch

  • graph_batch_size (Optional[int]) – Number of edges in each graph minibatch

  • align_burnin (Optional[int]) – Number of epochs to wait before starting alignment

  • safe_burnin (bool) – Whether to postpone learning rate scheduling and earlystopping until after the burnin stage

  • max_epochs (Optional[int]) – Maximal number of epochs

  • patience (Optional[int]) – Patience of early stopping

  • reduce_lr_patience (Optional[int]) – Patience to reduce learning rate

  • wait_n_lrs (Optional[int]) – Wait n learning rate scheduling events before starting early stopping

  • random_seed (Optional[int]) – Random seed

  • directory (Optional[PathLike]) – Directory to store checkpoints and tensorboard logs

  • plugins (Optional[List[TrainingPlugin]]) – Optional list of training plugins

Return type

None