@@ -364,4 +368,9 @@
}
{% endif %}
+
+
+
diff --git a/tests/test_active.py b/tests/test_active.py
index baf1fb25..5f7cbf95 100644
--- a/tests/test_active.py
+++ b/tests/test_active.py
@@ -77,52 +77,48 @@ def test_active_bad_keys(server, logs):
)
-@pytest.mark.parametrize("sampler", ["ARR", "CKL"])
+@pytest.mark.parametrize("sampler", ["ARR", "Random"])
def test_active_queries_generated(server, sampler, logs):
# R=1 chosen because that determines when active sampling starts; this
# test is designed to make sure no unexpected errors are thrown in
# active portion (not that it generates a good embedding)
+ # tests ARR to make sure active scores are generated;
+ # tests Random to make sure that's not a false positive and
+ # random queries are properly identifies
+
n = 6
config = {
"targets": [_ for _ in range(n)],
"samplers": {sampler: {}},
- "sampling": {"common": {"d": 1, "R": 1}},
+ "sampling": {},
}
+ if sampler != "Random":
+ config["sampling"]["common"] = {"d": 1, "R": 1}
with logs:
server.authorize()
server.post("/init_exp", data={"exp": config})
- n_active_queries = 0
- for k in range(6 * n + 1):
+ active_queries_generated = False
+ for k in range(10 * n + 1):
q = server.get("/query").json()
+ query = "random" if q["score"] == -9999 else "active"
+ if query == "active":
+ active_queries_generated = True
+ break
+
+ sleep(200e-3)
ans = random.choice([q["left"], q["right"]])
ans = {"winner": ans, "puid": "foo", **q}
- print(q)
server.post("/answer", json=ans)
- if q["score"] != -9999:
- # scored queries have been posted to the database
- # now, only thing to test is popping off database
- n_active_queries += 1
- if n_active_queries == n:
- sleep(1)
- break
-
- sleep(100e-3)
if k % n == 0:
sleep(1)
- d = server.get("/responses").json()
-
- df = pd.DataFrame(d)
- random_queries = df["score"] == -9999
- active_queries = ~random_queries
- assert active_queries.sum()
- assert random_queries.sum()
-
- samplers = set(df.sampler.unique())
- assert samplers == {sampler}
+ if sampler == "Random":
+ assert not active_queries_generated
+ else:
+ assert active_queries_generated
def test_active_basics(server, logs):
@@ -137,7 +133,7 @@ def test_active_basics(server, logs):
with logs:
server.authorize()
server.post("/init_exp", data={"exp": exp.read_text()})
- for k in range(len(samplers) * 2):
+ for k in range(len(samplers) * 3):
print(k)
q = server.get("/query").json()
@@ -154,3 +150,4 @@ def test_active_basics(server, logs):
assert (df["score"] <= 1).all()
algs = df.sampler.unique()
assert set(algs) == {"TSTE", "ARR", "CKL", "tste2", "GNMDS"}
+ assert True # to see if a log error is caught in the traceback
diff --git a/tests/test_allowabe_targets.py b/tests/test_allowabe_targets.py
index 893fd3b1..209dbfff 100644
--- a/tests/test_allowabe_targets.py
+++ b/tests/test_allowabe_targets.py
@@ -14,7 +14,7 @@ def test_targets(sampler, server):
server.authorize()
server.post("/init_exp", data={"exp": config})
- for k in range(20):
+ for k in range(30):
q = server.get("/query").json()
ans = {"winner": random.choice([q["left"], q["right"]]), "puid": "foo", **q}
server.post("/answer", json=ans)
diff --git a/tests/test_offline.py b/tests/test_offline.py
index ffd15949..b8dc0cc0 100644
--- a/tests/test_offline.py
+++ b/tests/test_offline.py
@@ -1,4 +1,5 @@
from pathlib import Path
+import yaml
import numpy as np
import numpy.linalg as LA
@@ -8,10 +9,11 @@
from salmon.triplets.offline import OfflineEmbedding
from salmon.triplets.samplers import TSTE
+import salmon.triplets.offline
def test_salmon_import():
- """ This test makes sure that no errors are raised on import
+ """This test makes sure that no errors are raised on import
(non-existant directories, etc)"""
import salmon
@@ -49,7 +51,7 @@ def test_score_accurate():
# Make sure the score has the expected value (winner has minimum distance)
embed = alg.opt.embedding() * 1e3
y_hat2 = []
- for (head, left, right) in X:
+ for head, left, right in X:
ldist = LA.norm(embed[head] - embed[left])
rdist = LA.norm(embed[head] - embed[right])
@@ -108,6 +110,26 @@ def test_offline_init():
assert not np.allclose(est.embedding_, em), "Embedding didn't change"
+def test_offline_names_correct():
+ DIR = Path(__file__).absolute().parent
+ _f = DIR / "data" / "active.yaml"
+ config = yaml.load(_f.read_text(), Loader=yaml.SafeLoader)
+ n = len(config["targets"])
+ d = config["sampling"]["common"]["d"]
+
+ X = np.random.choice(n, size=(100, 3))
+ est = OfflineEmbedding(n=n, d=d)
+ est.partial_fit(X)
+
+ import salmon.triplets.offline as offline
+
+ em = offline.join(est.embedding_, config["targets"])
+ assert isinstance(em, pd.DataFrame)
+ assert len(em) == len(config["targets"])
+ assert set(em.columns) == {"x", "y", "target"}
+ assert (em["target"] == config["targets"]).all()
+
+
if __name__ == "__main__":
test_offline_init()
test_offline_embedding_random_state()