Skip to content

Commit

Permalink
adapt test to new function structure
Browse files Browse the repository at this point in the history
  • Loading branch information
florian-huber committed Apr 24, 2023
1 parent af3bb67 commit 80b050c
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/test_data_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from ms2deepscore import SpectrumBinner
from ms2deepscore.data_generators import (DataGeneratorAllInchikeys,
DataGeneratorAllSpectrums,
DataGeneratorBase)
_exclude_nans_from_labels,
_validate_labels)
from tests.test_user_worfklow import load_processed_spectrums, get_reference_scores


Expand Down Expand Up @@ -316,11 +317,11 @@ def test_validate_labels():
# Test case 1: reference_scores_df with different index and column names
ref_scores = pd.DataFrame({'A1': [0.5, 0.6], 'A2': [0.7, 0.8]}, index=['B1', 'B2'])
with pytest.raises(ValueError):
DataGeneratorBase._validate_labels(ref_scores)
_validate_labels(ref_scores)

# Test case 2: reference_scores_df with identical index and column names
ref_scores = pd.DataFrame({'A1': [0.5, 0.6], 'A2': [0.7, 0.8]}, index=['A1', 'A2'])
DataGeneratorBase._validate_labels(ref_scores) # Should not raise ValueError
_validate_labels(ref_scores) # Should not raise ValueError


def test_exclude_nans_from_labels():
Expand All @@ -334,7 +335,7 @@ def test_exclude_nans_from_labels():
reference_scores_df = pd.DataFrame(data, index=["A", "B", "C", "D"])

# Call the _exclude_nans_from_labels method
clean_df = DataGeneratorBase._exclude_nans_from_labels(reference_scores_df)
clean_df = _exclude_nans_from_labels(reference_scores_df)

# Expected DataFrame after removing rows and columns with NaN values
expected_data = {
Expand Down

0 comments on commit 80b050c

Please sign in to comment.