import torch
import torch.nn as nn
import pytorch_lightning as pl
from copy import deepcopy
from collections import OrderedDict
from .transport import create_transport, Sampler
[docs]
class DeepSpatialModule(pl.LightningModule):
"""
DeepSpatial Module for Training & Inference.
Parameters
----------
args : dict or argparse.Namespace
Configuration dictionary containing hyperparameters such as learning rate,
path type, and sampling settings.
model : torch.nn.Module
The core neural network architecture (e.g., the GiT model) that predicts
the velocity fields.
"""
[docs]
def __init__(self, args, model):
"""
Initializes an LightningModule instance for DeepSpatial.
"""
super().__init__()
self.save_hyperparameters(args)
# Core Model
self.model = model
# EMA Setup
self.ema_decay = self.hparams.get('ema_decay', 0.999)
self.ema_model = deepcopy(self.model)
self._freeze(self.ema_model)
# Transport & Path Setup
self.transport = create_transport(
path_type=self.hparams.path_type,
prediction=self.hparams.prediction,
train_eps=self.hparams.get('train_eps', 0.02),
sample_eps=self.hparams.get('sample_eps', 0.02),
)
self.sampler = Sampler(self.transport)
def _freeze(self, module):
"""Freeze model parameters for EMA or evaluation."""
for param in module.parameters():
param.requires_grad = False
module.eval()
def configure_optimizers(self):
"""Initialize AdamW optimizer with weight decay."""
return torch.optim.AdamW(
self.model.parameters(),
lr=self.hparams.lr,
weight_decay=self.hparams.get('weight_decay', 1e-5)
)
# ============================================================
# Training & Validation Logic
# ============================================================
def _shared_step(self, batch):
"""
Computes joint Flow Matching losses across spatial and molecular dimensions.
"""
x0, x1 = batch['x0'], batch['x1']
g0, g1 = batch['g0'], batch['g1']
c0, c1 = batch['c0'], batch['c1']
z0, z1 = batch['z0'], batch['z1']
delta_z = batch['delta_z']
# Sample time steps
t, _, _ = self.transport.sample(x1)
# Plan paths (interpolation)
_, xt, ux_t = self.transport.path_sampler.plan(t, x0, x1)
_, gt, ug_t = self.transport.path_sampler.plan(t, g0, g1)
_, ct, uc_t = self.transport.path_sampler.plan(t, c0, c1)
_, zt, _ = self.transport.path_sampler.plan(t, z0, z1)
# 3. Predict velocity fields
vx_pred, vg_pred, vc_pred = self.model(
xt=xt, gt=gt, t=t, zt=zt, delta_z=delta_z, ct=ct
)
# Compute losses (Mean Squared Error on velocity)
# Spatial loss (X, Y)
loss_x = self.transport.loss_fn(vx_pred, x0, xt, t, ux_t).mean()
# Gene loss
loss_g = self.transport.loss_fn(vg_pred, g0, gt, t, ug_t).mean()
# Cell type loss (on one-hot/continuous space)
loss_c = self.transport.loss_fn(vc_pred, c0, ct, t, uc_t).mean()
# Weighted total loss
lambda_g = self.hparams.get('lambda_g', 0.1)
lambda_c = self.hparams.get('lambda_c', 10.0)
loss_total = loss_x + (lambda_g * loss_g) + (lambda_c * loss_c)
return {
'loss': loss_total,
'loss_x': loss_x,
'loss_g': loss_g,
'loss_c': loss_c
}
[docs]
def training_step(self, batch, batch_idx):
loss = self._shared_step(batch)
self.log('loss', loss['loss'], prog_bar=True, on_step=True, on_epoch=True)
self.log('loss_x', loss['loss_x'], on_epoch=True)
self.log('loss_g', loss['loss_g'], on_epoch=True)
self.log('loss_c', loss['loss_c'], on_epoch=True)
return loss
# ============================================================
# EMA Update Logic
# ============================================================
def on_train_batch_end(self, outputs, batch, batch_idx):
"""Update EMA parameters after each optimizer step."""
self._update_ema()
@torch.no_grad()
def _update_ema(self):
"""Update EMA model weights with exponential decay."""
for ema_param, model_param in zip(self.ema_model.parameters(), self.model.parameters()):
ema_param.data.mul_(self.ema_decay).add_(model_param.data, alpha=1 - self.ema_decay)
def on_load_checkpoint(self, checkpoint):
"""Ensure EMA model is also loaded from checkpoint."""
self._update_ema() # Warm start EMA with current loaded model params
# ============================================================
# Inference / Sampling Logic
# ============================================================
[docs]
@torch.no_grad()
def sample(self, batch, mode="ODE", steps=20):
"""
Integrates the learned flow field to reconstruct intermediate biological states.
Parameters
----------
batch : dict
A dictionary containing the initial states (`x0`, `g0`, `c0`), the physical
Z-depth conditions (`z0`, `z1`, `delta_z`), and other necessary tensors.
mode : str, optional
The integration mode, either `"ODE"` (Ordinary Differential Equation) or
`"SDE"` (Stochastic Differential Equation). By default `"ODE"`.
steps : int, optional
The number of integration steps from the source to the target slice.
Higher values yield more accurate but slower trajectories. By default 20.
Returns
-------
dict
A dictionary containing the full integration trajectories:
- `'x_traj'` : torch.Tensor of shape `(Steps, Batch, 2)`
- `'g_traj'` : torch.Tensor of shape `(Steps, Batch, Gene_Dim)`
- `'c_traj_discrete'` : torch.Tensor of shape `(Steps, Batch)` containing discrete cell type labels.
"""
self.ema_model.eval()
# Configure the ODE/SDE sampler based on hyperparameters
sample_config = {
'num_steps': steps,
'sampling_method': self.hparams.get('sampling_method', 'dopri5'),
'atol': self.hparams.get('atol', 1e-5),
'rtol': self.hparams.get('rtol', 1e-5)
}
# Instantiate the integration function
sampler_fn = self.sampler.sample_ode(**sample_config) if mode == "ODE" else self.sampler.sample_sde(**sample_config)
# Initial state at t=0
x0, g0, c0 = batch['x0'], batch['g0'], batch['c0']
x_dim, g_dim = x0.shape[-1], g0.shape[-1]
z0, z1, delta_z = batch['z0'], batch['z1'], batch['delta_z']
# Concatenate for joint integration
init_state = torch.cat([x0, g0, c0], dim=-1)
def velocity_field_wrapper(joint_state_t, t):
# Unpack modalities
xt = joint_state_t[..., :x_dim]
gt = joint_state_t[..., x_dim : x_dim + g_dim]
ct = joint_state_t[..., x_dim + g_dim :]
# Ensure t is a tensor on the correct device
if torch.is_tensor(t):
t_val = t.item() if t.dim() == 0 else t[0].item()
else:
t_val = t
t_tensor = torch.full((xt.shape[0],), t_val, device=xt.device, dtype=xt.dtype)
# Interpolate normalized Z coordinate
_, zt, _ = self.transport.path_sampler.plan(t_tensor, z0, z1)
# Forward pass through EMA model
vx, vg, vc = self.ema_model(
xt=xt, gt=gt, t=t_tensor, zt=zt, delta_z=delta_z, ct=ct
)
return torch.cat([vx, vg, vc], dim=-1)
# Compute trajectory: Shape [steps, batch, dim]
trajectory = sampler_fn(init_state, velocity_field_wrapper)
if isinstance(trajectory, (list, tuple)):
trajectory = torch.stack(trajectory, dim=0)
# Unpack integrated results
x_traj = trajectory[..., :x_dim]
g_traj = trajectory[..., x_dim : x_dim + g_dim]
c_traj_cont = trajectory[..., x_dim + g_dim :]
return {
'x_traj': x_traj,
'g_traj': g_traj,
'c_traj_discrete': torch.argmax(c_traj_cont, dim=-1)
}