[docs]defforward(self,activations:torch.Tensor)->Rigid3Array:# NOTE: During training, this needs to be run in higher precisionrigid_flat=self.linear(activations)rigid_flat=torch.unbind(rigid_flat,dim=-1)ifself.full_quat:qw,qx,qy,qz=rigid_flat[:4]translation=rigid_flat[4:]else:qx,qy,qz=rigid_flat[:3]qw=torch.ones_like(qx)translation=rigid_flat[3:]rotation=Rot3Array.from_quaternion(qw,qx,qy,qz,normalize=True,)translation=Vec3Array(*translation)returnRigid3Array(rotation,translation)