Source code for deepspatial.data_utils.uot_solver

import gc
import numpy as np
import ot
import scipy.sparse as sp

[docs] def compute_cost_matrix(x0, g0, c0, x1, g1, c1, alpha_spatial=0.5): """ Computes a hybrid cost matrix balancing spatial distance, gene expression, and cell types. Parameters ---------- x0 : numpy.ndarray or torch.Tensor Spatial coordinates for the source slice, shape `(N0, 2)`. g0 : numpy.ndarray or torch.Tensor Gene expression matrix for the source slice, shape `(N0, G)`. c0 : numpy.ndarray or torch.Tensor One-hot encoded cell types for the source slice, shape `(N0, C)`. x1 : numpy.ndarray or torch.Tensor Spatial coordinates for the target slice, shape `(N1, 2)`. g1 : numpy.ndarray or torch.Tensor Gene expression matrix for the target slice, shape `(N1, G)`. c1 : numpy.ndarray or torch.Tensor One-hot encoded cell types for the target slice, shape `(N1, C)`. alpha_spatial : float, optional Weight balancing the spatial vs. gene distance (between 0 and 1). By default 0.5. Returns ------- numpy.ndarray The combined cost matrix of shape `(N0, N1)`. """ eps = 1e-9 # Spatial distance (Euclidean) # x: [N0, 2], y: [N1, 2] -> M: [N0, N1] cost_spatial = ot.dist(x0, x1, metric='euclidean') # Robust normalization to [0, 1] s_max = cost_spatial.max() cost_spatial = cost_spatial / (s_max + eps) if s_max > 0 else cost_spatial # Gene expression distance (Cosine) # Cosine distance is naturally bounded, but we ensure robustness cost_gene = ot.dist(g0, g1, metric='cosine') g_max = cost_gene.max() cost_gene = cost_gene / (g_max + eps) if g_max > 0 else cost_gene # Cell type cost (One-hot dot product) # 1.0 - dot(c0, c1.T) gives 0 for same type, 1 for different types c0_np = c0.numpy() if hasattr(c0, 'numpy') else c0 c1_np = c1.numpy() if hasattr(c1, 'numpy') else c1 # Higher values mean higher cost to pair different cell types cost_class = np.clip(1.0 - np.dot(c0_np, c1_np.T), 0, 1) # Combined cost matrix with implicit large penalty for class mismatch # We use a large coefficient (e.g., 10.0) for class mismatch to act as a hard constraint C = (alpha_spatial * cost_spatial) + ((1 - alpha_spatial) * cost_gene) + (10.0 * cost_class) # Clean up large intermediate distance matrices del cost_spatial, cost_gene, cost_class gc.collect() return C
[docs] def compute_uot_coupling(x0, g0, c0, x1, g1, c1, alpha_spatial=0.5, uot_reg=0.8, uot_tau=0.05): """ Computes the Unbalanced Optimal Transport (UOT) coupling matrix between two slices. Parameters ---------- x0 : numpy.ndarray or torch.Tensor Spatial coordinates for the source slice. g0 : numpy.ndarray or torch.Tensor Gene expression matrix for the source slice. c0 : numpy.ndarray or torch.Tensor One-hot encoded cell types for the source slice. x1 : numpy.ndarray or torch.Tensor Spatial coordinates for the target slice. g1 : numpy.ndarray or torch.Tensor Gene expression matrix for the target slice. c1 : numpy.ndarray or torch.Tensor One-hot encoded cell types for the target slice. alpha_spatial : float, optional Weight for spatial distance in the underlying cost matrix. By default 0.5. uot_reg : float, optional Entropy regularization parameter. Higher values lead to smoother, more dispersed couplings; lower values yield sparser, more deterministic matchings. By default 0.8. uot_tau : float, optional Marginal relaxation weight (KL divergence penalty). Controls how strictly the mass conservation is enforced. Smaller values allow more mass creation/destruction. By default 0.05. Returns ------- numpy.ndarray The calculated optimal transport coupling matrix `pi` of shape `(N0, N1)`. """ # Generate the cost matrix C = compute_cost_matrix(x0, g0, c0, x1, g1, c1, alpha_spatial) # Define marginal weights (uniform distribution) n0, n1 = x0.shape[0], x1.shape[0] a = np.ones(n0) / n0 b = np.ones(n1) / n1 # Solve Unbalanced Sinkhorn # reg_m (uot_tau) controls how much we allow marginals to be violated pi = ot.unbalanced.sinkhorn_unbalanced( a, b, C, reg=max(uot_reg, 0.01), reg_m=uot_tau, numItermax=100, stopThr=1e-3, verbose=False ) # 4. Final memory cleanup del C, a, b gc.collect() return pi