Source code for deepfold.utils.geometry.rotation_matrix

"""Rot3Array Matrix Class."""

from __future__ import annotations

import dataclasses
from typing import List

import torch

from deepfold.utils.geometry import utils, vector
from deepfold.utils.tensor_utils import tensor_tree_map

COMPONENTS = ["xx", "xy", "xz", "yx", "yy", "yz", "zx", "zy", "zz"]


[docs] @dataclasses.dataclass(frozen=True) class Rot3Array: """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays.""" xx: torch.Tensor = dataclasses.field(metadata={"dtype": torch.float32}) xy: torch.Tensor xz: torch.Tensor yx: torch.Tensor yy: torch.Tensor yz: torch.Tensor zx: torch.Tensor zy: torch.Tensor zz: torch.Tensor __array_ufunc__ = None def __getitem__(self, index): field_names = utils.get_field_names(Rot3Array) return Rot3Array(**{name: getattr(self, name)[index] for name in field_names}) def __mul__(self, other: torch.Tensor): field_names = utils.get_field_names(Rot3Array) return Rot3Array(**{name: getattr(self, name) * other for name in field_names}) def __matmul__(self, other: Rot3Array) -> Rot3Array: """Composes two Rot3Arrays.""" c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx)) c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy)) c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz)) return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
[docs] def map_tensor_fn(self, fn) -> Rot3Array: field_names = utils.get_field_names(Rot3Array) return Rot3Array(**{name: fn(getattr(self, name)) for name in field_names})
[docs] def inverse(self) -> Rot3Array: """Returns inverse of Rot3Array.""" return Rot3Array(self.xx, self.yx, self.zx, self.xy, self.yy, self.zy, self.xz, self.yz, self.zz)
[docs] def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: """Applies Rot3Array to point.""" return vector.Vec3Array( self.xx * point.x + self.xy * point.y + self.xz * point.z, self.yx * point.x + self.yy * point.y + self.yz * point.z, self.zx * point.x + self.zy * point.y + self.zz * point.z, )
[docs] def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: """Applies inverse Rot3Array to point.""" return self.inverse().apply_to_point(point)
[docs] def unsqueeze(self, dim: int): return Rot3Array(*tensor_tree_map(lambda t: t.unsqueeze(dim), [getattr(self, c) for c in COMPONENTS]))
[docs] def stop_gradient(self) -> Rot3Array: return Rot3Array(*[getattr(self, c).detach() for c in COMPONENTS])
[docs] @classmethod def identity(cls, shape, device) -> Rot3Array: """Returns identity of given shape.""" ones = torch.ones(shape, dtype=torch.float32, device=device) zeros = torch.zeros(shape, dtype=torch.float32, device=device) return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones)
[docs] @classmethod def from_two_vectors(cls, e0: vector.Vec3Array, e1: vector.Vec3Array) -> Rot3Array: """Construct Rot3Array from two Vectors. Rot3Array is constructed such that in the corresponding frame 'e0' lies on the positive x-Axis and 'e1' lies in the xy plane with positive sign of y. Args: e0: Vector e1: Vector Returns: Rot3Array """ # Normalize the unit vector for the x-axis, e0. e0 = e0.normalized() # make e1 perpendicular to e0. c = e1.dot(e0) e1 = e1 - c * e0 e1 = e1.normalized() # Compute e2 as cross product of e0 and e1. e2 = e0.cross(e1) return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)
[docs] @classmethod def from_array(cls, array: torch.Tensor) -> Rot3Array: """Construct Rot3Array Matrix from array of shape. [..., 3, 3].""" rows = torch.unbind(array, dim=-2) rc = [torch.unbind(e, dim=-1) for e in rows] return cls(*[e for row in rc for e in row])
[docs] def to_tensor(self) -> torch.Tensor: """Convert Rot3Array to array of shape [..., 3, 3].""" return torch.stack( [ torch.stack([self.xx, self.xy, self.xz], dim=-1), torch.stack([self.yx, self.yy, self.yz], dim=-1), torch.stack([self.zx, self.zy, self.zz], dim=-1), ], dim=-2, )
[docs] @classmethod def from_quaternion( cls, w: torch.Tensor, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, normalize: bool = True, eps: float = 1e-6, ) -> Rot3Array: """Construct Rot3Array from components of quaternion.""" if normalize: inv_norm = torch.rsqrt(torch.clamp(w**2 + x**2 + y**2 + z**2, min=eps)) w = w * inv_norm x = x * inv_norm y = y * inv_norm z = z * inv_norm xx = 1.0 - 2.0 * (y**2 + z**2) xy = 2.0 * (x * y - w * z) xz = 2.0 * (x * z + w * y) yx = 2.0 * (x * y + w * z) yy = 1.0 - 2.0 * (x**2 + z**2) yz = 2.0 * (y * z - w * x) zx = 2.0 * (x * z - w * y) zy = 2.0 * (y * z + w * x) zz = 1.0 - 2.0 * (x**2 + y**2) return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz)
[docs] def reshape(self, new_shape): field_names = utils.get_field_names(Rot3Array) reshape_fn = lambda t: t.reshape(new_shape) return Rot3Array(**{name: reshape_fn(getattr(self, name)) for name in field_names})
[docs] @classmethod def cat(cls, rots: List[Rot3Array], dim: int) -> Rot3Array: field_names = utils.get_field_names(Rot3Array) cat_fn = lambda l: torch.cat(l, dim=dim) return cls(**{name: cat_fn([getattr(r, name) for r in rots]) for name in field_names})