Mouse Hypothalamus MERFISH 3D Reconstruction#

This tutorial presents an end-to-end DeepSpatial workflow for reconstructing a 3D volume from mouse hypothalamus MERFISH slices.

In this tutorial, you will:

  1. Import dependencies.

  2. Load and organize slice-level MERFISH data.

  3. Set up inputs and train DeepSpatial.

  4. Reconstruct the 3D tissue volume.

  5. Visualize and inspect reconstruction results.

1. Import Dependencies#

Import DeepSpatial, utility libraries, and visualization functions used throughout this tutorial.

The tutorial below follows a complete from-scratch workflow.

import os
import re
import glob
import torch
import anndata as ad

from deepspatial import DeepSpatial
from deepspatial.vis_utils import interactive_3d_labels, interactive_3d_expression

2. Data Preparation#

Please download the Mouse hypothalamus MERFISH dataset first. This section scans all h5ad files and converts slice indices in file names into physical Z coordinates, so the model can learn real inter-slice spacing.

# 2. Data preparation and path parsing
# Set this to your local dataset directory before running
data_dir = "/data/merfish_mouse_hypothalamus"
file_paths = sorted(
    glob.glob(os.path.join(data_dir, "merfish_*_paste.h5ad")),
    key=lambda x: int(re.search(r'merfish_(\d+)', x).group(1)),
)

# Physical distance between adjacent slices (same unit as the dataset)
slice_gap = 50.0

# Load slices and inject physical Z coordinates
adata_list = []
for p in file_paths:
    adata = ad.read_h5ad(p)
    # Parse slice index from file name and write physical height into adata.obs
    idx = int(re.search(r'merfish_(\d+)', p).group(1))
    adata.obs['z_coord'] = float(idx * slice_gap)
    adata_list.append(adata)

print(f"Loaded {len(adata_list)} slices with physical Z-coordinates in adata.obs['z_coord']")

Check Loaded Data#

Inspect adata_list to verify that all slices were loaded correctly.

adata_list

3. Setup and Train#

Configure model inputs, build the architecture, and run training in this section.

model = DeepSpatial()

# 3. Data setup: normalization is handled internally without overwriting raw coordinates
candidate_label_keys = ['cell_type', 'cell_class', 'annotation', 'Harmony_labels']
label_key = next((k for k in candidate_label_keys if k in adata_list[0].obs.columns), None)
if label_key is None:
    raise ValueError(f'No valid label key found. Tried: {candidate_label_keys}')
print('Using label key:', label_key)

model.setup_data(
    adata_list=adata_list,
    spatial_key='spatial',
    z_key='z_coord',
    label_key=label_key,
    batch_size=2048
)

Build the Model#

Define architecture and optimizer-related hyperparameters, then train the model. You can adjust settings based on data size, GPU memory, and training goals.

# Build model: configure architecture hyperparameters
model.build_model(
    patch_size=8,
    hidden_size=256,
    depth=6,
    lr=2e-4
)

Train the Model#

After training, checkpoints are saved to save_dir. Increase epochs based on convergence in real experiments.

# Train: set checkpoint directory and device options
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
devices = [0] if accelerator == 'gpu' else 1
print('Training accelerator:', accelerator)

model.fit(
    max_epochs=10,
    save_dir="./checkpoints/deepspatial_run",
    accelerator=accelerator,
    devices=devices,
    save_ckpt=True
)

4. Reconstruct#

Generate a full 3D volume from all slices after training.

adata_3d = model.reconstruct_full_volume(
    adata_list,
    thickness=10
)

Save Reconstruction#

Persist the reconstructed result for future analysis and reproducibility.

os.makedirs("output", exist_ok=True)
adata_3d.write_h5ad("output/deepspatial_3d.h5ad")
print("Saved reconstruction to output/deepspatial_3d.h5ad")

5. Visualization#

Use both interactive and static plots to inspect 3D structure, spatial patterns, and label or gene expression distributions.

adata_3d

3D Label Visualization#

Color by the selected label column to quickly verify whether reconstructed cellular patterns are spatially reasonable.

label_candidates = ['cell_type', 'cell_class', 'annotation', 'Harmony_labels', 'Region']
vis_label = next((k for k in label_candidates if k in adata_3d.obs.columns), None)
if vis_label is None:
    raise ValueError(f'No visualization label column found. Tried: {label_candidates}')

interactive_3d_labels(
    adata_3d,
    color_col=vis_label,
    title=f'DeepSpatial MERFISH Reconstruction ({vis_label})',
    width=1000,
    height=1000,
)

3D Gene Expression Visualization#

Map a representative marker gene into 3D space to identify local enrichment patterns.

gene_name = adata_3d.var_names[0]
for g in ['Ucn3', 'PanCK', 'CD45', 'Ki67']:
    if g in adata_3d.var_names:
        gene_name = g
        break

interactive_3d_expression(
    adata_3d,
    gene_name=gene_name,
    title=f'MERFISH Reconstruction Expression: {gene_name}',
    width=1000,
    height=1000,
)