Mouse Hypothalamus MERFISH 3D Reconstruction#
This tutorial presents an end-to-end DeepSpatial workflow for reconstructing a 3D volume from mouse hypothalamus MERFISH slices.
In this tutorial, you will:
Import dependencies.
Load and organize slice-level MERFISH data.
Set up inputs and train DeepSpatial.
Reconstruct the 3D tissue volume.
Visualize and inspect reconstruction results.
1. Import Dependencies#
Import DeepSpatial, utility libraries, and visualization functions used throughout this tutorial.
The tutorial below follows a complete from-scratch workflow.
import os
import re
import glob
import torch
import anndata as ad
from deepspatial import DeepSpatial
from deepspatial.vis_utils import interactive_3d_labels, interactive_3d_expression
2. Data Preparation#
Please download the Mouse hypothalamus MERFISH dataset first.
This section scans all h5ad files and converts slice indices in file names into physical Z coordinates, so the model can learn real inter-slice spacing.
# 2. Data preparation and path parsing
# Set this to your local dataset directory before running
data_dir = "/data/merfish_mouse_hypothalamus"
file_paths = sorted(
glob.glob(os.path.join(data_dir, "merfish_*_paste.h5ad")),
key=lambda x: int(re.search(r'merfish_(\d+)', x).group(1)),
)
# Physical distance between adjacent slices (same unit as the dataset)
slice_gap = 50.0
# Load slices and inject physical Z coordinates
adata_list = []
for p in file_paths:
adata = ad.read_h5ad(p)
# Parse slice index from file name and write physical height into adata.obs
idx = int(re.search(r'merfish_(\d+)', p).group(1))
adata.obs['z_coord'] = float(idx * slice_gap)
adata_list.append(adata)
print(f"Loaded {len(adata_list)} slices with physical Z-coordinates in adata.obs['z_coord']")
Check Loaded Data#
Inspect adata_list to verify that all slices were loaded correctly.
adata_list
3. Setup and Train#
Configure model inputs, build the architecture, and run training in this section.
model = DeepSpatial()
# 3. Data setup: normalization is handled internally without overwriting raw coordinates
candidate_label_keys = ['cell_type', 'cell_class', 'annotation', 'Harmony_labels']
label_key = next((k for k in candidate_label_keys if k in adata_list[0].obs.columns), None)
if label_key is None:
raise ValueError(f'No valid label key found. Tried: {candidate_label_keys}')
print('Using label key:', label_key)
model.setup_data(
adata_list=adata_list,
spatial_key='spatial',
z_key='z_coord',
label_key=label_key,
batch_size=2048
)
Build the Model#
Define architecture and optimizer-related hyperparameters, then train the model. You can adjust settings based on data size, GPU memory, and training goals.
# Build model: configure architecture hyperparameters
model.build_model(
patch_size=8,
hidden_size=256,
depth=6,
lr=2e-4
)
Train the Model#
After training, checkpoints are saved to save_dir. Increase epochs based on convergence in real experiments.
# Train: set checkpoint directory and device options
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
devices = [0] if accelerator == 'gpu' else 1
print('Training accelerator:', accelerator)
model.fit(
max_epochs=10,
save_dir="./checkpoints/deepspatial_run",
accelerator=accelerator,
devices=devices,
save_ckpt=True
)
4. Reconstruct#
Generate a full 3D volume from all slices after training.
adata_3d = model.reconstruct_full_volume(
adata_list,
thickness=10
)
Save Reconstruction#
Persist the reconstructed result for future analysis and reproducibility.
os.makedirs("output", exist_ok=True)
adata_3d.write_h5ad("output/deepspatial_3d.h5ad")
print("Saved reconstruction to output/deepspatial_3d.h5ad")
5. Visualization#
Use both interactive and static plots to inspect 3D structure, spatial patterns, and label or gene expression distributions.
adata_3d
3D Label Visualization#
Color by the selected label column to quickly verify whether reconstructed cellular patterns are spatially reasonable.
label_candidates = ['cell_type', 'cell_class', 'annotation', 'Harmony_labels', 'Region']
vis_label = next((k for k in label_candidates if k in adata_3d.obs.columns), None)
if vis_label is None:
raise ValueError(f'No visualization label column found. Tried: {label_candidates}')
interactive_3d_labels(
adata_3d,
color_col=vis_label,
title=f'DeepSpatial MERFISH Reconstruction ({vis_label})',
width=1000,
height=1000,
)
3D Gene Expression Visualization#
Map a representative marker gene into 3D space to identify local enrichment patterns.
gene_name = adata_3d.var_names[0]
for g in ['Ucn3', 'PanCK', 'CD45', 'Ki67']:
if g in adata_3d.var_names:
gene_name = g
break
interactive_3d_expression(
adata_3d,
gene_name=gene_name,
title=f'MERFISH Reconstruction Expression: {gene_name}',
width=1000,
height=1000,
)