from contextlib import ExitStack
from copy import deepcopy
from collections import OrderedDict
import torch.nn as nn
from miniworld.utils.kinematics import xyz_to_t2d
from miniworld.models_MiniWorld_v1_5_use_interaction.MiniWorld import MiniWorldModule
from miniworld.utils.util_module import XYZConverter
from miniworld.utils.data_refactoring import PDB_parsing, mmcif_parsing
import traceback
from miniworld.utils.chemical import INIT_CRDS
from miniworld.utils import kalign_mapping
from miniworld.utils.template_parser import read_templates
from miniworld.utils import hhpred_parser
from miniworld.utils.util import *
# distributed data parallel
import torch.multiprocessing as mp
from miniworld.utils.output_visualize import visualize_2D_heatmap
from miniworld.utils.ProteinClass import ProteinMSA
from miniworld.utils.ffindex import read_index, read_data
from miniworld.utils.kinematics import xyz_to_t2d
from miniworld.feature.MiniWorld_featuring_species import MSA_featurize_wo_statistics
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
import matplotlib.pyplot as plt
import os
import gzip
import string
import pickle
from collections import namedtuple
USE_AMP = False
NUM_CLASSES = 23
[docs]
def a3m_parse(a3m_file_path):
# Code from RF2 parsers.py
"""
Input : a3m_file_path (.a3m or .a3m.gz)
"""
msa = []
insertion_list = []
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
#print(filename)
if a3m_file_path.split('.')[-1] == 'gz':
f = gzip.open(a3m_file_path, 'rt')
else:
f = open(a3m_file_path, 'r')
# read file line by line
chain_break = {}
for line in f:
if line[0] == '#': # skip # line
continue
# skip labels
if line[0] == '>':
continue
if chain_break == {} and ':' in line:
chains = line.strip().split(':')
chain_start = 0
for chain_idx, chain in enumerate(chains):
chain_end = chain_start + len(chain) - 1
chain_break[chain_idx] = (chain_start, chain_end)
chain_start = chain_end + 1
# remove :
line = line.replace(':','')
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa_i = line.translate(table)
msas_i = msa_i.split('/')
if (len(msa)==0):
# first seq
Ls = [len(x) for x in msas_i]
msa = [[s] for s in msas_i]
else:
nchains = len(msas_i)
isgood = all([ len(msa[i][0]) == len(msas_i[i]) for i in range(nchains) ])
if isgood:
for i in range(nchains):
msa[i].append(msas_i[i])
else:
raise ValueError("Len error", a3m_file_path, len(msa[0]) )
# sequence length
L = sum(Ls)
# 0 - match or gap; 1 - insertion
lower_case = np.array([0 if c.isupper() or c=='-' else 1 for c in line])
insertion = np.zeros((L))
if np.sum(lower_case) > 0:
# positions of insertions
pos = np.where(lower_case==1)[0]
# shift by occurrence
lower_case = pos - np.arange(pos.shape[0])
# position of insertions in cleaned sequence
# and their length
pos,num = np.unique(lower_case, return_counts=True)
# append to the matrix of insetions
insertion[pos] = num
insertion_list.append(insertion)
# concatenate
msa = [np.array([list(s) for s in t], dtype='|S1').view(np.uint8) for t in msa]
msa = np.concatenate(msa,axis=-1)
query_sequence = msa[0]
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
# treat all unknown characters as gaps
msa[msa > 20] = 20
# re-convert to character
for i in range(alphabet.shape[0]):
msa[msa == i] = alphabet[i]
# uint8 -> string, np.array -> list
msa = msa.view('|S1').reshape(msa.shape)
msa = msa.tolist()
msa = ["".join([char.decode('utf-8') for char in sublist]) for sublist in msa]
return msa
[docs]
def visualize_2D_tensor(tensor, sav_dir = "permutation_test/", name = "test"):
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
im = ax.imshow(tensor)
fig.colorbar(im)
plt.savefig(sav_dir+name+".png")
plt.close()
torch.save(tensor, sav_dir+name+".pt")
[docs]
def add_weight_decay(model, l2_coeff):
decay, no_decay = [], []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
#if len(param.shape) == 1 or name.endswith(".bias"):
if "norm" in name or name.endswith(".bias"):
no_decay.append(param)
else:
decay.append(param)
return [{'params': no_decay, 'weight_decay': 0.0}, {'params': decay, 'weight_decay': l2_coeff}]
[docs]
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
[docs]
class EMA(nn.Module):
# From RF2
def __init__(self, model, decay):
super().__init__()
self.decay = decay
self.model = model
self.shadow = deepcopy(self.model)
for param in self.shadow.parameters():
param.detach_()
[docs]
@torch.no_grad()
def update(self):
if not self.training:
# print("EMA update should only be called during training", file=stderr, flush=True)
print("EMA update should only be called during training")
return
model_params = OrderedDict(self.model.named_parameters())
shadow_params = OrderedDict(self.shadow.named_parameters())
# check if both model contains the same set of keys
assert model_params.keys() == shadow_params.keys()
for name, param in model_params.items():
# see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
# shadow_variable -= (1 - decay) * (shadow_variable - variable)
shadow_params[name].sub_((1. - self.decay) * (shadow_params[name] - param)) # in-place subtraction
model_buffers = OrderedDict(self.model.named_buffers())
shadow_buffers = OrderedDict(self.shadow.named_buffers())
# check if both model contains the same set of keys
assert model_buffers.keys() == shadow_buffers.keys()
for name, buffer in model_buffers.items():
# buffers are copied
shadow_buffers[name].copy_(buffer)
[docs]
def forward(self, *args, **kwargs):
if self.training:
return self.model(*args, **kwargs)
else:
return self.shadow(*args, **kwargs)
[docs]
class ModelVisulizer():
def __init__(self,
saved_model_path,
saving_dir ,
device,
model_param={}, input_param={}, batch_size=1, maxcycle=4):
self.saved_model_path = saved_model_path
self.saving_dir = saving_dir
self.device = device
#
self.model_param = model_param
self.input_param = input_param
self.batch_size = batch_size
# for all-atom str loss
self.l2a = long2alt
self.aamask = allatom_mask
self.num_bonds = num_bonds
# from xyz to get xxxx or from xxxx to xyz
self.xyz_converter = XYZConverter()
self.xyz_converter.to(self.device)
self.maxcycle = maxcycle
self.model = EMA(MiniWorldModule(**self.model_param).to(device), 0.99)
# load model
loaded_epoch, best_valid_loss = self.load_model(self.model, self.saved_model_path)
self.model.eval()
[docs]
def load_model(self, model, saved_model_path):
loaded_epoch = -1
best_valid_loss = 999999.9
if not os.path.exists(saved_model_path):
print ('no model found', saved_model_path)
return -1, best_valid_loss
print ('loading model', saved_model_path)
map_location = {"cuda:%d"%0: "cuda:%d"%0}
checkpoint = torch.load(saved_model_path, map_location=map_location)
model.model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.shadow.load_state_dict(checkpoint['model_state_dict'], strict=False)
return loaded_epoch, best_valid_loss
def _prepare_input(self, inputs, gpu):
"""
Input
{
'msa' : {
'sample_seq': torch.tensor, # (Batch, MAXCYCLE, params['CROP'])
'sample_msa_species': torch.tensor, # (Batch, MAXCYCLE, params['CROP'])
'sample_msa_clust': torch.tensor, # (Batch, MAXCYCLE, params['MAX_SEED_MSA'], params['CROP'])
'sample_msa_seed': torch.tensor, # (Batch, MAXCYCLE, params['MAX_SEED_MSA'], params['CROP'], 23 + 1)
'sample_mask_pos': torch.tensor, # (Batch, MAXCYCLE, params['MAX_SEED_MSA'], params['CROP'])
'chain_to_msa_depth': list of dictionary, # non-batched (chain_idx : msa_depth)
},
'template' : {
'xyz' : torch.tensor # (Batch, params['N_PICK_GLOBAL'], params['CROP'], 27, 3)
'template_1D' : torch.tensor # (Batch, params['N_PICK_GLOBAL'], params['CROP'], NUM_CLASSES + 1)
'template_atom_mask': torch.tensor # (Batch, params['N_PICK_GLOBAL'], params['CROP'], 27)
'use_chain_mask' : bool
},
'crop_idx' : torch.tensor, # (Batch, params['CROP'])
'symmetry_related_info' : list, # non-batched
'chain_break' : list of dictionary, # non-batched, (N_chain, 2)
'ID' : list of string, # non-batched
'source' : list of string, # non-batched
}
"""
sample_seq = inputs['msa']['sample_seq'].float()
sample_msa_species = inputs['msa']['sample_msa_species'].float()
sample_msa_cluster = inputs['msa']['sample_msa_clust'].float()
sample_msa_seed = inputs['msa']['sample_msa_seed'].float()
sample_mask_pos = inputs['msa']['sample_mask_pos'].float()
sample_zero_pos = sample_msa_cluster == UNK_IDX | sample_mask_pos.bool() # (B, MAX_CYCLE, MAX_SEED_MSA, CROP)
# transfer inputs to device
B, _, N, L = sample_msa_cluster.shape # _ for MAX_CYCLE
if 'template' in inputs and inputs['template'] is not None:
template_xyz = inputs['template']['xyz']
template_1D = inputs['template']['template_1D']
template_atom_mask = inputs['template']['template_atom_mask']
use_chain_mask_at_template = inputs['template']['use_chain_mask']
else :
template_xyz = INIT_CRDS.reshape(1, 1,1,27,3).repeat(sample_seq.shape[0], self.input_param['N_PICK_GLOBAL'], L ,1,1)
template_atom_mask = torch.zeros((sample_seq.shape[0], self.input_param['N_PICK_GLOBAL'], L, 27))
template_1D = torch.zeros((sample_seq.shape[0], self.input_param['N_PICK_GLOBAL'], L, NUM_CLASSES + 1))
template_1D[:,:,:,GAP_IDX] = 1.0
use_chain_mask_at_template = True
crop_idx = inputs['crop_idx']
chain_break_list = inputs['chain_break']
ID_list = inputs['ID']
ignore_interchain_list = []
interaction_type_list = []
data_source_list = []
xyz_prev = inputs['prev']['xyz']
mask_prev = inputs['prev']['atom_mask']
if self.input_param['USE_MULTIPLE_MODELS'] :
pass
else :
if len(mask_prev.shape) == 4 :
# xyz_prev : (B, M, L, 27, 3)
# mask_prev : (B, M, L, 27)
xyz_prev = xyz_prev[:,0,:,:,:]
mask_prev = mask_prev[:,0,:,:]
crop_idx = crop_idx.float().to(gpu, non_blocking=True) # (B, L)
# generate chain_mask (B, L, L)
chain_mask = torch.zeros((B, L, L), dtype=torch.bool)
model_input_idx = crop_idx.clone()[:,:L]
chain_break_idx_added = 100 # TODO : Check
idx_to_chain_list = []
for bb, chain_break in enumerate(chain_break_list) :
inner_crop_idx = crop_idx[bb].to(torch.int).tolist()
crop_idx_set = set(inner_crop_idx)
temp_chain_break = {key : np.arange(value[0],value[1]+1) for key, value in chain_break.items()}
idx_to_chain_dict = {}
for ii, (key, value) in enumerate(temp_chain_break.items()):
for idx in value:
idx_to_chain_dict[idx] = key
intersection_list = crop_idx_set.intersection(set(value))
idxs = [inner_crop_idx.index(intersection) for intersection in intersection_list]
idxs = torch.tensor(idxs, dtype=torch.long) # (N_chain, )
chain_mask[bb, idxs[:, None], idxs] = True
if ii == 0 or len(intersection_list) == 0 : continue
model_input_idx[bb, idxs[0]:] += chain_break_idx_added
idx_to_chain_list.append(idx_to_chain_dict)
chain_mask = chain_mask.to(gpu, non_blocking=True)
template_xyz = template_xyz.to(gpu, non_blocking=True)
template_1D = template_1D.to(gpu, non_blocking=True)
template_atom_mask = template_atom_mask.to(gpu, non_blocking=True)
xyz_prev = xyz_prev.float().to(gpu, non_blocking=True)
mask_prev = mask_prev.float().to(gpu, non_blocking=True)
sample_seq = sample_seq.long().to(gpu, non_blocking=True)
sample_msa_species = sample_msa_species.long().to(gpu, non_blocking=True)
sample_msa_cluster = sample_msa_cluster.float().to(gpu, non_blocking=True)
sample_msa_seed = sample_msa_seed.float().to(gpu, non_blocking=True)
sample_zero_pos = sample_zero_pos.bool().to(gpu, non_blocking=True)
sample_mask_pos = sample_mask_pos.float().to(gpu, non_blocking=True)
# # processing template features
template_mask_2D = template_atom_mask[:,:,:,:3].all(dim=-1) # (B, T, L)
template_mask_2D = template_mask_2D[:,:,None]*template_mask_2D[:,:,:,None] # (B, T, L, L)
# (ignore inter-chain region)
try :
if isinstance(use_chain_mask_at_template, torch.Tensor):
use_chain_mask_at_template = use_chain_mask_at_template.float().to(gpu, non_blocking=True)
template_mask_2D = template_mask_2D.float() * use_chain_mask_at_template[:, None, :, :] # (B, T, L, L)
elif use_chain_mask_at_template is True:
template_mask_2D = template_mask_2D.float() * chain_mask.float()[:, None, :, :] # (B, T, L, L)
except Exception as e:
print(f"Error \n{e}")
print(f"query length : {L}")
print(f"sample_seq.shape : {sample_seq.shape}")
print(f"sample_msa_cluster.shape : {sample_msa_cluster.shape}")
print(f"ID_list : {ID_list}")
assert 1==0
# mask_t_2d = mask_t_2d.float() * same_chain.float()[:,None]
template_2D = xyz_to_t2d(template_xyz, template_mask_2D) # I don't change this part. (rename) PSK
seq_tmp = template_1D[...,:-1].argmax(dim=-1).reshape(-1,L)
alpha, _, alpha_mask, _ = self.xyz_converter.get_torsions(template_xyz.reshape(-1,L,27,3), seq_tmp, mask_in=template_atom_mask.reshape(-1,L,27))
alpha = alpha.reshape(B,-1,L,10,2)
alpha_mask = alpha_mask.reshape(B,-1,L,10,1)
template_alpha = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, L, 30)
network_input = {}
network_input['msa_latent'] = sample_msa_seed
network_input['msa_species'] = sample_msa_species
network_input['msa_zero_pos'] = sample_zero_pos
network_input['query_sequence'] = sample_seq
# 20230911 PSK
network_input['idx'] = model_input_idx
network_input['template_1D'] = template_1D
network_input['template_2D'] = template_2D
network_input['template_xyz'] = template_xyz[:,:,:,1]
network_input['template_alpha'] = template_alpha
network_input['mask_t'] = template_mask_2D
network_input['chain_mask'] = chain_mask.float()
network_input['interaction_points'] = inputs['interaction_points'].long().to(gpu, non_blocking=True)
mask_recycle = mask_prev[:,:,:3].bool().all(dim=-1)
mask_recycle = mask_recycle[:,:,None]*mask_recycle[:,None,:] # (B, L, L)
# mask_recycle = chain_mask.float()*mask_recycle.float()
return network_input, xyz_prev, mask_recycle, sample_msa_cluster, ID_list, idx_to_chain_list, crop_idx
def _get_model_input(self, network_input, output_i, mask_recycle, i_cycle, return_raw=False, use_checkpoint=False):
# I changed name of keys. PSK
input_i = {}
for key in network_input:
if key in ['msa_latent', 'msa_species', 'msa_zero_pos', 'query_sequence']:
input_i[key] = network_input[key][:,i_cycle]
else:
input_i[key] = network_input[key]
logits, logits_aa, logits_exp, logits_pae, p_bind, xyz_prev, alpha, symmsub, lddt, msa_prev, pair_prev, state_prev = output_i
if len(xyz_prev.shape) == 5:
xyz_prev = xyz_prev[-1]
input_i['msa_prev'] = msa_prev
input_i['pair_prev'] = pair_prev[-1].unsqueeze(0) if pair_prev is not None else None
input_i['state_prev'] = state_prev
input_i['xyz'] = xyz_prev
input_i['mask_recycle'] = mask_recycle
input_i['return_raw'] = return_raw
input_i['use_checkpoint'] = use_checkpoint
return input_i
[docs]
def draw_distogram(self, logit_s, saving_path):
PARAMS = {
"DMIN" : 2.0,
"DMAX" : 20.0,
"DBINS" : 36,
"ABINS" : 36,
}
# logit_s[0] : (I,1,37,L,L)
distogram = logit_s[0].squeeze(1) # (I,37,L,L)
nbins = PARAMS['DBINS'] + 1
distance = torch.linspace(PARAMS['DMIN'], PARAMS['DMAX'], nbins).unsqueeze(0).to(distogram.device) # (1, DBINS)
avg_distogram = torch.sum(distogram * distance.unsqueeze(-1).unsqueeze(-1), dim=1) # (I, L, L)
avg_distogram = avg_distogram.cpu().numpy()
# Distogram has I number of distograms, so draw I images I columns
I = avg_distogram.shape[0]
fig, axs = plt.subplots(1, I, figsize=(I*5, 5))
for i in range(I):
axs[i].imshow(avg_distogram[i], cmap='hot', interpolation='nearest')
axs[i].set_title(f"Distogram {i}")
plt.savefig(saving_path)
plt.close()
[docs]
def write_pae_plddt(self, save_path, ID, pae, plddt):
with open(save_path, "a") as f:
f.write(f"{ID} {pae} {plddt}\n")
[docs]
def model_inference(self, inputs, device, saving_dir = None, saving_intermediate = False, only_structure = True):
# move some global data to cuda device
self.l2a = self.l2a.to(device)
self.aamask = self.aamask.to(device)
self.xyz_converter = self.xyz_converter.to(device)
self.num_bonds = self.num_bonds.to(device)
initial_crds_list = []
pred_crds_list = []
sequence_list = []
ID_list = []
crop_idx_list = []
idx_to_chain_list = []
with torch.no_grad():
network_input, xyz_prev, mask_recycle, msa, inner_ID_list, inner_idx_to_chain_list, crop_idx = self._prepare_input(inputs, device)
N_cycle = self.maxcycle # number of recycling
query_sequence = network_input['query_sequence'][0]
ID = inner_ID_list[0]
ID_list.append(ID)
sequence_list.append(query_sequence)
output_i = (None, None, None, None, None, xyz_prev, None, None, None, None, None, None)
for i_cycle in range(N_cycle):
with ExitStack() as stack:
if i_cycle < N_cycle - 1:
stack.enter_context(torch.no_grad())
stack.enter_context(torch.cuda.amp.autocast(enabled=USE_AMP))
return_raw=False
use_checkpoint=False
else:
stack.enter_context(torch.cuda.amp.autocast(enabled=USE_AMP))
return_raw=False
use_checkpoint=False
input_i = self._get_model_input(network_input, output_i, mask_recycle, i_cycle, return_raw=return_raw, use_checkpoint=use_checkpoint)
output_i = self.model(**input_i)
B = network_input['query_sequence'].shape[0]
logit_s, logit_aa_s, logit_exp, logit_pae, p_bind, pred_crds, alphas, symmsubs, pred_lddts, _, pair, state = output_i
# print(f"xyz_prev : {xyz_prev}")
# print(f"pred_crds[-1,0,:,:14,:] : {pred_crds[0,0,:,:14,:]}")
pae = torch.nn.functional.softmax(logit_pae, dim=1) # (B, 64, L, L)
nbin = pae.shape[1]
bin_value = 0.5 * torch.arange(nbin, dtype=torch.float32, device=pae.device)
expected_error = torch.einsum('bnij, n -> bij', pae, bin_value) # (B, L, L)
nbin = pred_lddts.shape[1]
bin_step = 1.0 / nbin
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddts.dtype, device=pred_lddts.device)
plddt = torch.nn.functional.softmax(pred_lddts, dim=1) # (B, 64, L)
expected_lddt = torch.einsum('bni, n -> bi', plddt, lddt_bins) # (B, L)
average_pae = torch.mean(expected_error, dim=(1, 2)) # (B, )
average_lddt = torch.mean(expected_lddt, dim=1) # (B, )
average_pae = average_pae.item()
average_lddt = average_lddt.item()
B, _, N, L = msa.shape
logit_aa_s = torch.nn.functional.pad(logit_aa_s,(0,0,0,1),value=0) # (B, 22, CROP*CROP)
# And logit_aa_s 20 <-> 21
logit_aa_s[:,21] = logit_aa_s[:,20]
# logit_aa_s[:,21] = -99999
logit_aa_s[:,20] = 0
logit_aa_s = torch.nn.functional.pad(logit_aa_s,(0,0,0,1),value=0) # (B, 22, CROP*CROP)
# And logit_aa_s 20 <-> 21
logit_aa_s[:,21] = logit_aa_s[:,20]
# logit_aa_s[:,21] = -99999
logit_aa_s[:,20] = 0
msa_in = msa.long()[0,i_cycle] # (N, L)
if not only_structure:
visualize_2D_heatmap(msa_in, file_name = ID + "_msa", heatmap_dir = saving_dir + "msa/")
self.write_pae_plddt(saving_dir + "pae_plddt.txt", ID, average_pae, average_lddt)
I = pred_crds.shape[0]
if saving_intermediate :
for block_idx in range(I):
predRs_all, pred_all = self.xyz_converter.compute_all_atom(query_sequence[0:1].to(torch.int), pred_crds[block_idx], alphas[block_idx])
self.write_pdb([pred_all[0,:,:14,:]], [query_sequence[0]], [ID], expected_lddt, inner_idx_to_chain_list, crop_idx, File_suffix="pred_" +str(i_cycle)+"_"+str(block_idx), saving_dir = saving_dir)
else :
if i_cycle == N_cycle - 1:
predRs_all, pred_all = self.xyz_converter.compute_all_atom(query_sequence[0:1].to(torch.int), pred_crds[-1], alphas[-1])
# self.draw_distogram(logit_s, saving_dir + str(ID) + "_distogram_" + str(i_cycle) + ".png")
self.write_pdb([pred_all[0,:,:14,:]], [query_sequence[0]], [ID], expected_lddt, inner_idx_to_chain_list, crop_idx, File_suffix="pred_" +str(i_cycle), saving_dir = saving_dir)
# draw_pae(logit_pae = value, mask= None, saving_path = saving_dir + "pae/" + str(ID) + "_pae_from_value_" + str(i_cycle) + ".png")
if i_cycle == N_cycle - 1:
if not only_structure:
if not os.path.exists(saving_dir + "pae/"):
os.makedirs(saving_dir + "pae/")
draw_pae(logit_pae = logit_pae, mask= None, saving_path = saving_dir + "pae/" + str(ID) + "_pae_from_logit_pae_" + str(i_cycle) + ".png")
initial_crds_list.append(xyz_prev[0,:,:14,:]) # (256,14,3)
pred_crds_list.append(pred_all[0,:,:14,:]) # (256,14,3)
idx_to_chain_list.append(inner_idx_to_chain_list[0])
crop_idx_list.append(crop_idx[0])
torch.cuda.empty_cache()
return pred_crds_list, sequence_list, ID_list
[docs]
def ATOM_REPR(self, atom):
atom = atom.strip()
# remove empty spcae
atom = atom.replace(' ', '')
return atom[0]
[docs]
def write_pdb(self, pred_crds_list, seq_list, ID_list, lddt_list, idx_to_chain_list, crop_idx, File_suffix, saving_dir = None):
saving_dir = self.saving_dir + "pdb/" if saving_dir is None else saving_dir
if not os.path.exists(saving_dir):
os.makedirs(saving_dir)
for xyz, seq, ID, idx_to_chain, crop_idx, lddt in zip(pred_crds_list, seq_list, ID_list, idx_to_chain_list, crop_idx, lddt_list):
crop_idx = crop_idx.tolist()
lines = []
atom_idx = 1
residue_idx = 1
before_chain = ""
start = True
for ll in range(xyz.shape[0]):
idx = crop_idx[ll]
if idx == -1 : continue
chain = idx_to_chain[idx]
if chain != before_chain:
if start :
start = False
else :
lines.append("TER\n")
residue_idx = 1
before_chain = chain
xyz_ll = xyz[ll] # (14,3)
seq_ll = int(seq[ll].item())
residue = num2aa[seq_ll]
atom_list = aa2long[seq_ll][:14]
lddt_i = lddt[ll].item()
lddt_i = round(lddt_i*100, 2) # 0.00 ~ 100.00
# if lddt_i = 16.6 -> 16.60
lddt_i = "{:.2f}".format(lddt_i)
# remove None
atom_list = [item for item in atom_list if item is not None]
for aa, atom in enumerate(atom_list):
atom_xyz = xyz_ll[aa]
atom_xyz = atom_xyz.cpu().numpy()
atom_xyz = atom_xyz.tolist()
atom_xyz = [round(item,3) for item in atom_xyz]
atom_xyz = [str(item) for item in atom_xyz]
atom_x = atom_xyz[0]
atom_y = atom_xyz[1]
atom_z = atom_xyz[2]
if float(atom_x) == 0 and float(atom_y) == 0 and float(atom_z) == 0:
continue
one_atom = self.ATOM_REPR(atom)
atom_line = f"ATOM {atom_idx:>5} {atom:>4} {residue:>3}{chain:>2}{residue_idx:>4} {atom_x:>8}{atom_y:>8}{atom_z:>8} 1.00 {lddt_i:>3} {one_atom:>2}\n"
lines.append(atom_line)
atom_idx += 1
residue_idx += 1
saving_path = saving_dir + "/" + ID + '_{}.pdb'.format(File_suffix)
with open(saving_path, 'w') as f:
f.writelines(lines)
print(f"Saving {ID} done.")
[docs]
def get_msa(self, a3m_path):
msa = ProteinMSA(a3m_file_path=a3m_path)
return msa
[docs]
def get_fasta(self, fasta_path):
# Code from RF2 parsers.py
"""
Input : a3m_file_path (.a3m or .a3m.gz)
"""
msa = []
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
#print(filename)
if fasta_path.split('.')[-1] == 'gz':
f = gzip.open(fasta_path, 'rt')
else:
f = open(fasta_path, 'r')
# read file line by line
chain_break = {}
for line in f:
if line[0] == '#': # skip # line
continue
# skip labels
if line[0] == '>':
continue
if len(chain_break) == 0 and ':' in line:
chains = line.strip().split(':')
chain_start = 0
for chain_idx, chain in enumerate(chains):
chain_end = chain_start + len(chain) - 1
chain_break[chain_idx] = (chain_start, chain_end)
chain_start = chain_end + 1
self.chain_break = chain_break
# remove :
line = line.replace(':','')
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
seq_i = line.translate(table)
seq_i = seq_i.split('/')
if (len(msa)==0):
# first seq
Ls = [len(x) for x in seq_i]
msa = [[s] for s in seq_i]
else:
nchains = len(seq_i)
isgood = all([ len(msa[i][0]) == len(seq_i[i]) for i in range(nchains) ])
if isgood:
for i in range(nchains):
msa[i].append(seq_i[i])
else:
raise ValueError("Len error", fasta_path, len(msa[0]) )
# sequence length
L = sum(Ls)
break
if len(chain_break) == 0:
chain_break = {0 : (0, L-1)}
# concatenate
msa = [np.array([list(s) for s in t], dtype='|S1').view(np.uint8) for t in msa]
msa = np.concatenate(msa,axis=-1)
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
msa[msa>20] = 20
query_sequence = msa[0]
query_sequence = torch.tensor(query_sequence) # (L, ) torch.Tensor
chain_break = {str(key) : value for key, value in chain_break.items()}
return query_sequence, chain_break
[docs]
def load_initial_xyz(self, initial_xyz_path, is_protein=True):
if initial_xyz_path is None :
return None, None
if '.pdb' in initial_xyz_path :
protein = PDB_parsing(initial_xyz_path)
initial_xyz = protein.structure.xyz
initial_atom_mask = protein.structure.atom_mask
elif is_protein :
with open(initial_xyz_path, 'rb') as f:
protein = pickle.load(f)
initial_xyz = protein.structure.xyz
initial_atom_mask = protein.structure.atom_mask
else :
initial = torch.load(initial_xyz_path)
initial_xyz = initial['xyz']
initial_atom_mask = initial['atom_mask']
return initial_xyz, initial_atom_mask
[docs]
def symmetric_a3m_to_full_a3m(self, a3m_file_path, fasta_file_path, saving_path):
with open(a3m_file_path, "r") as f:
lines = f.readlines()
first_line = lines[0].strip()
input_seq = lines[2].strip()
parsed_msa = a3m_parse(a3m_file_path)
# remove #
first_line = first_line[1:]
seq_len_list, n_list = first_line.split("\t")
seq_len_list = seq_len_list.split(",")
n_list = n_list.split(",")
chain_to_seq_len = {ii: int(n) for ii, n in enumerate(seq_len_list)}
chain_to_n = {ii: n for ii, n in enumerate( n_list)}
chain_break = {}
seq_to_chain = {}
with open(fasta_file_path, "r") as f:
fasta_lines = f.readlines()
full_input_seq = fasta_lines[1].strip().split(":")
split = 0
for idx in chain_to_n.keys():
chain_start, chain_end = split, split + int(chain_to_seq_len[idx]) -1
chain_break[idx] = (chain_start, chain_end)
split += int(chain_to_seq_len[idx])
seq_to_chain[input_seq[chain_start:chain_end+1]] = idx
chain_to_msa = {}
for chain in chain_to_n.keys():
chain_to_msa[chain] = []
for line in parsed_msa:
line = line.strip()
for chain in chain_break.keys():
chain_start, chain_end = chain_break[chain]
chain_to_msa[chain].append(line[chain_start:chain_end + 1])
new_lines = []
add_idx = 0
for line_idx in range(len(lines)):
if line_idx == 0:
new_lines.append(lines[line_idx])
continue
line = lines[line_idx]
if line.startswith(">"):
new_lines.append(line)
else :
new_line = ""
for seq in full_input_seq:
chain = seq_to_chain[seq]
new_line += chain_to_msa[chain][add_idx]
new_line += "\n"
new_lines.append(new_line)
add_idx += 1
with open(saving_path, "w") as f:
for line in new_lines:
f.write(line)
[docs]
def symmetric_a3m_to_diagonal_a3m(self, a3m_file_path, fasta_file_path, saving_path):
with open(a3m_file_path, "r") as f:
lines = f.readlines()
first_line = lines[0].strip()
input_seq = lines[2].strip()
parsed_msa = a3m_parse(a3m_file_path)
# remove #
first_line = first_line[1:]
seq_len_list, n_list = first_line.split("\t")
seq_len_list = seq_len_list.split(",")
n_list = n_list.split(",")
chain_to_seq_len = {ii: int(n) for ii, n in enumerate(seq_len_list)}
chain_to_n = {ii: n for ii, n in enumerate(n_list)}
chain_break = {}
seq_to_chain = {}
with open(fasta_file_path, "r") as f:
fasta_lines = f.readlines()
full_input_seq = fasta_lines[1].strip().split(":")
total_len = sum([len(seq) for seq in full_input_seq])
split = 0
for idx in chain_to_n.keys():
chain_start, chain_end = split, split + int(chain_to_seq_len[idx]) -1
chain_break[idx] = (chain_start, chain_end)
split += int(chain_to_seq_len[idx])
seq_to_chain[input_seq[chain_start:chain_end+1]] = idx
chain_to_msa = {}
for chain in chain_to_n.keys():
chain_to_msa[chain] = []
for line in parsed_msa:
line = line.strip()
for chain in chain_break.keys():
chain_start, chain_end = chain_break[chain]
chain_to_msa[chain].append(line[chain_start:chain_end + 1])
output = {}
current_index = 0
for seq in full_input_seq:
# Calculate positions for the current sequence
indices = list(range(current_index, current_index + len(seq)))
# Check if sequence already has entries in the dictionary
if seq in output:
# Append new indices
output[seq].extend(indices)
else:
# Start new entry in dictionary with current indices
output[seq] = indices
# Update current index for the next sequence
current_index += len(seq)
# Convert dictionary keys from sequence to their initial indices based on order of first appearance
full_seq = "".join(full_input_seq)
new_lines = []
add_idx = 0
new_lines.append(">query\n")
new_lines.append("".join(full_input_seq)+"\n")
chain_start = 0
for seq, idxs in output.items():
chain = seq_to_chain[seq]
chain_end = chain_start + chain_to_seq_len[chain] - 1
for chain_msa_seq in chain_to_msa[chain]:
chain_idxs = copy.deepcopy(idxs)
rest = len(chain_idxs)
new_seq = ""
while rest > 0 :
new_seq += "-" * (chain_idxs[0] - len(new_seq))
new_seq += chain_msa_seq
chain_idxs = chain_idxs[len(chain_msa_seq):]
rest = len(chain_idxs)
new_seq += "-" * (total_len - len(new_seq))
new_lines.append(">\n")
new_lines.append(new_seq+"\n")
chain_start = chain_end + 1
with open(saving_path, "w") as f:
for line in new_lines:
f.write(line)
[docs]
@torch.no_grad()
def inference_from_raw_data(self, raw_file_dict, template_mode = 'use_hhr', a3m_mode='use_raw_a3m', saving_dir = None, saving_intermediate = False, only_structure = True, symmetry = None, msa_save_path = None,):
ID = raw_file_dict['ID']
fasta_path = raw_file_dict['fasta']
if a3m_mode == "single_sequence" or 'a3m' not in raw_file_dict.keys():
a3m_path = raw_file_dict['fasta']
elif a3m_mode == "use_colabfold_paired_a3m":
a3m_path = raw_file_dict['a3m']
self.symmetric_a3m_to_full_a3m(a3m_path, fasta_path, self.saving_dir + "temp.a3m")
a3m_path = self.saving_dir + "temp.a3m"
elif a3m_mode == "use_colabfold_diagonal_a3m":
a3m_path = raw_file_dict['a3m']
self.symmetric_a3m_to_diagonal_a3m(a3m_path, fasta_path, self.saving_dir + "temp.a3m")
a3m_path = self.saving_dir + "temp.a3m"
elif a3m_mode == "use_raw_a3m":
a3m_path = raw_file_dict['a3m']
# Load MSA
try:
msa = self.get_msa(a3m_path)
query_sequence, chain_break = self.get_fasta(fasta_path)
except Exception as e:
# TODO There can be no msa or template because the chain is too short. We should think about this case.
print(traceback.format_exc())
print(e, end = " | ")
return f"Error : {e}"
# Load templates
query_length = query_sequence.shape[0]
template_path_dict = raw_file_dict['templates'] if 'templates' in raw_file_dict.keys() else None # {chain_idx : template_path}
n_pick_global = self.input_param['N_PICK_GLOBAL']
template_featurized_dict = {
# 'xyz' : torch.zeros((n_pick_global,query_length,27,3)), # (N_template, L_query, 27, 3)
'xyz' : INIT_CRDS.reshape(1,1,1,27,3).repeat(1,n_pick_global, query_length ,1,1) , # (B, N_template,L, 27, 3)
'template_1D' : torch.full((1,n_pick_global,query_length,NUM_CLASSES + 1), GAP_IDX).float(), # (B, N_template, L_query, NUM_CLASSES + 1)
'template_atom_mask' : torch.full((1,n_pick_global,query_length,27), False), # (B, N_template, L_query, 27)
'use_chain_mask' : True,
}
if template_path_dict is not None :
if template_mode == 'use_hhr':
for chain_idx, template_path in template_path_dict.items():
chain_start, chain_end = chain_break[chain_idx]
qlen = chain_end - chain_start + 1
hhr_path = template_path['hhr']
atab_path = template_path['atab']
xyz, f1d, mask_t = read_templates(qlen, ffdb, hhr_path, atab_path, offset=0, n_templ=n_pick_global, random_noise=5.0)
template_featurized_dict['xyz'][:, :, chain_start:(chain_end+1), :, :] = xyz
template_featurized_dict['template_1D'][:, :, chain_start:(chain_end+1), :] = f1d
template_featurized_dict['template_atom_mask'][:, :, chain_start:(chain_end+1), :] = mask_t
template_featurized_dict['use_chain_mask'] = True
elif template_mode == 'use_complex_pdb':
# already aligned
pdb_path_list = template_path_dict['pdb_path_list']
N_template = len(pdb_path_list)
seq_identity = template_path_dict['seq_identity'] if 'seq_identity' in template_path_dict.keys() else None # (N_template, L)
if seq_identity is None:
seq_identity = torch.ones((N_template, query_length))
for ii, pdb_path in enumerate(pdb_path_list):
protein = PDB_parsing(pdb_path)
xyz = protein.structure.xyz[0] # (L, 27, 3)
seq = protein.sequence.sequence[0] # (L, )
seq_onehot = nn.functional.one_hot(seq, num_classes=NUM_CLASSES).float() # (L, NUM_CLASSES)
f1d = torch.concat([seq_onehot, seq_identity[ii].unsqueeze(-1)], dim=-1) # (L, NUM_CLASSES + 1)
mask_t = protein.structure.atom_mask # (L, 27)
template_featurized_dict['xyz'][:, ii, :, :14, :] = xyz.unsqueeze(0)
template_featurized_dict['template_1D'][:, ii, :, :] = f1d.unsqueeze(0)
template_featurized_dict['template_atom_mask'][:, ii, :, :14] = mask_t.unsqueeze(0)
template_featurized_dict['use_chain_mask'] = False
elif template_mode == 'use_pdb_per_chain':
# use only one template
template_chain_mask = torch.zeros((1, query_length, query_length), dtype=torch.bool)
seq_alignment_score = template_path_dict['seq_alignment_score'] if 'seq_alignment_score' in template_path_dict.keys() else 0.9
cif_path_dict = template_path_dict['cif_path_dict']
template_dict = {}
for ID, cif_path in cif_path_dict.items():
template_protein = mmcif_parsing(cif_path)
template_xyz = template_protein.structure.xyz[0] # (L_template, 27, 3)
template_seq = template_protein.sequence.sequence[0] # (L_template, )
template_atom_mask = template_protein.structure.atom_mask[0] # (L_template, 27)
template_chain_break = template_protein.structure.chain_break # {chain_idx : (start, end)}
template_dict[ID] = {
'xyz' : template_xyz,
'seq' : template_seq,
'atom_mask' : template_atom_mask,
'chain_break' : template_chain_break,
}
per_chain_dict = template_path_dict['per_chain_dict']
valid_idx_group = {}
before_chain_idx = -1
for chain_idx, item in per_chain_dict.items():
if 'template' not in item.keys():
# AF2 prediction (for F chain)
chain_start, chain_end = chain_break[chain_idx]
pdb_path = item['pdb_path']
template_protein, lddt = PDB_parsing(pdb_path, return_lddt=True)
lddt = lddt / 100 # for test
template_xyz = template_protein.structure.xyz[0] # (L_template, 27, 3)
template_seq = template_protein.sequence.sequence[0] # (L_template, )
template_atom_mask = template_protein.structure.atom_mask[0] # (L_template, 27)
template_seq_onehot = nn.functional.one_hot(template_seq, num_classes=NUM_CLASSES).float() # (L_template, NUM_CLASSES)
template_featurized_dict['xyz'][:, 0, chain_start:chain_end+1, :14, :] = template_xyz
template_featurized_dict['template_1D'][:, 0, chain_start:chain_end+1, :NUM_CLASSES] = template_seq_onehot
template_featurized_dict['template_1D'][:, 0, chain_start:chain_end+1, NUM_CLASSES] = lddt
template_featurized_dict['template_atom_mask'][:, 0, chain_start:chain_end+1, :14] = template_atom_mask
chain_valid_idx = torch.arange(chain_start, chain_end+1)
valid_idx_grid_x, valid_idx_grid_y = torch.meshgrid(chain_valid_idx, chain_valid_idx)
template_chain_mask[0, valid_idx_grid_x, valid_idx_grid_y] = True
continue
template_PDB_ID, template_chain_idx = item['template'].split("_")
kalign_path = item['kalign_result_path']
try:
chain_start, chain_end = chain_break[chain_idx]
except:
breakpoint()
template_xyz = template_dict[template_PDB_ID]['xyz']
template_seq = template_dict[template_PDB_ID]['seq']
template_atom_mask = template_dict[template_PDB_ID]['atom_mask']
template_chain_break = template_dict[template_PDB_ID]['chain_break']
template_chain_start, template_chain_end = template_chain_break[template_chain_idx]
if valid_idx_group.get(template_PDB_ID) is None:
valid_idx_group[template_PDB_ID] = []
valid_idx_group[template_PDB_ID].append(torch.arange(chain_start, chain_end+1))
query_map_idx, target_map_idx = kalign_mapping(kalign_path)
query_map_idx = query_map_idx + before_chain_idx + 1
before_chain_idx = chain_end
template_seq_alignment_score = torch.full_like(query_map_idx, seq_alignment_score).unsqueeze(0).float()
template_chain_xyz = template_xyz[template_chain_start:template_chain_end+1] # (L_chain, 27, 3)
template_chain_seq = template_seq[template_chain_start:template_chain_end+1] # (L_chain, )
template_chain_atom_mask = template_atom_mask[template_chain_start:template_chain_end+1] # (L_chain, 27)
template_chain_seq_onehot = nn.functional.one_hot(template_chain_seq, num_classes=NUM_CLASSES).float() # (L_chain, NUM_CLASSES)
template_featurized_dict['xyz'][:, 0, query_map_idx, :14, :] = template_chain_xyz[target_map_idx]
template_featurized_dict['template_1D'][:, 0, query_map_idx, :NUM_CLASSES] = template_chain_seq_onehot[target_map_idx].unsqueeze(0)
template_featurized_dict['template_1D'][:, 0, query_map_idx, NUM_CLASSES] = template_seq_alignment_score
template_featurized_dict['template_atom_mask'][:, 0, query_map_idx, :14] = template_chain_atom_mask[target_map_idx].bool()
for template_PDB_ID, valid_idx in valid_idx_group.items():
valid_idx = torch.cat(valid_idx) # (L,)
valid_idx_grid_x, valid_idx_grid_y = torch.meshgrid(valid_idx, valid_idx, indexing='ij')
template_chain_mask[0, valid_idx_grid_x, valid_idx_grid_y] = True
template_featurized_dict['use_chain_mask'] = template_chain_mask
elif template_mode == 'use_complex_prediction_pdb':
# already aligned
pdb_path_list = template_path_dict['pdb_path_list']
print(f'pdb_path_list : {pdb_path_list}')
for ii, pdb_path in enumerate(pdb_path_list):
protein, lddt = PDB_parsing(pdb_path, return_lddt=True)
lddt = lddt[0].unsqueeze(1)/100 # (L, 1)
xyz = protein.structure.xyz[0]
seq = protein.sequence.sequence[0]
seq_onehot = nn.functional.one_hot(seq, num_classes=NUM_CLASSES).float() # (L, NUM_CLASSES)
f1d = torch.concat([seq_onehot, lddt], dim=-1) # (L, NUM_CLASSES + 1)
mask_t = protein.structure.atom_mask[0] # (L, 27)
template_featurized_dict['xyz'][0, ii, :, :14, :] = xyz
template_featurized_dict['template_1D'][0, ii, :, :] = f1d
template_featurized_dict['template_atom_mask'][0, ii, :, :14] = mask_t
if ii > n_pick_global:
break
template_featurized_dict['use_chain_mask'] = torch.ones((1, query_length, query_length), dtype=torch.bool)
elif template_mode == 'use_hhpred_output':
# aligned by hhpred
# it requires pdb_path for each template
pdb_path_list = template_path_dict['pdb_path_list']
query_ID = template_path_dict['query_ID_used_in_hhpred']
hhpred_hhr_file = template_path_dict['hhpred_hhr_file']
parsed_info_list = hhpred_parser(hhpred_hhr_file, query_ID)
# filter out templates that are not in pdb_path_list
filtered_parsed_info_list = []
pdb_ID_list = []
for pdb_path in pdb_path_list:
pdb_ID_list.append(pdb_path.split("/")[-1].split(".")[0])
ID_to_pdb_and_parsed_info = {}
for parsed_info in parsed_info_list:
for pdb_ID in pdb_ID_list:
if pdb_ID in parsed_info['target_ID']:
ID_to_pdb_and_parsed_info[pdb_ID] = {
'parsed_info' : parsed_info,
'pdb_path' : pdb_path_list[pdb_ID_list.index(pdb_ID)]
}
if len(ID_to_pdb_and_parsed_info) == 0:
print(f"=== WARNING === : No template is found in pdb_path_list")
template_idx = 0
for pdb_ID, item_dict in ID_to_pdb_and_parsed_info.items():
try:
pdb_path = item_dict['pdb_path']
parsed_info = item_dict['parsed_info']
# protein = PDB_parsing(pdb_path)
protein = mmcif_parsing(pdb_path) # TODO symmetry operation can be needed.
template_xyz = protein.structure.xyz[0] # (L, 27, 3)
template_seq = protein.sequence.sequence[0] # (L, )
query_map_idx = parsed_info['query_map_idx'] # (L_match, )
target_map_idx = parsed_info['target_map_idx'] # (L_match, )
confidence = parsed_info['confidence'] # (L_match, )
template_seq_onehot = nn.functional.one_hot(template_seq, num_classes=NUM_CLASSES).float() # (L, NUM_CLASSES)
template_atom_mask = protein.structure.atom_mask[0] # (L, 27)
template_featurized_dict['xyz'][:, template_idx, query_map_idx, :14, :] = template_xyz[target_map_idx]
template_featurized_dict['template_1D'][:, template_idx, query_map_idx, :NUM_CLASSES] = template_seq_onehot[target_map_idx].unsqueeze(0)
template_featurized_dict['template_1D'][:, template_idx, query_map_idx, NUM_CLASSES] = confidence
template_featurized_dict['template_atom_mask'][:, template_idx, query_map_idx, :14] = template_atom_mask[target_map_idx].bool()
except Exception as e:
breakpoint()
# Load initial xyz
if 'initial_xyz' in raw_file_dict.keys():
initial_xyz, initial_atom_mask = self.load_initial_xyz(raw_file_dict['initial_xyz'])
else :
initial_xyz = None
initial_atom_mask = None
if symmetry is not None :
# TODO
msa_tensor = msa.msa
self.save_clean_msa(msa_tensor, symmetry, msa_save_path)
# Load interaction points
interaction_dict = raw_file_dict['interaction_points'] if 'interaction_points' in raw_file_dict.keys() else None
interaction_points_input = torch.zeros((1, query_length))
if interaction_dict is not None :
before_chain_idx = -1
positive_dict = interaction_dict['positive']
negative_dict = interaction_dict['negative']
for chain_idx in positive_dict.keys():
pos_idx_list = positive_dict[chain_idx]
neg_idx_list = negative_dict[chain_idx]
for idx in pos_idx_list:
idx += before_chain_idx + 1
interaction_points_input[0, idx] = 1
for idx in neg_idx_list:
idx += before_chain_idx + 1
interaction_points_input[0, idx] = 2
chain_start, chain_end = chain_break[chain_idx]
before_chain_idx = chain_end
# print(f"test chain_break : {chain_break}")
query_length = query_sequence.shape[0]
sequence_idxs = []
for chain_idx, chain_break_tuple in chain_break.items():
chain_start, chain_end = chain_break_tuple
sequence_idxs += list(range(chain_start, chain_end + 1))
sequence_idxs = torch.Tensor(sequence_idxs).long()
out_of_sequence_idxs = torch.Tensor(list(set(range(query_length)) - set(sequence_idxs.tolist()))).long()
crop_idx = sequence_idxs
inputs = {}
msa_feature_dict = MSA_featurize_wo_statistics(msa.msa, msa.insertion, chain_break, self.input_param)
# batchify
for key in msa_feature_dict.keys():
if isinstance(msa_feature_dict[key], torch.Tensor):
msa_feature_dict[key] = msa_feature_dict[key].unsqueeze(0)
else :
msa_feature_dict[key] = [msa_feature_dict[key]]
inputs["msa"] = msa_feature_dict
inputs['template'] = template_featurized_dict
inputs['interaction_points'] = interaction_points_input
# cropping
inputs["crop_idx"] = crop_idx.unsqueeze(0) # (1,L_crop)
inputs["chain_break"] = [chain_break] # (N_chain : 2)
# TODO : prev
# N, Ca, C initial coord is in chemical.py
if initial_xyz is None:
random_xyz, atom_mask = generate_initial_xyz(query_sequence.unsqueeze(0), chain_break)
inputs["prev"] = {
"xyz" : random_xyz.unsqueeze(0), # (1, L_crop, 27, 3)
'atom_mask' : atom_mask.unsqueeze(0), # (1, L_crop, 27)
}
else :
inputs["prev"] = {
"xyz" : initial_xyz.unsqueeze(0), # (1, L_crop, 27, 3)
'atom_mask' : initial_atom_mask.unsqueeze(0), # (1, L_crop, 27)
}
inputs["ID"] = [raw_file_dict['ID']]
inputs["source"] = ["PDB"]
inputs["symmetry_related_info"] = [None]
if saving_dir is None :
saving_dir = self.saving_dir
self.model_inference(inputs, self.device, saving_dir, saving_intermediate = saving_intermediate, only_structure = only_structure)
[docs]
def save_clean_msa(self, msa, symmetry, save_path):
# msa : (N,L)
N, L = msa.shape
# TODO : symmetry
new_msa = torch.zeros((N,2*L)).long()
new_msa[:,:L] = msa
new_msa[:,L:] = msa
with open(save_path, 'w') as f:
for n in range(N):
f.write(f">{n}\n")
f.write("".join([num2AA[aa] for aa in new_msa[n].tolist()]) + "\n")
if __name__ == "__main__":
from miniworld.utils import get_args
args, model_param, input_json = get_args()
# I use my own dataloader, so input_params are below. PSK
seed_num = args.seed_num
num_cycle = args.maxcycle
num_cycle = 5
seed_list = [args.seed + i for i in range(seed_num)]
# ffdb for load template
FFDB = '/public_data/pdb100/pdb100_2020Mar11/pdb100_2020Mar11'
FFindexDB = namedtuple("FFindexDB", "index, data")
ffdb = FFindexDB(read_index(FFDB+'_pdb.ffindex'),
read_data(FFDB+'_pdb.ffdata'))
input_param = {
"N_PICK" : 1,
"N_PICK_GLOBAL" : 1,
"PICK_TOP" : True,
"MIN_SEED_MSA_PER_CHAIN" : 64,
"MAX_SEED_MSA" : 512,
"USE_MASK" : True,
"MAX_MSA_CYCLE" : num_cycle,
"MAX_SEQ" : 1024,
"SEED" : args.seed,
"BLOCKCUT" : 3000,
'USE_MULTIPLE_MODELS' : False,
}
model_param['n_main_block'] = 16
model_param['n_ref_block'] = 0
saved_model_path = "./model_weights/Model_1_5_nblock16_epoch29_use_interaction.pt"
# print usable gpu num
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device :', device)
print('Usable GPU num :', torch.cuda.device_count())
rank = 0
output_dir = args.output_dir
if not os.path.exists(output_dir):
os.makedirs(output_dir)
Visualizer = ModelVisulizer(
saved_model_path=saved_model_path,
saving_dir = output_dir,
device=device,
model_param=model_param, input_param=input_param,
batch_size=args.batch_size,
maxcycle=num_cycle)
for seed_idx, seed in enumerate(seed_list):
args.seed = seed
Visualizer.input_param['SEED'] = args.seed
with torch.autograd.set_detect_anomaly(False):
# set random seed
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
np.random.seed(args.seed)
mp.freeze_support()
for data in input_json:
Visualizer.inference_from_raw_data(data, only_structure=False,
template_mode="use_pdb_per_chain",
a3m_mode = 'use_colabfold_diagonal_a3m')