Source code for miniworld.MiniWorld_train_multi_deep_v1_5_use_interaction

from contextlib import ExitStack
from copy import deepcopy
from collections import OrderedDict
import torch.nn as nn
from miniworld.utils.kinematics import xyz_to_t2d
from miniworld.models_MiniWorld_v1_5_use_interaction.MiniWorld import MiniWorldModule
from miniworld.utils.util_module import XYZConverter
from miniworld.utils.data_refactoring import PDB_parsing, mmcif_parsing
import traceback
from miniworld.utils.chemical import INIT_CRDS
from miniworld.utils import kalign_mapping
from miniworld.utils.template_parser import read_templates
from miniworld.utils import hhpred_parser
from miniworld.utils.util import *
# distributed data parallel
import torch.multiprocessing as mp
from miniworld.utils.output_visualize import visualize_2D_heatmap
from miniworld.utils.ProteinClass import ProteinMSA
from miniworld.utils.ffindex import read_index, read_data
from miniworld.utils.kinematics import xyz_to_t2d

from miniworld.feature.MiniWorld_featuring_species import MSA_featurize_wo_statistics

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
import matplotlib.pyplot as plt

import os
import gzip
import string
import pickle

from collections import namedtuple

USE_AMP = False
NUM_CLASSES = 23

[docs] def a3m_parse(a3m_file_path): # Code from RF2 parsers.py """ Input : a3m_file_path (.a3m or .a3m.gz) """ msa = [] insertion_list = [] table = str.maketrans(dict.fromkeys(string.ascii_lowercase)) #print(filename) if a3m_file_path.split('.')[-1] == 'gz': f = gzip.open(a3m_file_path, 'rt') else: f = open(a3m_file_path, 'r') # read file line by line chain_break = {} for line in f: if line[0] == '#': # skip # line continue # skip labels if line[0] == '>': continue if chain_break == {} and ':' in line: chains = line.strip().split(':') chain_start = 0 for chain_idx, chain in enumerate(chains): chain_end = chain_start + len(chain) - 1 chain_break[chain_idx] = (chain_start, chain_end) chain_start = chain_end + 1 # remove : line = line.replace(':','') # remove right whitespaces line = line.rstrip() if len(line) == 0: continue # remove lowercase letters and append to MSA msa_i = line.translate(table) msas_i = msa_i.split('/') if (len(msa)==0): # first seq Ls = [len(x) for x in msas_i] msa = [[s] for s in msas_i] else: nchains = len(msas_i) isgood = all([ len(msa[i][0]) == len(msas_i[i]) for i in range(nchains) ]) if isgood: for i in range(nchains): msa[i].append(msas_i[i]) else: raise ValueError("Len error", a3m_file_path, len(msa[0]) ) # sequence length L = sum(Ls) # 0 - match or gap; 1 - insertion lower_case = np.array([0 if c.isupper() or c=='-' else 1 for c in line]) insertion = np.zeros((L)) if np.sum(lower_case) > 0: # positions of insertions pos = np.where(lower_case==1)[0] # shift by occurrence lower_case = pos - np.arange(pos.shape[0]) # position of insertions in cleaned sequence # and their length pos,num = np.unique(lower_case, return_counts=True) # append to the matrix of insetions insertion[pos] = num insertion_list.append(insertion) # concatenate msa = [np.array([list(s) for s in t], dtype='|S1').view(np.uint8) for t in msa] msa = np.concatenate(msa,axis=-1) query_sequence = msa[0] # convert letters into numbers alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8) for i in range(alphabet.shape[0]): msa[msa == alphabet[i]] = i # treat all unknown characters as gaps msa[msa > 20] = 20 # re-convert to character for i in range(alphabet.shape[0]): msa[msa == i] = alphabet[i] # uint8 -> string, np.array -> list msa = msa.view('|S1').reshape(msa.shape) msa = msa.tolist() msa = ["".join([char.decode('utf-8') for char in sublist]) for sublist in msa] return msa
[docs] def visualize_2D_tensor(tensor, sav_dir = "permutation_test/", name = "test"): import matplotlib.pyplot as plt fig, ax = plt.subplots() im = ax.imshow(tensor) fig.colorbar(im) plt.savefig(sav_dir+name+".png") plt.close() torch.save(tensor, sav_dir+name+".pt")
[docs] def add_weight_decay(model, l2_coeff): decay, no_decay = [], [] for name, param in model.named_parameters(): if not param.requires_grad: continue #if len(param.shape) == 1 or name.endswith(".bias"): if "norm" in name or name.endswith(".bias"): no_decay.append(param) else: decay.append(param) return [{'params': no_decay, 'weight_decay': 0.0}, {'params': decay, 'weight_decay': l2_coeff}]
[docs] def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad)
[docs] class EMA(nn.Module): # From RF2 def __init__(self, model, decay): super().__init__() self.decay = decay self.model = model self.shadow = deepcopy(self.model) for param in self.shadow.parameters(): param.detach_()
[docs] @torch.no_grad() def update(self): if not self.training: # print("EMA update should only be called during training", file=stderr, flush=True) print("EMA update should only be called during training") return model_params = OrderedDict(self.model.named_parameters()) shadow_params = OrderedDict(self.shadow.named_parameters()) # check if both model contains the same set of keys assert model_params.keys() == shadow_params.keys() for name, param in model_params.items(): # see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage # shadow_variable -= (1 - decay) * (shadow_variable - variable) shadow_params[name].sub_((1. - self.decay) * (shadow_params[name] - param)) # in-place subtraction model_buffers = OrderedDict(self.model.named_buffers()) shadow_buffers = OrderedDict(self.shadow.named_buffers()) # check if both model contains the same set of keys assert model_buffers.keys() == shadow_buffers.keys() for name, buffer in model_buffers.items(): # buffers are copied shadow_buffers[name].copy_(buffer)
[docs] def forward(self, *args, **kwargs): if self.training: return self.model(*args, **kwargs) else: return self.shadow(*args, **kwargs)
[docs] class ModelVisulizer(): def __init__(self, saved_model_path, saving_dir , device, model_param={}, input_param={}, batch_size=1, maxcycle=4): self.saved_model_path = saved_model_path self.saving_dir = saving_dir self.device = device # self.model_param = model_param self.input_param = input_param self.batch_size = batch_size # for all-atom str loss self.l2a = long2alt self.aamask = allatom_mask self.num_bonds = num_bonds # from xyz to get xxxx or from xxxx to xyz self.xyz_converter = XYZConverter() self.xyz_converter.to(self.device) self.maxcycle = maxcycle self.model = EMA(MiniWorldModule(**self.model_param).to(device), 0.99) # load model loaded_epoch, best_valid_loss = self.load_model(self.model, self.saved_model_path) self.model.eval()
[docs] def load_model(self, model, saved_model_path): loaded_epoch = -1 best_valid_loss = 999999.9 if not os.path.exists(saved_model_path): print ('no model found', saved_model_path) return -1, best_valid_loss print ('loading model', saved_model_path) map_location = {"cuda:%d"%0: "cuda:%d"%0} checkpoint = torch.load(saved_model_path, map_location=map_location) model.model.load_state_dict(checkpoint['model_state_dict'], strict=False) model.shadow.load_state_dict(checkpoint['model_state_dict'], strict=False) return loaded_epoch, best_valid_loss
def _prepare_input(self, inputs, gpu): """ Input { 'msa' : { 'sample_seq': torch.tensor, # (Batch, MAXCYCLE, params['CROP']) 'sample_msa_species': torch.tensor, # (Batch, MAXCYCLE, params['CROP']) 'sample_msa_clust': torch.tensor, # (Batch, MAXCYCLE, params['MAX_SEED_MSA'], params['CROP']) 'sample_msa_seed': torch.tensor, # (Batch, MAXCYCLE, params['MAX_SEED_MSA'], params['CROP'], 23 + 1) 'sample_mask_pos': torch.tensor, # (Batch, MAXCYCLE, params['MAX_SEED_MSA'], params['CROP']) 'chain_to_msa_depth': list of dictionary, # non-batched (chain_idx : msa_depth) }, 'template' : { 'xyz' : torch.tensor # (Batch, params['N_PICK_GLOBAL'], params['CROP'], 27, 3) 'template_1D' : torch.tensor # (Batch, params['N_PICK_GLOBAL'], params['CROP'], NUM_CLASSES + 1) 'template_atom_mask': torch.tensor # (Batch, params['N_PICK_GLOBAL'], params['CROP'], 27) 'use_chain_mask' : bool }, 'crop_idx' : torch.tensor, # (Batch, params['CROP']) 'symmetry_related_info' : list, # non-batched 'chain_break' : list of dictionary, # non-batched, (N_chain, 2) 'ID' : list of string, # non-batched 'source' : list of string, # non-batched } """ sample_seq = inputs['msa']['sample_seq'].float() sample_msa_species = inputs['msa']['sample_msa_species'].float() sample_msa_cluster = inputs['msa']['sample_msa_clust'].float() sample_msa_seed = inputs['msa']['sample_msa_seed'].float() sample_mask_pos = inputs['msa']['sample_mask_pos'].float() sample_zero_pos = sample_msa_cluster == UNK_IDX | sample_mask_pos.bool() # (B, MAX_CYCLE, MAX_SEED_MSA, CROP) # transfer inputs to device B, _, N, L = sample_msa_cluster.shape # _ for MAX_CYCLE if 'template' in inputs and inputs['template'] is not None: template_xyz = inputs['template']['xyz'] template_1D = inputs['template']['template_1D'] template_atom_mask = inputs['template']['template_atom_mask'] use_chain_mask_at_template = inputs['template']['use_chain_mask'] else : template_xyz = INIT_CRDS.reshape(1, 1,1,27,3).repeat(sample_seq.shape[0], self.input_param['N_PICK_GLOBAL'], L ,1,1) template_atom_mask = torch.zeros((sample_seq.shape[0], self.input_param['N_PICK_GLOBAL'], L, 27)) template_1D = torch.zeros((sample_seq.shape[0], self.input_param['N_PICK_GLOBAL'], L, NUM_CLASSES + 1)) template_1D[:,:,:,GAP_IDX] = 1.0 use_chain_mask_at_template = True crop_idx = inputs['crop_idx'] chain_break_list = inputs['chain_break'] ID_list = inputs['ID'] ignore_interchain_list = [] interaction_type_list = [] data_source_list = [] xyz_prev = inputs['prev']['xyz'] mask_prev = inputs['prev']['atom_mask'] if self.input_param['USE_MULTIPLE_MODELS'] : pass else : if len(mask_prev.shape) == 4 : # xyz_prev : (B, M, L, 27, 3) # mask_prev : (B, M, L, 27) xyz_prev = xyz_prev[:,0,:,:,:] mask_prev = mask_prev[:,0,:,:] crop_idx = crop_idx.float().to(gpu, non_blocking=True) # (B, L) # generate chain_mask (B, L, L) chain_mask = torch.zeros((B, L, L), dtype=torch.bool) model_input_idx = crop_idx.clone()[:,:L] chain_break_idx_added = 100 # TODO : Check idx_to_chain_list = [] for bb, chain_break in enumerate(chain_break_list) : inner_crop_idx = crop_idx[bb].to(torch.int).tolist() crop_idx_set = set(inner_crop_idx) temp_chain_break = {key : np.arange(value[0],value[1]+1) for key, value in chain_break.items()} idx_to_chain_dict = {} for ii, (key, value) in enumerate(temp_chain_break.items()): for idx in value: idx_to_chain_dict[idx] = key intersection_list = crop_idx_set.intersection(set(value)) idxs = [inner_crop_idx.index(intersection) for intersection in intersection_list] idxs = torch.tensor(idxs, dtype=torch.long) # (N_chain, ) chain_mask[bb, idxs[:, None], idxs] = True if ii == 0 or len(intersection_list) == 0 : continue model_input_idx[bb, idxs[0]:] += chain_break_idx_added idx_to_chain_list.append(idx_to_chain_dict) chain_mask = chain_mask.to(gpu, non_blocking=True) template_xyz = template_xyz.to(gpu, non_blocking=True) template_1D = template_1D.to(gpu, non_blocking=True) template_atom_mask = template_atom_mask.to(gpu, non_blocking=True) xyz_prev = xyz_prev.float().to(gpu, non_blocking=True) mask_prev = mask_prev.float().to(gpu, non_blocking=True) sample_seq = sample_seq.long().to(gpu, non_blocking=True) sample_msa_species = sample_msa_species.long().to(gpu, non_blocking=True) sample_msa_cluster = sample_msa_cluster.float().to(gpu, non_blocking=True) sample_msa_seed = sample_msa_seed.float().to(gpu, non_blocking=True) sample_zero_pos = sample_zero_pos.bool().to(gpu, non_blocking=True) sample_mask_pos = sample_mask_pos.float().to(gpu, non_blocking=True) # # processing template features template_mask_2D = template_atom_mask[:,:,:,:3].all(dim=-1) # (B, T, L) template_mask_2D = template_mask_2D[:,:,None]*template_mask_2D[:,:,:,None] # (B, T, L, L) # (ignore inter-chain region) try : if isinstance(use_chain_mask_at_template, torch.Tensor): use_chain_mask_at_template = use_chain_mask_at_template.float().to(gpu, non_blocking=True) template_mask_2D = template_mask_2D.float() * use_chain_mask_at_template[:, None, :, :] # (B, T, L, L) elif use_chain_mask_at_template is True: template_mask_2D = template_mask_2D.float() * chain_mask.float()[:, None, :, :] # (B, T, L, L) except Exception as e: print(f"Error \n{e}") print(f"query length : {L}") print(f"sample_seq.shape : {sample_seq.shape}") print(f"sample_msa_cluster.shape : {sample_msa_cluster.shape}") print(f"ID_list : {ID_list}") assert 1==0 # mask_t_2d = mask_t_2d.float() * same_chain.float()[:,None] template_2D = xyz_to_t2d(template_xyz, template_mask_2D) # I don't change this part. (rename) PSK seq_tmp = template_1D[...,:-1].argmax(dim=-1).reshape(-1,L) alpha, _, alpha_mask, _ = self.xyz_converter.get_torsions(template_xyz.reshape(-1,L,27,3), seq_tmp, mask_in=template_atom_mask.reshape(-1,L,27)) alpha = alpha.reshape(B,-1,L,10,2) alpha_mask = alpha_mask.reshape(B,-1,L,10,1) template_alpha = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, L, 30) network_input = {} network_input['msa_latent'] = sample_msa_seed network_input['msa_species'] = sample_msa_species network_input['msa_zero_pos'] = sample_zero_pos network_input['query_sequence'] = sample_seq # 20230911 PSK network_input['idx'] = model_input_idx network_input['template_1D'] = template_1D network_input['template_2D'] = template_2D network_input['template_xyz'] = template_xyz[:,:,:,1] network_input['template_alpha'] = template_alpha network_input['mask_t'] = template_mask_2D network_input['chain_mask'] = chain_mask.float() network_input['interaction_points'] = inputs['interaction_points'].long().to(gpu, non_blocking=True) mask_recycle = mask_prev[:,:,:3].bool().all(dim=-1) mask_recycle = mask_recycle[:,:,None]*mask_recycle[:,None,:] # (B, L, L) # mask_recycle = chain_mask.float()*mask_recycle.float() return network_input, xyz_prev, mask_recycle, sample_msa_cluster, ID_list, idx_to_chain_list, crop_idx def _get_model_input(self, network_input, output_i, mask_recycle, i_cycle, return_raw=False, use_checkpoint=False): # I changed name of keys. PSK input_i = {} for key in network_input: if key in ['msa_latent', 'msa_species', 'msa_zero_pos', 'query_sequence']: input_i[key] = network_input[key][:,i_cycle] else: input_i[key] = network_input[key] logits, logits_aa, logits_exp, logits_pae, p_bind, xyz_prev, alpha, symmsub, lddt, msa_prev, pair_prev, state_prev = output_i if len(xyz_prev.shape) == 5: xyz_prev = xyz_prev[-1] input_i['msa_prev'] = msa_prev input_i['pair_prev'] = pair_prev[-1].unsqueeze(0) if pair_prev is not None else None input_i['state_prev'] = state_prev input_i['xyz'] = xyz_prev input_i['mask_recycle'] = mask_recycle input_i['return_raw'] = return_raw input_i['use_checkpoint'] = use_checkpoint return input_i
[docs] def draw_distogram(self, logit_s, saving_path): PARAMS = { "DMIN" : 2.0, "DMAX" : 20.0, "DBINS" : 36, "ABINS" : 36, } # logit_s[0] : (I,1,37,L,L) distogram = logit_s[0].squeeze(1) # (I,37,L,L) nbins = PARAMS['DBINS'] + 1 distance = torch.linspace(PARAMS['DMIN'], PARAMS['DMAX'], nbins).unsqueeze(0).to(distogram.device) # (1, DBINS) avg_distogram = torch.sum(distogram * distance.unsqueeze(-1).unsqueeze(-1), dim=1) # (I, L, L) avg_distogram = avg_distogram.cpu().numpy() # Distogram has I number of distograms, so draw I images I columns I = avg_distogram.shape[0] fig, axs = plt.subplots(1, I, figsize=(I*5, 5)) for i in range(I): axs[i].imshow(avg_distogram[i], cmap='hot', interpolation='nearest') axs[i].set_title(f"Distogram {i}") plt.savefig(saving_path) plt.close()
[docs] def write_pae_plddt(self, save_path, ID, pae, plddt): with open(save_path, "a") as f: f.write(f"{ID} {pae} {plddt}\n")
[docs] def model_inference(self, inputs, device, saving_dir = None, saving_intermediate = False, only_structure = True): # move some global data to cuda device self.l2a = self.l2a.to(device) self.aamask = self.aamask.to(device) self.xyz_converter = self.xyz_converter.to(device) self.num_bonds = self.num_bonds.to(device) initial_crds_list = [] pred_crds_list = [] sequence_list = [] ID_list = [] crop_idx_list = [] idx_to_chain_list = [] with torch.no_grad(): network_input, xyz_prev, mask_recycle, msa, inner_ID_list, inner_idx_to_chain_list, crop_idx = self._prepare_input(inputs, device) N_cycle = self.maxcycle # number of recycling query_sequence = network_input['query_sequence'][0] ID = inner_ID_list[0] ID_list.append(ID) sequence_list.append(query_sequence) output_i = (None, None, None, None, None, xyz_prev, None, None, None, None, None, None) for i_cycle in range(N_cycle): with ExitStack() as stack: if i_cycle < N_cycle - 1: stack.enter_context(torch.no_grad()) stack.enter_context(torch.cuda.amp.autocast(enabled=USE_AMP)) return_raw=False use_checkpoint=False else: stack.enter_context(torch.cuda.amp.autocast(enabled=USE_AMP)) return_raw=False use_checkpoint=False input_i = self._get_model_input(network_input, output_i, mask_recycle, i_cycle, return_raw=return_raw, use_checkpoint=use_checkpoint) output_i = self.model(**input_i) B = network_input['query_sequence'].shape[0] logit_s, logit_aa_s, logit_exp, logit_pae, p_bind, pred_crds, alphas, symmsubs, pred_lddts, _, pair, state = output_i # print(f"xyz_prev : {xyz_prev}") # print(f"pred_crds[-1,0,:,:14,:] : {pred_crds[0,0,:,:14,:]}") pae = torch.nn.functional.softmax(logit_pae, dim=1) # (B, 64, L, L) nbin = pae.shape[1] bin_value = 0.5 * torch.arange(nbin, dtype=torch.float32, device=pae.device) expected_error = torch.einsum('bnij, n -> bij', pae, bin_value) # (B, L, L) nbin = pred_lddts.shape[1] bin_step = 1.0 / nbin lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddts.dtype, device=pred_lddts.device) plddt = torch.nn.functional.softmax(pred_lddts, dim=1) # (B, 64, L) expected_lddt = torch.einsum('bni, n -> bi', plddt, lddt_bins) # (B, L) average_pae = torch.mean(expected_error, dim=(1, 2)) # (B, ) average_lddt = torch.mean(expected_lddt, dim=1) # (B, ) average_pae = average_pae.item() average_lddt = average_lddt.item() B, _, N, L = msa.shape logit_aa_s = torch.nn.functional.pad(logit_aa_s,(0,0,0,1),value=0) # (B, 22, CROP*CROP) # And logit_aa_s 20 <-> 21 logit_aa_s[:,21] = logit_aa_s[:,20] # logit_aa_s[:,21] = -99999 logit_aa_s[:,20] = 0 logit_aa_s = torch.nn.functional.pad(logit_aa_s,(0,0,0,1),value=0) # (B, 22, CROP*CROP) # And logit_aa_s 20 <-> 21 logit_aa_s[:,21] = logit_aa_s[:,20] # logit_aa_s[:,21] = -99999 logit_aa_s[:,20] = 0 msa_in = msa.long()[0,i_cycle] # (N, L) if not only_structure: visualize_2D_heatmap(msa_in, file_name = ID + "_msa", heatmap_dir = saving_dir + "msa/") self.write_pae_plddt(saving_dir + "pae_plddt.txt", ID, average_pae, average_lddt) I = pred_crds.shape[0] if saving_intermediate : for block_idx in range(I): predRs_all, pred_all = self.xyz_converter.compute_all_atom(query_sequence[0:1].to(torch.int), pred_crds[block_idx], alphas[block_idx]) self.write_pdb([pred_all[0,:,:14,:]], [query_sequence[0]], [ID], expected_lddt, inner_idx_to_chain_list, crop_idx, File_suffix="pred_" +str(i_cycle)+"_"+str(block_idx), saving_dir = saving_dir) else : if i_cycle == N_cycle - 1: predRs_all, pred_all = self.xyz_converter.compute_all_atom(query_sequence[0:1].to(torch.int), pred_crds[-1], alphas[-1]) # self.draw_distogram(logit_s, saving_dir + str(ID) + "_distogram_" + str(i_cycle) + ".png") self.write_pdb([pred_all[0,:,:14,:]], [query_sequence[0]], [ID], expected_lddt, inner_idx_to_chain_list, crop_idx, File_suffix="pred_" +str(i_cycle), saving_dir = saving_dir) # draw_pae(logit_pae = value, mask= None, saving_path = saving_dir + "pae/" + str(ID) + "_pae_from_value_" + str(i_cycle) + ".png") if i_cycle == N_cycle - 1: if not only_structure: if not os.path.exists(saving_dir + "pae/"): os.makedirs(saving_dir + "pae/") draw_pae(logit_pae = logit_pae, mask= None, saving_path = saving_dir + "pae/" + str(ID) + "_pae_from_logit_pae_" + str(i_cycle) + ".png") initial_crds_list.append(xyz_prev[0,:,:14,:]) # (256,14,3) pred_crds_list.append(pred_all[0,:,:14,:]) # (256,14,3) idx_to_chain_list.append(inner_idx_to_chain_list[0]) crop_idx_list.append(crop_idx[0]) torch.cuda.empty_cache() return pred_crds_list, sequence_list, ID_list
[docs] def ATOM_REPR(self, atom): atom = atom.strip() # remove empty spcae atom = atom.replace(' ', '') return atom[0]
[docs] def write_pdb(self, pred_crds_list, seq_list, ID_list, lddt_list, idx_to_chain_list, crop_idx, File_suffix, saving_dir = None): saving_dir = self.saving_dir + "pdb/" if saving_dir is None else saving_dir if not os.path.exists(saving_dir): os.makedirs(saving_dir) for xyz, seq, ID, idx_to_chain, crop_idx, lddt in zip(pred_crds_list, seq_list, ID_list, idx_to_chain_list, crop_idx, lddt_list): crop_idx = crop_idx.tolist() lines = [] atom_idx = 1 residue_idx = 1 before_chain = "" start = True for ll in range(xyz.shape[0]): idx = crop_idx[ll] if idx == -1 : continue chain = idx_to_chain[idx] if chain != before_chain: if start : start = False else : lines.append("TER\n") residue_idx = 1 before_chain = chain xyz_ll = xyz[ll] # (14,3) seq_ll = int(seq[ll].item()) residue = num2aa[seq_ll] atom_list = aa2long[seq_ll][:14] lddt_i = lddt[ll].item() lddt_i = round(lddt_i*100, 2) # 0.00 ~ 100.00 # if lddt_i = 16.6 -> 16.60 lddt_i = "{:.2f}".format(lddt_i) # remove None atom_list = [item for item in atom_list if item is not None] for aa, atom in enumerate(atom_list): atom_xyz = xyz_ll[aa] atom_xyz = atom_xyz.cpu().numpy() atom_xyz = atom_xyz.tolist() atom_xyz = [round(item,3) for item in atom_xyz] atom_xyz = [str(item) for item in atom_xyz] atom_x = atom_xyz[0] atom_y = atom_xyz[1] atom_z = atom_xyz[2] if float(atom_x) == 0 and float(atom_y) == 0 and float(atom_z) == 0: continue one_atom = self.ATOM_REPR(atom) atom_line = f"ATOM {atom_idx:>5} {atom:>4} {residue:>3}{chain:>2}{residue_idx:>4} {atom_x:>8}{atom_y:>8}{atom_z:>8} 1.00 {lddt_i:>3} {one_atom:>2}\n" lines.append(atom_line) atom_idx += 1 residue_idx += 1 saving_path = saving_dir + "/" + ID + '_{}.pdb'.format(File_suffix) with open(saving_path, 'w') as f: f.writelines(lines) print(f"Saving {ID} done.")
[docs] def get_msa(self, a3m_path): msa = ProteinMSA(a3m_file_path=a3m_path) return msa
[docs] def get_fasta(self, fasta_path): # Code from RF2 parsers.py """ Input : a3m_file_path (.a3m or .a3m.gz) """ msa = [] table = str.maketrans(dict.fromkeys(string.ascii_lowercase)) #print(filename) if fasta_path.split('.')[-1] == 'gz': f = gzip.open(fasta_path, 'rt') else: f = open(fasta_path, 'r') # read file line by line chain_break = {} for line in f: if line[0] == '#': # skip # line continue # skip labels if line[0] == '>': continue if len(chain_break) == 0 and ':' in line: chains = line.strip().split(':') chain_start = 0 for chain_idx, chain in enumerate(chains): chain_end = chain_start + len(chain) - 1 chain_break[chain_idx] = (chain_start, chain_end) chain_start = chain_end + 1 self.chain_break = chain_break # remove : line = line.replace(':','') # remove right whitespaces line = line.rstrip() if len(line) == 0: continue # remove lowercase letters and append to MSA seq_i = line.translate(table) seq_i = seq_i.split('/') if (len(msa)==0): # first seq Ls = [len(x) for x in seq_i] msa = [[s] for s in seq_i] else: nchains = len(seq_i) isgood = all([ len(msa[i][0]) == len(seq_i[i]) for i in range(nchains) ]) if isgood: for i in range(nchains): msa[i].append(seq_i[i]) else: raise ValueError("Len error", fasta_path, len(msa[0]) ) # sequence length L = sum(Ls) break if len(chain_break) == 0: chain_break = {0 : (0, L-1)} # concatenate msa = [np.array([list(s) for s in t], dtype='|S1').view(np.uint8) for t in msa] msa = np.concatenate(msa,axis=-1) # convert letters into numbers alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8) for i in range(alphabet.shape[0]): msa[msa == alphabet[i]] = i msa[msa>20] = 20 query_sequence = msa[0] query_sequence = torch.tensor(query_sequence) # (L, ) torch.Tensor chain_break = {str(key) : value for key, value in chain_break.items()} return query_sequence, chain_break
[docs] def load_initial_xyz(self, initial_xyz_path, is_protein=True): if initial_xyz_path is None : return None, None if '.pdb' in initial_xyz_path : protein = PDB_parsing(initial_xyz_path) initial_xyz = protein.structure.xyz initial_atom_mask = protein.structure.atom_mask elif is_protein : with open(initial_xyz_path, 'rb') as f: protein = pickle.load(f) initial_xyz = protein.structure.xyz initial_atom_mask = protein.structure.atom_mask else : initial = torch.load(initial_xyz_path) initial_xyz = initial['xyz'] initial_atom_mask = initial['atom_mask'] return initial_xyz, initial_atom_mask
[docs] def symmetric_a3m_to_full_a3m(self, a3m_file_path, fasta_file_path, saving_path): with open(a3m_file_path, "r") as f: lines = f.readlines() first_line = lines[0].strip() input_seq = lines[2].strip() parsed_msa = a3m_parse(a3m_file_path) # remove # first_line = first_line[1:] seq_len_list, n_list = first_line.split("\t") seq_len_list = seq_len_list.split(",") n_list = n_list.split(",") chain_to_seq_len = {ii: int(n) for ii, n in enumerate(seq_len_list)} chain_to_n = {ii: n for ii, n in enumerate( n_list)} chain_break = {} seq_to_chain = {} with open(fasta_file_path, "r") as f: fasta_lines = f.readlines() full_input_seq = fasta_lines[1].strip().split(":") split = 0 for idx in chain_to_n.keys(): chain_start, chain_end = split, split + int(chain_to_seq_len[idx]) -1 chain_break[idx] = (chain_start, chain_end) split += int(chain_to_seq_len[idx]) seq_to_chain[input_seq[chain_start:chain_end+1]] = idx chain_to_msa = {} for chain in chain_to_n.keys(): chain_to_msa[chain] = [] for line in parsed_msa: line = line.strip() for chain in chain_break.keys(): chain_start, chain_end = chain_break[chain] chain_to_msa[chain].append(line[chain_start:chain_end + 1]) new_lines = [] add_idx = 0 for line_idx in range(len(lines)): if line_idx == 0: new_lines.append(lines[line_idx]) continue line = lines[line_idx] if line.startswith(">"): new_lines.append(line) else : new_line = "" for seq in full_input_seq: chain = seq_to_chain[seq] new_line += chain_to_msa[chain][add_idx] new_line += "\n" new_lines.append(new_line) add_idx += 1 with open(saving_path, "w") as f: for line in new_lines: f.write(line)
[docs] def symmetric_a3m_to_diagonal_a3m(self, a3m_file_path, fasta_file_path, saving_path): with open(a3m_file_path, "r") as f: lines = f.readlines() first_line = lines[0].strip() input_seq = lines[2].strip() parsed_msa = a3m_parse(a3m_file_path) # remove # first_line = first_line[1:] seq_len_list, n_list = first_line.split("\t") seq_len_list = seq_len_list.split(",") n_list = n_list.split(",") chain_to_seq_len = {ii: int(n) for ii, n in enumerate(seq_len_list)} chain_to_n = {ii: n for ii, n in enumerate(n_list)} chain_break = {} seq_to_chain = {} with open(fasta_file_path, "r") as f: fasta_lines = f.readlines() full_input_seq = fasta_lines[1].strip().split(":") total_len = sum([len(seq) for seq in full_input_seq]) split = 0 for idx in chain_to_n.keys(): chain_start, chain_end = split, split + int(chain_to_seq_len[idx]) -1 chain_break[idx] = (chain_start, chain_end) split += int(chain_to_seq_len[idx]) seq_to_chain[input_seq[chain_start:chain_end+1]] = idx chain_to_msa = {} for chain in chain_to_n.keys(): chain_to_msa[chain] = [] for line in parsed_msa: line = line.strip() for chain in chain_break.keys(): chain_start, chain_end = chain_break[chain] chain_to_msa[chain].append(line[chain_start:chain_end + 1]) output = {} current_index = 0 for seq in full_input_seq: # Calculate positions for the current sequence indices = list(range(current_index, current_index + len(seq))) # Check if sequence already has entries in the dictionary if seq in output: # Append new indices output[seq].extend(indices) else: # Start new entry in dictionary with current indices output[seq] = indices # Update current index for the next sequence current_index += len(seq) # Convert dictionary keys from sequence to their initial indices based on order of first appearance full_seq = "".join(full_input_seq) new_lines = [] add_idx = 0 new_lines.append(">query\n") new_lines.append("".join(full_input_seq)+"\n") chain_start = 0 for seq, idxs in output.items(): chain = seq_to_chain[seq] chain_end = chain_start + chain_to_seq_len[chain] - 1 for chain_msa_seq in chain_to_msa[chain]: chain_idxs = copy.deepcopy(idxs) rest = len(chain_idxs) new_seq = "" while rest > 0 : new_seq += "-" * (chain_idxs[0] - len(new_seq)) new_seq += chain_msa_seq chain_idxs = chain_idxs[len(chain_msa_seq):] rest = len(chain_idxs) new_seq += "-" * (total_len - len(new_seq)) new_lines.append(">\n") new_lines.append(new_seq+"\n") chain_start = chain_end + 1 with open(saving_path, "w") as f: for line in new_lines: f.write(line)
[docs] @torch.no_grad() def inference_from_raw_data(self, raw_file_dict, template_mode = 'use_hhr', a3m_mode='use_raw_a3m', saving_dir = None, saving_intermediate = False, only_structure = True, symmetry = None, msa_save_path = None,): ID = raw_file_dict['ID'] fasta_path = raw_file_dict['fasta'] if a3m_mode == "single_sequence" or 'a3m' not in raw_file_dict.keys(): a3m_path = raw_file_dict['fasta'] elif a3m_mode == "use_colabfold_paired_a3m": a3m_path = raw_file_dict['a3m'] self.symmetric_a3m_to_full_a3m(a3m_path, fasta_path, self.saving_dir + "temp.a3m") a3m_path = self.saving_dir + "temp.a3m" elif a3m_mode == "use_colabfold_diagonal_a3m": a3m_path = raw_file_dict['a3m'] self.symmetric_a3m_to_diagonal_a3m(a3m_path, fasta_path, self.saving_dir + "temp.a3m") a3m_path = self.saving_dir + "temp.a3m" elif a3m_mode == "use_raw_a3m": a3m_path = raw_file_dict['a3m'] # Load MSA try: msa = self.get_msa(a3m_path) query_sequence, chain_break = self.get_fasta(fasta_path) except Exception as e: # TODO There can be no msa or template because the chain is too short. We should think about this case. print(traceback.format_exc()) print(e, end = " | ") return f"Error : {e}" # Load templates query_length = query_sequence.shape[0] template_path_dict = raw_file_dict['templates'] if 'templates' in raw_file_dict.keys() else None # {chain_idx : template_path} n_pick_global = self.input_param['N_PICK_GLOBAL'] template_featurized_dict = { # 'xyz' : torch.zeros((n_pick_global,query_length,27,3)), # (N_template, L_query, 27, 3) 'xyz' : INIT_CRDS.reshape(1,1,1,27,3).repeat(1,n_pick_global, query_length ,1,1) , # (B, N_template,L, 27, 3) 'template_1D' : torch.full((1,n_pick_global,query_length,NUM_CLASSES + 1), GAP_IDX).float(), # (B, N_template, L_query, NUM_CLASSES + 1) 'template_atom_mask' : torch.full((1,n_pick_global,query_length,27), False), # (B, N_template, L_query, 27) 'use_chain_mask' : True, } if template_path_dict is not None : if template_mode == 'use_hhr': for chain_idx, template_path in template_path_dict.items(): chain_start, chain_end = chain_break[chain_idx] qlen = chain_end - chain_start + 1 hhr_path = template_path['hhr'] atab_path = template_path['atab'] xyz, f1d, mask_t = read_templates(qlen, ffdb, hhr_path, atab_path, offset=0, n_templ=n_pick_global, random_noise=5.0) template_featurized_dict['xyz'][:, :, chain_start:(chain_end+1), :, :] = xyz template_featurized_dict['template_1D'][:, :, chain_start:(chain_end+1), :] = f1d template_featurized_dict['template_atom_mask'][:, :, chain_start:(chain_end+1), :] = mask_t template_featurized_dict['use_chain_mask'] = True elif template_mode == 'use_complex_pdb': # already aligned pdb_path_list = template_path_dict['pdb_path_list'] N_template = len(pdb_path_list) seq_identity = template_path_dict['seq_identity'] if 'seq_identity' in template_path_dict.keys() else None # (N_template, L) if seq_identity is None: seq_identity = torch.ones((N_template, query_length)) for ii, pdb_path in enumerate(pdb_path_list): protein = PDB_parsing(pdb_path) xyz = protein.structure.xyz[0] # (L, 27, 3) seq = protein.sequence.sequence[0] # (L, ) seq_onehot = nn.functional.one_hot(seq, num_classes=NUM_CLASSES).float() # (L, NUM_CLASSES) f1d = torch.concat([seq_onehot, seq_identity[ii].unsqueeze(-1)], dim=-1) # (L, NUM_CLASSES + 1) mask_t = protein.structure.atom_mask # (L, 27) template_featurized_dict['xyz'][:, ii, :, :14, :] = xyz.unsqueeze(0) template_featurized_dict['template_1D'][:, ii, :, :] = f1d.unsqueeze(0) template_featurized_dict['template_atom_mask'][:, ii, :, :14] = mask_t.unsqueeze(0) template_featurized_dict['use_chain_mask'] = False elif template_mode == 'use_pdb_per_chain': # use only one template template_chain_mask = torch.zeros((1, query_length, query_length), dtype=torch.bool) seq_alignment_score = template_path_dict['seq_alignment_score'] if 'seq_alignment_score' in template_path_dict.keys() else 0.9 cif_path_dict = template_path_dict['cif_path_dict'] template_dict = {} for ID, cif_path in cif_path_dict.items(): template_protein = mmcif_parsing(cif_path) template_xyz = template_protein.structure.xyz[0] # (L_template, 27, 3) template_seq = template_protein.sequence.sequence[0] # (L_template, ) template_atom_mask = template_protein.structure.atom_mask[0] # (L_template, 27) template_chain_break = template_protein.structure.chain_break # {chain_idx : (start, end)} template_dict[ID] = { 'xyz' : template_xyz, 'seq' : template_seq, 'atom_mask' : template_atom_mask, 'chain_break' : template_chain_break, } per_chain_dict = template_path_dict['per_chain_dict'] valid_idx_group = {} before_chain_idx = -1 for chain_idx, item in per_chain_dict.items(): if 'template' not in item.keys(): # AF2 prediction (for F chain) chain_start, chain_end = chain_break[chain_idx] pdb_path = item['pdb_path'] template_protein, lddt = PDB_parsing(pdb_path, return_lddt=True) lddt = lddt / 100 # for test template_xyz = template_protein.structure.xyz[0] # (L_template, 27, 3) template_seq = template_protein.sequence.sequence[0] # (L_template, ) template_atom_mask = template_protein.structure.atom_mask[0] # (L_template, 27) template_seq_onehot = nn.functional.one_hot(template_seq, num_classes=NUM_CLASSES).float() # (L_template, NUM_CLASSES) template_featurized_dict['xyz'][:, 0, chain_start:chain_end+1, :14, :] = template_xyz template_featurized_dict['template_1D'][:, 0, chain_start:chain_end+1, :NUM_CLASSES] = template_seq_onehot template_featurized_dict['template_1D'][:, 0, chain_start:chain_end+1, NUM_CLASSES] = lddt template_featurized_dict['template_atom_mask'][:, 0, chain_start:chain_end+1, :14] = template_atom_mask chain_valid_idx = torch.arange(chain_start, chain_end+1) valid_idx_grid_x, valid_idx_grid_y = torch.meshgrid(chain_valid_idx, chain_valid_idx) template_chain_mask[0, valid_idx_grid_x, valid_idx_grid_y] = True continue template_PDB_ID, template_chain_idx = item['template'].split("_") kalign_path = item['kalign_result_path'] try: chain_start, chain_end = chain_break[chain_idx] except: breakpoint() template_xyz = template_dict[template_PDB_ID]['xyz'] template_seq = template_dict[template_PDB_ID]['seq'] template_atom_mask = template_dict[template_PDB_ID]['atom_mask'] template_chain_break = template_dict[template_PDB_ID]['chain_break'] template_chain_start, template_chain_end = template_chain_break[template_chain_idx] if valid_idx_group.get(template_PDB_ID) is None: valid_idx_group[template_PDB_ID] = [] valid_idx_group[template_PDB_ID].append(torch.arange(chain_start, chain_end+1)) query_map_idx, target_map_idx = kalign_mapping(kalign_path) query_map_idx = query_map_idx + before_chain_idx + 1 before_chain_idx = chain_end template_seq_alignment_score = torch.full_like(query_map_idx, seq_alignment_score).unsqueeze(0).float() template_chain_xyz = template_xyz[template_chain_start:template_chain_end+1] # (L_chain, 27, 3) template_chain_seq = template_seq[template_chain_start:template_chain_end+1] # (L_chain, ) template_chain_atom_mask = template_atom_mask[template_chain_start:template_chain_end+1] # (L_chain, 27) template_chain_seq_onehot = nn.functional.one_hot(template_chain_seq, num_classes=NUM_CLASSES).float() # (L_chain, NUM_CLASSES) template_featurized_dict['xyz'][:, 0, query_map_idx, :14, :] = template_chain_xyz[target_map_idx] template_featurized_dict['template_1D'][:, 0, query_map_idx, :NUM_CLASSES] = template_chain_seq_onehot[target_map_idx].unsqueeze(0) template_featurized_dict['template_1D'][:, 0, query_map_idx, NUM_CLASSES] = template_seq_alignment_score template_featurized_dict['template_atom_mask'][:, 0, query_map_idx, :14] = template_chain_atom_mask[target_map_idx].bool() for template_PDB_ID, valid_idx in valid_idx_group.items(): valid_idx = torch.cat(valid_idx) # (L,) valid_idx_grid_x, valid_idx_grid_y = torch.meshgrid(valid_idx, valid_idx, indexing='ij') template_chain_mask[0, valid_idx_grid_x, valid_idx_grid_y] = True template_featurized_dict['use_chain_mask'] = template_chain_mask elif template_mode == 'use_complex_prediction_pdb': # already aligned pdb_path_list = template_path_dict['pdb_path_list'] print(f'pdb_path_list : {pdb_path_list}') for ii, pdb_path in enumerate(pdb_path_list): protein, lddt = PDB_parsing(pdb_path, return_lddt=True) lddt = lddt[0].unsqueeze(1)/100 # (L, 1) xyz = protein.structure.xyz[0] seq = protein.sequence.sequence[0] seq_onehot = nn.functional.one_hot(seq, num_classes=NUM_CLASSES).float() # (L, NUM_CLASSES) f1d = torch.concat([seq_onehot, lddt], dim=-1) # (L, NUM_CLASSES + 1) mask_t = protein.structure.atom_mask[0] # (L, 27) template_featurized_dict['xyz'][0, ii, :, :14, :] = xyz template_featurized_dict['template_1D'][0, ii, :, :] = f1d template_featurized_dict['template_atom_mask'][0, ii, :, :14] = mask_t if ii > n_pick_global: break template_featurized_dict['use_chain_mask'] = torch.ones((1, query_length, query_length), dtype=torch.bool) elif template_mode == 'use_hhpred_output': # aligned by hhpred # it requires pdb_path for each template pdb_path_list = template_path_dict['pdb_path_list'] query_ID = template_path_dict['query_ID_used_in_hhpred'] hhpred_hhr_file = template_path_dict['hhpred_hhr_file'] parsed_info_list = hhpred_parser(hhpred_hhr_file, query_ID) # filter out templates that are not in pdb_path_list filtered_parsed_info_list = [] pdb_ID_list = [] for pdb_path in pdb_path_list: pdb_ID_list.append(pdb_path.split("/")[-1].split(".")[0]) ID_to_pdb_and_parsed_info = {} for parsed_info in parsed_info_list: for pdb_ID in pdb_ID_list: if pdb_ID in parsed_info['target_ID']: ID_to_pdb_and_parsed_info[pdb_ID] = { 'parsed_info' : parsed_info, 'pdb_path' : pdb_path_list[pdb_ID_list.index(pdb_ID)] } if len(ID_to_pdb_and_parsed_info) == 0: print(f"=== WARNING === : No template is found in pdb_path_list") template_idx = 0 for pdb_ID, item_dict in ID_to_pdb_and_parsed_info.items(): try: pdb_path = item_dict['pdb_path'] parsed_info = item_dict['parsed_info'] # protein = PDB_parsing(pdb_path) protein = mmcif_parsing(pdb_path) # TODO symmetry operation can be needed. template_xyz = protein.structure.xyz[0] # (L, 27, 3) template_seq = protein.sequence.sequence[0] # (L, ) query_map_idx = parsed_info['query_map_idx'] # (L_match, ) target_map_idx = parsed_info['target_map_idx'] # (L_match, ) confidence = parsed_info['confidence'] # (L_match, ) template_seq_onehot = nn.functional.one_hot(template_seq, num_classes=NUM_CLASSES).float() # (L, NUM_CLASSES) template_atom_mask = protein.structure.atom_mask[0] # (L, 27) template_featurized_dict['xyz'][:, template_idx, query_map_idx, :14, :] = template_xyz[target_map_idx] template_featurized_dict['template_1D'][:, template_idx, query_map_idx, :NUM_CLASSES] = template_seq_onehot[target_map_idx].unsqueeze(0) template_featurized_dict['template_1D'][:, template_idx, query_map_idx, NUM_CLASSES] = confidence template_featurized_dict['template_atom_mask'][:, template_idx, query_map_idx, :14] = template_atom_mask[target_map_idx].bool() except Exception as e: breakpoint() # Load initial xyz if 'initial_xyz' in raw_file_dict.keys(): initial_xyz, initial_atom_mask = self.load_initial_xyz(raw_file_dict['initial_xyz']) else : initial_xyz = None initial_atom_mask = None if symmetry is not None : # TODO msa_tensor = msa.msa self.save_clean_msa(msa_tensor, symmetry, msa_save_path) # Load interaction points interaction_dict = raw_file_dict['interaction_points'] if 'interaction_points' in raw_file_dict.keys() else None interaction_points_input = torch.zeros((1, query_length)) if interaction_dict is not None : before_chain_idx = -1 positive_dict = interaction_dict['positive'] negative_dict = interaction_dict['negative'] for chain_idx in positive_dict.keys(): pos_idx_list = positive_dict[chain_idx] neg_idx_list = negative_dict[chain_idx] for idx in pos_idx_list: idx += before_chain_idx + 1 interaction_points_input[0, idx] = 1 for idx in neg_idx_list: idx += before_chain_idx + 1 interaction_points_input[0, idx] = 2 chain_start, chain_end = chain_break[chain_idx] before_chain_idx = chain_end # print(f"test chain_break : {chain_break}") query_length = query_sequence.shape[0] sequence_idxs = [] for chain_idx, chain_break_tuple in chain_break.items(): chain_start, chain_end = chain_break_tuple sequence_idxs += list(range(chain_start, chain_end + 1)) sequence_idxs = torch.Tensor(sequence_idxs).long() out_of_sequence_idxs = torch.Tensor(list(set(range(query_length)) - set(sequence_idxs.tolist()))).long() crop_idx = sequence_idxs inputs = {} msa_feature_dict = MSA_featurize_wo_statistics(msa.msa, msa.insertion, chain_break, self.input_param) # batchify for key in msa_feature_dict.keys(): if isinstance(msa_feature_dict[key], torch.Tensor): msa_feature_dict[key] = msa_feature_dict[key].unsqueeze(0) else : msa_feature_dict[key] = [msa_feature_dict[key]] inputs["msa"] = msa_feature_dict inputs['template'] = template_featurized_dict inputs['interaction_points'] = interaction_points_input # cropping inputs["crop_idx"] = crop_idx.unsqueeze(0) # (1,L_crop) inputs["chain_break"] = [chain_break] # (N_chain : 2) # TODO : prev # N, Ca, C initial coord is in chemical.py if initial_xyz is None: random_xyz, atom_mask = generate_initial_xyz(query_sequence.unsqueeze(0), chain_break) inputs["prev"] = { "xyz" : random_xyz.unsqueeze(0), # (1, L_crop, 27, 3) 'atom_mask' : atom_mask.unsqueeze(0), # (1, L_crop, 27) } else : inputs["prev"] = { "xyz" : initial_xyz.unsqueeze(0), # (1, L_crop, 27, 3) 'atom_mask' : initial_atom_mask.unsqueeze(0), # (1, L_crop, 27) } inputs["ID"] = [raw_file_dict['ID']] inputs["source"] = ["PDB"] inputs["symmetry_related_info"] = [None] if saving_dir is None : saving_dir = self.saving_dir self.model_inference(inputs, self.device, saving_dir, saving_intermediate = saving_intermediate, only_structure = only_structure)
[docs] def save_clean_msa(self, msa, symmetry, save_path): # msa : (N,L) N, L = msa.shape # TODO : symmetry new_msa = torch.zeros((N,2*L)).long() new_msa[:,:L] = msa new_msa[:,L:] = msa with open(save_path, 'w') as f: for n in range(N): f.write(f">{n}\n") f.write("".join([num2AA[aa] for aa in new_msa[n].tolist()]) + "\n")
if __name__ == "__main__": from miniworld.utils import get_args args, model_param, input_json = get_args() # I use my own dataloader, so input_params are below. PSK seed_num = args.seed_num num_cycle = args.maxcycle num_cycle = 5 seed_list = [args.seed + i for i in range(seed_num)] # ffdb for load template FFDB = '/public_data/pdb100/pdb100_2020Mar11/pdb100_2020Mar11' FFindexDB = namedtuple("FFindexDB", "index, data") ffdb = FFindexDB(read_index(FFDB+'_pdb.ffindex'), read_data(FFDB+'_pdb.ffdata')) input_param = { "N_PICK" : 1, "N_PICK_GLOBAL" : 1, "PICK_TOP" : True, "MIN_SEED_MSA_PER_CHAIN" : 64, "MAX_SEED_MSA" : 512, "USE_MASK" : True, "MAX_MSA_CYCLE" : num_cycle, "MAX_SEQ" : 1024, "SEED" : args.seed, "BLOCKCUT" : 3000, 'USE_MULTIPLE_MODELS' : False, } model_param['n_main_block'] = 16 model_param['n_ref_block'] = 0 saved_model_path = "./model_weights/Model_1_5_nblock16_epoch29_use_interaction.pt" # print usable gpu num device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Device :', device) print('Usable GPU num :', torch.cuda.device_count()) rank = 0 output_dir = args.output_dir if not os.path.exists(output_dir): os.makedirs(output_dir) Visualizer = ModelVisulizer( saved_model_path=saved_model_path, saving_dir = output_dir, device=device, model_param=model_param, input_param=input_param, batch_size=args.batch_size, maxcycle=num_cycle) for seed_idx, seed in enumerate(seed_list): args.seed = seed Visualizer.input_param['SEED'] = args.seed with torch.autograd.set_detect_anomaly(False): # set random seed torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) torch.backends.cudnn.deterministic = True np.random.seed(args.seed) mp.freeze_support() for data in input_json: Visualizer.inference_from_raw_data(data, only_structure=False, template_mode="use_pdb_per_chain", a3m_mode = 'use_colabfold_diagonal_a3m')