Source code for deepspatial.models.git

import math
import torch
import torch.nn as nn
from timm.models.vision_transformer import Attention, Mlp

from .commons import modulate, TimestepEmbedder, FinalLayer, get_1d_sincos_pos_embed

[docs] class PatchEmbedder(nn.Module): """ Embeds 1D vectors (e.g., gene expressions) into patch tokens via an MLP. Parameters ---------- input_size : int The total number of features in the input 1D vector (e.g., `gene_dim`). patch_size : int The number of features to group into a single patch token. hidden_size : int The embedding dimension for the output patch tokens. """
[docs] def __init__(self, input_size, patch_size, hidden_size): super().__init__() self.patch_size = patch_size self.num_patches = (input_size + patch_size - 1) // patch_size self.mlp = nn.Sequential( nn.Linear(patch_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), )
def forward(self, x): B, L = x.shape # Pad if not divisible pad_size = (self.patch_size - (L % self.patch_size)) % self.patch_size if pad_size > 0: x = torch.nn.functional.pad(x, (0, pad_size), "constant", 0) # Reshape to [Batch, Num_Patches, Patch_Size] x = x.reshape(B, -1, self.patch_size) x = self.mlp(x) return x
[docs] class GiTBlock(nn.Module): """ Transformer block with adaptive layer normalization (adaLN-Zero) conditioning. Parameters ---------- hidden_size : int The dimensionality of the input and output token embeddings. num_heads : int The number of attention heads in the multi-head self-attention mechanism. mlp_ratio : float, optional The expansion ratio for the hidden dimension inside the MLP, by default 4.0. """
[docs] def __init__(self, hidden_size, num_heads, mlp_ratio=4.0): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) )
def forward(self, x, c): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x
[docs] class GiT(nn.Module): """ GiT: Generative Transformer with separated streams for spatial transcriptomics. Parameters ---------- gene_dim : int The number of genes (features) in the expression matrix. patch_size : int The size of the patches used to tokenize the gene expression vector. hidden_size : int The embedding dimension of the transformer backbone. depth : int The number of GiTBlocks (transformer layers). num_heads : int The number of attention heads in each block. num_classes : int The number of unique cell types for categorical conditioning. mlp_ratio : float, optional The expansion ratio for the MLP inside transformer blocks, by default 4.0. """
[docs] def __init__(self, gene_dim, patch_size, hidden_size, depth, num_heads, num_classes, mlp_ratio=4.0): super().__init__() self.patch_size = patch_size # x is 2D, g is gene_dim self.num_patches_x = 1 self.num_patches_g = math.ceil(gene_dim / patch_size) self.num_patches = self.num_patches_x + self.num_patches_g # 1. Input Embedders self.x_embedder = nn.Linear(2, hidden_size) self.g_embedder = PatchEmbedder(gene_dim, patch_size, hidden_size) # Condition Embedders self.t_embedder = TimestepEmbedder(hidden_size) self.z_embedder = TimestepEmbedder(hidden_size) self.c_embedder = nn.Linear(num_classes, hidden_size) # Positional Embedding self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, hidden_size), requires_grad=False) # 2. Transformer Backbone self.blocks = nn.ModuleList([ GiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) ]) # 3. Output Heads self.x_head = FinalLayer(hidden_size, 2, 1) # Spatial velocity self.g_head = FinalLayer(hidden_size, patch_size, 1) # Gene velocity self.c_head = nn.Linear(hidden_size, num_classes) # Cell type logits self.initialize_weights()
def initialize_weights(self): """Initializes weights using GiT/DiT standards (adaLN-Zero).""" def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize Positional Embedding pos_embed = get_1d_sincos_pos_embed(self.pos_embed.shape[-1], self.num_patches) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) # Zero-out modulation and final layers for identity mapping at start for block in self.blocks: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.x_head.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.x_head.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.x_head.linear.weight, 0) nn.init.constant_(self.x_head.linear.bias, 0) nn.init.constant_(self.g_head.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.g_head.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.g_head.linear.weight, 0) nn.init.constant_(self.g_head.linear.bias, 0) def forward(self, xt, gt, t, zt, delta_z, ct): """ Forward pass computing the vector fields (velocity) for the given state. Parameters ---------- xt : torch.Tensor Current spatial coordinates, shape `(Batch, 2)`. gt : torch.Tensor Current gene expression, shape `(Batch, Gene_Dim)`. t : torch.Tensor Current integration timestep, shape `(Batch, 1)`. zt : torch.Tensor Physical Z-depth coordinates, shape `(Batch, 1)`. delta_z : torch.Tensor The physical distance gap between the source and target slices, shape `(Batch, 1)`. ct : torch.Tensor One-hot encoded cell types, shape `(Batch, Num_Classes)`. """ # Embed modalities separately gene_dim = gt.shape[1] x_feat = self.x_embedder(xt).unsqueeze(1) # [B, 1, D] g_feat = self.g_embedder(gt) # [B, num_patches_g, D] h = torch.cat([x_feat, g_feat], dim=1) + self.pos_embed # Global conditioning cond = self.t_embedder(t.view(-1)) + \ self.z_embedder(zt.view(-1)) + self.z_embedder(delta_z.view(-1)) + \ self.c_embedder(ct) for block in self.blocks: h = block(h, cond) # Project back to modality-specific outputs x = self.x_head(h[:, :1, :], cond).squeeze(1) # [B, 2] g = self.g_head(h[:, 1:, :], cond).reshape(xt.shape[0], -1)[:, :gene_dim] # [B, Gene_Dim] c = self.c_head(h.mean(dim=1)) # [B, Classes] return x, g, c