deepspatial.core.DeepSpatial.build_model#
- DeepSpatial.build_model(patch_size: int = 8, hidden_size: int = 256, depth: int = 6, num_heads: int = 8, mlp_ratio: float = 4.0, path_type: str = 'Linear', lr: float = 0.0002, weight_decay: float = 1e-05, lambda_g: float = 0.1, lambda_c: float = 10.0, sampling_method: str = 'dopri5', atol: float = 1e-05, rtol: float = 1e-05)[source]#
Instantiates the GiT network architecture and Flow Matching logic.
- Parameters:
patch_size – Tokenization patch size for spatial coordinates.
hidden_size – Transformer embedding dimension.
depth – Number of transformer layers.
num_heads – Number of attention heads.
mlp_ratio – Expansion ratio for MLP layers.
path_type – Probability path type for Flow Matching.
lr – Learning rate for training.
weight_decay – L2 regularization weight.
lambda_g – Loss weight for gene expression reconstruction.
lambda_c – Loss weight for cell type classification.
sampling_method – ODE solver for inference (e.g., ‘dopri5’, ‘euler’).
atol – Absolute tolerance for ODE solver.
rtol – Relative tolerance for ODE solver.