diff --git a/t3f/ops.py b/t3f/ops.py index b414a92..e292369 100644 --- a/t3f/ops.py +++ b/t3f/ops.py @@ -1118,6 +1118,75 @@ def bilinear_form(A, b, c, name='t3f_bilinear_form'): return tf.squeeze(res) +def bilinear_form_two_mat(x, A, B, y, name='t3f_bilinear_xaby'): + """Bilinear form x^t A B y; A are B are TT-matrices, x and y can be batches. + + Args: + x: `TensorTrain` object containing a TT-matrix of size N x 1 + or `TensorTrainBatch` with a batch of TT-matrices of size N x 1. + A: `TensorTrain` object containing a TT-matrix of size N x M. + B: `TensorTrain` object containing a TT-matrix of size M x K. + y: `TensorTrain` object containing a TT-matrix of size K x 1 + or `TensorTrainBatch` with a batch of TT-matrices of size K x 1. + name: string, name of the Op. + Returns: + A number, the value of the bilinear form if all the arguments are + `TensorTrain`s. + OR tf.Tensor of size batch_size if at least one of the arguments is + `TensorTrainBatch` + Raises: + ValueError if the arguments are not TT-matrices or if the shapes are + not consistent. + """ + for matrix in [A, B]: + if not isinstance(matrix, TensorTrainBase) or not matrix.is_tt_matrix(): + raise ValueError('The arguments should be a TT-matrix.') + + # TODO: support tf.Tensor as x and y. + for vec in [x, y]: + if not isinstance(vec, TensorTrainBase) or not vec.is_tt_matrix(): + raise ValueError('The arguments should be a TT-matrix.') + + x_is_batch = isinstance(x, TensorTrainBatch) + y_is_batch = isinstance(x, TensorTrainBatch) + x_bs_str = 'p' if x_is_batch else '' + y_bs_str = 'p' if y_is_batch else '' + out_bs_str = 'p' if x_is_batch or y_is_batch else '' + all_cores = x.tt_cores + A.tt_cores + B.tt_cores + y.tt_cores + with tf.name_scope(name, values=all_cores): + ndims = A.ndims() + curr_core_1 = x.tt_cores[0] + curr_core_2 = y.tt_cores[0] + curr_matrix_core_1 = A.tt_cores[0] + curr_matrix_core_2 = B.tt_cores[0] + # We enumerate the dummy dimension (that takes 1 value) with `k`. + # You may think that using two different k would be faster, but in my + # experience it's even a little bit slower (but neglectable in general). + einsum_str = '{0}elnf,glph,ipoj,{1}aomb->{2}fhjb'.format(x_bs_str, y_bs_str, + out_bs_str) + res = tf.einsum(einsum_str, curr_core_1, curr_matrix_core_1, curr_matrix_core_2, + curr_core_2) + for core_idx in range(1, ndims): + curr_core_1 = x.tt_cores[core_idx] + curr_core_2 = y.tt_cores[core_idx] + curr_matrix_core_1 = A.tt_cores[core_idx] + curr_matrix_core_2 = B.tt_cores[core_idx] + einsum_str = '{2}egia,{0}elnf,glph,ipoj,{1}aomb->{2}fhjb'.format(x_bs_str, + y_bs_str, + out_bs_str) + res = tf.einsum(einsum_str, res, curr_core_1, + curr_matrix_core_1, curr_matrix_core_2, + curr_core_2) + + # Squeeze to make the result a number instead of 1 x 1 for NON batch case + # and to make the result a tensor of size + # batch_size + # instead of + # batch_size x 1 x 1 + # in the batch case. + return tf.squeeze(res) + + def cast(tt, dtype, name='t3f_cast'): """Casts a tt-tensor to a new type. diff --git a/t3f/ops_test.py b/t3f/ops_test.py index 6f5a458..1d65082 100644 --- a/t3f/ops_test.py +++ b/t3f/ops_test.py @@ -401,6 +401,30 @@ def testBilinearFormBatch(self): self.assertAllClose(res_actual_val, np.squeeze(res_desired), atol=1e-5, rtol=1e-5) + def testBilinearFormTwoMat(self): + # Test bilinear_form_two_mat. + shape_list = (((2, 2), (3, 4)), + ((2, 3, 4), (2, 2, 2))) + rank_list = (1, 2) + with self.test_session() as sess: + for tensor_shape in shape_list: + for rank in rank_list: + A = initializers.random_matrix(tensor_shape, tt_rank=rank, + dtype=self.dtype) + B = initializers.random_matrix(tensor_shape, tt_rank=rank, + dtype=self.dtype) + B = ops.transpose(B) + x = initializers.random_matrix((tensor_shape[0], None), tt_rank=rank, + dtype=self.dtype) + y = initializers.random_matrix((tensor_shape[0], None), tt_rank=rank, + dtype=self.dtype) + res_actual = ops.bilinear_form_two_mat(x, A, B, y) + vars = [res_actual, ops.full(x), ops.full(A), ops.full(B), ops.full(y)] + res_actual_val, x_val, A_val, B_val, y_val = sess.run(vars) + res_desired = x_val.T.dot(A_val).dot(B_val).dot(y_val) + self.assertAllClose(res_actual_val, np.squeeze(res_desired), + atol=1e-5, rtol=1e-5) + def testCastFloat(self): # Test cast function for float tt-matrices and vectors.