Source code for miniworld.utils.template_parser

import numpy as np
import string
from miniworld import utils as util
import gzip
import torch
from miniworld.utils.ffindex import *
from miniworld.utils.chemical import INIT_CRDS

to1letter = {
    "ALA":'A', "ARG":'R', "ASN":'N', "ASP":'D', "CYS":'C',
    "GLN":'Q', "GLU":'E', "GLY":'G', "HIS":'H', "ILE":'I',
    "LEU":'L', "LYS":'K', "MET":'M', "PHE":'F', "PRO":'P',
    "SER":'S', "THR":'T', "TRP":'W', "TYR":'Y', "VAL":'V' }

# read A3M and convert letters into
# integers in the 0..20 range,
# also keep track of insertions
[docs] def parse_a3m(filename, max_seq=5000): msa = [] ins = [] table = str.maketrans(dict.fromkeys(string.ascii_lowercase)) #print(filename) if filename.split('.')[-1] == 'gz': fp = gzip.open(filename, 'rt') else: fp = open(filename, 'r') # read file line by line for line in fp: # skip labels if line[0] == '>': continue # remove right whitespaces line = line.rstrip() if len(line) == 0: continue # remove lowercase letters and append to MSA msa_i = line.translate(table) msas_i = msa_i.split('/') if (len(msa)==0): # first seq Ls = [len(x) for x in msas_i] msa = [[s] for s in msas_i] else: nchains = len(msas_i) isgood = all([ len(msa[i][0]) == len(msas_i[i]) for i in range(nchains) ]) if isgood: for i in range(nchains): msa[i].append(msas_i[i]) else: print ("Len error", filename, len(msa[0]) ) # sequence length L = sum(Ls) # 0 - match or gap; 1 - insertion a = np.array([0 if c.isupper() or c=='-' else 1 for c in line]) i = np.zeros((L)) if np.sum(a) > 0: # positions of insertions pos = np.where(a==1)[0] # shift by occurrence a = pos - np.arange(pos.shape[0]) # position of insertions in cleaned sequence # and their length pos,num = np.unique(a, return_counts=True) # append to the matrix of insetions i[pos] = num ins.append(i) if (len(msa[0]) >= max_seq): break # concatenate msa = [np.array([list(s) for s in t], dtype='|S1').view(np.uint8) for t in msa] msa = np.concatenate(msa,axis=-1) # convert letters into numbers alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8) for i in range(alphabet.shape[0]): msa[msa == alphabet[i]] = i # treat all unknown characters as gaps msa[msa > 20] = 20 ins = np.array(ins, dtype=np.uint8) return msa,ins,Ls
# read and extract xyz coords of N,Ca,C atoms # from a PDB file
[docs] def parse_pdb(filename): lines = open(filename,'r').readlines() return parse_pdb_lines(lines)
#'''
[docs] def parse_pdb_lines(lines): # indices of residues observed in the structure idx_s = [int(l[22:26]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"] # 4 BB + up to 10 SC atoms xyz = np.full((len(idx_s), 14, 3), np.nan, dtype=np.float32) for l in lines: if l[:4] != "ATOM": continue resNo, atom, aa = int(l[22:26]), l[12:16], l[17:20] idx = idx_s.index(resNo) for i_atm, tgtatm in enumerate(util.aa2long[util.aa2num[aa]]): if tgtatm == atom: xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])] break # save atom mask mask = np.logical_not(np.isnan(xyz[...,0])) xyz[np.isnan(xyz[...,0])] = 0.0 return xyz,mask,np.array(idx_s)
[docs] def parse_pdb_w_seq(filename, chnid=None): lines = open(filename,'r').readlines() return parse_pdb_lines_w_seq(lines)
[docs] def parse_pdb_lines_w_seq(lines, chnid=None): # indices of residues observed in the structure #idx_s = [int(l[22:26]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"] res = [(l[22:26],l[17:20]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"] idx_s = [int(r[0]) for r in res] seq = [util.aa2num[r[1]] if r[1] in util.aa2num.keys() else 20 for r in res] # 4 BB + up to 10 SC atoms xyz = np.full((len(idx_s), 27, 3), np.nan, dtype=np.float32) for l in lines: if l[:4] != "ATOM": continue chn, resNo, atom, aa = l[21], int(l[22:26]), l[12:16], l[17:20] if chnid is not None : #and chn != chain: continue idx = idx_s.index(resNo) for i_atm, tgtatm in enumerate(util.aa2long[util.aa2num[aa]]): if tgtatm == atom: xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])] break # save atom mask mask = np.logical_not(np.isnan(xyz[...,0])) #xyz[np.isnan(xyz[...,0])] = 0.0 return xyz,mask,np.array(idx_s), np.array(seq)
# def parse_templates(item, params): # # init FFindexDB of templates # ### and extract template IDs # ### present in the DB # ffdb = FFindexDB(read_index(params['FFDB']+'_pdb.ffindex'), # read_data(params['FFDB']+'_pdb.ffdata')) # #ffids = set([i.name for i in ffdb.index]) # # process tabulated hhsearch output to get # # matched positions and positional scores # infile = params['DIR']+'/hhr/'+item[-2:]+'/'+item+'.atab' # hits = [] # for l in open(infile, "r").readlines(): # if l[0]=='>': # key = l[1:].split()[0] # hits.append([key,[],[]]) # elif "score" in l or "dssp" in l: # continue # else: # hi = l.split()[:5]+[0.0,0.0,0.0] # hits[-1][1].append([int(hi[0]),int(hi[1])]) # hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])]) # # get per-hit statistics from an .hhr file # # (!!! assume that .hhr and .atab have the same hits !!!) # # [Probab, E-value, Score, Aligned_cols, # # Identities, Similarity, Sum_probs, Template_Neff] # lines = open(infile[:-4]+'hhr', "r").readlines() # pos = [i+1 for i,l in enumerate(lines) if l[0]=='>'] # for i,posi in enumerate(pos): # hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]]) # # parse templates from FFDB # for hi in hits: # #if hi[0] not in ffids: # # continue # entry = get_entry_by_name(hi[0], ffdb.index) # if entry == None: # continue # data = read_entry_lines(entry, ffdb.data) # hi += list(parse_pdb_lines(data)) # # process hits # counter = 0 # xyz,qmap,mask,f0d,f1d,ids = [],[],[],[],[],[] # for data in hits: # if len(data)<7: # continue # qi,ti = np.array(data[1]).T # _,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True) # ncol = sel1.shape[0] # if ncol < 10: # continue # ids.append(data[0]) # f0d.append(data[3]) # f1d.append(np.array(data[2])[sel1]) # xyz.append(data[4][sel2]) # mask.append(data[5][sel2]) # qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1)) # counter += 1 # xyz = np.vstack(xyz).astype(np.float32) # mask = np.vstack(mask).astype(np.bool) # qmap = np.vstack(qmap).astype(np.long) # f0d = np.vstack(f0d).astype(np.float32) # f1d = np.vstack(f1d).astype(np.float32) # ids = ids # return xyz,mask,qmap,f0d,f1d,ids
[docs] def parse_templates_raw(ffdb, hhr_fn, atab_fn, templ_to_use, max_templ=20): # process tabulated hhsearch output to get # matched positions and positional scores hits = [] read_stat = False for l in open(atab_fn, "r").readlines(): if l[0]=='>': read_stat = False if len(hits) == max_templ: break key = l[1:].split()[0] if len(templ_to_use) > 1: if key not in templ_to_use: continue read_stat = True hits.append([key,[],[]]) elif "score" in l or "dssp" in l: continue elif not read_stat: continue else: hi = l.split()[:5]+[0.0,0.0,0.0] hits[-1][1].append([int(hi[0]),int(hi[1])]) hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])]) # breakpoint() # parse templates from FFDB for hi in hits: #if hi[0] not in ffids: # continue entry = get_entry_by_name(hi[0], ffdb.index) if entry == None: print ("Failed to find %s in *_pdb.ffindex"%hi[0]) continue data = read_entry_lines(entry, ffdb.data) hi += list(parse_pdb_lines_w_seq(data)) # (add four more items) # process hits counter = 0 xyz,qmap,mask,f1d,ids,seq = [],[],[],[],[],[] for data in hits: if len(data)<7: continue # print ("Process %s..."%data[0]) qi,ti = np.array(data[1]).T _,sel1,sel2 = np.intersect1d(ti, data[5], return_indices=True) ncol = sel1.shape[0] if ncol < 10: continue ids.append(data[0]) f1d.append(np.array(data[2])[sel1]) xyz.append(data[3][sel2]) mask.append(data[4][sel2]) seq.append(data[-1][sel2]) qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1)) counter += 1 xyz = np.vstack(xyz).astype(np.float32) mask = np.vstack(mask).astype(np.float32) qmap = np.vstack(qmap).astype(np.int32) f1d = np.vstack(f1d).astype(np.float32) seq = np.hstack(seq).astype(np.int32) ids = ids return torch.from_numpy(xyz), torch.from_numpy(mask), torch.from_numpy(qmap).long(), \ torch.from_numpy(f1d), torch.from_numpy(seq).long(), ids
[docs] def read_templates(qlen, ffdb, hhr_fn, atab_fn, templ_to_use=[], offset=0, n_templ=10, random_noise=5.0): xyz_t, mask, qmap, t1d, seq, ids = parse_templates_raw(ffdb, hhr_fn, atab_fn, templ_to_use, max_templ=max(n_templ, 20)) npick = min(n_templ, len(ids)) if npick < 1: # no templates xyz = INIT_CRDS.reshape(1,1,27,3).repeat(npick,qlen,1,1) + torch.rand(npick,qlen,1,3)*random_noise t1d = torch.nn.functional.one_hot(torch.full((1, qlen), 20).long(), num_classes=21).float() # all gaps t1d = torch.cat((t1d, torch.zeros((1,qlen,1)).float()), -1) return xyz, t1d sample = torch.arange(npick) # xyz = INIT_CRDS.reshape(1,1,27,3).repeat(npick,qlen,1,1) + torch.rand(npick,qlen,1,3)*random_noise mask_t = torch.full((npick,qlen,27),False) # True for valid atom, False for missing atom f1d = torch.full((npick, qlen), 20).long() f1d_val = torch.zeros((npick, qlen, 1)).float() # for i, nt in enumerate(sample): sel = torch.where(qmap[:,1] == nt)[0] pos = qmap[sel, 0] + offset xyz[i, pos] = xyz_t[sel] mask_t[i,pos] = mask[sel].bool() f1d[i, pos] = seq[sel] f1d_val[i,pos] = t1d[sel, 2].unsqueeze(-1) # xyz[i] = util.center_and_realign_missing(xyz[i], mask_t[i]) f1d = torch.nn.functional.one_hot(f1d, num_classes=23).float() f1d = torch.cat((f1d, f1d_val), dim=-1) return xyz, f1d, mask_t
[docs] def read_template_pdb(qlen, templ_fn, align_conf=1.0): xyz = torch.full((qlen, 27, 3), np.nan).float() seq = torch.full((qlen,), 20).long() # all gaps with open(templ_fn) as fp: for l in fp: if l[:4] != "ATOM": continue resNo, atom, aa = int(l[22:26]), l[12:16], l[17:20] aa_idx = util.aa2num[aa] if aa in util.aa2num.keys() else 20 # idx = resNo - 1 for i_atm, tgtatm in enumerate(util.aa2long[aa_idx][:3]): if tgtatm == atom: xyz[idx,i_atm,:] = torch.tensor([float(l[30:38]), float(l[38:46]), float(l[46:54])]) break else: continue seq[idx] = aa_idx mask = torch.logical_not(torch.isnan(xyz[:,:3,0])) # (qlen, 3) mask = mask.all(dim=-1) # (qlen) conf = torch.where(mask, torch.tensor(align_conf).float(), torch.zeros(qlen, dtype=xyz.dtype)) seq = torch.nn.functional.one_hot(seq, num_classes=21).float() t1d = torch.cat((seq, conf[:,None]), -1) return xyz[None], t1d[None]
[docs] def get_args(): import argparse parser = argparse.ArgumentParser(description="RoseTTAFold2NA") parser.add_argument("-inputs", help="R|Input data in format A:B:C, with\n" " A = multiple sequence alignment file\n" " B = hhpred hhr file\n" " C = hhpred atab file\n" "Spaces seperate multiple inputs. The last two arguments may be omitted\n", default=None, nargs='+') #parser.add_argument("-db", help="HHpred database location", default=None) parser.add_argument("-db", help="HHpred database location", default='/public_data/pdb100/pdb100_2020Mar11/pdb100_2020Mar11') parser.add_argument("-prefix", help="Output prefix", type=str, default="S") parser.add_argument("-symm", default="C1", help="Symmetry. IF PROVIDED, 'input' should cover the asymmetric unit") #parser.add_argument("-model", default=None, help="Model weights") parser.add_argument("-model", default="/home/mink/RF2_variants/RoseTTAFold2/network/weights/RF2_apr23.pt", help="Model weights") args = parser.parse_args() return args
if __name__ == "__main__": qlen = 144 FFDB = '/public_data/pdb100/pdb100_2020Mar11/pdb100_2020Mar11' FFindexDB = namedtuple("FFindexDB", "index, data") ffdb = FFindexDB(read_index(FFDB+'_pdb.ffindex'), read_data(FFDB+'_pdb.ffdata')) # ffdb = None hhr_file = "/home/casp16/run/TS.human/T1207/MiniWorld/templates/T1207_.hhr" atab_file = "/home/casp16/run/TS.human/T1207/MiniWorld/templates/T1207_.atab" xyz, f1d, mask_t = read_templates(qlen, ffdb, hhr_file, atab_file, offset=0, n_templ=10, random_noise=5.0) breakpoint()