from miniworld.utils.ProteinClass import *
import time
import random
import sys
from numbers import Number
from collections import deque
from collections.abc import Set, Mapping
import tracemalloc
import linecache
from miniworld.utils.chemical import INIT_CRDS
UNK_IDX = 20
GAP_IDX = 21
MASK_IDX = 22
NUM_OF_CLASS = 23
[docs]
def display_top(snapshot, key_type='lineno', limit=10):
snapshot = snapshot.filter_traces((
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
tracemalloc.Filter(False, "<unknown>"),
))
top_stats = snapshot.statistics(key_type)
print("Top %s lines" % limit)
for index, stat in enumerate(top_stats[:limit], 1):
frame = stat.traceback[0]
print("#%s: %s:%s: %.1f KiB"
% (index, frame.filename, frame.lineno, stat.size / 1024))
line = linecache.getline(frame.filename, frame.lineno).strip()
if line:
print(' %s' % line)
other = top_stats[limit:]
if other:
size = sum(stat.size for stat in other)
print("%s other: %.1f KiB" % (len(other), size / 1024))
total = sum(stat.size for stat in top_stats)
print("Total allocated size: %.1f KiB" % (total / 1024))
[docs]
def cluster_sum(data, assignment, N_seq, N_res):
# Get statistics from clustering results (clustering extra sequences with seed sequences)
csum = torch.zeros(
(N_seq, N_res, data.shape[-1]), device=data.device
).scatter_add(
0, assignment.view(-1,1,1).expand(-1,N_res,data.shape[-1]), data.float()
)
return csum
[docs]
def chain_break_cropping(chain_break, crop_idx):
crop_idx = crop_idx.tolist()
crop_idx_set = set(crop_idx)
new_chain_break = OrderedDict()
chain_to_crop_idx = OrderedDict()
before_chain_end = -1
for chain_idx, (chain_start, chain_end) in chain_break.items():
chain_idx_list = list(range(chain_start, chain_end+1))
intersection = crop_idx_set.intersection(chain_idx_list)
intersection = list(intersection)
if len(intersection) == 0:
continue
chain_start = crop_idx.index(min(intersection))
chain_end = crop_idx.index(max(intersection))
new_chain_break[chain_idx] = (chain_start, chain_end)
chain_to_crop_idx[chain_idx] = torch.sort(torch.tensor(intersection))[0]
return new_chain_break, chain_to_crop_idx
[docs]
def MSA_block_deletion(msa, insertion, nb=5):
'''
Down-sample given MSA by randomly delete blocks of sequences
Input: MSA/Insertion having shape (N, L)
output: new MSA/Insertion with block deletion (N', L)
'''
N, L = msa.shape
block_size = max(int(N*0.1), 1)
block_start = np.random.randint(low=1, high=N, size=nb) # (nb)
to_delete = block_start[:,None] + np.arange(block_size)[None,:]
to_delete = np.unique(np.clip(to_delete, 1, N-1))
#
mask = np.ones(N, np.bool_) # for numpy 1.24.1
mask[to_delete] = 0
return msa[mask], insertion[mask]
[docs]
@torch.no_grad()
def MSA_featurize_wo_statistics_by_chain(msa, insertion, N_clust, params):
'''
I modified RF2 version. (just changed name of variables)
Input: full MSA information (after Block deletion if necessary) & full insertion information
Output: seed MSA features & extra sequences
msa : (N, L) torch.LongTensor
ins : (N, L) torch.LongTensor
params : list of parameters
p_mask : probability of masking
eps : small number to avoid zero division
chain_break : dictionary of chain idx {chain_id: (start, end)}
Seed MSA features:
- aatype of seed sequence (20 regular aa + 1 gap + 1 unknwon + 1 mask)
- profile of clustered sequences (23) => removed
- insertion statistics (2) => removed statistics, only use insertion_clust
- N-term or C-term? (2)
extra sequence features:
- aatype of extra sequence (23)
- insertion info (1)
- N-term or C-term? (2)
'''
NUM_OF_CLASS = 23
N, _ = msa.shape
if N == 0:
print(f"msa.shape: {msa.shape}")
print(f"insertion.shape: {insertion.shape}")
raise ValueError("Error in MSA_featurize_wo_statistics_by_chain")
# if crop_idx is not None:
# L = crop_idx.shape[0]
p_mask = params['PROB_MASK'] if 'PROB_MASK' in params else 0.15
eps = params['EPS'] if 'EPS' in params else 1e-6
# TODO Question: Why should I add term_info?
# ----------------------------------------
# I remove term_info relative codes, but it seems to be fine.
# term_info = torch.zeros((L,2), device=msa.device).float()
# print(f"chain_break: {chain_break}")
# print(f"crop_idx: {crop_idx}, crop_idx.shape: {crop_idx.shape}")
# chain_break = chain_break_cropping(chain_break, crop_idx)
# print(f"chain_break: {chain_break}")
# for chain_id, (start, end) in chain_break.items():
# term_info[start,0] = 1.0
# term_info[end,1] = 1.0
# ----------------------------------------
# TODO 20231001 PSK
# Cropping.
# if crop_idx is not None:
# msa = msa[:, crop_idx] #(N, L_crop)
# insertion = insertion[:, crop_idx] #(N, L_crop)
L = msa.shape[1]
# Remove empty sequences
# raw MSA profile
test_idx = 0
msa = msa.long()
try:
full_raw_profile = torch.nn.functional.one_hot(msa, num_classes=NUM_OF_CLASS - 1) # (N, L, 23 - 1), without MASK
except:
print(f"msa.shape: {msa.shape} | msa dtype: {msa.dtype} | msa device: {msa.device}")
raise ValueError("Error in MSA_featurize_wo_statistics")
raw_profile = full_raw_profile.float().mean(dim=0) # (L, 23 - 1)
# summed_profile = torch.zeros((msa.shape[1], NUM_OF_CLASS - 1))
# for i in range(NUM_OF_CLASS - 1):
# summed_profile[:, i] = (msa == i).sum(dim=0).float()
# # Normalize by the number of sequences to get the average
# raw_profile = summed_profile / msa.shape[0]
# Nclust sequences will be selected randomly as a seed MSA (aka latent MSA)
# - First sequence is always query sequence
# - the rest of sequences are selected randomly
USE_MASK = params['USE_MASK']
for i_cycle in range(params['MAX_MSA_CYCLE']):
sample = torch.randperm(N-1, device=msa.device) # (N-1)
# 20231218 PSK
msa_species_idx = torch.cat((torch.tensor([0]), sample[:N_clust-1])) # (Nclust)
msa_clust = torch.cat((msa[:1,:], msa[1:,:][sample[:N_clust-1]]), dim=0) # (Nclust, L)
insertion_clust = torch.cat((insertion[:1,:], insertion[1:,:][sample[:N_clust-1]]), dim=0) # (Nclust, L)
if USE_MASK:
# TODO
# 15% random masking
# - 10%: aa replaced with a uniformly sampled random amino acid
# - 10%: aa replaced with an amino acid sampled from the MSA profile
# - 10%: not replaced
# - 70%: replaced with a special token ("mask")
random_aa = torch.tensor([[0.05]*20 + [0.0]*2], device=msa.device) # (1, 22)
same_aa = torch.nn.functional.one_hot(msa_clust, num_classes=NUM_OF_CLASS - 1) # (Nclust, L, 22)
probs = 0.1*random_aa + 0.1*raw_profile + 0.1*same_aa # (Nclust, L, 22)
probs = torch.nn.functional.pad(probs, (0, 1), "constant", 0.7) # (Nclust, L, 23)
sampler = torch.distributions.categorical.Categorical(probs=probs)
mask_sample = sampler.sample()
mask_pos = torch.rand(msa_clust.shape, device=msa_clust.device) < p_mask # (Nclust, L)
# 20231115 PSK : don't mask at the UNK position
UNK_pos = (msa_clust == UNK_IDX) # (Nclust, L)
mask_pos = mask_pos & ~UNK_pos
msa_masked = torch.where(mask_pos, mask_sample, msa_clust) # (Nclust, L)
else:
mask_pos = torch.zeros_like(msa_clust, dtype=torch.bool, device=msa_clust.device)
msa_masked = msa_clust.clone()
start_time = time.time()
# get number of identical tokens for each pair of sequences (extra vs seed)
# Tensor shape summary
# msa_clust: (N_clust, L)
# mask_pos: (N_clust, L)
# msa_masked: (N_clust, L)
# insertion_clust : (N_clust, L)
# msa_clust_onehot: (N_clust, L, 22)
# count_clust: (N_clust, L)
# term_info : (L, 2) => 따로 처리
# 1. one_hot encoded aatype: msa_clust_onehot
msa_clust_onehot = torch.nn.functional.one_hot(msa_masked, num_classes=NUM_OF_CLASS) # (N, L_crop, 23)
insertion_clust = (2.0/np.pi)*torch.arctan(insertion_clust.float()/3.0) # (from 0 to 1)
insertion_clust = insertion_clust.unsqueeze(-1) # (N_clust, L_crop, 1)
# seed MSA features (one-hot aa, cluster profile, ins statistics, terminal info)
# msa_seed = torch.cat((msa_clust_onehot, msa_clust_profile, insertion_clust, term_info[None].expand(N_clust,-1,-1)), dim=-1)
# TODO Term_info??
# msa_seed = torch.cat((msa_clust_onehot, insertion_clust, term_info[None].expand(N_clust,-1,-1)), dim=-1)
msa_seed = torch.cat((msa_clust_onehot, insertion_clust), dim=-1) # (N_clust, L, 23 + 1)
# sample_seq.append(msa_masked[0].clone())
# sample_msa_clust.append(msa_clust.clone())
# sample_msa_seed.append(msa_seed.clone())
# sample_mask_pos.append(mask_pos.clone())
if i_cycle == 0 :
sample_seq = torch.zeros(params['MAX_MSA_CYCLE'], L) # (MAXCYCLE, L)
sample_msa_species = torch.zeros(params['MAX_MSA_CYCLE'], N_clust) # (MAXCYCLE, Nclust)
sample_msa_clust = torch.zeros(params['MAX_MSA_CYCLE'], N_clust, L) # (MAXCYCLE, Nclust, L)
sample_msa_seed = torch.zeros(params['MAX_MSA_CYCLE'],N_clust,L,NUM_OF_CLASS+1) # (MAXCYCLE, Nclust, L, 23 + 1)
sample_mask_pos = torch.zeros(params['MAX_MSA_CYCLE'],N_clust,L) # (MAXCYCLE, Nclust, L)
sample_seq[i_cycle] = msa[0]
sample_msa_species[i_cycle] = msa_species_idx
sample_msa_clust[i_cycle] = msa_clust
sample_msa_seed[i_cycle] = msa_seed
sample_mask_pos[i_cycle] = mask_pos
test_idx += 1
# print(f"== \t test_idx : { test_idx } | difference : {start_avaialble_memory - end_avaialble_memory} GB")
output_dict = {
'sample_seq': sample_seq, # (MAXCYCLE, L_crop)
'sample_msa_species': sample_msa_species, # (MAXCYCLE, Nclust)
'sample_msa_clust': sample_msa_clust, # (MAXCYCLE, Nclust, L_crop)
'sample_msa_seed': sample_msa_seed, # (MAXCYCLE, Nclust, L_crop, 23 + 1)
'sample_mask_pos': sample_mask_pos # (MAXCYCLE, Nclust, L_crop)
}
return output_dict
[docs]
@torch.no_grad()
def MSA_featurize_wo_statistics(msa, insertion, chain_to_idx_dict, params):
"""I modified RF2 version. (just changed name of variables)
Input: full MSA information (after Block deletion if necessary) & full insertion information
Output: seed MSA features & extra sequences.
:param msa: Full MSA tensor.
:type msa: torch.LongTensor
:param insertion: Full insertion tensor.
:type insertion: torch.LongTensor
:param chain_to_idx_dict: Dictionary mapping chain ID to residue indices.
:type chain_to_idx_dict: dict
:param params: Dictionary of parameters.
:type params: dict
**Seed MSA features:**
- aatype of seed sequence (20 regular aa + 1 gap + 1 unknwon + 1 mask)
- profile of clustered sequences (23) => removed
- insertion statistics (2) => removed statistics, only use insertion_clust
- N-term or C-term? (2)
**extra sequence features:**
- aatype of extra sequence (23)
- insertion info (1)
- N-term or C-term? (2)
"""
# Remove empty sequences
mask = (msa != UNK_IDX).any(dim=-1) # (N)
msa = msa[mask] # (N', L_crop)
insertion = insertion[mask] # (N', L_crop)
# msa_block_deleted = msa[:1] # (1, L)
# insertion_block_deleted = insertion[:1] # (1, L)
# for chain_idx, (start, end) in chain_to_idx_dict.items():
# msa_chain = msa[:, start:end+1] # (N, L_crop)
# mask = (msa_chain != GAP_IDX).any(dim=-1) # (N)
# msa_chain = msa[mask] # (N', L_crop)
# insertion_chain = insertion[mask] # (N', L_crop)
# depth = mask.sum().item()
# if depth > params["MIN_SEED_MSA_PER_CHAIN"]:
# msa_chain, insertion_chain = MSA_block_deletion(msa_chain, insertion_chain, nb=5)
# msa_block_deleted = torch.cat((msa_block_deleted, msa_chain[1:]), dim=0)
# insertion_block_deleted = torch.cat((insertion_block_deleted, insertion_chain[1:]), dim=0)
# msa = msa_block_deleted
# insertion = insertion_block_deleted
depth_per_chain = {}
for chain_idx, (start, end) in chain_to_idx_dict.items():
msa_chain = msa[:, start:end+1] # (N, L_crop)
mask = (msa_chain != UNK_IDX).any(dim=-1) # (N)
depth = mask.sum().item()
depth_per_chain[chain_idx] = depth
remaining_depth = params["MAX_SEED_MSA"]
enough_msa_chain_num = 0
for chain_idx, depth in depth_per_chain.items():
if depth < params["MIN_SEED_MSA_PER_CHAIN"]:
remaining_depth -= depth
else :
enough_msa_chain_num += 1
if remaining_depth < 0:
print(f"chain_num : {depth_per_chain}")
print(f"remaining_depth : {remaining_depth}")
raise ValueError("Error in MSA_featurize_wo_statistics")
max_chain_depth = remaining_depth // enough_msa_chain_num if enough_msa_chain_num > 0 else remaining_depth
depth_per_chain = {k : min(v, max_chain_depth) for k, v in depth_per_chain.items()}
output_dict = {
'sample_seq': torch.full((params['MAX_MSA_CYCLE'],msa.shape[1]),fill_value = UNK_IDX), # (MAXCYCLE, L_crop)
'sample_msa_species': torch.zeros((params['MAX_MSA_CYCLE'],1)), # (MAXCYCLE, Nclust)
'sample_msa_clust': torch.full((params['MAX_MSA_CYCLE'],1, msa.shape[1]),fill_value = UNK_IDX), # (MAXCYCLE, Nclust, L_crop)
'sample_msa_seed': torch.zeros((params['MAX_MSA_CYCLE'],1, msa.shape[1], 24)), # (MAXCYCLE, Nclust, L_crop, 23 + 1)
'sample_mask_pos': torch.zeros((params['MAX_MSA_CYCLE'],1, msa.shape[1])) # (MAXCYCLE, Nclust, L_crop)
}
chain_to_msa_depth = {}
for chain_idx, (start, end) in chain_to_idx_dict.items():
depth = depth_per_chain[chain_idx]
if depth == 0 :
continue
chain_to_msa_depth[chain_idx] = depth
msa_chain = msa[:, start:end+1] # (N, L_chain)
mask = (msa_chain != UNK_IDX).any(dim=-1) # (N)
msa_chain = msa[mask] # (N', L_crop)
insertion_chain = insertion[mask] # (N', L_crop)
chain_dict = MSA_featurize_wo_statistics_by_chain(msa_chain, insertion_chain, depth, params)
for key, value in chain_dict.items():
if key == "sample_seq":
output_dict[key][:,start:end+1] = value[:,start:end+1]
elif key == "sample_msa_species":
output_dict[key] = torch.cat((output_dict[key], value[:,1:]), dim=1)
else :
output_dict[key][:,0,start:end+1] = value[:,0,start:end+1]
output_dict[key] = torch.cat((output_dict[key], value[:,1:]), dim=1)
output_dict["chain_to_msa_depth"] = chain_to_msa_depth
return output_dict
ZERO_DEPTH_BASES = (str, bytes, Number, range, bytearray)
[docs]
def getsize(obj_0):
"""Recursively iterate to sum size of object & members."""
_seen_ids = set()
def inner(obj):
if isinstance(obj, torch.Tensor):
return obj.element_size() * obj.nelement()
obj_id = id(obj)
if obj_id in _seen_ids:
return 0
_seen_ids.add(obj_id)
size = sys.getsizeof(obj)
if isinstance(obj, ZERO_DEPTH_BASES):
pass # bypass remaining control flow and return
elif isinstance(obj, (tuple, list, Set, deque)):
size += sum(inner(i) for i in obj)
elif isinstance(obj, Mapping) or hasattr(obj, 'items'):
size += sum(inner(k) + inner(v) for k, v in getattr(obj, 'items')())
# Check for custom object instances - may subclass above too
if hasattr(obj, '__dict__'):
size += inner(vars(obj))
if hasattr(obj, '__slots__'): # can have __slots__ with __dict__
size += sum(inner(getattr(obj, s)) for s in obj.__slots__ if hasattr(obj, s))
return size
return inner(obj_0)
[docs]
def center_and_realign_missing(xyz, mask_t):
# I don't modify this function.
# xyz: (L, 27, 3)
# mask_t: (L, 27)
L = xyz.shape[0]
mask = mask_t[:,:3].all(dim=-1) # True for valid atom (L)
# center c.o.m at the origin
center_CA = (mask[...,None]*xyz[:,1]).sum(dim=0) / (mask[...,None].sum(dim=0) + 1e-5) # (3)
xyz = torch.where(mask.view(L,1,1), xyz - center_CA.view(1, 1, 3), xyz)
# move missing residues to the closest valid residues
exist_in_xyz = torch.where(mask)[0] # L_sub
seqmap = (torch.arange(L, device=xyz.device)[:,None] - exist_in_xyz[None,:]).abs() # (L, Lsub)
seqmap = torch.argmin(seqmap, dim=-1) # L
idx = torch.gather(exist_in_xyz, 0, seqmap)
offset_CA = torch.gather(xyz[:,1], 0, idx.reshape(L,1).expand(-1,3))
xyz = torch.where(mask.view(L,1,1), xyz, xyz + offset_CA.reshape(L,1,3))
return xyz
[docs]
def template_featurize(input_template_dict, params):
"""I modified RF2 version.
In MSA_featurize, I changed the name of variables and a small part of code because
the shape of inputs (msa, insertion) are almost same as RF2.
On the other hand, I totally reconstructed template structure, so I changed a lot in this function.
.. note::
Processes template information for a single chain.
:param input_template_dict: A dictionary containing template information.
It should have the following keys:
- 'xyz': torch.Tensor of shape (N_template, L_chain, 27, 3)
- 'mask': torch.Tensor of shape (N_template, L_chain, 27)
- 'sequence': torch.Tensor of shape (N_template, L_chain, NUM_CLASSES)
- 'f0d': torch.Tensor of shape (N_template)
- 'f1d': torch.Tensor of shape (N_template, L_chain)
:type input_template_dict: dict
:param params: Dictionary of parameters.
:type params: dict
:return: A dictionary with processed template features:
- 'xyz': torch.Tensor of shape (npick_global, L_query, 27, 3)
- 'template_1D': torch.Tensor of shape (npick_global, L_query, 23 + 1)
- 'template_atom_mask': torch.Tensor of shape (npick_global, L_query, 27)
:rtype: dict
"""
npick = params["N_PICK"] if "N_PICK" in params else 1
npick_global = params["N_PICK_GLOBAL"] if "N_PICK_GLOBAL" in params else None
assert npick <= npick_global
pick_top = params["PICK_TOP"] if "PICK_TOP" in params else True
random_noise = params["RANDOM_NOISE"] if "RANDOM_NOISE" in params else 5.0
if npick_global == None:
npick_global=max(npick, 1)
seqID_cut = params['SEQUENCE_IDENTITY_CUTOFF'] if 'SEQUENCE_IDENTITY_CUTOFF' in params else 100.0
NUM_CLASSES = 23
N_template = input_template_dict['xyz'].shape[0]
chain_length = input_template_dict['xyz'].shape[1]
if (N_template < 1) or (npick < 1): # no templates in hhsearch file or not want to use templ - return fake templ
xyz = INIT_CRDS.reshape(1,1,27,3).repeat(npick_global,chain_length,1,1) + torch.rand(npick_global,chain_length,1,3)*random_noise
template_1D = torch.nn.functional.one_hot(torch.full((npick_global, chain_length), GAP_IDX).long(), num_classes=NUM_CLASSES).float() # all gaps
conf = torch.zeros((npick_global, chain_length, 1)).float()
template_1D = torch.cat((template_1D, conf), -1)
template_atom_mask = torch.full((npick_global,chain_length,27), False)
output_dict = {
'xyz': xyz, # (npick_global, L_chain, 27, 3)
'template_1D': template_1D, # (npick_global, L_chain, NUM_CLASSES + 1)
'template_atom_mask': template_atom_mask # (npick_global, L_chain, 27)
}
return output_dict
# ignore templates having too high seqID
if seqID_cut <= 100.0:
template_valid_idx = torch.where(input_template_dict['f0d'] < seqID_cut)[0]
else :
template_valid_idx = torch.arange(N_template)
# check again if there are templates having seqID < cutoff
N_template = template_valid_idx.shape[0]
npick = min(npick, N_template)
if npick<1: # no templates -- return fake templ
# xyz = INIT_CRDS.reshape(1,1,27,3).repeat(npick_global,chain_length,1,1) + torch.rand(npick_global,chain_length,1,3)*random_noise
xyz = INIT_CRDS.reshape(1,1,27,3).repeat(npick_global,chain_length,1,1) + torch.rand(npick_global,chain_length,1,3)*random_noise
template_1D = torch.nn.functional.one_hot(torch.full((npick_global, chain_length), GAP_IDX).long(), num_classes=NUM_CLASSES).float() # all gaps
conf = torch.zeros((npick_global, chain_length, 1)).float()
template_1D = torch.cat((template_1D, conf), -1)
template_atom_mask = torch.full((npick_global,chain_length,27), False)
output_dict = {
'xyz': xyz, # (npick_global, L_chain, 27, 3)
'template_1D': template_1D, # (npick_global, L_chain, NUM_CLASSES + 1)
'template_atom_mask': template_atom_mask # (npick_global, L_chain, 27)
}
return output_dict
if not pick_top: # select randomly among all possible templates
sample = torch.randperm(N_template)[:npick]
else: # only consider top 50 templates
sample = torch.randperm(min(50,N_template))[:npick]
# TODO. Question in Notion 20230620
xyz = INIT_CRDS.reshape(1,1,27,3).repeat(npick_global,chain_length,1,1) + torch.rand(npick_global,chain_length,1,3)*random_noise # ??? why same noise to all templates
template_atom_mask = torch.full((npick_global,chain_length,27), False) # True for valid atom, False for missing atom
template_1D = torch.full((npick_global,chain_length,NUM_CLASSES), GAP_IDX).long() # (npick_global, L_chain, NUM_CLASSES)
t1d_val = torch.zeros((npick_global, chain_length)).float()
for i,nt in enumerate(sample):
template_idx = template_valid_idx[nt]
template_xyz = input_template_dict['xyz'][template_idx] # (L_chain, 27, 3)
template_mask = input_template_dict['mask'][template_idx] # (L_chain, 27)
template_seq = input_template_dict['sequence'][template_idx] # (L_chain, NUM_CLASSES) (already one-hot encoded)
template_f1d = input_template_dict['f1d'][template_idx] # (L_chain)
xyz[i] = center_and_realign_missing(template_xyz, template_mask)
template_atom_mask[i] = template_mask
template_1D[i] = template_seq
t1d_val[i] = template_f1d
template_1D = torch.cat((template_1D, t1d_val[...,None]), dim=-1) # (npick_global, L_chain, NUM_CLASSES + 1)
output_dict = {
'xyz': xyz, # (npick_global, L_chain, 27, 3)
'template_1D': template_1D, # (npick_global, L_chain, NUM_CLASSES + 1)
'template_atom_mask': template_atom_mask # (npick_global, L_chain, 27)
}
return output_dict
# slice long chains
[docs]
def get_crop(chain_start, chain_end, mask, device, params, unclamp = False, ID = None):
# chain_start : idx of first residue in chain
# chain_end : idx of last residue in chain
# mask : (Model_num, L, 27, 3)
l = chain_end - chain_start + 1
sel = torch.arange(chain_start, chain_end+1).to(device)
crop_size = params['CROP']
if l <= crop_size:
return sel
random_model = torch.randint(mask.shape[0], (1,)).item()
mask = mask[random_model,chain_start:chain_end+1] # (L, 27, 3)
mask = ~(mask[:,:3].sum(dim=-1) < 3.0)
exists = mask.nonzero()[:,0]
# print(f"exists shape: {exists.shape}")
if (len(exists) == 0):
print(f"exists: {exists}")
print(f"exists: {exists.shape}")
print(f"exists len: {len(exists)}")
print(f"mask sum : {mask.sum()}")
print(f"ID : {ID}")
x = 0
else :
x = np.random.randint(len(exists)) + 1
# print(f"ffasdfa x: {x}")
crop_size = params['CROP']
if unclamp: # bias it toward N-term.. (follow what AF did.. but don't know why)
# x = np.random.randint(len(exists)) + 1
res_idx = exists[torch.randperm(x)[0]].item()
else:
res_idx = exists[torch.randperm(len(exists))[0]].item()
res_idx += chain_start
lower_bound = max(chain_start, res_idx-crop_size)
upper_bound = min(chain_end-crop_size+1, res_idx+1)
start = np.random.randint(lower_bound, upper_bound)
end = min(start+crop_size, chain_end+1)
return torch.arange(start, end).to(device)
[docs]
def random_split(n, k, min_split):
# Adjust n and the range of the random sample to account for min_split
n_adjusted = n - min_split * k
# Get k-1 random points in the range 1 to n_adjusted
points = sorted(random.sample(range(1, n_adjusted + 1), k - 1))
# Add 0 to the start and n_adjusted to the end
points = [0] + points + [n_adjusted]
# Get the differences between consecutive points and add min_split to each
splits = [points[i+1] - points[i] + min_split for i in range(len(points) - 1)]
return splits
[docs]
def get_complex_crop(len_s, mask, device, params):
# randonly select chain and crop.
# this function only works for complex, not for single chain.
# set seed
np.random.seed(params["SEED"])
if mask is None :
mask = torch.ones((1, sum(len_s), 27), device=device)
if mask.dtype == torch.bool:
mask = mask.float()
if len(mask.shape) == 3: # (Model_num, L, 27)
tot_len = mask.shape[1]
model_idx = np.random.randint(mask.shape[0])
mask = mask[model_idx]
else: # (L, 27)
tot_len = mask.shape[0]
assert len(len_s) > 1, "This function only works for complex, not for single chain."
sel = torch.arange(tot_len, device=device)
# filter valid chain (remove all masked chain)
valid_chain_idxs = list()
preset = 0
for ii in range(len(len_s)):
mask_chain = ~(mask[preset:preset+len_s[ii],:3].sum(dim=-1) < 3.0)
exist_chain = mask_chain.nonzero()[:,0]
if len(exist_chain) > 0:
valid_chain_idxs.append(ii)
preset += len_s[ii]
chain_number = len(valid_chain_idxs)
assert chain_number > 0, "There is no valid chain."
random_chain_min = params["RANDOM_CHAIN_MIN"] # 2
random_chain_max = params["RANDOM_CHAIN_MAX"] # 4
random_chain = np.random.randint(random_chain_min, random_chain_max+1)
random_chain = min(random_chain, chain_number)
# sample without replacement from valid_chain_idxs
random_chain_idx = np.random.choice(valid_chain_idxs, random_chain, replace=False)
crop_min = params["CROP_MIN"] # 10
crop_length = params["CROP"] # 128
cropped_length = random_split(crop_length, random_chain, crop_min)
chain_crop_length_dict = dict(zip(random_chain_idx, cropped_length))
n_added = 0
n_remaining = sum(len_s)
preset = 0
sel_s = list()
for k in range(len(len_s)):
if k in random_chain_idx:
n_remaining -= len_s[k]
crop_size = chain_crop_length_dict[k]
if crop_size > len_s[k]:
sel_s.append(sel[preset:preset+len_s[k]])
else:
n_added += crop_size
mask_chain = ~(mask[preset:preset+len_s[k],:3].sum(dim=-1) < 3.0)
exists = mask_chain.nonzero()[:,0]
res_idx = exists[torch.randperm(len(exists))[0]].item()
lower_bound = max(0, res_idx - crop_size + 1)
upper_bound = min(len_s[k]-crop_size, res_idx) + 1
try:
start = np.random.randint(lower_bound, upper_bound) + preset
except :
print (lower_bound, upper_bound, preset, len_s[k], crop_size) # 0, -2, 351, 351, 354
print(cropped_length)
raise ValueError
sel_s.append(sel[start:start+crop_size])
preset += len_s[k]
return torch.cat(sel_s)
[docs]
def get_STRING_crop(len_s, mask, device, params):
# crop algorithm for STRING
# It should contains two chains.
np.random.seed(params["SEED"])
if mask is None :
mask = torch.ones((1, sum(len_s), 27), device=device)
if mask.dtype == torch.bool:
mask = mask.float()
if len(mask.shape) == 3: # (Model_num, L, 27)
tot_len = mask.shape[1]
model_idx = np.random.randint(mask.shape[0])
mask = mask[model_idx]
else: # (L, 27)
tot_len = mask.shape[0]
assert len(len_s) == 2, "This function only works for STRING data"
sel = torch.arange(tot_len, device=device)
# filter valid chain (remove all masked chain)
valid_chain_idxs = list()
preset = 0
for ii in range(len(len_s)):
mask_chain = ~(mask[preset:preset+len_s[ii],:3].sum(dim=-1) < 3.0)
exist_chain = mask_chain.nonzero()[:,0]
if len(exist_chain) > 0:
valid_chain_idxs.append(ii)
chain_number = len(valid_chain_idxs)
assert chain_number == 2, "This function only works for STRING data"
# random_chain_min = params["RANDOM_CHAIN_MIN"] # 2
# random_chain_max = params["RANDOM_CHAIN_MAX"] # 4
# random_chain = np.random.randint(random_chain_min, random_chain_max+1)
# random_chain = min(random_chain, chain_number)
# sample without replacement from valid_chain_idxs
# random_chain_idx = np.random.choice(valid_chain_idxs, random_chain, replace=False)
chain_idxs = valid_chain_idxs
crop_min = params["CROP_MIN"] # 10
# TODO PSK
crop_min = 30
crop_length = params["CROP"] # 128
cropped_length = random_split(crop_length, 2, crop_min)
chain_crop_length_dict = dict(zip(chain_idxs, cropped_length))
n_added = 0
n_remaining = sum(len_s)
preset = 0
sel_s = list()
for k in range(len(len_s)):
if k in chain_idxs:
n_remaining -= len_s[k]
crop_size = chain_crop_length_dict[k]
if crop_size > len_s[k]:
sel_s.append(sel[preset:preset+len_s[k]])
else:
n_added += crop_size
mask_chain = ~(mask[preset:preset+len_s[k],:3].sum(dim=-1) < 3.0)
exists = mask_chain.float().nonzero()[:,0]
res_idx = exists[torch.randperm(len(exists))[0]].item()
lower_bound = max(0, res_idx - crop_size + 1)
upper_bound = min(len_s[k]-crop_size, res_idx) + 1
try:
start = np.random.randint(lower_bound, upper_bound) + preset
except :
print (lower_bound, upper_bound, preset, len_s[k], crop_size) # 0, -2, 351, 351, 354
print(cropped_length)
raise ValueError
sel_s.append(sel[start:start+crop_size])
preset += len_s[k]
return torch.cat(sel_s)
[docs]
def cutoff_chain_num(sel, xyz, chain_break, params, query_chain_idx):
chain_to_residue_idxs = {}
for chain_idx, (chain_start, chain_end) in chain_break.items():
chain_residue_idxs = torch.arange(chain_start, chain_end+1)
if torch.isin(sel, chain_residue_idxs).any():
intersection = torch.isin(sel, chain_residue_idxs)
chain_residue_idxs = sel[intersection]
chain_to_residue_idxs[chain_idx] = chain_residue_idxs
max_chain = params["MAX_CHAIN"] if "MAX_CHAIN" in params else 4
length_cutoff = params["CHAIN_LENGTH_CUTOFF"] if "CHAIN_LENGTH_CUTOFF" in params else 5
if len(chain_to_residue_idxs) > max_chain:
distance_between_query_chain = {}
query_chain_xyz = xyz[chain_to_residue_idxs[query_chain_idx], 1] # (L_query, 3)
for chain_idx in chain_to_residue_idxs.keys():
if chain_idx == query_chain_idx:
continue
if len(chain_to_residue_idxs[chain_idx]) < length_cutoff:
continue
chain_xyz = xyz[chain_to_residue_idxs[chain_idx], 1] # (L_chain, 3)
dist = torch.cdist(query_chain_xyz, chain_xyz) # (L_query, L_chain)
min_dist = torch.mean(dist) # (1)
distance_between_query_chain[chain_idx] = min_dist
nearest_chains = sorted(distance_between_query_chain.items(), key=lambda x: x[1])[:max_chain-1]
nearest_chains = [query_chain_idx] + [chain_idx for chain_idx, _ in nearest_chains]
filtered_sel = []
for chain_idx in nearest_chains:
filtered_sel.append(chain_to_residue_idxs[chain_idx])
filtered_sel = torch.cat(filtered_sel)
filtered_sel = torch.sort(filtered_sel)[0]
sel = filtered_sel
return sel
# def get_spatial_crop(xyz, mask, chain_break, len_s, params, protein_ID, cutoff=10.0, eps=1e-6):
# # xyz : (Model_num, L, 14, 3) or (Model_num, L, 14, 3)
# # mask : (Model_num, L, 14) or None
# tot_len = xyz.shape[1]
# model_idx = np.random.randint(xyz.shape[0])
# device = xyz.device
# if mask is None :
# mask = torch.bool((xyz.shape[0], xyz.shape[1], xyz.shape[2]), device=device)
# if mask.dtype != torch.bool:
# mask = mask.bool()
# if "5mwe" in protein_ID : breakpoint()
# # model sampling
# xyz = xyz[model_idx] # (L, 14, 3)
# mask = mask[model_idx] # (L, 14)
# position_mask = mask[:, 1] if mask is not None else torch.ones((xyz.shape[0], xyz.shape[1]), device=xyz.device)
# valid_residue_idxs = torch.where(position_mask)[0]
# if len(valid_residue_idxs) <= params['CROP']:
# if "5mwe" in protein_ID : breakpoint()
# query_chain_idx = torch.randint(len(chain_break), (1,)).item()
# query_chain_idx = list(chain_break.keys())[query_chain_idx]
# sel = cutoff_chain_num(valid_residue_idxs, xyz, chain_break, params, query_chain_idx)
# if len(valid_residue_idxs) < 1:
# print("ERROR: no valid residue????", protein_ID)
# return sel, "PDB_Spatial"
# sel = torch.arange(tot_len, device=device)
# chain_number = len(len_s)
# chain_list = list(chain_break.keys())
# random_chain_min = params["RANDOM_CHAIN_MIN"] # 2
# random_chain_max = params["RANDOM_CHAIN_MAX"] # 4, Acutally these two values are not that important.
# random_chain_number = np.random.randint(random_chain_min, random_chain_max+1)
# random_chain_number = min(random_chain_number, chain_number)
# random_chain_idx = np.random.choice(np.arange(chain_number), random_chain_number, replace = False)
# random_chain_idx = [chain_list[i] for i in random_chain_idx]
# ifaces = list()
# for chain_idx in random_chain_idx:
# chain_start, chain_end = chain_break[chain_idx]
# in_chain = torch.arange(chain_start, chain_end+1)
# try :
# out_chain = torch.cat([torch.arange(0, chain_start), torch.arange(chain_end+1, tot_len)])
# except Exception as e:
# print(f"protein_ID : {protein_ID}")
# print(f"chain_break : {chain_break}")
# print(f"chain_start : {chain_start}")
# print(f"chain_end : {chain_end}")
# print(f"tot_len : {tot_len}")
# print(f"error : {e}")
# raise ValueError
# cond = torch.cdist(xyz[in_chain,1], xyz[out_chain,1]) < cutoff # (L_in, L_out)
# # import matplotlib.pyplot as plt
# # cond_path = f"cond_{protein_ID}_{chain_idx}.png"
# # mask_path = f"mask_{protein_ID}_{chain_idx}.png"
# # plt.imshow(cond.cpu().numpy())
# # plt.savefig(cond_path)
# # plt.imshow(mask.cpu().numpy())
# # plt.savefig(mask_path)
# cond = torch.logical_and(cond, mask[in_chain,None,1]*mask[None,out_chain,1]) # (L_in, L_out)
# i,_ = torch.where(cond) # (N_ifaces)
# chain_ifaces = i + chain_start # (N_ifaces)
# ifaces.append(chain_ifaces)
# ifaces = torch.cat(ifaces) # (N_ifaces)
# if len(ifaces) < 1:
# print ("ERROR: no iface residue????", protein_ID)
# print(f"random_chain_idx : {random_chain_idx}")
# print(f"label_xyz.shape : {xyz.shape}")
# print(f"label_mask.shape : {mask.shape}")
# chain_start, chain_end = chain_break[random_chain_idx[0]]
# mask = mask.unsqueeze(0)
# crop_idx = get_crop(chain_start, chain_end, mask, device, params, ID = None)
# # crop_idx = get_complex_crop(len_s, mask, device, params)
# if len(crop_idx) == 0:
# print(f"len_s : {len_s}")
# print(f"mask.shape : {mask.shape}")
# print("WTF")
# # return crop_idx, "PDB_Complex"
# return crop_idx, "PDB_Monomer"
# if "5mwe" in protein_ID : breakpoint()
# cnt_idx = ifaces[np.random.randint(len(ifaces))]
# # 20240130 PSK Add chain number condition
# cnt_idx = ifaces[np.random.randint(len(ifaces))]
# for chain_idx, (chain_start, chain_end) in chain_break.items():
# if chain_start <= cnt_idx <= chain_end:
# query_chain_idx = chain_idx
# break
# if "5mwe" in protein_ID : breakpoint()
# dist = torch.cdist(xyz[:,1], xyz[cnt_idx,1][None]).reshape(-1) + torch.arange(len(xyz), device=xyz.device)*eps
# cond = mask[:,1]*mask[cnt_idx,1]
# dist[~cond] = 999999.9
# _, idx = torch.topk(dist, params['CROP'], largest=False)
# if "5mwe" in protein_ID : breakpoint()
# #TODO. Check this function.
# sel, _ = torch.sort(sel[idx])
# if "5mwe" in protein_ID : breakpoint()
# sel = cutoff_chain_num(sel, xyz, chain_break, params, query_chain_idx)
# # Although dist of masked residues are assigned to 999999.9, there is a chance that the residue is selected.
# # So I added this code to remove the residue using valid_residue_idxs.
# # removed residues are filled by -1(crop mask idx).
# # sel = torch.where(torch.isin(sel, valid_residue_idxs), sel, 999999)
# # sel = torch.sort(sel)[0]
# # # 999999 -> -1
# # sel = torch.where(sel == 999999, -1, sel)
# if len(sel) == 0 :
# print(f"dist.shape : {dist.shape}")
# print(f"idx.shape : {idx.shape}")
# print(f"sel.shape : {sel.shape}")
# print(f"sel : {sel}")
# return sel, "PDB_Spatial"
[docs]
def get_spatial_crop(xyz, mask, pivot_chain_idx, chain_break, len_s, params, protein_ID, cutoff=10.0, eps=1e-6):
# xyz : (Model_num, L, 14, 3) or (Model_num, L, 14, 3)
# mask : (Model_num, L, 14) or None
tot_len = xyz.shape[1]
model_idx = np.random.randint(xyz.shape[0])
device = xyz.device
if mask is None :
mask = torch.bool((xyz.shape[0], xyz.shape[1], xyz.shape[2]), device=device)
if mask.dtype != torch.bool:
mask = mask.bool()
# model sampling
xyz = xyz[model_idx] # (L, 14, 3)
mask = mask[model_idx] # (L, 14)
position_mask = mask[:, 1] if mask is not None else torch.ones((xyz.shape[0], xyz.shape[1]), device=xyz.device)
valid_residue_idxs = torch.where(position_mask)[0]
if len(valid_residue_idxs) <= params['CROP']:
sel = cutoff_chain_num(valid_residue_idxs, xyz, chain_break, params, pivot_chain_idx)
if len(valid_residue_idxs) < 1:
print("ERROR: no valid residue????", protein_ID)
return sel, "PDB_Spatial"
sel = torch.arange(tot_len, device=device)
chain_number = len(len_s)
chain_list = list(chain_break.keys())
ifaces = list()
chain_start, chain_end = chain_break[pivot_chain_idx]
in_chain = torch.arange(chain_start, chain_end+1)
try :
out_chain = torch.cat([torch.arange(0, chain_start), torch.arange(chain_end+1, tot_len)])
except Exception as e:
print(f"protein_ID : {protein_ID}")
print(f"chain_break : {chain_break}")
print(f"chain_start : {chain_start}")
print(f"chain_end : {chain_end}")
print(f"tot_len : {tot_len}")
print(f"error : {e}")
raise ValueError
cond = torch.cdist(xyz[in_chain,1], xyz[out_chain,1]) < cutoff # (L_in, L_out)
cond = torch.logical_and(cond, mask[in_chain,None,1]*mask[None,out_chain,1]) # (L_in, L_out)
i,_ = torch.where(cond) # (N_ifaces)
chain_ifaces = i + chain_start # (N_ifaces)
ifaces.append(chain_ifaces)
ifaces = torch.cat(ifaces) # (N_ifaces)
# select ifaces from in_chain
ifaces = ifaces[torch.isin(ifaces, in_chain)]
ifaces = torch.sort(ifaces)[0]
if len(ifaces) < 1:
print ("ERROR: no iface residue????", protein_ID)
print(f"random_chain_idx : {pivot_chain_idx}")
print(f"label_xyz.shape : {xyz.shape}")
print(f"label_mask.shape : {mask.shape}")
chain_start, chain_end = chain_break[pivot_chain_idx[0]]
mask = mask.unsqueeze(0)
crop_idx = get_crop(chain_start, chain_end, mask, device, params, ID = None)
# crop_idx = get_complex_crop(len_s, mask, device, params)
if len(crop_idx) == 0:
print(f"len_s : {len_s}")
print(f"mask.shape : {mask.shape}")
print("WTF")
# return crop_idx, "PDB_Complex"
return crop_idx, "PDB_Monomer"
# 20240130 PSK Add chain number condition
cnt_idx = ifaces[np.random.randint(len(ifaces))]
dist = torch.cdist(xyz[:,1], xyz[cnt_idx,1][None]).reshape(-1) + torch.arange(len(xyz), device=xyz.device)*eps
cond = mask[:,1]*mask[cnt_idx,1]
dist[~cond] = 999999.9
_, idx = torch.topk(dist, params['CROP'], largest=False)
#TODO. Check this function.
sel, _ = torch.sort(sel[idx])
sel = cutoff_chain_num(sel, xyz, chain_break, params, pivot_chain_idx)
if len(sel) == 0 :
print(f"dist.shape : {dist.shape}")
print(f"idx.shape : {idx.shape}")
print(f"sel.shape : {sel.shape}")
print(f"sel : {sel}")
return sel, "PDB_Spatial"
[docs]
def find_chain_combinations(pairs):
# Build a graph with nodes as chain points (e.g., 'A_1')
graph = {}
for a, b in pairs:
if a not in graph:
graph[a] = []
if b not in graph:
graph[b] = []
graph[a].append(b)
graph[b].append(a)
def extract_chain_id(chain_point):
return chain_point.split('|')[0]
# Extract unique chain IDs
chain_ids = set(extract_chain_id(x) for x in graph.keys())
all_combinations = []
def dfs(node, path):
# Check if current path covers all chain types
if set(extract_chain_id(p) for p in path) == chain_ids:
sorted_path = tuple(sorted(path, key=lambda x: extract_chain_id(x))) # Sort path for uniqueness
all_combinations.append(sorted_path)
return
for neighbor in graph[node]:
if neighbor not in path: # Allow revisiting chain types but not exact nodes
path.append(neighbor)
dfs(neighbor, path)
path.pop()
# Start DFS from each node
for start_node in graph.keys():
dfs(start_node, [start_node])
return list(set(all_combinations)) # Remove duplicates by converting to set and back to list
[docs]
def get_same_crop_idx(xyz_full, crop_idx, chain_break, same_chain_info, cutoff = 10.0):
# xyz_full : (Model_num, L, 14, 3)
# crop_idx : (L_crop)
# chain_break : {chain_idx : (chain_start, chain_end)}
# same_chain_info : {chains : [chain_idx1, chain_idx2, ...], same_chain : torch.tensor, NxN}
cropped_chain_break, chain_to_crop_idx = chain_break_cropping(chain_break, crop_idx)
cropped_chain_idx_dict = {}
chain_list, same_chain = same_chain_info['chains'], same_chain_info['same_chain']
total_idx = []
for chain_idx in cropped_chain_break.keys():
same_chain_idxs = same_chain[chain_list.index(chain_idx)] # tensor (Chain num)
same_chain_list = torch.where(same_chain_idxs == 1)[0].tolist()
chain_start, chain_end = cropped_chain_break[chain_idx]
cropped_length = chain_end - chain_start + 1
for same_chain_idx in same_chain_list:
if cropped_length != chain_break[same_chain_idx][1] - chain_break[same_chain_idx][0] + 1:
return [crop_idx]
original_chain_crop_idx = chain_to_crop_idx[chain_idx] - chain_start
for ii, same_chain_idx in enumerate(same_chain_list):
same_chain_crop_idx = original_chain_crop_idx + chain_break[same_chain_idx][0]
cropped_chain_idx_dict[chain_idx + "|" + str(ii)] = same_chain_crop_idx
total_idx.append(same_chain_crop_idx)
if len(cropped_chain_break) == 1 :
return cropped_chain_idx_dict[cropped_chain_break.keys()[0]]
total_idx = torch.tensor(total_idx)
cropped_ca = xyz_full[0,total_idx,1,:] # use first model
dist_map = torch.cdist(cropped_ca, cropped_ca) # (L_crop, L_crop)
contact_map = dist_map < cutoff # (L_crop, L_crop)
num_chains = len(cropped_chain_idx_dict)
cropped_chain_idx_list = list(cropped_chain_idx_dict.values())
expanded_contact_map = contact_map.unsqueeze(0).unsqueeze(0).expand(num_chains, num_chains, -1, -1)
chain_masks = torch.zeros((num_chains, contact_map.shape[0]), dtype=torch.bool)
for ii, (chain_idx, crop_idx_tensor) in enumerate(cropped_chain_idx_dict.items()):
chain_masks[ii, crop_idx_tensor] = True
chain_mask_i = chain_masks.unsqueeze(1).unsqueeze(-1).expand_as(expanded_contact_map)
chain_mask_j = chain_masks.unsqueeze(0).unsqueeze(-1).expand_as(expanded_contact_map).transpose(-1, -2)
masked_contact_map = expanded_contact_map * chain_mask_i * chain_mask_j
chain_contact_map = masked_contact_map.sum(dim=(-1, -2)) # (num_chains, num_chains)
chain_contact_map = chain_contact_map > 0
contact_pair = []
for i in range(num_chains):
for j in range(i+1, num_chains):
if chain_contact_map[i,j]:
contact_pair.append(cropped_chain_idx_list[i], cropped_chain_idx_list[j])
chain_combinations = find_chain_combinations(contact_pair)
crop_idx_list = []
for combination in chain_combinations:
crop_idx = torch.cat([cropped_chain_idx_dict[chain_idx] for chain_idx in combination])
crop_idx_list.append(crop_idx)
return crop_idx_list
[docs]
def generate_combinations(lst, pairs):
combinations = set() # Use a set to avoid duplicate combinations
queue = [lst] # Start with the initial list
while queue:
current = queue.pop(0)
combinations.add(tuple(current)) # Add the current combination as a tuple to maintain order
for a, b in pairs:
if a in current and b in current:
# Swap elements a and b
idx_a, idx_b = current.index(a), current.index(b)
new_combination = current.copy()
new_combination[idx_a], new_combination[idx_b] = new_combination[idx_b], new_combination[idx_a]
# If this new combination hasn't been seen before, add it to the queue
if tuple(new_combination) not in combinations:
queue.append(new_combination)
# Convert each combination from tuple back to string and return the list of combinations
return [list(combination) for combination in combinations]
[docs]
def permute_label(protein_list, crop_idx, out_of_sequence_idxs, chain_break, same_chain_info):
cropped_chain_break, chain_to_crop_idx = chain_break_cropping(chain_break, crop_idx)
query_seq = protein_list[0].sequence.sequence[0]
chain_list = same_chain_info['chains']
same_chain = same_chain_info['same_chain']
cropped_chain_list = list(cropped_chain_break.keys())
chain_num = len(cropped_chain_list)
new_same_chain = torch.zeros((chain_num, chain_num), dtype=torch.bool)
for cropped_chain1 in cropped_chain_break.keys():
for cropped_chain2 in cropped_chain_break.keys():
if cropped_chain1 == cropped_chain2:
continue
chain1 = cropped_chain_list.index(cropped_chain1)
chain2 = cropped_chain_list.index(cropped_chain2)
crop_idx1 = chain_to_crop_idx[cropped_chain1]
crop_idx2 = chain_to_crop_idx[cropped_chain2]
chain_seq1 = query_seq[crop_idx1]
chain_seq2 = query_seq[crop_idx2]
if same_chain[chain1, chain2] and crop_idx1.shape[0]==crop_idx2.shape[0]:
if torch.all(chain_seq1 == chain_seq2).item():
new_same_chain[cropped_chain_list.index(cropped_chain1), cropped_chain_list.index(cropped_chain2)] = True
if new_same_chain.sum() == 0:
permuted_list = [cropped_chain_list]
else :
permutable_pair_list = []
for i in range(chain_num):
for j in range(i+1, chain_num):
if new_same_chain[i,j]:
permutable_pair_list.append((cropped_chain_list[i], cropped_chain_list[j]))
permuted_list = generate_combinations(cropped_chain_list, permutable_pair_list)
new_label_seq = []
new_label_xyz = []
new_label_atom_mask = []
for protein in protein_list :
protein_sequence = protein.sequence
protein_structure = protein.structure
# filter no MSA
label_sequence = protein_sequence.sequence # (Model, L_total)
label_xyz = protein_structure.xyz # (Model, L_total, 14, 3)
label_atom_mask = protein_structure.atom_mask # (Model, L_total, 14)
label_atom_mask[:,out_of_sequence_idxs,:] = 0
chain_to_seq = {}
chain_to_xyz = {}
chain_to_atom_mask = {}
for chain_idx in cropped_chain_break.keys():
chain_crop_idx = chain_to_crop_idx[chain_idx]
chain_to_seq[chain_idx] = label_sequence[:, chain_crop_idx]
chain_to_xyz[chain_idx] = label_xyz[:, chain_crop_idx, :, :]
chain_to_atom_mask[chain_idx] = label_atom_mask[:, chain_crop_idx, :]
for permuted_chain_list in permuted_list:
permuted_seq = []
permuted_xyz = []
permuted_atom_mask = []
permuted_chain_break = {}
chain_start = 0
for chain_idx in permuted_chain_list:
permuted_seq.append(chain_to_seq[chain_idx])
permuted_xyz.append(chain_to_xyz[chain_idx])
permuted_atom_mask.append(chain_to_atom_mask[chain_idx])
permuted_chain_break[chain_idx] = (chain_start, chain_start + chain_to_seq[chain_idx].shape[1]-1)
chain_start += chain_to_seq[chain_idx].shape[1]
permuted_seq = torch.cat(permuted_seq, dim=1)
permuted_xyz = torch.cat(permuted_xyz, dim=1)
permuted_atom_mask = torch.cat(permuted_atom_mask, dim=1)
new_label_seq.append(permuted_seq)
new_label_xyz.append(permuted_xyz)
new_label_atom_mask.append(permuted_atom_mask)
new_label_seq = torch.cat(new_label_seq, dim=0)
new_label_xyz = torch.cat(new_label_xyz, dim=0)
new_label_atom_mask = torch.cat(new_label_atom_mask, dim=0)
new_label = Protein(
sequence = ProteinSequence(sequence = new_label_seq),
structure = ProteinStructure(xyz = new_label_xyz, atom_mask = new_label_atom_mask, chain_break = cropped_chain_break,),
)
return new_label
if __name__ == "__main__":
ID = "6vw2_F"
permute_label_input = torch.load(f"permute_label_input_{ID}.pt")
protein_list = permute_label_input['protein_list']
crop_idx = permute_label_input['crop_idx']
out_of_sequence_idxs = permute_label_input['out_of_sequence_idxs']
chain_break = permute_label_input['chain_break']
same_chain_info = permute_label_input['same_chain_info']
new_label_list = permute_label(protein_list, crop_idx, out_of_sequence_idxs, chain_break, same_chain_info)