diff --git a/nmdc_server/api.py b/nmdc_server/api.py index 26437194..222b914c 100644 --- a/nmdc_server/api.py +++ b/nmdc_server/api.py @@ -2,11 +2,11 @@ import json import logging from io import BytesIO, StringIO -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from uuid import UUID import requests -from fastapi import APIRouter, Depends, Header, HTTPException, Response, status +from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, status from fastapi.responses import JSONResponse from sqlalchemy.orm import Session from starlette.responses import StreamingResponse @@ -18,6 +18,7 @@ from nmdc_server.data_object_filters import WorkflowActivityTypeEnum from nmdc_server.database import get_db from nmdc_server.ingest.envo import nested_envo_trees +from nmdc_server.metadata import SampleMetadataSuggester from nmdc_server.models import ( IngestLock, SubmissionEditorRole, @@ -1040,6 +1041,36 @@ async def submit_metadata( return submission +@router.post( + "/metadata_submission/suggest", + tags=["metadata_submission"], + responses=login_required_responses, +) +async def suggest_metadata( + body: List[schemas_submission.MetadataSuggestionRequest], + suggester: SampleMetadataSuggester = Depends(SampleMetadataSuggester), + types: Union[List[schemas_submission.MetadataSuggestionType], None] = Query(None), + user: models.User = Depends(get_current_user), +) -> List[schemas_submission.MetadataSuggestion]: + response: List[schemas_submission.MetadataSuggestion] = [] + for item in body: + suggestions = suggester.get_suggestions(item.data, types=types) + for slot, value in suggestions.items(): + response.append( + schemas_submission.MetadataSuggestion( + type=( + schemas_submission.MetadataSuggestionType.REPLACE + if slot in item.data + else schemas_submission.MetadataSuggestionType.ADD + ), + row=item.row, + slot=slot, + value=value, + ) + ) + return response + + @router.get( "/users", responses=login_required_responses, response_model=query.UserResponse, tags=["user"] ) diff --git a/nmdc_server/metadata.py b/nmdc_server/metadata.py new file mode 100644 index 00000000..da0cb441 --- /dev/null +++ b/nmdc_server/metadata.py @@ -0,0 +1,71 @@ +import re +from typing import Any, Callable, Dict, List, Optional + +from nmdc_geoloc_tools import GeoEngine + +from nmdc_server.schemas_submission import MetadataSuggestionType + + +class SampleMetadataSuggester: + """A class to suggest sample metadata values based on partial sample metadata.""" + + def __init__(self): + self._geo_engine: Optional[GeoEngine] = None + + @property + def geo_engine(self) -> GeoEngine: + """A GeoEngine instance for looking up geospatial data.""" + if self._geo_engine is None: + self._geo_engine = GeoEngine() + return self._geo_engine + + def suggest_elevation_from_lat_lon(self, sample: Dict[str, str]) -> Optional[float]: + """Suggest an elevation for a sample based on its lat_lon.""" + lat_lon = sample.get("lat_lon", None) + if lat_lon is None: + return None + lat_lon_split = re.split("[, ]+", lat_lon) + if len(lat_lon_split) == 2: + try: + lat, lon = map(float, lat_lon_split) + return self.geo_engine.get_elevation((lat, lon)) + except ValueError: + # This could happen if the lat_lon string is not parseable as a float + # or the GeoEngine determined they are invalid values. In either case, + # just don't suggest an elevation. + pass + return None + + def get_suggestions( + self, sample: Dict[str, str], *, types: Optional[List[MetadataSuggestionType]] = None + ) -> Dict[str, str]: + """Suggest metadata values for a sample. + + Returns a dictionary where the keys are sample metadata slots and the values are suggested + values. + """ + + # Not explicitly supplying types implies using all types. + if types is None: + types = list(MetadataSuggestionType) + + do_add = MetadataSuggestionType.ADD in types + do_replace = MetadataSuggestionType.REPLACE in types + + # Map from sample metadata slot to a list of functions that can suggest values for + # that slot. + suggesters: dict[str, list[Callable[[dict[str, str]], Optional[Any]]]] = { + "elev": [self.suggest_elevation_from_lat_lon], + } + + suggestions = {} + + for target_slot, suggester_list in suggesters.items(): + has_data = target_slot in sample and sample[target_slot] + if (do_add and not has_data) or (do_replace and has_data): + for suggester_fn in suggester_list: + suggestion = suggester_fn(sample) + if suggestion is not None: + suggestions[target_slot] = str(suggestion) + + return suggestions diff --git a/nmdc_server/schemas_submission.py b/nmdc_server/schemas_submission.py index 1aab2e70..04c95db4 100644 --- a/nmdc_server/schemas_submission.py +++ b/nmdc_server/schemas_submission.py @@ -1,4 +1,5 @@ from datetime import datetime +from enum import Enum from typing import Any, Dict, List, Optional from uuid import UUID @@ -144,3 +145,20 @@ def populate_roles(cls, metadata_submission, values): SubmissionMetadataSchema.update_forward_refs() + + +class MetadataSuggestionRequest(BaseModel): + row: int + data: Dict[str, str] + + +class MetadataSuggestionType(str, Enum): + ADD = "add" + REPLACE = "replace" + + +class MetadataSuggestion(BaseModel): + type: MetadataSuggestionType + row: int + slot: str + value: str diff --git a/pyproject.toml b/pyproject.toml index bfe7cb93..029a3f36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "ipython==8.10.0", "itsdangerous==2.0.1", "mypy<0.920", + "nmdc-geoloc-tools==0.1.1", "nmdc-schema==10.8.0", "nmdc-submission-schema==10.8.0", "pint==0.18", diff --git a/tests/conftest.py b/tests/conftest.py index 6d2f6fd3..2cb5f271 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import pytest from factory import random +from nmdc_geoloc_tools import GeoEngine from starlette.testclient import TestClient from nmdc_server import database, schemas @@ -17,6 +18,27 @@ def set_seed(connection): random.reseed_random("nmdc") +@pytest.fixture(autouse=True) +def patch_geo_engine(monkeypatch): + """Patch all the GeoEngine methods that make external network requests.""" + + def mock_get_elevation(self, lat_lon): + lat, lon = lat_lon + if not -90 <= lat <= 90: + raise ValueError(f"Invalid Latitude: {lat}") + if not -180 <= lon <= 180: + raise ValueError(f"Invalid Longitude: {lon}") + return 16.0 + + def mock_not_implemented(self, *args, **kwargs): + raise NotImplementedError() + + monkeypatch.setattr(GeoEngine, "get_elevation", mock_get_elevation) + monkeypatch.setattr(GeoEngine, "get_fao_soil_type", mock_not_implemented) + monkeypatch.setattr(GeoEngine, "get_landuse", mock_not_implemented) + monkeypatch.setattr(GeoEngine, "get_landuse_dates", mock_not_implemented) + + @pytest.fixture(scope="session") def connection(): assert settings.environment == "testing" diff --git a/tests/test_metadata.py b/tests/test_metadata.py new file mode 100644 index 00000000..8e6a7780 --- /dev/null +++ b/tests/test_metadata.py @@ -0,0 +1,33 @@ +from nmdc_server.metadata import SampleMetadataSuggester + + +def test_sample_metadata_suggester_elevation(): + suggester = SampleMetadataSuggester() + + # Test with valid lat_lon + sample = {"lat_lon": "37.875766 -122.248580"} + elevation = suggester.suggest_elevation_from_lat_lon(sample) + assert elevation == 16.0 + + # Be tolerant of a comma separator + sample = {"lat_lon": "37.875766, -122.248580"} + elevation = suggester.suggest_elevation_from_lat_lon(sample) + assert elevation == 16.0 + + # Don't return a suggestion when lat_lon is missing + sample = {} + elevation = suggester.suggest_elevation_from_lat_lon(sample) + assert elevation is None + + # Don't return a suggestion when lat_lon is invalid + sample = {"lat_lon": "91.0 -122.248580"} + elevation = suggester.suggest_elevation_from_lat_lon(sample) + assert elevation is None + + sample = {"lat_lon": "no good"} + elevation = suggester.suggest_elevation_from_lat_lon(sample) + assert elevation is None + + sample = {"lat_lon": "0 0 0"} + elevation = suggester.suggest_elevation_from_lat_lon(sample) + assert elevation is None diff --git a/tests/test_submission.py b/tests/test_submission.py index aaf7d0c2..d772efc6 100644 --- a/tests/test_submission.py +++ b/tests/test_submission.py @@ -11,6 +11,16 @@ from nmdc_server.schemas_submission import SubmissionMetadataSchema, SubmissionMetadataSchemaPatch +@pytest.fixture +def suggest_payload(): + return [ + {"row": 1, "data": {"foo": "bar", "lat_lon": "44.058648, -123.095277"}}, + {"row": 3, "data": {"elev": 0, "lat_lon": "44.046389 -123.051910"}}, + {"row": 4, "data": {"foo": "bar"}}, + {"row": 5, "data": {"lat_lon": "garbage foo bar"}}, + ] + + def test_list_submissions(db: Session, client: TestClient, logged_in_user): submission = fakes.MetadataSubmissionFactory( author=logged_in_user, author_orcid=logged_in_user.orcid @@ -610,3 +620,48 @@ def test_sync_submission_study_name(db: Session, client: TestClient, logged_in_u response = client.request(method="GET", url=f"/api/metadata_submission/{submission.id}") assert response.status_code == 200 assert response.json()["study_name"] == expected_val + + +def test_metadata_suggest(client: TestClient, suggest_payload, logged_in_user): + response = client.request( + method="POST", url="/api/metadata_submission/suggest", json=suggest_payload + ) + assert response.status_code == 200 + assert response.json() == [ + {"type": "add", "row": 1, "slot": "elev", "value": "16.0"}, + {"type": "replace", "row": 3, "slot": "elev", "value": "16.0"}, + ] + + +def test_metadata_suggest_single_type(client: TestClient, suggest_payload, logged_in_user): + response = client.request( + method="POST", + url="/api/metadata_submission/suggest?types=add", + json=suggest_payload, + ) + assert response.status_code == 200 + assert response.json() == [ + {"type": "add", "row": 1, "slot": "elev", "value": "16.0"}, + ] + + +def test_metadata_suggest_multiple_types(client: TestClient, suggest_payload, logged_in_user): + response = client.request( + method="POST", + url="/api/metadata_submission/suggest?types=add&types=replace", + json=suggest_payload, + ) + assert response.status_code == 200 + assert response.json() == [ + {"type": "add", "row": 1, "slot": "elev", "value": "16.0"}, + {"type": "replace", "row": 3, "slot": "elev", "value": "16.0"}, + ] + + +def test_metadata_suggest_invalid_type(client: TestClient, suggest_payload, logged_in_user): + response = client.request( + method="POST", + url="/api/metadata_submission/suggest?types=whatever", + json=suggest_payload, + ) + assert response.status_code == 422