deepspatial.data_utils.DeepSpatialDataset#
- class deepspatial.data_utils.DeepSpatialDataset(adata_list: list, spatial_key: str = 'spatial_norm', z_key: str = 'z_norm', label_key: str = 'cell_class', n_samples_base: int = 50000, alpha_spatial: float = 0.5, uot_reg: float = 0.8, uot_tau: float = 0.05, mode: str = 'fit')[source]#
DeepSpatial Global Trajectory Dataset. Constructs cross-slice cell pairs using Unbalanced Optimal Transport (UOT) to provide continuous training samples for Flow Matching.
- __init__(adata_list: list, spatial_key: str = 'spatial_norm', z_key: str = 'z_norm', label_key: str = 'cell_class', n_samples_base: int = 50000, alpha_spatial: float = 0.5, uot_reg: float = 0.8, uot_tau: float = 0.05, mode: str = 'fit')[source]#
- Parameters:
adata_list – List of AnnData objects.
spatial_key – Key in .obsm containing normalized XY coordinates.
z_key – Key in .obs containing normalized Z coordinates.
label_key – Key in .obs for cell labels/types.
n_samples_base – Target total number of trajectory pairs.
alpha_spatial – Balance between spatial and gene distances for UOT.
uot_reg – Entropy regularization for UOT solver.
uot_tau – Marginal relaxation for UOT solver.
mode – ‘fit’ for full multi-slice training, ‘predict’ for limited slice pairs.
Methods
__init__(adata_list[, spatial_key, z_key, ...])decode_label(one_hot_vec)Converts model one-hot output back to original string label.