Skip to content

Update tf.cross, preserve quaternion dtype, and new quaternion computing function #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions tfquaternion/test_tfquaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,71 @@ def test_quaternion_to_tensor(self):
self.assertAllEqual(type(tf.convert_to_tensor(ref)), tf.Tensor)
# reduce sum internally calls tf.convert_to_tensor
self.assertAllEqual(tf.reduce_sum(ref), 1.0)

def test_get_rotation_quaternion_from_u_to_v(self):
for dtype, epsilon in ((tf.float16, 1e-2), (tf.float32, 1e-5), (tf.float64, 1e-8)):
test_vector_count = 20
u_list = tf.random.uniform((test_vector_count, 3), dtype=dtype)
u_list = tf.math.l2_normalize(u_list, axis=1)
v_list = tf.random.uniform((test_vector_count, 3), dtype=dtype)
v_list = tf.math.l2_normalize(v_list, axis=1)

# append some hard-coded test cases, to make sure that we hit the special case
# in a special case that is flipping the x-axis, as well as some others like
# no rotation.
u_manuals = tf.constant(
(
(1, 0, 0),
(1, 0, 0),
(1, 0, 0),
(-1, 0, 0),
(-1, 0, 0),
(-1, 0, 0),
(1, 1, 1),
(1, 1, 1),
),
dtype=dtype
)
v_manuals = tf.constant(
(
(1, 0, 0),
(-1, 0, 0),
(0, 1, 0),
(-1, 0, 0),
(1, 0, 0),
(0, 1, 0),
(1, 1, 1),
(-1, -1, -1),
),
dtype=dtype
)

u_list = tf.concat([u_list, u_manuals], 0)
v_list = tf.concat([v_list, v_manuals], 0)
count = u_list.shape[0]

# Test that we can correctly rotate u onto v individually
for u, v in zip(u_list, v_list):
q = tfq.get_rotation_quaternion_from_u_to_v(u, v, epsilon=epsilon)
v_computed = tfq.rotate_vector_by_quaternion(q, u)

self.assertAllLess(v - v_computed, epsilon)

# Test that flipping u works correctly
for u in u_list:
q = tfq.get_rotation_quaternion_from_u_to_v(u, -u, epsilon=epsilon)
neg_u = tfq.rotate_vector_by_quaternion(q, u)
self.assertAllLess(u + neg_u, epsilon)

# Test that we can correctly rotate u onto v in a batch
q = tfq.get_rotation_quaternion_from_u_to_v(u_list, v_list, epsilon=epsilon)
v_computed = tfq.rotate_vector_by_quaternion(q, u_list)
self.assertAllLess(v_list - v_computed, epsilon)

# Test that flipping u works in a batch
q = tfq.get_rotation_quaternion_from_u_to_v(u_list, -u_list, epsilon=epsilon)
neg_u = tfq.rotate_vector_by_quaternion(q, u_list)
self.assertAllLess(u_list + neg_u, epsilon)

class QuaternionTest(AutoEvalTestCase):
""" Tests for the member functions of class Quaternion """
Expand Down
99 changes: 70 additions & 29 deletions tfquaternion/tfquaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
# limitations under the License.

"""

This small library implements quaternion operations in tensorflow.
All operations are differentiable.

"""
import tensorflow as tf

Expand All @@ -25,24 +23,22 @@
# Quaternion module functions
def scope_wrapper(func, *args, **kwargs):
"""Create a tf name scope around the function with its name."""

def scoped_func(*args, **kwargs):
with tf.name_scope("quaternion_{}".format(func.__name__)):
return func(*args, **kwargs)

return scoped_func


@scope_wrapper
def vector3d_to_quaternion(x):
"""Convert a tensor of 3D vectors to a quaternion.

Prepends a 0 to the last dimension, i.e. [[1,2,3]] -> [[0,1,2,3]].

Args:
x: A `tf.Tensor` of rank R, the last dimension must be 3.

Returns:
A `Quaternion` of Rank R with the last dimension being 4.

Raises:
ValueError, if the last dimension of x is not 3.
"""
Expand All @@ -62,7 +58,6 @@ def quaternion_to_vector3d(q):
@scope_wrapper
def _prepare_tensor_for_div_mul(x):
"""Prepare the tensor x for division/multiplication.

This function
a) converts x to a tensor if necessary,
b) prepends a 0 in the last dimension if the last dimension is 3,
Expand All @@ -79,12 +74,9 @@ def _prepare_tensor_for_div_mul(x):
@scope_wrapper
def quaternion_multiply(a, b):
"""Multiply two quaternion tensors.

Note that this differs from tf.multiply and is not commutative.

Args:
a, b: A `tf.Tensor` with shape (..., 4).

Returns:
A `Quaternion`.
"""
Expand All @@ -102,11 +94,9 @@ def quaternion_multiply(a, b):
@scope_wrapper
def quaternion_divide(a, b):
"""Divide tensor `a` by quaternion tensor `b`. `a` may be a scalar value.

Args:
a: A scalar or `tf.Tensor` with shape (..., 4).
b: A `tf.Tensor` with shape (..., 4).

Returns:
A `Quaternion`.
"""
Expand All @@ -132,27 +122,24 @@ def quaternion_conjugate(q):
@scope_wrapper
def rotate_vector_by_quaternion(q, v, q_ndims=None, v_ndims=None):
"""Rotate a vector (or tensor with last dimension of 3) by q.

This function computes v' = q * v * conjugate(q) but faster.
Fast version can be found here:
https://blog.molecular-matters.com/2013/05/24/a-faster-quaternion-vector-multiplication/

Args:
q: A `Quaternion` or `tf.Tensor` with shape (..., 4)
v: A `tf.Tensor` with shape (..., 3)
q_ndims: The number of dimensions of q. Only necessary to specify if
the shape of q is unknown.
v_ndims: The number of dimensions of v. Only necessary to specify if
the shape of v is unknown.

Returns: A `tf.Tensor` with the broadcasted shape of v and q.
"""
v = tf.convert_to_tensor(v)
q = q.normalized()
w = q.value()[..., 0]
q_xyz = q.value()[..., 1:]
# Broadcast shapes. Todo(phil): Prepare a pull request which adds
# broadcasting support to tf.cross
# broadcasting support to tf.linalg.cross
if q_xyz.shape.ndims is not None:
q_ndims = q_xyz.shape.ndims
if v.shape.ndims is not None:
Expand All @@ -163,8 +150,67 @@ def rotate_vector_by_quaternion(q, v, q_ndims=None, v_ndims=None):
v = tf.expand_dims(v, axis=0) + tf.zeros_like(q_xyz)
q_xyz += tf.zeros_like(v)
v += tf.zeros_like(q_xyz)
t = 2 * tf.cross(q_xyz, v)
return v + tf.expand_dims(w, axis=-1) * t + tf.cross(q_xyz, t)
t = 2 * tf.linalg.cross(q_xyz, v)
return v + tf.expand_dims(w, axis=-1) * t + tf.linalg.cross(q_xyz, t)


@scope_wrapper
def get_rotation_quaternion_from_u_to_v(u, v, epsilon=1e-6, dtype=tf.float32):
"""
Return a quaternion that will rotate one vector u onto another, v.

Given v amd u, this function computes q such that v = rotate_vector_by_quaternion(q, u).

Pseudocode was adapted from https://stackoverflow.com/questions/1171849/finding-
quaternion-representing-the-rotation-from-one-vector-to-another#:~:text=One%20
solution%20is%20to%20compute,all%20the%20way%20to%20v!

Args:
u: A `tf.Tensor` with shape (..., 3)
v: A `tf.Tensor` with shape (..., 3)
epsilon: a float, optional. A small number used to avoid divide by zero, which
occurs when the rotation is a full flip (v -> -1)
dtype: The type used for the quaternion, must be a floating point
number, i.e. one of tf.float16, tf.float32, tf.float64.

Returns: A `Quaternion` with shape (..., 4)
"""
u = tf.math.l2_normalize(u, axis=-1)
v = tf.math.l2_normalize(v, axis=-1)

# The dot / cross product determination produces twice the desired rotation.
# Half the rotation can be accomplished by averaging the rotation produced here with a zero
# rotation, which is accomplished by simply adding 1 to the dot product.
dot = tf.reduce_sum(tf.multiply(u, v), -1, keepdims=True) + 1
cross = tf.linalg.cross(u, v)

# If the rotation is a pure 180 flip then the dot product will be -1 and the cross
# will be (0, 0, 0), so that quaternion generated will be a zero quaterion, which is
# not correct.
# We need to construct the quaternion (0, x, y, z) where (x, y, z) is any vector
# orthogonal to u. To get this, try u cross x-axis, and in the rare case where
# that is zero because u is already the x-axis, then crossing with the y axis is
# guarenteed to work, because there is no vector that can be orthogonal to both
x_axis = tf.cast(tf.broadcast_to((1, 0, 0), u.shape), dtype)
y_axis = tf.cast(tf.broadcast_to((0, 1, 0), u.shape), dtype)
ortho_x = tf.linalg.cross(u, x_axis)
ortho_y = tf.linalg.cross(u, y_axis)
is_parallel_to_x_axis = tf.less(tf.reduce_sum(tf.multiply(ortho_x, ortho_x), -1, keepdims=True), epsilon)
ortho = tf.where(
is_parallel_to_x_axis,
ortho_y,
ortho_x
)

q_untested = Quaternion(tf.math.l2_normalize(tf.concat(
[dot, cross], -1), axis=-1), dtype=dtype)
q_ortho = vector3d_to_quaternion(ortho)
q = tf.where(
tf.less(dot, epsilon),
q_ortho,
q_untested
)
return Quaternion(q, dtype=dtype)


# ____________________________________________________________________________
Expand All @@ -180,34 +226,32 @@ class Quaternion(object):
# https://stackoverflow.com/questions/40694380/forcing-multiplication-to-use-rmul-instead-of-numpy-array-mul-or-byp)
__array_priority__ = 1000

def __init__(self, wxyz=(1, 0, 0, 0), dtype=tf.float32, name=None):
def __init__(self, wxyz=(1, 0, 0, 0), dtype=None, name=None):
"""The quaternion constructor.

Args:
wxyz: The values for w, x, y, z, a `tf.Tensor` with shape (..., 4).
Note that quaternions only support floating point numbers.
Defaults to (1.0, 0.0, 0.0, 0.0)
dtype: The type used for the quaternion, must be a floating point
number, i.e. one of tf.float16, tf.float32, tf.float64.
name: An optional name for the tensor.

Returns:
A Quaternion.

Raises:
ValueError, if wxyz is a `tf.Tensor` and the tensors dtype differs
from the given dtype.
ValueError, if the last dimension of wxyz is not 4.
TypeError, if dtype is not a float.
"""
if dtype is None:
dtype = wxyz.dtype if isinstance(wxyz, (tf.Tensor, tf.Variable. np.ndarray)) else tf.float32
self._q = tf.convert_to_tensor(wxyz, dtype=dtype, name=name)
self.name = name if name else ""
self.validate_type(self._q)
self.validate_shape(self._q) # check that shape is (..., 4)

def value(self):
"""The `Tensor` which holds the value of the quaternion.

Note that this does not return a reference, so you can not alter the
quaternion through this.
"""
Expand Down Expand Up @@ -238,7 +282,6 @@ def graph(self):
@property
def shape(self):
"""The `TensorShape` of the variable. Is always [..., 4].

Returns:
A `TensorShape`.
"""
Expand Down Expand Up @@ -325,32 +368,30 @@ def conjugate(self):

def conj(self):
"""Compute the conjugate of self.q, i.e. [w, -x, -y, -z].

Alias for Quaternion.conjugate().
"""
return quaternion_conjugate(self)

@scope_wrapper
def inverse(self):
"""Compute the inverse of the quaternion, i.e. q.conjugate / q.norm."""
return Quaternion(tf.convert_to_tensor(self.conjugate()) / self.norm())
return Quaternion(tf.convert_to_tensor(self.conjugate()) / self.norm(), dtype=self.dtype)

@scope_wrapper
def normalized(self):
"""Compute the normalized quaternion."""
return Quaternion(tf.divide(self._q, self.abs()))
return Quaternion(tf.divide(self._q, self.abs()), dtype=self.dtype)

@scope_wrapper
def as_rotation_matrix(self):
"""Calculate the corresponding rotation matrix.

See
http://www.euclideanspace.com/maths/geometry/rotations/conversions/quaternionToMatrix/

Returns:
A `tf.Tensor` with R+1 dimensions and
shape [d_1, ..., d_(R-1), 3, 3], the rotation matrix
"""

# helper functions
def diag(a, b): # computes the diagonal entries, 1 - 2*a**2 - 2*b**2
return 1 - 2 * tf.pow(a, 2) - 2 * tf.pow(b, 2)
Expand Down