Skip to content

Commit

Permalink
Move input validation out of tests, into pipeline code itself
Browse files Browse the repository at this point in the history
  • Loading branch information
mattsolo1 committed Oct 2, 2023
1 parent 5a3c628 commit d6221b0
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
import pytest
from cattrs import structure, structure_attrs_fromdict
import hail as hl
import json

from data_pipeline.pipelines.gnomad_v4_variants import (
pipeline as gnomad_v4_variant_pipeline,
)
from loguru import logger

from data_pipeline.pipeline import Pipeline

from data_pipeline.datasets.gnomad_v4.types.initial_globals import Globals
from data_pipeline.datasets.gnomad_v4.types.initial_variant import InitialVariant
from data_pipeline.datasets.gnomad_v4.types.prepare_variants_step1 import Variant as Step1Variant
from data_pipeline.datasets.gnomad_v4.types.prepare_variants_step2 import Variant as Step2Variant
from data_pipeline.datasets.gnomad_v4.types.prepare_variants_step3 import Variant as Step3Variant

step1_task = gnomad_v4_variant_pipeline.get_task("prepare_gnomad_v4_exome_variants")


def ht_to_json(ht: hl.Table, field: str = "row"):
if field == "row":
Expand All @@ -32,45 +29,43 @@ def ht_to_json(ht: hl.Table, field: str = "row"):
return objs


@pytest.mark.mock_data
def test_globals_input_validation():
input_path = gnomad_v4_variant_pipeline.get_task("prepare_gnomad_v4_exome_variants").get_inputs()["input_path"]
def validate_globals_input(pipeline: Pipeline):
input_path = pipeline.get_task("prepare_gnomad_v4_exome_variants").get_inputs()["input_path"]
ht = hl.read_table(input_path)
result = ht_to_json(ht, "globals")[0]
# logger.info(result)
structure(result, Globals)
logger.info("Validated prepare_gnomad_v4_exome_variants input globals")


@pytest.mark.mock_data
def test_validate_variant_input():
input_path = gnomad_v4_variant_pipeline.get_task("prepare_gnomad_v4_exome_variants").get_inputs()["input_path"]
def validate_variant_input(pipeline: Pipeline):
input_path = pipeline.get_task("prepare_gnomad_v4_exome_variants").get_inputs()["input_path"]
ht = hl.read_table(input_path)
result = ht_to_json(ht)
[structure_attrs_fromdict(variant, InitialVariant) for variant in result]
logger.info("Validated prepare_gnomad_v4_exome_variants input variants")


@pytest.mark.mock_data
def test_validate_step1_output():
output_path = gnomad_v4_variant_pipeline.get_task("prepare_gnomad_v4_exome_variants").get_output_path()
def validate_step1_output(pipeline: Pipeline):
output_path = pipeline.get_task("prepare_gnomad_v4_exome_variants").get_output_path()
ht = hl.read_table(output_path)
# ht = ht.sample(0.1, seed=1234)
result = ht_to_json(ht)
[structure_attrs_fromdict(variant, Step1Variant) for variant in result]
logger.info("Validated prepare_gnomad_v4_exome_variants (step 1) output")


@pytest.mark.mock_data
def test_validate_step2_output():
output_path = gnomad_v4_variant_pipeline.get_task("annotate_gnomad_v4_exome_variants").get_output_path()
def validate_step2_output(pipeline: Pipeline):
output_path = pipeline.get_task("annotate_gnomad_v4_exome_variants").get_output_path()
ht = hl.read_table(output_path)
result = ht_to_json(ht)
[structure_attrs_fromdict(variant, Step2Variant) for variant in result]
logger.info("Validated annotate_gnomad_v4_exome_variants (step 2) output")


@pytest.mark.mock_data
def test_validate_step3_output():
output_path = gnomad_v4_variant_pipeline.get_task(
"annotate_gnomad_v4_exome_transcript_consequences"
).get_output_path()
def validate_step3_output(pipeline: Pipeline):
output_path = pipeline.get_task("annotate_gnomad_v4_exome_transcript_consequences").get_output_path()
ht = hl.read_table(output_path)
result = ht_to_json(ht)
[structure_attrs_fromdict(variant, Step3Variant) for variant in result]
logger.info("Validated annotate_gnomad_v4_exome_transcript_consequences (step 3) output")
15 changes: 15 additions & 0 deletions data-pipeline/src/data_pipeline/pipelines/gnomad_v4_variants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from loguru import logger


from data_pipeline.config import PipelineConfig, get_data_environment, DataEnvironment
Expand All @@ -12,6 +13,13 @@
annotate_variants,
annotate_transcript_consequences,
)
from data_pipeline.pipelines.gnomad_v4_validation import (
validate_globals_input,
validate_step1_output,
validate_step2_output,
validate_step3_output,
validate_variant_input,
)

DATA_ENV = os.getenv("DATA_ENV", "mock")

Expand Down Expand Up @@ -106,3 +114,10 @@

if __name__ == "__main__":
run_pipeline(pipeline)

logger.info("Validating pipeline IO formats")
validate_globals_input(pipeline)
validate_variant_input(pipeline)
validate_step1_output(pipeline)
validate_step2_output(pipeline)
validate_step3_output(pipeline)

0 comments on commit d6221b0

Please sign in to comment.