import numpy as np
import torch
import os
from collections import OrderedDict
import pickle
import string
import gzip
from miniworld.utils.chemical import AA2num, num2AA, aa2num, aa2long
[docs]
class ProteinStructure():
"""This class is used to store the protein structure information::
M : Number of model
N : Number of residue
xyz : (M, L, 14, 3) np.array or list or tensor. 14 is the number of heavy atom and
# ... (중략) ...
position_mask : (M, L) np.array or list or tensor, 1 for confident position and 0 for missing position.
"""
def __init__(self,
xyz,
chain_break,
atom_mask,
position_mask = None,
has_multiple_chains=False,
has_multiple_models=False,
):
assert isinstance(xyz, (np.ndarray, list, torch.Tensor)), "xyz must be np.ndarray, list or torch.Tensor"
assert isinstance(atom_mask, (np.ndarray, list, torch.Tensor)), "atom_mask must be np.ndarray, list or torch.Tensor"
assert isinstance(position_mask, (np.ndarray, list, torch.Tensor)) or position_mask is None, "position_mask must be np.ndarray, list or torch.Tensor"
if isinstance(xyz, (np.ndarray, list)):
xyz = torch.tensor(xyz)
if isinstance(atom_mask, (np.ndarray, list)):
atom_mask = torch.tensor(atom_mask)
if len(xyz.shape) == 3:
xyz = torch.unsqueeze(xyz, 0) # (1, L, 14, 3) torch.Tensor
if len(atom_mask.shape) == 2:
atom_mask = torch.unsqueeze(atom_mask, 0) # (1, L, 14) torch.Tensor
if position_mask is None:
position_mask = torch.ones((xyz.shape[0], xyz.shape[1])) # (M, L) torch.Tensor
self.xyz = xyz # (M, L, 14, 3) torch.Tensor
self.atom_mask = atom_mask # (M, L, 14) torch.Tensor
self.position_mask = position_mask # (M, L) torch.Tensor
self.chain_break = chain_break # dict or OrderdDict with Ex) chain_break = {A: (0,197), B:(198,356), ... }
self.has_multiple_chains = has_multiple_chains # True if this protein has multiple chains
self.has_multiple_models = has_multiple_models # True if this protein has multiple models
def __repr__(self):
return f'xyz shape : {self.xyz.shape}, mask shape : {self.atom_mask.shape}, chain_break : {self.chain_break}'
[docs]
def get_chain_length_by_ID(self,chain_id):
if chain_id not in self.chain_break.keys():
raise KeyError("Chain ID {} is not in this structure".format(chain_id))
return self.chain_break[chain_id][1] - self.chain_break[chain_id][0]
[docs]
def get_structure_by_model(self, model_id):
assert self.has_multiple_models or model_id == 0, "This protein has only one model"
model_xyz = self.xyz[model_id]
model_atom_mask = self.atom_mask[model_id]
model_position_mask = self.position_mask[model_id]
chain_break = self.chain_break[model_id]
return model_xyz, model_atom_mask, model_position_mask, chain_break
[docs]
def get_chain_structure_by_ID(self, chain_id):
assert self.has_multiple_chains or chain_id != "A", "This protein has only one chain"
chain_start = self.chain_break[chain_id][0]
chain_end = self.chain_break[chain_id][1]
chain_xyz = self.xyz[:, chain_start:chain_end]
chain_mask = self.atom_mask[:, chain_start:chain_end]
return chain_xyz, chain_mask
[docs]
class ProteinSequence():
"""This class is used to store the protein sequence information::
M : Number of model
L : Number of residue
sequence : (M, L) torch.Tensor or str. I strongly recommend to use torch.Tensor
masked_sequence : (M, L) torch.Tensor.
chain_break : (M, ) list of dictionary or OrderDict. Ex) M=1, chain_break[0] = {A: (0,197), B:(198,356), ... }
sequence = Full sequence without None
masked_sequence = Masked sequence with None
"""
def __init__(self,
sequence,
masked_sequence = None,
chain_break = None,
):
assert isinstance(sequence, (str, torch.Tensor)), "sequence must be string or torch.Tensor"
assert isinstance(masked_sequence, (str, torch.Tensor)) or masked_sequence is None, "maksed_sequence must be string or torch.Tensor"
if isinstance(sequence, str):
sequence = torch.tensor([AA2num[aa] for aa in sequence]) # (L, ) torch.Tensor
sequence = torch.unsqueeze(sequence, 0) # (1, L) torch.Tensor
self.sequence = sequence # str or torch.Tensor with (M, L) shape; M is num of Model and L is sequence length.
if masked_sequence is None :
masked_sequence = sequence
self.masked_sequence = masked_sequence # str or torch.Tensor with (M, L) shape; M is num of Model and L is sequence length.
if chain_break is not None :
self.chain_break = chain_break # list of dictionary. Ex) chain_break[0] = {A: (0,197), B:(198,356), ... }
else :
self.chain_break = OrderedDict()
self.chain_break["A"] = (0, self.get_sequence_length()-1)
[docs]
def get_sequence_length(self):
if isinstance(self.sequence, str):
return len(self.sequence)
elif isinstance(self.sequence, torch.Tensor):
return self.sequence.shape[1]
def __repr__(self):
if isinstance(self.sequence, str):
return str(f'{self.sequence}')
elif isinstance(self.sequence, torch.Tensor):
return str(f'shape : {self.sequence.shape}')
else :
raise TypeError("sequence must be string or torch.Tensor")
[docs]
def get_sequence_by_chain_ID(self, chain_id):
if chain_id not in self.chain_list:
raise KeyError("Chain ID {} is not in this sequence".format(chain_id))
chain_start = self.chain_break[chain_id][0]
chain_end = self.chain_break[chain_id][1]
return self.sequence[chain_start:chain_end+1]
[docs]
class Protein():
def __init__(self,
sequence,
structure,
occupancy = None,
symmetry_related_info = None,
ID="----",
IS_LABEL=False, # If this protein from PDB, IS_LABEL is True
source = "",
):
assert isinstance(sequence, ProteinSequence) or isinstance(sequence, str) or sequence is None, f"sequence must be torch.tensor or string or None, but {type(sequence)} is given"
assert isinstance(structure, ProteinStructure) or structure is None, f"structure must be ProteinStructure or None, but {type(structure)} is given"
self.ID = ID
self.sequence = sequence
self.structure = structure
self.occupancy = occupancy
self.symmetry_related_info = self.get_symmetry_related_info(symmetry_related_info) # dictionary
self.chain_break = structure.chain_break
self.chain_list = list(self.chain_break.keys())
self.IS_LABEL = IS_LABEL
self.source = source
def __extract__(residue_idx_list):
pass
def __repr__(self):
output = '\n'
output += '-'*15 + ' Protein Class ' + '-'*15+'\n'
output += f'ID : {self.ID}\nSource : {self.source}\nIS_LABEL : {self.IS_LABEL}\nSequence\t| {self.sequence.__repr__()}\nStructure\t| {self.structure.__repr__()}\n'
if self.symmetry_related_info is not None :
output += 'Symmetry Related Info\n'
for key, value in self.symmetry_related_info.items():
for value_ in value:
for key_in_value, value_in_value in value_.items():
if len(str(value_in_value)) < 8 :
output += f'{key}.{key_in_value} : {value_in_value}\n'
else :
output += f'{key}.{key_in_value} : \n {value_in_value}\n'
output += '-'*45
return output
[docs]
def get_sequence_by_chain_ID(self, chain_id):
if chain_id not in self.chain_list:
raise KeyError("Chain ID {} is not in this protein".format(chain_id))
chain_start = self.chain_break[chain_id][0]
chain_end = self.chain_break[chain_id][1]
return self.sequence[chain_start:chain_end]
[docs]
def get_structure_by_chain_ID(self, chain_id):
if chain_id not in self.chain_list:
raise KeyError("Chain ID {} is not in this protein".format(chain_id))
chain_start = self.chain_break[chain_id][0]
chain_end = self.chain_break[chain_id][1]
return self.structure[chain_start:chain_end]
[docs]
def save(self,saving_directory):
"""
Save file using pickle
"""
if not os.path.exists(saving_directory):
os.makedirs(saving_directory)
if self.ID == "----":
raise ValueError("ID must be changed")
with open(os.path.join(saving_directory, self.ID+".pkl"), 'wb') as f:
pickle.dump(self, f)
[docs]
class ProteinTemplate(Protein):
"""
position : (M, L_template) np.array or list or tensor. L is the number of model and N_template is the number of residue in template. It have to be less than L.
"""
def __init__(self,
sequence,
structure,
position = None,
template_ID = None,
f0d=None,
f1d=None,
):
super().__init__(sequence, structure)
assert isinstance(position, torch.Tensor) or position is None, "position must be torch.Tensor"
assert isinstance(f0d, torch.Tensor) or f0d is None, "f0d must be torch.Tensor or None"
assert isinstance(f1d, torch.Tensor) or f1d is None, "f1d must be torch.Tensor or None"
if position is None : # It means that all residue position is confident
protein_length = sequence.get_sequence_length()
position = torch.arange(protein_length)
else:
self.position = position
self.template_ID = template_ID
self.f0d = f0d
self.f1d = f1d
[docs]
@classmethod
def from_dict(cls, dictionary):
sequence = ProteinSequence(dictionary["sequence"])
dict_for_structure = {}
dict_for_structure["xyz"] = dictionary["xyz"]
dict_for_structure["atom_mask"] = dictionary["mask"]
dict_for_structure["chain_break"] = sequence.chain_break
structure = ProteinStructure(**dict_for_structure)
new_dict = {}
new_dict["sequence"] = sequence
new_dict["structure"] = structure
new_dict["position"] = dictionary["position"]
new_dict["template_ID"] = dictionary["template_ID"]
new_dict["f0d"] = dictionary["f0d"]
new_dict["f1d"] = dictionary["f1d"]
return cls(**new_dict)
[docs]
@classmethod
def from_dict2(cls, dictionary):
sequence = dictionary["sequence"]
structure = dictionary["structure"]
new_dict = {}
new_dict["sequence"] = sequence
new_dict["structure"] = structure
new_dict["position"] = dictionary["position"]
new_dict["template_ID"] = dictionary["template_ID"]
new_dict["f0d"] = dictionary["f0d"]
new_dict["f1d"] = dictionary["f1d"]
return cls(**new_dict)
def __repr__(self):
fod_output = "None" if self.f0d is None else self.f0d.shape
f1d_output = "None" if self.f1d is None else self.f1d.shape
return f'Sequence : {self.sequence.__repr__()}, Structure : {self.structure.__repr__()}, Position : {self.position.shape}, f0d : {fod_output}, f1d : {f1d_output}'
[docs]
class ProteinTemplates():
"""
ProteinTemplates : list of ProteinTemplate
"""
def __init__(self,
templates = [],
template_pdb_list = None,
hhr_file = None,
hhr_file_path = None,
templates_ID = None,
):
assert isinstance(hhr_file, torch.Tensor) or hhr_file is None, "hhr_file must be torch.Tensor or None"
assert isinstance(templates, list) or templates is None, "templates must be list"
self.templates = templates
self.templates_ID = templates_ID
if len(self.templates) != 0:
self.template_ID_list = [template.template_ID for template in templates]
if template_pdb_list is not None:
self.from_pdb_file(template_pdb_list)
elif hhr_file is not None:
hhr_hash = hhr_file.split("/")[-1].split(".")[:-1]
hhr_hash = ".".join(hhr_hash)
self.from_hhr_file(hhr_file, hhr_hash)
elif hhr_file_path is not None:
hhr_hash = hhr_file_path.split("/")[-1].split(".")[:-1]
hhr_hash = ".".join(hhr_hash)
hhr_file = torch.load(hhr_file_path)
self.from_hhr_file(hhr_file, hhr_hash)
def __repr__(self) -> str:
output_str = ""
for ii, template in enumerate(self.templates):
if ii>3:
output_str += "...\n"
break
output_str += template.__repr__() + "\n"
return output_str
[docs]
def from_hhr_file(self, hhr_file, hhr_hash):
if len(self.templates) != 0:
raise ValueError("templates must be empty")
self.templates_ID = hhr_hash
# hhr file is list of dictionary
# [template 1 dict, template 2 dict, ...]
templates = []
for template_dict in hhr_file:
templates.append(ProteinTemplate.from_dict(template_dict))
self.templates = templates
[docs]
def from_pdb_file(self, template_pdb_list):
if len(self.templates) != 0:
raise ValueError("templates must be empty")
templates = []
for template_pdb in template_pdb_list:
template, lddt = PDB_parsing(template_pdb, return_lddt = True)
template_sequence = template.sequence
template_structure = template.structure
template_ID = template.ID
template_dict= {
"sequence" : template_sequence.sequence,
"xyz" : template_structure.xyz,
"mask" : template_structure.atom_mask,
"chain_break" : template_structure.chain_break,
"f0d" : torch.tensor([0.0]),
"f1d" : lddt/100,
"position" : None,
"template_ID" : template_ID,
}
templates.append(ProteinTemplate.from_dict(template_dict))
self.templates = templates
# def get_template_by_ID(self, template_ID):
# if template_ID not in self.template_list:
# raise KeyError("Template ID {} is not in this protein".format(template_ID))
# return self.templates[template_ID]
[docs]
def filter_template_by_sequence_identity(self, sequence_identity):
pass
# filtered_templates = []
# for template in self.templates:
# if template.sequence_identity >= sequence_identity:
# filtered_templates.append(template)
# self.templates = filtered_templates
# self.template_ID_list = [template.template_ID for template in filtered_templates]
[docs]
def save(self, saving_directory):
if not os.path.exists(saving_directory):
os.makedirs(saving_directory)
with open(os.path.join(saving_directory, self.templates_ID+".pkl"), 'wb') as f:
pickle.dump(self, f)
[docs]
class ProteinMSA():
"""Represents the Multiple Sequence Alignment (MSA) of a Protein::
msa_ID : str
query_sequence : str or torch.tensor(1, L) or (L,)
msa(_tensor) : torch.tensor(N, L); N is MSA depth
insertion(_tensor) : torch.tensor(N, L)
a3m_file_path : str for load this object from a3m file
"""
def __init__(self,
msa_ID = None,
query_sequence = None,
msa_tensor = None,
insertion_tensor = None,
a3m_file_path = None # it is None there is no MSA.
):
assert isinstance(msa_tensor, torch.Tensor) or msa_tensor is None, "msa_tensor must be torch.Tensor or None"
assert isinstance(insertion_tensor, torch.Tensor) or insertion_tensor is None, "insertion_tensor must be torch.Tensor or None"
if msa_ID is None :
if msa_tensor is not None:
msa_ID = msa_tensor.split("/")[-1].split(".")[:-1]
msa_ID = ".".join(msa_ID)
elif a3m_file_path is not None:
msa_ID = a3m_file_path.split("/")[-1].split(".")[:-1]
msa_ID = ".".join(msa_ID)
if isinstance(query_sequence, str):
query_sequence = torch.tensor([AA2num[aa] for aa in query_sequence])
self.msa_ID = msa_ID
self.query_sequence = query_sequence
self.chain_break = None
if msa_tensor is not None:
self.msa = msa_tensor
if insertion_tensor is not None:
self.insertion = insertion_tensor
if a3m_file_path is not None:
self.from_a3m_file_path(a3m_file_path)
def __repr__(self):
return f'MSA ID : {self.msa_ID}, Query Sequence : {self.query_sequence.shape}, MSA : {self.msa.shape}, Insertion : {self.insertion.shape}'
[docs]
def from_msa_tensor(self, msa_tensor):
self.msa = msa_tensor
query_sequence = msa_tensor[0]
self.query_sequence = query_sequence
[docs]
def from_a3m_file_path(self, a3m_file_path, max_seq = 5000):
# Code from RF2 parsers.py
"""
Input : a3m_file_path (.a3m or .a3m.gz)
"""
msa = []
insertion_list = []
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
#print(filename)
if a3m_file_path.split('.')[-1] == 'gz':
f = gzip.open(a3m_file_path, 'rt')
else:
f = open(a3m_file_path, 'r')
# read file line by line
chain_break = {}
for line in f:
if line[0] == '#': # skip # line
continue
# skip labels
if line[0] == '>':
continue
if chain_break == {} and ':' in line:
chains = line.strip().split(':')
chain_start = 0
for chain_idx, chain in enumerate(chains):
chain_end = chain_start + len(chain) - 1
chain_break[chain_idx] = (chain_start, chain_end)
chain_start = chain_end + 1
self.chain_break = chain_break
# remove :
line = line.replace(':','')
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa_i = line.translate(table)
msas_i = msa_i.split('/')
if (len(msa)==0):
# first seq
Ls = [len(x) for x in msas_i]
msa = [[s] for s in msas_i]
else:
nchains = len(msas_i)
isgood = all([ len(msa[i][0]) == len(msas_i[i]) for i in range(nchains) ])
if isgood:
for i in range(nchains):
msa[i].append(msas_i[i])
else:
raise ValueError("Len error", a3m_file_path, len(msa[0]), self.msa_ID )
# sequence length
L = sum(Ls)
# 0 - match or gap; 1 - insertion
lower_case = np.array([0 if c.isupper() or c=='-' else 1 for c in line])
insertion = np.zeros((L))
if np.sum(lower_case) > 0:
# positions of insertions
pos = np.where(lower_case==1)[0]
# shift by occurrence
lower_case = pos - np.arange(pos.shape[0])
# position of insertions in cleaned sequence
# and their length
pos,num = np.unique(lower_case, return_counts=True)
# append to the matrix of insetions
insertion[pos] = num
insertion_list.append(insertion)
# concatenate
msa = [np.array([list(s) for s in t], dtype='|S1').view(np.uint8) for t in msa]
msa = np.concatenate(msa,axis=-1)
query_sequence = msa[0]
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
# treat all unknown characters as gaps
msa[msa > 20] = 20
insertion_list = np.array(insertion_list, dtype=np.uint8)
query_sequence = torch.tensor(query_sequence) # (L, ) torch.Tensor
msa = torch.tensor(msa) # (N, L) torch.Tensor
insertion = torch.tensor(insertion_list) # (N, L) torch.Tensor
self.msa = msa
self.insertion = insertion
self.query_sequence = query_sequence
[docs]
def save(self, saving_directory):
if not os.path.exists(saving_directory):
os.makedirs(saving_directory)
with open(os.path.join(saving_directory, self.msa_ID+"_msas.pkl"), 'wb') as f:
pickle.dump(self, f)
[docs]
def print_query(self):
print("".join([num2AA[aa] for aa in self.query_sequence]))
[docs]
def PDB_parsing(PDB_file, IS_LABEL = True, return_lddt = False):
"""
Input : PDB file path, IS_LABEL (bool)
Output : Protein object
This function handle multi-model PDB file as well as single-model PDB file. Also this function handle multiple chains.
And model idx start from 0 in this function, but PDB model idx start from 1. So I subtract 1 from model idx.
This function parse PDB file and return a object of Protein class.
I assume that all chain break is same in each model.
"""
if "_unrelaxed" in PDB_file: # for AF output
PDB_ID = PDB_file.split("/")[-1].split("_unrelaxed")[0]
else:
PDB_ID = PDB_file.split("/")[-1].split(".")[:-1]
PDB_ID = ".".join(PDB_ID)
lines = open(PDB_file, "r").readlines()
lines = [line.strip() for line in lines]
has_multiple_models= False
for line in lines:
if line.startswith("MODEL"):
has_multiple_models = True
break
if has_multiple_models:
models = {}
for line in lines:
if line.startswith("MODEL"):
model_idx = int(line.split()[1]) - 1
models[model_idx] = []
elif line.startswith("ATOM"):
models[model_idx].append(line)
else :
models = {0:[]}
for line in lines:
if line.startswith("ATOM"):
models[0].append(line)
xyz = []
mask = []
seq = []
lddt = []
chain_break = OrderedDict()
for ii, model_idx in enumerate(models.keys()):
xyz_model = []
mask_model = []
seq_model = []
lddt_model = []
if ii == 0 :
before_chain = "-"
chain_start = -1
for line in models[model_idx]:
residue = line[17:20]
residue_num = aa2num[residue]
atom_list = aa2long[residue_num] # 0~13 : Heavy atom, 14~26 : Hydrogen atom (what is it?)
Heavy_atom_list = atom_list[:14]
# Hydrogen_atom_list = atom_list[14:]
atom = line[12:16]
chain = line[20:22] # for large PDB
bfactor = float(line[60:66])
if atom == " N " or atom.strip() == "N":
seq_model.append(residue_num)
lddt_model.append(bfactor)
# append xyz_model to empty 14,3 list
if ii == 0 :
chain_start += 1
if chain != before_chain:
chain_key = chain.replace(" ", "")
if len(chain_break) == 0:
chain_break[chain_key] = [chain_start]
else :
before_chain_key = before_chain.replace(" ", "")
chain_break[before_chain_key].append(chain_start-1)
chain_break [chain_key] = []
chain_break[chain_key].append(chain_start)
xyz_empty_array = [[0.0,0.0,0.0] for _ in range(14)]
mask_empty_array = [0 for _ in range(14)]
xyz_model.append(xyz_empty_array)
mask_model.append(mask_empty_array)
if atom in Heavy_atom_list:
atom_idx = Heavy_atom_list.index(atom)
x = float(line[30:38])
y = float(line[38:46])
z = float(line[46:54])
try:
xyz_model[-1][atom_idx] = [x,y,z]
except :
breakpoint()
mask_model[-1][atom_idx] = 1
if ii == 0 :
before_chain = chain
if ii == 0 :
chain_break[chain_key].append(chain_start)
#------------------------------------------------------------------------------------
# Remove Hexa Histine and Penta Histidine, 8 is a index of Histidine
# TODO
# PRECAUTION!!! If there is hexa-histidine on complex, this code doesn't work properly.
# first_five = np.array(seq_model[:5])
# last_five = np.array(seq_model[-5:])
# if np.all(first_five == 8) :
# if seq_model[5] != 8: # if 6th residue is not histidine
# seq_model = seq_model[5:]
# xyz_model = xyz_model[5:]
# mask_model = mask_model[5:]
# if ii == 0 :
# chain_break = {key : [value - 5 for value in chain_break[key]] for key in chain_break.keys()}
# else : # if 6th residue is histidine
# seq_model = seq_model[6:]
# xyz_model = xyz_model[6:]
# mask_model = mask_model[6:]
# if ii == 0 :
# chain_break = {key : [value - 6 for value in chain_break[key]] for key in chain_break.keys()}
# if np.all(last_five == 8) :
# if seq_model[-6] != 8:
# seq_model = seq_model[:-5]
# xyz_model = xyz_model[:-5]
# mask_model = mask_model[:-5]
# if ii == 0 :
# last_chain = list(chain_break.keys())[-1]
# chain_break[last_chain][1] -= 5
# else :
# seq_model = seq_model[:-6]
# xyz_model = xyz_model[:-6]
# mask_model = mask_model[:-6]
# if ii == 0 :
# last_chain = list(chain_break.keys())[-1]
# chain_break[last_chain][1] -= 6
#------------------------------------------------------------------------------------
xyz.append(xyz_model)
mask.append(mask_model)
seq.append(seq_model)
lddt.append(lddt_model)
xyz = torch.tensor(xyz) # (M, L, 14, 3)
mask = torch.tensor(mask) # (M, L, 14)
seq = torch.tensor(seq) # (M, L)
lddt = torch.tensor(lddt) # (M, L)
has_multiple_models = True if xyz.shape[0] > 1 else False
has_multiple_chains = True if len(chain_break.keys()) > 1 else False
new_chain_break = OrderedDict()
for key in chain_break.keys():
new_chain_break[key] = (chain_break[key][0], chain_break[key][1])
chain_break = new_chain_break
Structure = ProteinStructure(xyz, atom_mask = mask, position_mask = None, chain_break = chain_break, has_multiple_models = has_multiple_models, has_multiple_chains = has_multiple_chains)
Sequence = ProteinSequence(sequence = seq, masked_sequence = None, chain_break = chain_break)
Protein_object = Protein(sequence = Sequence, structure = Structure, ID = PDB_ID, IS_LABEL=IS_LABEL)
if return_lddt:
return Protein_object, lddt
else:
return Protein_object