Source code for promptbind.models.cross_att

import torch
from torch import nn
from torch.nn import LayerNorm, Linear

from .model_utils import permute_final_dims, Attention, Transition, InteractionModule

[docs] class CrossAttentionModule(nn.Module): def __init__(self, node_hidden_dim, pair_hidden_dim, rm_layernorm=False, keep_trig_attn=False, dist_hidden_dim=32, normalize_coord=None): super().__init__() self.pair_hidden_dim = pair_hidden_dim self.keep_trig_attn = keep_trig_attn if keep_trig_attn: self.triangle_block_row = RowTriangleAttentionBlock(pair_hidden_dim, dist_hidden_dim, rm_layernorm=rm_layernorm) self.triangle_block_column = RowTriangleAttentionBlock(pair_hidden_dim, dist_hidden_dim, rm_layernorm=rm_layernorm) self.p_attention_block = RowAttentionBlock(node_hidden_dim, pair_hidden_dim, rm_layernorm=rm_layernorm) self.c_attention_block = RowAttentionBlock(node_hidden_dim, pair_hidden_dim, rm_layernorm=rm_layernorm) self.p_transition = Transition(node_hidden_dim, 2, rm_layernorm=rm_layernorm) self.c_transition = Transition(node_hidden_dim, 2, rm_layernorm=rm_layernorm) self.pair_transition = Transition(pair_hidden_dim, 2, rm_layernorm=rm_layernorm) self.inter_layer = InteractionModule(node_hidden_dim, pair_hidden_dim, 32, opm=False, rm_layernorm=rm_layernorm)
[docs] def forward(self, p_embed_batched, p_mask, c_embed_batched, c_mask, pair_embed, pair_mask, c_c_dist_embed=None, p_p_dist_embed=None): if self.keep_trig_attn: pair_embed = self.triangle_block_row(pair_embed=pair_embed, pair_mask=pair_mask, dist_embed=c_c_dist_embed) pair_embed = self.triangle_block_row(pair_embed=pair_embed.transpose(-2, -3), pair_mask=pair_mask.transpose(-1, -2), dist_embed=p_p_dist_embed).transpose(-2, -3) p_embed_batched = self.p_attention_block(node_embed_i=p_embed_batched, node_embed_j=c_embed_batched, pair_embed=pair_embed, pair_mask=pair_mask, node_mask_i=p_mask) c_embed_batched = self.c_attention_block(node_embed_i=c_embed_batched, node_embed_j=p_embed_batched, pair_embed=pair_embed.transpose(-2, -3), pair_mask=pair_mask.transpose(-1, -2), node_mask_i=c_mask) p_embed_batched = p_embed_batched + self.p_transition(p_embed_batched) c_embed_batched = c_embed_batched + self.c_transition(c_embed_batched) pair_embed = pair_embed + self.inter_layer(p_embed_batched, c_embed_batched, p_mask, c_mask)[0] pair_embed = self.pair_transition(pair_embed) * pair_mask.to(torch.float).unsqueeze(-1) return p_embed_batched, c_embed_batched, pair_embed
[docs] class RowTriangleAttentionBlock(nn.Module): inf = 1e9 def __init__(self, pair_hidden_dim, dist_hidden_dim, attention_hidden_dim=32, no_heads=4, dropout=0.1, rm_layernorm=False): super(RowTriangleAttentionBlock, self).__init__() self.no_heads = no_heads self.attention_hidden_dim = attention_hidden_dim self.dist_hidden_dim = dist_hidden_dim self.pair_hidden_dim = pair_hidden_dim self.rm_layernorm = rm_layernorm if not self.rm_layernorm: self.layernorm = LayerNorm(pair_hidden_dim) self.linear = Linear(dist_hidden_dim, self.no_heads) self.linear_g = Linear(dist_hidden_dim, self.no_heads) self.dropout = nn.Dropout(dropout) self.mha = Attention( pair_hidden_dim, pair_hidden_dim, pair_hidden_dim, attention_hidden_dim, no_heads )
[docs] def forward(self, pair_embed, pair_mask, dist_embed): if not self.rm_layernorm: pair_embed = self.layernorm(pair_embed) # (*, I, J, C_pair) mask_bias = (self.inf * (pair_mask.to(torch.float) - 1))[..., :, None, None, :] # (*, I, 1, 1, J) dist_bias = self.linear(dist_embed) * self.linear_g(dist_embed).sigmoid() # (*, J, J, H) dist_bias = permute_final_dims(dist_bias, [2, 1, 0])[..., None, :, :, :] # (*, 1, H, J, J) pair_embed = pair_embed + self.dropout(self.mha( q_x=pair_embed, # [*, I, J, C_pair] kv_x=pair_embed, # [*, I, J, C_pair] biases=[mask_bias, dist_bias] # List of [*, I, H, J, J] )) * pair_mask.to(torch.float).unsqueeze(-1) # (*, I, J, C_pair) return pair_embed
[docs] class RowAttentionBlock(nn.Module): inf = 1e9 def __init__(self, node_hidden_dim, pair_hidden_dim, attention_hidden_dim=32, no_heads=4, dropout=0.1, rm_layernorm=False): super(RowAttentionBlock, self).__init__() self.no_heads = no_heads self.attention_hidden_dim = attention_hidden_dim self.pair_hidden_dim = pair_hidden_dim self.node_hidden_dim = node_hidden_dim self.rm_layernorm = rm_layernorm if not self.rm_layernorm: self.layernorm_node_i = LayerNorm(node_hidden_dim) self.layernorm_node_j = LayerNorm(node_hidden_dim) self.layernorm_pair = LayerNorm(pair_hidden_dim) self.linear = Linear(pair_hidden_dim, self.no_heads) self.linear_g = Linear(pair_hidden_dim, self.no_heads) self.dropout = nn.Dropout(dropout) self.mha = Attention(node_hidden_dim, node_hidden_dim, node_hidden_dim, attention_hidden_dim, no_heads)
[docs] def forward(self, node_embed_i, node_embed_j, pair_embed, pair_mask, node_mask_i): if not self.rm_layernorm: node_embed_i = self.layernorm_node_i(node_embed_i) # (*, I, C_node) node_embed_j = self.layernorm_node_j(node_embed_j) # (*, J, C_node) pair_embed = self.layernorm_pair(pair_embed) # (*, I, J, C_pair) mask_bias = (self.inf * (pair_mask.to(torch.float) - 1))[..., None, :, :] # (*, 1, I, J) pair_bias = self.linear(pair_embed) * self.linear_g(pair_embed).sigmoid() # (*, I, J, H) pair_bias = permute_final_dims(pair_bias, [2, 0, 1]) # (*, H, I, J) node_embed_i = node_embed_i + self.dropout(self.mha( q_x=node_embed_i, kv_x=node_embed_j, biases=[mask_bias, pair_bias] )) * node_mask_i.to(torch.float).unsqueeze(-1) return node_embed_i