Source code for diffalign.models.epsnet.diffusion

# ===== Standard library =====
import math
from typing import Optional

# ===== Third-party =====
import torch
from torch import autograd
from torch import nn
import torch.nn.functional as F
from torch_geometric.data import Batch
from uff_torch import UFFTorch, build_uff_inputs, merge_uff_inputs

# ===== Local (project) =====
from ..encoder.egnn import EGNN
from ..encoder.edge import MLPEdgeEncoder
from ..encoder.cross_attention import CrossGraphAligner
from ..common import extend_graph_order_radius


# ---------------- Schedules ----------------

[docs] def linear_beta_schedule(num_timesteps: int, beta_start: float = 1e-4, beta_end: float = 0.02) -> torch.Tensor: return torch.linspace(beta_start, beta_end, num_timesteps, dtype=torch.float32)
[docs] def cosine_beta_schedule(num_timesteps: int, s: float = 0.008) -> torch.Tensor: """ Nichol & Dhariwal (2021): https://arxiv.org/abs/2102.09672 Returns betas of length T (float32). """ steps = num_timesteps + 1 x = torch.linspace(0, num_timesteps, steps, dtype=torch.float32) alphas_cumprod = torch.cos(((x / num_timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1. - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999).float()
# ---------------- Positional / time encoders ----------------
[docs] class SinusoidalPosEmb(nn.Module): """Sine/cosine timestep embedding (float32).""" def __init__(self, dim: int): super().__init__() if dim % 2 != 0: raise ValueError(f"Embedding dimension ({dim}) must be even.") self.dim = dim
[docs] def forward(self, t: torch.Tensor) -> torch.Tensor: device = t.device t = t.float() half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) pos = t.unsqueeze(-1) * emb.unsqueeze(0) return torch.cat((pos.sin(), pos.cos()), dim=-1).float()
[docs] class DDPMTimeEncoder(nn.Module): """SinusoidalPosEmb + MLP for timestep embeddings.""" def __init__(self, embed_dim: int, activation=nn.SiLU): super().__init__() sine_embed_dim = embed_dim if (embed_dim % 2 == 0) else (embed_dim - 1) self.pos_emb = SinusoidalPosEmb(sine_embed_dim) self.mlp = nn.Sequential( nn.Linear(sine_embed_dim, embed_dim), activation(), nn.Linear(embed_dim, embed_dim), )
[docs] def forward(self, t: torch.Tensor) -> torch.Tensor: return self.mlp(self.pos_emb(t)).float()
# ---------------- Geometry helpers ----------------
[docs] def get_distance(pos: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: if edge_index.numel() == 0: return torch.empty((0,), dtype=pos.dtype, device=pos.device) return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)
# ---------------- Batch merge ---------------- def _pick_device(*objs) -> torch.device: """ Pick a common device from a list of tensors/Batches. Priority: first CUDA device encountered; else CPU. """ for o in objs: if isinstance(o, torch.Tensor): if o.is_cuda: return o.device elif isinstance(o, Batch): # Try representative attribute for attr in ("pos", "x", "atom_type", "edge_index"): if hasattr(o, attr) and getattr(o, attr) is not None: t = getattr(o, attr) if isinstance(t, torch.Tensor) and t.is_cuda: return t.device return torch.device("cpu")
[docs] def merge_graphs_in_batch(batch1: Batch, batch2: Batch, device: Optional[torch.device] = None) -> Batch: """ Merge as [Q1,R1,Q2,R2,...] and attach graph_idx (even=query, odd=ref) and pair-level batch. All Data objects are moved to `device` beforehand to avoid CPU/CUDA mixing. """ if device is None: device = _pick_device(batch1, batch2) data_list = [] for d1, d2 in zip(batch1.to_data_list(), batch2.to_data_list()): data_list.append(d1.to(device)) data_list.append(d2.to(device)) if not data_list: # empty Batch on target device empty = Batch() # attach empty required attrs on correct device if needed later return empty merge_batch = Batch.from_data_list(data_list) # now all on same device num_nodes_list = [d.num_nodes for d in data_list] # graph_idx/batch on correct device graph_idx_list = [torch.full((n,), i, dtype=torch.long, device=device) for i, n in enumerate(num_nodes_list)] batch_idx_list = [torch.full((n,), i // 2, dtype=torch.long, device=device) for i, n in enumerate(num_nodes_list)] merge_batch.graph_idx = torch.cat(graph_idx_list) if graph_idx_list else torch.empty(0, dtype=torch.long, device=device) merge_batch.batch = torch.cat(batch_idx_list) if batch_idx_list else torch.empty(0, dtype=torch.long, device=device) return merge_batch
# ---------------- Main (Isotropic DiffAlign) ----------------
[docs] class DiffAlign(nn.Module): """ Isotropic Gaussian Diffusion (v-parameterization; T steps) - Backbone: EGNN + CrossGraphAligner (only query coordinates move) - Output: v_t in merged (Q,R) order - Loss: v MSE + x0 anchor + optional repulsion """ def __init__( self, node_feature_dim: int = 64, time_embed_dim: int = 32, query_embed_dim: int = 32, edge_encoder_dim: int = 64, gnn_hidden_dim: int = 128, gnn_layers_intra: int = 12, gnn_layers_intra_2: int = 4, gnn_layers_inter: int = 8, max_atom_types: int = 100, # Diffusion num_timesteps: int = 32, beta_start: float = 1e-4, beta_end: float = 0.02, schedule_type: str = 'cosine', # Repulsion repulsion_weight: float = 1e-2, repulsion_margin: float = 1.2, repulsion_exclude_hops: int = 3, ): super().__init__() # ---- Diffusion buffers (isotropic) ---- self.num_timesteps = int(num_timesteps) if schedule_type == 'linear': betas = linear_beta_schedule(self.num_timesteps, beta_start, beta_end) elif schedule_type == 'cosine': betas = cosine_beta_schedule(self.num_timesteps) else: raise ValueError(f"Unknown beta schedule: {schedule_type}") alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) self.register_buffer('betas', betas) self.register_buffer('alphas', alphas) self.register_buffer('alphas_cumprod', alphas_cumprod) self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) self.register_buffer('sqrt_alphas', torch.sqrt(alphas)) self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) self.register_buffer('posterior_variance', betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod + 1e-12)) self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod + 1e-12)) self.register_buffer('posterior_mean_coef2', torch.sqrt(alphas) * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod + 1e-12)) # ---- Encoders ---- self.edge_encoder = MLPEdgeEncoder(edge_encoder_dim, "relu") self.edge_encoder2 = MLPEdgeEncoder(edge_encoder_dim, "relu") self.node_encoder = nn.Sequential( nn.Embedding(max_atom_types, node_feature_dim), nn.SiLU(), nn.Linear(node_feature_dim, node_feature_dim), ) self.time_encoder = DDPMTimeEncoder(time_embed_dim, activation=nn.SiLU) self.query_encoder = nn.Sequential( nn.Embedding(2, query_embed_dim), # 0=ref, 1=query nn.SiLU(), nn.Linear(query_embed_dim, query_embed_dim), ) gnn_in_node_dim = node_feature_dim + time_embed_dim + query_embed_dim self.intra_encoder = EGNN( in_node_nf=gnn_in_node_dim, in_edge_nf=edge_encoder_dim, hidden_nf=gnn_hidden_dim, n_layers=gnn_layers_intra, attention=True ) self.cross_aligner = CrossGraphAligner( dim=gnn_hidden_dim, heads=4, dropout=0.1, coord_update=True, num_layers=gnn_layers_inter, recompute_each=1, ) self.intra_encoder_2 = EGNN( in_node_nf=gnn_hidden_dim, in_edge_nf=edge_encoder_dim, hidden_nf=gnn_hidden_dim, n_layers=gnn_layers_intra_2, attention=True ) # Repulsion hyperparams self.repulsion_weight = float(repulsion_weight) self.repulsion_margin = float(repulsion_margin) self.repulsion_exclude_hops = int(repulsion_exclude_hops) # -------------- Forward: predict v_t --------------
[docs] def forward(self, query_batch: Batch, reference_batch: Batch, t: torch.Tensor, condition: bool = True) -> torch.Tensor: """ Predict v_t for merged (query, reference) batches. - query_batch, reference_batch: torch_geometric.data.Batch - t: [G] timesteps (0..T-1) per graph """ merged_batch = merge_graphs_in_batch(query_batch, reference_batch) if merged_batch.num_nodes == 0: device_to_use = query_batch.pos.device if hasattr(query_batch, 'pos') else 'cpu' return torch.zeros((0, 3), device=device_to_use) device = merged_batch.pos.device x_in = merged_batch.pos # (A) Query mask: only query coordinates are updated qmask_bool = ((merged_batch.graph_idx % 2) == 0) coord_mask = qmask_bool.float().unsqueeze(-1) # (B) Embeddings node_feat = self.node_encoder(merged_batch.atom_type) t_nodes = t[merged_batch.batch] # expand per-graph timestep to nodes time_emb = self.time_encoder(t_nodes) is_query = ((merged_batch.graph_idx % 2) == 0).long() # 1=query, 0=ref query_emb = self.query_encoder(is_query) h = torch.cat([node_feat, time_emb, query_emb], dim=-1) # (C) Intra-graph (stage 1): update only query coords edge_index, edge_type = extend_graph_order_radius( num_nodes=merged_batch.atom_type.size(0), pos=x_in, edge_index=merged_batch.edge_index, edge_type=merged_batch.edge_type, batch=merged_batch.graph_idx, order=3, cutoff=8, extend_order=True, extend_radius=True, ) edge_length = get_distance(x_in, edge_index).unsqueeze(-1) e = self.edge_encoder(edge_length=edge_length, edge_type=edge_type) h, x = self.intra_encoder( h=h, x=x_in, edges=edge_index, edge_attr=e, coord_mask=coord_mask, ) # (D) Cross (query moves) if condition: h, x = self.cross_aligner(h, x, batch=merged_batch.batch, graph_idx=merged_batch.graph_idx) # (E) Intra-graph (stage 2) edge_index2, edge_type2 = extend_graph_order_radius( num_nodes=merged_batch.atom_type.size(0), pos=x, edge_index=merged_batch.edge_index, edge_type=merged_batch.edge_type, batch=merged_batch.graph_idx, order=3, cutoff=8, extend_order=True, extend_radius=True, ) edge_length2 = get_distance(x, edge_index2).unsqueeze(-1) e2 = self.edge_encoder2(edge_length=edge_length2, edge_type=edge_type2) h, x = self.intra_encoder_2( h=h, x=x, edges=edge_index2, edge_attr=e2, coord_mask=coord_mask, ) # (F) Output: v = x - x_in v_hat = x - x_in return v_hat
# -------------- Samplers --------------
[docs] @torch.no_grad() def DDPM_Sampling_UFF( self, query_batch: Batch, reference_batch: Batch, *, clamp: float = 1e-10, cfg_scale: float = 1.0, # ---- UFF(PyTorch) options ---- query_mols=None, # RDKit Mol list; None disables UFF pocket_mols=None, # RDKit Mol list; None disables pocket guidance uff_guidance_scale: float = 0.0, # 0 disables UFF uff_inner_steps: int = 8, # inner gradient steps for UFF uff_clamp: float = 1.0, # clamp magnitude of forces uff_start_ratio: float = 0.0, # skip UFF if (t/(T-1)) < ratio snr_gate_gamma: float = 1.0, # gate(t) = (1 - σ_t)^γ # ---- Temperature option ---- noise_temperature: float = 0.3, # posterior noise temperature τ # ---- UFF dynamic nonbonded params ---- uff_vdw_multiplier: float = 10.0, # dynamic cutoff multiplier debug_log: bool = False, ): """ Standard DDPM (v-param) + UFFTorch-based UFF steering on x0. Note: _refresh_nonbond_candidates is called with ligand+pocket coords merged. """ device = self.betas.device qb = query_batch.to(device) rb = reference_batch.to(device) T = self.num_timesteps if qb.num_nodes == 0: return (torch.zeros((0, 3), device=device), None) # ---------- UFFTorch setup ---------- use_uff = ( uff_guidance_scale > 0.0 and (query_mols is not None) and (pocket_mols is not None) ) if use_uff: assert len(query_mols) == qb.num_graphs assert len(pocket_mols) == qb.num_graphs q_inputs_ref = build_uff_inputs( query_mols, device=device, dtype=torch.float32, vdw_distance_multiplier=uff_vdw_multiplier, ignore_interfragment_interactions=False, ) p_inputs_ref = build_uff_inputs( pocket_mols, device=device, dtype=torch.float32, vdw_distance_multiplier=uff_vdw_multiplier, ignore_interfragment_interactions=False, ) qp_inputs = merge_uff_inputs( q_inputs_ref, p_inputs_ref, ignore_interfragment_interactions=False, vdw_distance_multiplier=float(uff_vdw_multiplier), ) uff_model = UFFTorch(qp_inputs).to(device).eval() uff_model._vdw_distance_multiplier = float(uff_vdw_multiplier) # 3. Query node indices mol_slices_q = [ (qb.batch == i).nonzero(as_tuple=True)[0] for i in range(qb.num_graphs) ] gather_idx_q = torch.cat(mol_slices_q, dim=0).to(device) B = qb.num_graphs Nq = mol_slices_q[0].numel() for sl in mol_slices_q: assert sl.numel() == Nq, "Ligand atom counts must match across the batch." # 4. Freeze pocket coordinates (tensor conversion) def _mol_to_coords_tensor(m): conf = m.GetConformer() return torch.tensor( [[conf.GetAtomPosition(k).x, conf.GetAtomPosition(k).y, conf.GetAtomPosition(k).z] for k in range(m.GetNumAtoms())], device=device, dtype=torch.float32 ) pocket_coords_fixed = torch.stack( [_mol_to_coords_tensor(m) for m in pocket_mols], dim=0 ) # [B, Np, 3] else: uff_model = None # ---------- x_T init ---------- x_t = torch.randn( (qb.num_nodes, 3), device=device, dtype=torch.float32, ) # ---------- Reverse diffusion loop ---------- for t in reversed(range(T)): t_graph = torch.full((qb.num_graphs,), t, device=device, dtype=torch.long) # Current coords + model prediction cur_q = qb.clone() cur_q.pos = x_t if cfg_scale == 1.0: v_hat_merged = self(cur_q, rb, t_graph, condition=True) else: v_u = self(cur_q, rb, t_graph, condition=False) v_c = self(cur_q, rb, t_graph, condition=True) v_hat_merged = v_u + cfg_scale * (v_c - v_u) merged = merge_graphs_in_batch(cur_q, rb, device=device) qmask = (merged.graph_idx % 2 == 0) v_hat = v_hat_merged[qmask] abar_t = self.sqrt_alphas_cumprod[t] sig_t = self.sqrt_one_minus_alphas_cumprod[t] # 2) x0 prediction (reconstruction) x0_pred = abar_t * x_t - sig_t * v_hat # SNR Gate sigma_t_val = float(sig_t.item()) h_t = 1.0 - sigma_t_val h_t = max(0.0, min(1.0, h_t)) gate_t = h_t ** float(snr_gate_gamma) # 3) UFF Guidance: x0_pred → x0_star if ( use_uff and (t / max(1, T - 1)) >= uff_start_ratio and gate_t > 0.0 ): # (1) Extract ligand coords [B, Nq, 3] coords_q0 = x0_pred.index_select(0, gather_idx_q) coords_q0 = coords_q0.view(B, Nq, 3).detach() # (2) Build full coords (ligand+pocket) for neighbor refresh coords_full = torch.cat([coords_q0, pocket_coords_fixed], dim=1) with torch.no_grad(): uff_model._refresh_nonbond_candidates(coords_full) # (3) Gradient descent loop coords_q = coords_q0.clone() inner = max(1, int(uff_inner_steps)) step_scale = (uff_guidance_scale * gate_t) / float(inner) with torch.enable_grad(): for _ in range(inner): coords_q_req = coords_q.clone().requires_grad_(True) # Concatenate for energy; ligand needs grads, pocket stays fixed coords_cat = torch.cat([coords_q_req, pocket_coords_fixed], dim=1) E_b = uff_model(coords_cat) if not isinstance(E_b, torch.Tensor): E_b = torch.as_tensor(E_b, device=device, dtype=coords_cat.dtype) if E_b.ndim == 0: E_b = E_b.unsqueeze(0) grad_q, = autograd.grad(E_b.sum(), coords_q_req, create_graph=False) forces_q = (-grad_q).detach().clamp_(-uff_clamp, uff_clamp) coords_q = (coords_q + step_scale * forces_q).detach() # (4) Apply refined coords to x0_star x0_star = x0_pred.clone() x0_star.index_copy_(0, gather_idx_q, coords_q.reshape(B * Nq, 3)) if debug_log: f_norm = forces_q.norm(dim=-1) print(f"[UFF] t={t:02d} | Force Mean={f_norm.mean():.3f}") else: x0_star = x0_pred # 4) Posterior Mean (Standard DDPM) # mu_t = c1 * x0_star + c2 * x_t c1_t = self.posterior_mean_coef1[t] c2_t = self.posterior_mean_coef2[t] x_mean = c1_t * x0_star + c2_t * x_t # 5) Noise Addition if t > 0: var_t = self.posterior_variance[t] if noise_temperature != 1.0: var_t = (noise_temperature ** 2) * var_t var_t = max(float(var_t), 1e-20) noise = torch.randn_like(x_t) x_t = x_mean + math.sqrt(var_t) * noise else: x_t = x_mean return (x_t, None)