Source code for miniworld.utils.DataClass

import torch
import os

[docs] class MiniWorldMSAClass(): def __init__(self, msa_dict = None): """ msa_dict : { 'sample_seq': torch.tensor, (MAXCYCLE, L_crop) 'sample_msa_clust': torch.tensor, (MAXCYCLE, Nclust, L_crop) 'sample_msa_seed': torch.tensor, (MAXCYCLE, Nclust, L_crop, 23 + 23 + 2 + 2) 'sample_msa_extra': torch.tensor, (MAXCYCLE, N_extra, L_crop, 23 + 1 + 2) 'sample_mask_pos': torch.tensor (MAXCYCLE, Nclust, L_crop) } """ if msa_dict is None: self.sample_seq = None self.sample_msa_clust = None self.sample_msa_seed = None self.sample_msa_extra = None self.sample_mask_pos = None else : self.sample_seq = msa_dict['sample_seq'] self.sample_msa_clust = msa_dict['sample_msa_clust'] self.sample_msa_seed = msa_dict['sample_msa_seed'] self.sample_msa_extra = msa_dict['sample_msa_extra'] self.sample_mask_pos = msa_dict['sample_mask_pos'] def __IsEmpty__(self): if self.sample_seq is None and self.sample_msa_clust is None and self.sample_msa_seed is None and self.sample_msa_extra is None and self.sample_mask_pos is None: return True else : return False
[docs] class MiniWorldTemplateClass(): def __init__(self, template_dict = None): """ template_dict : { 'xyz' : torch.tensor (params['N_PICK_GLOBAL'], L_crop, 27, 3) 'template_1D' : torch.tensor (params['N_PICK_GLOBAL'], L_crop, NUM_CLASSES + 1) 'template_atom_mask': torch.tensor (params['N_PICK_GLOBAL'], L_crop, 27) } """ if template_dict is None: self.xyz = None self.template_1D = None self.template_atom_mask = None else : self.xyz = template_dict['xyz'] self.template_1D = template_dict['template_1D'] self.template_atom_mask = template_dict['template_atom_mask'] def __IsEmpty__(self): if self.xyz is None and self.template_1D is None and self.template_atom_mask is None: return True else : return False
[docs] class MiniWorldLabelClass(): def __init__(self, label_dict = None): """ label_dict : { 'sequence' : torch.tensor # (L_chain, NUM_CLASSES) 'structure' : { 'xyz' : torch.tensor, # (L_chain, 27, 3) 'atom_mask' : torch.tensor, # (L_chain, 27) 'position_mask' : torch.tensor, # (L_chain, 27) 'has_multiple_chains' : bool, 'has_multiple_models' : bool }, 'occupancy' : occupancy, } """ if label_dict is None : self.sequence = None self.structure_xyz = None self.structure_atom_mask = None else : self.sequence = label_dict['sequence'] self.structure_xyz = label_dict['structure']['xyz'] self.structure_atom_mask = label_dict['structure']['atom_mask'] def __IsEmpty__(self): if self.sequence is None and self.structure_xyz is None and self.structure_atom_mask is None and self.structure_position_mask is None and self.has_multiple_chains is None and self.has_multiple_models is None and self.occupancy is None: return True else : return False
[docs] class MiniWorldDataClass(): def __init__(self, data_dict): """ For test memory leak. data_dict : { # L_crop can be less than params['CROP'] 'msa' : { 'sample_seq': torch.tensor, (MAXCYCLE, L_crop) 'sample_msa_clust': torch.tensor, (MAXCYCLE, Nclust, L_crop) 'sample_msa_seed': torch.tensor, (MAXCYCLE, Nclust, L_crop, 23 + 23 + 2 + 2) 'sample_msa_extra': torch.tensor, (MAXCYCLE, N_extra, L_crop, 23 + 1 + 2) 'sample_mask_pos': torch.tensor (MAXCYCLE, Nclust, L_crop) }, 'template' : { 'xyz' : torch.tensor (params['N_PICK_GLOBAL'], L_crop, 27, 3) 'template_1D' : torch.tensor (params['N_PICK_GLOBAL'], L_crop, NUM_CLASSES + 1) 'template_atom_mask': torch.tensor (params['N_PICK_GLOBAL'], L_crop, 27) }, 'label' : { 'sequence' : { 'sequence' : torch.tensor # (L_chain, NUM_CLASSES) }, 'structure' : { 'xyz' : torch.tensor, # (L_chain, 27, 3) 'atom_mask' : torch.tensor, # (L_chain, 27) 'position_mask' : torch.tensor, # (L_chain, 27) 'has_multiple_chains' : bool, 'has_multiple_models' : bool }, 'occupancy' : occupancy, }, 'prev' : { 'xyz' : torch.tensor, # (L_crop, 27, 3) 'atom_mask' : torch.tensor, # (L_crop, 27) } 'symmetry_related_info' : symmetry_related_info, 'crop_idx' : torch.tensor, # (L_crop) 'chain_break' : dictionary, # (N_chain, 2) 'ID' : ID, 'has_label_structure' : has_label_structure (bool), 'source' : source (str), } """ self.msa = MiniWorldMSAClass(data_dict['msa']) if 'template' not in data_dict.keys(): self.template = None else : self.template = MiniWorldTemplateClass(data_dict['template']) self.label = MiniWorldLabelClass(data_dict['label']) self.prev_xyz = data_dict['prev']['xyz'] self.prev_atom_mask = data_dict['prev']['atom_mask'] self.symmetry_related_info = data_dict['symmetry_related_info'] self.crop_idx = data_dict['crop_idx'] self.chain_break = data_dict['chain_break'] self.ID = data_dict['ID'] self.has_label_structure = data_dict['has_label_structure'] self.source = data_dict['source']
[docs] class MiniWorldWrongDataClass(): def __init__(self, ID, source, error): self.ID = ID self.source = source self.error = error
[docs] class MiniWorldBatchedDataClass(): def __init__(self, MiniWorldData_list, use_template = False): self.MiniWorldData_list = MiniWorldData_list self.batch_size = len(MiniWorldData_list) # unbatched_key_list = ['ID', 'source', 'has_label_structure', 'symmetry_related_info', 'chain_break', 'label'] self.ID = [] self.source = [] self.has_label_structure = [] self.symmetry_related_info = [] self.chain_break = [] self.label = [] self.msa = MiniWorldMSAClass() self.use_template = use_template if use_template : self.template = MiniWorldTemplateClass() self.prev_xyz = None self.prev_atom_mask = None self.crop_idx = None self.__collate__() def __collate__(self): for MiniWorldData in self.MiniWorldData_list: self.ID.append(MiniWorldData.ID) self.source.append(MiniWorldData.source) self.has_label_structure.append(MiniWorldData.has_label_structure) self.symmetry_related_info.append(MiniWorldData.symmetry_related_info) self.chain_break.append(MiniWorldData.chain_break) self.label.append(MiniWorldData.label) device = "cpu" for batch_idx, MiniWorldData in enumerate(self.MiniWorldData_list): if self.msa.__IsEmpty__(): self.msa.sample_seq = torch.zeros((self.batch_size,) + MiniWorldData.msa.sample_seq.shape, dtype=MiniWorldData.msa.sample_seq.dtype) self.msa.sample_msa_clust = torch.zeros((self.batch_size,) + MiniWorldData.msa.sample_msa_clust.shape, dtype=MiniWorldData.msa.sample_msa_clust.dtype) self.msa.sample_msa_seed = torch.zeros((self.batch_size,) + MiniWorldData.msa.sample_msa_seed.shape, dtype=MiniWorldData.msa.sample_msa_seed.dtype) self.msa.sample_msa_extra = torch.zeros((self.batch_size,) + MiniWorldData.msa.sample_msa_extra.shape, dtype=MiniWorldData.msa.sample_msa_extra.dtype) self.msa.sample_mask_pos = torch.zeros((self.batch_size,) + MiniWorldData.msa.sample_mask_pos.shape, dtype=MiniWorldData.msa.sample_mask_pos.dtype) if self.use_template : if self.template.__IsEmpty__(): self.template.xyz = torch.zeros((self.batch_size,) + MiniWorldData.template.xyz.shape, dtype=MiniWorldData.template.xyz.dtype) self.template.template_1D = torch.zeros((self.batch_size,) + MiniWorldData.template.template_1D.shape, dtype=MiniWorldData.template.template_1D.dtype) self.template.template_atom_mask = torch.zeros((self.batch_size,) + MiniWorldData.template.template_atom_mask.shape, dtype=MiniWorldData.template.template_atom_mask.dtype) if self.prev_xyz is None: self.prev_xyz = torch.zeros((self.batch_size,) + MiniWorldData.prev_xyz.shape, dtype=MiniWorldData.prev_xyz.dtype) self.prev_atom_mask = torch.zeros((self.batch_size,) + MiniWorldData.prev_atom_mask.shape, dtype=MiniWorldData.prev_atom_mask.dtype) if self.crop_idx is None: self.crop_idx = torch.zeros((self.batch_size,) + MiniWorldData.crop_idx.shape, dtype=MiniWorldData.crop_idx.dtype) ### self.msa.sample_seq[batch_idx] = MiniWorldData.msa.sample_seq self.msa.sample_msa_clust[batch_idx] = MiniWorldData.msa.sample_msa_clust self.msa.sample_msa_seed[batch_idx] = MiniWorldData.msa.sample_msa_seed self.msa.sample_msa_extra[batch_idx] = MiniWorldData.msa.sample_msa_extra self.msa.sample_mask_pos[batch_idx] = MiniWorldData.msa.sample_mask_pos if self.use_template : self.template.xyz[batch_idx] = MiniWorldData.template.xyz self.template.template_1D[batch_idx] = MiniWorldData.template.template_1D self.template.template_atom_mask[batch_idx] = MiniWorldData.template.template_atom_mask self.prev_xyz[batch_idx] = MiniWorldData.prev_xyz self.prev_atom_mask[batch_idx] = MiniWorldData.prev_atom_mask self.crop_idx[batch_idx] = MiniWorldData.crop_idx def __len__(self): return self.batch_size def __repr__(self): return "MiniWorldBatchedDataClass(batch_size = {})".format(self.batch_size)