Skip to content

Commit

Permalink
Develop pipeline config
Browse files Browse the repository at this point in the history
  • Loading branch information
mattsolo1 committed Sep 26, 2023
1 parent d5c6e07 commit 620d360
Show file tree
Hide file tree
Showing 9 changed files with 219 additions and 45 deletions.
65 changes: 49 additions & 16 deletions data-pipeline/src/data_pipeline/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import attr
from enum import Enum
from pathlib import Path

DATA_ENV = os.getenv("DATA_ENV", "local")
Expand Down Expand Up @@ -31,25 +32,57 @@ def make_local_folder(self):
Path(self.root).mkdir(parents=True, exist_ok=True)


# @attr.define
# class GnomadV4
# gnomad_v4_exome_variants_sites_ht_path: str = "external_datasets/mock_v4_release.ht"
class ComputeEnvironment(Enum):
local = "local"
cicd = "cicd"
dataproc = "dataproc"


@attr.define
class PipelineConfig:
data_paths: DataPaths
compute_env: str = "local"
data_env: str = "tiny"
class DataEnvironment(Enum):
tiny = "tiny"
full = "full"


config = PipelineConfig(
data_env="local",
data_paths=DataPaths.create(os.path.join("data")),
)
def is_valid_fn(cls):
def is_valid(instance, attribute, value):
if not isinstance(value, cls):
raise ValueError(f"Expected {cls} enum, got {type(value)}")

return is_valid

if DATA_ENV == "dataproc":
config = PipelineConfig(
data_paths=DataPaths.create(os.path.join("gs://gnomad-matt-data-pipeline")),
)

@attr.define
class PipelineConfig:
name: str
input_paths: DataPaths
output_paths: DataPaths
data_env: DataEnvironment = attr.field(validator=is_valid_fn(DataEnvironment))
compute_env: ComputeEnvironment = attr.field(validator=is_valid_fn(ComputeEnvironment))

@classmethod
def create(
cls,
name: str,
input_root: str,
output_root: str,
data_env=DataEnvironment.tiny,
compute_env=ComputeEnvironment.local,
):
input_paths = DataPaths.create(input_root)
output_paths = DataPaths.create(output_root)
return cls(name, input_paths, output_paths, data_env, compute_env)


# config = PipelineConfig.create(
# name=
# input_root="data_in",
# output_root="data_out",
# compute_env=ComputeEnvironment.local,
# data_env=DataEnvironment.tiny,
# )


# if DATA_ENV == "dataproc":
# config = PipelineConfig(
# output_path=DataPaths.create(os.path.join("gs://gnomad-matt-data-pipeline")),
# )
8 changes: 3 additions & 5 deletions data-pipeline/src/data_pipeline/helpers/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@


def create_logger():
config = {
"handlers": [
logger.configure(
handlers=[
{
"sink": sys.stdout,
"format": "<level>{time:YYYY-MM-DDTHH:mm}</level> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>", # noqa
},
]
}

logger.configure(**config)
)
logger.level("CONFIG", no=38, icon="🐍")

# clear log file after each run
Expand Down
2 changes: 1 addition & 1 deletion data-pipeline/src/data_pipeline/helpers/write_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def describe_handler(text):


for pipeline in pipelines:
pipeline_name = pipeline.name
pipeline_name = pipeline.config.name
task_names = pipeline.get_all_task_names()
out_dir = os.path.join(SCHEMA_PATH, pipeline_name)

Expand Down
42 changes: 24 additions & 18 deletions data-pipeline/src/data_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import subprocess
import tempfile
import time
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union
import attr
from collections import OrderedDict

import hail as hl

from data_pipeline.config import config
from data_pipeline.config import PipelineConfig

logger = logging.getLogger("gnomad_data_pipeline")
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -57,23 +57,24 @@ def modified_time(path):
return file_system.modified_time(check_path)


_pipeline_config = {}
# _pipeline_config = {}

_pipeline_config["output_root"] = config.data_paths.root
# _pipeline_config["output_root"] = config.output_paths.root


@attr.define
class DownloadTask:
_config: PipelineConfig
_name: str
_url: str
_output_path: str

@classmethod
def create(cls, name, url, output_path):
return cls(name, url, output_path)
def create(cls, config: PipelineConfig, name: str, url: str, output_path: str):
return cls(config, name, url, output_path)

def get_output_path(self):
return _pipeline_config["output_root"] + self._output_path
return self._config.output_paths.root + self._output_path

def should_run(self):
output_path = self.get_output_path()
Expand All @@ -82,6 +83,9 @@ def should_run(self):

return (False, None)

def get_inputs(self):
raise NotImplementedError("Method not valid for DownloadTask")

def run(self, force=False):
output_path = self.get_output_path()
should_run, reason = (True, "Forced") if force else self.should_run()
Expand All @@ -106,17 +110,19 @@ def run(self, force=False):

@attr.define
class Task:
_config: PipelineConfig
_name: str
_task_function: str
_task_function: Callable
_output_path: str
_inputs: dict
_params: dict

@classmethod
def create(
cls,
config: PipelineConfig,
name: str,
task_function: str,
task_function: Callable,
output_path: str,
inputs: Optional[dict] = None,
params: Optional[dict] = None,
Expand All @@ -125,10 +131,10 @@ def create(
inputs = {}
if params is None:
params = {}
return cls(name, task_function, output_path, inputs, params)
return cls(config, name, task_function, output_path, inputs, params)

def get_output_path(self):
return _pipeline_config["output_root"] + self._output_path
return self._config.output_paths.root + self._output_path

def get_inputs(self):
paths = {}
Expand All @@ -138,7 +144,7 @@ def get_inputs(self):
paths.update({k: v.get_output_path()})
else:
logger.info(v)
paths.update({k: os.path.join(config.data_paths.root, v)})
paths.update({k: os.path.join(self._config.output_paths.root, v)})

return paths

Expand Down Expand Up @@ -173,14 +179,14 @@ def run(self, force=False):

@attr.define
class Pipeline:
name: str
config: PipelineConfig
_tasks: OrderedDict = OrderedDict()
_outputs: dict = {}

def add_task(
self,
name: str,
task_function: str,
task_function: Callable,
output_path: str,
inputs: Optional[dict] = None,
params: Optional[dict] = None,
Expand All @@ -189,12 +195,12 @@ def add_task(
inputs = {}
if params is None:
params = {}
task = Task.create(name, task_function, output_path, inputs, params)
task = Task.create(self.config, name, task_function, output_path, inputs, params)
self._tasks[name] = task
return task

def add_download_task(self, name, *args, **kwargs) -> DownloadTask:
task = DownloadTask.create(name, *args, **kwargs)
task = DownloadTask.create(self.config, name, *args, **kwargs)
self._tasks[name] = task
return task

Expand Down Expand Up @@ -232,8 +238,8 @@ def run_pipeline(pipeline):
group.add_argument("--force-all", action="store_true")
args = parser.parse_args()

if args.output_root:
_pipeline_config["output_root"] = args.output_root.rstrip("/")
# if args.output_root:
# _pipeline_config["output_root"] = args.output_root.rstrip("/")

pipeline_args = {}
if args.force_all:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from data_pipeline.pipeline import Pipeline, run_pipeline
from data_pipeline.config import PipelineConfig

from data_pipeline.data_types.coverage import prepare_coverage


pipeline = Pipeline(name="gnomad_v4_coverage")
pipeline = Pipeline(
config=PipelineConfig.create(name="gnomad_v4_variants", input_root="data_in", output_root="data_out")
)

pipeline.add_task(
name="prepare_gnomad_v4_exome_coverage",
Expand Down
15 changes: 11 additions & 4 deletions data-pipeline/src/data_pipeline/pipelines/gnomad_v4_variants.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from data_pipeline.config import PipelineConfig
from data_pipeline.pipeline import Pipeline, run_pipeline

from data_pipeline.datasets.gnomad_v4.gnomad_v4_variants import prepare_gnomad_v4_variants
from data_pipeline.datasets.gnomad_v4.gnomad_v4_variants import (
prepare_gnomad_v4_variants,
)

from data_pipeline.data_types.variant import annotate_variants, annotate_transcript_consequences
from data_pipeline.data_types.variant import (
annotate_variants,
annotate_transcript_consequences,
)

# from data_pipeline.pipelines.gnomad_v4_coverage import pipeline as coverage_pipeline

# from data_pipeline.pipelines.genes import pipeline as genes_pipeline


pipeline = Pipeline(name="gnomad_v4_variants")
pipeline = Pipeline(
config=PipelineConfig.create(name="gnomad_v4_variants", input_root="data_in", output_root="data_out")
)

pipeline.add_task(
name="prepare_gnomad_v4_exome_variants",
Expand All @@ -18,7 +26,6 @@
inputs={
"input_path": "external_datasets/mock_v4_release.ht",
},
# params={"sequencing_type": "exome"},
)

# pipeline.add_task(
Expand Down
63 changes: 63 additions & 0 deletions data-pipeline/tests/pipeline/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# from loguru import logger
import os
import pytest
import tempfile

from data_pipeline.config import ComputeEnvironment, DataEnvironment, DataPaths, PipelineConfig

# from data_pipeline.pipeline import Pipeline


@pytest.fixture
def input_tmp():
with tempfile.TemporaryDirectory() as temp_dir:
with open(os.path.join(temp_dir, "sample_tiny.txt"), "w") as f:
f.write("tiny dataset")
with open(os.path.join(temp_dir, "sample_full.txt"), "w") as f:
f.write("full dataset")
yield temp_dir


@pytest.fixture
def output_tmp():
with tempfile.TemporaryDirectory() as temp_dir:
yield temp_dir


@pytest.mark.only
def test_config_created(input_tmp, output_tmp):
config = PipelineConfig.create(name="test", input_root=input_tmp, output_root=output_tmp)
assert isinstance(config, PipelineConfig)
assert isinstance(config.input_paths, DataPaths)
assert isinstance(config.output_paths, DataPaths)
assert isinstance(config.compute_env, ComputeEnvironment)
assert isinstance(config.data_env, DataEnvironment)


@pytest.mark.only
def test_config_read_input_file(input_tmp, output_tmp):
config = PipelineConfig.create(
name="test",
input_root=input_tmp,
output_root=output_tmp,
)
sample = os.path.join(config.input_paths.root, "sample_tiny.txt")
with open(sample, "r") as f:
assert f.read() == "tiny dataset"


# @pytest.mark.only
# def test_pipeline_tasks(ht_1_fixture: TestHt, ht_2_fixture: TestHt):
# def task_1_fn():
# pass

# pipeline = Pipeline("p1")

# pipeline.add_task(
# name="task_1_join_hts",
# task_function=task_1_fn,
# output_path="/gnomad_v4/gnomad_v4_exome_variants_base.ht",
# inputs={
# "input_ht_1": ht_1_fixture.path,
# },
# )
Loading

0 comments on commit 620d360

Please sign in to comment.