import torch
import os
[docs]
class MiniWorldMSAClass():
def __init__(self, msa_dict = None):
"""
msa_dict :
{
'sample_seq': torch.tensor, (MAXCYCLE, L_crop)
'sample_msa_clust': torch.tensor, (MAXCYCLE, Nclust, L_crop)
'sample_msa_seed': torch.tensor, (MAXCYCLE, Nclust, L_crop, 23 + 23 + 2 + 2)
'sample_msa_extra': torch.tensor, (MAXCYCLE, N_extra, L_crop, 23 + 1 + 2)
'sample_mask_pos': torch.tensor (MAXCYCLE, Nclust, L_crop)
}
"""
if msa_dict is None:
self.sample_seq = None
self.sample_msa_clust = None
self.sample_msa_seed = None
self.sample_msa_extra = None
self.sample_mask_pos = None
else :
self.sample_seq = msa_dict['sample_seq']
self.sample_msa_clust = msa_dict['sample_msa_clust']
self.sample_msa_seed = msa_dict['sample_msa_seed']
self.sample_msa_extra = msa_dict['sample_msa_extra']
self.sample_mask_pos = msa_dict['sample_mask_pos']
def __IsEmpty__(self):
if self.sample_seq is None and self.sample_msa_clust is None and self.sample_msa_seed is None and self.sample_msa_extra is None and self.sample_mask_pos is None:
return True
else :
return False
[docs]
class MiniWorldTemplateClass():
def __init__(self, template_dict = None):
"""
template_dict :
{
'xyz' : torch.tensor (params['N_PICK_GLOBAL'], L_crop, 27, 3)
'template_1D' : torch.tensor (params['N_PICK_GLOBAL'], L_crop, NUM_CLASSES + 1)
'template_atom_mask': torch.tensor (params['N_PICK_GLOBAL'], L_crop, 27)
}
"""
if template_dict is None:
self.xyz = None
self.template_1D = None
self.template_atom_mask = None
else :
self.xyz = template_dict['xyz']
self.template_1D = template_dict['template_1D']
self.template_atom_mask = template_dict['template_atom_mask']
def __IsEmpty__(self):
if self.xyz is None and self.template_1D is None and self.template_atom_mask is None:
return True
else :
return False
[docs]
class MiniWorldLabelClass():
def __init__(self, label_dict = None):
"""
label_dict :
{
'sequence' : torch.tensor # (L_chain, NUM_CLASSES)
'structure' : {
'xyz' : torch.tensor, # (L_chain, 27, 3)
'atom_mask' : torch.tensor, # (L_chain, 27)
'position_mask' : torch.tensor, # (L_chain, 27)
'has_multiple_chains' : bool,
'has_multiple_models' : bool
},
'occupancy' : occupancy,
}
"""
if label_dict is None :
self.sequence = None
self.structure_xyz = None
self.structure_atom_mask = None
else :
self.sequence = label_dict['sequence']
self.structure_xyz = label_dict['structure']['xyz']
self.structure_atom_mask = label_dict['structure']['atom_mask']
def __IsEmpty__(self):
if self.sequence is None and self.structure_xyz is None and self.structure_atom_mask is None and self.structure_position_mask is None and self.has_multiple_chains is None and self.has_multiple_models is None and self.occupancy is None:
return True
else :
return False
[docs]
class MiniWorldDataClass():
def __init__(self, data_dict):
"""
For test memory leak.
data_dict :
{ # L_crop can be less than params['CROP']
'msa' : {
'sample_seq': torch.tensor, (MAXCYCLE, L_crop)
'sample_msa_clust': torch.tensor, (MAXCYCLE, Nclust, L_crop)
'sample_msa_seed': torch.tensor, (MAXCYCLE, Nclust, L_crop, 23 + 23 + 2 + 2)
'sample_msa_extra': torch.tensor, (MAXCYCLE, N_extra, L_crop, 23 + 1 + 2)
'sample_mask_pos': torch.tensor (MAXCYCLE, Nclust, L_crop)
},
'template' : {
'xyz' : torch.tensor (params['N_PICK_GLOBAL'], L_crop, 27, 3)
'template_1D' : torch.tensor (params['N_PICK_GLOBAL'], L_crop, NUM_CLASSES + 1)
'template_atom_mask': torch.tensor (params['N_PICK_GLOBAL'], L_crop, 27)
},
'label' : {
'sequence' : {
'sequence' : torch.tensor # (L_chain, NUM_CLASSES)
},
'structure' : {
'xyz' : torch.tensor, # (L_chain, 27, 3)
'atom_mask' : torch.tensor, # (L_chain, 27)
'position_mask' : torch.tensor, # (L_chain, 27)
'has_multiple_chains' : bool,
'has_multiple_models' : bool
},
'occupancy' : occupancy,
},
'prev' : {
'xyz' : torch.tensor, # (L_crop, 27, 3)
'atom_mask' : torch.tensor, # (L_crop, 27)
}
'symmetry_related_info' : symmetry_related_info,
'crop_idx' : torch.tensor, # (L_crop)
'chain_break' : dictionary, # (N_chain, 2)
'ID' : ID,
'has_label_structure' : has_label_structure (bool),
'source' : source (str),
}
"""
self.msa = MiniWorldMSAClass(data_dict['msa'])
if 'template' not in data_dict.keys():
self.template = None
else :
self.template = MiniWorldTemplateClass(data_dict['template'])
self.label = MiniWorldLabelClass(data_dict['label'])
self.prev_xyz = data_dict['prev']['xyz']
self.prev_atom_mask = data_dict['prev']['atom_mask']
self.symmetry_related_info = data_dict['symmetry_related_info']
self.crop_idx = data_dict['crop_idx']
self.chain_break = data_dict['chain_break']
self.ID = data_dict['ID']
self.has_label_structure = data_dict['has_label_structure']
self.source = data_dict['source']
[docs]
class MiniWorldWrongDataClass():
def __init__(self, ID, source, error):
self.ID = ID
self.source = source
self.error = error
[docs]
class MiniWorldBatchedDataClass():
def __init__(self, MiniWorldData_list, use_template = False):
self.MiniWorldData_list = MiniWorldData_list
self.batch_size = len(MiniWorldData_list)
# unbatched_key_list = ['ID', 'source', 'has_label_structure', 'symmetry_related_info', 'chain_break', 'label']
self.ID = []
self.source = []
self.has_label_structure = []
self.symmetry_related_info = []
self.chain_break = []
self.label = []
self.msa = MiniWorldMSAClass()
self.use_template = use_template
if use_template :
self.template = MiniWorldTemplateClass()
self.prev_xyz = None
self.prev_atom_mask = None
self.crop_idx = None
self.__collate__()
def __collate__(self):
for MiniWorldData in self.MiniWorldData_list:
self.ID.append(MiniWorldData.ID)
self.source.append(MiniWorldData.source)
self.has_label_structure.append(MiniWorldData.has_label_structure)
self.symmetry_related_info.append(MiniWorldData.symmetry_related_info)
self.chain_break.append(MiniWorldData.chain_break)
self.label.append(MiniWorldData.label)
device = "cpu"
for batch_idx, MiniWorldData in enumerate(self.MiniWorldData_list):
if self.msa.__IsEmpty__():
self.msa.sample_seq = torch.zeros((self.batch_size,) + MiniWorldData.msa.sample_seq.shape, dtype=MiniWorldData.msa.sample_seq.dtype)
self.msa.sample_msa_clust = torch.zeros((self.batch_size,) + MiniWorldData.msa.sample_msa_clust.shape, dtype=MiniWorldData.msa.sample_msa_clust.dtype)
self.msa.sample_msa_seed = torch.zeros((self.batch_size,) + MiniWorldData.msa.sample_msa_seed.shape, dtype=MiniWorldData.msa.sample_msa_seed.dtype)
self.msa.sample_msa_extra = torch.zeros((self.batch_size,) + MiniWorldData.msa.sample_msa_extra.shape, dtype=MiniWorldData.msa.sample_msa_extra.dtype)
self.msa.sample_mask_pos = torch.zeros((self.batch_size,) + MiniWorldData.msa.sample_mask_pos.shape, dtype=MiniWorldData.msa.sample_mask_pos.dtype)
if self.use_template :
if self.template.__IsEmpty__():
self.template.xyz = torch.zeros((self.batch_size,) + MiniWorldData.template.xyz.shape, dtype=MiniWorldData.template.xyz.dtype)
self.template.template_1D = torch.zeros((self.batch_size,) + MiniWorldData.template.template_1D.shape, dtype=MiniWorldData.template.template_1D.dtype)
self.template.template_atom_mask = torch.zeros((self.batch_size,) + MiniWorldData.template.template_atom_mask.shape, dtype=MiniWorldData.template.template_atom_mask.dtype)
if self.prev_xyz is None:
self.prev_xyz = torch.zeros((self.batch_size,) + MiniWorldData.prev_xyz.shape, dtype=MiniWorldData.prev_xyz.dtype)
self.prev_atom_mask = torch.zeros((self.batch_size,) + MiniWorldData.prev_atom_mask.shape, dtype=MiniWorldData.prev_atom_mask.dtype)
if self.crop_idx is None:
self.crop_idx = torch.zeros((self.batch_size,) + MiniWorldData.crop_idx.shape, dtype=MiniWorldData.crop_idx.dtype)
###
self.msa.sample_seq[batch_idx] = MiniWorldData.msa.sample_seq
self.msa.sample_msa_clust[batch_idx] = MiniWorldData.msa.sample_msa_clust
self.msa.sample_msa_seed[batch_idx] = MiniWorldData.msa.sample_msa_seed
self.msa.sample_msa_extra[batch_idx] = MiniWorldData.msa.sample_msa_extra
self.msa.sample_mask_pos[batch_idx] = MiniWorldData.msa.sample_mask_pos
if self.use_template :
self.template.xyz[batch_idx] = MiniWorldData.template.xyz
self.template.template_1D[batch_idx] = MiniWorldData.template.template_1D
self.template.template_atom_mask[batch_idx] = MiniWorldData.template.template_atom_mask
self.prev_xyz[batch_idx] = MiniWorldData.prev_xyz
self.prev_atom_mask[batch_idx] = MiniWorldData.prev_atom_mask
self.crop_idx[batch_idx] = MiniWorldData.crop_idx
def __len__(self):
return self.batch_size
def __repr__(self):
return "MiniWorldBatchedDataClass(batch_size = {})".format(self.batch_size)