Skip to content

Commit

Permalink
double layered splitting strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
MattWellie committed Oct 30, 2024
1 parent 5d8da43 commit a74c5b3
Showing 1 changed file with 90 additions and 68 deletions.
158 changes: 90 additions & 68 deletions src/talos/CreateTalosHTML.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
Separate documents will have the variant details per-family
The variant row will offer a hyperlink to the variant details
Additional separate pages will contain metadata/panel data
We're snapping it even further!
"""

import re
Expand All @@ -29,13 +31,16 @@

JINJA_TEMPLATE_DIR = Path(__file__).absolute().parent / 'templates'
MIN_REPORT_SIZE: int = 10
MAX_REPORT_SIZE: int = 200

# above this length we trim the actual bases to just an int
MAX_INDEL_LEN: int = 10

# regex pattern - number, number, not-number
KNOWN_YEAR_PREFIX = re.compile(r'\d{2}\D')
CDNA_SQUASH = re.compile(r'(?P<type>ins|del)(?P<bases>[ACGT]+)$')
MEAN_SLASH_SAMPLE = 'Mean/sample'
GNOMAD_SV_KEY = 'gnomad_v2.1_sv_svid'


def known_date_prefix_check(all_results: ResultData) -> list[str]:
Expand All @@ -61,21 +66,64 @@ def known_date_prefix_check(all_results: ResultData) -> list[str]:
return sorted(known_prefixes.keys())


def split_data_into_sub_reports(data_path: str, split_samples: int) -> list[tuple[ResultData, str, str]]:
def calculate_report_size(num_results: int, max_elements: int = MAX_REPORT_SIZE) -> int:
"""
Calculate the number of samples per report
Args:
num_results (): total number of samples to report
max_elements (): maximum number of samples per report
Returns:
a number of samples to include in each report
"""

# to begin, assume a single report is fine
report_count = 1

# keep track of the number of cases in each report fragment
if num_results < max_elements:
return num_results
while True:
report_count += 1
# break when we hava number that subdivides this group well enough
if (num_results // report_count) <= max_elements:
break

# calculate the number of samples per sub-report
samples_per_report = num_results // report_count

# adjust so that we don't make tiny little report HTMLs
if num_results % samples_per_report < MIN_REPORT_SIZE:
samples_per_report += MIN_REPORT_SIZE

return samples_per_report


def split_up(all_results: ResultData, max_elements: int = MAX_REPORT_SIZE) -> list[tuple[ResultData, str, str]]:
"""
Split the data into sub-reports
Return a list of the ResultData subsets, output base path, and a subset identifier
Alright, this is going to be a bit rogue - make this splitting reactive to the number
of elements in the split.
Args:
data_path ():
split_samples ():
all_results (ResultData): the whole cohort results
max_elements ():
Returns:
tuple: a list of tuples, each containing a ResultData object, the output base path, and a subset identifier
"""
all_results = read_json_from_path(data_path, return_model=ResultData)
return_results: list[tuple[ResultData, str, str]] = []

# only interested in presenting probands with results (for now)
# so we're stripping these back to just results with variants for the smaller reports
all_results.results = {
key: val for key, val in all_results.results.items() if val.variants and not val.metadata.solved
}

partially_split: list[tuple[ResultData, str]] = []

# first check if there's a logical way to break up the results
if prefixes := known_date_prefix_check(all_results):
for prefix in prefixes:
this_rd = ResultData(
Expand All @@ -88,31 +136,27 @@ def split_data_into_sub_reports(data_path: str, split_samples: int) -> list[tupl
version=all_results.version,
)
get_logger().info(f'Found {len(this_rd.results)} with prefix {prefix}')
return_results.append((this_rd, f'subset_{prefix}.html', prefix))
return return_results
partially_split.append((this_rd, f'subset_{prefix}'))

# only interested in presenting probands with results (for now)
results_with_variants = {
key: val for key, val in all_results.results.items() if val.variants and not val.metadata.solved
}
else:
partially_split = [(all_results, 'whole_cohort')]

# calculate the number of samples per sub-report
samples_per_report = len(results_with_variants) // split_samples

# resolve dodgy remainders
if len(results_with_variants) % split_samples < MIN_REPORT_SIZE:
samples_per_report += (MIN_REPORT_SIZE // split_samples) + 1

# split the data into sub-reports
sub_reports = []
for i, chunk in enumerate(chunks(list(results_with_variants.keys()), samples_per_report), start=1):
sub_report = ResultData(
metadata=all_results.metadata,
results={key: results_with_variants[key] for key in chunk},
)
sub_reports.append((sub_report, f'subset_{i}.html', str(i)))
# ready an object to return the final results
return_results: list[tuple[ResultData, str, str]] = []

for subset_of_cases, name in partially_split:
# get the number of samples to include in this report
samples_per_report = calculate_report_size(len(subset_of_cases.results), max_elements)

# split the data into sub-reports
for i, chunk in enumerate(chunks(list(subset_of_cases.results.keys()), samples_per_report), start=1):
sub_report = ResultData(
metadata=all_results.metadata,
results={key: subset_of_cases.results[key] for key in chunk},
)
return_results.append((sub_report, f'{name}_{i}.html', f'{name}_{i}'))

return sub_reports
return return_results


def cli_main():
Expand All @@ -128,28 +172,27 @@ def cli_main():
if args.latest:
get_logger(__file__).warning('"--latest" argument is not in use')

main(
results=args.input,
panelapp=args.panelapp,
output=args.output,
split_samples=args.split_samples,
)
if args.split_samples:
get_logger(__file__).warning('"--split_samples" argument is not in use')

main(results=args.input, panelapp=args.panelapp, output=args.output)

def main(results: str, panelapp: str, output: str, split_samples: int | None = None):

def main(results: str, panelapp: str, output: str):
"""
Args:
results (str): path to the MOI-tested results file
panelapp (str): path to the panelapp data
output (str): where to write the HTML file
split_samples (int, optional): how many sub-reports to generate
"""

report_output_dir = Path(output).parent

results_object = read_json_from_path(results, return_model=ResultData)

# we always make this main page - we need a reliable output path to generate analysis entries [CPG]
html = HTMLBuilder(results=results, panelapp_path=panelapp)
html = HTMLBuilder(results_dict=results_object, panelapp_path=panelapp)
# if this fails with a NoVariantsFoundException, there were no variants to present in the whole cohort
# catch this, but fail gracefully so that the process overall is a success
try:
Expand All @@ -161,7 +204,7 @@ def main(results: str, panelapp: str, output: str, split_samples: int | None = N
get_logger().warning('No Categorised variants found in this whole cohort')

# then quit if we're not splitting samples
if not split_samples:
if len(results_object.results) <= MAX_REPORT_SIZE:
return

# do something to split the output into separate datasets
Expand All @@ -171,8 +214,8 @@ def main(results: str, panelapp: str, output: str, split_samples: int | None = N
default_report_name = Path(output).name
html_base = output.rstrip(default_report_name)

for data, report, prefix in split_data_into_sub_reports(results, split_samples):
html = HTMLBuilder(results=data, panelapp_path=panelapp, subset_id=prefix)
for data, report, prefix in split_up(results_object):
html = HTMLBuilder(results_dict=data, panelapp_path=panelapp, subset_id=prefix)
try:
output_filepath = f'{html_base}{report}'
get_logger().info(f'Attempting to create {report} at {output_filepath}')
Expand Down Expand Up @@ -219,10 +262,10 @@ class HTMLBuilder:
Takes the input, makes the output
"""

def __init__(self, results: str | ResultData, panelapp_path: str, subset_id: str | None = None):
def __init__(self, results_dict: ResultData, panelapp_path: str, subset_id: str | None = None):
"""
Args:
results (str | ResultData): path to the results JSON, or the results object
results_dict (ResultData): the results object
panelapp_path (str): where to read panelapp data from
subset_id (str, optional): the subset ID to use for this report
"""
Expand Down Expand Up @@ -264,10 +307,6 @@ def __init__(self, results: str | ResultData, panelapp_path: str, subset_id: str
self.ext_labels: dict[str, dict] = config_retrieve(['CreateTalosHTML', 'external_labels'], {})
assert isinstance(self.ext_labels, dict)

# Read results file, or take it directly
results_dict = read_json_from_path(results, return_model=ResultData) if isinstance(results, str) else results
assert isinstance(results_dict, ResultData)

self.metadata = results_dict.metadata
self.panel_names = {panel.name for panel in self.metadata.panels}

Expand Down Expand Up @@ -348,7 +387,7 @@ def get_summary_stats(self) -> tuple[pd.DataFrame, list[str], list[dict]]:
'Total': sum(category_count[key]),
'Unique': len(unique_variants[key]),
'Peak #/sample': max(category_count[key]),
'Mean/sample': sum(category_count[key]) / len(category_count[key]),
MEAN_SLASH_SAMPLE: sum(category_count[key]) / len(category_count[key]),
}
for key in ordered_categories
if category_count[key]
Expand All @@ -359,7 +398,7 @@ def get_summary_stats(self) -> tuple[pd.DataFrame, list[str], list[dict]]:
raise NoVariantsFoundError('No categorised variants found')

my_df: pd.DataFrame = pd.DataFrame(summary_dicts)
my_df['Mean/sample'] = my_df['Mean/sample'].round(3)
my_df[MEAN_SLASH_SAMPLE] = my_df[MEAN_SLASH_SAMPLE].round(3)

# the table re-sorts when parsed into the DataTable
# so this forced ordering doesn't work
Expand Down Expand Up @@ -423,23 +462,6 @@ def write_html(self, output_filepath: str):
'type': 'whole_cohort',
}

# for title, meta_table in self.read_metadata().items():
# template_context['meta_tables'][title] = DataTable(
# id=f'{title.lower()}-table',
# heading=title,
# description='',
# columns=list(meta_table.columns),
# rows=list(meta_table.to_records(index=False)),
# )

# template_context['summary_table'] = DataTable(
# id='summary-table',
# heading='Per-Category Summary',
# description='',
# columns=list(summary_table.columns),
# rows=list(summary_table.to_records(index=False)),
# )

# write all HTML content to the output file in one go
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(JINJA_TEMPLATE_DIR),
Expand Down Expand Up @@ -621,10 +643,10 @@ def __init__(self, report_variant: ReportVariant, sample: Sample, ext_labels: li
# this is the weird gnomad callset ID
if (
isinstance(self.var_data, StructuralVariant)
and 'gnomad_v2.1_sv_svid' in self.var_data.info
and isinstance(self.var_data.info['gnomad_v2.1_sv_svid'], str)
and GNOMAD_SV_KEY in self.var_data.info
and isinstance(self.var_data.info[GNOMAD_SV_KEY], str)
):
self.var_data.info['gnomad_key'] = self.var_data.info['gnomad_v2.1_sv_svid'].split('v2.1_')[-1]
self.var_data.info['gnomad_key'] = self.var_data.info[GNOMAD_SV_KEY].split('v2.1_')[-1]

def __str__(self) -> str:
return f'{self.chrom}-{self.pos}-{self.ref}-{self.alt}'
Expand Down Expand Up @@ -674,7 +696,7 @@ def check_date_filter(results: str | ResultData, filter_date: str | None = None)
Extra consideration - if one part of a comp-het variant pair is new,
retain both sides in the report
deprecated for now, migrating to lightweight fitler-able reports, which should mitigate need for this
deprecated for now, migrating to lightweight filter-able reports, which should mitigate need for this
Args:
results (str): path to the results file
Expand Down

0 comments on commit a74c5b3

Please sign in to comment.