"""Rigid3Array Transformations represented by a Matrix and a Vector."""
from __future__ import annotations
import dataclasses
from typing import List, Union
import torch
from deepfold.utils.geometry import rotation_matrix, vector
Float = Union[float, torch.Tensor]
[docs]
@dataclasses.dataclass(frozen=True)
class Rigid3Array:
"""Rigid Transformation, i.e. element of special euclidean group."""
rotation: rotation_matrix.Rot3Array
translation: vector.Vec3Array
def __matmul__(self, other: Rigid3Array) -> Rigid3Array:
new_rotation = self.rotation @ other.rotation # __matmul__
new_translation = self.apply_to_point(other.translation)
return Rigid3Array(new_rotation, new_translation)
def __getitem__(self, index) -> Rigid3Array:
return Rigid3Array(
self.rotation[index],
self.translation[index],
)
def __mul__(self, other: torch.Tensor) -> Rigid3Array:
return Rigid3Array(
self.rotation * other,
self.translation * other,
)
[docs]
def map_tensor_fn(self, fn) -> Rigid3Array:
return Rigid3Array(
self.rotation.map_tensor_fn(fn),
self.translation.map_tensor_fn(fn),
)
[docs]
def inverse(self) -> Rigid3Array:
"""Return Rigid3Array corresponding to inverse transform."""
inv_rotation = self.rotation.inverse()
inv_translation = inv_rotation.apply_to_point(-self.translation)
return Rigid3Array(inv_rotation, inv_translation)
[docs]
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply Rigid3Array transform to point."""
return self.rotation.apply_to_point(point) + self.translation
[docs]
def apply(self, point: torch.Tensor) -> torch.Tensor:
return self.apply_to_point(vector.Vec3Array.from_array(point)).to_tensor()
[docs]
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply inverse Rigid3Array transform to point."""
new_point = point - self.translation
return self.rotation.apply_inverse_to_point(new_point)
[docs]
def invert_apply(self, point: torch.Tensor) -> torch.Tensor:
return self.apply_inverse_to_point(vector.Vec3Array.from_array(point)).to_tensor()
[docs]
def compose_rotation(self, other_rotation):
rot = self.rotation @ other_rotation
return Rigid3Array(rot, self.translation.clone())
[docs]
def compose(self, other_rigid):
return self @ other_rigid
[docs]
def unsqueeze(self, dim: int):
return Rigid3Array(
self.rotation.unsqueeze(dim),
self.translation.unsqueeze(dim),
)
@property
def shape(self) -> torch.Size:
return self.rotation.xx.shape
@property
def dtype(self) -> torch.dtype:
return self.rotation.xx.dtype
@property
def device(self) -> torch.device:
return self.rotation.xx.device
[docs]
@classmethod
def identity(cls, shape, device) -> Rigid3Array:
"""Return identity Rigid3Array of given shape."""
return cls(rotation_matrix.Rot3Array.identity(shape, device), vector.Vec3Array.zeros(shape, device))
[docs]
@classmethod
def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array:
return cls(
rotation_matrix.Rot3Array.cat([r.rotation for r in rigids], dim=dim),
vector.Vec3Array.cat([r.translation for r in rigids], dim=dim),
)
[docs]
def scale_translation(self, factor: Float) -> Rigid3Array:
"""Scale translation in Rigid3Array by 'factor'."""
return Rigid3Array(self.rotation, self.translation * factor)
[docs]
def to_tensor(self) -> torch.Tensor:
rot_array = self.rotation.to_tensor()
vec_array = self.translation.to_tensor()
array = torch.zeros(rot_array.shape[:-2] + (4, 4), device=rot_array.device, dtype=rot_array.dtype)
array[..., :3, :3] = rot_array
array[..., :3, 3] = vec_array
array[..., 3, 3] = 1.0
return array
[docs]
def to_tensor_4x4(self) -> torch.Tensor:
return self.to_tensor()
[docs]
def reshape(self, new_shape) -> Rigid3Array:
rots = self.rotation.reshape(new_shape)
trans = self.translation.reshape(new_shape)
return Rigid3Array(rots, trans)
[docs]
def stop_rot_gradient(self) -> Rigid3Array:
return Rigid3Array(
self.rotation.stop_gradient(),
self.translation,
)
[docs]
@classmethod
def from_array(cls, array):
rot = rotation_matrix.Rot3Array.from_array(
array[..., :3, :3],
)
vec = vector.Vec3Array.from_array(array[..., :3, 3])
return cls(rot, vec)
[docs]
@classmethod
def from_tensor_4x4(cls, array):
return cls.from_array(array)
[docs]
@classmethod
def from_array4x4(cls, array: torch.Tensor) -> Rigid3Array:
"""Construct Rigid3Array from homogeneous 4x4 array."""
rotation = rotation_matrix.Rot3Array(
array[..., 0, 0],
array[..., 0, 1],
array[..., 0, 2],
array[..., 1, 0],
array[..., 1, 1],
array[..., 1, 2],
array[..., 2, 0],
array[..., 2, 1],
array[..., 2, 2],
)
translation = vector.Vec3Array(array[..., 0, 3], array[..., 1, 3], array[..., 2, 3])
return cls(rotation, translation)
[docs]
def cuda(self) -> Rigid3Array:
return Rigid3Array.from_tensor_4x4(self.to_tensor_4x4().cuda())