import torch
import torch.nn as nn
from miniworld.models_MiniWorld_v1_5_use_interaction.Embeddings import MSA_emb, Templ_emb, Recycling, InteractionEmb
from miniworld.models_MiniWorld_v1_5_use_interaction.Track_module import IterativeSimulator
from miniworld.models_MiniWorld_v1_5_use_interaction.AuxiliaryPredictor import DistanceNetwork, MaskedTokenNetwork, ExpResolvedNetwork, LDDTNetwork, PAENetwork, BinderNetwork
from miniworld.models_MiniWorld_v1_5_use_interaction.value_network import StrSeqValueNet
from miniworld.utils.util import rigid_from_3_points
from torch import einsum
NUM_CLASSES = 23
d_init = NUM_CLASSES + 1
[docs]
class MiniWorldModule(nn.Module):
def __init__(self, n_extra_block=4, n_main_block=8, n_ref_block=4,\
d_msa=256, d_pair=128, d_templ=64, d_rbf = 64,
n_head_msa=8, n_head_pair=4, n_head_templ=4,
d_hidden=32, d_hidden_templ=64,
p_drop=0.15,
SE3_param_full={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
SE3_param_topk={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
value_net_param={'d_node':128, 'd_hidden':128, 'num_layers':3, 'topk':16, 'dropout':0.1}
):
super(MiniWorldModule, self).__init__()
print(f"n_extra_block : {n_extra_block} | n_main_block : {n_main_block} | n_ref_block : {n_ref_block}")
#
# Input Embeddings
d_state = SE3_param_topk['l0_out_features']
self.latent_emb = MSA_emb(d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop)
self.templ_emb = Templ_emb(d_pair=d_pair, d_templ=d_templ, d_state=d_state,
n_head=n_head_templ,
d_hidden=d_hidden_templ, p_drop=0.25)
# Update inputs with outputs from previous round
self.value_net = StrSeqValueNet(**value_net_param)
# freeze value_net
for param in self.value_net.parameters():
param.requires_grad = False
self.recycle = Recycling(value_net = self.value_net, d_msa=d_msa, d_pair=d_pair, d_state=d_state)
#
self.interaction_emb_left = InteractionEmb(d_pair=d_pair)
self.interaction_emb_right = InteractionEmb(d_pair=d_pair)
self.simulator = IterativeSimulator(value_net = self.value_net,
n_extra_block=n_extra_block,
n_main_block=n_main_block,
n_ref_block=n_ref_block,
d_msa=d_msa,
d_pair=d_pair, d_hidden=d_hidden, d_rbf=d_rbf,
n_head_msa=n_head_msa,
n_head_pair=n_head_pair,
SE3_param_full=SE3_param_full,
SE3_param_topk=SE3_param_topk,
p_drop=p_drop)
##
self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
self.lddt_pred = LDDTNetwork(d_state)
self.exp_pred = ExpResolvedNetwork(d_msa, d_state)
self.pae_pred = PAENetwork(d_pair)
self.bind_pred = BinderNetwork() #fd - expose n_hidden as variable?
[docs]
def update_value_net(self, value_net):
# weight copy
value_net_state_dict = value_net.state_dict()
self.value_net.load_state_dict(value_net_state_dict)
# freeze value_net
for param in self.value_net.parameters():
param.requires_grad = False
[docs]
def forward(self, msa_latent=None, msa_species = None, msa_zero_pos = None, query_sequence=None, xyz=None, idx=None, interaction_points = None,
template_1D=None, template_2D=None, template_xyz=None, template_alpha=None, mask_t=None, chain_mask=None, # TODO
msa_prev=None, pair_prev=None, state_prev=None, mask_recycle=None,
return_raw=False, return_full=False,
use_checkpoint=False, p2p_crop=-1, topk_crop=-1,
symmids=None, symmsub=None, symmRs=None, symmmeta=None):
if symmids is None:
symmids = torch.tensor([[0]], device=xyz.device) # C1
oligo = symmids.shape[0] # PSK MiniWorld oligo = 1
# msa_latent : (B, N, L, d_init)
# msa_zero_pos : (B, N, L)
# query_sequence : (B, L)
# msa_species : (B, N)
B, N, L = msa_latent.shape[:3]
dtype = msa_latent.dtype
# Get embeddings
# 20231218
msa_species = torch.zeros_like(msa_species) # (B, N)
msa_latent, pair, state = self.latent_emb(msa_latent, msa_species, msa_zero_pos, query_sequence, idx, chain_mask, xyz, symmids)
msa_latent, pair, state = msa_latent.to(dtype), pair.to(dtype), state.to(dtype)
template_1D, template_2D, template_xyz, template_alpha, mask_t = template_1D.to(dtype), template_2D.to(dtype), template_xyz.to(dtype), template_alpha.to(dtype), mask_t.to(dtype)
# list of logit_pae that comes from value network
#
# Do recycling
if msa_prev == None:
msa_prev = torch.zeros_like(msa_latent[:,0])
pair_prev = torch.zeros_like(pair)
state_prev = torch.zeros_like(state)
msa_recycle, pair_recycle, state_recycle = self.recycle(query_sequence, idx, chain_mask, msa_prev, pair_prev, state_prev, xyz, mask_recycle)
msa_recycle, pair_recycle, state_recycle = msa_recycle.to(dtype), pair_recycle.to(dtype), state_recycle.to(dtype)
msa_latent[:,0] = msa_latent[:,0] + msa_recycle
pair = pair + pair_recycle
state = state + state_recycle
# add template embedding
pair, state = self.templ_emb(template_1D, template_2D, template_alpha, template_xyz, mask_t, pair, state, use_checkpoint=use_checkpoint, p2p_crop=p2p_crop, symmids=symmids)
# add interaction information
emb_interaction_left = self.interaction_emb_left(interaction_points) # (B, L, d_pair)
emb_interaction_right = self.interaction_emb_right(interaction_points) # (B, L, d_pair)
emb_interaction_pair = emb_interaction_left.unsqueeze(2) + emb_interaction_right.unsqueeze(1) # (B, L, L, d_pair)
pair[:,1:,1:] = pair[:,1:,1:] + emb_interaction_pair
xyz_original = xyz.clone()
# Predict coordinates from given inputs
msa, pair, R_list, T_list, pair, alpha, state, symmsub = self.simulator(
query_sequence, msa_latent, pair, xyz, state, idx, chain_mask, symmids, symmsub, symmRs, symmmeta,
use_checkpoint=use_checkpoint, p2p_crop=p2p_crop, topk_crop=topk_crop)
# xyz_black_hole = INIT_CRDS.reshape(1,1,27,3).repeat(xyz.shape[0],xyz.shape[1],1,1).to(xyz.device)
# R0, T0 = rigid_from_3_points(xyz[:,:,0], xyz[:,:,1], xyz[:,:,2]) # (B,L,3,3), (B,L,3)
# xyz0 = einsum('blij,blaj->blai', R0, xyz_black_hole) + T0.unsqueeze(-2)
# print(f'diff : {torch.norm(xyz0-xyz)}')
# xyz = einsum('rblij,blaj->rblai', Rs, xyz_black_hole) + Ts.unsqueeze(-2)
R0, T0 = rigid_from_3_points(xyz[:,:,0], xyz[:,:,1], xyz[:,:,2]) # (B,L,3,3), (B,L,3)
R0_inv = R0.transpose(-1,-2)
xyz_0 = einsum('blij,blaj->blai', R0_inv, xyz-T0.unsqueeze(-2))
xyz = einsum('rblij,blaj->rblai', R_list, xyz_0) + T_list.unsqueeze(-2)
# xyz = einsum('rblij,blaj->rblai', Rs, xyz-xyz[:,:,1].unsqueeze(-2)) + Ts.unsqueeze(-2)
if return_raw:
# get last structure
xyz = xyz[-1]
return msa[:,0], pair, state, xyz, alpha[-1], None
# predict masked amino acids
logits_aa = self.aa_pred(msa[:,:,1:])
#
# predict distogram & orientograms
logits = self.c6d_pred(pair[:,1:,1:])
# Predict LDDT
lddt = self.lddt_pred(state[:,1:])
# predict experimentally resolved or not
logits_exp = self.exp_pred(msa[:,0,1:], state[:,1:])
# predict PAE
logits_pae = self.pae_pred(pair[:,1:,1:])
# predict bind/no-bind
p_bind = self.bind_pred(logits_pae[-1].unsqueeze(0),chain_mask)
# p_bind = self.bind_pred(logits_pae,chain_mask)
# print(f"testtsettest logits[0].shape : {logits[0].shape}")
logits = (logits[0].reshape(B,-1,L,L), logits[1].reshape(B,-1,L,L), logits[2].reshape(B,-1,L,L), logits[3].reshape(B,-1,L,L))
logits_pae = logits_pae.reshape(B,-1,L,L)
return logits, logits_aa, logits_exp, logits_pae, p_bind, xyz, alpha, symmsub, lddt, msa[:,0], pair, state