Source code for deepfold.utils.tensor_utils

"""Utilities related to tensor operations."""

from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload

import numpy as np
import torch


[docs] def add(m1: torch.Tensor, m2: torch.Tensor, inplace: bool) -> torch.Tensor: if not inplace: m1 = m1 + m2 else: m1 += m2 return m1
[docs] def permute_final_dims(tensor: torch.Tensor, inds: List[int]) -> torch.Tensor: zero_index = -1 * len(inds) first_inds = list(range(len(tensor.shape[:zero_index]))) return tensor.permute(first_inds + [zero_index + i for i in inds])
[docs] def flatten_final_dims(t: torch.Tensor, num_dims: int) -> torch.Tensor: return t.reshape(t.shape[:-num_dims] + (-1,))
[docs] def masked_mean( mask: torch.Tensor, value: torch.Tensor, dim: Union[int, Tuple[int, ...]], eps: float = 1e-4, keepdim: bool = False, ) -> torch.Tensor: mask = mask.expand(*value.shape) return torch.sum(mask * value, dim=dim, keepdim=keepdim) / (eps + torch.sum(mask, dim=dim, keepdim=keepdim))
[docs] def one_hot(x: torch.Tensor, v_bins: torch.Tensor) -> torch.Tensor: reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) diffs = x[..., None] - reshaped_bins am = torch.argmin(torch.abs(diffs), dim=-1) return torch.nn.functional.one_hot(am, num_classes=len(v_bins)).float()
[docs] def pts_to_distogram( pts: torch.Tensor, min_bin: float = 2.2325, max_bin: float = 21.6875, num_bins: int = 64, ) -> torch.Tensor: boundaries = torch.linspace(min_bin, max_bin, steps=(num_bins - 1), device=pts.device) dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)) return torch.bucketize(dists, boundaries, right=False)
[docs] def dict_multimap(fn: Callable, dicts: List[Dict[Any, Any]]) -> Dict[Any, Any]: first = dicts[0] new_dict = {} for k, v in first.items(): all_v = [d[k] for d in dicts] if type(v) is dict: new_dict[k] = dict_multimap(fn, all_v) else: new_dict[k] = fn(all_v) return new_dict
[docs] def batched_gather( data: torch.Tensor, inds: torch.Tensor, dim: int = 0, num_batch_dims: int = 0, ) -> torch.Tensor: ranges = [] for i, s in enumerate(data.shape[:num_batch_dims]): r = torch.arange(s) r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) ranges.append(r) remaining_dims = [slice(None) for _ in range(len(data.shape) - num_batch_dims)] remaining_dims[dim - num_batch_dims if dim >= 0 else dim] = inds ranges.extend(remaining_dims) return data[ranges]
T = TypeVar("T") # With tree_map, a poor man's JAX tree_map
[docs] def dict_map( fn: Callable[[T], Any], dic: Dict[Any, Union[dict, list, tuple, T]], leaf_type: Type[T], ) -> Dict[Any, Union[dict, list, tuple, Any]]: new_dict: Dict[Any, Union[dict, list, tuple, Any]] = {} for k, v in dic.items(): if isinstance(v, dict): new_dict[k] = dict_map(fn, v, leaf_type) else: new_dict[k] = tree_map(fn, v, leaf_type) return new_dict
@overload def tree_map(fn: Callable[[T], Any], tree: T, leaf_type: Type[T]) -> Any: ... @overload def tree_map(fn: Callable[[T], Any], tree: dict, leaf_type: Type[T]) -> dict: ... @overload def tree_map(fn: Callable[[T], Any], tree: list, leaf_type: Type[T]) -> list: ... @overload def tree_map(fn: Callable[[T], Any], tree: tuple, leaf_type: Type[T]) -> tuple: ...
[docs] def tree_map(fn, tree, leaf_type): if isinstance(tree, dict): return dict_map(fn, tree, leaf_type) elif isinstance(tree, list): return [tree_map(fn, x, leaf_type) for x in tree] elif isinstance(tree, tuple): return tuple(tree_map(fn, x, leaf_type) for x in tree) elif isinstance(tree, leaf_type): return fn(tree) else: print(type(tree)) raise ValueError("Not supported")
array_tree_map = partial(tree_map, leaf_type=np.ndarray) tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
[docs] def collate(samples: List[dict]) -> dict: """Converts list of samples into a batch dict.""" assert isinstance(samples, list) assert len(samples) > 0 sample_0 = samples[0] assert isinstance(sample_0, dict) batch = {} for key in list(sample_0.keys()): batch[key] = [sample[key] for sample in samples] if isinstance(sample_0[key], torch.Tensor): batch[key] = torch.stack(batch[key], dim=0) return batch