scpdac.tl.HierarchicalClassifier#
- class scpdac.tl.HierarchicalClassifier(root, malignant, non_malignant, device='cpu')#
Bases:
objectWraps the three MLPs of the hierarchical classifier for one species.
Construct via
from_species()to load the packaged weights, or pass three checkpoint dicts directly (useful for testing).- Parameters:
root (
dict) – Checkpoint dict for the root (malignant vs non-malignant) model.malignant (
dict) – Checkpoint dict for the malignant Level-4 sub-classifier.non_malignant (
dict) – Checkpoint dict for the non-malignant Level-4 sub-classifier.device (
str(default:'cpu')) – Torch device to run inference on.
- classmethod from_species(species, device='cpu')#
Load the packaged classifier checkpoints for
species.- Return type: