diffalign.models.epsnet package

Submodules

diffalign.models.epsnet.diffusion module

class diffalign.models.epsnet.diffusion.DDPMTimeEncoder(embed_dim: int, activation=<class 'torch.nn.modules.activation.SiLU'>)[source]

Bases: Module

SinusoidalPosEmb + MLP for timestep embeddings.

forward(t: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class diffalign.models.epsnet.diffusion.DiffAlign(node_feature_dim: int = 64, time_embed_dim: int = 32, query_embed_dim: int = 32, edge_encoder_dim: int = 64, gnn_hidden_dim: int = 128, gnn_layers_intra: int = 12, gnn_layers_intra_2: int = 4, gnn_layers_inter: int = 8, max_atom_types: int = 100, num_timesteps: int = 32, beta_start: float = 0.0001, beta_end: float = 0.02, schedule_type: str = 'cosine', repulsion_weight: float = 0.01, repulsion_margin: float = 1.2, repulsion_exclude_hops: int = 3)[source]

Bases: Module

Isotropic Gaussian Diffusion (v-parameterization; T steps) - Backbone: EGNN + CrossGraphAligner (only query coordinates move) - Output: v_t in merged (Q,R) order - Loss: v MSE + x0 anchor + optional repulsion

DDPM_Sampling_UFF(query_batch: Batch, reference_batch: Batch, *, clamp: float = 1e-10, cfg_scale: float = 1.0, query_mols=None, pocket_mols=None, uff_guidance_scale: float = 0.0, uff_inner_steps: int = 8, uff_clamp: float = 1.0, uff_start_ratio: float = 0.0, snr_gate_gamma: float = 1.0, noise_temperature: float = 0.3, uff_vdw_multiplier: float = 10.0, debug_log: bool = False)[source]

Standard DDPM (v-param) + UFFTorch-based UFF steering on x0. Note: _refresh_nonbond_candidates is called with ligand+pocket coords merged.

forward(query_batch: Batch, reference_batch: Batch, t: Tensor, condition: bool = True) Tensor[source]
Predict v_t for merged (query, reference) batches.
  • query_batch, reference_batch: torch_geometric.data.Batch

  • t: [G] timesteps (0..T-1) per graph

class diffalign.models.epsnet.diffusion.SinusoidalPosEmb(dim: int)[source]

Bases: Module

Sine/cosine timestep embedding (float32).

forward(t: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

diffalign.models.epsnet.diffusion.cosine_beta_schedule(num_timesteps: int, s: float = 0.008) Tensor[source]

Nichol & Dhariwal (2021): https://arxiv.org/abs/2102.09672 Returns betas of length T (float32).

diffalign.models.epsnet.diffusion.get_distance(pos: Tensor, edge_index: Tensor) Tensor[source]
diffalign.models.epsnet.diffusion.linear_beta_schedule(num_timesteps: int, beta_start: float = 0.0001, beta_end: float = 0.02) Tensor[source]
diffalign.models.epsnet.diffusion.merge_graphs_in_batch(batch1: Batch, batch2: Batch, device: device | None = None) Batch[source]

Merge as [Q1,R1,Q2,R2,…] and attach graph_idx (even=query, odd=ref) and pair-level batch. All Data objects are moved to device beforehand to avoid CPU/CUDA mixing.

Module contents

Entry points for EPSNet-based DiffAlign models.

class diffalign.models.epsnet.DiffAlign(node_feature_dim: int = 64, time_embed_dim: int = 32, query_embed_dim: int = 32, edge_encoder_dim: int = 64, gnn_hidden_dim: int = 128, gnn_layers_intra: int = 12, gnn_layers_intra_2: int = 4, gnn_layers_inter: int = 8, max_atom_types: int = 100, num_timesteps: int = 32, beta_start: float = 0.0001, beta_end: float = 0.02, schedule_type: str = 'cosine', repulsion_weight: float = 0.01, repulsion_margin: float = 1.2, repulsion_exclude_hops: int = 3)[source]

Bases: Module

Isotropic Gaussian Diffusion (v-parameterization; T steps) - Backbone: EGNN + CrossGraphAligner (only query coordinates move) - Output: v_t in merged (Q,R) order - Loss: v MSE + x0 anchor + optional repulsion

DDPM_Sampling_UFF(query_batch: Batch, reference_batch: Batch, *, clamp: float = 1e-10, cfg_scale: float = 1.0, query_mols=None, pocket_mols=None, uff_guidance_scale: float = 0.0, uff_inner_steps: int = 8, uff_clamp: float = 1.0, uff_start_ratio: float = 0.0, snr_gate_gamma: float = 1.0, noise_temperature: float = 0.3, uff_vdw_multiplier: float = 10.0, debug_log: bool = False)[source]

Standard DDPM (v-param) + UFFTorch-based UFF steering on x0. Note: _refresh_nonbond_candidates is called with ligand+pocket coords merged.

forward(query_batch: Batch, reference_batch: Batch, t: Tensor, condition: bool = True) Tensor[source]
Predict v_t for merged (query, reference) batches.
  • query_batch, reference_batch: torch_geometric.data.Batch

  • t: [G] timesteps (0..T-1) per graph