, graph, edge_weight='weight', edge_sign='sign', neg_samples=10, val_split=0.1, data_batch_size=128, graph_batch_size=- 1, align_burnin=- 1, safe_burnin=True, max_epochs=- 1, patience=- 1, reduce_lr_patience=- 1, wait_n_lrs=1, directory=None)[source]

Fit model on given datasets

  • adatas (Mapping[str, AnnData]) – Datasets (indexed by domain name)

  • graph (Graph) – Prior graph

  • edge_weight (str) – Key of edge attribute for edge weight

  • edge_sign (str) – Key of edge attribute for edge sign

  • neg_samples (int) – Number of negative samples for each edge

  • val_split (float) – Validation split

  • data_batch_size (int) – Number of cells in each data minibatch

  • graph_batch_size (int) – Number of edges in each graph minibatch

  • align_burnin (int) – Number of epochs to wait before starting alignment

  • safe_burnin (bool) – Whether to postpone learning rate scheduling and earlystopping until after the burnin stage

  • max_epochs (int) – Maximal number of epochs

  • patience (Optional[int]) – Patience of early stopping

  • reduce_lr_patience (Optional[int]) – Patience to reduce learning rate

  • wait_n_lrs (int) – Wait n learning rate scheduling events before starting early stopping

  • directory (Optional[PathLike]) – Directory to store checkpoints and tensorboard logs

Return type