diff --git a/src/baskerville/metrics.py b/src/baskerville/metrics.py index a1cbfaf..58d69ce 100644 --- a/src/baskerville/metrics.py +++ b/src/baskerville/metrics.py @@ -128,7 +128,7 @@ def poisson_multinomial( rescale (bool): Rescale loss after re-weighting. """ seq_len = y_true.shape[1] - + if weight_range < 1: raise ValueError("Poisson Multinomial weight_range must be >=1") elif weight_range == 1: @@ -147,8 +147,8 @@ def poisson_multinomial( y_pred = tf.math.multiply(y_pred, position_weights) # sum across lengths - s_true = tf.math.reduce_sum(y_true, axis=-2) # B x T - s_pred = tf.math.reduce_sum(y_pred, axis=-2) # B x T + s_true = tf.math.reduce_sum(y_true, axis=-2) # B x T + s_pred = tf.math.reduce_sum(y_pred, axis=-2) # B x T # total count poisson loss, mean across targets poisson_term = poisson(s_true, s_pred) # B x T @@ -159,7 +159,7 @@ def poisson_multinomial( y_pred += epsilon # normalize to sum to one - p_pred = y_pred / tf.expand_dims(s_pred, axis=-2) # B x L x T + p_pred = y_pred / tf.expand_dims(s_pred, axis=-2) # B x L x T # multinomial loss pl_pred = tf.math.log(p_pred) # B x L x T diff --git a/tests/test_metrics.py b/tests/test_metrics.py index de235fb..509384f 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -9,27 +9,29 @@ # data dimensions N, L, T = 6, 8, 4 + @pytest.fixture def sample_data(): y_true = tf.random.uniform((N, L, T), minval=0, maxval=10, dtype=tf.float32) y_pred = y_true + tf.random.normal((N, L, T), mean=0, stddev=0.1) return y_true, y_pred + def test_PearsonR(sample_data): y_true, y_pred = sample_data pearsonr = PearsonR(num_targets=T, summarize=False) pearsonr.update_state(y_true, y_pred) tf_result = pearsonr.result().numpy() - + # Compute SciPy result scipy_result = np.zeros(T) y_true_np = y_true.numpy().reshape(-1, T) y_pred_np = y_pred.numpy().reshape(-1, T) for i in range(T): scipy_result[i], _ = stats.pearsonr(y_true_np[:, i], y_pred_np[:, i]) - + np.testing.assert_allclose(tf_result, scipy_result, rtol=1e-5, atol=1e-5) - + # Test summarized result pearsonr_summarized = PearsonR(num_targets=T, summarize=True) pearsonr_summarized.update_state(y_true, y_pred) @@ -37,27 +39,31 @@ def test_PearsonR(sample_data): assert tf_result_summarized.shape == () assert np.isclose(tf_result_summarized, np.mean(scipy_result), rtol=1e-5, atol=1e-5) + def test_R2(sample_data): y_true, y_pred = sample_data r2 = R2(num_targets=T, summarize=False) r2.update_state(y_true, y_pred) tf_result = r2.result().numpy() - + # Compute sklearn result sklearn_result = np.zeros(T) y_true_np = y_true.numpy().reshape(-1, T) y_pred_np = y_pred.numpy().reshape(-1, T) for i in range(T): sklearn_result[i] = r2_score(y_true_np[:, i], y_pred_np[:, i]) - + np.testing.assert_allclose(tf_result, sklearn_result, rtol=1e-5, atol=1e-5) - + # Test summarized result r2_summarized = R2(num_targets=T, summarize=True) r2_summarized.update_state(y_true, y_pred) tf_result_summarized = r2_summarized.result().numpy() assert tf_result_summarized.shape == () - assert np.isclose(tf_result_summarized, np.mean(sklearn_result), rtol=1e-5, atol=1e-5) + assert np.isclose( + tf_result_summarized, np.mean(sklearn_result), rtol=1e-5, atol=1e-5 + ) + def test_poisson_multinomial_shape(sample_data): y_true, y_pred = sample_data