Mouse Brain Deep STARmap 3D Validation#
This tutorial presents an end-to-end DeepSpatial workflow for reconstructing a 3D volume from Deep STARmap slices, followed by validation against true 3D reference data.
In this tutorial, you will:
Import dependencies.
Prepare slice-level training data from true 3D coordinates.
Set up and train DeepSpatial.
Reconstruct the 3D tissue volume.
Validate and visualize reconstruction results.
1. Import Dependencies#
Import DeepSpatial, data utilities, and visualization functions used in this tutorial.
import os
import numpy as np
import pandas as pd
import anndata as ad
from deepspatial import DeepSpatial
from deepspatial.vis_utils import (
interactive_3d_labels,
interactive_3d_expression,
plot_z_distribution,
plot_orthogonal_projections
)
2. Data Preparation#
Please download the Mouse brain Deep STARmap dataset first. In this workflow, 2D training slices are extracted from true 3D coordinates by segmented Z ranges.
# Set this to your local true-3D reference file
gt_path = '/data/yuhangyang/DeepSpatial/data/deepstarmap_mouse_brain.h5ad'
adata_gt = ad.read_h5ad(gt_path)
adata_gt
# Extract 2D slice inputs from true 3D data using segmented Z ranges
z_coords = adata_gt.obsm['spatial'][:, 2]
ranges = [
(5, 15),
(35, 45),
(65, 75),
(95, 105),
(125, 135),
(155, 165),
(185, 195)
]
adata_list = []
for i, (z_start, z_end) in enumerate(ranges, start=1):
mask = (z_coords >= z_start) & (z_coords <= z_end)
sub_adata = adata_gt[mask].copy()
# Use XY as 2D coordinates for model input
sub_adata.obsm['spatial'] = sub_adata.obsm['spatial'][:, :2]
midpoint = float((z_start + z_end) / 2)
sub_adata.obs['z_coord'] = midpoint
adata_list.append(sub_adata)
Check Loaded Data#
Inspect extracted slices before training.
adata_list
3. Setup and Train#
Configure model inputs, build the architecture, and train DeepSpatial.
model = DeepSpatial()
model.setup_data(
adata_list=adata_list,
spatial_key='spatial',
z_key='z_coord',
label_key='Harmony_labels',
batch_size=512
)
model.build_model(
patch_size=8,
hidden_size=256,
depth=6,
lr=2e-4
)
model.fit(
max_epochs=10,
save_dir='./checkpoints/deepspatial_run',
accelerator='gpu',
devices=[5],
save_ckpt=True
)
4. Reconstruct#
Generate a full 3D volume from all training slices.
adata_3d = model.reconstruct_full_volume(
adata_list,
thickness=10
)
adata_3d
os.makedirs('output', exist_ok=True)
adata_3d.write_h5ad('output/deepspatial_3d_starmap_brain.h5ad')
5. Validation and Visualization#
Compare reconstructed results with true 3D reference data using the same DeepSpatial visualization functions.
# Build comparable 3D coordinate matrix for reconstruction
xy_coords = adata_3d.obsm['spatial']
z_col = adata_3d.obs['z_coord'].to_numpy().reshape(-1, 1)
adata_3d.obsm['spatial_3d_aligned'] = np.column_stack((xy_coords, z_col))
# Restrict GT to reconstructed z-range for fair comparison
z_gt = adata_gt.obsm['spatial'][:, 2]
mask_gt = (z_gt >= 10) & (z_gt <= 190)
adata_gt_sub = adata_gt[mask_gt].copy()
# Build two comparable AnnData views for visualization functions
adata_gt_vis = adata_gt_sub.copy()
adata_gt_vis.obs['z_coord'] = adata_gt_vis.obsm['spatial'][:, 2].astype(float)
adata_gt_vis.obsm['spatial'] = adata_gt_vis.obsm['spatial'][:, :2].copy()
adata_recon_vis = adata_3d.copy()
# Use a fixed label key and a shared palette across all label-based plots
label_col = 'Harmony_labels'
if label_col not in adata_recon_vis.obs.columns or label_col not in adata_gt_vis.obs.columns:
raise KeyError("'Harmony_labels' must exist in both GT and reconstruction .obs")
all_labels = sorted(
set(adata_gt_vis.obs[label_col].astype(str).unique())
| set(adata_recon_vis.obs[label_col].astype(str).unique())
)
base_colors = [
'#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
'#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
'#393b79', '#637939', '#8c6d31', '#843c39', '#7b4173',
'#3182bd', '#31a354', '#756bb1', '#636363', '#e6550d'
]
common_palette = {lab: base_colors[i % len(base_colors)] for i, lab in enumerate(all_labels)}
# Compare 3D label distributions with the same DeepSpatial function
interactive_3d_labels(
adata_gt_vis,
color_col=label_col,
palette=common_palette,
title='Ground Truth: 3D Label Distribution',
width=1000,
height=900
)
# Compare 3D label distributions with the same DeepSpatial function
interactive_3d_labels(
adata_recon_vis,
color_col=label_col,
palette=common_palette,
title='Reconstruction: 3D Label Distribution',
width=1000,
height=900
)
# Compare orthogonal projections with the same DeepSpatial function
plot_orthogonal_projections(adata_gt_vis, color_col=label_col, palette=common_palette)
plot_orthogonal_projections(adata_recon_vis, color_col=label_col, palette=common_palette)
# Compare gene expression with the same DeepSpatial function
gene_name = 'Reln' if 'Reln' in adata_recon_vis.var_names and 'Reln' in adata_gt_vis.var_names else adata_recon_vis.var_names[0]
interactive_3d_expression(
adata_gt_vis,
gene_name=gene_name,
title=f'Ground Truth: Expression {gene_name}',
width=1000,
height=900
)
interactive_3d_expression(
adata_recon_vis,
gene_name=gene_name,
title=f'Reconstruction: Expression {gene_name}',
width=1000,
height=900
)