deepspatial.module.DeepSpatialModule.sample

deepspatial.module.DeepSpatialModule.sample#

DeepSpatialModule.sample(batch, mode='ODE', steps=20)[source]#

Integrates the learned flow field to reconstruct intermediate biological states.

Parameters:
  • batch (dict) – A dictionary containing the initial states (x0, g0, c0), the physical Z-depth conditions (z0, z1, delta_z), and other necessary tensors.

  • mode (str, optional) – The integration mode, either “ODE” (Ordinary Differential Equation) or “SDE” (Stochastic Differential Equation). By default “ODE”.

  • steps (int, optional) – The number of integration steps from the source to the target slice. Higher values yield more accurate but slower trajectories. By default 20.

Returns:

A dictionary containing the full integration trajectories: - ‘x_traj’ : torch.Tensor of shape (Steps, Batch, 2) - ‘g_traj’ : torch.Tensor of shape (Steps, Batch, Gene_Dim) - ‘c_traj_discrete’ : torch.Tensor of shape (Steps, Batch) containing discrete cell type labels.

Return type:

dict