scglue.models.nn 源代码

r"""
Neural network modules, datasets & data loaders, and other utilities
"""

import functools
import os
from math import sqrt

import numpy as np
import pynvml
import torch
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _NormBase

from ..utils import config, logged


#-------------------------- Neural network modules -----------------------------

[文档]class GraphConv(torch.nn.Module): r""" Graph convolution (propagation only) """
[文档] def forward( self, input: torch.Tensor, eidx: torch.Tensor, enorm: torch.Tensor, esgn: torch.Tensor ) -> torch.Tensor: r""" Forward propagation Parameters ---------- input Input data (:math:`n_{vertices} \times n_{features}`) eidx Vertex indices of edges (:math:`2 \times n_{edges}`) enorm Normalized weight of edges (:math:`n_{edges}`) esgn Sign of edges (:math:`n_{edges}`) Returns ------- result Graph convolution result (:math:`n_{vertices} \times n_{features}`) """ sidx, tidx = eidx # source index and target index message = input[sidx] * (esgn * enorm).unsqueeze(1) # n_edges * n_features res = torch.zeros_like(input) tidx = tidx.unsqueeze(1).expand_as(message) # n_edges * n_features res.scatter_add_(0, tidx, message) return res
[文档]class GraphAttent(torch.nn.Module): # pragma: no cover r""" Graph attention Parameters ---------- in_features Input dimensionality out_featres Output dimensionality Note ---- **EXPERIMENTAL** """ def __init__(self, in_features: int, out_features: int) -> None: super().__init__() self.weight = torch.nn.ParameterDict({ "pos": torch.nn.Parameter(torch.Tensor(out_features, in_features)), "neg": torch.nn.Parameter(torch.Tensor(out_features, in_features)) }) self.head = torch.nn.ParameterDict({ "pos": torch.nn.Parameter(torch.zeros(out_features * 2)), "neg": torch.nn.Parameter(torch.zeros(out_features * 2)) }) torch.nn.init.kaiming_uniform_(self.weight["pos"], sqrt(5)) # Following torch.nn.Linear torch.nn.init.kaiming_uniform_(self.weight["neg"], sqrt(5)) # Following torch.nn.Linear
[文档] def forward( self, input: torch.Tensor, eidx: torch.Tensor, ewt: torch.Tensor, esgn: torch.Tensor ) -> torch.Tensor: r""" Forward propagation Parameters ---------- input Input data (:math:`n_{vertices} \times n_{features}`) eidx Vertex indices of edges (:math:`2 \times n_{edges}`) ewt Weight of edges (:math:`n_{edges}`) esgn Sign of edges (:math:`n_{edges}`) Returns ------- result Graph attention result (:math:`n_{vertices} \times n_{features}`) """ res_dict = {} for sgn in ("pos", "neg"): mask = esgn == 1 if sgn == "pos" else esgn == -1 sidx, tidx = eidx[:, mask] ptr = input @ self.weight[sgn].T alpha = torch.cat([ptr[sidx], ptr[tidx]], dim=1) @ self.head[sgn] alpha = F.leaky_relu(alpha, negative_slope=0.2).exp() * ewt[mask] normalizer = torch.zeros(ptr.shape[0], device=ptr.device) normalizer.scatter_add_(0, tidx, alpha) alpha = alpha / normalizer[tidx] # Only entries with non-zero denominators will be used message = ptr[sidx] * alpha.unsqueeze(1) res = torch.zeros_like(ptr) tidx = tidx.unsqueeze(1).expand_as(message) res.scatter_add_(0, tidx, message) res_dict[sgn] = res return res_dict["pos"] + res_dict["neg"]
#----------------------------- Utility functions -------------------------------
[文档]def freeze_running_stats(m: torch.nn.Module) -> None: r""" Selectively stops normalization layers from updating running stats Parameters ---------- m Network module """ if isinstance(m, _NormBase): m.eval()
[文档]def get_default_numpy_dtype() -> type: r""" Get numpy dtype matching that of the pytorch default dtype Returns ------- dtype Default numpy dtype """ return getattr(np, str(torch.get_default_dtype()).replace("torch.", ""))
[文档]@logged @functools.lru_cache(maxsize=1) def autodevice() -> torch.device: r""" Get torch computation device automatically based on GPU availability and memory usage Returns ------- device Computation device """ used_device = -1 if not config.CPU_ONLY: try: pynvml.nvmlInit() free_mems = np.array([ pynvml.nvmlDeviceGetMemoryInfo( pynvml.nvmlDeviceGetHandleByIndex(i) ).free for i in range(pynvml.nvmlDeviceGetCount()) ]) if free_mems.size: for item in config.MASKED_GPUS: free_mems[item] = -1 best_devices = np.where(free_mems == free_mems.max())[0] used_device = np.random.choice(best_devices, 1)[0] if free_mems[used_device] < 0: used_device = -1 except pynvml.NVMLError: pass if used_device == -1: autodevice.logger.info("Using CPU as computation device.") return torch.device("cpu") autodevice.logger.info("Using GPU %d as computation device.", used_device) os.environ["CUDA_VISIBLE_DEVICES"] = str(used_device) return torch.device("cuda")