import io
import itertools
import logging
import os
from collections.abc import Sequence
from typing import Any, Literal, Optional, Union
import numpy as np
import pandas as pd
import rich
import torch
from anndata import AnnData
from lightning.pytorch.callbacks import ModelCheckpoint
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import (
CategoricalObsField,
LayerField,
NumericalObsField,
ObsmField,
)
from scvi.dataloaders import DataSplitter
from scvi.model.base import BaseModelClass
from scvi.train import TrainRunner
from scvi.utils import setup_anndata_dsp
from tqdm import tqdm
from ._data import AnnDataSplitter
from ._module import BiolordClassifyModule, BiolordModule
from ._train import biolordClassifyTrainingPlan, biolordTrainingPlan
from ._utils import repeat_n
logger = logging.getLogger(__name__)
logger.propagate = False
logging_dir = "./biolord_log/"
__all__ = ["Biolord"]
[docs]
class Biolord(BaseModelClass):
"""The biolord model class.
Parameters
----------
adata
Annotated data object.
model_name
Name of the model.
module_params
Hyperparameters for the model's module initialization, e.g, :class:`~biolord.BiolordModule` or
:class:`~biolord.BiolordClassifyModule`.
n_latent
Number of latent dimensions used for the latent embedding.
train_classifiers
Whether to activate a :class:`~biolord.BiolordClassifyModule`.
split_key
Key in :attr:`anndata.AnnData.obs` used to split the data between train, test and validation.
train_split
Value in :attr:`anndata.AnnData.obs` ``['{split_key}']`` marking the train set.
valid_split
Value in :attr:`anndata.AnnData.obs` ``['{split_key}']`` marking the validation set.
test_split
Value in :attr:`anndata.AnnData.obs` ``['{split_key}']`` marking the test set.
Examples
--------
.. code-block:: python
import scanpy as sc
import biolord
adata = sc.read(...)
biolord.Biolord.setup_anndata(
adata, ordered_attributes_keys=["time"], categorical_attributes_keys=["cell_type"]
)
model = biolord.Biolord(adata, n_latent=256, split_key="split")
model.train(max_epochs=200, batch_size=256)
"""
def __init__(
self,
adata: AnnData,
model_name: Optional[str] = None,
module_params: dict[str, Any] = None,
n_latent: int = 128,
train_classifiers: bool = False,
split_key: Optional[str] = None,
train_split: str = "train",
valid_split: str = "test",
test_split: str = "ood",
):
super().__init__(adata)
self.categorical_attributes_map = {}
self.ordered_attributes_map = {}
self.retrieval_attribute_dict = {}
self.categorical_attributes_missing = self.registry_["setup_args"]["categorical_attributes_missing"]
self.x_loc = None
self._set_attributes_maps()
self.n_latent = n_latent
self.n_genes = adata.n_vars
self.split_key = split_key
self.scores = {}
self._module = None
self._training_plan = None
self._data_splitter = None
train_indices, valid_indices, test_indices = None, None, None
if split_key is not None:
train_indices = np.where(adata.obs.loc[:, split_key] == train_split)[0]
valid_indices = np.where(adata.obs.loc[:, split_key] == valid_split)[0]
test_indices = np.where(adata.obs.loc[:, split_key] == test_split)[0]
self.train_indices = train_indices
self.valid_indices = valid_indices
self.test_indices = test_indices
self.n_samples = adata.n_obs
self.train_classifiers = train_classifiers
module_params = module_params if isinstance(module_params, dict) else {}
if self.train_classifiers:
self.module = BiolordClassifyModule(
n_genes=self.n_genes,
n_samples=self.n_samples,
x_loc=self.x_loc,
categorical_attributes_map=self.categorical_attributes_map,
ordered_attributes_map=self.ordered_attributes_map,
categorical_attributes_missing=self.categorical_attributes_missing,
n_latent=self.n_latent,
**module_params,
).float()
else:
self.module = BiolordModule(
n_genes=self.n_genes,
n_samples=self.n_samples,
x_loc=self.x_loc,
ordered_attributes_map=self.ordered_attributes_map,
categorical_attributes_map=self.categorical_attributes_map,
n_latent=self.n_latent,
**module_params,
).float()
self._model_summary_string = self.__class__.__name__
self._model_name = model_name
self.init_params_ = self._get_init_params(locals())
self.epoch_history = None
def _set_attributes_maps(self):
"""Set attributes' maps."""
for attribute_ in self.registry_["setup_args"]["categorical_attributes_keys"]:
self.categorical_attributes_map[attribute_] = {
c: i
for i, c in enumerate(
self.registry_["field_registries"][attribute_]["state_registry"]["categorical_mapping"]
)
}
for attribute_ in self.registry_["setup_args"]["ordered_attributes_keys"]:
# validata obs
if attribute_ in self.adata.obs:
self.ordered_attributes_map[attribute_] = 1
elif attribute_ in self.adata.obsm:
self.ordered_attributes_map[attribute_] = self.adata.obsm[attribute_].shape[1]
else:
raise KeyError(f"class {attribute_} not found in `adata.obs` or `adata.obsm`.")
if self.registry_["setup_args"]["retrieval_attribute_key"] is not None:
self.retrieval_attribute_dict = {
"retrieval_attribute_key": len(
np.unique(self.adata.obs[self.registry_["setup_args"]["retrieval_attribute_key"]])
)
}
self.x_loc = self.registry_["setup_args"]["FIELD"].attr_name
@property
def training_plan(self):
"""The model's training plan."""
return self._training_plan
@training_plan.setter
def training_plan(self, plan):
self._training_plan = plan
@property
def data_splitter(self):
"""Data splitter."""
return self._data_splitter
@data_splitter.setter
def data_splitter(self, data_splitter):
self._data_splitter = data_splitter
@property
def module(self) -> BiolordModule:
"""Model's module."""
return self._module
@module.setter
def module(self, module: BiolordModule):
self._module = module
@property
def model_name(self) -> str:
"""Model's name."""
return self._model_name
@model_name.setter
def model_name(self, model_name: str):
self._model_name = model_name
[docs]
@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
ordered_attributes_keys: Optional[list[str]] = None,
categorical_attributes_keys: Optional[list[str]] = None,
categorical_attributes_missing: Optional[dict[str, str]] = None,
retrieval_attribute_key: Optional[str] = None,
layer: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Setup function.
Parameters
----------
adata
Annotated data object.
ordered_attributes_keys
Valid :attr:`anndata.AnnData.obs` or :attr:`anndata.AnnData.obsm` keys for the ordered attributes.
categorical_attributes_keys
Valid :attr:`anndata.AnnData.obs` keys for the categorical attributes.
categorical_attributes_missing
Categories representing missing labels. Only used if ``train_classifiers=True``.
retrieval_attribute_key
Valid :attr:`anndata.AnnData.obs` key for an attribute to evaluate retrieval performance over.
layer
Expression layer in :attr:`anndata.AnnData.layers` to use. If :obj:`None`, use :attr:`anndata.AnnData.X`.
kwargs
Keyword arguments for :meth:`~scvi.data.AnnDataManager.register_fields`.
Returns
-------
Nothing, just sets up ``adata``.
"""
if layer is not None:
if layer not in adata.layers:
raise KeyError(f"{layer} is not a valid key in `adata.layers`.")
logger.info(f"Using data from adata.layers[{layer!r}]")
FIELD = LayerField(
registry_key="layers",
layer=layer,
is_count_data=True,
)
else:
logger.info("Using data from `adata.X`.")
FIELD = LayerField(registry_key="X", layer=None, is_count_data=False)
ordered_attributes_keys = ordered_attributes_keys if isinstance(ordered_attributes_keys, list) else []
categorical_attributes_keys = (
categorical_attributes_keys if isinstance(categorical_attributes_keys, list) else []
)
if categorical_attributes_missing is not None:
for attribute_, val_ in categorical_attributes_missing.items():
if val_ is not None:
adata.obs[attribute_] = adata.obs[attribute_].astype("category")
cats = adata.obs[attribute_].cat.categories
idx_ = cats.isin([val_])
ordered = list(cats[~idx_]) + [val_]
adata.obs[attribute_] = adata.obs[attribute_].cat.reorder_categories(ordered)
# set retrieval class
retrieval_attribute_dict = {}
if retrieval_attribute_key is not None:
retrieval_attribute_dict = {"retrieval_attribute_key": len(np.unique(adata.obs[retrieval_attribute_key]))}
# set ordered classes
ordered_attributes_obs = []
ordered_attributes_obsm = []
for attribute_ in ordered_attributes_keys:
# validata obs
if attribute_ in adata.obs:
ordered_attributes_obs.append(attribute_)
elif attribute_ in adata.obsm:
ordered_attributes_obsm.append(attribute_)
else:
raise KeyError(f"class {attribute_} not found in `adata.obs` or `adata.obsm`.")
setup_method_args = cls._get_setup_method_args(**locals())
adata.obs["_indices"] = np.arange(adata.n_obs)
anndata_fields = (
[
FIELD,
NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"),
]
+ [
CategoricalObsField(registry_key=attribute_, attr_key=attribute_)
for attribute_ in categorical_attributes_keys
]
+ [
NumericalObsField(
attribute_,
attribute_,
)
for attribute_ in ordered_attributes_obs
]
+ [
ObsmField(
attribute_,
attribute_,
is_count_data=False,
correct_data_format=True,
)
for attribute_ in ordered_attributes_obsm
]
)
if retrieval_attribute_key is not None:
anndata_fields += [
CategoricalObsField(
registry_key=retrieval_attribute_key,
attr_key=retrieval_attribute_key,
)
]
adata_manager = AnnDataManager(
fields=anndata_fields,
setup_method_args=setup_method_args,
)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)
[docs]
@torch.no_grad()
def get_latent_representation_adata(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = 512,
nullify_attribute: Optional[list[str]] = None,
) -> tuple[AnnData, AnnData]:
"""Return the unknown attributes latent space and full latent variable.
Parameters
----------
adata
Annotated data object.
indices
Optional indices.
batch_size
Batch size to use.
nullify_attribute
Attribute to nullify in the latent space.
Returns
-------
Two :class:`~anndata.AnnData` objects providing the unknown attributes latent space and
the concatenated decomposed latent respectively.
"""
if self.is_trained_ is False:
raise RuntimeError("Please train the model first.")
nullify_attribute = [] if nullify_attribute is None else nullify_attribute
adata = self._validate_anndata(adata)
if indices is None:
indices = np.arange(adata.n_obs)
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size, shuffle=False)
latent_unknown_attributes = []
latent = []
for tensors in scdl:
inference_inputs = self.module.get_inference_input(tensors)
outputs = self.module.inference(**inference_inputs, nullify_attribute=nullify_attribute)
latent_unknown_attributes += [outputs["latent_unknown_attributes"].cpu().numpy()]
latent += [outputs["latent"].cpu().numpy()]
latent_unknown_attributes_adata = AnnData(
X=np.concatenate(latent_unknown_attributes, axis=0), obs=adata[indices].obs.copy()
)
latent_unknown_attributes_adata.obs_names = adata[indices].obs_names
latent_adata = AnnData(X=np.concatenate(latent, axis=0), obs=adata[indices].obs.copy())
latent_adata.obs_names = adata[indices].obs_names
return latent_unknown_attributes_adata, latent_adata
[docs]
@torch.no_grad()
def get_dataset(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
) -> dict[str, torch.Tensor]:
"""Processes :class:`~anndata.AnnData` object into valid input tensors for the model.
Parameters
----------
adata
Annotated data object.
indices
Optional indices.
Returns
-------
A dictionary of tensors which can be passed as input to the model.
"""
adata = self._validate_anndata(adata)
if indices is None:
indices = np.arange(adata.n_obs)
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=len(indices), shuffle=False)
return list(scdl)[0]
[docs]
@torch.no_grad()
def predict(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = 512,
nullify_attribute: Optional[list[str]] = None,
) -> tuple[AnnData, AnnData]:
"""The model's gene expression prediction for a given :class:`~anndata.AnnData` object.
Parameters
----------
adata
Annotated data object.
indices
Optional indices.
batch_size
Batch size to use.
nullify_attribute
Attribute to nullify in latent space.
Returns
-------
Two :class:`~anndata.AnnData` objects representing the model's prediction of the expression mean and variance respectively.
"""
nullify_attribute = [] if nullify_attribute is None else nullify_attribute
self.module.eval()
adata = self._validate_anndata(adata)
if indices is None:
indices = np.arange(adata.n_obs)
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size, shuffle=False)
mus = []
stds = []
for tensors in scdl:
_mus, _stds = self.module.get_expression(tensors, nullify_attribute=nullify_attribute)
_mus = _mus if _mus.ndim > 1 else _mus[None, :]
_stds = _stds if _stds.ndim > 1 else _stds[None, :]
mus.append(_mus.detach().cpu().numpy())
stds.append(_stds.detach().cpu().numpy())
pred_adata_mean = AnnData(X=np.concatenate(mus, axis=0), obs=adata.obs.copy())
pred_adata_var = AnnData(X=np.concatenate(stds, axis=0), obs=adata.obs.copy())
pred_adata_mean.obs_names = adata.obs_names
pred_adata_var.obs_names = adata.obs_names
pred_adata_mean.var_names = adata.var_names
pred_adata_var.var_names = adata.var_names
return pred_adata_mean, pred_adata_var
[docs]
@torch.no_grad()
def get_ordered_attribute_embedding(
self,
attribute_key: str,
vals: Optional[Union[float, str, np.ndarray]] = None,
) -> np.ndarray:
"""Compute embedding of an ordered attribute.
Parameters
----------
attribute_key
The key of the desired attribute.
vals
Values of interest.
Returns
-------
Array of the attribute's embedding.
"""
self.module.eval()
vals = vals if vals is not None else 1.0
if isinstance(vals, float):
batch = torch.tensor([vals]).to(self.device).float()
elif isinstance(vals, np.ndarray):
batch = torch.tensor(vals).to(self.device).float()
else:
batch = vals
return self.module.ordered_networks[attribute_key](batch).detach().cpu().numpy()
[docs]
@torch.no_grad()
def get_categorical_attribute_embeddings(
self, attribute_key: str, attribute_category: Optional[str] = None
) -> np.ndarray:
"""Compute embedding of a categorical attribute.
Parameters
----------
attribute_key
The key of the desired attribute.
attribute_category
A specific category for embedding computation.
Returns
-------
Array of the attribute's embedding.
"""
if attribute_category is None:
cat_ids = torch.arange(len(self.categorical_attributes_map[attribute_key]), device=self.device)
else:
cat_ids = torch.LongTensor([self.categorical_attributes_map[attribute_key][attribute_category]]).to(
self.device
)
embeddings = self.module.categorical_embeddings[attribute_key](cat_ids).detach().cpu().numpy()
return embeddings
[docs]
def save(
self,
dir_path: Optional[str] = None,
overwrite: bool = False,
save_anndata: bool = False,
**anndata_save_kwargs: Any,
) -> None:
"""Save the model.
Parameters
----------
dir_path
Directory where to save the model. If :obj:`None`, it will be determined automatically.
overwrite
Whether to overwrite an existing model.
save_anndata
Whether to also save :class:`~anndata.AnnData`.
anndata_save_kwargs
Keyword arguments :meth:`scvi.model.base.BaseModelClass.save`.
Returns
-------
Nothing, just saves the model.
"""
if dir_path is None:
dir_path = (
f"./{self.__class__.__name__}_model/"
if self.model_name is None
else f"./{self.__class__.__name__}_{self.model_name}_model/"
)
super().save(
dir_path=dir_path,
overwrite=overwrite,
save_anndata=save_anndata,
**anndata_save_kwargs,
)
if isinstance(self.training_plan.epoch_history, dict):
self.epoch_history = pd.DataFrame().from_dict(self.training_plan.epoch_history)
self.epoch_history.to_csv(os.path.join(dir_path, "history.csv"), index=False)
[docs]
@classmethod
def load(
cls,
dir_path: str,
adata: Optional[AnnData] = None,
accelerator: str = "auto",
device: Union[int, list[int], str] = "auto",
**kwargs: Any,
) -> "Biolord":
"""Load a saved model.
Parameters
----------
dir_path
Directory where the model is saved.
adata
AnnData organized in the same way as data used to train model.
accelerator
Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu",
"mps, "auto") as well as custom accelerator instances.
device
The device to use. Can be set to a positive number (int or str), or ``"auto"``
for automatic selection based on the chosen accelerator.
kwargs
Keyword arguments for :meth:`scvi`
Returns
-------
The saved model.
"""
model = super().load(dir_path, adata, accelerator=accelerator, device=device, **kwargs)
Biolord.categorical_attributes_map = model.categorical_attributes_map
Biolord.ordered_attributes_map = model.ordered_attributes_map
fname = os.path.join(dir_path, "history.csv")
if os.path.isfile(fname):
model.epoch_history = pd.read_csv(fname)
else:
logger.warning(f"The history file `{fname}` was not found")
return model
[docs]
def train(
self,
max_epochs: Optional[int] = None,
accelerator: str = "auto",
device: Union[int, list[int], str] = "auto",
train_size: float = 0.9,
validation_size: Optional[float] = None,
plan_kwargs: Optional[dict[str, Any]] = None,
batch_size: int = 128,
early_stopping: bool = False,
**trainer_kwargs: Any,
) -> None:
"""Train the :class:`~biolord.Biolord` model.
Parameters
----------
max_epochs
Maximum number of epochs for training.
accelerator
Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu",
"mps, "auto") as well as custom accelerator instances.
device
The device to use. Can be set to a positive number (int or str), or ``"auto"``
for automatic selection based on the chosen accelerator.
train_size
Fraction of training data in the case of randomly splitting dataset to train/validation
if :attr:`split_key` is not set in model's constructor.
validation_size
Fraction of validation data in the case of randomly splitting dataset to train/validation
if :attr:`split_key` is not set in model's constructor.
batch_size
Size of mini-batches for training.
early_stopping
If `True`, early stopping will be used during training on validation dataset.
plan_kwargs
Keyword arguments for :class:`~scvi.train.TrainingPlan`.
trainer_kwargs
Keyword arguments for :class:`~scvi.train.TrainRunner`.
Returns
-------
Nothing, just trains the :class:`~biolord.Biolord` model.
"""
plan_kwargs = plan_kwargs if plan_kwargs is not None else {}
if self.train_classifiers:
self.training_plan = biolordClassifyTrainingPlan(
module=self.module,
**plan_kwargs,
)
else:
self.training_plan = biolordTrainingPlan(
module=self.module,
**plan_kwargs,
)
monitor = trainer_kwargs.pop("monitor", "val_biolord_metric")
save_ckpt_every_n_epoch = trainer_kwargs.pop("save_ckpt_every_n_epoch", 20)
enable_checkpointing = trainer_kwargs.pop("enable_checkpointing", False)
trainer_kwargs["callbacks"] = [] if "callbacks" not in trainer_kwargs.keys() else trainer_kwargs["callbacks"]
if enable_checkpointing:
checkpointing_callback = ModelCheckpoint(monitor=monitor, every_n_epochs=save_ckpt_every_n_epoch)
trainer_kwargs["callbacks"] += [checkpointing_callback]
num_workers = trainer_kwargs.pop("num_workers", 0)
if max_epochs is None:
n_cells = self.adata.n_obs
max_epochs = np.min([round((20000 / n_cells) * 400), 400])
manual_splitting = (
(self.valid_indices is not None) and (self.train_indices is not None) and (self.test_indices is not None)
)
if manual_splitting:
self.data_splitter = AnnDataSplitter(
self.adata_manager,
train_indices=self.train_indices,
valid_indices=self.valid_indices,
test_indices=self.test_indices,
batch_size=batch_size,
num_workers=num_workers,
)
else:
self.data_splitter = DataSplitter(
self.adata_manager,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
num_workers=num_workers,
)
es = "early_stopping"
trainer_kwargs[es] = early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
trainer_kwargs["check_val_every_n_epoch"] = trainer_kwargs.get("check_val_every_n_epoch", 1)
trainer_kwargs["early_stopping_patience"] = trainer_kwargs.get("early_stopping_patience", 20)
root_dir = logging_dir
root_dir = (
os.path.join(root_dir, f"{self.__class__.__name__}/")
if self.model_name is None
else os.path.join(root_dir, f"{self.model_name}_{self.__class__.__name__}/")
)
trainer_kwargs["default_root_dir"] = trainer_kwargs.pop("default_root_dir", root_dir)
runner = TrainRunner(
self,
training_plan=self.training_plan,
data_splitter=self.data_splitter,
max_epochs=max_epochs,
accelerator=accelerator,
devices=device,
early_stopping_monitor=monitor,
early_stopping_mode="max",
enable_checkpointing=enable_checkpointing,
checkpointing_monitor=monitor,
**trainer_kwargs,
)
return runner()
[docs]
@torch.no_grad()
def evaluate_retrieval(
self,
batch_size: int = None,
eval_set: Literal["test", "validation"] = "test",
) -> float:
"""Returns the accuracy of the retrieval task over the pre-defined `retrieval_class`.
Parameters
----------
batch_size
Batch size to use.
eval_set
Evaluation dataset.
Returns
-------
Retrieval accuracy over the evaluation dataset.
"""
k = 1
if self.is_trained_ is False:
raise RuntimeError("Please train the model first.")
batch_size = batch_size if batch_size is not None else self.adata.n_obs
latent_unknown_attributes_train_adata, _ = self.get_latent_representation_adata(
adata=self.adata, indices=self.train_indices, batch_size=batch_size
)
if eval_set == "test":
eval_indices = self.test_indices
elif eval_set == "validation":
eval_indices = self.valid_indices
else:
raise RuntimeError(f"Supports `eval_type` `test` or `validation` but {eval_set} was provided.")
latent_unknown_attributes_test_adata, _ = self.get_latent_representation_adata(
adata=self.adata, indices=eval_indices
)
return self._retrieval_accuracy(
latent_unknown_attributes_test_adata.obs[REGISTRY_KEYS.RETRIEVAL_CLASS],
latent_unknown_attributes_test_adata.X,
latent_unknown_attributes_train_adata.obs[REGISTRY_KEYS.RETRIEVAL_CLASS],
latent_unknown_attributes_train_adata.X,
k=k,
)
@torch.no_grad()
def _retrieval_accuracy(
self,
retrieval_class,
latent_unknown_attributes,
retrieval_attribute_train,
latent_unknown_attributes_train,
k=1,
) -> float:
"""Retrieval accuracy score."""
from sklearn.neighbors import NearestNeighbors
nbrs = NearestNeighbors(n_neighbors=k).fit(latent_unknown_attributes_train)
_, ind_ = nbrs.kneighbors(latent_unknown_attributes)
tot_correct = 0.0
for i_eval, i_train in enumerate(ind_):
is_equal = [(retrieval_attribute_train[i_train[i]] == retrieval_class[i_eval]) for i in range(k)]
is_equal = np.sum(is_equal) / k
tot_correct = tot_correct + is_equal
return tot_correct / len(ind_)
@torch.no_grad()
def _compute_prediction(
self,
adata,
dataset_source,
target_attributes,
) -> tuple[dict[tuple[Any], Any], Any]:
"""Expression prediction over given inputs.
Parameters
----------
adata
An annotated data object containing possible values of the `target_attributes`.
dataset_source
Dataset of cells to "shift", make predictions over.
target_attributes
Attributes to make predictions over.
Returns
-------
The prediction dict for each attribute value and the original expression prediction.
"""
keys = list(
itertools.product(
*[list(self.categorical_attributes_map[attribute_].keys()) for attribute_ in target_attributes]
)
)
layer = "X" if "X" in dataset_source else "layers"
pred_original, _ = self.module.get_expression(dataset_source)
classes_dataset = {}
predictions_dict = {}
for attribute_ in target_attributes:
categories_index = pd.Index(adata.obs[attribute_].values, dtype="category")
classes_dataset[attribute_] = {}
for key_, _ in tqdm(zip(*np.unique(categories_index.values, return_counts=True), strict=True)):
bool_category = categories_index.get_loc(key_)
adata_cur = adata[bool_category, :].copy()
dataset = self.get_dataset(adata_cur)
classes_dataset[attribute_][key_] = dataset[attribute_][0, :]
for key_ in keys:
dataset_comb = {}
n_obs = dataset_source[layer].size(0)
for key_dataset in dataset_source:
dataset_comb[key_dataset] = dataset_source[key_dataset].to(self.device)
for ci, attribute_ in enumerate(target_attributes):
dataset_comb[attribute_] = repeat_n(classes_dataset[attribute_][key_[ci]], n_obs)
pred, _ = self.module.get_expression(dataset_comb)
predictions_dict[key_] = pred
return predictions_dict, pred_original
[docs]
def compute_prediction_adata(
self,
adata: AnnData,
adata_source: AnnData,
target_attributes: list[str],
add_attributes: Optional[list[str]] = None,
) -> AnnData:
"""Expression prediction over given inputs.
Parameters
----------
adata
Annotated data object containing possible values of the ``target_attributes``.
adata_source
Annotated data object we wish to make predictions over, e.g., change their ``target_attributes``.
target_attributes
Attributes to make predictions over.
add_attributes
Additional attributes to add to :attr:`anndata.AnnData.obs` from the original adata
to the prediction adata object.
Returns
-------
Annotated data object containing predictions of the cells in all combinations of the ``target_attributes``.
"""
dataset_source = self.get_dataset(adata_source)
predictions_dict, _ = self._compute_prediction(
adata=adata, dataset_source=dataset_source, target_attributes=target_attributes
)
preds_ = np.concatenate([val.cpu() for key, val in predictions_dict.items()])
adata_preds = AnnData(X=preds_, dtype=preds_.dtype)
for attribute_ in target_attributes:
adata_preds.obs[attribute_] = "Source"
start = 0
obs_names_tmp = adata_preds.obs_names.values
for key_, vals_ in predictions_dict.items():
for ci, _ in enumerate(target_attributes):
adata_preds.obs.iloc[start : start + vals_.shape[0], ci] = key_[ci]
obs_names_tmp[start : start + vals_.shape[0]] = [
obs_name + "_" + "_".join([str(k) for k in key_]) for obs_name in adata_source.obs_names
]
start += vals_.shape[0]
adata_preds.obs_names = obs_names_tmp
for attribute_ in target_attributes:
adata_preds.obs[attribute_] = adata_preds.obs[attribute_].astype("category")
if add_attributes is not None:
for attribute_ in add_attributes:
start = 0
adata_preds.obs[attribute_] = np.nan
for _ in range(int(adata_preds.shape[0] / adata_source.shape[0])):
adata_preds.obs.iloc[start : start + adata_source.shape[0], -1] = adata_source.obs[attribute_]
start += adata_source.shape[0]
adata_preds.obs[attribute_] = adata_preds.obs[attribute_].astype(adata_source.obs[attribute_].dtype)
if f"{attribute_}_colors" in adata_source.uns:
adata_preds.uns[f"{attribute_}_colors"] = adata_source.uns[f"{attribute_}_colors"]
adata_preds.var_names = adata_source.var_names
return adata_preds
def __repr__(self) -> str:
buffer = io.StringIO()
summary_string = f"{self._model_summary_string} training status: "
summary_string += "{}".format("[green]Trained[/]" if self.is_trained else "[red]Not trained[/]")
console = rich.console.Console(file=buffer)
with console.capture() as capture:
console.print(summary_string)
return capture.get()