Source code for miniworld.feature.MiniWorld_featuring_species

from miniworld.utils.ProteinClass import *
import time
import random
import sys
from numbers import Number
from collections import deque
from collections.abc import Set, Mapping
import tracemalloc
import linecache
from miniworld.utils.chemical import INIT_CRDS

UNK_IDX = 20
GAP_IDX = 21
MASK_IDX = 22
NUM_OF_CLASS = 23

[docs] def display_top(snapshot, key_type='lineno', limit=10): snapshot = snapshot.filter_traces(( tracemalloc.Filter(False, "<frozen importlib._bootstrap>"), tracemalloc.Filter(False, "<unknown>"), )) top_stats = snapshot.statistics(key_type) print("Top %s lines" % limit) for index, stat in enumerate(top_stats[:limit], 1): frame = stat.traceback[0] print("#%s: %s:%s: %.1f KiB" % (index, frame.filename, frame.lineno, stat.size / 1024)) line = linecache.getline(frame.filename, frame.lineno).strip() if line: print(' %s' % line) other = top_stats[limit:] if other: size = sum(stat.size for stat in other) print("%s other: %.1f KiB" % (len(other), size / 1024)) total = sum(stat.size for stat in top_stats) print("Total allocated size: %.1f KiB" % (total / 1024))
[docs] def cluster_sum(data, assignment, N_seq, N_res): # Get statistics from clustering results (clustering extra sequences with seed sequences) csum = torch.zeros( (N_seq, N_res, data.shape[-1]), device=data.device ).scatter_add( 0, assignment.view(-1,1,1).expand(-1,N_res,data.shape[-1]), data.float() ) return csum
[docs] def chain_break_cropping(chain_break, crop_idx): crop_idx = crop_idx.tolist() crop_idx_set = set(crop_idx) new_chain_break = OrderedDict() chain_to_crop_idx = OrderedDict() before_chain_end = -1 for chain_idx, (chain_start, chain_end) in chain_break.items(): chain_idx_list = list(range(chain_start, chain_end+1)) intersection = crop_idx_set.intersection(chain_idx_list) intersection = list(intersection) if len(intersection) == 0: continue chain_start = crop_idx.index(min(intersection)) chain_end = crop_idx.index(max(intersection)) new_chain_break[chain_idx] = (chain_start, chain_end) chain_to_crop_idx[chain_idx] = torch.sort(torch.tensor(intersection))[0] return new_chain_break, chain_to_crop_idx
[docs] def MSA_block_deletion(msa, insertion, nb=5): ''' Down-sample given MSA by randomly delete blocks of sequences Input: MSA/Insertion having shape (N, L) output: new MSA/Insertion with block deletion (N', L) ''' N, L = msa.shape block_size = max(int(N*0.1), 1) block_start = np.random.randint(low=1, high=N, size=nb) # (nb) to_delete = block_start[:,None] + np.arange(block_size)[None,:] to_delete = np.unique(np.clip(to_delete, 1, N-1)) # mask = np.ones(N, np.bool_) # for numpy 1.24.1 mask[to_delete] = 0 return msa[mask], insertion[mask]
[docs] @torch.no_grad() def MSA_featurize_wo_statistics_by_chain(msa, insertion, N_clust, params): ''' I modified RF2 version. (just changed name of variables) Input: full MSA information (after Block deletion if necessary) & full insertion information Output: seed MSA features & extra sequences msa : (N, L) torch.LongTensor ins : (N, L) torch.LongTensor params : list of parameters p_mask : probability of masking eps : small number to avoid zero division chain_break : dictionary of chain idx {chain_id: (start, end)} Seed MSA features: - aatype of seed sequence (20 regular aa + 1 gap + 1 unknwon + 1 mask) - profile of clustered sequences (23) => removed - insertion statistics (2) => removed statistics, only use insertion_clust - N-term or C-term? (2) extra sequence features: - aatype of extra sequence (23) - insertion info (1) - N-term or C-term? (2) ''' NUM_OF_CLASS = 23 N, _ = msa.shape if N == 0: print(f"msa.shape: {msa.shape}") print(f"insertion.shape: {insertion.shape}") raise ValueError("Error in MSA_featurize_wo_statistics_by_chain") # if crop_idx is not None: # L = crop_idx.shape[0] p_mask = params['PROB_MASK'] if 'PROB_MASK' in params else 0.15 eps = params['EPS'] if 'EPS' in params else 1e-6 # TODO Question: Why should I add term_info? # ---------------------------------------- # I remove term_info relative codes, but it seems to be fine. # term_info = torch.zeros((L,2), device=msa.device).float() # print(f"chain_break: {chain_break}") # print(f"crop_idx: {crop_idx}, crop_idx.shape: {crop_idx.shape}") # chain_break = chain_break_cropping(chain_break, crop_idx) # print(f"chain_break: {chain_break}") # for chain_id, (start, end) in chain_break.items(): # term_info[start,0] = 1.0 # term_info[end,1] = 1.0 # ---------------------------------------- # TODO 20231001 PSK # Cropping. # if crop_idx is not None: # msa = msa[:, crop_idx] #(N, L_crop) # insertion = insertion[:, crop_idx] #(N, L_crop) L = msa.shape[1] # Remove empty sequences # raw MSA profile test_idx = 0 msa = msa.long() try: full_raw_profile = torch.nn.functional.one_hot(msa, num_classes=NUM_OF_CLASS - 1) # (N, L, 23 - 1), without MASK except: print(f"msa.shape: {msa.shape} | msa dtype: {msa.dtype} | msa device: {msa.device}") raise ValueError("Error in MSA_featurize_wo_statistics") raw_profile = full_raw_profile.float().mean(dim=0) # (L, 23 - 1) # summed_profile = torch.zeros((msa.shape[1], NUM_OF_CLASS - 1)) # for i in range(NUM_OF_CLASS - 1): # summed_profile[:, i] = (msa == i).sum(dim=0).float() # # Normalize by the number of sequences to get the average # raw_profile = summed_profile / msa.shape[0] # Nclust sequences will be selected randomly as a seed MSA (aka latent MSA) # - First sequence is always query sequence # - the rest of sequences are selected randomly USE_MASK = params['USE_MASK'] for i_cycle in range(params['MAX_MSA_CYCLE']): sample = torch.randperm(N-1, device=msa.device) # (N-1) # 20231218 PSK msa_species_idx = torch.cat((torch.tensor([0]), sample[:N_clust-1])) # (Nclust) msa_clust = torch.cat((msa[:1,:], msa[1:,:][sample[:N_clust-1]]), dim=0) # (Nclust, L) insertion_clust = torch.cat((insertion[:1,:], insertion[1:,:][sample[:N_clust-1]]), dim=0) # (Nclust, L) if USE_MASK: # TODO # 15% random masking # - 10%: aa replaced with a uniformly sampled random amino acid # - 10%: aa replaced with an amino acid sampled from the MSA profile # - 10%: not replaced # - 70%: replaced with a special token ("mask") random_aa = torch.tensor([[0.05]*20 + [0.0]*2], device=msa.device) # (1, 22) same_aa = torch.nn.functional.one_hot(msa_clust, num_classes=NUM_OF_CLASS - 1) # (Nclust, L, 22) probs = 0.1*random_aa + 0.1*raw_profile + 0.1*same_aa # (Nclust, L, 22) probs = torch.nn.functional.pad(probs, (0, 1), "constant", 0.7) # (Nclust, L, 23) sampler = torch.distributions.categorical.Categorical(probs=probs) mask_sample = sampler.sample() mask_pos = torch.rand(msa_clust.shape, device=msa_clust.device) < p_mask # (Nclust, L) # 20231115 PSK : don't mask at the UNK position UNK_pos = (msa_clust == UNK_IDX) # (Nclust, L) mask_pos = mask_pos & ~UNK_pos msa_masked = torch.where(mask_pos, mask_sample, msa_clust) # (Nclust, L) else: mask_pos = torch.zeros_like(msa_clust, dtype=torch.bool, device=msa_clust.device) msa_masked = msa_clust.clone() start_time = time.time() # get number of identical tokens for each pair of sequences (extra vs seed) # Tensor shape summary # msa_clust: (N_clust, L) # mask_pos: (N_clust, L) # msa_masked: (N_clust, L) # insertion_clust : (N_clust, L) # msa_clust_onehot: (N_clust, L, 22) # count_clust: (N_clust, L) # term_info : (L, 2) => 따로 처리 # 1. one_hot encoded aatype: msa_clust_onehot msa_clust_onehot = torch.nn.functional.one_hot(msa_masked, num_classes=NUM_OF_CLASS) # (N, L_crop, 23) insertion_clust = (2.0/np.pi)*torch.arctan(insertion_clust.float()/3.0) # (from 0 to 1) insertion_clust = insertion_clust.unsqueeze(-1) # (N_clust, L_crop, 1) # seed MSA features (one-hot aa, cluster profile, ins statistics, terminal info) # msa_seed = torch.cat((msa_clust_onehot, msa_clust_profile, insertion_clust, term_info[None].expand(N_clust,-1,-1)), dim=-1) # TODO Term_info?? # msa_seed = torch.cat((msa_clust_onehot, insertion_clust, term_info[None].expand(N_clust,-1,-1)), dim=-1) msa_seed = torch.cat((msa_clust_onehot, insertion_clust), dim=-1) # (N_clust, L, 23 + 1) # sample_seq.append(msa_masked[0].clone()) # sample_msa_clust.append(msa_clust.clone()) # sample_msa_seed.append(msa_seed.clone()) # sample_mask_pos.append(mask_pos.clone()) if i_cycle == 0 : sample_seq = torch.zeros(params['MAX_MSA_CYCLE'], L) # (MAXCYCLE, L) sample_msa_species = torch.zeros(params['MAX_MSA_CYCLE'], N_clust) # (MAXCYCLE, Nclust) sample_msa_clust = torch.zeros(params['MAX_MSA_CYCLE'], N_clust, L) # (MAXCYCLE, Nclust, L) sample_msa_seed = torch.zeros(params['MAX_MSA_CYCLE'],N_clust,L,NUM_OF_CLASS+1) # (MAXCYCLE, Nclust, L, 23 + 1) sample_mask_pos = torch.zeros(params['MAX_MSA_CYCLE'],N_clust,L) # (MAXCYCLE, Nclust, L) sample_seq[i_cycle] = msa[0] sample_msa_species[i_cycle] = msa_species_idx sample_msa_clust[i_cycle] = msa_clust sample_msa_seed[i_cycle] = msa_seed sample_mask_pos[i_cycle] = mask_pos test_idx += 1 # print(f"== \t test_idx : { test_idx } | difference : {start_avaialble_memory - end_avaialble_memory} GB") output_dict = { 'sample_seq': sample_seq, # (MAXCYCLE, L_crop) 'sample_msa_species': sample_msa_species, # (MAXCYCLE, Nclust) 'sample_msa_clust': sample_msa_clust, # (MAXCYCLE, Nclust, L_crop) 'sample_msa_seed': sample_msa_seed, # (MAXCYCLE, Nclust, L_crop, 23 + 1) 'sample_mask_pos': sample_mask_pos # (MAXCYCLE, Nclust, L_crop) } return output_dict
[docs] @torch.no_grad() def MSA_featurize_wo_statistics(msa, insertion, chain_to_idx_dict, params): """I modified RF2 version. (just changed name of variables) Input: full MSA information (after Block deletion if necessary) & full insertion information Output: seed MSA features & extra sequences. :param msa: Full MSA tensor. :type msa: torch.LongTensor :param insertion: Full insertion tensor. :type insertion: torch.LongTensor :param chain_to_idx_dict: Dictionary mapping chain ID to residue indices. :type chain_to_idx_dict: dict :param params: Dictionary of parameters. :type params: dict **Seed MSA features:** - aatype of seed sequence (20 regular aa + 1 gap + 1 unknwon + 1 mask) - profile of clustered sequences (23) => removed - insertion statistics (2) => removed statistics, only use insertion_clust - N-term or C-term? (2) **extra sequence features:** - aatype of extra sequence (23) - insertion info (1) - N-term or C-term? (2) """ # Remove empty sequences mask = (msa != UNK_IDX).any(dim=-1) # (N) msa = msa[mask] # (N', L_crop) insertion = insertion[mask] # (N', L_crop) # msa_block_deleted = msa[:1] # (1, L) # insertion_block_deleted = insertion[:1] # (1, L) # for chain_idx, (start, end) in chain_to_idx_dict.items(): # msa_chain = msa[:, start:end+1] # (N, L_crop) # mask = (msa_chain != GAP_IDX).any(dim=-1) # (N) # msa_chain = msa[mask] # (N', L_crop) # insertion_chain = insertion[mask] # (N', L_crop) # depth = mask.sum().item() # if depth > params["MIN_SEED_MSA_PER_CHAIN"]: # msa_chain, insertion_chain = MSA_block_deletion(msa_chain, insertion_chain, nb=5) # msa_block_deleted = torch.cat((msa_block_deleted, msa_chain[1:]), dim=0) # insertion_block_deleted = torch.cat((insertion_block_deleted, insertion_chain[1:]), dim=0) # msa = msa_block_deleted # insertion = insertion_block_deleted depth_per_chain = {} for chain_idx, (start, end) in chain_to_idx_dict.items(): msa_chain = msa[:, start:end+1] # (N, L_crop) mask = (msa_chain != UNK_IDX).any(dim=-1) # (N) depth = mask.sum().item() depth_per_chain[chain_idx] = depth remaining_depth = params["MAX_SEED_MSA"] enough_msa_chain_num = 0 for chain_idx, depth in depth_per_chain.items(): if depth < params["MIN_SEED_MSA_PER_CHAIN"]: remaining_depth -= depth else : enough_msa_chain_num += 1 if remaining_depth < 0: print(f"chain_num : {depth_per_chain}") print(f"remaining_depth : {remaining_depth}") raise ValueError("Error in MSA_featurize_wo_statistics") max_chain_depth = remaining_depth // enough_msa_chain_num if enough_msa_chain_num > 0 else remaining_depth depth_per_chain = {k : min(v, max_chain_depth) for k, v in depth_per_chain.items()} output_dict = { 'sample_seq': torch.full((params['MAX_MSA_CYCLE'],msa.shape[1]),fill_value = UNK_IDX), # (MAXCYCLE, L_crop) 'sample_msa_species': torch.zeros((params['MAX_MSA_CYCLE'],1)), # (MAXCYCLE, Nclust) 'sample_msa_clust': torch.full((params['MAX_MSA_CYCLE'],1, msa.shape[1]),fill_value = UNK_IDX), # (MAXCYCLE, Nclust, L_crop) 'sample_msa_seed': torch.zeros((params['MAX_MSA_CYCLE'],1, msa.shape[1], 24)), # (MAXCYCLE, Nclust, L_crop, 23 + 1) 'sample_mask_pos': torch.zeros((params['MAX_MSA_CYCLE'],1, msa.shape[1])) # (MAXCYCLE, Nclust, L_crop) } chain_to_msa_depth = {} for chain_idx, (start, end) in chain_to_idx_dict.items(): depth = depth_per_chain[chain_idx] if depth == 0 : continue chain_to_msa_depth[chain_idx] = depth msa_chain = msa[:, start:end+1] # (N, L_chain) mask = (msa_chain != UNK_IDX).any(dim=-1) # (N) msa_chain = msa[mask] # (N', L_crop) insertion_chain = insertion[mask] # (N', L_crop) chain_dict = MSA_featurize_wo_statistics_by_chain(msa_chain, insertion_chain, depth, params) for key, value in chain_dict.items(): if key == "sample_seq": output_dict[key][:,start:end+1] = value[:,start:end+1] elif key == "sample_msa_species": output_dict[key] = torch.cat((output_dict[key], value[:,1:]), dim=1) else : output_dict[key][:,0,start:end+1] = value[:,0,start:end+1] output_dict[key] = torch.cat((output_dict[key], value[:,1:]), dim=1) output_dict["chain_to_msa_depth"] = chain_to_msa_depth return output_dict
ZERO_DEPTH_BASES = (str, bytes, Number, range, bytearray)
[docs] def getsize(obj_0): """Recursively iterate to sum size of object & members.""" _seen_ids = set() def inner(obj): if isinstance(obj, torch.Tensor): return obj.element_size() * obj.nelement() obj_id = id(obj) if obj_id in _seen_ids: return 0 _seen_ids.add(obj_id) size = sys.getsizeof(obj) if isinstance(obj, ZERO_DEPTH_BASES): pass # bypass remaining control flow and return elif isinstance(obj, (tuple, list, Set, deque)): size += sum(inner(i) for i in obj) elif isinstance(obj, Mapping) or hasattr(obj, 'items'): size += sum(inner(k) + inner(v) for k, v in getattr(obj, 'items')()) # Check for custom object instances - may subclass above too if hasattr(obj, '__dict__'): size += inner(vars(obj)) if hasattr(obj, '__slots__'): # can have __slots__ with __dict__ size += sum(inner(getattr(obj, s)) for s in obj.__slots__ if hasattr(obj, s)) return size return inner(obj_0)
[docs] def center_and_realign_missing(xyz, mask_t): # I don't modify this function. # xyz: (L, 27, 3) # mask_t: (L, 27) L = xyz.shape[0] mask = mask_t[:,:3].all(dim=-1) # True for valid atom (L) # center c.o.m at the origin center_CA = (mask[...,None]*xyz[:,1]).sum(dim=0) / (mask[...,None].sum(dim=0) + 1e-5) # (3) xyz = torch.where(mask.view(L,1,1), xyz - center_CA.view(1, 1, 3), xyz) # move missing residues to the closest valid residues exist_in_xyz = torch.where(mask)[0] # L_sub seqmap = (torch.arange(L, device=xyz.device)[:,None] - exist_in_xyz[None,:]).abs() # (L, Lsub) seqmap = torch.argmin(seqmap, dim=-1) # L idx = torch.gather(exist_in_xyz, 0, seqmap) offset_CA = torch.gather(xyz[:,1], 0, idx.reshape(L,1).expand(-1,3)) xyz = torch.where(mask.view(L,1,1), xyz, xyz + offset_CA.reshape(L,1,3)) return xyz
[docs] def template_featurize(input_template_dict, params): """I modified RF2 version. In MSA_featurize, I changed the name of variables and a small part of code because the shape of inputs (msa, insertion) are almost same as RF2. On the other hand, I totally reconstructed template structure, so I changed a lot in this function. .. note:: Processes template information for a single chain. :param input_template_dict: A dictionary containing template information. It should have the following keys: - 'xyz': torch.Tensor of shape (N_template, L_chain, 27, 3) - 'mask': torch.Tensor of shape (N_template, L_chain, 27) - 'sequence': torch.Tensor of shape (N_template, L_chain, NUM_CLASSES) - 'f0d': torch.Tensor of shape (N_template) - 'f1d': torch.Tensor of shape (N_template, L_chain) :type input_template_dict: dict :param params: Dictionary of parameters. :type params: dict :return: A dictionary with processed template features: - 'xyz': torch.Tensor of shape (npick_global, L_query, 27, 3) - 'template_1D': torch.Tensor of shape (npick_global, L_query, 23 + 1) - 'template_atom_mask': torch.Tensor of shape (npick_global, L_query, 27) :rtype: dict """ npick = params["N_PICK"] if "N_PICK" in params else 1 npick_global = params["N_PICK_GLOBAL"] if "N_PICK_GLOBAL" in params else None assert npick <= npick_global pick_top = params["PICK_TOP"] if "PICK_TOP" in params else True random_noise = params["RANDOM_NOISE"] if "RANDOM_NOISE" in params else 5.0 if npick_global == None: npick_global=max(npick, 1) seqID_cut = params['SEQUENCE_IDENTITY_CUTOFF'] if 'SEQUENCE_IDENTITY_CUTOFF' in params else 100.0 NUM_CLASSES = 23 N_template = input_template_dict['xyz'].shape[0] chain_length = input_template_dict['xyz'].shape[1] if (N_template < 1) or (npick < 1): # no templates in hhsearch file or not want to use templ - return fake templ xyz = INIT_CRDS.reshape(1,1,27,3).repeat(npick_global,chain_length,1,1) + torch.rand(npick_global,chain_length,1,3)*random_noise template_1D = torch.nn.functional.one_hot(torch.full((npick_global, chain_length), GAP_IDX).long(), num_classes=NUM_CLASSES).float() # all gaps conf = torch.zeros((npick_global, chain_length, 1)).float() template_1D = torch.cat((template_1D, conf), -1) template_atom_mask = torch.full((npick_global,chain_length,27), False) output_dict = { 'xyz': xyz, # (npick_global, L_chain, 27, 3) 'template_1D': template_1D, # (npick_global, L_chain, NUM_CLASSES + 1) 'template_atom_mask': template_atom_mask # (npick_global, L_chain, 27) } return output_dict # ignore templates having too high seqID if seqID_cut <= 100.0: template_valid_idx = torch.where(input_template_dict['f0d'] < seqID_cut)[0] else : template_valid_idx = torch.arange(N_template) # check again if there are templates having seqID < cutoff N_template = template_valid_idx.shape[0] npick = min(npick, N_template) if npick<1: # no templates -- return fake templ # xyz = INIT_CRDS.reshape(1,1,27,3).repeat(npick_global,chain_length,1,1) + torch.rand(npick_global,chain_length,1,3)*random_noise xyz = INIT_CRDS.reshape(1,1,27,3).repeat(npick_global,chain_length,1,1) + torch.rand(npick_global,chain_length,1,3)*random_noise template_1D = torch.nn.functional.one_hot(torch.full((npick_global, chain_length), GAP_IDX).long(), num_classes=NUM_CLASSES).float() # all gaps conf = torch.zeros((npick_global, chain_length, 1)).float() template_1D = torch.cat((template_1D, conf), -1) template_atom_mask = torch.full((npick_global,chain_length,27), False) output_dict = { 'xyz': xyz, # (npick_global, L_chain, 27, 3) 'template_1D': template_1D, # (npick_global, L_chain, NUM_CLASSES + 1) 'template_atom_mask': template_atom_mask # (npick_global, L_chain, 27) } return output_dict if not pick_top: # select randomly among all possible templates sample = torch.randperm(N_template)[:npick] else: # only consider top 50 templates sample = torch.randperm(min(50,N_template))[:npick] # TODO. Question in Notion 20230620 xyz = INIT_CRDS.reshape(1,1,27,3).repeat(npick_global,chain_length,1,1) + torch.rand(npick_global,chain_length,1,3)*random_noise # ??? why same noise to all templates template_atom_mask = torch.full((npick_global,chain_length,27), False) # True for valid atom, False for missing atom template_1D = torch.full((npick_global,chain_length,NUM_CLASSES), GAP_IDX).long() # (npick_global, L_chain, NUM_CLASSES) t1d_val = torch.zeros((npick_global, chain_length)).float() for i,nt in enumerate(sample): template_idx = template_valid_idx[nt] template_xyz = input_template_dict['xyz'][template_idx] # (L_chain, 27, 3) template_mask = input_template_dict['mask'][template_idx] # (L_chain, 27) template_seq = input_template_dict['sequence'][template_idx] # (L_chain, NUM_CLASSES) (already one-hot encoded) template_f1d = input_template_dict['f1d'][template_idx] # (L_chain) xyz[i] = center_and_realign_missing(template_xyz, template_mask) template_atom_mask[i] = template_mask template_1D[i] = template_seq t1d_val[i] = template_f1d template_1D = torch.cat((template_1D, t1d_val[...,None]), dim=-1) # (npick_global, L_chain, NUM_CLASSES + 1) output_dict = { 'xyz': xyz, # (npick_global, L_chain, 27, 3) 'template_1D': template_1D, # (npick_global, L_chain, NUM_CLASSES + 1) 'template_atom_mask': template_atom_mask # (npick_global, L_chain, 27) } return output_dict
# slice long chains
[docs] def get_crop(chain_start, chain_end, mask, device, params, unclamp = False, ID = None): # chain_start : idx of first residue in chain # chain_end : idx of last residue in chain # mask : (Model_num, L, 27, 3) l = chain_end - chain_start + 1 sel = torch.arange(chain_start, chain_end+1).to(device) crop_size = params['CROP'] if l <= crop_size: return sel random_model = torch.randint(mask.shape[0], (1,)).item() mask = mask[random_model,chain_start:chain_end+1] # (L, 27, 3) mask = ~(mask[:,:3].sum(dim=-1) < 3.0) exists = mask.nonzero()[:,0] # print(f"exists shape: {exists.shape}") if (len(exists) == 0): print(f"exists: {exists}") print(f"exists: {exists.shape}") print(f"exists len: {len(exists)}") print(f"mask sum : {mask.sum()}") print(f"ID : {ID}") x = 0 else : x = np.random.randint(len(exists)) + 1 # print(f"ffasdfa x: {x}") crop_size = params['CROP'] if unclamp: # bias it toward N-term.. (follow what AF did.. but don't know why) # x = np.random.randint(len(exists)) + 1 res_idx = exists[torch.randperm(x)[0]].item() else: res_idx = exists[torch.randperm(len(exists))[0]].item() res_idx += chain_start lower_bound = max(chain_start, res_idx-crop_size) upper_bound = min(chain_end-crop_size+1, res_idx+1) start = np.random.randint(lower_bound, upper_bound) end = min(start+crop_size, chain_end+1) return torch.arange(start, end).to(device)
[docs] def random_split(n, k, min_split): # Adjust n and the range of the random sample to account for min_split n_adjusted = n - min_split * k # Get k-1 random points in the range 1 to n_adjusted points = sorted(random.sample(range(1, n_adjusted + 1), k - 1)) # Add 0 to the start and n_adjusted to the end points = [0] + points + [n_adjusted] # Get the differences between consecutive points and add min_split to each splits = [points[i+1] - points[i] + min_split for i in range(len(points) - 1)] return splits
[docs] def get_complex_crop(len_s, mask, device, params): # randonly select chain and crop. # this function only works for complex, not for single chain. # set seed np.random.seed(params["SEED"]) if mask is None : mask = torch.ones((1, sum(len_s), 27), device=device) if mask.dtype == torch.bool: mask = mask.float() if len(mask.shape) == 3: # (Model_num, L, 27) tot_len = mask.shape[1] model_idx = np.random.randint(mask.shape[0]) mask = mask[model_idx] else: # (L, 27) tot_len = mask.shape[0] assert len(len_s) > 1, "This function only works for complex, not for single chain." sel = torch.arange(tot_len, device=device) # filter valid chain (remove all masked chain) valid_chain_idxs = list() preset = 0 for ii in range(len(len_s)): mask_chain = ~(mask[preset:preset+len_s[ii],:3].sum(dim=-1) < 3.0) exist_chain = mask_chain.nonzero()[:,0] if len(exist_chain) > 0: valid_chain_idxs.append(ii) preset += len_s[ii] chain_number = len(valid_chain_idxs) assert chain_number > 0, "There is no valid chain." random_chain_min = params["RANDOM_CHAIN_MIN"] # 2 random_chain_max = params["RANDOM_CHAIN_MAX"] # 4 random_chain = np.random.randint(random_chain_min, random_chain_max+1) random_chain = min(random_chain, chain_number) # sample without replacement from valid_chain_idxs random_chain_idx = np.random.choice(valid_chain_idxs, random_chain, replace=False) crop_min = params["CROP_MIN"] # 10 crop_length = params["CROP"] # 128 cropped_length = random_split(crop_length, random_chain, crop_min) chain_crop_length_dict = dict(zip(random_chain_idx, cropped_length)) n_added = 0 n_remaining = sum(len_s) preset = 0 sel_s = list() for k in range(len(len_s)): if k in random_chain_idx: n_remaining -= len_s[k] crop_size = chain_crop_length_dict[k] if crop_size > len_s[k]: sel_s.append(sel[preset:preset+len_s[k]]) else: n_added += crop_size mask_chain = ~(mask[preset:preset+len_s[k],:3].sum(dim=-1) < 3.0) exists = mask_chain.nonzero()[:,0] res_idx = exists[torch.randperm(len(exists))[0]].item() lower_bound = max(0, res_idx - crop_size + 1) upper_bound = min(len_s[k]-crop_size, res_idx) + 1 try: start = np.random.randint(lower_bound, upper_bound) + preset except : print (lower_bound, upper_bound, preset, len_s[k], crop_size) # 0, -2, 351, 351, 354 print(cropped_length) raise ValueError sel_s.append(sel[start:start+crop_size]) preset += len_s[k] return torch.cat(sel_s)
[docs] def get_STRING_crop(len_s, mask, device, params): # crop algorithm for STRING # It should contains two chains. np.random.seed(params["SEED"]) if mask is None : mask = torch.ones((1, sum(len_s), 27), device=device) if mask.dtype == torch.bool: mask = mask.float() if len(mask.shape) == 3: # (Model_num, L, 27) tot_len = mask.shape[1] model_idx = np.random.randint(mask.shape[0]) mask = mask[model_idx] else: # (L, 27) tot_len = mask.shape[0] assert len(len_s) == 2, "This function only works for STRING data" sel = torch.arange(tot_len, device=device) # filter valid chain (remove all masked chain) valid_chain_idxs = list() preset = 0 for ii in range(len(len_s)): mask_chain = ~(mask[preset:preset+len_s[ii],:3].sum(dim=-1) < 3.0) exist_chain = mask_chain.nonzero()[:,0] if len(exist_chain) > 0: valid_chain_idxs.append(ii) chain_number = len(valid_chain_idxs) assert chain_number == 2, "This function only works for STRING data" # random_chain_min = params["RANDOM_CHAIN_MIN"] # 2 # random_chain_max = params["RANDOM_CHAIN_MAX"] # 4 # random_chain = np.random.randint(random_chain_min, random_chain_max+1) # random_chain = min(random_chain, chain_number) # sample without replacement from valid_chain_idxs # random_chain_idx = np.random.choice(valid_chain_idxs, random_chain, replace=False) chain_idxs = valid_chain_idxs crop_min = params["CROP_MIN"] # 10 # TODO PSK crop_min = 30 crop_length = params["CROP"] # 128 cropped_length = random_split(crop_length, 2, crop_min) chain_crop_length_dict = dict(zip(chain_idxs, cropped_length)) n_added = 0 n_remaining = sum(len_s) preset = 0 sel_s = list() for k in range(len(len_s)): if k in chain_idxs: n_remaining -= len_s[k] crop_size = chain_crop_length_dict[k] if crop_size > len_s[k]: sel_s.append(sel[preset:preset+len_s[k]]) else: n_added += crop_size mask_chain = ~(mask[preset:preset+len_s[k],:3].sum(dim=-1) < 3.0) exists = mask_chain.float().nonzero()[:,0] res_idx = exists[torch.randperm(len(exists))[0]].item() lower_bound = max(0, res_idx - crop_size + 1) upper_bound = min(len_s[k]-crop_size, res_idx) + 1 try: start = np.random.randint(lower_bound, upper_bound) + preset except : print (lower_bound, upper_bound, preset, len_s[k], crop_size) # 0, -2, 351, 351, 354 print(cropped_length) raise ValueError sel_s.append(sel[start:start+crop_size]) preset += len_s[k] return torch.cat(sel_s)
[docs] def cutoff_chain_num(sel, xyz, chain_break, params, query_chain_idx): chain_to_residue_idxs = {} for chain_idx, (chain_start, chain_end) in chain_break.items(): chain_residue_idxs = torch.arange(chain_start, chain_end+1) if torch.isin(sel, chain_residue_idxs).any(): intersection = torch.isin(sel, chain_residue_idxs) chain_residue_idxs = sel[intersection] chain_to_residue_idxs[chain_idx] = chain_residue_idxs max_chain = params["MAX_CHAIN"] if "MAX_CHAIN" in params else 4 length_cutoff = params["CHAIN_LENGTH_CUTOFF"] if "CHAIN_LENGTH_CUTOFF" in params else 5 if len(chain_to_residue_idxs) > max_chain: distance_between_query_chain = {} query_chain_xyz = xyz[chain_to_residue_idxs[query_chain_idx], 1] # (L_query, 3) for chain_idx in chain_to_residue_idxs.keys(): if chain_idx == query_chain_idx: continue if len(chain_to_residue_idxs[chain_idx]) < length_cutoff: continue chain_xyz = xyz[chain_to_residue_idxs[chain_idx], 1] # (L_chain, 3) dist = torch.cdist(query_chain_xyz, chain_xyz) # (L_query, L_chain) min_dist = torch.mean(dist) # (1) distance_between_query_chain[chain_idx] = min_dist nearest_chains = sorted(distance_between_query_chain.items(), key=lambda x: x[1])[:max_chain-1] nearest_chains = [query_chain_idx] + [chain_idx for chain_idx, _ in nearest_chains] filtered_sel = [] for chain_idx in nearest_chains: filtered_sel.append(chain_to_residue_idxs[chain_idx]) filtered_sel = torch.cat(filtered_sel) filtered_sel = torch.sort(filtered_sel)[0] sel = filtered_sel return sel
# def get_spatial_crop(xyz, mask, chain_break, len_s, params, protein_ID, cutoff=10.0, eps=1e-6): # # xyz : (Model_num, L, 14, 3) or (Model_num, L, 14, 3) # # mask : (Model_num, L, 14) or None # tot_len = xyz.shape[1] # model_idx = np.random.randint(xyz.shape[0]) # device = xyz.device # if mask is None : # mask = torch.bool((xyz.shape[0], xyz.shape[1], xyz.shape[2]), device=device) # if mask.dtype != torch.bool: # mask = mask.bool() # if "5mwe" in protein_ID : breakpoint() # # model sampling # xyz = xyz[model_idx] # (L, 14, 3) # mask = mask[model_idx] # (L, 14) # position_mask = mask[:, 1] if mask is not None else torch.ones((xyz.shape[0], xyz.shape[1]), device=xyz.device) # valid_residue_idxs = torch.where(position_mask)[0] # if len(valid_residue_idxs) <= params['CROP']: # if "5mwe" in protein_ID : breakpoint() # query_chain_idx = torch.randint(len(chain_break), (1,)).item() # query_chain_idx = list(chain_break.keys())[query_chain_idx] # sel = cutoff_chain_num(valid_residue_idxs, xyz, chain_break, params, query_chain_idx) # if len(valid_residue_idxs) < 1: # print("ERROR: no valid residue????", protein_ID) # return sel, "PDB_Spatial" # sel = torch.arange(tot_len, device=device) # chain_number = len(len_s) # chain_list = list(chain_break.keys()) # random_chain_min = params["RANDOM_CHAIN_MIN"] # 2 # random_chain_max = params["RANDOM_CHAIN_MAX"] # 4, Acutally these two values are not that important. # random_chain_number = np.random.randint(random_chain_min, random_chain_max+1) # random_chain_number = min(random_chain_number, chain_number) # random_chain_idx = np.random.choice(np.arange(chain_number), random_chain_number, replace = False) # random_chain_idx = [chain_list[i] for i in random_chain_idx] # ifaces = list() # for chain_idx in random_chain_idx: # chain_start, chain_end = chain_break[chain_idx] # in_chain = torch.arange(chain_start, chain_end+1) # try : # out_chain = torch.cat([torch.arange(0, chain_start), torch.arange(chain_end+1, tot_len)]) # except Exception as e: # print(f"protein_ID : {protein_ID}") # print(f"chain_break : {chain_break}") # print(f"chain_start : {chain_start}") # print(f"chain_end : {chain_end}") # print(f"tot_len : {tot_len}") # print(f"error : {e}") # raise ValueError # cond = torch.cdist(xyz[in_chain,1], xyz[out_chain,1]) < cutoff # (L_in, L_out) # # import matplotlib.pyplot as plt # # cond_path = f"cond_{protein_ID}_{chain_idx}.png" # # mask_path = f"mask_{protein_ID}_{chain_idx}.png" # # plt.imshow(cond.cpu().numpy()) # # plt.savefig(cond_path) # # plt.imshow(mask.cpu().numpy()) # # plt.savefig(mask_path) # cond = torch.logical_and(cond, mask[in_chain,None,1]*mask[None,out_chain,1]) # (L_in, L_out) # i,_ = torch.where(cond) # (N_ifaces) # chain_ifaces = i + chain_start # (N_ifaces) # ifaces.append(chain_ifaces) # ifaces = torch.cat(ifaces) # (N_ifaces) # if len(ifaces) < 1: # print ("ERROR: no iface residue????", protein_ID) # print(f"random_chain_idx : {random_chain_idx}") # print(f"label_xyz.shape : {xyz.shape}") # print(f"label_mask.shape : {mask.shape}") # chain_start, chain_end = chain_break[random_chain_idx[0]] # mask = mask.unsqueeze(0) # crop_idx = get_crop(chain_start, chain_end, mask, device, params, ID = None) # # crop_idx = get_complex_crop(len_s, mask, device, params) # if len(crop_idx) == 0: # print(f"len_s : {len_s}") # print(f"mask.shape : {mask.shape}") # print("WTF") # # return crop_idx, "PDB_Complex" # return crop_idx, "PDB_Monomer" # if "5mwe" in protein_ID : breakpoint() # cnt_idx = ifaces[np.random.randint(len(ifaces))] # # 20240130 PSK Add chain number condition # cnt_idx = ifaces[np.random.randint(len(ifaces))] # for chain_idx, (chain_start, chain_end) in chain_break.items(): # if chain_start <= cnt_idx <= chain_end: # query_chain_idx = chain_idx # break # if "5mwe" in protein_ID : breakpoint() # dist = torch.cdist(xyz[:,1], xyz[cnt_idx,1][None]).reshape(-1) + torch.arange(len(xyz), device=xyz.device)*eps # cond = mask[:,1]*mask[cnt_idx,1] # dist[~cond] = 999999.9 # _, idx = torch.topk(dist, params['CROP'], largest=False) # if "5mwe" in protein_ID : breakpoint() # #TODO. Check this function. # sel, _ = torch.sort(sel[idx]) # if "5mwe" in protein_ID : breakpoint() # sel = cutoff_chain_num(sel, xyz, chain_break, params, query_chain_idx) # # Although dist of masked residues are assigned to 999999.9, there is a chance that the residue is selected. # # So I added this code to remove the residue using valid_residue_idxs. # # removed residues are filled by -1(crop mask idx). # # sel = torch.where(torch.isin(sel, valid_residue_idxs), sel, 999999) # # sel = torch.sort(sel)[0] # # # 999999 -> -1 # # sel = torch.where(sel == 999999, -1, sel) # if len(sel) == 0 : # print(f"dist.shape : {dist.shape}") # print(f"idx.shape : {idx.shape}") # print(f"sel.shape : {sel.shape}") # print(f"sel : {sel}") # return sel, "PDB_Spatial"
[docs] def get_spatial_crop(xyz, mask, pivot_chain_idx, chain_break, len_s, params, protein_ID, cutoff=10.0, eps=1e-6): # xyz : (Model_num, L, 14, 3) or (Model_num, L, 14, 3) # mask : (Model_num, L, 14) or None tot_len = xyz.shape[1] model_idx = np.random.randint(xyz.shape[0]) device = xyz.device if mask is None : mask = torch.bool((xyz.shape[0], xyz.shape[1], xyz.shape[2]), device=device) if mask.dtype != torch.bool: mask = mask.bool() # model sampling xyz = xyz[model_idx] # (L, 14, 3) mask = mask[model_idx] # (L, 14) position_mask = mask[:, 1] if mask is not None else torch.ones((xyz.shape[0], xyz.shape[1]), device=xyz.device) valid_residue_idxs = torch.where(position_mask)[0] if len(valid_residue_idxs) <= params['CROP']: sel = cutoff_chain_num(valid_residue_idxs, xyz, chain_break, params, pivot_chain_idx) if len(valid_residue_idxs) < 1: print("ERROR: no valid residue????", protein_ID) return sel, "PDB_Spatial" sel = torch.arange(tot_len, device=device) chain_number = len(len_s) chain_list = list(chain_break.keys()) ifaces = list() chain_start, chain_end = chain_break[pivot_chain_idx] in_chain = torch.arange(chain_start, chain_end+1) try : out_chain = torch.cat([torch.arange(0, chain_start), torch.arange(chain_end+1, tot_len)]) except Exception as e: print(f"protein_ID : {protein_ID}") print(f"chain_break : {chain_break}") print(f"chain_start : {chain_start}") print(f"chain_end : {chain_end}") print(f"tot_len : {tot_len}") print(f"error : {e}") raise ValueError cond = torch.cdist(xyz[in_chain,1], xyz[out_chain,1]) < cutoff # (L_in, L_out) cond = torch.logical_and(cond, mask[in_chain,None,1]*mask[None,out_chain,1]) # (L_in, L_out) i,_ = torch.where(cond) # (N_ifaces) chain_ifaces = i + chain_start # (N_ifaces) ifaces.append(chain_ifaces) ifaces = torch.cat(ifaces) # (N_ifaces) # select ifaces from in_chain ifaces = ifaces[torch.isin(ifaces, in_chain)] ifaces = torch.sort(ifaces)[0] if len(ifaces) < 1: print ("ERROR: no iface residue????", protein_ID) print(f"random_chain_idx : {pivot_chain_idx}") print(f"label_xyz.shape : {xyz.shape}") print(f"label_mask.shape : {mask.shape}") chain_start, chain_end = chain_break[pivot_chain_idx[0]] mask = mask.unsqueeze(0) crop_idx = get_crop(chain_start, chain_end, mask, device, params, ID = None) # crop_idx = get_complex_crop(len_s, mask, device, params) if len(crop_idx) == 0: print(f"len_s : {len_s}") print(f"mask.shape : {mask.shape}") print("WTF") # return crop_idx, "PDB_Complex" return crop_idx, "PDB_Monomer" # 20240130 PSK Add chain number condition cnt_idx = ifaces[np.random.randint(len(ifaces))] dist = torch.cdist(xyz[:,1], xyz[cnt_idx,1][None]).reshape(-1) + torch.arange(len(xyz), device=xyz.device)*eps cond = mask[:,1]*mask[cnt_idx,1] dist[~cond] = 999999.9 _, idx = torch.topk(dist, params['CROP'], largest=False) #TODO. Check this function. sel, _ = torch.sort(sel[idx]) sel = cutoff_chain_num(sel, xyz, chain_break, params, pivot_chain_idx) if len(sel) == 0 : print(f"dist.shape : {dist.shape}") print(f"idx.shape : {idx.shape}") print(f"sel.shape : {sel.shape}") print(f"sel : {sel}") return sel, "PDB_Spatial"
[docs] def find_chain_combinations(pairs): # Build a graph with nodes as chain points (e.g., 'A_1') graph = {} for a, b in pairs: if a not in graph: graph[a] = [] if b not in graph: graph[b] = [] graph[a].append(b) graph[b].append(a) def extract_chain_id(chain_point): return chain_point.split('|')[0] # Extract unique chain IDs chain_ids = set(extract_chain_id(x) for x in graph.keys()) all_combinations = [] def dfs(node, path): # Check if current path covers all chain types if set(extract_chain_id(p) for p in path) == chain_ids: sorted_path = tuple(sorted(path, key=lambda x: extract_chain_id(x))) # Sort path for uniqueness all_combinations.append(sorted_path) return for neighbor in graph[node]: if neighbor not in path: # Allow revisiting chain types but not exact nodes path.append(neighbor) dfs(neighbor, path) path.pop() # Start DFS from each node for start_node in graph.keys(): dfs(start_node, [start_node]) return list(set(all_combinations)) # Remove duplicates by converting to set and back to list
[docs] def get_same_crop_idx(xyz_full, crop_idx, chain_break, same_chain_info, cutoff = 10.0): # xyz_full : (Model_num, L, 14, 3) # crop_idx : (L_crop) # chain_break : {chain_idx : (chain_start, chain_end)} # same_chain_info : {chains : [chain_idx1, chain_idx2, ...], same_chain : torch.tensor, NxN} cropped_chain_break, chain_to_crop_idx = chain_break_cropping(chain_break, crop_idx) cropped_chain_idx_dict = {} chain_list, same_chain = same_chain_info['chains'], same_chain_info['same_chain'] total_idx = [] for chain_idx in cropped_chain_break.keys(): same_chain_idxs = same_chain[chain_list.index(chain_idx)] # tensor (Chain num) same_chain_list = torch.where(same_chain_idxs == 1)[0].tolist() chain_start, chain_end = cropped_chain_break[chain_idx] cropped_length = chain_end - chain_start + 1 for same_chain_idx in same_chain_list: if cropped_length != chain_break[same_chain_idx][1] - chain_break[same_chain_idx][0] + 1: return [crop_idx] original_chain_crop_idx = chain_to_crop_idx[chain_idx] - chain_start for ii, same_chain_idx in enumerate(same_chain_list): same_chain_crop_idx = original_chain_crop_idx + chain_break[same_chain_idx][0] cropped_chain_idx_dict[chain_idx + "|" + str(ii)] = same_chain_crop_idx total_idx.append(same_chain_crop_idx) if len(cropped_chain_break) == 1 : return cropped_chain_idx_dict[cropped_chain_break.keys()[0]] total_idx = torch.tensor(total_idx) cropped_ca = xyz_full[0,total_idx,1,:] # use first model dist_map = torch.cdist(cropped_ca, cropped_ca) # (L_crop, L_crop) contact_map = dist_map < cutoff # (L_crop, L_crop) num_chains = len(cropped_chain_idx_dict) cropped_chain_idx_list = list(cropped_chain_idx_dict.values()) expanded_contact_map = contact_map.unsqueeze(0).unsqueeze(0).expand(num_chains, num_chains, -1, -1) chain_masks = torch.zeros((num_chains, contact_map.shape[0]), dtype=torch.bool) for ii, (chain_idx, crop_idx_tensor) in enumerate(cropped_chain_idx_dict.items()): chain_masks[ii, crop_idx_tensor] = True chain_mask_i = chain_masks.unsqueeze(1).unsqueeze(-1).expand_as(expanded_contact_map) chain_mask_j = chain_masks.unsqueeze(0).unsqueeze(-1).expand_as(expanded_contact_map).transpose(-1, -2) masked_contact_map = expanded_contact_map * chain_mask_i * chain_mask_j chain_contact_map = masked_contact_map.sum(dim=(-1, -2)) # (num_chains, num_chains) chain_contact_map = chain_contact_map > 0 contact_pair = [] for i in range(num_chains): for j in range(i+1, num_chains): if chain_contact_map[i,j]: contact_pair.append(cropped_chain_idx_list[i], cropped_chain_idx_list[j]) chain_combinations = find_chain_combinations(contact_pair) crop_idx_list = [] for combination in chain_combinations: crop_idx = torch.cat([cropped_chain_idx_dict[chain_idx] for chain_idx in combination]) crop_idx_list.append(crop_idx) return crop_idx_list
[docs] def generate_combinations(lst, pairs): combinations = set() # Use a set to avoid duplicate combinations queue = [lst] # Start with the initial list while queue: current = queue.pop(0) combinations.add(tuple(current)) # Add the current combination as a tuple to maintain order for a, b in pairs: if a in current and b in current: # Swap elements a and b idx_a, idx_b = current.index(a), current.index(b) new_combination = current.copy() new_combination[idx_a], new_combination[idx_b] = new_combination[idx_b], new_combination[idx_a] # If this new combination hasn't been seen before, add it to the queue if tuple(new_combination) not in combinations: queue.append(new_combination) # Convert each combination from tuple back to string and return the list of combinations return [list(combination) for combination in combinations]
[docs] def permute_label(protein_list, crop_idx, out_of_sequence_idxs, chain_break, same_chain_info): cropped_chain_break, chain_to_crop_idx = chain_break_cropping(chain_break, crop_idx) query_seq = protein_list[0].sequence.sequence[0] chain_list = same_chain_info['chains'] same_chain = same_chain_info['same_chain'] cropped_chain_list = list(cropped_chain_break.keys()) chain_num = len(cropped_chain_list) new_same_chain = torch.zeros((chain_num, chain_num), dtype=torch.bool) for cropped_chain1 in cropped_chain_break.keys(): for cropped_chain2 in cropped_chain_break.keys(): if cropped_chain1 == cropped_chain2: continue chain1 = cropped_chain_list.index(cropped_chain1) chain2 = cropped_chain_list.index(cropped_chain2) crop_idx1 = chain_to_crop_idx[cropped_chain1] crop_idx2 = chain_to_crop_idx[cropped_chain2] chain_seq1 = query_seq[crop_idx1] chain_seq2 = query_seq[crop_idx2] if same_chain[chain1, chain2] and crop_idx1.shape[0]==crop_idx2.shape[0]: if torch.all(chain_seq1 == chain_seq2).item(): new_same_chain[cropped_chain_list.index(cropped_chain1), cropped_chain_list.index(cropped_chain2)] = True if new_same_chain.sum() == 0: permuted_list = [cropped_chain_list] else : permutable_pair_list = [] for i in range(chain_num): for j in range(i+1, chain_num): if new_same_chain[i,j]: permutable_pair_list.append((cropped_chain_list[i], cropped_chain_list[j])) permuted_list = generate_combinations(cropped_chain_list, permutable_pair_list) new_label_seq = [] new_label_xyz = [] new_label_atom_mask = [] for protein in protein_list : protein_sequence = protein.sequence protein_structure = protein.structure # filter no MSA label_sequence = protein_sequence.sequence # (Model, L_total) label_xyz = protein_structure.xyz # (Model, L_total, 14, 3) label_atom_mask = protein_structure.atom_mask # (Model, L_total, 14) label_atom_mask[:,out_of_sequence_idxs,:] = 0 chain_to_seq = {} chain_to_xyz = {} chain_to_atom_mask = {} for chain_idx in cropped_chain_break.keys(): chain_crop_idx = chain_to_crop_idx[chain_idx] chain_to_seq[chain_idx] = label_sequence[:, chain_crop_idx] chain_to_xyz[chain_idx] = label_xyz[:, chain_crop_idx, :, :] chain_to_atom_mask[chain_idx] = label_atom_mask[:, chain_crop_idx, :] for permuted_chain_list in permuted_list: permuted_seq = [] permuted_xyz = [] permuted_atom_mask = [] permuted_chain_break = {} chain_start = 0 for chain_idx in permuted_chain_list: permuted_seq.append(chain_to_seq[chain_idx]) permuted_xyz.append(chain_to_xyz[chain_idx]) permuted_atom_mask.append(chain_to_atom_mask[chain_idx]) permuted_chain_break[chain_idx] = (chain_start, chain_start + chain_to_seq[chain_idx].shape[1]-1) chain_start += chain_to_seq[chain_idx].shape[1] permuted_seq = torch.cat(permuted_seq, dim=1) permuted_xyz = torch.cat(permuted_xyz, dim=1) permuted_atom_mask = torch.cat(permuted_atom_mask, dim=1) new_label_seq.append(permuted_seq) new_label_xyz.append(permuted_xyz) new_label_atom_mask.append(permuted_atom_mask) new_label_seq = torch.cat(new_label_seq, dim=0) new_label_xyz = torch.cat(new_label_xyz, dim=0) new_label_atom_mask = torch.cat(new_label_atom_mask, dim=0) new_label = Protein( sequence = ProteinSequence(sequence = new_label_seq), structure = ProteinStructure(xyz = new_label_xyz, atom_mask = new_label_atom_mask, chain_break = cropped_chain_break,), ) return new_label
if __name__ == "__main__": ID = "6vw2_F" permute_label_input = torch.load(f"permute_label_input_{ID}.pt") protein_list = permute_label_input['protein_list'] crop_idx = permute_label_input['crop_idx'] out_of_sequence_idxs = permute_label_input['out_of_sequence_idxs'] chain_break = permute_label_input['chain_break'] same_chain_info = permute_label_input['same_chain_info'] new_label_list = permute_label(protein_list, crop_idx, out_of_sequence_idxs, chain_break, same_chain_info)