import os
import json
import torch
import numpy as np
import pandas as pd
import anndata as ad
import scipy.sparse
import pytorch_lightning as pl
from tqdm import tqdm
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from .data_utils import DeepSpatialDataset
from .models import GiT
from .module import DeepSpatialModule
[docs]
class DeepSpatial:
"""
DeepSpatial: Reconstructing True 3D Spatial Omics at Single-Cell Resolution.
"""
[docs]
def __init__(self):
"""
Initializes an empty DeepSpatial instance.
Configurations and paths are injected during functional method calls.
"""
# Data and Dataloader
self.dataset = None
self.train_loader = None
# Model and Training state
self.gene_dim = None
self.num_classes = None
self.model = None
self.module = None
# Metadata for persistence and restoration
self.categories = None
self.spatial_stats = None
self.spatial_key = None
self.z_key = None
self.label_key = None
# Configuration dictionaries
self.model_config = {}
self.train_config = {}
def _normalize_spatial(self, adata_list: list[ad.AnnData]) -> None:
"""
Extracts coordinates from user keys and computes normalization stats.
Stores normalized data in internal 'spatial_norm' and 'z_norm' keys.
"""
z_raw_list = []
all_spatial = []
for i, adata in enumerate(adata_list):
if self.spatial_key not in adata.obsm:
raise KeyError(f"Spatial key '{self.spatial_key}' not found in slice {i}.")
if self.z_key not in adata.obs:
raise KeyError(f"Z-coord key '{self.z_key}' not found in slice {i}.")
all_spatial.append(adata.obsm[self.spatial_key])
z_raw_list.append(adata.obs[self.z_key].iloc[0])
all_spatial = np.vstack(all_spatial)
z_raw_arr = np.array(z_raw_list)
# Store stats for physical space restoration
self.spatial_stats = {
'x_min': float(all_spatial[:, 0].min()),
'x_range': float(all_spatial[:, 0].max() - all_spatial[:, 0].min() + 1e-8),
'y_min': float(all_spatial[:, 1].min()),
'y_range': float(all_spatial[:, 1].max() - all_spatial[:, 1].min() + 1e-8),
'z_min': float(z_raw_arr.min()),
'z_range': float(z_raw_arr.max() - z_raw_arr.min() + 1e-8)
}
norm_z_arr = (z_raw_arr - self.spatial_stats['z_min']) / self.spatial_stats['z_range']
# Perform non-destructive normalization
for i, adata in enumerate(adata_list):
coords = adata.obsm[self.spatial_key].copy()
coords[:, 0] = (coords[:, 0] - self.spatial_stats['x_min']) / self.spatial_stats['x_range']
coords[:, 1] = (coords[:, 1] - self.spatial_stats['y_min']) / self.spatial_stats['y_range']
adata.obsm['spatial_norm'] = coords
adata.obs['z_norm'] = norm_z_arr[i]
[docs]
def setup_data(self,
adata_list: list[ad.AnnData],
spatial_key: str = 'spatial',
z_key: str = 'z_coord',
label_key: str = 'cell_class',
batch_size: int = 128,
num_workers: int = 4,
n_samples_base: int = 50000,
alpha_spatial: float = 0.5,
uot_reg: float = 0.8,
uot_tau: float = 0.05,
mode: str = 'fit'):
"""
Prepares the data pipeline and calculates physical normalization statistics.
Args:
adata_list: List of AnnData objects (slices).
spatial_key: Key in `.obsm` for XY coordinates.
z_key: Key in `.obs` for the physical Z coordinate.
label_key: Key in `.obs` for cell type annotations.
batch_size: Number of samples per training batch.
num_workers: Multi-process data loading workers.
n_samples_base: Base number of cell pairs to sample per slice pair.
alpha_spatial: UOT spatial distance weight.
uot_reg: Entropy regularization for UOT.
uot_tau: Marginal relaxation for UOT.
mode: Dataset mode ('fit' for training, 'predict' for inference).
"""
self.spatial_key = spatial_key
self.z_key = z_key
self.label_key = label_key
self._normalize_spatial(adata_list)
self.dataset = DeepSpatialDataset(
adata_list=adata_list,
spatial_key='spatial_norm',
z_key='z_norm',
label_key=label_key,
n_samples_base=n_samples_base,
alpha_spatial=alpha_spatial,
uot_reg=uot_reg,
uot_tau=uot_tau,
mode=mode
)
self.categories = pd.Index(self.dataset.label_encoder.classes_)
self.train_loader = DataLoader(
self.dataset, batch_size=batch_size,
shuffle=True, num_workers=num_workers
)
# Infer dimensions from the first batch
batch = next(iter(self.train_loader))
self.gene_dim = batch['g0'].shape[1]
self.num_classes = batch['c0'].shape[1]
[docs]
def build_model(self,
patch_size: int = 8,
hidden_size: int = 256,
depth: int = 6,
num_heads: int = 8,
mlp_ratio: float = 4.0,
path_type: str = "Linear",
lr: float = 2e-4,
weight_decay: float = 1e-5,
lambda_g: float = 0.1,
lambda_c: float = 10.0,
sampling_method: str = "dopri5",
atol: float = 1e-5,
rtol: float = 1e-5):
"""
Instantiates the GiT network architecture and Flow Matching logic.
Args:
patch_size: Tokenization patch size for spatial coordinates.
hidden_size: Transformer embedding dimension.
depth: Number of transformer layers.
num_heads: Number of attention heads.
mlp_ratio: Expansion ratio for MLP layers.
path_type: Probability path type for Flow Matching.
lr: Learning rate for training.
weight_decay: L2 regularization weight.
lambda_g: Loss weight for gene expression reconstruction.
lambda_c: Loss weight for cell type classification.
sampling_method: ODE solver for inference (e.g., 'dopri5', 'euler').
atol: Absolute tolerance for ODE solver.
rtol: Relative tolerance for ODE solver.
"""
if self.gene_dim is None or self.num_classes is None:
raise ValueError("Dimensions unknown. Call `setup_data()` first.")
self.model_config = {
"patch_size": patch_size,
"hidden_size": hidden_size,
"depth": depth,
"num_heads": num_heads,
"mlp_ratio": mlp_ratio
}
self.train_config = {
"path_type": path_type,
"prediction": "velocity",
"train_eps": 0.02,
"sample_eps": 0.02,
"ema_decay": 0.999,
"lr": lr,
"weight_decay": weight_decay,
"lambda_g": lambda_g,
"lambda_c": lambda_c,
"sampling_method": sampling_method,
"atol": atol,
"rtol": rtol
}
self.model = GiT(
gene_dim=self.gene_dim,
num_classes=self.num_classes,
**self.model_config
)
self.module = DeepSpatialModule(self.train_config, self.model)
def _save_config(self, save_dir: str):
"""Internal helper to persist metadata and configurations."""
os.makedirs(save_dir, exist_ok=True)
config = {
'gene_dim': self.gene_dim,
'num_classes': self.num_classes,
'spatial_stats': self.spatial_stats,
'categories': self.categories.tolist() if self.categories is not None else None,
'spatial_key': self.spatial_key,
'z_key': self.z_key,
'label_key': self.label_key,
'model_config': self.model_config,
'train_config': self.train_config
}
with open(os.path.join(save_dir, 'config.json'), 'w') as f:
json.dump(config, f)
[docs]
def load_checkpoint(self, ckpt_path: str, config_path: str = None, sampling_method: str = "dopri5"):
"""
Loads model weights and metadata for inference or resuming.
Args:
ckpt_path: Path to the `.ckpt` file.
config_path: Path to `config.json`. Defaults to the same folder as ckpt_path.
sampling_method: Overrides the ODE solver for this inference session.
"""
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
if config_path is None:
config_path = os.path.join(os.path.dirname(ckpt_path), 'config.json')
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
self.gene_dim = config['gene_dim']
self.num_classes = config['num_classes']
self.spatial_stats = config['spatial_stats']
self.spatial_key = config.get('spatial_key', 'spatial')
self.z_key = config.get('z_key', 'z_coord')
self.label_key = config['label_key']
self.categories = pd.Index(config['categories'])
self.model_config = config.get('model_config', {})
self.train_config = config.get('train_config', {})
else:
raise ValueError("Metadata 'config.json' not found. Cannot rebuild model.")
if self.module is None:
self.train_config['sampling_method'] = sampling_method
self.build_model(**self.model_config, **self.train_config)
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
state_dict = checkpoint.get('state_dict', checkpoint)
self.module.load_state_dict(state_dict)
[docs]
def fit(self,
max_epochs: int = 100,
save_dir: str = "./checkpoints",
accelerator: str = 'auto',
devices: str = 'auto',
save_ckpt: bool = False,
resume_ckpt_path: str = None):
"""
Executes the training loop using PyTorch Lightning.
Args:
max_epochs: Max training epochs.
save_dir: Directory to save checkpoints and metadata.
accelerator: Hardware accelerator ('auto', 'gpu', 'cpu', 'mps').
devices: Number of devices or indices ('auto', 1, [0, 1]).
save_ckpt: Whether to save progress.
resume_ckpt_path: Path to resume full training state.
"""
if save_ckpt:
self._save_config(save_dir)
callbacks = []
if save_ckpt:
callbacks.append(ModelCheckpoint(
dirpath=save_dir,
monitor="loss",
filename="deepspatial-{epoch:02d}-{loss:.4f}",
save_top_k=1,
mode="min"
))
trainer = pl.Trainer(
max_epochs=max_epochs,
accelerator=accelerator,
devices=devices,
callbacks=callbacks,
enable_checkpointing=save_ckpt,
logger=False
)
trainer.fit(self.module, self.train_loader, ckpt_path=resume_ckpt_path)
def _restore_3d_physical_coords(self, adata_3d: ad.AnnData) -> ad.AnnData:
"""Internal helper to map [0,1] coordinates back to physical scale."""
if self.spatial_stats is None:
raise ValueError("Physical stats missing. Ensure setup_data or load_checkpoint was called.")
# Restore XY in obsm[spatial_key]
adata_3d.obsm[self.spatial_key][:, 0] = adata_3d.obsm[self.spatial_key][:, 0] * self.spatial_stats['x_range'] + self.spatial_stats['x_min']
adata_3d.obsm[self.spatial_key][:, 1] = adata_3d.obsm[self.spatial_key][:, 1] * self.spatial_stats['y_range'] + self.spatial_stats['y_min']
# Restore Z in obs[z_key]
adata_3d.obs[self.z_key] = adata_3d.obs[self.z_key] * self.spatial_stats['z_range'] + self.spatial_stats['z_min']
adata_3d.obs_names = [f"cell_3d_z{z:.4f}_{i}" for i, z in enumerate(adata_3d.obs[self.z_key])]
return adata_3d
[docs]
@torch.no_grad()
def reconstruct_between_slices(self,
adata0: ad.AnnData,
adata1: ad.AnnData,
thickness: float,
steps: int = 20,
chunk_size: int = 2048,
device: str = "auto") -> ad.AnnData:
"""
Generates a 3D volume segment between two specific AnnData slices.
Args:
adata0: Source slice AnnData (must contain 'spatial_norm' and 'z_norm').
adata1: Target slice AnnData (must contain 'spatial_norm' and 'z_norm').
steps: Number of integration steps for the ODE solver.
thickness: Physical distance (um) between generated cells, controlling density.
chunk_size: Batch size for ODE integration to manage VRAM usage.
device: Computing device ('auto', 'cuda', 'cpu', or specific 'cuda:n').
Returns:
An AnnData object containing the interpolated 3D segment in physical coordinates.
"""
# Device management
if device == "auto":
dev = self.module.device if self.module.device.type != 'cpu' else \
torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
dev = torch.device(device)
if self.module.device != dev:
self.module.to(dev)
self.module.eval()
# Extract features and normalized Z coordinates
x0, z0, g0, c0, x1, z1, g1, c1, target_cells, total_cells = self._setup_and_extract(
adata0, adata1, thickness, dev
)
# Execute chunked ODE integration
mix_data = self._generate_and_prune_optimized(
x0, g0, c0, x1, g1, c1, z0, z1, steps,
target_cells, total_cells, dev, chunk_size
)
# Assemble and restore to physical coordinates
adata_segment = self._assemble_fast_anndata(adata0, adata1, mix_data)
return self._restore_3d_physical_coords(adata_segment)
[docs]
@torch.no_grad()
def reconstruct_full_volume(self,
adata_list: list,
thickness: float,
steps: int = 100,
chunk_size: int = 2048,
device: str = "auto") -> ad.AnnData:
"""
High-level API to reconstruct the entire 3D volume from a list of slices.
Args:
adata_list: Ordered list of AnnData slices.
thickness: Target physical distance (um) between cells in the Z-axis.
steps: Number of ODE integration steps per gap.
chunk_size: Processing batch size to prevent OOM on large datasets.
device: Target device for the entire reconstruction process.
Returns:
A single merged AnnData object representing the continuous 3D volume.
"""
from tqdm import tqdm
segment_list = []
num_pairs = len(adata_list) - 1
if num_pairs < 1:
raise ValueError("adata_list must contain at least 2 slices.")
# Initialize progress bar
pbar = tqdm(range(num_pairs), desc="DeepSpatial: 3D Reconstruct", unit="gap")
for i in pbar:
ad0, ad1 = adata_list[i], adata_list[i+1]
# Display current physical range in progress bar
z_start = ad0.obs[self.z_key].iloc[0]
z_end = ad1.obs[self.z_key].iloc[0]
pbar.set_postfix({"range": f"{z_start:.1f}-{z_end:.1f}um"})
# Generate segment
segment = self.reconstruct_between_slices(
adata0=ad0,
adata1=ad1,
steps=steps,
thickness=thickness,
chunk_size=chunk_size,
device=device
)
segment_list.append(segment)
# Concatenate all generated segments
full_volume = ad.concat(segment_list, join='outer', uns_merge='first', index_unique='-')
# Inherit metadata from the reference slice
if hasattr(adata_list[0], 'uns'):
full_volume.uns.update(adata_list[0].uns)
return full_volume
# =========================================================================
# Internal Computation Helpers
# =========================================================================
def _setup_and_extract(self, adata0, adata1, thickness, dev):
"""Calculates sampling density and maps normalized input to device."""
avg_n_ref = (adata0.n_obs + adata1.n_obs) / 2
z0, z1 = adata0.obs[self.z_key].iloc[0], adata1.obs[self.z_key].iloc[0]
physical_gap = abs(z1 - z0)
target_cells = max(1, int(avg_n_ref * (physical_gap / thickness)))
total_cells = int(target_cells)
def extract(adata):
# Extract from normalized keys
x = torch.tensor(adata.obsm['spatial_norm'], dtype=torch.float32, device=dev)
z = adata.obs['z_norm'].iloc[0]
g_arr = adata.X.toarray() if scipy.sparse.issparse(adata.X) else adata.X
g = torch.tensor(g_arr, dtype=torch.float32, device=dev)
c_idx = torch.tensor(pd.Categorical(adata.obs[self.label_key], categories=self.categories).codes.astype(np.int64), device=dev)
return x, z, g, c_idx
x0, z0, g0, c0 = extract(adata0)
x1, z1, g1, c1 = extract(adata1)
return x0, z0, g0, c0, x1, z1, g1, c1, target_cells, total_cells
def _generate_and_prune_optimized(self, x0, g0, c0, x1, g1, c1, z0, z1, steps, target_cells, total_cells, dev, chunk_size):
"""Memory-efficient ODE integration and spatial density pruning."""
N0, N1 = float(x0.shape[0]), float(x1.shape[0])
u = torch.rand(total_cells, device=dev)
# Inverse transform sampling for time distribution
if abs(N0 - N1) < 1e-5:
t_vals = u
else:
t_vals = (-N0 + torch.sqrt(N0**2 * (1 - u) + N1**2 * u)) / (N1 - N0)
target_zs = t_vals * (z1 - z0) + z0
is_fwd = torch.rand(total_cells, device=dev) > t_vals
fwd_ids, bwd_ids = torch.where(is_fwd)[0], torch.where(~is_fwd)[0]
src_fwd = torch.randint(0, x0.shape[0], (len(fwd_ids),), device=dev)
src_bwd = torch.randint(0, x1.shape[0], (len(bwd_ids),), device=dev)
final_x = torch.zeros((total_cells, 2), device=dev)
final_g = torch.zeros((total_cells, g0.shape[1]), device=dev)
final_c = torch.zeros(total_cells, device=dev, dtype=torch.long)
def process_direction(is_forward, src_indices, child_ids, x_ref, g_ref, c_ref, z_start, z_end):
if len(child_ids) == 0: return
unique_parents, inverse_indices = torch.unique(src_indices, return_inverse=True)
c_onehot_ref = torch.nn.functional.one_hot(c_ref, num_classes=self.num_classes).float()
for i in range(0, len(unique_parents), chunk_size):
end_idx = min(i + chunk_size, len(unique_parents))
chunk_parents = unique_parents[i:end_idx]
mask_in_chunk = (inverse_indices >= i) & (inverse_indices < end_idx)
chunk_child_ids = child_ids[mask_in_chunk]
if len(chunk_child_ids) == 0: continue
local_parent_idx = inverse_indices[mask_in_chunk] - i
batch = {
'x0': x_ref[chunk_parents], 'g0': g_ref[chunk_parents], 'c0': c_onehot_ref[chunk_parents],
'z0': torch.full((len(chunk_parents), 1), z_start, device=dev),
'z1': torch.full((len(chunk_parents), 1), z_end, device=dev),
'delta_z': torch.full((len(chunk_parents), 1), z_end - z_start, device=dev)
}
res = self.module.sample(batch, mode="ODE", steps=steps)
if not is_forward:
res['x_traj'] = torch.flip(res['x_traj'], dims=[0])
res['g_traj'] = torch.flip(res['g_traj'], dims=[0])
res['c_traj_discrete'] = torch.flip(res['c_traj_discrete'], dims=[0])
chunk_t_vals = t_vals[chunk_child_ids] * (steps - 1)
idx_low = chunk_t_vals.long()
idx_high = torch.clamp(idx_low + 1, max=steps - 1)
w_high = (chunk_t_vals - idx_low.float()).unsqueeze(1)
final_x[chunk_child_ids] = (1 - w_high) * res['x_traj'][idx_low, local_parent_idx] + w_high * res['x_traj'][idx_high, local_parent_idx]
final_g[chunk_child_ids] = (1 - w_high) * res['g_traj'][idx_low, local_parent_idx] + w_high * res['g_traj'][idx_high, local_parent_idx]
final_c[chunk_child_ids] = res['c_traj_discrete'][torch.clamp(chunk_t_vals.round().long(), max=steps-1), local_parent_idx]
del res, batch
torch.cuda.empty_cache()
process_direction(True, src_fwd, fwd_ids, x0, g0, c0, z0, z1)
process_direction(False, src_bwd, bwd_ids, x1, g1, c1, z1, z0)
fwd_src_keep, bwd_src_keep = src_fwd, src_bwd
fwd_keep_indices, bwd_keep_indices = fwd_ids, bwd_ids
fwd_mask = is_fwd
# Sparsity enforcement
nnz_parent = torch.zeros(target_cells, device=dev, dtype=torch.long)
nnz_parent[fwd_mask] = (g0 > 0).sum(dim=1)[fwd_src_keep]
nnz_parent[~fwd_mask] = (g1 > 0).sum(dim=1)[bwd_src_keep]
max_k = int(nnz_parent.max().item())
if max_k > 0:
for i in range(0, target_cells, chunk_size):
end_idx = min(i + chunk_size, target_cells)
g_chunk, nnz_chunk = final_g[i:end_idx], nnz_parent[i:end_idx]
_, sorted_indices = torch.sort(g_chunk, dim=1, descending=True)
top_indices = sorted_indices[:, :max_k]
mask = torch.arange(max_k, device=dev).unsqueeze(0) < nnz_chunk.unsqueeze(1)
final_mask = torch.zeros_like(g_chunk, dtype=torch.bool).scatter_(1, top_indices, mask)
g_chunk[~final_mask] = 0.0
final_g[i:end_idx] = g_chunk
return {
'x': final_x.cpu().numpy(), 'g': final_g, 'c': final_c.cpu().numpy(),
'zs': target_zs.cpu().numpy(), 'cells': target_cells,
'fwd_src': fwd_src_keep.cpu().numpy(), 'bwd_src': bwd_src_keep.cpu().numpy(),
'fwd_idx': fwd_keep_indices.cpu().numpy(), 'bwd_idx': bwd_keep_indices.cpu().numpy()
}
def _assemble_fast_anndata(self, adata0, adata1, mix_data):
"""Constructs the resulting AnnData with sparse expression matrix."""
nz_indices = mix_data['g'].nonzero(as_tuple=True)
res_g_sparse = scipy.sparse.csr_matrix(
(mix_data['g'][nz_indices].cpu().numpy(), (nz_indices[0].cpu().numpy(), nz_indices[1].cpu().numpy())),
shape=(mix_data['cells'], mix_data['g'].shape[1])
)
obs_fwd = adata0.obs.iloc[mix_data['fwd_src']].copy()
obs_fwd.index = mix_data['fwd_idx']
obs_bwd = adata1.obs.iloc[mix_data['bwd_src']].copy()
obs_bwd.index = mix_data['bwd_idx']
mixed_obs = pd.concat([obs_fwd, obs_bwd]).sort_index()
mixed_obs[self.z_key] = mix_data['zs']
mixed_obs[self.label_key] = pd.Categorical.from_codes(mix_data['c'], categories=self.categories)
mixed_obs.index = mixed_obs.index.astype(str)
return ad.AnnData(
X=res_g_sparse,
obs=mixed_obs,
var=adata0.var.copy(),
obsm={self.spatial_key: mix_data['x']}
)