Source code for diffalign.utils.dataloader

from collections.abc import Mapping
from typing import Any, List, Optional, Sequence, Union

import torch.utils.data
import copy
from torch.utils.data.dataloader import default_collate

from torch_geometric.data.collate import collate
from torch_geometric.data import Batch, Dataset
from torch_geometric.data.data import BaseData
from torch_geometric.data.datapipes import DatasetAdapter

[docs] class Collater: def __init__( self, dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None, ): self.dataset = dataset self.follow_batch = follow_batch self.exclude_keys = exclude_keys self.keys = ["atom_type_r", "edge_index_r", "edge_type_r", "pos_r", "smiles_r", "rdmol_r"] def __call__(self, batch: List[Any]) -> Any: elem = batch[0] if isinstance(elem, BaseData): batch2 = copy.deepcopy(batch) for i in range(len(batch)): for key in self.keys: del batch2[i][key[:-2]], batch2[i][key] batch2[i][key[:-2]] = batch[i][key] del batch[i][key] batch2 = Batch.from_data_list(batch2, follow_batch=self.follow_batch, exclude_keys=self.exclude_keys) for i in range(len(batch)): for key in self.keys: batch[i][key] = batch2[i][key[:-2]] return Batch.from_data_list(batch, follow_batch=self.follow_batch, exclude_keys=self.exclude_keys) elif isinstance(elem, torch.Tensor): return default_collate(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, str): return batch elif isinstance(elem, Mapping): return {key: self([data[key] for data in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, '_fields'): return type(elem)(*(self(s) for s in zip(*batch))) elif isinstance(elem, Sequence) and not isinstance(elem, str): return [self(s) for s in zip(*batch)] raise TypeError(f"DataLoader found invalid type: '{type(elem)}'")
[docs] class DataLoader(torch.utils.data.DataLoader): def __init__( self, dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter], batch_size: int = 1, shuffle: bool = False, follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None, **kwargs, ): # Remove for PyTorch Lightning: kwargs.pop('collate_fn', None) # Save for PyTorch Lightning < 1.6: self.follow_batch = follow_batch self.exclude_keys = exclude_keys super().__init__( dataset, batch_size, shuffle, collate_fn=Collater(dataset, follow_batch, exclude_keys), **kwargs, )