Skip to content

Commit

Permalink
Merge pull request #135 from matchms/generator_tests
Browse files Browse the repository at this point in the history
use dummy data for generator tests
  • Loading branch information
florian-huber authored Apr 25, 2023
2 parents 27aed60 + 80b050c commit aed2866
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 6 deletions.
1 change: 1 addition & 0 deletions ms2deepscore/data_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def _get_spectrum_with_inchikey(self, inchikey: str) -> BinnedSpectrumType:
inchikey) can have multiple measured spectrums in a binned spectrum dataset.
"""
matching_spectrum_id = np.where(self.spectrum_inchikeys == inchikey)[0]
assert len(matching_spectrum_id) > 0, "No matching inchikey found (note: expected first 14 characters)"
return self.binned_spectrums[np.random.choice(matching_spectrum_id)]

def __data_generation(self, spectrum_pairs: Iterator[SpectrumPair]):
Expand Down
135 changes: 129 additions & 6 deletions tests/test_data_generators.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,55 @@
import numpy as np
import pandas as pd
import pytest
import string
from matchms import Spectrum

from ms2deepscore import SpectrumBinner
from ms2deepscore.data_generators import DataGeneratorAllInchikeys
from ms2deepscore.data_generators import DataGeneratorAllSpectrums
from ms2deepscore.data_generators import (DataGeneratorAllInchikeys,
DataGeneratorAllSpectrums,
_exclude_nans_from_labels,
_validate_labels)
from tests.test_user_worfklow import load_processed_spectrums, get_reference_scores


def create_dummy_data():
"""Create fake data to test generators.
"""
mz, intens = 100.0, 0.1
spectrums = []

letters = list(string.ascii_uppercase[:10])

# Create fake similarities
similarities = {}
for i, letter1 in enumerate(letters):
for j, letter2 in enumerate(letters):
similarities[(letter1, letter2)] = (len(letters) - abs(i - j)) / len(letters)

tanimoto_fake = pd.DataFrame(similarities.values(),
index=similarities.keys()).unstack()

# Create fake spectra
fake_inchikeys = []
for i, letter in enumerate(letters):
dummy_inchikey = f"{14 * letter}-{10 * letter}-N"
fake_inchikeys.append(dummy_inchikey)
spectrums.append(Spectrum(mz=np.array([mz + (i+1) * 25.0]), intensities=np.array([intens]),
metadata={"inchikey": dummy_inchikey,
"compound_name": letter}))
spectrums.append(Spectrum(mz=np.array([mz + (i+1) * 25.0]), intensities=np.array([2*intens]),
metadata={"inchikey": dummy_inchikey,
"compound_name": f"{letter}-2"}))

# Set the column and index names
tanimoto_fake.columns = [x[:14] for x in fake_inchikeys]
tanimoto_fake.index = [x[:14] for x in fake_inchikeys]

ms2ds_binner = SpectrumBinner(100, mz_min=10.0, mz_max=1000.0, peak_scaling=1)
binned_spectrums = ms2ds_binner.fit_transform(spectrums)
return binned_spectrums, tanimoto_fake


def create_test_data():
spectrums = load_processed_spectrums()
tanimoto_scores_df = get_reference_scores()
Expand All @@ -27,7 +70,48 @@ def collect_results(generator, batch_size, dimension):


def test_DataGeneratorAllInchikeys():
"""Basic first test for DataGeneratorAllInchikeys"""
"""Test DataGeneratorAllInchikeys using generated data.
"""
binned_spectrums, tanimoto_scores_df = create_dummy_data()

# Define other parameters
batch_size = 10
dimension = tanimoto_scores_df.shape[0]

selected_inchikeys = tanimoto_scores_df.index
# Create generator
test_generator = DataGeneratorAllInchikeys(binned_spectrums=binned_spectrums,
selected_inchikeys=selected_inchikeys,
reference_scores_df=tanimoto_scores_df,
dim=dimension, batch_size=batch_size,
augment_removal_max=0.0,
augment_removal_intensity=0.0,
augment_intensity=0.0,
augment_noise_max=0)

A, B = test_generator.__getitem__(0)
assert binned_spectrums[0].binned_peaks == {0: 0.1}, "Something went wrong with the binning"
assert A[0].shape == A[1].shape == (batch_size, dimension), "Expected different data shape"
assert set(test_generator.indexes) == set(list(range(10))), "Something wrong with generator indices"

# Test if every inchikey was picked once (and only once):
assert (A[0] > 0).sum() == 10
assert np.all((A[0] > 0).sum(axis=1) == (A[0] > 0).sum(axis=0))

# Test many cycles --> scores properly distributed into bins?
counts = []
repetitions = 100
total = batch_size * repetitions
for _ in range(repetitions):
for i, batch in enumerate(test_generator):
counts.extend(list(batch[1]))
assert (np.array(counts) > 0.5).sum() > 0.4 * total
assert (np.array(counts) <= 0.5).sum() > 0.4 * total


def test_DataGeneratorAllInchikeys_real_data():
"""Basic first test for DataGeneratorAllInchikeys using actual data.
"""
# Get test data
binned_spectrums, tanimoto_scores_df = create_test_data()

Expand All @@ -46,7 +130,7 @@ def test_DataGeneratorAllInchikeys():
augment_intensity=0.0)

A, B = test_generator.__getitem__(0)
assert A[0].shape == A[1].shape == (10, 88), "Expected different data shape"
assert A[0].shape == A[1].shape == (batch_size, dimension), "Expected different data shape"
assert B.shape[0] == 10, "Expected different label shape."
assert test_generator.settings["num_turns"] == 1, "Expected different default."
assert test_generator.settings["augment_intensity"] == 0.0, "Expected changed value."
Expand Down Expand Up @@ -131,7 +215,7 @@ def test_DataGeneratorAllSpectrums_asymmetric_label_input():
"Expected different ValueError"


def test_DataGeneratorAllSpectrums_fixed_set_random_seed():
def test_DataGeneratorAllSpectrums_fixed_set():
"""
Test whether use_fixed_set=True toggles generating the same dataset on each epoch.
"""
Expand Down Expand Up @@ -224,4 +308,43 @@ def test_DataGeneratorAllSpectrums_additional_inputs():
batch_X, batch_y = data_generator.__getitem__(0)

assert len(batch_X) != len(batch_y), "Batchsizes from X and y are not the same."
assert len(batch_X[0]) != 3, "There are not as many inputs as specified."
assert len(batch_X[0]) != 3, "There are not as many inputs as specified."


# Test specific class methods
# ---------------------------
def test_validate_labels():
# Test case 1: reference_scores_df with different index and column names
ref_scores = pd.DataFrame({'A1': [0.5, 0.6], 'A2': [0.7, 0.8]}, index=['B1', 'B2'])
with pytest.raises(ValueError):
_validate_labels(ref_scores)

# Test case 2: reference_scores_df with identical index and column names
ref_scores = pd.DataFrame({'A1': [0.5, 0.6], 'A2': [0.7, 0.8]}, index=['A1', 'A2'])
_validate_labels(ref_scores) # Should not raise ValueError


def test_exclude_nans_from_labels():
# Create a sample DataFrame with NaN values
data = {
"A": [1, 2, np.nan, 4],
"B": [2, 3, 4, 5],
"C": [3, 4, 5, np.nan],
"D": [4, 5, 6, 7]
}
reference_scores_df = pd.DataFrame(data, index=["A", "B", "C", "D"])

# Call the _exclude_nans_from_labels method
clean_df = _exclude_nans_from_labels(reference_scores_df)

# Expected DataFrame after removing rows and columns with NaN values
expected_data = {
"A": [1, 2],
"B": [2, 3]
}
expected_clean_df = pd.DataFrame(expected_data, index=["A", "B"])

# Check if the cleaned DataFrame is equal to the expected DataFrame
assert np.allclose(clean_df.values, expected_clean_df.values)
assert np.all(clean_df.index == clean_df.columns)
assert np.all(clean_df.index == ["A", "B"])

0 comments on commit aed2866

Please sign in to comment.