Source code for bsitep.proteindata

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch
import numpy as np
import openbabel
import os
import pickle
from openbabel import pybel
from math import ceil, sin, cos, sqrt, pi
from itertools import combinations
from random import shuffle, choice, sample
from skimage.draw import ellipsoid
from scipy import ndimage

pybel.ob.obErrorLog.StopLogging()

[docs] class Featurizer(): def __init__(self, atom_codes=None, atom_labels=None, named_properties=None, save_molecule_codes=True, custom_properties=None, smarts_properties=None, smarts_labels=None): self.FEATURE_NAMES = [] if atom_codes is not None: if not isinstance(atom_codes, dict): raise TypeError('Atom codes should be dict, got %s instead' % type(atom_codes)) codes = set(atom_codes.values()) for i in range(len(codes)): if i not in codes: raise ValueError('Incorrect atom code %s' % i) self.NUM_ATOM_CLASSES = len(codes) self.ATOM_CODES = atom_codes if atom_labels is not None: if len(atom_labels) != self.NUM_ATOM_CLASSES: raise ValueError('Incorrect number of atom labels: ' '%s instead of %s' % (len(atom_labels), self.NUM_ATOM_CLASSES)) else: atom_labels = ['atom%s' % i for i in range(self.NUM_ATOM_CLASSES)] self.FEATURE_NAMES += atom_labels else: self.ATOM_CODES = {} metals = ([3, 4, 11, 12, 13] + list(range(19, 32)) + list(range(37, 51)) + list(range(55, 84)) + list(range(87, 104))) atom_classes = [ (5, 'B'), (6, 'C'), (7, 'N'), (8, 'O'), (15, 'P'), (16, 'S'), (34, 'Se'), ([9, 17, 35, 53], 'halogen'), (metals, 'metal') ] for code, (atom, name) in enumerate(atom_classes): if type(atom) is list: for a in atom: self.ATOM_CODES[a] = code else: self.ATOM_CODES[atom] = code self.FEATURE_NAMES.append(name) self.NUM_ATOM_CLASSES = len(atom_classes) if named_properties is not None: if not isinstance(named_properties, (list, tuple, np.ndarray)): raise TypeError('named_properties must be a list') allowed_props = [prop for prop in dir(pybel.Atom) if not prop.startswith('__')] for prop_id, prop in enumerate(named_properties): if prop not in allowed_props: raise ValueError( 'named_properties must be in pybel.Atom attributes,' ' %s was given at position %s' % (prop_id, prop) ) self.NAMED_PROPS = named_properties else: # pybel.Atom properties to save self.NAMED_PROPS = ['hyb', 'heavydegree', 'heterodegree', 'partialcharge'] self.FEATURE_NAMES += self.NAMED_PROPS if not isinstance(save_molecule_codes, bool): raise TypeError('save_molecule_codes should be bool, got %s ' 'instead' % type(save_molecule_codes)) self.save_molecule_codes = save_molecule_codes if save_molecule_codes: # Remember if an atom belongs to the ligand or to the protein self.FEATURE_NAMES.append('molcode') self.CALLABLES = [] if custom_properties is not None: for i, func in enumerate(custom_properties): if not callable(func): raise TypeError('custom_properties should be list of' ' callables, got %s instead' % type(func)) name = getattr(func, '__name__', '') if name == '': name = 'func%s' % i self.CALLABLES.append(func) self.FEATURE_NAMES.append(name) if smarts_properties is None: # SMARTS definition for other properties self.SMARTS = [ '[#6+0!$(*~[#7,#8,F]),SH0+0v2,s+0,S^3,Cl+0,Br+0,I+0]', '[a]', '[!$([#1,#6,F,Cl,Br,I,o,s,nX3,#7v5,#15v5,#16v4,#16v6,*+1,*+2,*+3])]', '[!$([#6,H0,-,-2,-3]),$([!H0;#7,#8,#9])]', '[r]' ] smarts_labels = ['hydrophobic', 'aromatic', 'acceptor', 'donor', 'ring'] elif not isinstance(smarts_properties, (list, tuple, np.ndarray)): raise TypeError('smarts_properties must be a list') else: self.SMARTS = smarts_properties if smarts_labels is not None: if len(smarts_labels) != len(self.SMARTS): raise ValueError('Incorrect number of SMARTS labels: %s' ' instead of %s' % (len(smarts_labels), len(self.SMARTS))) else: smarts_labels = ['smarts%s' % i for i in range(len(self.SMARTS))] # Compile patterns self.compile_smarts() self.FEATURE_NAMES += smarts_labels
[docs] def compile_smarts(self): self.__PATTERNS = [] for smarts in self.SMARTS: self.__PATTERNS.append(pybel.Smarts(smarts))
[docs] def encode_num(self, atomic_num): if not isinstance(atomic_num, int): raise TypeError('Atomic number must be int, %s was given' % type(atomic_num)) encoding = np.zeros(self.NUM_ATOM_CLASSES) try: encoding[self.ATOM_CODES[atomic_num]] = 1.0 except: pass return encoding
[docs] def find_smarts(self, molecule): if not isinstance(molecule, pybel.Molecule): raise TypeError('molecule must be pybel.Molecule object, %s was given' % type(molecule)) features = np.zeros((len(molecule.atoms), len(self.__PATTERNS))) for (pattern_id, pattern) in enumerate(self.__PATTERNS): atoms_with_prop = np.array(list(*zip(*pattern.findall(molecule))), dtype=int) - 1 features[atoms_with_prop, pattern_id] = 1.0 return features
[docs] def get_binary_features(self,mol): coords = [] for a in mol.atoms: coords.append(a.coords) coords = np.array(coords) features = np.ones((len(coords), 1)) return coords, features
[docs] def get_features(self, molecule, molcode=None): if not isinstance(molecule, pybel.Molecule): raise TypeError('molecule must be pybel.Molecule object,' ' %s was given' % type(molecule)) if molcode is None: if self.save_molecule_codes is True: raise ValueError('save_molecule_codes is set to True,' ' you must specify code for the molecule') elif not isinstance(molcode, (float, int)): raise TypeError('motlype must be float, %s was given' % type(molcode)) coords = [] features = [] heavy_atoms = [] for i, atom in enumerate(molecule): if atom.atomicnum > 1: heavy_atoms.append(i) coords.append(atom.coords) features.append(np.concatenate(( self.encode_num(atom.atomicnum), [atom.__getattribute__(prop) for prop in self.NAMED_PROPS], [func(atom) for func in self.CALLABLES], ))) coords = np.array(coords, dtype=np.float32) features = np.array(features, dtype=np.float32) if self.save_molecule_codes: features = np.hstack((features, molcode * np.ones((len(features), 1)))) features = np.hstack([features, self.find_smarts(molecule)[heavy_atoms]]) if np.isnan(features).any(): raise RuntimeError('Got NaN when calculating features') return coords, features
[docs] def get_features_gt(self, molecule, molcode=None): if not isinstance(molecule, pybel.Molecule): raise TypeError('molecule must be pybel.Molecule object,' ' %s was given' % type(molecule)) if molcode is None: if self.save_molecule_codes is True: raise ValueError('save_molecule_codes is set to True,' ' you must specify code for the molecule') elif not isinstance(molcode, (float, int)): raise TypeError('motlype must be float, %s was given' % type(molcode)) pocket_coords, pocket_features = self.get_binary_features(molecule) features = pocket_features coords = pocket_coords if np.isnan(features).any(): raise RuntimeError('Got NaN when calculating features') return coords, features
[docs] def to_pickle(self, fname='featurizer.pkl'): patterns = self.__PATTERNS[:] del self.__PATTERNS try: with open(fname, 'wb') as f: pickle.dump(self, f) finally: self.__PATTERNS = patterns[:]
[docs] @staticmethod def from_pickle(fname): with open(fname, 'rb') as f: featurizer = pickle.load(f) featurizer.compile_smarts() return featurizer
[docs] class proteinDataset(Dataset): def __init__(self, data_path,featurizer = Featurizer(save_molecule_codes=False),max_dist=35,eval=True, scale=0.5,max_translation=5,kfold_ind=0): self.data_path = data_path self.max_dist = max_dist self.scale = scale self.max_translation = max_translation self.eval = eval self.featurizer = featurizer self.data_list = sorted([os.path.join(self.data_path,x) for x in os.listdir(self.data_path)]) self.eval_num =int(len(self.data_list) / 4) footprint = ellipsoid(2, 2, 2) self.footprint = footprint.reshape(( *footprint.shape, 1)) self.data_list_list = [self.data_list[:self.eval_num], self.data_list[self.eval_num:self.eval_num*2], self.data_list[self.eval_num*2:self.eval_num*3], self.data_list[-self.eval_num:]] if self.eval: self.data_list = self.data_list_list[kfold_ind] self.protein_list = [next(pybel.readfile('mol2',os.path.join(path,'protein.mol2'))) for path in self.data_list] self.pocket_list = [next(pybel.readfile('mol2',os.path.join(path,'cavity6.mol2'))) for path in self.data_list] else: # self.data_list_list.pop(kfold_ind) self.data_list = self.data_list_list[0] + self.data_list_list[1] + self.data_list_list[2] + self.data_list_list[3] self.protein_list = [next(pybel.readfile('mol2',os.path.join(path,'protein.mol2'))) for path in self.data_list] self.pocket_list = [next(pybel.readfile('mol2',os.path.join(path,'cavity6.mol2'))) for path in self.data_list] def __len__(self): return len(self.data_list) def __getitem__(self,index): path = self.data_list[index] mol1=self.protein_list[index] mol2=self.pocket_list[index] rot = choice(range(24)) tr = self.max_translation * np.random.rand(1, 3) x,y = self.feed_data(mol1,mol2,rot,tr) x = torch.Tensor(x.astype(np.float32)).permute(3,0,1,2) y = torch.Tensor(y.astype(np.float32)).permute(3,0,1,2) return x ,y
[docs] def make_grid(self, coords, features, grid_resolution=1.0, max_dist=10.0): try: coords = np.asarray(coords, dtype=np.float64) except ValueError: raise ValueError('coords must be an array of floats of shape (N, 3)') c_shape = coords.shape if len(c_shape) != 2 or c_shape[1] != 3: raise ValueError('coords must be an array of floats of shape (N, 3)') N = len(coords) try: features = np.asarray(features, dtype=np.float64) except ValueError: raise ValueError('features must be an array of floats of shape (N, F)') f_shape = features.shape if len(f_shape) != 2 or f_shape[0] != N: raise ValueError('features must be an array of floats of shape (N, F)') if not isinstance(grid_resolution, (float, int)): raise TypeError('grid_resolution must be float') if grid_resolution <= 0: raise ValueError('grid_resolution must be positive') if not isinstance(max_dist, (float, int)): raise TypeError('max_dist must be float') if max_dist <= 0: raise ValueError('max_dist must be positive') num_features = f_shape[1] max_dist = float(max_dist) grid_resolution = float(grid_resolution) box_size = ceil(2 * max_dist / grid_resolution + 1) # move all atoms to the neares grid point grid_coords = (coords + max_dist) / grid_resolution grid_coords = grid_coords.round().astype(int) # remove atoms outside the box in_box = ((grid_coords >= 0) & (grid_coords < box_size)).all(axis=1) grid = np.zeros(( box_size, box_size, box_size, num_features), dtype=np.float32) for (x, y, z), f in zip(grid_coords[in_box], features[in_box]): grid[ x, y, z] += f return grid
[docs] def rotation_matrix(self,axis, theta): try: axis = np.asarray(axis, dtype=np.float64) except ValueError: raise ValueError('axis must be an array of floats of shape (3,)') if axis.shape != (3,): raise ValueError('axis must be an array of floats of shape (3,)') if not isinstance(theta, (float, int)): raise TypeError('theta must be a float') axis = axis / sqrt(np.dot(axis, axis)) a = cos(theta / 2.0) b, c, d = -axis * sin(theta / 2.0) aa, bb, cc, dd = a * a, b * b, c * c, d * d bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
[docs] def rotate(self,coords,rotation): ROTATIONS = [self.rotation_matrix([1,1,1],0)] for a1 in range(3): for t in range(1, 4): axis = np.zeros(3) axis[a1] = 1 theta = t * pi / 2.0 ROTATIONS.append(self.rotation_matrix(axis, theta)) # about each face diagonal - 6 rotations for (a1, a2) in combinations(range(3), 2): axis = np.zeros(3) axis[[a1, a2]] = 1.0 theta = pi ROTATIONS.append(self.rotation_matrix(axis, theta)) axis[a2] = -1.0 ROTATIONS.append(self.rotation_matrix(axis, theta)) # about each space diagonal - 8 rotations for t in [1, 2]: theta = t * 2 * pi / 3 axis = np.ones(3) ROTATIONS.append(self.rotation_matrix(axis, theta)) for a1 in range(3): axis = np.ones(3) axis[a1] = -1 ROTATIONS.append(self.rotation_matrix(axis, theta)) if isinstance(rotation, int): if rotation >= 0 and rotation < len(ROTATIONS): return np.dot(coords, ROTATIONS[rotation]) else: raise ValueError('Invalid rotation number %s!' % rotation) elif isinstance(rotation, np.ndarray) and rotation.shape == (3, 3): return np.dot(coords, rotation) else: raise ValueError('Invalid rotation %s!' % rotation)
[docs] def feed_data(self,mol1,mol2,rotation=0,translation=(0,0,0)): if not isinstance(mol1, pybel.Molecule): raise TypeError('mol should be a pybel.Molecule object, got %s ' 'instead' % type(mol1)) if not isinstance(mol2, pybel.Molecule): raise TypeError('mol should be a pybel.Molecule object, got %s ' 'instead' % type(mol2)) if self.featurizer is None: raise ValueError('featurizer must be set to make predistions for ' 'molecules') if self.scale is None: raise ValueError('scale must be set to make predistions') prot_coords, prot_features = self.featurizer.get_features(mol1) pocket_coords,pocket_features = self.featurizer.get_features_gt(mol2) centroid = prot_coords.mean(axis=0) prot_coords -= centroid prot_coords = self.rotate(prot_coords,rotation) prot_coords += translation resolution = 1. / self.scale x = self.make_grid(prot_coords, prot_features, max_dist=self.max_dist, grid_resolution=resolution) y_channels = pocket_features.shape[1] pocket_coords -=centroid pocket_coords = self.rotate(pocket_coords,rotation) pocket_coords += translation gt = self.make_grid(pocket_coords, pocket_features, max_dist=self.max_dist) margin = ndimage.maximum_filter(gt, footprint=self.footprint) gt += margin gt = gt.clip(0,1) zoom = x.shape[1] / gt.shape[1] gt = np.expand_dims(gt, 0) gt = np.stack([ndimage.zoom(gt[0, ..., i], zoom) for i in range(y_channels)], -1) gt = gt.clip(0,1) return x, gt
[docs] class proteinDataset_predict(Dataset): def __init__(self, data_path,featurizer = Featurizer(save_molecule_codes=False),max_dist=35, scale=0.5,max_translation=5,file_format = 'mol2'): self.data_path = data_path self.max_dist = max_dist self.scale = scale self.max_translation = max_translation self.featurizer = featurizer self.file_format = file_format self.data_list = [os.path.join(self.data_path,x) for x in sorted(os.listdir(self.data_path))] footprint = ellipsoid(2, 2, 2) self.footprint = footprint.reshape(( *footprint.shape, 1)) self.protein_list = [next(pybel.readfile(self.file_format,os.path.join(path,'protein.' + self.file_format))) for path in self.data_list] def __len__(self): return len(self.data_list) def __getitem__(self,index): molname= self.data_list[index].split('/')[-1] mol = os.path.join(self.data_list[index],'protein.'+self.file_format) mol1=self.protein_list[index] x, origin, step = self.feed_data(mol1) x = torch.Tensor(x.astype(np.float32)).permute(3,0,1,2) return x , origin, step, molname, mol
[docs] def make_grid(self, coords, features, grid_resolution=1.0, max_dist=10.0): try: coords = np.asarray(coords, dtype=np.float64) except ValueError: raise ValueError('coords must be an array of floats of shape (N, 3)') c_shape = coords.shape if len(c_shape) != 2 or c_shape[1] != 3: raise ValueError('coords must be an array of floats of shape (N, 3)') N = len(coords) try: features = np.asarray(features, dtype=np.float64) except ValueError: raise ValueError('features must be an array of floats of shape (N, F)') f_shape = features.shape if len(f_shape) != 2 or f_shape[0] != N: raise ValueError('features must be an array of floats of shape (N, F)') if not isinstance(grid_resolution, (float, int)): raise TypeError('grid_resolution must be float') if grid_resolution <= 0: raise ValueError('grid_resolution must be positive') if not isinstance(max_dist, (float, int)): raise TypeError('max_dist must be float') if max_dist <= 0: raise ValueError('max_dist must be positive') num_features = f_shape[1] max_dist = float(max_dist) grid_resolution = float(grid_resolution) box_size = ceil(2 * max_dist / grid_resolution + 1) # move all atoms to the neares grid point grid_coords = (coords + max_dist) / grid_resolution grid_coords = grid_coords.round().astype(int) # remove atoms outside the box in_box = ((grid_coords >= 0) & (grid_coords < box_size)).all(axis=1) grid = np.zeros(( box_size, box_size, box_size, num_features), dtype=np.float32) for (x, y, z), f in zip(grid_coords[in_box], features[in_box]): grid[ x, y, z] += f return grid
[docs] def feed_data(self,mol1): if not isinstance(mol1, pybel.Molecule): raise TypeError('mol should be a pybel.Molecule object, got %s ' 'instead' % type(mol1)) if self.featurizer is None: raise ValueError('featurizer must be set to make predistions for ' 'molecules') if self.scale is None: raise ValueError('scale must be set to make predistions') prot_coords, prot_features = self.featurizer.get_features(mol1) centroid = prot_coords.mean(axis=0) prot_coords -= centroid resolution = 1. / self.scale x = self.make_grid(prot_coords, prot_features, max_dist=self.max_dist, grid_resolution=resolution) origin = (centroid - self.max_dist) step = np.array([1.0 / self.scale] * 3) return x, origin, step