scpdac.tl.HierarchicalClassifier#

class scpdac.tl.HierarchicalClassifier(root, malignant, non_malignant, device='cpu')#

Bases: object

Wraps the three MLPs of the hierarchical classifier for one species.

Construct via from_species() to load the packaged weights, or pass three checkpoint dicts directly (useful for testing).

Parameters:
  • root (dict) – Checkpoint dict for the root (malignant vs non-malignant) model.

  • malignant (dict) – Checkpoint dict for the malignant Level-4 sub-classifier.

  • non_malignant (dict) – Checkpoint dict for the non-malignant Level-4 sub-classifier.

  • device (str (default: 'cpu')) – Torch device to run inference on.

classmethod from_species(species, device='cpu')#

Load the packaged classifier checkpoints for species.

Return type:

HierarchicalClassifier

predict(adata, layer='log1p_norm')#

Run the 2-step hierarchy.

Return type:

tuple[ndarray, ndarray]

Returns:

A tuple (malignant_labels, celltype_labels) of length adata.n_obs.