deepfold.utils.rigid_utils.Rigid

class deepfold.utils.rigid_utils.Rigid(rots: Rotation | None, trans: Tensor | None)[source]

A class representing a rigid transformation. Little more than a wrapper around two objects: a Rotation object and a [*, 3] translation Designed to behave approximately like a single torch tensor with the shape of the shared batch dimensions of its component parts.

__init__(rots: Rotation | None, trans: Tensor | None)[source]
Parameters:
  • rots – A [*, 3, 3] rotation tensor

  • trans – A corresponding [*, 3] translation tensor

Methods

__init__(rots, trans)

apply(pts)

Applies the transformation to a coordinate tensor.

apply_rot_fn(fn)

Applies a Rotation -> Rotation function to the stored rotation object.

apply_trans_fn(fn)

Applies a Tensor -> Tensor function to the stored translation.

cat(ts, dim)

Concatenates transformations along a new dimension.

compose(r)

Composes the current rigid object with another.

compose_q_update_vec(q_update_vec)

Composes the transformation with a quaternion update vector of shape [*, 6], where the final 6 columns represent the x, y, and z values of a quaternion of form (1, x, y, z) followed by a 3D translation.

cuda()

Moves the transformation object to GPU memory

from_3_points(p_neg_x_axis, origin, p_xy_plane)

Implements algorithm 21.

from_tensor_4x4(t)

Constructs a transformation from a homogenous transformation tensor.

from_tensor_7(t[, normalize_quats])

get_rots()

Getter for the rotation.

get_trans()

Getter for the translation.

identity(shape[, dtype, device, ...])

Constructs an identity transformation.

invert()

Inverts the transformation.

invert_apply(pts)

Applies the inverse of the transformation to a coordinate tensor.

make_transform_from_reference(n_xyz, ca_xyz, ...)

Returns a transformation object from reference coordinates.

map_tensor_fn(fn)

Apply a Tensor -> Tensor function to underlying translation and rotation tensors, mapping over the translation/rotation dimensions respectively.

scale_translation(trans_scale_factor)

Scales the translation by a constant factor.

stop_rot_gradient()

Detaches the underlying rotation object

to_tensor_4x4()

Converts a transformation to a homogenous transformation tensor.

to_tensor_7()

Converts a transformation to a tensor with 7 final columns, four for the quaternion followed by three for the translation.

unsqueeze(dim)

Analogous to torch.unsqueeze.

Attributes

device

Returns the device on which the Rigid's tensors are located.

dtype

Returns the dtype of the Rigid tensors.

shape

Returns the shape of the shared dimensions of the rotation and the translation.