import gc
import numpy as np
import scipy.sparse as sp
import torch
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset
from tqdm import tqdm
from .uot_solver import compute_uot_coupling
[docs]
class DeepSpatialDataset(Dataset):
"""
DeepSpatial Global Trajectory Dataset.
Constructs cross-slice cell pairs using Unbalanced Optimal Transport (UOT)
to provide continuous training samples for Flow Matching.
"""
[docs]
def __init__(self,
adata_list: list,
spatial_key: str = 'spatial_norm',
z_key: str = 'z_norm',
label_key: str = 'cell_class',
n_samples_base: int = 50000,
alpha_spatial: float = 0.5,
uot_reg: float = 0.8,
uot_tau: float = 0.05,
mode: str = 'fit'):
"""
Args:
adata_list: List of AnnData objects.
spatial_key: Key in .obsm containing normalized XY coordinates.
z_key: Key in .obs containing normalized Z coordinates.
label_key: Key in .obs for cell labels/types.
n_samples_base: Target total number of trajectory pairs.
alpha_spatial: Balance between spatial and gene distances for UOT.
uot_reg: Entropy regularization for UOT solver.
uot_tau: Marginal relaxation for UOT solver.
mode: 'fit' for full multi-slice training, 'predict' for limited slice pairs.
"""
self.adata_list = adata_list
self.spatial_key = spatial_key
self.z_key = z_key
self.label_key = label_key
self.n_samples_base = n_samples_base
self.alpha_spatial = alpha_spatial
self.uot_reg = uot_reg
self.uot_tau = uot_tau
# --- 1. Global Label Encoding ---
self.label_encoder = LabelEncoder()
all_labels = []
for adata in adata_list:
if label_key in adata.obs:
all_labels.extend(adata.obs[label_key].astype(str).tolist())
if all_labels:
self.label_encoder.fit(all_labels)
self.num_classes = len(self.label_encoder.classes_)
self.id2label = {i: label for i, label in enumerate(self.label_encoder.classes_)}
else:
self.num_classes = 1
self.id2label = {0: "unknown"}
if mode != 'fit':
self.adata_list = self.adata_list[:2]
# Containers for trajectory pairs
self.trajectory_pairs = {
'x0': [], 'g0': [], 'c0': [], 'z0': [],
'x1': [], 'g1': [], 'c1': [], 'z1': [],
'delta_z': [],
}
self._build_trajectory_dataset()
self._convert_to_tensors()
def _get_data_arrays(self, adata):
"""Extracts spatial, gene expression, and one-hot labels from AnnData."""
x = adata.obsm[self.spatial_key].astype(np.float32)
g = adata.X.toarray().astype(np.float32) if sp.issparse(adata.X) else adata.X.astype(np.float32)
if self.label_key in adata.obs:
raw_labels = adata.obs[self.label_key].astype(str).values
indices = self.label_encoder.transform(raw_labels)
c_onehot = np.eye(self.num_classes)[indices].astype(np.float32)
else:
c_onehot = np.zeros((adata.n_obs, self.num_classes), dtype=np.float32)
# Z coordinate is assumed uniform across a single slice
z = float(adata.obs[self.z_key].iloc[0])
return x, g, c_onehot, z
def _build_trajectory_dataset(self):
"""Pairs cells across adjacent slices using UOT coupling."""
num_slices = len(self.adata_list)
# Allocate samples based on slice product weights
pair_sizes = [self.adata_list[k].n_obs * self.adata_list[k+1].n_obs for k in range(num_slices - 1)]
total_weight = sum(pair_sizes)
sampling_counts = [int(self.n_samples_base * (w / total_weight)) for w in pair_sizes]
for k in tqdm(range(num_slices - 1), desc="DeepSpatial: Building Trajectories"):
n_to_sample = sampling_counts[k]
if n_to_sample <= 0:
continue
x0, g0, c0, z0 = self._get_data_arrays(self.adata_list[k])
x1, g1, c1, z1 = self._get_data_arrays(self.adata_list[k+1])
delta_z = z1 - z0
# Compute Unbalanced Optimal Transport (UOT)
pi = compute_uot_coupling(
x0, g0, c0, x1, g1, c1,
alpha_spatial=self.alpha_spatial,
uot_reg=self.uot_reg,
uot_tau=self.uot_tau
)
# Sampling logic
pi_flat = pi.ravel()
pi_sum = pi_flat.sum()
if pi_sum > 0:
pi_prob = pi_flat / pi_sum
idx_flat = np.random.choice(len(pi_flat), size=n_to_sample, p=pi_prob, replace=True)
idx0, idx1 = np.unravel_index(idx_flat, pi.shape)
# Store trajectory endpoints
self.trajectory_pairs['x0'].append(x0[idx0])
self.trajectory_pairs['g0'].append(g0[idx0])
self.trajectory_pairs['c0'].append(c0[idx0])
self.trajectory_pairs['z0'].append(np.full((n_to_sample, 1), z0, dtype=np.float32))
self.trajectory_pairs['x1'].append(x1[idx1])
self.trajectory_pairs['g1'].append(g1[idx1])
self.trajectory_pairs['c1'].append(c1[idx1])
self.trajectory_pairs['z1'].append(np.full((n_to_sample, 1), z1, dtype=np.float32))
self.trajectory_pairs['delta_z'].append(np.full((n_to_sample, 1), delta_z, dtype=np.float32))
del pi, pi_flat, x0, g0, c0, x1, g1, c1
gc.collect()
def _convert_to_tensors(self):
"""Aggregates list of arrays into final PyTorch tensors."""
self.tensors = {}
for key in self.trajectory_pairs:
if self.trajectory_pairs[key]:
# Use np.concatenate for efficiency before converting to tensor
concatenated = np.concatenate(self.trajectory_pairs[key], axis=0)
self.tensors[key] = torch.from_numpy(concatenated)
self.trajectory_pairs.clear()
gc.collect()
if 'x0' in self.tensors:
self.num_samples = self.tensors['x0'].shape[0]
else:
self.num_samples = 0
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {k: v[idx] for k, v in self.tensors.items()}
def decode_label(self, one_hot_vec):
"""Converts model one-hot output back to original string label."""
idx = torch.argmax(one_hot_vec, dim=-1).item()
return self.id2label[idx]