Source code for bapred.data.utils

import torch, pickle

[docs] def load_obj(name): with open(name, 'rb') as f: return pickle.load(f)
[docs] def save_obj(data, name): with open(f'{name}.pickle', 'wb') as f: pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
[docs] def one_hot(x, allowable_set): if x not in allowable_set: x = allowable_set[-1] return list( map( lambda s: x == s, allowable_set ) )
[docs] def is_one(x, allowable_set): return [ 1 if x in allowable_set else 0 ]
[docs] def calculate_pair_distance(arr1, arr2): return torch.linalg.norm( arr1[:, None, :] - arr2[None, :, :], axis = -1).float()