deepspatial.data_utils.DeepSpatialDataset

deepspatial.data_utils.DeepSpatialDataset#

class deepspatial.data_utils.DeepSpatialDataset(adata_list: list, spatial_key: str = 'spatial_norm', z_key: str = 'z_norm', label_key: str = 'cell_class', n_samples_base: int = 50000, alpha_spatial: float = 0.5, uot_reg: float = 0.8, uot_tau: float = 0.05, mode: str = 'fit')[source]#

DeepSpatial Global Trajectory Dataset. Constructs cross-slice cell pairs using Unbalanced Optimal Transport (UOT) to provide continuous training samples for Flow Matching.

__init__(adata_list: list, spatial_key: str = 'spatial_norm', z_key: str = 'z_norm', label_key: str = 'cell_class', n_samples_base: int = 50000, alpha_spatial: float = 0.5, uot_reg: float = 0.8, uot_tau: float = 0.05, mode: str = 'fit')[source]#
Parameters:
  • adata_list – List of AnnData objects.

  • spatial_key – Key in .obsm containing normalized XY coordinates.

  • z_key – Key in .obs containing normalized Z coordinates.

  • label_key – Key in .obs for cell labels/types.

  • n_samples_base – Target total number of trajectory pairs.

  • alpha_spatial – Balance between spatial and gene distances for UOT.

  • uot_reg – Entropy regularization for UOT solver.

  • uot_tau – Marginal relaxation for UOT solver.

  • mode – ‘fit’ for full multi-slice training, ‘predict’ for limited slice pairs.

Methods

__init__(adata_list[, spatial_key, z_key, ...])

decode_label(one_hot_vec)

Converts model one-hot output back to original string label.