Source code for bapred.model.model

import torch, dgl
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn

from dgl.nn.pytorch.glob import SumPooling

from .GraphGPS import GraphGPS

[docs] class PredictionPKD(nn.Module): def __init__(self, in_size, emb_size, intra_edge_size, inter_edge_size, pose_size, num_layers, dropout_ratio=0.15): super(PredictionPKD, self).__init__() self.protein_node_encoder = nn.Linear( in_size, emb_size ) self.protein_edge_encoder = nn.Linear( intra_edge_size, emb_size ) self.protein_pose_encoder = nn.Linear( pose_size, emb_size ) self.ligand_node_encoder = nn.Linear( in_size, emb_size ) self.ligand_edge_encoder = nn.Linear( intra_edge_size, emb_size ) self.ligand_pose_encoder = nn.Linear( pose_size, emb_size ) self.complex_edge_encoder = nn.Linear( inter_edge_size, emb_size ) self.protein_norm = nn.LayerNorm( emb_size ) self.ligand_norm = nn.LayerNorm( emb_size ) blocks = [ nn.ModuleList( [ GraphGPS( emb_size, 4 ) for _ in range(num_layers) ] ) for _ in range(3) ] self.protein_block = blocks[0] self.ligand_block = blocks[1] self.complex_block = blocks[2] self.mlp_binding_affinity = nn.Sequential( nn.Linear(emb_size, emb_size), nn.BatchNorm1d(emb_size), nn.ELU(), nn.Dropout(p=dropout_ratio), nn.Linear(emb_size, 1), ) self.sum_pooling = SumPooling()
[docs] def forward(self, gp, gl, gc): hp = self.protein_node_encoder( gp.ndata['feats'] ) ep = self.protein_edge_encoder( gp.edata['feats'] ) pp = self.protein_pose_encoder( gp.ndata['pos_enc'] ) hl = self.ligand_node_encoder( gl.ndata['feats'] ) el = self.ligand_edge_encoder( gl.edata['feats'] ) pl = self.ligand_pose_encoder( gl.ndata['pos_enc'] ) ec = self.complex_edge_encoder( gc.edata['feats'] ) hp = self.protein_norm( hp ) hl = self.ligand_norm( hl ) hp_raw = hp hl_raw = hl gp_batch_sizes = gp.batch_num_nodes() gl_batch_sizes = gl.batch_num_nodes() gp_start_indices = [0] + torch.cumsum(gp_batch_sizes[:-1], dim=0).tolist() gl_start_indices = [0] + torch.cumsum(gl_batch_sizes[:-1], dim=0).tolist() for (protein_layer, ligand_layer, complex_layer) in zip(self.protein_block, self.ligand_block, self.complex_block): hp, pp, ep = protein_layer( gp, hp, pp, ep ) # g, h, p, e, hl, pl, el = ligand_layer( gl, hl, pl, el ) hc = [] pc = [] for gp_start, gp_size, gl_start, gl_size in zip(gp_start_indices, gp_batch_sizes, gl_start_indices, gl_batch_sizes): gp_slice = hp[gp_start:gp_start + gp_size] gl_slice = hl[gl_start:gl_start + gl_size] pp_slice = pp[gp_start:gp_start + gp_size] pl_slice = pl[gl_start:gl_start + gl_size] hc.append( torch.cat( [gp_slice, gl_slice] ) ) pc.append( torch.cat( [pp_slice, pl_slice] ) ) hc = torch.cat( hc ) pc = torch.cat( pc ) hc, pc, ec = complex_layer( gc, hc, pc, ec ) hp_separated = [] hl_separated = [] start = 0 for gp_size, gl_size in zip(gp_batch_sizes, gl_batch_sizes): hp_separated.append(hc[start: start + gp_size]) start += gp_size hl_separated.append(hc[start: start + gl_size]) start += gl_size hp = torch.cat(hp_separated) hl = torch.cat(hl_separated) h = self.sum_pooling(gl, hl) binding_affinity = self.mlp_binding_affinity( h ) return binding_affinity