# Source code for scglue.models.prob

```r"""
Probability distributions
"""

import torch
import torch.distributions as D
import torch.nn.functional as F

from ..num import EPS

#-------------------------------- Distributions --------------------------------

[docs]class MSE(D.Distribution):

r"""
A "sham" distribution that outputs negative MSE on ``log_prob``

Parameters
----------
loc
Mean of the distribution
"""

def __init__(self, loc: torch.Tensor) -> None:
super().__init__(validate_args=False)
self.loc = loc

[docs]    def log_prob(self, value: torch.Tensor) -> None:
return -F.mse_loss(self.loc, value)

@property
def mean(self) -> torch.Tensor:
return self.loc

[docs]class RMSE(MSE):

r"""
A "sham" distribution that outputs negative RMSE on ``log_prob``

Parameters
----------
loc
Mean of the distribution
"""

[docs]    def log_prob(self, value: torch.Tensor) -> None:
return -F.mse_loss(self.loc, value).sqrt()

[docs]class ZIN(D.Normal):

r"""
Zero-inflated normal distribution with subsetting support

Parameters
----------
zi_logits
Zero-inflation logits
loc
Location of the normal distribution
scale
Scale of the normal distribution
"""

def __init__(
self, zi_logits: torch.Tensor,
loc: torch.Tensor, scale: torch.Tensor
) -> None:
super().__init__(loc, scale)
self.zi_logits = zi_logits

[docs]    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
raw_log_prob = super().log_prob(value)
zi_log_prob = torch.empty_like(raw_log_prob)
).log() - F.softplus(z_zi_logits)
return zi_log_prob

[docs]class ZILN(D.LogNormal):

r"""
Zero-inflated log-normal distribution with subsetting support

Parameters
----------
zi_logits
Zero-inflation logits
loc
Location of the log-normal distribution
scale
Scale of the log-normal distribution
"""

def __init__(
self, zi_logits: torch.Tensor,
loc: torch.Tensor, scale: torch.Tensor
) -> None:
super().__init__(loc, scale)
self.zi_logits = zi_logits

[docs]    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
zi_log_prob = torch.empty_like(value)
return zi_log_prob

[docs]class ZINB(D.NegativeBinomial):

r"""
Zero-inflated negative binomial distribution

Parameters
----------
zi_logits
Zero-inflation logits
total_count
Total count of the negative binomial distribution
logits
Logits of the negative binomial distribution
"""

def __init__(
self, zi_logits: torch.Tensor,
total_count: torch.Tensor, logits: torch.Tensor = None
) -> None:
super().__init__(total_count, logits=logits)
self.zi_logits = zi_logits

[docs]    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
raw_log_prob = super().log_prob(value)
zi_log_prob = torch.empty_like(raw_log_prob)