biolord.BiolordClassifyModule#

class biolord.BiolordClassifyModule(categorical_attributes_missing=None, classify_all=False, logits=False, bias=True, classification_penalty=0.1, classifier_penalty=0.0001, classifier_nn_width=128, classifier_nn_depth=2, classifier_dropout_rate=0.1, loss_regression='normal', **kwargs)[source]#

The biolord-classify module.

A BiolordModule accompanied by regressors for ordered classes and classifiers for categorical classes.

Parameters:
  • categorical_attributes_missing (Optional[dict[str, str]]) – Categorical categories representing un-labeled cells.

  • classify_all (bool) – Whether to classify all classes or only semi-supervised classes.

  • logits (bool) – Classifier output type.

  • bias (bool) – Whether to add bias to the regressor.

  • classification_penalty (float) – Classification penalty strength.

  • classifier_nn_width (int) – Classifier’s layer width.

  • classifier_nn_depth (int) – Classifier’s number of layers.

  • classifier_dropout_rate (float) – Classifier’s dropout rate.

  • loss_regression (Literal['normal', 'mse']) – Loss function for regressors

  • kwargs (Any) – Keyword arguments for BiolordModule.

Attributes table#

Methods table#

classify(genes)

Run classification.

loss(tensors, inference_outputs, ...)

Compute the loss.

Attributes#

Methods#

classify#

BiolordClassifyModule.classify(genes)[source]#

Run classification.

Parameters:

genes (Tensor) – Gene expression used for classification.

Return type:

dict[str, Tensor]

Returns:

Classification output, probability for each ordered attribute and the regression value for ordered attributes

loss#

BiolordClassifyModule.loss(tensors, inference_outputs, generative_outputs)[source]#

Compute the loss.

Parameters:
  • tensors (dict[str, Tensor]) – Considered model inputs.

  • inference_outputs (dict[Literal['latent_unknown_attributes'], Tensor]) – Inference step outputs.

  • generative_outputs (dict[Literal['distribution', 'means', 'variances'], Tensor]) – Generative step outputs.

Return type:

dict[str, float]

Returns:

The loss elements.