"""
This file just contains the test data generation.
"""
import numpy as np
import os
from collections import OrderedDict
import pickle
from miniworld.utils.ProteinClass import ProteinSequence, ProteinStructure, Protein, ProteinTemplate, ProteinMSA, ProteinTemplates, PDB_parsing
from miniworld.utils.chemical import aa2num, aa2long, AA2num, aa2AA, num2AA
import copy
import time
import gzip
import re
import json
import traceback
import torch
PDB_RELEASE_DATE = "2021-Aug-02"
PDB_TRAIN_DATA_PATH = "/public_data/ml/RF2_train/PDB-2021AUG02"
# TEST_NUM = 100
# Example_hash_code = "000/"
# sample_num = 100
Example_code = "d7"
sample_num = 100
[docs]
def get_csv():
csv_path = os.path.join(PDB_TRAIN_DATA_PATH, "list_v02.csv")
with open(csv_path, "r") as f:
lines = f.readlines()
hash_seq_dict = {}
for ii, line in enumerate(lines):
if ii == 0:
continue
chain_ID, deposition, resolution, hash, cluster, seq, length = line.strip().split(",")
if chain_ID[1:3] != Example_code:
continue
if hash not in hash_seq_dict:
hash_seq_dict[hash] = [[chain_ID, seq, cluster, deposition]]
else:
hash_seq_dict[hash].append([chain_ID, seq, cluster, deposition])
return hash_seq_dict
[docs]
def pdb_Sampling(output_path = "data/test_data/PDB_2021Aug02/"):
hash_seq_dict = get_csv()
hash_seq_dict_keys = list(hash_seq_dict.keys())
hash_seq_dict_keys.sort()
a3m_path = PDB_TRAIN_DATA_PATH + "/a3m/"
hhr_path = PDB_TRAIN_DATA_PATH + "/torch/hhr/"
pdb_path = PDB_TRAIN_DATA_PATH + "/torch/pdb/" + Example_code + "/"
a3m_file_path_list = []
hhr_file_path_list = []
for hash in hash_seq_dict_keys:
# hash : 6 digits
hash_folder = hash[:3]
a3m_file_path = a3m_path + hash_folder + "/" + hash + ".a3m.gz"
hhr_file_path = hhr_path + hash_folder + "/" + hash + ".pt"
a3m_file_path_list.append(a3m_file_path)
hhr_file_path_list.append(hhr_file_path)
# pdb_file_path_list = []
# for file in os.listdir(pdb_path):
# if file[-3:] == ".pt":
# pdb_file_path_list.append(pdb_path + file)
if not os.path.exists(output_path):
os.makedirs(output_path)
for hash, a3m_file_path, hhr_file_path in zip(hash_seq_dict_keys, a3m_file_path_list, hhr_file_path_list):
a3m_save_path = output_path + hash + ".a3m.gz"
hhr_save_path = output_path + hash + ".pt"
os.system("cp %s %s"%(a3m_file_path, a3m_save_path))
os.system("cp %s %s"%(hhr_file_path, hhr_save_path))
with open(output_path + "list_v02_sampled.csv", "w") as f:
for hash in hash_seq_dict:
value_list = hash_seq_dict[hash]
for chain_ID, seq, cluster, deposition in value_list:
f.write("%s,%s,%s,%s,%s,%s,%s\n"%(chain_ID, deposition, "0.0", hash, cluster, seq, "0"))
[docs]
def gzip_files():
"""
It's just for convenience. For large data, use gzip library rather than this function.
"""
folder_path = "data/test_data/PDB_2021Aug02/"
file_list = os.listdir(folder_path)
for file in file_list:
if file[-3:] == ".gz":
os.system("gzip -d %s"%(folder_path + file))
[docs]
def print_hhr_pt(file_path = "data/test_data/PDB_2021Aug02/001128.pt"):
import torch
pt = torch.load(file_path)
print(pt.keys())
for key, value in pt.items():
print(key)
# if key is tensor
if isinstance(value, torch.Tensor):
print(value.shape)
# else if key is list
elif isinstance(value, list):
print(len(value))
print(pt['xyz'].shape)
print(pt['mask'][0,:,-1])
assert 1==0
print(pt["f0d"][0,0,:])
print(pt["qmap"][0,:,1])
sel = torch.where(pt['qmap'][0,:,1]==1)[0]
print(sel)
mask_t = pt['mask'][0,sel].bool()
print(mask_t[0])
pos = pt['qmap'][0,sel,0]
print(pos)
# for aa in pt["seq"][0]:
# print(num2aa[int(aa)], end="")
[docs]
def refactoring_hhr_file(original_hhr_file, saving_directory = None):
"""
Input : original hhr file path
Do : Refactoring hhr file and save it
Output : None
Original hhr file has qmap which has information about length and position of each template.
I think it is not necessary to save this information in hhr file and hard to understand. So I remove it. and save as list of torch tensor.
"""
filename = original_hhr_file.split("/")[-1]
original_hhr_file = torch.load(original_hhr_file)
print(f"test original_hhr_file.keys() : {original_hhr_file.keys()}")
print(f"test original_hhr_file['qmap'].shape : {original_hhr_file['qmap'].shape}")
print(f"test original_hhr_file['xyz'].shape : {original_hhr_file['xyz'].shape}")
qmap = original_hhr_file["qmap"]
assert qmap.shape[0] == 1, "qmap.shape[0] != 1"
ids = original_hhr_file["ids"]
qmap = qmap[0]
print(f"test qmap.shape : {qmap.shape}")
template_idxs = qmap[:,1]
template_list = []
xyz = original_hhr_file["xyz"] # [B=1,Sum of seqs of template, 14, 3]
mask = original_hhr_file["mask"] # [B=1,Sum of seqs of template, 14]
seq = original_hhr_file["seq"] # [B=1,Sum of seqs of template]
f0d = original_hhr_file["f0d"] # [B=1, Number of template, 8]
f1d = original_hhr_file["f1d"] # [B=1,Sum of seqs of template, 3]
for id in range(len(ids)):
qmap_pos = torch.where(template_idxs == id)[0]
position = qmap[qmap_pos,0]
xyz_id = xyz[:,qmap_pos,:,:]
mask_id = mask[:,qmap_pos,:]
seq_id = seq[:,qmap_pos]
f1d_id = f1d[:,qmap_pos,-1] # last one is a residue-wise alignment confidence
f0d_id = f0d[:,id,4] # index 4 is a sequence identity
template_list.append({"template_ID" : id, "xyz":xyz_id, "mask":mask_id, "position" : position, "sequence":seq_id, "f0d" : f0d_id, "f1d":f1d_id})
if saving_directory is not None :
if not os.path.exists(saving_directory):
os.makedirs(saving_directory)
torch.save(template_list, saving_directory + "/" + filename)
else :
return template_list
"""
template_list example
template_ID xyz mask, seq f0d f1d
0 torch.Size([1, 201, 14, 3]) torch.Size([1, 201, 14]) torch.Size([1, 201]) torch.Size([1, 201])
1 torch.Size([1, 193, 14, 3]) torch.Size([1, 193, 14]) torch.Size([1, 193]) torch.Size([1, 193])
2 torch.Size([1, 193, 14, 3]) torch.Size([1, 193, 14]) torch.Size([1, 193]) torch.Size([1, 193])
3 torch.Size([1, 193, 14, 3]) torch.Size([1, 193, 14]) torch.Size([1, 193]) torch.Size([1, 193])
4 torch.Size([1, 196, 14, 3]) torch.Size([1, 196, 14]) torch.Size([1, 196]) torch.Size([1, 196])
...
"""
# For test. you can remove it.
# template_list = torch.load(output_directory + "/" + filename)
# total_length = 0
# for template in template_list:
# total_length += template["sequence"].shape[1]
# print(template["template_ID"], template["xyz"].shape, template["mask"].shape, template["sequence"].shape, template["f1d"].shape)
# print(total_length)
[docs]
def PDB_parsing_from_pt(PDB_file, IS_LABEL = True, source = ""):
"""
Input : PDB file path, IS_LABEL (bool)
Output : Protein object
This function is for .pt file pt_file is dictionary with keys of seq, xyz, mask, bfac, occ
"""
pt_file = torch.load(PDB_file)
sequence = ProteinSequence(sequence = pt_file["seq"])
structure = ProteinStructure(xyz = pt_file["xyz"], mask = pt_file["mask"], chain_break = sequence.chain_break)
protein_ID = PDB_file.split("/")[-1].split(".")[:-1]
protein_ID = ".".join(protein_ID)
protein = Protein(sequence = sequence, structure = structure, ID = protein_ID, IS_LABEL = IS_LABEL, source = source)
return protein
[docs]
def protein_to_template(protein, position = None, f0d = None, f1d = None, IS_LABEL = False):
"""
Input : Protein object, position (torch.Tensor)
Output : ProteinTemplate object
This function is used when Alphafold output is used as template.
A protein is a object of Protein class and a position is confident residue position. (It is filtered from other function)
"""
assert isinstance(protein, Protein), "protein must be Protein class"
protein_template = copy.deepcopy(protein)
protein_template.__class__= ProteinTemplate
if position is None : # All residue position is confident
protein_length = protein_template.sequence.get_sequence_length()
position = torch.arange(protein_length)
protein_template.position = position
protein_template.template_ID = protein_template.ID # template_ID is used in ProteinTemplate class
protein_template.f0d = f0d
protein_template.f1d = f1d
protein_template.IS_LABEL = False
return protein_template
"""
PDB file, MSA(a3m) file is too big and it takes too much time to parse and load. So I parse PDB file and make Protein object to save it as .pkl file using pickle
(Refactoring) hhr file is compact and easy to load. So I don't save it as .pkl file.
"""
[docs]
def test_a3m():
# 압축된 파일 열기
with gzip.open('example.gz', 'rb') as f:
# 파일 내용 읽기
content = f.read()
# 읽은 내용 출력
print(content)
[docs]
def mmcif_line_parser(line, loop_ = None):
"""Parses a single line from an mmCIF file.::
Input : mmcif line
Output : dictionary or list
loop_ : Optional, Example) [_atom_site.group_PDB, _atom_site.id, ...]
Example :
ATOM line -> ATOM 1 N N . MET A 1 1 ? 11.242 1.210 20.525 1.00 4.07 ? 1 MET A N 1
Chem line -> ALA 'L-peptide linking' y ALANINE ? 'C3 H7 N O2' 89.093
These lines are separated by space but there is 'something something' ('L-peptide linking') with space in some lines. So I use alternative split function.
"""
pattern = r"('[^']*')|(\S+)" # This pattern means that if there is quote, it is separated by quote. If not, it is separated by space.
matches = re.finditer(pattern, line)
results = []
for match in matches:
# This line is equivalent to first trying match.group(1) and if that's None, trying match.group(2).
result = match.group(1) or match.group(2)
# remove quotes if they exist
if result.startswith("'") and result.endswith("'"):
result = result[1:-1]
results.append(result)
if loop_ is None:
return results
else :
if len(loop_) != len(results):
raise ValueError("loop_ length is not same with results length")
return {key : value for key, value in zip(loop_, results)}
[docs]
def mmcif_loop_parser(lines_split_by_sharp, first_key, IS_ATOM = False):
"""
CAUTION !!!
mmcif is not well-sturctured file. So it is not easy to parse.
Therefore I use this function to parse mmcif file.
"""
for lines in lines_split_by_sharp:
if len(lines) < 2:
continue
if lines[0] == "loop_" and lines[1].startswith(first_key):
key_list = []
temp_info_list = []
for line in lines[1:]:
line = line.strip()
if line.startswith("_"):
key_list.append(line.split(".")[1])
else :
if IS_ATOM:
if "ATOM" != line[:4]:
if "HETATM" != line[:6] or "MSE" != line[21:24]:
continue
mmcif_line_parse_result = mmcif_line_parser(line)
if mmcif_line_parse_result[0] == ";" : continue
elif mmcif_line_parse_result[0][0] == ";" :
mmcif_line_parse_result[0] = mmcif_line_parse_result[0][1:]
for data in mmcif_line_parse_result:
temp_info_list.append(data)
if len(temp_info_list) % len(key_list) != 0:
raise ValueError("len(temp_info_list) '%' len(key_list) != 0")
info_list = []
for key in key_list:
info_list.append([])
for i, data in enumerate(temp_info_list):
info_list[i%len(key_list)].append(data)
# key_list : (key1, key2, key3, ...)
# info_list : [(value1, value2, value3, ...), (value1, value2, value3, ...), (value1, value2, value3, ...), ...]
dict_list = []
for i in range(len(info_list[0])):
dict_list.append({key : info_list[j][i] for j, key in enumerate(key_list)})
return dict_list
elif lines[0].startswith(first_key):
"""
In this case, there is no loop -> key \t value
However, there are freaking weird lines like below.
Ex)
_pdbx_struct_assembly_gen.asym_id_list
;A,B,C,D,E,F,G,H,I,J,K,L,M,N,O,P,Q,R,S,T,U,V,W,X,Y,Z,AA,BA,CA,DA,EA,FA,GA,HA,IA,JA,KA,LA,MA,NA,OA,PA,QA,RA,SA,TA,UA,VA,WA,XA,YA,ZA,AB,BB,CB,DB,EB,FB,GB,HB,IB,JB,KB,LB,MB,NB,OB,PB,QB,RB,SB,TB,UB,VB,WB,XB,YB,ZB,AC,BC,CC,DC,EC,FC,GC,HC,IC,JC,KC,LC,MC,NC,OC,PC,QC,RC,SC,TC,UC,VC,WC,XC,YC,ZC,AD,BD,CD,DD,ED,FD,GD,HD,ID,JD,KD,LD,MD,ND,OD,PD,QD,RD,SD,TD,UD,VD,WD,XD,YD,ZD,AE,BE,CE,DE,EE,FE,GE,HE,IE,JE,KE,LE,ME,NE,OE,PE,QE,RE,SE,TE,UE,VE,WE,XE,YE,ZE,AF,BF,CF,DF,EF,FF,GF
;
I define this kind of lines as "freaking weird" lines.
"""
key_list = []
info_list = []
IS_freaking_weird = False
for line in lines:
mmcif_line_parse_result = mmcif_line_parser(line)
if IS_freaking_weird:
value = mmcif_line_parse_result[0]
value = value.split(";")[-1]
key_list.append(key.split(".")[1])
info_list.append(value)
IS_freaking_weird = False
continue
if mmcif_line_parse_result[0] == ";" : continue
if len(mmcif_line_parse_result) == 1:
key = mmcif_line_parse_result[0]
IS_freaking_weird = True
continue
key, value = mmcif_line_parser(line)
key_list.append(key.split(".")[1])
info_list.append(value)
dict = {key : value for key, value in zip(key_list, info_list)}
dict_list = [dict]
return dict_list
return None
[docs]
def mmcif_parsing(mmcif_file, IS_LABEL = True):
"""
Input : mmcif file path (~.cif or ~.cif.gz)
Output : Protein object
In this project, I exclude nucleic acid. So I don't consider nucleic acid.
Also I only consider ATOM line -> It can make problem...
"""
ProteinID = mmcif_file.split("/")[-1].split(".")[:-1]
ProteinID = ".".join(ProteinID)
start_time = time.time()
if mmcif_file[-3:] == ".gz":
with gzip.open(mmcif_file, "rb") as f:
lines = f.readlines()
lines = [line.decode("utf-8") for line in lines]
elif mmcif_file[-4:] == ".cif":
with open(mmcif_file, "r") as f:
lines = f.readlines()
else :
raise ValueError("mmcif file must be .cif or .cif.gz file")
lines = [line.strip() for line in lines]
lines_split_by_sharp = []
temp = []
for line in lines:
if line.startswith("#"):
lines_split_by_sharp.append(temp)
temp = []
else :
temp.append(line)
lines_split_by_sharp.append(temp)
"""
#
_pdbx_struct_assembly.id 1
_pdbx_struct_assembly.details author_and_software_defined_assembly
_pdbx_struct_assembly.method_details PISA
_pdbx_struct_assembly.oligomeric_details trimeric
_pdbx_struct_assembly.oligomeric_count 3
#
_pdbx_struct_assembly_gen.assembly_id 1
_pdbx_struct_assembly_gen.oper_expression 1,2,3
_pdbx_struct_assembly_gen.asym_id_list A,B,C,D,E
#
loop_
_pdbx_struct_assembly_prop.biol_id
_pdbx_struct_assembly_prop.type
_pdbx_struct_assembly_prop.value
_pdbx_struct_assembly_prop.details
1 'ABSA (A^2)' 5740 ?
1 MORE -155 ?
1 'SSA (A^2)' 18370 ?
#
loop_
_pdbx_struct_oper_list.id
_pdbx_struct_oper_list.type
_pdbx_struct_oper_list.name
_pdbx_struct_oper_list.symmetry_operation
_pdbx_struct_oper_list.matrix[1][1]
_pdbx_struct_oper_list.matrix[1][2]
_pdbx_struct_oper_list.matrix[1][3]
_pdbx_struct_oper_list.vector[1]
_pdbx_struct_oper_list.matrix[2][1]
_pdbx_struct_oper_list.matrix[2][2]
_pdbx_struct_oper_list.matrix[2][3]
_pdbx_struct_oper_list.vector[2]
_pdbx_struct_oper_list.matrix[3][1]
_pdbx_struct_oper_list.matrix[3][2]
_pdbx_struct_oper_list.matrix[3][3]
_pdbx_struct_oper_list.vector[3]
1 'identity operation' 1_555 x,y,z 1.0000000000 0.0000000000 0.0000000000 0.0000000000 0.0000000000 1.0000000000
0.0000000000 0.0000000000 0.0000000000 0.0000000000 1.0000000000 0.0000000000
2 'crystal symmetry operation' 2_555 -y,x-y,z -0.5000000000 -0.8660254038 0.0000000000 0.0000000000 0.8660254038 -0.5000000000
0.0000000000 0.0000000000 0.0000000000 0.0000000000 1.0000000000 0.0000000000
3 'crystal symmetry operation' 3_555 -x+y,-x,z -0.5000000000 0.8660254038 0.0000000000 0.0000000000 -0.8660254038 -0.5000000000
0.0000000000 0.0000000000 0.0000000000 0.0000000000 1.0000000000 0.0000000000
#
atom_site_idxs :
_atom_site.group_PDB
_atom_site.id
_atom_site.type_symbol
_atom_site.label_atom_id
_atom_site.label_alt_id
_atom_site.label_comp_id
_atom_site.label_asym_id
_atom_site.label_entity_id
_atom_site.label_seq_id
_atom_site.pdbx_PDB_ins_code
_atom_site.Cartn_x
_atom_site.Cartn_y
_atom_site.Cartn_z
_atom_site.occupancy
_atom_site.B_iso_or_equiv
_atom_site.pdbx_formal_charge
_atom_site.auth_seq_id
_atom_site.auth_comp_id
_atom_site.auth_asym_id
_atom_site.auth_atom_id
_atom_site.pdbx_PDB_model_num
Example)
{'group_PDB': 'ATOM', 'id': '1', 'type_symbol': 'N', 'label_atom_id': 'N', 'label_alt_id': '.',
'label_comp_id': 'THR', 'label_asym_id': 'A', 'label_entity_id': '1', 'label_seq_id': '3', 'pdbx_PDB_ins_code': '?',
'Cartn_x': '8.342', 'Cartn_y': '8.244', 'Cartn_z': '-0.883', 'occupancy': '1.00', 'B_iso_or_equiv': '160.56', 'pdbx_formal_charge': '?',
'auth_seq_id': '2', 'auth_comp_id': 'THR', 'auth_asym_id': 'A', 'auth_atom_id': 'N', 'pdbx_PDB_model_num': '1'}
"""
atom_site_dict_list = mmcif_loop_parser(lines_split_by_sharp, first_key = "_atom_site.group_PDB",IS_ATOM=True) # list of dict
pdbx_struct_assembly_dict_list = mmcif_loop_parser(lines_split_by_sharp, first_key = "_pdbx_struct_assembly.id")
pdbx_struct_assembly_gen_dict_list = mmcif_loop_parser(lines_split_by_sharp, first_key = "_pdbx_struct_assembly_gen.assembly_id")
pdbx_struct_oper_list_dict_list = mmcif_loop_parser(lines_split_by_sharp, first_key = "_pdbx_struct_oper_list.id")
Symmetry_related_lines = {"_pdbx_struct_assembly" :pdbx_struct_assembly_dict_list, "_pdbx_struct_assembly_gen" :pdbx_struct_assembly_gen_dict_list, "_pdbx_struct_oper_list" :pdbx_struct_oper_list_dict_list}
if len(atom_site_dict_list) == 0:
return "No ATOM line"
has_multiple_chains = False
model_xyz = OrderedDict() # (M, L, 14, 3)
model_atom_mask = OrderedDict() # (M, L, 14)
model_occ = OrderedDict() # (M, L, 14)
model_seq = OrderedDict() # (M, L)
model_position = OrderedDict() # (M, L) # Some mmcif files have missing residue. So I use position to indicate residue position.
model_chain = OrderedDict() # (M, L) # I use this to check chain break
chain_entity_ID_dict = OrderedDict() # (C,) # I assume that each chain has only one entitiy ID.
before_residue_position = -1 # This is for full-CA mmcif files
for ii, ATOM_line in enumerate(atom_site_dict_list) :
atom, chain, occupancy, entity_ID, residue_position = ATOM_line["label_atom_id"], ATOM_line["label_asym_id"], float(ATOM_line["occupancy"]), ATOM_line['label_entity_id'], int(ATOM_line['label_seq_id'])
model = int(ATOM_line["pdbx_PDB_model_num"]) - 1
if model not in model_xyz:
model_xyz[model] = []
model_atom_mask[model] = []
model_occ[model] = []
model_seq[model] = []
model_position[model] = []
model_chain[model] = []
residue = ATOM_line["label_comp_id"]
if residue not in aa2num.keys():
print(f"residue {residue}",end = "")
# has_nucleic_acid = True
# print("Nucleic acid is included")
return "Nucleic acid is included"
residue_num = aa2num[residue]
atom_list = aa2long[residue_num] # 0~13 : Heavy atom, 14~26 : Hydrogen atom (what is it?)
Heavy_atom_list = atom_list[:14]
Heavy_atom_list = [atom.replace(" ", "") if atom is not None else None for atom in Heavy_atom_list]
# Hydrogen_atom_list = atom_list[14:]
try : chain_entity_ID_dict[chain] = entity_ID
except :
print(f"chain {chain} has multiple entity ID")
return "chain has multiple entity ID"
if residue_position != before_residue_position:
model_seq[model].append(residue_num)
model_position[model].append(residue_position)
model_chain[model].append(chain)
# append xyz_model to empty 14,3 list
xyz_empty_array = [[0.0,0.0,0.0] for _ in range(14)]
mask_empty_array = [0 for _ in range(14)]
model_xyz[model].append(xyz_empty_array)
model_atom_mask[model].append(mask_empty_array)
model_occ[model].append([0.0 for _ in range(14)])
if atom in Heavy_atom_list:
atom_idx = Heavy_atom_list.index(atom)
x = float(ATOM_line["Cartn_x"])
y = float(ATOM_line["Cartn_y"])
z = float(ATOM_line["Cartn_z"])
model_xyz[model][-1][atom_idx] = [x,y,z]
model_atom_mask[model][-1][atom_idx] = 1
model_occ[model][-1][atom_idx] = occupancy
before_residue_position = residue_position
"""
model_xyz[model] # (L', 14, 3) ; L' <= L because some residue is missing
model_atom_mask[model] # (L', 14)
model_occ[model] # (L', 14)
model_seq[model] # (L')
model_entity_ID[model] # (L')
model_position[model] # (L')
Using mmcif_Full_sequence_parsing, I make the full sequence.
"""
try :
_, full_sequence_dict = mmcif_Full_sequence_parsing(mmcif_file, output = "Tensor") # full_sequence_dict : {entity_ID : sequence}
except Exception as e:
return "Nucleic acid is included"
# for this case only
full_sequence_dict['?'] = full_sequence_dict['1']
full_model_xyz = [] # (M, L, 14, 3)
full_model_atom_mask = [] # (M, L, 14)
full_model_occ = [] # (M, L, 14)
full_model_masked_seq = [] # (M, L)
full_model_seq = [] # (M, L)
full_model_position_mask = [] # (M, L)
chain_break = OrderedDict()
for ii, model in enumerate(model_xyz.keys()):
full_xyz_of_chain = {}
full_atom_mask_of_chain = {}
full_occ_of_chain = {}
full_masked_seq_of_chain = {}
full_seq_of_chain = {}
full_position_mask = []
total_length = 0
for jj in range(len(model_seq[model])):
chain = model_chain[model][jj]
entity_ID = chain_entity_ID_dict[chain]
full_sequence = full_sequence_dict[str(entity_ID)]
if chain not in full_xyz_of_chain:
total_length += len(full_sequence)
full_xyz_of_chain[chain] = torch.full((len(full_sequence), 14, 3), 0.0) # (L_i, 14, 3)
full_atom_mask_of_chain[chain] = torch.full((len(full_sequence), 14), 0.0) # (L_i, 14)
full_occ_of_chain[chain] = torch.full((len(full_sequence), 14), 0.0) # (L_i, 14)
full_masked_seq_of_chain[chain] = torch.full((len(full_sequence),), 22) # (L_i), Filled to Mask, 22 is Mask
full_position_mask.append(torch.full((len(full_sequence),), 0.0)) # (L_i)
full_xyz_of_chain[chain][model_position[model][jj]-1] = torch.tensor(model_xyz[model][jj])
full_atom_mask_of_chain[chain][model_position[model][jj]-1] = torch.tensor(model_atom_mask[model][jj])
full_occ_of_chain[chain][model_position[model][jj]-1] = torch.tensor(model_occ[model][jj])
full_masked_seq_of_chain[chain][model_position[model][jj]-1] = torch.tensor(model_seq[model][jj])
full_seq_of_chain[chain] = full_sequence
full_position_mask[-1][model_position[model][jj]-1] = 1.0
if ii == 0:
chain_start = 0
chain_end = 0
for chain in full_xyz_of_chain.keys():
chain_end = chain_start + len(full_xyz_of_chain[chain]) - 1
chain_break[chain] = [chain_start, chain_end]
chain_start = chain_end + 1
full_xyz = torch.cat(list(full_xyz_of_chain.values())) # (L, 14, 3)
full_atom_mask = torch.cat(list(full_atom_mask_of_chain.values())) # (L, 14)
full_occ = torch.cat(list(full_occ_of_chain.values())) # (L, 14)
full_masked_seq = torch.cat(list(full_masked_seq_of_chain.values())) # (L)
full_seq = torch.cat(list(full_seq_of_chain.values())) # (L)
full_model_xyz.append(full_xyz)
full_model_atom_mask.append(full_atom_mask)
full_model_occ.append(full_occ)
full_model_masked_seq.append(full_masked_seq)
full_model_seq.append(full_seq)
full_model_position_mask.append(torch.cat(full_position_mask))
full_model_xyz = torch.stack(full_model_xyz) # (M, L, 14, 3)
full_model_atom_mask = torch.stack(full_model_atom_mask) # (M, L, 14)
full_model_occ = torch.stack(full_model_occ) # (M, L, 14)
full_model_masked_seq = torch.stack(full_model_masked_seq) # (M, L)
full_model_seq = torch.stack(full_model_seq) # (M, L)
full_model_position_mask = torch.stack(full_model_position_mask) # (M, L)
has_multiple_chains = True if len(chain_break.keys()) > 1 else False
has_multiple_models = True if full_model_xyz.shape[0] > 1 else False
Sequence = ProteinSequence(sequence = full_model_seq, masked_sequence = full_model_masked_seq, chain_break = chain_break)
Structure = ProteinStructure(xyz = full_model_xyz, atom_mask = full_model_atom_mask, position_mask = full_model_position_mask, chain_break = chain_break, has_multiple_models = has_multiple_models, has_multiple_chains = has_multiple_chains)
Protein_ID = mmcif_file.split("/")[-1].split(".")[:-1]
Protein_ID = ".".join(Protein_ID)
symmetry_values = list(Symmetry_related_lines.values())
if symmetry_values == [None, None, None]:
Symmetry_related_lines = None
Protein_object = Protein(sequence = Sequence, structure = Structure, occupancy = full_model_occ, symmetry_related_info = Symmetry_related_lines, ID = Protein_ID, IS_LABEL = IS_LABEL)
Protein_object.source = "mmcif"
Protein_object.ID = ProteinID
# print("mmcif parsing time : %.2f"%(time.time() - start_time))
# print(Protein_object)
return Protein_object
[docs]
def mmcif_pickling(mmcif_folder, saving_directory = "data/pickle_data/mmcif/"):
"""
Input : mmcif folder path
Output : None
This function parse mmcif file and save it as .pkl file using pickle.
In public_data/rcsb/cif/, there are many folders which has 2 letters.
"""
if not os.path.exists(saving_directory):
os.makedirs(saving_directory)
inner_folder_list = os.listdir(mmcif_folder)
for inner_folder in inner_folder_list:
if inner_folder[0] == ".":
continue
inner_folder_path = mmcif_folder + inner_folder + "/"
for file in os.listdir(inner_folder_path):
if file[-6:] == "cif.gz" or file[-3:] == "cif":
mmcif_file = inner_folder_path + file
print(f"{mmcif_file} ", end = "")
try :
Protein_object = mmcif_parsing(mmcif_file)
except Exception as e:
print(f"is not parsed, because it has error")
print(traceback.format_exc())
print(e)
continue
if Protein_object == "Nucleic acid is included":
print(f"is not parsed, because it has nucleic acid")
continue
elif Protein_object == "No ATOM line":
print(f"is not parsed, because it has no ATOM line")
continue
Protein_object.save(saving_directory = saving_directory)
print(f"is parsed")
[docs]
def mmcif_Full_sequence_parsing(mmcif_file, output="String"):
"""
Input : mmcif file path
Output : Full Sequence
I miss that sequence in mmcif file is not full sequence.
So, I parse full sequence.
"""
ProteinID = mmcif_file.split("/")[-1].split(".")[:-1]
ProteinID = ".".join(ProteinID)
start_time = time.time()
if mmcif_file[-3:] == ".gz":
with gzip.open(mmcif_file, "rb") as f:
lines = f.readlines()
lines = [line.decode("utf-8") for line in lines]
elif mmcif_file[-4:] == ".cif":
with open(mmcif_file, "r") as f:
lines = f.readlines()
else :
raise ValueError("mmcif file must be .cif or .cif.gz file")
lines = [line.strip() for line in lines]
lines_split_by_sharp = []
temp = []
for line in lines:
if line.startswith("#"):
lines_split_by_sharp.append(temp)
temp = []
else :
temp.append(line)
atom_site_dict_list = mmcif_loop_parser(lines_split_by_sharp, first_key = "_entity_poly_seq.entity_id",IS_ATOM=False) # list of dict
full_sequence_dict = {}
for ii in range(len(atom_site_dict_list)):
try :
entity_id = atom_site_dict_list[ii]["entity_id"]
if entity_id not in full_sequence_dict.keys():
full_sequence_dict[entity_id] = ""
full_sequence_dict[entity_id] += aa2AA[atom_site_dict_list[ii]["mon_id"]]
except KeyError as e:
abnormal_key = atom_site_dict_list[ii]["mon_id"]
if len(abnormal_key) == 3:
full_sequence_dict[entity_id] += 'X'
else :
# print(traceback.format_exc())
# print(e)
print(f"{ProteinID} has nucleic acid or something abnormal key {abnormal_key}")
return "Nucleic acid is included"
if output == "Tensor":
for key in full_sequence_dict.keys():
full_sequence_dict[key] = torch.tensor([AA2num[AA] for AA in full_sequence_dict[key]])
PDB_ID = mmcif_file.split("/")[-1].split(".")[:-1][0]
return PDB_ID, full_sequence_dict
[docs]
def mmcif_full_sequence(mmcif_folder, saving_file_path = "data/PDB_ID_full_sequence.txt"):
"""
Input : mmcif folder path
Output : None
This function parse mmcif file and save it as .pkl file using pickle.
In public_data/rcsb/cif/, there are many folders which has 2 letters.
"""
output_list = []
inner_folder_list = os.listdir(mmcif_folder)
for inner_folder in inner_folder_list:
if inner_folder[0] == ".":
continue
inner_folder_path = mmcif_folder + inner_folder + "/"
for file in os.listdir(inner_folder_path):
if file[-6:] == "cif.gz" or file[-3:] == "cif":
mmcif_file = inner_folder_path + file
print(f"{mmcif_file} ", end = "")
try :
PDB_ID, full_sequence_dict = mmcif_Full_sequence_parsing(mmcif_file)
except Exception as e:
# print(f"is not parsed, because it has error")
# print(traceback.format_exc())
# print(e)
continue
output_list.append([PDB_ID, full_sequence_dict])
print(f"is parsed")
with (open(saving_file_path, "w")) as f:
for line in output_list:
f.write(line[0] + "," + str(line[1]) + "\n")
[docs]
def pdb_pickling(pdb_folder, saving_directory = "data/pickle_data/pdb/"):
"""
Input : pdb folder path
Output : None
This function parse pdb file and save it as .pkl file using pickle
"""
if not os.path.exists(saving_directory):
os.makedirs(saving_directory)
for file in os.listdir(pdb_folder):
if file[-3:] == "pdb":
pdb_file = pdb_folder + file
Protein_object = PDB_parsing(pdb_file)
Protein_object.save(saving_directory = saving_directory)
[docs]
def a3m_pickling(a3m_folder, saving_directory = "data/pickle_data/a3m/"):
"""
Input : a3m folder path
Output : None
This function parse a3m file and save it as .pkl file using pickle
In /public_data/ml/RF2_train/PDB-2021AUG02/a3m/, there are many folders which has 3 letters.
"""
if not os.path.exists(saving_directory):
os.makedirs(saving_directory)
inner_folder_list = os.listdir(a3m_folder)
for inner_folder in inner_folder_list:
if inner_folder[0] == ".":
continue
inner_folder_path = a3m_folder + inner_folder + "/"
for file in os.listdir(inner_folder_path):
if file[-3:] == "a3m" or file[-6:] == "a3m.gz":
a3m_file = inner_folder_path + file
print(f"{a3m_file} ", end = "")
ProteinMSA_object = ProteinMSA(a3m_file_path = a3m_file)
try :
ProteinMSA_object.save(saving_directory = saving_directory)
print(f"is parsed")
except :
print(f"is not parsed, because it has error")
continue
[docs]
def hhr_refactoring(hhr_folder, saving_directory = "data/hhr_refactoring/"):
"""
Input : hhr folder path
Output : None
This function parse hhr file and save it as .pt file using torch
In /public_data/ml/RF2_train/PDB-2021AUG02/torch/hhr/, there are many folders which has 3 letters.
"""
if not os.path.exists(saving_directory):
os.makedirs(saving_directory)
inner_folder_list = os.listdir(hhr_folder)
for inner_folder in inner_folder_list:
if inner_folder[0] == ".":
continue
inner_folder_path = hhr_folder + inner_folder + "/"
for file in os.listdir(inner_folder_path):
if file[-2:] == "pt":
hhr_file = inner_folder_path + file
print(f"{hhr_file} ", end = "")
try :
refactoring_hhr_file(hhr_file, saving_directory)
print(f"is refactored")
except :
print(f"is not refactored, because it has error")
continue
[docs]
def AF_Results_json_viewer(file_path):
"""
Input : AF results json path
Output : None
This function is used to view AF results json file
"""
with open(file_path, "r") as f:
json_file = json.load(f)
import pprint
pprint.pprint(json_file.keys())
[docs]
def AF_Results_json_filtering(json_file_path, residue_plddt_filtering = 0.95):
"""
Input : AF results json path, plddt_filtering (float)
Output : plddt_filtered (np.array) [It is index of filtered seq]
This function filter AF results and return plddt_filtered
"""
if residue_plddt_filtering < 1 :
residue_plddt_filtering *= 100
json_file = json.load(open(json_file_path, "r"))
plddt = json_file["plddt"]
plddt = np.array(plddt)
mean_plddt = np.mean(plddt)
plddt_filtered = np.where(plddt > residue_plddt_filtering)[0]
return plddt_filtered, plddt
[docs]
def get_query_from_a3m(a3m_file_path):
with open(a3m_file_path, "r") as f:
lines = f.readlines()
IS_query = -1
for line in lines:
if line.startswith("#"):
continue
elif line.startswith(">"):
IS_query = 1
continue
else :
if IS_query == 1:
query = line.strip()
break
return query
[docs]
def AF_Results_filtering(file_path = "/home/psk6950/practice/string_MSA/monomer_colabfold_0_15/direct/0/",
saving_directory = "data/STRING/",
residue_plddt_filtering = 0.0,
whole_plddt_filtering = 0.7,
rank_num = 5,
IS_DIRECT = True):
"""
Input : AF results folder path
Output : None
This function filter AF results and save it as .pt file using torch
"""
if not os.path.exists(saving_directory):
os.makedirs(saving_directory)
interaction_type = "Direct" if IS_DIRECT else "Indirect"
json_list = os.listdir(file_path)
a3m_list = [file for file in json_list if file[-3:] == "a3m"]
pdb_list = [file for file in json_list if file[-3:] == "pdb"]
json_list = [file for file in json_list if file[-5:] == ".json"]
json_list = [file for file in json_list if "scores_rank" in file]
def get_ID_dict(file_list, split_str, IS_a3m):
ID_dict = {}
for file in file_list:
ID = file.split(split_str)[0]
if IS_a3m:
ID_dict[ID] = file
else:
rank = int(file.split(split_str)[1].split("_")[0])
if rank > rank_num:
continue
if ID not in ID_dict.keys():
ID_dict[ID] = []
ID_dict[ID].append(file)
return ID_dict
ID_json_dict = get_ID_dict(json_list, "_scores_rank_", IS_a3m = False)
ID_a3m_dict = get_ID_dict(a3m_list, ".a3m", IS_a3m = True)
ID_pdb_dict = get_ID_dict(pdb_list, "_unrelaxed_rank_", IS_a3m = False)
ID_template_dict = {}
for ID in ID_json_dict.keys():
try :
json_file_list = ID_json_dict[ID]
pdb_file_list = ID_pdb_dict[ID]
plddt_filtered_list = []
pdb_filtered_list = []
f0d = 0
f1d_list = []
for json_file,pdb_file in zip(json_file_list, pdb_file_list):
plddt_filtered, plddt = AF_Results_json_filtering(file_path + "/"+json_file, residue_plddt_filtering = residue_plddt_filtering)
f0d = torch.tensor(np.mean(plddt)) # (1,)
f1d = torch.tensor(plddt) # (L,)
if f0d < whole_plddt_filtering:
continue
plddt_filtered_list.append(plddt_filtered)
pdb_filtered_list.append(pdb_file)
ID_template_dict[ID] = {"plddt_filtered_idxs" : plddt_filtered_list, "a3m" : ID_a3m_dict[ID], "pdb" : pdb_filtered_list, "f0d" : f0d, "f1d" : f1d}
except KeyError:
print(f"{ID} is not in a3m or pdb")
continue
for ID in ID_template_dict.keys():
a3m_file = ID_template_dict[ID]["a3m"]
try :
MSA_object = ProteinMSA(a3m_file_path = file_path + "/" + a3m_file)
except ValueError as e:
print(f"{ID} has no MSA")
query_sequence = get_query_from_a3m(file_path + "/" + a3m_file)
MSA_object = ProteinMSA(msa_ID = ID, query_sequence=query_sequence)
MSA_object.save(saving_directory = saving_directory + "/MSA/")
pdb_file_list = ID_template_dict[ID]["pdb"]
plddt_filtered_list = ID_template_dict[ID]["plddt_filtered_idxs"]
f0d = ID_template_dict[ID]["f0d"]
f1d = ID_template_dict[ID]["f1d"]
print(f"{ID} ", end = "")
Template_list = []
for pdb_file, plddt_filtered in zip(pdb_file_list, plddt_filtered_list):
Protein_object = PDB_parsing(file_path + "/" + pdb_file, IS_LABEL=False)
Template_object = protein_to_template(Protein_object, position = plddt_filtered, f0d = f0d, f1d = f1d)
Template_list.append(Template_object)
ProteinTemplates_object = ProteinTemplates(templates = Template_list, templates_ID = ID)
ProteinTemplates_object.save(saving_directory = saving_directory + "/" + interaction_type+ f"/Template/")
print(f"is filtered")
[docs]
def STRING_refactoring(rank_num = 1):
directory1= "/home/psk6950/practice/string_MSA/monomer_colabfold_0_15/direct"
directory2= "/home/psk6950/practice/string_MSA/monomer_colabfold_0_15/indirect"
folder_list1 = os.listdir(directory1)
folder_list2 = os.listdir(directory2)
for direct_folder in folder_list1:
AF_Results_filtering(file_path = directory1 + "/" + direct_folder, rank_num = rank_num, IS_DIRECT = True)
for indirect_folder in folder_list2:
AF_Results_filtering(file_path = directory2 + "/" + indirect_folder, rank_num = rank_num, IS_DIRECT = False)
[docs]
def get_ID_to_source_dict(sources_path_list,saving_path = "data/ID_to_source_dict.txt"):
"""
Input : sources directory path, saving directory path(.txt)
Output : ID_to_source_dict (dictionary)
PDB ID = mmcif ID
STRING ID = pkl ID
"""
ID_to_source_dict = {}
for sources_path in sources_path_list:
source_type = "PDB" if "mmcif" in sources_path else "STRING"
if source_type == "PDB" :
source_list = os.listdir(sources_path)
for source in source_list:
ID = source.split(".pkl")[0]
if ".cif" in ID:
ID = ID.split(".cif")[0]
ID_to_source_dict[ID] = source_type
else :
if "INDIRECT" in sources_path :
source_type = "STRING_Indirect"
elif "DIRECT" in sources_path:
source_type = "STRING_Direct"
with open(sources_path, "r") as f:
lines = f.readlines()
for line in lines:
ID = line.split("\t")[0]
ID_to_source_dict[ID] = source_type
with open(saving_path, "w") as f:
for ID in ID_to_source_dict.keys():
f.write(f"{ID},{ID_to_source_dict[ID]}\n")
return ID_to_source_dict
[docs]
def get_hash_dict(msa_folder_path, saving_path = "data/hash_dict.txt"):
file_list = os.listdir(msa_folder_path)
MSA_hash_dict = {}
for file in file_list:
with open(msa_folder_path + file, "rb") as f:
MSA_object = pickle.load(f)
query_sequence = MSA_object.query_sequence # torch.tensor
# to string
query_sequence = "".join([num2AA[num] for num in query_sequence])
MSA_hash = file.split("_msas.pkl")[0]
MSA_hash_dict[MSA_hash] = query_sequence
with open(saving_path, "w") as f:
for MSA_hash in MSA_hash_dict.keys():
f.write(f"{MSA_hash},{MSA_hash_dict[MSA_hash]}\n")
[docs]
def test_a3m_file(file_path):
start_time = time.time()
with open(file_path, "rb") as f:
MSA_object = pickle.load(f)
print(MSA_object)
print(f"pickle load time : {time.time() - start_time}")
[docs]
def check_hhr_loaded(hhr_folder, saved_dir):
hhr_name_list = []
inner_folder_list = os.listdir(hhr_folder)
for inner_folder in inner_folder_list:
if inner_folder[0] == ".":
continue
inner_folder_path = hhr_folder + inner_folder + "/"
for file in os.listdir(inner_folder_path):
if file[-2:] == "pt":
hhr_name_list.append(file)
saved_hhr_name_list = []
for file in os.listdir(saved_dir):
if file[-2:] == "pt":
saved_hhr_name_list.append(file)
hhr_name_set = set(hhr_name_list)
saved_hhr_name_set = set(saved_hhr_name_list)
diff = hhr_name_set - saved_hhr_name_set
print(diff) # result : {'094172.pt'} 20230809
[docs]
def hhr_dir_rerefactoring(hhr_dir = "data/hhr_refactoring", saving_dir="data/hhr_rerefactoring"):
if not os.path.isdir(saving_dir):
os.mkdir(saving_dir)
hhr_file_list = os.listdir(hhr_dir)
for hhr_file in hhr_file_list:
if hhr_file[0] == ".":
continue
hhr_file_path = hhr_dir + "/" + hhr_file
hhr_file_rerefactoring(hhr_file_path, saving_dir)
[docs]
def hhr_file_rerefactoring(hhr_file_path, saving_dir):
hhr_hash = hhr_file_path.split("/")[-1].split(".")[0]
hhr_list_dict_tensor = torch.load(hhr_file_path)
key_list = ['template_ID', 'xyz', 'mask', 'position', 'sequence', 'f0d', 'f1d']
if not os.path.isdir(saving_dir + "/" + hhr_hash):
os.mkdir(saving_dir + "/" + hhr_hash)
template_ID_list = []
xyz_list = []
mask_list = []
position_list = []
sequence_list = []
f0d_list = []
f1d_list = []
for ii in range(len(hhr_list_dict_tensor)):
inner_dictionary = hhr_list_dict_tensor[ii]
for key in key_list :
if key == "template_ID":
template_ID_list.append(inner_dictionary[key])
elif key == "xyz":
xyz_list.append(inner_dictionary[key])
elif key == "mask":
mask_list.append(inner_dictionary[key])
elif key == "position":
position_list.append(inner_dictionary[key])
elif key == "sequence":
sequence_list.append(inner_dictionary[key])
elif key == "f0d":
f0d_list.append(inner_dictionary[key])
elif key == "f1d":
f1d_list.append(inner_dictionary[key])
# Save
torch.save(template_ID_list, saving_dir + "/" + hhr_hash + "/template_ID.pt")
torch.save(xyz_list, saving_dir + "/" + hhr_hash + "/xyz.pt")
torch.save(mask_list, saving_dir + "/" + hhr_hash + "/mask.pt")
torch.save(position_list, saving_dir + "/" + hhr_hash + "/position.pt")
torch.save(sequence_list, saving_dir + "/" + hhr_hash + "/sequence.pt")
torch.save(f0d_list, saving_dir + "/" + hhr_hash + "/f0d.pt")
torch.save(f1d_list, saving_dir + "/" + hhr_hash + "/f1d.pt")
if __name__ == "__main__":
sources_path_list = ["data/pickle_data/mmcif","data/STRING_v0/STRING_DIRECT.txt", "data/STRING_v0/STRING_INDIRECT.txt"]
# sources_path_list = ["data/STRING_v0/STRING_INDIRECT.txt"]
# ID_to_source_dict = get_ID_to_source_dict(sources_path_list, saving_path = "data/ID_to_source_dict.txt")
# get_hash_dict("data/pickle_data/a3m/", saving_path = "data/hash_dict.txt")
# get_hash_dict("data/STRING/MSA/", saving_path = "data/STRING_ID_dict_2.txt")
# test_a3m_file("data/pickle_data/a3m/000150_msas.pkl")
# mmcif_Full_sequence_parsing("3mj7.cif")
# mmcif_full_sequence("/public_data/rcsb/cif/")
# mmcif_full_sequence("test_cif/")
# hhr_refactoring(hhr_folder = "/public_data/ml/RF2_train/PDB-2021AUG02/torch/hhr/")
# TODO NOW
# check_hhr_loaded(hhr_folder = PDB_TRAIN_DATA_PATH + "/torch/hhr/", saved_dir = "data/hhr_refactoring/")
# hhr_file = PDB_TRAIN_DATA_PATH + "/torch/hhr/094/094172.pt"
# saving_directory = "/home/psk6950/practice/MiniWorld/data/hhr_refactoring"
# refactoring_hhr_file(hhr_file, saving_directory)
# hhr_dir_rerefactoring()
# mmcif_pickling(mmcif_folder = "/public_data/rcsb/cif/")
# mmcif_file = "6j54.cif"
# mmcif_parsing(mmcif_file)
# mmcif_pickling(mmcif_folder = "data/test_cif/", saving_directory = "data/pickle_data/mmcif_xz/")
# protein_1a1d = mmcif_parsing("/public_data/rcsb/cif/ab/3abk.cif.gz")
# print(protein_1a1d)
# STRING_refactoring(rank_num = 5)
Protein_object = mmcif_parsing("3gpv (1).cif")
# a3m_file =
# MSA_object = ProteinMSA(a3m_file_path = a3m_file)
breakpoint()
# mmcif_pickling(mmcif_folder = "/public_data/rcsb/cif/", saving_directory= "data/pickle_data/mmcif2/")