from typing import Any, Literal, Optional
import numpy as np
import torch
from scvi import REGISTRY_KEYS, settings
from scvi.distributions import NegativeBinomial, Poisson
from scvi.module import Classifier
from scvi.module.base import BaseModuleClass, auto_move_data
from scvi.nn import Decoder, DecoderSCVI, FCLayers
from sklearn.metrics import mean_squared_error, r2_score
from torch import nn
from torch.distributions import Categorical, Normal
from ._constants import LOSS_KEYS
__all__ = ["RegularizedEmbedding", "BiolordModule", "BiolordClassifyModule"]
class RegularizedEmbedding(nn.Module):
"""Regularized embedding module."""
def __init__(
self,
n_input: int,
n_output: int,
sigma: float,
embed: bool = True,
):
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=n_input,
embedding_dim=n_output,
)
self.sigma = sigma if embed else 0
self.embed = embed
def forward(self, x):
"""Forward pass."""
x_ = self.embedding(x)
if self.training and self.sigma != 0:
noise = torch.zeros_like(x_)
noise.normal_(mean=0, std=self.sigma)
x_ = x_ + noise
x_ = x_ * self.embed
return x_
[docs]
class BiolordModule(BaseModuleClass):
"""The :mod:`biolord` module.
Parameters
----------
n_genes
Number of input genes.
n_samples
Number of layers.
x_loc
The expression data location.
ordered_attributes_map
Dictionary of ordered classes and their dimensions.
categorical_attributes_map
Dictionary for categorical classes, containing categorical values with keys as each category name and values
as the categorical integer assignment.
n_latent
Latent dimension.
n_latent_attribute_ordered
Latent dimension of ordered attributes.
n_latent_attribute_categorical
Latent dimension of categorical attributes.
gene_likelihood
The gene_likelihood model.
reconstruction_penalty
MSE error to reconstruction loss.
use_batch_norm
Use batch norm in layers.
use_layer_norm
Use layer norm in layers.
unknown_attribute_noise_param
Noise strength added to encoding of unknown attributes.
unknown_attributes
Whether to include learning for unknown attributes
attribute_dropout_rate
Dropout rate.
attribute_nn_width
Ordered attributes autoencoder layers' width.
attribute_nn_depth
Ordered attributes autoencoder number of layers.
attribute_nn_activation
Use activation in ordered attributes.
decoder_width
Decoder layers' width.
decoder_depth
Decoder number of layers.
decoder_activation
Use activation in decoder.
eval_r2_ordered
Evaluate the R2 w.r.t. the ordered attribute. Set to `True` only if ordered attributes are binned.
decoder_dropout_rate
Decoder dropout rate.
seed
Random seed.
"""
def __init__(
self,
n_genes: int,
n_samples: int,
x_loc: str,
ordered_attributes_map: Optional[dict[str, int]] = None,
categorical_attributes_map: Optional[dict[str, dict]] = None,
n_latent: int = 32,
n_latent_attribute_categorical: int = 4,
n_latent_attribute_ordered: int = 16,
gene_likelihood: Literal["normal", "nb", "poisson"] = "normal",
reconstruction_penalty: float = 1e2,
unknown_attribute_penalty: float = 1e1,
use_batch_norm: bool = True,
use_layer_norm: bool = False,
unknown_attribute_noise_param: float = 1e-1,
unknown_attributes: bool = True,
attribute_dropout_rate: dict[str, float] = None,
decoder_width: int = 512,
decoder_depth: int = 4,
decoder_activation: bool = True,
attribute_nn_width: dict[str, int] = None,
attribute_nn_depth: dict[str, int] = None,
attribute_nn_activation: bool = True,
eval_r2_ordered: bool = False,
decoder_dropout_rate: float = 0.1,
seed: int = 0,
):
super().__init__()
gene_likelihood = gene_likelihood.lower()
assert gene_likelihood in ["normal", "nb", "poisson"], gene_likelihood
default_width = 256
default_depth = 2
torch.manual_seed(seed)
np.random.seed(seed)
settings.seed = seed
self.ae_loss_fn = nn.GaussianNLLLoss()
self.ae_loss_mse_fn = nn.MSELoss()
self.reconstruction_penalty = reconstruction_penalty
self.unknown_attribute_penalty = unknown_attribute_penalty
self.mm_regression_loss_fn = nn.BCEWithLogitsLoss()
self.n_genes = n_genes
self.n_latent = n_latent
self.x_loc = x_loc
self.n_latent_attribute_categorical = n_latent_attribute_categorical
self.n_latent_attribute_ordered = n_latent_attribute_ordered
self.gene_likelihood = gene_likelihood
self.use_batch_norm = use_batch_norm
self.use_layer_norm = use_layer_norm
self.eval_r2_ordered = eval_r2_ordered
self.n_decoder_input = n_latent + (
n_latent_attribute_categorical * len(categorical_attributes_map)
+ n_latent_attribute_ordered * len(ordered_attributes_map)
)
self.categorical_attributes_map = (
categorical_attributes_map if isinstance(categorical_attributes_map, dict) else {}
)
self.ordered_attributes_map = ordered_attributes_map if isinstance(ordered_attributes_map, dict) else {}
if isinstance(attribute_nn_width, dict):
self.attribute_nn_width = attribute_nn_width
elif attribute_nn_width is None:
self.attribute_nn_width = {attribute_: default_width for attribute_ in self.ordered_attributes_map}
else:
self.attribute_nn_width = {attribute_: attribute_nn_width for attribute_ in self.ordered_attributes_map}
if isinstance(attribute_nn_depth, dict):
self.attribute_nn_depth = attribute_nn_depth
elif attribute_nn_depth is None:
self.attribute_nn_depth = {attribute_: default_depth for attribute_ in self.ordered_attributes_map}
else:
self.attribute_nn_depth = {attribute_: attribute_nn_depth for attribute_ in self.ordered_attributes_map}
if isinstance(attribute_dropout_rate, dict):
self.attribute_dropout_rate = attribute_dropout_rate
elif attribute_dropout_rate is None:
self.attribute_dropout_rate = {
attribute_: decoder_dropout_rate for attribute_ in self.ordered_attributes_map
}
else:
self.attribute_dropout_rate = {
attribute_: attribute_dropout_rate for attribute_ in self.ordered_attributes_map
}
self.latent_codes = RegularizedEmbedding(
n_input=n_samples, n_output=n_latent, sigma=unknown_attribute_noise_param, embed=unknown_attributes
)
# Create Embeddings
# 1. ordered classes
reps_ordered = []
self.ordered_networks = nn.ModuleDict()
for attribute_, len_ in self.ordered_attributes_map.items():
if "_rep" in attribute_:
reps_ordered.append(attribute_)
else:
self.ordered_networks[attribute_] = FCLayers(
n_in=len_,
n_out=self.n_latent_attribute_ordered,
n_layers=self.attribute_nn_depth[attribute_],
n_hidden=self.attribute_nn_width[attribute_],
dropout_rate=self.attribute_dropout_rate[attribute_],
bias=False,
use_activation=attribute_nn_activation,
)
for attribute_ in reps_ordered:
self.ordered_networks[attribute_] = self.ordered_networks[attribute_.split("_rep")[0]]
# 2. categorical classes
self.categorical_embeddings = nn.ModuleDict()
reps_categorical = []
for attribute_, unique_categories in self.categorical_attributes_map.items():
if "_rep" in attribute_:
reps_categorical.append(attribute_)
else:
self.categorical_embeddings[attribute_] = torch.nn.Embedding(
len(unique_categories),
n_latent_attribute_categorical,
)
for attribute_ in reps_categorical:
self.categorical_embeddings[attribute_] = self.categorical_embeddings[attribute_.split("_rep")[0]]
# Decoder components
if self.gene_likelihood in ["nb", "poisson"]:
self.decoder = DecoderSCVI(
n_input=self.n_decoder_input,
n_output=n_genes,
n_hidden=decoder_width,
n_layers=decoder_depth,
use_batch_norm=use_batch_norm,
use_layer_norm=use_layer_norm,
scale_activation="softmax",
)
self.px_r = torch.nn.Parameter(torch.randn(n_genes))
else:
self.decoder = Decoder(
n_input=self.n_decoder_input,
n_output=n_genes,
n_hidden=decoder_width,
n_layers=decoder_depth,
use_batch_norm=use_batch_norm,
use_layer_norm=use_layer_norm,
use_activation=decoder_activation,
)
def _get_inference_input(self, tensors: dict[Any, Any], **kwargs):
x = tensors[self.x_loc] # batch_size, n_genes
sample_indices = tensors[REGISTRY_KEYS.INDICES_KEY].long().ravel()
categorical_attribute_dict = {}
for attribute_ in self.categorical_attributes_map:
categorical_attribute_dict[attribute_] = tensors[attribute_].view(
-1,
)
ordered_attribute_dict = {}
for attribute_ in self.ordered_attributes_map:
ordered_attribute_dict[attribute_] = tensors[attribute_]
input_dict = {
"genes": x,
"sample_indices": sample_indices,
"categorical_attribute_dict": categorical_attribute_dict,
"ordered_attribute_dict": ordered_attribute_dict,
}
return input_dict
@auto_move_data
def _inference_attribute_embeddings(
self,
genes,
categorical_attribute_dict,
ordered_attribute_dict,
nullify_attribute=None,
):
"""Inference over attribute embeddings."""
nullify_attribute = [] if nullify_attribute is None else nullify_attribute
inference_output = {}
batch_size = genes.shape[0]
for attribute_, embedding_ in self.categorical_embeddings.items():
latent_i = embedding_(categorical_attribute_dict[attribute_].long())
latent_i = latent_i.view(batch_size, self.n_latent_attribute_categorical).unsqueeze(
0
) # 1, batch_size, n_latent_attribute_categorical
if attribute_ in nullify_attribute:
latent_i = torch.zeros_like(latent_i)
inference_output[attribute_] = latent_i
for attribute_, network_ in self.ordered_networks.items():
latent_i = network_(ordered_attribute_dict[attribute_])
latent_i = latent_i.view(batch_size, self.n_latent_attribute_ordered).unsqueeze(0)
if attribute_ in nullify_attribute:
latent_i = torch.zeros_like(latent_i)
inference_output[attribute_] = latent_i
return inference_output
def _get_latent_unknown_attributes(
self,
sample_indices,
):
"""Get the module's latent unknown attributes representation."""
latent_unknown_attributes = self.latent_codes(sample_indices)
return latent_unknown_attributes
[docs]
@auto_move_data
def inference(
self,
genes: torch.Tensor,
sample_indices: torch.Tensor,
categorical_attribute_dict: dict[Any, Any],
ordered_attribute_dict: dict[Any, Any],
nullify_attribute: Optional[list] = None,
) -> dict[str, Any]:
"""Apply module inference.
Parameters
----------
genes
Input expression.
sample_indices
Indices in the :class:`~anndata.AnnData` object of the input samples.
categorical_attribute_dict
Dictionary with categorical attributes as keys and the attribute sample labels as values.
ordered_attribute_dict
Dictionary with ordered attributes as keys and the attribute sample values as values.
nullify_attribute
Attributes to exclude from inferred latent space.
Returns
-------
Dictionary with the module's expected input tensors (`genes`, `sample_indices`, `categorical_attribute_dict`, and `ordered_attribute_dict`).
"""
nullify_attribute = [] if nullify_attribute is None else nullify_attribute
inference_output = {}
x_ = genes
library = torch.log(genes.sum(1)).unsqueeze(1)
latent_unknown_attributes = self._get_latent_unknown_attributes(sample_indices=sample_indices)
latent_classes = self._inference_attribute_embeddings(
genes=x_,
categorical_attribute_dict=categorical_attribute_dict,
ordered_attribute_dict=ordered_attribute_dict,
nullify_attribute=nullify_attribute,
)
latent_vecs = [latent_unknown_attributes.squeeze()]
for key_, latent_ in latent_classes.items():
latent_vecs.append(latent_.squeeze())
inference_output[key_] = latent_.squeeze()
latent = torch.cat(latent_vecs, dim=-1)
inference_output["latent"] = latent
inference_output["latent_unknown_attributes"] = latent_unknown_attributes
inference_output["library"] = library
return inference_output
def _get_generative_input(self, tensors, inference_outputs, **kwargs):
input_dict = {
"latent": inference_outputs["latent"],
"library": inference_outputs["library"],
}
return input_dict
[docs]
@auto_move_data
def generative(
self,
latent: torch.Tensor,
library: torch.Tensor = None,
) -> dict[str, Any]:
"""Runs the generative step.
Parameters
----------
latent
The concatenated decomposed latent space.
library
Library sizes for each cell.
Returns
-------
Dictionary with the generative predictions of the expression distribution.
"""
if self.gene_likelihood in ["nb", "poisson"]:
px_scale, _, px_rate, _ = self.decoder(
dispersion="gene",
z=latent,
library=library,
)
px_r = torch.exp(self.px_r)
px = (
NegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale)
if self.gene_likelihood == "nb"
else Poisson(px_rate)
) # , scale=px_scale)
return {
"means": px.mean,
"variances": px.variance,
"distribution": px,
"samples": px.sample().squeeze(0),
}
else:
p_m, p_v = self.decoder(x=latent)
px = Normal(loc=p_m, scale=p_v.sqrt())
return {
"means": px.loc,
"variances": px.variance,
"distribution": px,
"samples": px.sample().squeeze(0),
}
[docs]
@auto_move_data
def loss(
self,
tensors: dict[str, torch.Tensor],
inference_outputs: dict[Literal["latent_unknown_attributes"], torch.Tensor],
generative_outputs: dict[Literal["distribution", "means", "variances"], torch.Tensor],
) -> dict[str, float]:
"""Computes the module's loss.
Parameters
----------
tensors
Considered model inputs.
inference_outputs
Inference step outputs.
generative_outputs
Generative step outputs.
Returns
-------
The loss elements.
"""
x_ = tensors[self.x_loc]
means = generative_outputs["means"]
variances = generative_outputs["variances"]
if self.gene_likelihood in ["nb", "poisson"]:
reconstruction_loss = -generative_outputs["distribution"].log_prob(x_).sum(-1)
reconstruction_loss = reconstruction_loss.mean()
else:
reconstruction_loss = self.ae_loss_fn(input=means, target=x_, var=variances)
reconstruction_loss += self.reconstruction_penalty * self.ae_loss_mse_fn(input=means, target=x_)
unknown_attribute_penalty_loss_val = self.unknown_attribute_penalty_loss(
inference_outputs["latent_unknown_attributes"]
)
return {
LOSS_KEYS.RECONSTRUCTION: reconstruction_loss,
LOSS_KEYS.UNKNOWN_ATTRIBUTE_PENALTY: unknown_attribute_penalty_loss_val,
}
[docs]
@staticmethod
def unknown_attribute_penalty_loss(latent_unknown_attributes: torch.Tensor) -> float:
"""Computes the content penalty term in the loss."""
return torch.sum(latent_unknown_attributes**2, dim=1).mean()
[docs]
@torch.no_grad()
def r2_metric(
self,
tensors: dict[str, torch.Tensor],
generative_outputs: dict[str, torch.Tensor],
) -> tuple[float, float]:
"""Evaluate the :math:`R^2` metric over gene expression.
Parameters
----------
tensors
Considered inputs.
generative_outputs
Generative model outputs.
Returns
-------
The :math:`R^2` of the mean and standard deviation predictions of the gene expression.
"""
x = tensors[self.x_loc].detach().cpu().numpy() # batch_size, n_genes
batch_size = x.shape[0]
indices = torch.zeros(batch_size).to(self.device)
if self.eval_r2_ordered:
for ordered_attribute_, len_ in self.ordered_attributes_map.items():
if len_ > 1:
attribute_vals = tensors[ordered_attribute_] # (batch_size, n_class)
indices += (
attribute_vals
* torch.arange(attribute_vals.shape[1])
.view(1, -1)
.repeat(batch_size, 1)
.to(attribute_vals.device)
).sum(dim=1)
else:
indices += tensors[ordered_attribute_].view(
-1,
) # (batch_size,)
for categorical_attribute_ in self.categorical_attributes_map:
indices += tensors[categorical_attribute_].view(
-1,
) # (batch_size,)
unique_indices = indices.unique()
r2_mean = 0.0
r2_var = 0.0
k = 0
pred_x_mean = (
torch.nan_to_num(generative_outputs["means"], nan=0, neginf=0, posinf=100).detach().cpu().numpy()
) # batch_size, n_genes
pred_x_var = (
torch.nan_to_num(generative_outputs["variances"], nan=0, neginf=0, posinf=100).detach().cpu().numpy()
) # batch_size, n_genes
for index in unique_indices:
index_mask = (indices == index).detach().cpu().numpy()
if index_mask.sum() > 2:
x_index = x[index_mask]
means_index = pred_x_mean[index_mask]
variances_index = pred_x_var[index_mask]
true_mean_index = np.nanmean(x_index, axis=0)
pred_mean_index = np.nanmean(means_index, axis=0)
true_var_index = np.nanvar(x_index, axis=0)
pred_var_index = (
np.nanvar(means_index, axis=0)
if self.gene_likelihood in ["nb", "poisson"]
else np.nanmean(variances_index, axis=0)
)
r2_mean += r2_score(true_mean_index, pred_mean_index)
r2_var += r2_score(true_var_index, pred_var_index)
k += 1
else:
continue
if k > 0:
return r2_mean / k, r2_var / k
else:
return r2_mean, r2_var
[docs]
@torch.no_grad()
def get_expression(
self, tensors: dict[str, torch.Tensor], **inference_kwargs: Any
) -> tuple[torch.tensor, torch.tensor]:
"""Computes gene expression means and standard deviation.
Parameters
----------
tensors
Considered inputs.
inference_kwargs
Additional arguments.
Returns
-------
Prediction of gene expression mean and standard deviation.
"""
_, generative_outputs = self.forward(
tensors,
compute_loss=False,
inference_kwargs=inference_kwargs,
)
mus = torch.nan_to_num(generative_outputs["means"], nan=0, neginf=0, posinf=100) # batch_size, n_genes
stds = torch.nan_to_num(generative_outputs["variances"], nan=0, neginf=0, posinf=100) # batch_size, n_genes
return mus, stds
[docs]
class BiolordClassifyModule(BiolordModule):
"""The `biolord-classify` module.
A :class:`~biolord.BiolordModule` accompanied by regressors for ordered classes and
classifiers for categorical classes.
Parameters
----------
categorical_attributes_missing
Categorical categories representing un-labeled cells.
classify_all
Whether to classify all classes or only semi-supervised classes.
logits
Classifier output type.
bias
Whether to add bias to the regressor.
classification_penalty
Classification penalty strength.
classifier_nn_width
Classifier's layer width.
classifier_nn_depth
Classifier's number of layers.
classifier_dropout_rate
Classifier's dropout rate.
loss_regression
Loss function for regressors
kwargs
Keyword arguments for :class:`~biolord.BiolordModule`.
"""
def __init__(
self,
categorical_attributes_missing: Optional[dict[str, str]] = None,
classify_all: bool = False,
logits: bool = False,
bias: bool = True,
classification_penalty: float = 1e-1,
classifier_penalty: float = 1e-4,
classifier_nn_width: int = 128,
classifier_nn_depth: int = 2,
classifier_dropout_rate: float = 1e-1,
loss_regression: Literal["normal", "mse"] = "normal",
**kwargs: Any,
):
super().__init__(**kwargs)
loss_regression = loss_regression.lower()
assert loss_regression in ["normal", "mse"], loss_regression
self.classification_penalty = classification_penalty
self.classifier_penalty = classifier_penalty
self.classification_loss_fn = nn.CrossEntropyLoss()
self.regression_loss_fn = nn.MSELoss() if loss_regression == "mse" else nn.GaussianNLLLoss()
self.loss_regression = loss_regression
self.mm_regression_loss_fn = nn.BCEWithLogitsLoss()
self.classify_all = classify_all
if isinstance(categorical_attributes_missing, dict):
self.categorical_attributes_missing = categorical_attributes_missing
elif categorical_attributes_missing is None:
self.categorical_attributes_missing = {attribute_: None for attribute_ in self.categorical_attributes_map}
else:
self.categorical_attributes_missing = {
attribute_: categorical_attributes_missing for attribute_ in self.categorical_attributes_map
}
self.categorical_embeddings = nn.ModuleDict()
for attribute_, unique_categories in self.categorical_attributes_map.items():
if self.categorical_attributes_missing[attribute_] is not None:
self.categorical_embeddings[attribute_] = torch.nn.Embedding(
len(unique_categories),
self.n_latent_attribute_categorical,
padding_idx=self.categorical_attributes_map[attribute_][
self.categorical_attributes_missing[attribute_]
],
)
else:
self.categorical_embeddings[attribute_] = torch.nn.Embedding(
len(unique_categories),
self.n_latent_attribute_categorical,
)
# Create classifiers
self.ordered_regressors = nn.ModuleDict()
if self.classify_all:
if self.loss_regression == "mse":
self.ordered_regressors = nn.ModuleDict(
{
attribute_: nn.Linear(
in_features=self.n_genes,
out_features=len_,
bias=bias,
)
for attribute_, len_ in self.ordered_attributes_map.items()
}
)
else:
self.ordered_regressors = nn.ModuleDict(
{
attribute_: Decoder(
n_input=self.n_genes,
n_output=len_,
n_hidden=classifier_nn_width,
n_layers=classifier_nn_depth,
use_batch_norm=self.use_batch_norm,
use_layer_norm=self.use_layer_norm,
)
for attribute_, len_ in self.ordered_attributes_map.items()
}
)
self.categorical_classifiers = nn.ModuleDict()
for attribute_, unique_categories in self.categorical_attributes_map.items():
if self.categorical_attributes_missing[attribute_] is not None:
self.categorical_classifiers[attribute_] = Classifier(
n_input=self.n_genes,
n_labels=len(unique_categories) - 1,
n_hidden=classifier_nn_width,
n_layers=classifier_nn_depth,
dropout_rate=classifier_dropout_rate,
logits=logits,
)
elif self.classify_all:
self.categorical_classifiers[attribute_] = Classifier(
n_input=self.n_genes,
n_labels=len(unique_categories),
n_hidden=classifier_nn_width,
n_layers=classifier_nn_depth,
dropout_rate=classifier_dropout_rate,
logits=logits,
)
def _get_inference_input(self, tensors: dict, **kwargs):
x = tensors[self.x_loc] # batch_size, n_genes
sample_indices = tensors[REGISTRY_KEYS.INDICES_KEY].long().ravel()
classification = self.classify(x)
categorical_attribute_dict = {}
for attribute_ in self.categorical_attributes_map:
categorical_attribute_dict[attribute_] = tensors[attribute_].view(
-1,
)
if self.categorical_attributes_missing[attribute_] is not None:
idx_ = categorical_attribute_dict[attribute_] == self.categorical_attributes_missing[attribute_]
categorical_attribute_dict[attribute_][idx_] = classification[attribute_][idx_]
ordered_attribute_dict = {}
for attribute_ in self.ordered_attributes_map:
ordered_attribute_dict[attribute_] = tensors[attribute_]
input_dict = {
"genes": x,
"sample_indices": sample_indices,
"categorical_attribute_dict": categorical_attribute_dict,
"ordered_attribute_dict": ordered_attribute_dict,
}
return input_dict
[docs]
@auto_move_data
def classify(
self,
genes: torch.Tensor,
) -> dict[str, torch.Tensor]:
"""Run classification.
Parameters
----------
genes
Gene expression used for classification.
Returns
-------
Classification output, probability for each ordered attribute and the regression value
for ordered attributes
"""
classification = {}
for attribute_, classifier_ in self.categorical_classifiers.items():
classification[attribute_] = classifier_(genes)
for attribute_, regressor_ in self.ordered_regressors.items():
classification[attribute_] = regressor_(genes)
return classification
@auto_move_data
def _classification_loss(self, tensors: dict[str, torch.Tensor]):
"""Get module classification loss."""
x = tensors[self.x_loc] # batch_size, n_genes
classification_loss = torch.tensor([0.0]).to(self.device)
classification = self.classify(x)
for attribute_ in self.categorical_classifiers:
attribute_vals = tensors[attribute_].view(-1).long()
if self.categorical_attributes_missing[attribute_] is not None:
idx_ = (
attribute_vals
== self.categorical_attributes_map[attribute_][self.categorical_attributes_missing[attribute_]]
)
assignment = Categorical(classification[attribute_][idx_, :])
loss_fn = nn.CrossEntropyLoss(
ignore_index=self.categorical_attributes_map[attribute_][
self.categorical_attributes_missing[attribute_]
]
)
classification_loss += (
loss_fn(
classification[attribute_],
attribute_vals,
)
+ assignment.entropy().mean()
)
else:
classification_loss += self.classification_loss_fn(
classification[attribute_],
attribute_vals,
)
for attribute_, len_ in self.ordered_attributes_map.items():
if attribute_ in classification:
if len_ > 1:
if self.loss_regression == "mse":
classification_loss += self.regression_loss_fn(
classification[attribute_], tensors[attribute_].float()
) + self.classification_penalty * self.mm_regression_loss_fn(
classification[attribute_], tensors[attribute_].gt(0).float()
)
else:
classification_loss += self.regression_loss_fn(
classification[attribute_][0],
tensors[attribute_].float(),
classification[attribute_][1],
)
else:
classification_loss += self.regression_loss_fn(
classification[attribute_][0],
tensors[attribute_].float(),
classification[attribute_][1],
)
return classification_loss
[docs]
@auto_move_data
def loss(
self,
tensors: dict[str, torch.Tensor],
inference_outputs: dict[Literal["latent_unknown_attributes"], torch.Tensor],
generative_outputs: dict[Literal["distribution", "means", "variances"], torch.Tensor],
) -> dict[str, float]:
"""Compute the loss.
Parameters
----------
tensors
Considered model inputs.
inference_outputs
Inference step outputs.
generative_outputs
Generative step outputs.
Returns
-------
The loss elements.
"""
losses = super().loss(
tensors=tensors,
inference_outputs=inference_outputs,
generative_outputs=generative_outputs,
)
losses[LOSS_KEYS.CLASSIFICATION] = self._classification_loss(tensors)
return losses
@torch.no_grad()
def _classification_accuracy(self, tensors):
"""Get module classification accuracy."""
accuracy_dict = {}
accuracy_val = 0
r2_dict = {}
r2_val = 0
mse_dict = {}
mse_val = 0
x = tensors[self.x_loc] # .detach().cpu().numpy() # batch_size, n_genes
classification = self.classify(x)
for attribute_ in self.categorical_classifiers:
attribute_vals_pred = classification[attribute_].argmax(dim=-1).cpu().numpy()
attribute_vals = tensors[attribute_].view(-1).long().cpu().numpy()
if self.categorical_attributes_missing[attribute_] is not None:
idx_ = (
attribute_vals
!= self.categorical_attributes_map[attribute_][self.categorical_attributes_missing[attribute_]]
)
if idx_.sum() > 0:
accuracy_dict[attribute_] = np.mean(attribute_vals_pred[idx_] == attribute_vals[idx_])
accuracy_val += accuracy_dict[attribute_]
else:
accuracy_dict[attribute_] = np.mean(attribute_vals_pred == attribute_vals)
accuracy_val += accuracy_dict[attribute_]
for attribute_ in self.ordered_regressors:
attribute_vals_pred = (
classification[attribute_] if self.regression_loss == "mse" else classification[attribute_][0]
)
attribute_vals = tensors[attribute_].cpu().numpy()
attribute_vals_pred = (
torch.nan_to_num(attribute_vals_pred, nan=0, neginf=0, posinf=100).detach().cpu().numpy()
)
r2_dict[attribute_] = r2_score(attribute_vals, attribute_vals_pred) if attribute_vals.shape[0] > 2 else 0
mse_dict[attribute_] = mean_squared_error(attribute_vals, attribute_vals_pred)
r2_val += r2_dict[attribute_]
mse_val += mse_dict[attribute_]
accuracy_mean = 0
r2_mean = 0
mse_mean = 0
if len(accuracy_dict):
accuracy_mean = accuracy_val / len(accuracy_dict)
if len(r2_dict):
r2_mean = r2_val / len(r2_dict)
mse_mean = mse_val / len(mse_dict)
return accuracy_dict, accuracy_mean, r2_dict, r2_mean, mse_dict, mse_mean