Source code for diffalign.models.epsnet.diffusion

import math
import torch

from torch import nn
from tqdm.auto import tqdm
from torch_geometric.data import Batch

from ..encoder import EGNN, MLPEdgeEncoder
from ..common import extend_graph_order_radius, extend_to_cross_attention


[docs] def get_distance(pos, edge_index): return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)
[docs] def merge_graphs_in_batch(batch1, batch2): merge_batch = Batch.from_data_list([val for pair in zip(batch1.to_data_list(), batch2.to_data_list()) for val in pair]) merge_batch.graph_idx = torch.tensor(merge_batch.batch) merge_batch.batch = merge_batch.batch//2 return merge_batch
[docs] class SinusoidalTimeEmbeddings(nn.Module): """ Sinusoidal Time Embedder. Args: out_dim (int): The output dimension of the embedding. """ def __init__(self, out_dim): super().__init__() self.out_dim = out_dim
[docs] def forward(self, time): """ Args: time (torch.tensor): A tensor of shaped `(num_nodes, 1)` representing time values of each node. returns: torch.tensor: A tensor of shape `(num_nodes, self.out_dim)` containing the sinusoidal time embeddings. """ device = time.device half_out_dim = self.out_dim // 2 embeddings = math.log(10000) / (half_out_dim - 1) embeddings = torch.exp(torch.arange(half_out_dim, device=device) * -embeddings) embeddings = time[:, None] * embeddings[None, :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) return embeddings.squeeze(1)
[docs] class DiffAlign(nn.Module): def __init__(self): super().__init__() self.edge_encoder = MLPEdgeEncoder(128, "relu") self.edge_encoder2 = MLPEdgeEncoder(128, "relu") self.node_encoder = nn.Sequential( nn.Embedding(100, 64), nn.SiLU(), nn.Linear(64,64), ) self.time_encoder = nn.Sequential( SinusoidalTimeEmbeddings(32), nn.Linear(32,32), nn.SiLU(), nn.Linear(32,32), ) self.query_encoder = nn.Sequential( nn.Embedding(2, 32), nn.SiLU(), nn.Linear(32,32), ) self.encoder = EGNN( in_node_nf=128, in_edge_nf=128, hidden_nf=128, device='cpu', act_fn=torch.nn.SiLU(), n_layers=12, attention=True, tanh=False, norm_constant=0 ) self.encoder2 = EGNN( in_node_nf=128, in_edge_nf=128, hidden_nf=128, device='cpu', act_fn=torch.nn.SiLU(), n_layers=4, attention=True, tanh=False, norm_constant=0 ) self.encoder_cross = EGNN( in_node_nf=128, in_edge_nf=1, hidden_nf=128, device='cpu', act_fn=torch.nn.SiLU(), n_layers=8, attention=True, tanh=False, norm_constant=0 ) self.betas = nn.Parameter(torch.linspace(0.0001*1000/1000, 0.01*1000/1000, 1000), requires_grad=False) self.alphas = nn.Parameter((1. - self.betas), requires_grad=False) self.num_timesteps = self.betas.size(0)
[docs] def forward(self, query_batch, reference_batch, time_step, condition=True): """ Args: quey_batch (torch_geometric.data.Batch): A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes. reference_batch (torch_geometric.data.Batch): A batch for reference molecules with the same structure as query_batch. time_step (torch.tensor): A time step index vector for each node shaped `(num_batches,)`. condition (bool, optional): A flag indicating whether to apply conditioning. Defaults to `True`. Returns: torch.tensor: Predicted noise of shape `(n_nodes, 3)`. """ merged_batch = merge_graphs_in_batch(query_batch, reference_batch) edge_index, edge_type = extend_graph_order_radius( num_nodes=merged_batch.atom_type.size(0), pos=merged_batch.pos, edge_index=merged_batch.edge_index, edge_type=merged_batch.edge_type, batch=merged_batch.graph_idx, order=3, cutoff=10, extend_order=True, extend_radius=False, ) edge_index_2, edge_type_2 = extend_graph_order_radius( num_nodes=merged_batch.atom_type.size(0), pos=merged_batch.pos, edge_index=merged_batch.edge_index, edge_type=merged_batch.edge_type, batch=merged_batch.graph_idx, order=3, cutoff=10, extend_order=True, extend_radius=True, ) edge_length = get_distance(merged_batch.pos, edge_index).unsqueeze(-1) # (E, 1) query_mask = ((merged_batch.graph_idx%2)==0) # Create cross-attention-like edges to apply conditions. if self.training: if (torch.rand(1)<0.9).sum()==1: edge_index_a = extend_to_cross_attention(merged_batch.pos, 200, merged_batch.batch, merged_batch.graph_idx) else: edge_index_a = extend_to_cross_attention(merged_batch.pos, 0, merged_batch.batch, merged_batch.graph_idx) else: if condition: edge_index_a = extend_to_cross_attention(merged_batch.pos, 200, merged_batch.batch, merged_batch.graph_idx) else: edge_index_a = extend_to_cross_attention(merged_batch.pos, 0, merged_batch.batch, merged_batch.graph_idx) h = torch.cat([self.node_encoder(merged_batch.atom_type), self.time_encoder(time_step.index_select(0, merged_batch.batch).view(-1,1)), self.query_encoder(query_mask*1)], dim=1) x = merged_batch.pos e = self.edge_encoder( edge_length=edge_length, edge_type=edge_type ) h, x = self.encoder( h = h, x = x, edges = edge_index, edge_attr=e, coord_mask=query_mask ) edge_length_a = get_distance(x, edge_index_a).unsqueeze(-1) h, x = self.encoder_cross( h = h, x = x, edges = edge_index_a, edge_attr = edge_length_a, coord_mask = query_mask ) edge_length_2 = get_distance(x, edge_index_2).unsqueeze(-1) e_2 = self.edge_encoder2( edge_length=edge_length_2, edge_type=edge_type_2 ) h, x = self.encoder2( h = h, x = x, edges = edge_index_2, edge_attr = e_2, coord_mask=query_mask ) return x[query_mask] - merged_batch.pos[query_mask]
[docs] def get_loss(self, query_batch, reference_batch): """ Args: quey_batch (torch_geometric.data.Batch): A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes. reference_batch (torch_geometric.data.Batch): A batch for reference molecules with the same structure as query_batch. Returns: torch.tensor: Calculated loss, a scalar tensor. """ time_step = torch.randint(0, self.num_timesteps, size=(query_batch.num_graphs, ), device=query_batch.pos.device) a = self.alphas.cumprod(dim=0).index_select(0, time_step) a_pos = a.index_select(0, query_batch.batch).unsqueeze(-1) # (N, 1) # Add noise as ddpm manner pos_noise = torch.randn_like(query_batch.pos) query_batch.pos = query_batch.pos * a_pos.sqrt() + pos_noise * (1.0 - a_pos).sqrt() reference_batch.atom_type = reference_batch.atom_type + 35 x = self( query_batch=query_batch, reference_batch=reference_batch, time_step = time_step, ) loss = ((x - pos_noise).square()).mean() return loss
[docs] def DDPM_Sampling(self, query_batch, reference_batch): """ Args: quey_batch (torch_geometric.data.Batch): A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes. reference_batch (torch_geometric.data.Batch): A batch for reference molecules with the same structure as query_batch. Returns: torch.tensor: Predicted position of query molecule shaped `(num_query_nodes, 3)`. list[torch.tensor]: Trajectory of query molecule, contatining 999 tensors each shaped `(num_query_nodes, 3)`. """ reference_batch.atom_type = reference_batch.atom_type + 35 pos_traj = [] with torch.no_grad(): seq = range(0, self.num_timesteps) seq_next = [-1] + list(seq[:-1]) for i, j in tqdm(zip(reversed(seq[1:]), reversed(seq_next[1:])), desc='sample'): t = torch.full(size=(query_batch.num_graphs,), fill_value=i, dtype=torch.long, device=query_batch.pos.device) next_t = torch.full(size=(query_batch.num_graphs,), fill_value=j, dtype=torch.long, device=query_batch.pos.device) x = self( query_batch=query_batch, reference_batch=reference_batch, time_step=t, condition=True, ) eps_pos = x at = self.alphas.cumprod(dim=0).index_select(0, t[0]) at_next = self.alphas.cumprod(dim=0).index_select(0, next_t[0]) e = eps_pos # Denoising as DDPM manner x0_pred = (query_batch.pos - e*(1-at).sqrt()) / at.sqrt() sigma_t = ((1-at_next)/(1-at)*(1-(at/at_next))).sqrt() pos_next = at_next.sqrt()*x0_pred + (1-at_next-sigma_t.square()).sqrt()*e + sigma_t*torch.randn_like(query_batch.pos) query_batch.pos = pos_next if torch.isnan(query_batch.pos).any(): print('NaN detected. Please restart.') raise FloatingPointError() pos_traj.append(query_batch.pos.clone().cpu()) return query_batch.pos, pos_traj
[docs] def DDPM_CFG_Sampling(self, query_batch, reference_batch): """ Args: quey_batch (torch_geometric.data.Batch): A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes. reference_batch (torch_geometric.data.Batch): A batch for reference molecules with the same structure as query_batch. Returns: torch.tensor: Predicted position of query molecule shaped `(num_query_nodes, 3)`. list[torch.tensor]: Trajectory of query molecule, contatining 999 tensors each shaped `(num_query_nodes, 3)`. """ reference_batch.atom_type = reference_batch.atom_type + 35 pos_traj = [] with torch.no_grad(): seq = range(0, self.num_timesteps) seq_next = [-1] + list(seq[:-1]) for i, j in tqdm(zip(reversed(seq[1:]), reversed(seq_next[1:]))): t = torch.full(size=(query_batch.num_graphs,), fill_value=i, dtype=torch.long, device=query_batch.pos.device) next_t = torch.full(size=(query_batch.num_graphs,), fill_value=j, dtype=torch.long, device=query_batch.pos.device) x_g = self( query_batch=query_batch, reference_batch=reference_batch, time_step = t, condition=True ) x_f = self( query_batch=query_batch, reference_batch=reference_batch, time_step = t, condition=False ) eps_pos = 10*x_g - 9*x_f at = self.alphas.cumprod(dim=0).index_select(0, t[0]) at_next = self.alphas.cumprod(dim=0).index_select(0, next_t[0]) e = eps_pos # Denoising as DDPM manner x0_pred = (query_batch.pos - e*(1-at).sqrt()) / at.sqrt() sigma_t = ((1-at_next)/(1-at)*(1-(at/at_next))).sqrt() pos_next = at_next.sqrt()*x0_pred + (1-at_next-sigma_t.square()).sqrt()*e + sigma_t*torch.randn_like(query_batch.pos) query_batch.pos = pos_next if torch.isnan(query_batch.pos).any(): print('NaN detected. Please restart.') raise FloatingPointError() pos_traj.append(query_batch.pos.clone().cpu()) return query_batch.pos, pos_traj
[docs] def DDIM_CFG_Sampling(self, query_batch, reference_batch, n_steps): """ Args: quey_batch (torch_geometric.data.Batch): A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes. reference_batch (torch_geometric.data.Batch): A batch for reference molecules with the same structure as query_batch. n_steps (int): A number of steps for DDIM sampling. Returns: torch.tensor: Predicted position of query molecule shaped `(num_query_nodes, 3)`. list[torch.tensor]: Trajectory of query molecule, contatining 999 tensors each shaped `(num_query_nodes, 3)`. """ reference_batch.atom_type = reference_batch.atom_type + 35 pos_traj = [] with torch.no_grad(): t_max = self.num_timesteps - 1 seq = torch.linspace(0, 1, n_steps) * t_max seq_prev = torch.cat([torch.tensor([-1]), seq[:-1]], dim=0) timesteps = reversed(seq[1:]) timesteps_prev = reversed(seq_prev[1:]) for i, j in tqdm(zip(timesteps, timesteps_prev), desc='sample'): t = torch.full(size=(query_batch.num_graphs,), fill_value=i, dtype=torch.long, device=query_batch.pos.device) next_t = torch.full(size=(query_batch.num_graphs,), fill_value=j, dtype=torch.long, device=query_batch.pos.device) x_g = self( query_batch=query_batch, reference_batch=reference_batch, time_step = t, condition=True ) x_f = self( query_batch=query_batch, reference_batch=reference_batch, time_step = t, condition=False ) eps_pos = 10*x_g - 9*x_f at = self.alphas.cumprod(dim=0).index_select(0, t[0]) at_next = self.alphas.cumprod(dim=0).index_select(0, next_t[0]) e = eps_pos # Denoising as DDIM manner x0_pred = (query_batch.pos - e*(1-at).sqrt()) / at.sqrt() pos_next = at_next.sqrt()*x0_pred + (1-at_next).sqrt()*e query_batch.pos = pos_next if torch.isnan(query_batch.pos).any(): print('NaN detected. Please restart.') raise FloatingPointError() pos_traj.append(query_batch.pos.clone().cpu()) return query_batch.pos, pos_traj