Source code for bindingrmsd.model.model

import torch, dgl
import torch.nn as nn

from dgl.nn.pytorch.glob import SumPooling

from .GatedGCNLSPE import GatedGCNLSPELayer

[docs] class PredictionRMSD(nn.Module): def __init__(self, in_size, emb_size, intra_edge_size, inter_edge_size, pose_size, num_layers, dropout_ratio=0.15): super(PredictionRMSD, self).__init__() self.res_token_encoder = nn.Embedding( 22, int(emb_size / 2) ) self.atom_token_encoder = nn.Embedding( 175, int(emb_size / 2) ) self.protein_edge_encoder = nn.Linear( 15, emb_size ) self.ligand_node_encoder = nn.Linear( in_size, emb_size ) self.ligand_edge_encoder = nn.Linear( intra_edge_size, emb_size ) self.protein_pose_encoder = nn.Linear( pose_size, emb_size ) self.ligand_pose_encoder = nn.Linear( pose_size, emb_size ) self.complex_edge_encoder = nn.Linear( 15, emb_size ) self.protein_norm = nn.LayerNorm( emb_size ) self.ligand_norm = nn.LayerNorm( emb_size ) blocks = [ nn.ModuleList( [ GatedGCNLSPELayer( input_dim=emb_size, output_dim=emb_size, dropout=0.2, batch_norm=True ) for _ in range(num_layers) ] ) for i in range(3) ] self.protein_block = blocks[0] self.ligand_block = blocks[1] self.complex_block = blocks[2] self.mlp_rmsd = nn.Sequential( nn.Linear(emb_size, emb_size), nn.BatchNorm1d(emb_size), nn.ELU(), nn.Dropout(p=dropout_ratio), nn.Linear(emb_size, 1), ) self.sum_pooling = SumPooling()
[docs] def forward(self, gp, gl, gc): hpr = self.res_token_encoder( gp.ndata['token_res'] ) hpa = self.atom_token_encoder( gp.ndata['token_atom']) hp = torch.cat( [hpr, hpa], 1 ) ep = self.protein_edge_encoder( gp.edata['dist'] ) pp = self.protein_pose_encoder( gp.ndata['pos_enc'] ) hl = self.ligand_node_encoder( gl.ndata['feat'] ) el = self.ligand_edge_encoder( gl.edata['feat'] ) pl = self.ligand_pose_encoder( gl.ndata['pos_enc']) ec = self.complex_edge_encoder( gc.edata['dist'] ) hp = self.protein_norm( hp ) hl = self.ligand_norm( hl ) xp = gp.ndata['coord'] xl = gl.ndata['coord'] hp_raw = hp hl_raw = hl gp_batch_sizes = gp.batch_num_nodes() gl_batch_sizes = gl.batch_num_nodes() gp_start_indices = [0] + torch.cumsum(gp_batch_sizes[:-1], dim=0).tolist() gl_start_indices = [0] + torch.cumsum(gl_batch_sizes[:-1], dim=0).tolist() for (protein_layer, ligand_layer, complex_layer) in zip(self.protein_block, self.ligand_block, self.complex_block): hp, pp, ep = protein_layer( gp, hp, pp, ep ) hl, pl, el = ligand_layer( gl, hl, pl, el ) hc = [] pc = [] xc = [] for gp_start, gp_size, gl_start, gl_size in zip(gp_start_indices, gp_batch_sizes, gl_start_indices, gl_batch_sizes): hp_slice = hp[gp_start:gp_start + gp_size] hl_slice = hl[gl_start:gl_start + gl_size] pp_slice = pp[gp_start:gp_start + gp_size] pl_slice = pl[gl_start:gl_start + gl_size] xp_slice = xp[gp_start:gp_start + gp_size] xl_slice = xl[gp_start:gp_start + gp_size] hc.append( torch.cat( [hp_slice, hl_slice] ) ) pc.append( torch.cat( [pp_slice, pl_slice] ) ) xc.append( torch.cat( [xp_slice, xl_slice] ) ) hc = torch.cat( hc ) pc = torch.cat( pc ) xc = torch.cat( xc ) hc, pc, ec = complex_layer( gc, hc, pc, ec ) hp_separated = [] hl_separated = [] start = 0 for gp_size, gl_size in zip(gp_batch_sizes, gl_batch_sizes): hp_separated.append(hc[start: start + gp_size]) start += gp_size hl_separated.append(hc[start: start + gl_size]) start += gl_size hp = torch.cat(hp_separated) hl = torch.cat(hl_separated) hp += hp_raw hl += hl_raw h = self.sum_pooling(gc, hc) rmsd = self.mlp_rmsd( h ) return rmsd