from miniworld.utils.ProteinClass import *
import matplotlib.pyplot as plt
from miniworld.utils.chemical import num2aa
save_dir = "model_output_visualize/"
[docs]
def TMalign(candidate,answer):
TMscore = os.popen('TMscore {} {}'.format(candidate, answer)).read()[1074:1080]
return TMscore
[docs]
def atom_to_pdb_line(atom, atom_xyz, residue_idx, residue_stirng, atom_mask):
pdb_line = "ATOM {:>5d} {:>4s} {:>3s} A{:>4d} {:>8.3f}{:>8.3f}{:>8.3f}{:>6.2f}{:>6.2f} {:>2s}\n".format(
residue_idx, atom, residue_stirng, residue_idx, atom_xyz[0], atom_xyz[1], atom_xyz[2], 1.0, 0.0, atom
)
return pdb_line
[docs]
def tensor_to_pdb(ID, xyz, sequence, atom_mask, crop_idx = None, save_dir = save_dir):
# xyz : (N, L, 14, 3)
# sequence : (L, ) or (L, NUM_OF_CLASS = 23)
# atom_mask : (N, L, 14)
# crop_idx : (128, )
print(f"test xyz shape : {xyz.shape}")
print(f"test sequence shape : {sequence.shape}")
print(f"test atom_mask shape : {atom_mask.shape}")
# print(f"test crop_idx shape : {crop_idx.shape}")
# TODO This function only get first model.
if crop_idx is None :
xyz = xyz[0, :, : , :]
sequence = sequence[0, :]
atom_mask = atom_mask[0, :, :]
else :
xyz = xyz[0, crop_idx, : , :]
sequence = sequence[0, crop_idx, :]
atom_mask = atom_mask[0, crop_idx, :]
pdb_lines = []
for ii in range(xyz.shape[0]):
residue = int(sequence[ii])
residue_stirng = num2aa[residue]
atom_list = list(aa2long[residue])
for jj in range(14):
atom = atom_list[jj]
if atom is None : break
atom_xyz = xyz[ii, jj, :] # (3, )
atom_mask_ = atom_mask[ii, jj]
if atom_mask_ == 0 : continue
pdb_line = atom_to_pdb_line(atom, atom_xyz, ii, residue_stirng, atom_mask_)
pdb_lines.append(pdb_line)
pdb_lines.append("END\n")
pdb_lines = "".join(pdb_lines)
with open(save_dir + str(ID) + ".pdb", "w") as f:
f.write(pdb_lines)
[docs]
def protein_to_pdb(ID, protein, crop_idx = None, save_dir = save_dir):
xyz = protein.struture.xyz
sequence = protein.sequence.sequence
atom_mask = protein.structure.atom_mask
tensor_to_pdb(ID, xyz, sequence, atom_mask, crop_idx, save_dir)
[docs]
def align_two_protein(protein1, protein2):
pass
[docs]
def output_to_pdb(output_dict, save_dir = save_dir):
ID, crop_idx, output_type = output_dict['ID'], output_dict['crop_idx'], output_dict['output_type']
crop_idx = None
if output_type == "Tensor" :
xyz = output_dict['xyz']
sequence = output_dict['sequence']
atom_mask = output_dict['atom_mask']
elif output_type == "Protein" :
protein = output_dict['protein']
xyz = protein['structure']['xyz']
sequence = protein['sequence']
atom_mask = protein['structure']['atom_mask']
tensor_to_pdb(ID, xyz, sequence, atom_mask, crop_idx, save_dir)
[docs]
def visualize_2D_tensor(tensor, saving_path = "test.png"):
import matplotlib.pyplot as plt
tensor = tensor.cpu().detach().numpy()
plt.imshow(tensor)
plt.savefig(saving_path)
[docs]
def visualize_2D_heatmap(heatmap, file_name = "heatmap", heatmap_dir = "opt_visualize/"):
if not os.path.exists(heatmap_dir):
os.makedirs(heatmap_dir)
heatmap = heatmap.detach().cpu().numpy()
plt.imshow(heatmap)
plt.colorbar()
plt.savefig(heatmap_dir + f"{file_name}.png")
plt.close()
[docs]
def visualize_2D_heatmaps(heatmap1, heatmap2, first_title = "True", second_title = "Pred", file_name = "heatmap", heatmap_dir = "opt_visualize/"):
if not os.path.exists(heatmap_dir):
os.makedirs(heatmap_dir)
heatmap1 = heatmap1.detach().cpu().numpy()
heatmap2 = heatmap2.detach().cpu().numpy()
plt.figure(figsize=(20,20))
# plot figures vertically
plt.subplot(2,1,1)
plt.title(first_title)
plt.imshow(heatmap1)
plt.colorbar()
plt.subplot(2,1,2)
plt.title(second_title)
plt.imshow(heatmap2)
plt.colorbar()
plt.savefig(heatmap_dir + f"{file_name}.png")
[docs]
def visualize_t_map(t_map_list, l_idx, atom_idx = 1, file_name = "t_map", t_map_dir = "LADE_visualize/"):
'''
t_map : list of tensor (L,L,n_atoms,3) (size = 2)
'''
if not os.path.exists(t_map_dir):
os.makedirs(t_map_dir)
assert len(t_map_list) == 2
plt.figure(figsize=(20,10))
for ii in range(len(t_map_list)):
t_map = t_map_list[ii]
t_map = t_map.detach().cpu().numpy()
t_map = t_map[l_idx, :, atom_idx, :] # (L,3)
t_map = np.transpose(t_map)
plt.subplot(2,1,ii+1)
plt.imshow(t_map)
plt.colorbar()
plt.savefig(t_map_dir + f"{file_name}.png")
plt.close()