Source code for bindingrmsd.model.GatedGCNLSPE

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn

import dgl

[docs] class GatedGCNLSPELayer(nn.Module): def __init__(self, input_dim, output_dim, dropout, batch_norm, use_lapeig_loss=False, residual=True): super().__init__() self.in_channels = input_dim self.out_channels = output_dim self.dropout = dropout self.batch_norm = batch_norm self.residual = residual self.use_lapeig_loss = use_lapeig_loss if input_dim != output_dim: self.residual = False self.A1 = nn.Linear(input_dim*2, output_dim, bias=True) self.A2 = nn.Linear(input_dim*2, output_dim, bias=True) self.B1 = nn.Linear(input_dim, output_dim, bias=True) self.B2 = nn.Linear(input_dim, output_dim, bias=True) self.B3 = nn.Linear(input_dim, output_dim, bias=True) self.C1 = nn.Linear(input_dim, output_dim, bias=True) self.C2 = nn.Linear(input_dim, output_dim, bias=True) self.bn_node_h = nn.BatchNorm1d(output_dim) self.bn_node_e = nn.BatchNorm1d(output_dim)
[docs] def message_func_for_vij(self, edges): hj = edges.src['h'] # h_j pj = edges.src['p'] # p_j vij = self.A2(torch.cat((hj, pj), -1)) return {'v_ij': vij}
[docs] def message_func_for_pj(self, edges): pj = edges.src['p'] # p_j return {'C2_pj': self.C2(pj)}
[docs] def compute_normalized_eta(self, edges): return {'eta_ij': edges.data['sigma_hat_eta'] / (edges.dst['sum_sigma_hat_eta'] + 1e-6)} # sigma_hat_eta_ij/ sum_j' sigma_hat_eta_ij'
[docs] def forward(self, g, h, p, e): with g.local_scope(): h_in = h p_in = p e_in = e g.ndata['h'] = h g.ndata['A1_h'] = self.A1(torch.cat((h, p), -1)) g.ndata['B1_h'] = self.B1(h) g.ndata['B2_h'] = self.B2(h) g.ndata['p'] = p g.ndata['C1_p'] = self.C1(p) g.edata['e'] = e g.edata['B3_e'] = self.B3(e) g.apply_edges(fn.u_add_v('B1_h', 'B2_h', 'B1_B2_h')) g.edata['hat_eta'] = g.edata['B1_B2_h'] + g.edata['B3_e'] g.edata['sigma_hat_eta'] = torch.sigmoid(g.edata['hat_eta']) g.update_all(fn.copy_e('sigma_hat_eta', 'm'), fn.sum('m', 'sum_sigma_hat_eta')) # sum_j' sigma_hat_eta_ij' g.apply_edges(self.compute_normalized_eta) # sigma_hat_eta_ij/ sum_j' sigma_hat_eta_ij' g.apply_edges(self.message_func_for_vij) # v_ij g.edata['eta_mul_v'] = g.edata['eta_ij'] * g.edata['v_ij'] # eta_ij * v_ij g.update_all(fn.copy_e('eta_mul_v', 'm'), fn.sum('m', 'sum_eta_v')) # sum_j eta_ij * v_ij g.ndata['h'] = g.ndata['A1_h'] + g.ndata['sum_eta_v'] g.apply_edges(self.message_func_for_pj) # p_j g.edata['eta_mul_p'] = g.edata['eta_ij'] * g.edata['C2_pj'] # eta_ij * C2_pj g.update_all(fn.copy_e('eta_mul_p', 'm'), fn.sum('m', 'sum_eta_p')) # sum_j eta_ij * C2_pj g.ndata['p'] = g.ndata['C1_p'] + g.ndata['sum_eta_p'] h = g.ndata['h'] p = g.ndata['p'] e = g.edata['hat_eta'] if self.batch_norm: h = self.bn_node_h(h) e = self.bn_node_e(e) h = F.relu(h) e = F.relu(e) p = torch.tanh(p) if self.residual: h = h_in + h p = p_in + p e = e_in + e h = F.dropout(h, self.dropout, training=self.training) p = F.dropout(p, self.dropout, training=self.training) e = F.dropout(e, self.dropout, training=self.training) return h, p, e
def __repr__(self): return '{}(in_channels={}, out_channels={})'.format(self.__class__.__name__, self.in_channels, self.out_channels)