r"""
Miscellaneous utilities
"""
import os
import logging
import signal
import subprocess
import sys
from collections import defaultdict
from multiprocessing import Process
from typing import Any, List, Mapping, Optional
import numpy as np
import pandas as pd
import torch
from pybedtools.helpers import set_bedtools_path
from .typehint import RandomState, T
AUTO = "AUTO" # Flag for using automatically determined hyperparameters
#------------------------------ Global containers ------------------------------
processes: Mapping[int, Mapping[int, Process]] = defaultdict(dict) # id -> pid -> process
#-------------------------------- Meta classes ---------------------------------
#--------------------------------- Log manager ---------------------------------
class _CriticalFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
return record.levelno >= logging.WARNING
class _NonCriticalFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
return record.levelno < logging.WARNING
[文档]class LogManager(metaclass=SingletonMeta):
r"""
Manage loggers used in the package
"""
def __init__(self) -> None:
self._loggers = {}
self._log_file = None
self._console_log_level = logging.INFO
self._file_log_level = logging.DEBUG
self._file_fmt = \
"%(asctime)s.%(msecs)03d [%(levelname)s] %(name)s: %(message)s"
self._console_fmt = \
"[%(levelname)s] %(name)s: %(message)s"
self._date_fmt = "%Y-%m-%d %H:%M:%S"
@property
def log_file(self) -> str:
r"""
Configure log file
"""
return self._log_file
@property
def file_log_level(self) -> int:
r"""
Configure logging level in the log file
"""
return self._file_log_level
@property
def console_log_level(self) -> int:
r"""
Configure logging level printed in the console
"""
return self._console_log_level
def _create_file_handler(self) -> logging.FileHandler:
file_handler = logging.FileHandler(self.log_file)
file_handler.setLevel(self.file_log_level)
file_handler.setFormatter(logging.Formatter(
fmt=self._file_fmt, datefmt=self._date_fmt))
return file_handler
def _create_console_handler(self, critical: bool) -> logging.StreamHandler:
if critical:
console_handler = logging.StreamHandler(sys.stderr)
console_handler.addFilter(_CriticalFilter())
else:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.addFilter(_NonCriticalFilter())
console_handler.setLevel(self.console_log_level)
console_handler.setFormatter(logging.Formatter(fmt=self._console_fmt))
return console_handler
[文档] def get_logger(self, name: str) -> logging.Logger:
r"""
Get a logger by name
"""
if name in self._loggers:
return self._loggers[name]
new_logger = logging.getLogger(name)
new_logger.setLevel(logging.DEBUG) # lowest level
new_logger.addHandler(self._create_console_handler(True))
new_logger.addHandler(self._create_console_handler(False))
if self.log_file:
new_logger.addHandler(self._create_file_handler())
self._loggers[name] = new_logger
return new_logger
@log_file.setter
def log_file(self, file_name: os.PathLike) -> None:
self._log_file = file_name
for logger in self._loggers.values():
for idx, handler in enumerate(logger.handlers):
if isinstance(handler, logging.FileHandler):
logger.handlers[idx].close()
if self.log_file:
logger.handlers[idx] = self._create_file_handler()
else:
del logger.handlers[idx]
break
else:
if file_name:
logger.addHandler(self._create_file_handler())
@file_log_level.setter
def file_log_level(self, log_level: int) -> None:
self._file_log_level = log_level
for logger in self._loggers.values():
for handler in logger.handlers:
if isinstance(handler, logging.FileHandler):
handler.setLevel(self.file_log_level)
break
@console_log_level.setter
def console_log_level(self, log_level: int) -> None:
self._console_log_level = log_level
for logger in self._loggers.values():
for handler in logger.handlers:
if type(handler) is logging.StreamHandler: # pylint: disable=unidiomatic-typecheck
handler.setLevel(self.console_log_level)
log = LogManager()
[文档]def logged(obj: T) -> T:
r"""
Add logger as an attribute
"""
obj.logger = log.get_logger(obj.__name__)
return obj
#---------------------------- Configuration Manager ----------------------------
[文档]@logged
class ConfigManager(metaclass=SingletonMeta):
r"""
Global configurations
"""
def __init__(self) -> None:
self.TMP_PREFIX = "GLUETMP"
self.ANNDATA_KEY = "__scglue__"
self.CPU_ONLY = False
self.CUDNN_MODE = "repeatability"
self.MASKED_GPUS = []
self.ARRAY_SHUFFLE_NUM_WORKERS = 0
self.GRAPH_SHUFFLE_NUM_WORKERS = 1
self.FORCE_TERMINATE_WORKER_PATIENCE = 60
self.DATALOADER_NUM_WORKERS = 0
self.DATALOADER_FETCHES_PER_WORKER = 4
self.DATALOADER_PIN_MEMORY = True
self.CHECKPOINT_SAVE_INTERVAL = 10
self.CHECKPOINT_SAVE_NUMBERS = 3
self.PRINT_LOSS_INTERVAL = 10
self.TENSORBOARD_FLUSH_SECS = 5
self.ALLOW_TRAINING_INTERRUPTION = True
self.BEDTOOLS_PATH = ""
@property
def TMP_PREFIX(self) -> str:
r"""
Prefix of temporary files and directories created.
Default values is ``"GLUETMP"``.
"""
return self._TMP_PREFIX
@TMP_PREFIX.setter
def TMP_PREFIX(self, tmp_prefix: str) -> None:
self._TMP_PREFIX = tmp_prefix
@property
def ANNDATA_KEY(self) -> str:
r"""
Key in ``adata.uns`` for storing dataset configurations.
Default value is ``"__scglue__"``
"""
return self._ANNDATA_KEY
@ANNDATA_KEY.setter
def ANNDATA_KEY(self, anndata_key: str) -> None:
self._ANNDATA_KEY = anndata_key
@property
def CPU_ONLY(self) -> bool:
r"""
Whether computation should use only CPUs.
Default value is ``False``.
"""
return self._CPU_ONLY
@CPU_ONLY.setter
def CPU_ONLY(self, cpu_only: bool) -> None:
self._CPU_ONLY = cpu_only
if self._CPU_ONLY and self._DATALOADER_NUM_WORKERS:
self.logger.warning(
"It is recommended to set `DATALOADER_NUM_WORKERS` to 0 "
"when using CPU_ONLY mode. Otherwise, deadlocks may happen "
"occationally."
)
@property
def CUDNN_MODE(self) -> str:
r"""
CuDNN computation mode, should be one of {"repeatability", "performance"}.
Default value is ``"repeatability"``.
Note
----
As of now, due to the use of :meth:`torch.Tensor.scatter_add_`
operation, the results are not completely reproducible even when
``CUDNN_MODE`` is set to ``"repeatability"``, if GPU is used as
computation device. Exact repeatability can only be achieved on CPU.
The situtation might change with new releases of :mod:`torch`.
"""
return self._CUDNN_MODE
@CUDNN_MODE.setter
def CUDNN_MODE(self, cudnn_mode: str) -> None:
if cudnn_mode not in ("repeatability", "performance"):
raise ValueError("Invalid mode!")
self._CUDNN_MODE = cudnn_mode
torch.backends.cudnn.deterministic = self._CUDNN_MODE == "repeatability"
torch.backends.cudnn.benchmark = self._CUDNN_MODE == "performance"
@property
def MASKED_GPUS(self) -> List[int]:
r"""
A list of GPUs that should not be used when selecting computation device.
This must be set before initializing any model, otherwise would be ineffective.
Default value is ``[]``.
"""
return self._MASKED_GPUS
@MASKED_GPUS.setter
def MASKED_GPUS(self, masked_gpus: List[int]) -> None:
if masked_gpus:
import pynvml
pynvml.nvmlInit()
device_count = pynvml.nvmlDeviceGetCount()
for item in masked_gpus:
if item >= device_count:
raise ValueError(f"GPU device \"{item}\" is non-existent!")
self._MASKED_GPUS = masked_gpus
@property
def ARRAY_SHUFFLE_NUM_WORKERS(self) -> int:
r"""
Number of background workers for array data shuffling.
Default value is ``0``.
"""
return self._ARRAY_SHUFFLE_NUM_WORKERS
@ARRAY_SHUFFLE_NUM_WORKERS.setter
def ARRAY_SHUFFLE_NUM_WORKERS(self, array_shuffle_num_workers: int) -> None:
self._ARRAY_SHUFFLE_NUM_WORKERS = array_shuffle_num_workers
@property
def GRAPH_SHUFFLE_NUM_WORKERS(self) -> int:
r"""
Number of background workers for graph data shuffling.
Default value is ``1``.
"""
return self._GRAPH_SHUFFLE_NUM_WORKERS
@GRAPH_SHUFFLE_NUM_WORKERS.setter
def GRAPH_SHUFFLE_NUM_WORKERS(self, graph_shuffle_num_workers: int) -> None:
self._GRAPH_SHUFFLE_NUM_WORKERS = graph_shuffle_num_workers
@property
def FORCE_TERMINATE_WORKER_PATIENCE(self) -> int:
r"""
Seconds to wait before force terminating unresponsive workers.
Default value is ``60``.
"""
return self._FORCE_TERMINATE_WORKER_PATIENCE
@FORCE_TERMINATE_WORKER_PATIENCE.setter
def FORCE_TERMINATE_WORKER_PATIENCE(self, force_terminate_worker_patience: int) -> None:
self._FORCE_TERMINATE_WORKER_PATIENCE = force_terminate_worker_patience
@property
def DATALOADER_NUM_WORKERS(self) -> int:
r"""
Number of worker processes to use in data loader.
Default value is ``0``.
"""
return self._DATALOADER_NUM_WORKERS
@DATALOADER_NUM_WORKERS.setter
def DATALOADER_NUM_WORKERS(self, dataloader_num_workers: int) -> None:
if dataloader_num_workers > 8:
self.logger.warning(
"Worker number 1-8 is generally sufficient, "
"too many workers might have negative impact on speed."
)
self._DATALOADER_NUM_WORKERS = dataloader_num_workers
@property
def DATALOADER_FETCHES_PER_WORKER(self) -> int:
r"""
Number of fetches per worker per batch to use in data loader.
Default value is ``4``.
"""
return self._DATALOADER_FETCHES_PER_WORKER
@DATALOADER_FETCHES_PER_WORKER.setter
def DATALOADER_FETCHES_PER_WORKER(self, dataloader_fetches_per_worker: int) -> None:
self._DATALOADER_FETCHES_PER_WORKER = dataloader_fetches_per_worker
@property
def DATALOADER_FETCHES_PER_BATCH(self) -> int:
r"""
Number of fetches per batch in data loader (read-only).
"""
return max(1, self.DATALOADER_NUM_WORKERS) * self.DATALOADER_FETCHES_PER_WORKER
@property
def DATALOADER_PIN_MEMORY(self) -> bool:
r"""
Whether to use pin memory in data loader.
Default value is ``True``.
"""
return self._DATALOADER_PIN_MEMORY
@DATALOADER_PIN_MEMORY.setter
def DATALOADER_PIN_MEMORY(self, dataloader_pin_memory: bool):
self._DATALOADER_PIN_MEMORY = dataloader_pin_memory
@property
def CHECKPOINT_SAVE_INTERVAL(self) -> int:
r"""
Automatically save checkpoints every n epochs.
Default value is ``10``.
"""
return self._CHECKPOINT_SAVE_INTERVAL
@CHECKPOINT_SAVE_INTERVAL.setter
def CHECKPOINT_SAVE_INTERVAL(self, checkpoint_save_interval: int) -> None:
self._CHECKPOINT_SAVE_INTERVAL = checkpoint_save_interval
@property
def CHECKPOINT_SAVE_NUMBERS(self) -> int:
r"""
Maximal number of checkpoints to preserve at any point.
Default value is ``3``.
"""
return self._CHECKPOINT_SAVE_NUMBERS
@CHECKPOINT_SAVE_NUMBERS.setter
def CHECKPOINT_SAVE_NUMBERS(self, checkpoint_save_numbers: int) -> None:
self._CHECKPOINT_SAVE_NUMBERS = checkpoint_save_numbers
@property
def PRINT_LOSS_INTERVAL(self) -> int:
r"""
Print loss values every n epochs.
Default value is ``10``.
"""
return self._PRINT_LOSS_INTERVAL
@PRINT_LOSS_INTERVAL.setter
def PRINT_LOSS_INTERVAL(self, print_loss_interval: int) -> None:
self._PRINT_LOSS_INTERVAL = print_loss_interval
@property
def TENSORBOARD_FLUSH_SECS(self) -> int:
r"""
Flush tensorboard logs to file every n seconds.
Default values is ``5``.
"""
return self._TENSORBOARD_FLUSH_SECS
@TENSORBOARD_FLUSH_SECS.setter
def TENSORBOARD_FLUSH_SECS(self, tensorboard_flush_secs: int) -> None:
self._TENSORBOARD_FLUSH_SECS = tensorboard_flush_secs
@property
def ALLOW_TRAINING_INTERRUPTION(self) -> bool:
r"""
Allow interruption before model training converges.
Default values is ``True``.
"""
return self._ALLOW_TRAINING_INTERRUPTION
@ALLOW_TRAINING_INTERRUPTION.setter
def ALLOW_TRAINING_INTERRUPTION(self, allow_training_interruption: bool) -> None:
self._ALLOW_TRAINING_INTERRUPTION = allow_training_interruption
@property
def BEDTOOLS_PATH(self) -> str:
r"""
Path to bedtools executable.
Default value is ``bedtools``.
"""
return self._BEDTOOLS_PATH
@BEDTOOLS_PATH.setter
def BEDTOOLS_PATH(self, bedtools_path: str) -> None:
self._BEDTOOLS_PATH = bedtools_path
set_bedtools_path(bedtools_path)
config = ConfigManager()
#---------------------------- Interruption handling ----------------------------
[文档]@logged
class DelayedKeyboardInterrupt: # pragma: no cover
r"""
Shield a code block from keyboard interruptions, delaying handling
till the block is finished (adapted from
`https://stackoverflow.com/a/21919644
<https://stackoverflow.com/a/21919644>`__).
"""
def __init__(self):
self.signal_received = None
self.old_handler = None
def __enter__(self):
self.signal_received = False
self.old_handler = signal.signal(signal.SIGINT, self._handler)
def _handler(self, sig, frame):
self.signal_received = (sig, frame)
self.logger.debug("SIGINT received, delaying KeyboardInterrupt...")
def __exit__(self, exc_type, exc_val, exc_tb):
signal.signal(signal.SIGINT, self.old_handler)
if self.signal_received:
self.old_handler(*self.signal_received)
#--------------------------- Constrained data frame ----------------------------
[文档]@logged
class ConstrainedDataFrame(pd.DataFrame):
r"""
Data frame with certain format constraints
Note
----
Format constraints are checked and maintained automatically.
"""
def __init__(self, *args, **kwargs) -> None:
df = pd.DataFrame(*args, **kwargs)
df = self.rectify(df)
self.verify(df)
super().__init__(df)
def __setitem__(self, key, value) -> None:
super().__setitem__(key, value)
self.verify(self)
@property
def _constructor(self) -> type:
return type(self)
[文档] @classmethod
def rectify(cls, df: pd.DataFrame) -> pd.DataFrame:
r"""
Rectify data frame for format integrity
Parameters
----------
df
Data frame to be rectified
Returns
-------
rectified_df
Rectified data frame
"""
return df
[文档] @classmethod
def verify(cls, df: pd.DataFrame) -> None:
r"""
Verify data frame for format integrity
Parameters
----------
df
Data frame to be verified
"""
@property
def df(self) -> pd.DataFrame:
r"""
Convert to regular data frame
"""
return pd.DataFrame(self)
def __repr__(self) -> str:
r"""
Note
----
We need to explicitly call :func:`repr` on the regular data frame
to bypass integrity verification, because when the terminal is
too narrow, :mod:`pandas` would split the data frame internally,
causing format verification to fail.
"""
return repr(self.df)
#--------------------------- Other utility functions ---------------------------
[文档]def get_chained_attr(x: Any, attr: str) -> Any:
r"""
Get attribute from an object, with support for chained attribute names.
Parameters
----------
x
Object to get attribute from
attr
Attribute name
Returns
-------
attr_value
Attribute value
"""
for k in attr.split("."):
if not hasattr(x, k):
raise AttributeError(f"{attr} not found!")
x = getattr(x, k)
return x
[文档]def get_rs(x: RandomState = None) -> np.random.RandomState:
r"""
Get random state object
Parameters
----------
x
Object that can be converted to a random state object
Returns
-------
rs
Random state object
"""
if isinstance(x, int):
return np.random.RandomState(x)
if isinstance(x, np.random.RandomState):
return x
return np.random
[文档]@logged
def run_command(
command: str,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
log_command: bool = True, print_output: bool = True,
err_message: Optional[Mapping[int, str]] = None, **kwargs
) -> Optional[List[str]]:
r"""
Run an external command and get realtime output
Parameters
----------
command
A string containing the command to be executed
stdout
Where to redirect stdout
stderr
Where to redirect stderr
echo_command
Whether to log the command being printed (log level is INFO)
print_output
Whether to print stdout of the command.
If ``stdout`` is PIPE and ``print_output`` is set to False,
the output will be returned as a list of output lines.
err_message
Look up dict of error message (indexed by error code)
**kwargs
Other keyword arguments to be passed to :class:`subprocess.Popen`
Returns
-------
output_lines
A list of output lines (only returned if ``stdout`` is PIPE
and ``print_output`` is False)
"""
if log_command:
run_command.logger.info("Executing external command: %s", command)
executable = command.split(" ")[0]
with subprocess.Popen(command, stdout=stdout, stderr=stderr,
shell=True, **kwargs) as p:
if stdout == subprocess.PIPE:
prompt = f"{executable} ({p.pid}): "
output_lines = []
def _handle(line):
line = line.strip().decode()
if print_output:
print(prompt + line)
else:
output_lines.append(line)
while True:
_handle(p.stdout.readline())
ret = p.poll()
if ret is not None:
# Handle output between last readlines and successful poll
for line in p.stdout.readlines():
_handle(line)
break
else:
output_lines = None
ret = p.wait()
if ret != 0:
err_message = err_message or {}
if ret in err_message:
err_message = " " + err_message[ret]
elif "__default__" in err_message:
err_message = " " + err_message["__default__"]
else:
err_message = ""
raise RuntimeError(
f"{executable} exited with error code: {ret}.{err_message}")
if stdout == subprocess.PIPE and not print_output:
return output_lines