deepspatial.core.DeepSpatial.build_model

deepspatial.core.DeepSpatial.build_model#

DeepSpatial.build_model(patch_size: int = 8, hidden_size: int = 256, depth: int = 6, num_heads: int = 8, mlp_ratio: float = 4.0, path_type: str = 'Linear', lr: float = 0.0002, weight_decay: float = 1e-05, lambda_g: float = 0.1, lambda_c: float = 10.0, sampling_method: str = 'dopri5', atol: float = 1e-05, rtol: float = 1e-05)[source]#

Instantiates the GiT network architecture and Flow Matching logic.

Parameters:
  • patch_size – Tokenization patch size for spatial coordinates.

  • hidden_size – Transformer embedding dimension.

  • depth – Number of transformer layers.

  • num_heads – Number of attention heads.

  • mlp_ratio – Expansion ratio for MLP layers.

  • path_type – Probability path type for Flow Matching.

  • lr – Learning rate for training.

  • weight_decay – L2 regularization weight.

  • lambda_g – Loss weight for gene expression reconstruction.

  • lambda_c – Loss weight for cell type classification.

  • sampling_method – ODE solver for inference (e.g., ‘dopri5’, ‘euler’).

  • atol – Absolute tolerance for ODE solver.

  • rtol – Relative tolerance for ODE solver.