diff --git a/src/baskerville/snps.py b/src/baskerville/snps.py index a4b2e25..d9d2acc 100644 --- a/src/baskerville/snps.py +++ b/src/baskerville/snps.py @@ -164,10 +164,20 @@ def score_write(ref_preds, alt_preds, si): si = 0 with concurrent.futures.ThreadPoolExecutor() as executor: - for sc in tqdm(snp_clusters): - snp_1hot_list = sc.get_1hots(genome_open) + # initialize 1 hot encoding + sc0 = snp_clusters[0] + s1l = executor.submit(sc0.get_1hots, genome_open) + + for ci, sc in enumerate(tqdm(snp_clusters)): + # pull latest 1 hot encoding + snp_1hot_list = s1l.result() ref_1hot = np.expand_dims(snp_1hot_list[0], axis=0) + # submit next 1 hot encoding + if ci + 1 < len(snp_clusters): + sc1 = snp_clusters[ci + 1] + s1l = executor.submit(sc1.get_1hots, genome_open) + # predict reference ref_preds = [] for shift in options.shifts: