Skip to content

Commit

Permalink
Merge pull request #203 from Bihaqo/bilinear_ab
Browse files Browse the repository at this point in the history
Bilinear form with two matrices
  • Loading branch information
Bihaqo authored Feb 23, 2020
2 parents 0361b80 + bc7d781 commit 0ce4d12
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
69 changes: 69 additions & 0 deletions t3f/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,75 @@ def bilinear_form(A, b, c, name='t3f_bilinear_form'):
return tf.squeeze(res)


def bilinear_form_two_mat(x, A, B, y, name='t3f_bilinear_xaby'):
"""Bilinear form x^t A B y; A are B are TT-matrices, x and y can be batches.
Args:
x: `TensorTrain` object containing a TT-matrix of size N x 1
or `TensorTrainBatch` with a batch of TT-matrices of size N x 1.
A: `TensorTrain` object containing a TT-matrix of size N x M.
B: `TensorTrain` object containing a TT-matrix of size M x K.
y: `TensorTrain` object containing a TT-matrix of size K x 1
or `TensorTrainBatch` with a batch of TT-matrices of size K x 1.
name: string, name of the Op.
Returns:
A number, the value of the bilinear form if all the arguments are
`TensorTrain`s.
OR tf.Tensor of size batch_size if at least one of the arguments is
`TensorTrainBatch`
Raises:
ValueError if the arguments are not TT-matrices or if the shapes are
not consistent.
"""
for matrix in [A, B]:
if not isinstance(matrix, TensorTrainBase) or not matrix.is_tt_matrix():
raise ValueError('The arguments should be a TT-matrix.')

# TODO: support tf.Tensor as x and y.
for vec in [x, y]:
if not isinstance(vec, TensorTrainBase) or not vec.is_tt_matrix():
raise ValueError('The arguments should be a TT-matrix.')

x_is_batch = isinstance(x, TensorTrainBatch)
y_is_batch = isinstance(x, TensorTrainBatch)
x_bs_str = 'p' if x_is_batch else ''
y_bs_str = 'p' if y_is_batch else ''
out_bs_str = 'p' if x_is_batch or y_is_batch else ''
all_cores = x.tt_cores + A.tt_cores + B.tt_cores + y.tt_cores
with tf.name_scope(name, values=all_cores):
ndims = A.ndims()
curr_core_1 = x.tt_cores[0]
curr_core_2 = y.tt_cores[0]
curr_matrix_core_1 = A.tt_cores[0]
curr_matrix_core_2 = B.tt_cores[0]
# We enumerate the dummy dimension (that takes 1 value) with `k`.
# You may think that using two different k would be faster, but in my
# experience it's even a little bit slower (but neglectable in general).
einsum_str = '{0}elnf,glph,ipoj,{1}aomb->{2}fhjb'.format(x_bs_str, y_bs_str,
out_bs_str)
res = tf.einsum(einsum_str, curr_core_1, curr_matrix_core_1, curr_matrix_core_2,
curr_core_2)
for core_idx in range(1, ndims):
curr_core_1 = x.tt_cores[core_idx]
curr_core_2 = y.tt_cores[core_idx]
curr_matrix_core_1 = A.tt_cores[core_idx]
curr_matrix_core_2 = B.tt_cores[core_idx]
einsum_str = '{2}egia,{0}elnf,glph,ipoj,{1}aomb->{2}fhjb'.format(x_bs_str,
y_bs_str,
out_bs_str)
res = tf.einsum(einsum_str, res, curr_core_1,
curr_matrix_core_1, curr_matrix_core_2,
curr_core_2)

# Squeeze to make the result a number instead of 1 x 1 for NON batch case
# and to make the result a tensor of size
# batch_size
# instead of
# batch_size x 1 x 1
# in the batch case.
return tf.squeeze(res)


def cast(tt, dtype, name='t3f_cast'):
"""Casts a tt-tensor to a new type.
Expand Down
24 changes: 24 additions & 0 deletions t3f/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,30 @@ def testBilinearFormBatch(self):
self.assertAllClose(res_actual_val, np.squeeze(res_desired),
atol=1e-5, rtol=1e-5)

def testBilinearFormTwoMat(self):
# Test bilinear_form_two_mat.
shape_list = (((2, 2), (3, 4)),
((2, 3, 4), (2, 2, 2)))
rank_list = (1, 2)
with self.test_session() as sess:
for tensor_shape in shape_list:
for rank in rank_list:
A = initializers.random_matrix(tensor_shape, tt_rank=rank,
dtype=self.dtype)
B = initializers.random_matrix(tensor_shape, tt_rank=rank,
dtype=self.dtype)
B = ops.transpose(B)
x = initializers.random_matrix((tensor_shape[0], None), tt_rank=rank,
dtype=self.dtype)
y = initializers.random_matrix((tensor_shape[0], None), tt_rank=rank,
dtype=self.dtype)
res_actual = ops.bilinear_form_two_mat(x, A, B, y)
vars = [res_actual, ops.full(x), ops.full(A), ops.full(B), ops.full(y)]
res_actual_val, x_val, A_val, B_val, y_val = sess.run(vars)
res_desired = x_val.T.dot(A_val).dot(B_val).dot(y_val)
self.assertAllClose(res_actual_val, np.squeeze(res_desired),
atol=1e-5, rtol=1e-5)

def testCastFloat(self):
# Test cast function for float tt-matrices and vectors.

Expand Down

0 comments on commit 0ce4d12

Please sign in to comment.