diff --git a/str/associatr/get_cis_numpy_files.py b/str/associatr/get_cis_numpy_files.py index 1018da37..dc00c09a 100644 --- a/str/associatr/get_cis_numpy_files.py +++ b/str/associatr/get_cis_numpy_files.py @@ -10,7 +10,7 @@ - output gene-level phenotype and covariate numpy objects for input into associatr analysis-runner --config get_cis_numpy_files.toml --dataset "bioheart" --access-level "test" \ ---description "get cis and numpy" --output-dir "str/associatr/tob_n1055/input_files" \ +--description "get cis and numpy" --output-dir "str/associatr/tob_n1055/tester" \ --image australia-southeast1-docker.pkg.dev/cpg-common/images/scanpy:1.9.3 \ python3 get_cis_numpy_files.py @@ -21,6 +21,7 @@ import numpy as np import pandas as pd import scanpy as sc +from cyvcf2 import VCF from scipy.stats import norm import hail as hl @@ -31,6 +32,33 @@ from cpg_utils.hail_batch import get_batch, image_path, init_batch, output_path +def extract_genotypes(vcf_file, loci): + """ + Helper function to extract genotypes (SNPs) from a VCF file; target loci specified as a list (can be single or multiple) + + """ + # Read the VCF file + vcf_reader = VCF(vcf_file) + + results = pd.DataFrame() + results['sample_id'] = vcf_reader.samples + + # Iterate through the records in the VCF file for each locus + for locus in loci: + chrom, pos = locus.split(':') + pos = int(pos) + + for record in vcf_reader(f'{chrom}:{pos}-{pos}'): + if record.CHROM == chrom and record.POS == pos: + gt = record.gt_types + gt[gt == 3] = 2 #HOM ALT is coded as 3; change it to 2 + results[locus] = gt + break + + # Convert results to a DataFrame + return results + + def cis_window_numpy_extractor( input_h5ad_dir, input_pseudobulk_dir, @@ -42,6 +70,8 @@ def cis_window_numpy_extractor( chrom_len, min_pct, remove_samples_file, + snp_input, + snp_loci, ): """ Creates gene-specific cis window files and phenotype-covariate numpy objects @@ -106,6 +136,11 @@ def cis_window_numpy_extractor( gene_pheno_cov = gene_pheno.merge(covariates, on='sample_id', how='inner') + # add SNP genotypes we would like to condition on + if snp_input is not None: + snp_genotype_df = extract_genotypes(snp_input['vcf'], snp_loci) + gene_pheno_cov = gene_pheno_cov.merge(snp_genotype_df, on='sample_id', how='inner') + # filter for samples that were assigned a CPG ID; unassigned samples after demultiplexing will not have a CPG ID gene_pheno_cov = gene_pheno_cov[gene_pheno_cov['sample_id'].str.startswith('CPG')] @@ -151,6 +186,13 @@ def manage_concurrency_for_job(job: hb.batch.job.Job): j.cpu(get_config()['get_cis_numpy']['job_cpu']) j.memory(get_config()['get_cis_numpy']['job_memory']) j.storage(get_config()['get_cis_numpy']['job_storage']) + + if get_config()['get_cis_numpy']['snp_vcf_dir'] is not None: + snp_vcf_dir = get_config()['get_cis_numpy']['snp_vcf_dir'] + snp_vcf_path = f'{snp_vcf_dir}/{chrom}_common_variants.vcf.bgz' + snp_input = get_batch().read_input_group(**{'vcf': snp_vcf_path, 'csi': snp_vcf_path + '.csi'}) + else: # no SNP VCF provided + snp_input = None j.call( cis_window_numpy_extractor, get_config()['get_cis_numpy']['input_h5ad_dir'], @@ -163,6 +205,8 @@ def manage_concurrency_for_job(job: hb.batch.job.Job): chrom_len, get_config()['get_cis_numpy']['min_pct'], get_config()['get_cis_numpy']['remove_samples_file'], + snp_input, + get_config()['get_cis_numpy']['snp_loci'], ) manage_concurrency_for_job(j)