r"""
Data handling utilities
"""
import copy
import functools
import multiprocessing
import operator
import os
import queue
import signal
import uuid
from math import ceil
from typing import Any, List, Mapping, Optional, Tuple
import h5py
import networkx as nx
import numpy as np
import pandas as pd
import scipy.sparse
import torch
from anndata import AnnData
from anndata._core.sparse_dataset import SparseDataset
from ..num import vertex_degrees
from ..typehint import AnyArray, Array, RandomState
from ..utils import config, get_rs, logged, processes
from .nn import get_default_numpy_dtype
DATA_CONFIG = Mapping[str, Any]
#---------------------------------- Datasets -----------------------------------
[文档]@logged
class Dataset(torch.utils.data.Dataset):
r"""
Abstract dataset interface extending that of :class:`torch.utils.data.Dataset`
Parameters
----------
getitem_size
Unitary fetch size for each __getitem__ call
"""
def __init__(self, getitem_size: int = 1) -> None:
super().__init__()
self.getitem_size = getitem_size
self.shuffle_seed: Optional[int] = None
self.seed_queue: Optional[multiprocessing.Queue] = None
self.propose_queue: Optional[multiprocessing.Queue] = None
self.propose_cache: Mapping[int, Any] = {}
@property
def has_workers(self) -> bool:
r"""
Whether background shuffling workers have been registered
"""
self_processes = processes[id(self)]
pl = bool(self_processes)
sq = self.seed_queue is not None
pq = self.propose_queue is not None
if not pl == sq == pq:
raise RuntimeError("Background shuffling seems broken!")
return pl and sq and pq
[文档] def prepare_shuffle(self, num_workers: int = 1, random_seed: int = 0) -> None:
r"""
Prepare dataset for custom shuffling
Parameters
----------
num_workers
Number of background workers for data shuffling
random_seed
Initial random seed (will increase by 1 with every shuffle call)
"""
if self.has_workers:
self.clean()
self_processes = processes[id(self)]
self.shuffle_seed = random_seed
if num_workers:
self.seed_queue = multiprocessing.Queue()
self.propose_queue = multiprocessing.Queue()
for i in range(num_workers):
p = multiprocessing.Process(target=self.shuffle_worker)
p.start()
self.logger.debug("Started background process: %d", p.pid)
self_processes[p.pid] = p
self.seed_queue.put(self.shuffle_seed + i)
[文档] def shuffle(self) -> None:
r"""
Custom shuffling
"""
if self.has_workers:
self_processes = processes[id(self)]
self.seed_queue.put(self.shuffle_seed + len(self_processes)) # Look ahead
while self.shuffle_seed not in self.propose_cache:
shuffle_seed, shuffled = self.propose_queue.get()
self.propose_cache[shuffle_seed] = shuffled
self.accept_shuffle(self.propose_cache.pop(self.shuffle_seed))
else:
self.accept_shuffle(self.propose_shuffle(self.shuffle_seed))
self.shuffle_seed += 1
[文档] def shuffle_worker(self) -> None:
r"""
Background shuffle worker
"""
signal.signal(signal.SIGINT, signal.SIG_IGN)
while True:
seed = self.seed_queue.get()
if seed is None:
self.propose_queue.put((None, os.getpid()))
break
self.propose_queue.put((seed, self.propose_shuffle(seed)))
[文档] def propose_shuffle(self, seed: int) -> Any:
r"""
Propose shuffling using a given random seed
Parameters
----------
seed
Random seed
Returns
-------
shuffled
Shuffled result
"""
raise NotImplementedError # pragma: no cover
[文档] def accept_shuffle(self, shuffled: Any) -> None:
r"""
Accept shuffling result
Parameters
----------
shuffled
Shuffled result
"""
raise NotImplementedError # pragma: no cover
[文档] def clean(self) -> None:
r"""
Clean up multi-process resources used in custom shuffling
"""
self_processes = processes[id(self)]
if not self.has_workers:
return
for _ in self_processes:
self.seed_queue.put(None)
self.propose_cache.clear()
while self_processes:
try:
first, second = self.propose_queue.get(
timeout=config.FORCE_TERMINATE_WORKER_PATIENCE
)
except queue.Empty:
break
if first is not None:
continue
pid = second
self_processes[pid].join()
self.logger.debug("Joined background process: %d", pid)
del self_processes[pid]
for pid in list(self_processes.keys()): # If some background processes failed to exit gracefully
self_processes[pid].terminate()
self_processes[pid].join()
self.logger.debug("Terminated background process: %d", pid)
del self_processes[pid]
self.propose_queue = None
self.seed_queue = None
def __del__(self) -> None:
self.clean()
[文档]@logged
class ArrayDataset(Dataset):
r"""
Array dataset for :class:`numpy.ndarray` and :class:`scipy.sparse.spmatrix`
objects. Different arrays are considered as unpaired, and thus do not need
to have identical sizes in the first dimension. Smaller arrays are recycled.
Also, data fetched from this dataset are automatically densified.
Parameters
----------
*arrays
An arbitrary number of data arrays
Note
----
We keep using arrays because sparse tensors do not support slicing.
Arrays are only converted to tensors after minibatch slicing.
"""
def __init__(self, *arrays: Array, getitem_size: int = 1) -> None:
super().__init__(getitem_size=getitem_size)
self.sizes = None
self.size = None
self.view_idx = None
self.shuffle_idx = None
self.arrays = arrays
@property
def arrays(self) -> List[Array]:
r"""
Internal array objects
"""
return self._arrays
@arrays.setter
def arrays(self, arrays: List[Array]) -> None:
self.sizes = [array.shape[0] for array in arrays]
if min(self.sizes) == 0:
raise ValueError("Empty array is not allowed!")
self.size = max(self.sizes)
self.view_idx = [np.arange(s) for s in self.sizes]
self.shuffle_idx = self.view_idx
self._arrays = arrays
def __len__(self) -> int:
return ceil(self.size / self.getitem_size)
def __getitem__(self, index: int) -> List[torch.Tensor]:
index = np.arange(
index * self.getitem_size,
min((index + 1) * self.getitem_size, self.size)
)
return [
torch.as_tensor(a[self.shuffle_idx[i][np.mod(index, self.sizes[i])]].toarray())
if scipy.sparse.issparse(a) or isinstance(a, SparseDataset)
else torch.as_tensor(a[self.shuffle_idx[i][np.mod(index, self.sizes[i])]])
for i, a in enumerate(self.arrays)
]
[文档] def propose_shuffle(self, seed: int) -> List[np.ndarray]:
rs = get_rs(seed)
return [rs.permutation(view_idx) for view_idx in self.view_idx]
[文档] def accept_shuffle(self, shuffled: List[np.ndarray]) -> None:
self.shuffle_idx = shuffled
[文档] def random_split(
self, fractions: List[float], random_state: RandomState = None
) -> List["ArrayDataset"]:
r"""
Randomly split the dataset into multiple subdatasets according to
given fractions.
Parameters
----------
fractions
Fraction of each split
random_state
Random state
Returns
-------
subdatasets
A list of splitted subdatasets
"""
if min(fractions) <= 0:
raise ValueError("Fractions should be greater than 0!")
if sum(fractions) != 1:
raise ValueError("Fractions do not sum to 1!")
rs = get_rs(random_state)
cum_frac = np.cumsum(fractions)
subdatasets = [
ArrayDataset(
*self.arrays, getitem_size=self.getitem_size
) for _ in fractions
]
for j, view_idx in enumerate(self.view_idx):
view_idx = rs.permutation(view_idx)
split_pos = np.round(cum_frac * view_idx.size).astype(int)
split_idx = np.split(view_idx, split_pos[:-1]) # Last pos produces an extra empty split
for i, idx in enumerate(split_idx):
subdatasets[i].sizes[j] = len(idx)
subdatasets[i].view_idx[j] = idx
subdatasets[i].shuffle_idx[j] = idx
return subdatasets
[文档]@logged
class AnnDataset(Dataset):
r"""
Dataset for :class:`anndata.AnnData` objects with partial pairing support.
Parameters
----------
*adatas
An arbitrary number of configured :class:`anndata.AnnData` objects
data_configs
Data configurations, one per dataset
mode
Data mode, must be one of ``{"train", "eval"}``
getitem_size
Unitary fetch size for each __getitem__ call
"""
def __init__(
self, adatas: List[AnnData], data_configs: List[DATA_CONFIG],
mode: str = "train", getitem_size: int = 1
) -> None:
super().__init__(getitem_size=getitem_size)
if mode not in ("train", "eval"):
raise ValueError("Invalid `mode`!")
self.mode = mode
self.adatas = adatas
self.data_configs = data_configs
@property
def adatas(self) -> List[AnnData]:
r"""
Internal :class:`AnnData` objects
"""
return self._adatas
@property
def data_configs(self) -> List[DATA_CONFIG]:
r"""
Data configuration for each dataset
"""
return self._data_configs
@adatas.setter
def adatas(self, adatas: List[AnnData]) -> None:
self.sizes = [adata.shape[0] for adata in adatas]
if min(self.sizes) == 0:
raise ValueError("Empty dataset is not allowed!")
self._adatas = adatas
@data_configs.setter
def data_configs(self, data_configs: List[DATA_CONFIG]) -> None:
if len(data_configs) != len(self.adatas):
raise ValueError(
"Number of data configs must match "
"the number of datasets!"
)
self.data_idx, self.extracted_data = self._extract_data(data_configs)
self.view_idx = pd.concat(
[data_idx.to_series() for data_idx in self.data_idx]
).drop_duplicates().to_numpy()
self.size = self.view_idx.size
self.shuffle_idx, self.shuffle_pmsk = self._get_idx_pmsk(self.view_idx)
self._data_configs = data_configs
def _get_idx_pmsk(
self, view_idx: np.ndarray, random_fill: bool = False,
random_state: RandomState = None
) -> Tuple[np.ndarray, np.ndarray]:
rs = get_rs(random_state) if random_fill else None
shuffle_idx, shuffle_pmsk = [], []
for data_idx in self.data_idx:
idx = data_idx.get_indexer(view_idx)
pmsk = idx >= 0
n_true = pmsk.sum()
n_false = pmsk.size - n_true
idx[~pmsk] = rs.choice(idx[pmsk], n_false, replace=True) \
if random_fill else idx[pmsk][np.mod(np.arange(n_false), n_true)]
shuffle_idx.append(idx)
shuffle_pmsk.append(pmsk)
return np.stack(shuffle_idx, axis=1), np.stack(shuffle_pmsk, axis=1)
def __len__(self) -> int:
return ceil(self.size / self.getitem_size)
def __getitem__(self, index: int) -> List[torch.Tensor]:
s = slice(
index * self.getitem_size,
min((index + 1) * self.getitem_size, self.size)
)
shuffle_idx = self.shuffle_idx[s].T
shuffle_pmsk = self.shuffle_pmsk[s]
items = [
torch.as_tensor(self._index_array(data, idx))
for extracted_data in self.extracted_data
for idx, data in zip(shuffle_idx, extracted_data)
]
items.append(torch.as_tensor(shuffle_pmsk))
return items
@staticmethod
def _index_array(arr: AnyArray, idx: np.ndarray) -> np.ndarray:
if isinstance(arr, (h5py.Dataset, SparseDataset)):
rank = scipy.stats.rankdata(idx, method="dense") - 1
sorted_idx = np.empty(rank.max() + 1, dtype=int)
sorted_idx[rank] = idx
arr = arr[sorted_idx.tolist()][rank.tolist()] # Convert to sequantial access and back
else:
arr = arr[idx]
return arr.toarray() if scipy.sparse.issparse(arr) else arr
def _extract_data(self, data_configs: List[DATA_CONFIG]) -> Tuple[
List[pd.Index], Tuple[
List[AnyArray], List[AnyArray], List[AnyArray],
List[AnyArray], List[AnyArray]
]
]:
if self.mode == "eval":
return self._extract_data_eval(data_configs)
return self._extract_data_train(data_configs) # self.mode == "train"
def _extract_data_train(self, data_configs: List[DATA_CONFIG]) -> Tuple[
List[pd.Index], Tuple[
List[AnyArray], List[AnyArray], List[AnyArray],
List[AnyArray], List[AnyArray]
]
]:
xuid = [
self._extract_xuid(adata, data_config)
for adata, data_config in zip(self.adatas, data_configs)
]
x = [
self._extract_x(adata, data_config)
for adata, data_config in zip(self.adatas, data_configs)
]
xrep = [
self._extract_xrep(adata, data_config)
for adata, data_config in zip(self.adatas, data_configs)
]
xbch = [
self._extract_xbch(adata, data_config)
for adata, data_config in zip(self.adatas, data_configs)
]
xlbl = [
self._extract_xlbl(adata, data_config)
for adata, data_config in zip(self.adatas, data_configs)
]
xdwt = [
self._extract_xdwt(adata, data_config)
for adata, data_config in zip(self.adatas, data_configs)
]
return xuid, (x, xrep, xbch, xlbl, xdwt)
def _extract_data_eval(self, data_configs: List[DATA_CONFIG]) -> Tuple[
List[pd.Index], Tuple[
List[AnyArray], List[AnyArray], List[AnyArray],
List[AnyArray], List[AnyArray]
]
]:
default_dtype = get_default_numpy_dtype()
xuid = [
self._extract_xuid(adata, data_config)
for adata, data_config in zip(self.adatas, data_configs)
]
xrep = [
self._extract_xrep(adata, data_config)
for adata, data_config in zip(self.adatas, data_configs)
]
x = [
np.empty((adata.shape[0], 0), dtype=default_dtype)
if xrep_.size else self._extract_x(adata, data_config)
for adata, data_config, xrep_ in zip(self.adatas, data_configs, xrep)
]
xbch = xlbl = [
np.empty((adata.shape[0], 0), dtype=int)
for adata in self.adatas
]
xdwt = [
np.empty((adata.shape[0], 0), dtype=default_dtype)
for adata in self.adatas
]
return xuid, (x, xrep, xbch, xlbl, xdwt)
def _extract_x(self, adata: AnnData, data_config: DATA_CONFIG) -> AnyArray:
default_dtype = get_default_numpy_dtype()
features = data_config["features"]
use_layer = data_config["use_layer"]
if not np.array_equal(adata.var_names, features):
adata = adata[:, features] # This will load all data to memory if backed
if use_layer:
if use_layer not in adata.layers:
raise ValueError(
f"Configured data layer '{use_layer}' "
f"cannot be found in input data!"
)
x = adata.layers[use_layer]
else:
x = adata.X
if x.dtype.type is not default_dtype:
if isinstance(x, (h5py.Dataset, SparseDataset)):
raise RuntimeError(
f"User is responsible for ensuring a {default_dtype} dtype "
f"when using backed data!"
)
x = x.astype(default_dtype)
if scipy.sparse.issparse(x):
x = x.tocsr()
return x
def _extract_xrep(self, adata: AnnData, data_config: DATA_CONFIG) -> AnyArray:
default_dtype = get_default_numpy_dtype()
use_rep = data_config["use_rep"]
rep_dim = data_config["rep_dim"]
if use_rep:
if use_rep not in adata.obsm:
raise ValueError(
f"Configured data representation '{use_rep}' "
f"cannot be found in input data!"
)
xrep = np.asarray(adata.obsm[use_rep]).astype(default_dtype)
if xrep.shape[1] != rep_dim:
raise ValueError(
f"Input representation dimensionality {xrep.shape[1]} "
f"does not match the configured {rep_dim}!"
)
return xrep
return np.empty((adata.shape[0], 0), dtype=default_dtype)
def _extract_xbch(self, adata: AnnData, data_config: DATA_CONFIG) -> AnyArray:
use_batch = data_config["use_batch"]
batches = data_config["batches"]
if use_batch:
if use_batch not in adata.obs:
raise ValueError(
f"Configured data batch '{use_batch}' "
f"cannot be found in input data!"
)
return batches.get_indexer(adata.obs[use_batch])
return np.zeros(adata.shape[0], dtype=int)
def _extract_xlbl(self, adata: AnnData, data_config: DATA_CONFIG) -> AnyArray:
use_cell_type = data_config["use_cell_type"]
cell_types = data_config["cell_types"]
if use_cell_type:
if use_cell_type not in adata.obs:
raise ValueError(
f"Configured cell type '{use_cell_type}' "
f"cannot be found in input data!"
)
return cell_types.get_indexer(adata.obs[use_cell_type])
return -np.ones(adata.shape[0], dtype=int)
def _extract_xdwt(self, adata: AnnData, data_config: DATA_CONFIG) -> AnyArray:
default_dtype = get_default_numpy_dtype()
use_dsc_weight = data_config["use_dsc_weight"]
if use_dsc_weight:
if use_dsc_weight not in adata.obs:
raise ValueError(
f"Configured discriminator sample weight '{use_dsc_weight}' "
f"cannot be found in input data!"
)
xdwt = adata.obs[use_dsc_weight].to_numpy().astype(default_dtype)
xdwt /= xdwt.sum() / xdwt.size
else:
xdwt = np.ones(adata.shape[0], dtype=default_dtype)
return xdwt
def _extract_xuid(self, adata: AnnData, data_config: DATA_CONFIG) -> pd.Index:
if data_config["use_obs_names"]:
xuid = adata.obs_names.to_numpy()
else: # NOTE: Assuming random UUIDs never collapse with anything
self.logger.debug("Generating random xuid...")
xuid = np.array([uuid.uuid4().hex for _ in range(adata.shape[0])])
if len(set(xuid)) != xuid.size:
raise ValueError("Non-unique cell ID!")
return pd.Index(xuid)
[文档] def propose_shuffle(self, seed: int) -> Tuple[np.ndarray, np.ndarray]:
rs = get_rs(seed)
view_idx = rs.permutation(self.view_idx)
return self._get_idx_pmsk(view_idx, random_fill=True, random_state=rs)
[文档] def accept_shuffle(self, shuffled: Tuple[np.ndarray, np.ndarray]) -> None:
self.shuffle_idx, self.shuffle_pmsk = shuffled
[文档] def random_split(
self, fractions: List[float], random_state: RandomState = None
) -> List["AnnDataset"]:
r"""
Randomly split the dataset into multiple subdatasets according to
given fractions.
Parameters
----------
fractions
Fraction of each split
random_state
Random state
Returns
-------
subdatasets
A list of splitted subdatasets
"""
if min(fractions) <= 0:
raise ValueError("Fractions should be greater than 0!")
if sum(fractions) != 1:
raise ValueError("Fractions do not sum to 1!")
rs = get_rs(random_state)
cum_frac = np.cumsum(fractions)
view_idx = rs.permutation(self.view_idx)
split_pos = np.round(cum_frac * view_idx.size).astype(int)
split_idx = np.split(view_idx, split_pos[:-1]) # Last pos produces an extra empty split
subdatasets = []
for idx in split_idx:
sub = copy.copy(self)
sub.view_idx = idx
sub.size = idx.size
sub.shuffle_idx, sub.shuffle_pmsk = sub._get_idx_pmsk(idx) # pylint: disable=protected-access
subdatasets.append(sub)
return subdatasets
[文档]@logged
class GraphDataset(Dataset):
r"""
Dataset for graphs with support for negative sampling
Parameters
----------
graph
Graph object
vertices
Indexer of graph vertices
neg_samples
Number of negative samples per edge
weighted_sampling
Whether to do negative sampling based on vertex importance
deemphasize_loops
Whether to deemphasize self-loops when computing vertex importance
getitem_size
Unitary fetch size for each __getitem__ call
Note
----
Custom shuffling performs negative sampling.
"""
def __init__(
self, graph: nx.Graph, vertices: pd.Index,
neg_samples: int = 1, weighted_sampling: bool = True,
deemphasize_loops: bool = True, getitem_size: int = 1
) -> None:
super().__init__(getitem_size=getitem_size)
self.eidx, self.ewt, self.esgn = \
self.graph2triplet(graph, vertices)
self.eset = {
(i, j, s) for (i, j), s in
zip(self.eidx.T, self.esgn)
}
self.vnum = self.eidx.max() + 1
if weighted_sampling:
if deemphasize_loops:
non_loop = self.eidx[0] != self.eidx[1]
eidx = self.eidx[:, non_loop]
ewt = self.ewt[non_loop]
else:
eidx = self.eidx
ewt = self.ewt
degree = vertex_degrees(eidx, ewt, vnum=self.vnum, direction="both")
else:
degree = np.ones(self.vnum, dtype=self.ewt.dtype)
degree_sum = degree.sum()
if degree_sum:
self.vprob = degree / degree_sum # Vertex sampling probability
else: # Possible when `deemphasize_loops` is set on a loop-only graph
self.vprob = np.ones(self.vnum, dtype=self.ewt.dtype) / self.vnum
effective_enum = self.ewt.sum()
self.eprob = self.ewt / effective_enum # Edge sampling probability
self.effective_enum = round(effective_enum)
self.neg_samples = neg_samples
self.size = self.effective_enum * (1 + self.neg_samples)
self.samp_eidx: Optional[np.ndarray] = None
self.samp_ewt: Optional[np.ndarray] = None
self.samp_esgn: Optional[np.ndarray] = None
[文档] def graph2triplet(
self, graph: nx.Graph, vertices: pd.Index,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
r"""
Convert graph object to graph triplet
Parameters
----------
graph
Graph object
vertices
Graph vertices
Returns
-------
eidx
Vertex indices of edges (:math:`2 \times n_{edges}`)
ewt
Weight of edges (:math:`n_{edges}`)
esgn
Sign of edges (:math:`n_{edges}`)
"""
graph = nx.MultiDiGraph(graph) # Convert undirecitonal to bidirectional, while keeping multi-edges
default_dtype = get_default_numpy_dtype()
i, j, w, s = [], [], [], []
for k, v in dict(graph.edges).items():
i.append(k[0])
j.append(k[1])
w.append(v["weight"])
s.append(v["sign"])
eidx = np.stack([
vertices.get_indexer(i),
vertices.get_indexer(j)
]).astype(np.int64)
if eidx.min() < 0:
raise ValueError("Missing vertices!")
ewt = np.asarray(w).astype(default_dtype)
if ewt.min() <= 0 or ewt.max() > 1:
raise ValueError("Invalid edge weight!")
esgn = np.asarray(s).astype(default_dtype)
if set(esgn).difference({-1, 1}):
raise ValueError("Invalid edge sign!")
return eidx, ewt, esgn
def __len__(self) -> int:
return ceil(self.size / self.getitem_size)
def __getitem__(self, index: int) -> List[torch.Tensor]:
s = slice(
index * self.getitem_size,
min((index + 1) * self.getitem_size, self.size)
)
return [
torch.as_tensor(self.samp_eidx[:, s]),
torch.as_tensor(self.samp_ewt[s]),
torch.as_tensor(self.samp_esgn[s])
]
[文档] def propose_shuffle(
self, seed: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
(pi, pj), pw, ps = self.eidx, self.ewt, self.esgn
rs = get_rs(seed)
psamp = rs.choice(self.ewt.size, self.effective_enum, replace=True, p=self.eprob)
pi_, pj_, pw_, ps_ = pi[psamp], pj[psamp], pw[psamp], ps[psamp]
pw_ = np.ones_like(pw_)
ni_ = np.tile(pi_, self.neg_samples)
nw_ = np.zeros(pw_.size * self.neg_samples, dtype=pw_.dtype)
ns_ = np.tile(ps_, self.neg_samples)
nj_ = rs.choice(self.vnum, pj_.size * self.neg_samples, replace=True, p=self.vprob)
remain = np.where([
item in self.eset
for item in zip(ni_, nj_, ns_)
])[0]
while remain.size: # NOTE: Potential infinite loop if graph too dense
newnj = rs.choice(self.vnum, remain.size, replace=True, p=self.vprob)
nj_[remain] = newnj
remain = remain[[
item in self.eset
for item in zip(ni_[remain], newnj, ns_[remain])
]]
idx = np.stack([np.concatenate([pi_, ni_]), np.concatenate([pj_, nj_])])
w = np.concatenate([pw_, nw_])
s = np.concatenate([ps_, ns_])
perm = rs.permutation(idx.shape[1])
return idx[:, perm], w[perm], s[perm]
[文档] def accept_shuffle(
self, shuffled: Tuple[np.ndarray, np.ndarray, np.ndarray]
) -> None:
self.samp_eidx, self.samp_ewt, self.samp_esgn = shuffled
#-------------------------------- Data loaders ---------------------------------
[文档]class DataLoader(torch.utils.data.DataLoader):
r"""
Custom data loader that manually shuffles the internal dataset before each
round of iteration (see :class:`torch.utils.data.DataLoader` for usage)
"""
def __init__(self, dataset: Dataset, **kwargs) -> None:
super().__init__(dataset, **kwargs)
self.collate_fn = self._collate_graph if isinstance(
dataset, GraphDataset
) else self._collate
self.shuffle = kwargs["shuffle"] if "shuffle" in kwargs else False
def __iter__(self) -> "DataLoader":
if self.shuffle:
self.dataset.shuffle() # Customized shuffling
return super().__iter__()
@staticmethod
def _collate(batch):
return tuple(map(lambda x: torch.cat(x, dim=0), zip(*batch)))
@staticmethod
def _collate_graph(batch):
eidx, ewt, esgn = zip(*batch)
eidx = torch.cat(eidx, dim=1)
ewt = torch.cat(ewt, dim=0)
esgn = torch.cat(esgn, dim=0)
return eidx, ewt, esgn
[文档]class ParallelDataLoader:
r"""
Parallel data loader
Parameters
----------
*data_loaders
An arbitrary number of data loaders
cycle_flags
Whether each data loader should be cycled in case they are of
different lengths, by default none of them are cycled.
"""
def __init__(
self, *data_loaders: DataLoader,
cycle_flags: Optional[List[bool]] = None
) -> None:
cycle_flags = cycle_flags or [False] * len(data_loaders)
if len(cycle_flags) != len(data_loaders):
raise ValueError("Invalid cycle flags!")
self.cycle_flags = cycle_flags
self.data_loaders = list(data_loaders)
self.num_loaders = len(self.data_loaders)
self.iterators = None
def __iter__(self) -> "ParallelDataLoader":
self.iterators = [iter(loader) for loader in self.data_loaders]
return self
def _next(self, i: int) -> List[torch.Tensor]:
try:
return next(self.iterators[i])
except StopIteration as e:
if self.cycle_flags[i]:
self.iterators[i] = iter(self.data_loaders[i])
return next(self.iterators[i])
raise e
def __next__(self) -> List[torch.Tensor]:
return functools.reduce(
operator.add, [self._next(i) for i in range(self.num_loaders)]
)