r"""
Probability distributions
"""
import torch
import torch.distributions as D
import torch.nn.functional as F
from pyro.distributions import BetaBinomial as BetaBinomialDistribution
from ..num import EPS
from .nn import zero_nan_grad
# ------------------------------- Distributions --------------------------------
[docs]class MSE(D.Distribution):
r"""
A "sham" distribution that outputs negative MSE on ``log_prob``
Parameters
----------
loc
Mean of the distribution
"""
def __init__(self, loc: torch.Tensor) -> None:
super().__init__(validate_args=False)
self.loc = loc
[docs] def log_prob(self, value: torch.Tensor) -> None:
return -F.mse_loss(self.loc, value)
@property
def mean(self) -> torch.Tensor:
return self.loc
[docs]class RMSE(MSE):
r"""
A "sham" distribution that outputs negative RMSE on ``log_prob``
Parameters
----------
loc
Mean of the distribution
"""
[docs] def log_prob(self, value: torch.Tensor) -> None:
return -F.mse_loss(self.loc, value).sqrt()
[docs]class ZIN(D.Normal):
r"""
Zero-inflated normal distribution with subsetting support
Parameters
----------
zi_logits
Zero-inflation logits
loc
Location of the normal distribution
scale
Scale of the normal distribution
"""
def __init__(
self, zi_logits: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor
) -> None:
if zi_logits.requires_grad:
zi_logits.register_hook(zero_nan_grad)
if loc.requires_grad:
loc.register_hook(zero_nan_grad)
if scale.requires_grad:
scale.register_hook(zero_nan_grad)
super().__init__(loc, scale, validate_args=False)
self.zi_logits = zi_logits
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor:
raw_log_prob = super().log_prob(value)
zi_log_prob = torch.empty_like(raw_log_prob)
z_mask = value.abs() < EPS
z_zi_logits, nz_zi_logits = self.zi_logits[z_mask], self.zi_logits[~z_mask]
zi_log_prob[z_mask] = (
raw_log_prob[z_mask].exp() + z_zi_logits.exp() + EPS
).log() - F.softplus(z_zi_logits)
zi_log_prob[~z_mask] = raw_log_prob[~z_mask] - F.softplus(nz_zi_logits)
return zi_log_prob
[docs]class ZILN(D.LogNormal):
r"""
Zero-inflated log-normal distribution with subsetting support
Parameters
----------
zi_logits
Zero-inflation logits
loc
Location of the log-normal distribution
scale
Scale of the log-normal distribution
"""
def __init__(
self, zi_logits: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor
) -> None:
super().__init__(loc, scale)
self.zi_logits = zi_logits
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor:
zi_log_prob = torch.empty_like(value)
z_mask = value.abs() < EPS
z_zi_logits, nz_zi_logits = self.zi_logits[z_mask], self.zi_logits[~z_mask]
zi_log_prob[z_mask] = z_zi_logits - F.softplus(z_zi_logits)
zi_log_prob[~z_mask] = D.LogNormal(
self.loc[~z_mask], self.scale[~z_mask]
).log_prob(value[~z_mask]) - F.softplus(nz_zi_logits)
return zi_log_prob
[docs]class ZINB(D.NegativeBinomial):
r"""
Zero-inflated negative binomial distribution
Parameters
----------
zi_logits
Zero-inflation logits
total_count
Total count of the negative binomial distribution
logits
Logits of the negative binomial distribution
"""
def __init__(
self,
zi_logits: torch.Tensor,
total_count: torch.Tensor,
logits: torch.Tensor = None,
) -> None:
super().__init__(total_count, logits=logits)
self.zi_logits = zi_logits
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor:
raw_log_prob = super().log_prob(value)
zi_log_prob = torch.empty_like(raw_log_prob)
z_mask = value.abs() < EPS
z_zi_logits, nz_zi_logits = self.zi_logits[z_mask], self.zi_logits[~z_mask]
zi_log_prob[z_mask] = (
raw_log_prob[z_mask].exp() + z_zi_logits.exp() + EPS
).log() - F.softplus(z_zi_logits)
zi_log_prob[~z_mask] = raw_log_prob[~z_mask] - F.softplus(nz_zi_logits)
return zi_log_prob
[docs]class Beta(D.Beta):
r"""
Stable beta distribution parameterized by mean and concentration
Parameters
----------
logit_mu
Logit of mean of the beta distribution
size
Concentration of the beta distribution
"""
def __init__(self, logit_mu: torch.Tensor, size: torch.Tensor) -> None:
if logit_mu.requires_grad:
logit_mu.register_hook(zero_nan_grad)
if size.requires_grad:
size.register_hook(zero_nan_grad)
mu = logit_mu.sigmoid()
super().__init__(mu * size + EPS, (1 - mu) * size + EPS, validate_args=False)
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor:
return super().log_prob(value.clamp(EPS, 1 - EPS))
[docs]class BetaBinomial(D.Beta):
r"""
Stable beta-binomial distribution parameterized by mean and concentration
Parameters
----------
logit_mu
Logit of mean of the beta distribution
size
Concentration of the beta distribution
"""
def __init__(self, logit_mu: torch.Tensor, size: torch.Tensor) -> None:
mu = logit_mu.sigmoid()
super().__init__(mu * size + EPS, (1 - mu) * size + EPS)
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor:
return BetaBinomialDistribution(
self.concentration1, self.concentration0, total_count=value.imag
).log_prob(value.real)
[docs]class Bernoulli(D.Bernoulli):
def __init__(self, logits: torch.Tensor) -> None:
if logits.requires_grad:
logits.register_hook(zero_nan_grad)
super().__init__(logits=logits, validate_args=False)