deepspatial.core.DeepSpatial.fit

deepspatial.core.DeepSpatial.fit#

DeepSpatial.fit(max_epochs: int = 100, save_dir: str = './checkpoints', accelerator: str = 'auto', devices: str = 'auto', save_ckpt: bool = False, resume_ckpt_path: str | None = None)[source]#

Executes the training loop using PyTorch Lightning.

Parameters:
  • max_epochs – Max training epochs.

  • save_dir – Directory to save checkpoints and metadata.

  • accelerator – Hardware accelerator (‘auto’, ‘gpu’, ‘cpu’, ‘mps’).

  • devices – Number of devices or indices (‘auto’, 1, [0, 1]).

  • save_ckpt – Whether to save progress.

  • resume_ckpt_path – Path to resume full training state.