diff --git a/str/associatr/plotters/qqplotter.py b/str/associatr/plotters/qqplotter.py index 6868a97f..86bf8e48 100644 --- a/str/associatr/plotters/qqplotter.py +++ b/str/associatr/plotters/qqplotter.py @@ -6,7 +6,7 @@ --output-dir "str/associatr/tob_n1055_and_bioheart_n990" --memory=8G \ qqplotter.py \ --input-dir=gs://cpg-bioheart-test/str/associatr/tob_n1055_and_bioheart_n990/DL_random_model/raw_pval_extractor \ - --cell-types=CD4_TCM,CD4_Naive,CD4_TEM,CD4_CTL,CD4_Proliferating,CD4_TCM_permuted,NK,NK_CD56bright,NK_Proliferating,CD8_TEM,CD8_TCM,CD8_Proliferating,CD8_Naive,Treg,B_naive,B_memory,B_intermediate,Plasmablast,CD14_Mono,CD16_Mono,cDC1,cDC2,pDC,dnT,gdT,MAIT,ASDC,HSPC,ILC \ + --cell-types=CD4_TCM,CD4_Naive,NK,CD8_TEM,B_naive,CD8_Naive,CD14_Mono,CD4_TEM,CD8_TCM,B_intermediate,B_memory,Treg,CD4_CTL,gdT,CD16_Mono,MAIT,NK_CD56bright,cDC2,NK_Proliferating,dnT,pDC,Plasmablast,ILC,HSPC,CD8_Proliferating,cDC1,CD4_Proliferating,ASDC \ --title='associaTR BioHEART' --ylim=315 @@ -22,6 +22,37 @@ from cpg_utils import to_path from cpg_utils.hail_batch import init_batch, output_path +# Define the color mapping +color_mapping = { + 'CD4_TCM': '#0C46A0FF', + 'CD4_Naive': '#1976D2FF', + 'CD4_TEM': '#2096F2FF', + 'CD4_CTL': '#64B4F6FF', + 'Treg': '#90CAF8FF', + 'CD4_Proliferating': '#BADEFAFF', + 'gdT': '#817717FF', + 'MAIT': '#AEB32BFF', + 'dnT': '#CCDC39FF', + 'ILC': '#DCE674FF', + 'CD8_TEM': '#311A92FF', + 'CD8_Naive': '#5E34B1FF', + 'CD8_TCM': '#7E57C1FF', + 'CD8_Proliferating': '#D1C4E9FF', + 'NK': '#AC1357FF', + 'NK_CD56bright': '#E91E63FF', + 'NK_Proliferating': '#F38EB1FF', + 'B_naive': '#F47F17FF', + 'B_intermediate': '#FABF2CFF', + 'B_memory': '#FFEB3AFF', + 'Plasmablast': '#FFF176FF', + 'CD14_Mono': '#388D3BFF', + 'CD16_Mono': '#80C684FF', + 'cDC2': '#5D3F37FF', + 'pDC': '#795447FF', + 'cDC1': '#A0877FFF', + 'ASDC': '#D7CCC7FF', + 'HSPC': '#BDBDBDFF', +} @click.option('--title', help='Title of the QQ plot') @click.option('--ylim', help='Y-axis limit for the QQ plot', default=335) @@ -83,51 +114,25 @@ def main(input_dir, cell_types, title, ylim): plt.figure(figsize=(10, 8)) fig, ax = plt.subplots(figsize=(10, 8)) - # Define a list of colors - # Define a color mapping dictionary for each cell type - color_mapping = { - 'CD4_TCM': '#0C46A0FF', - 'CD4_Naive': '#1976D2FF', - 'CD4_TEM': '#2096F2FF', - 'CD4_CTL': '#64B4F6FF', - 'Treg': '#90CAF8FF', - 'CD4_Proliferating': '#BADEFAFF', - 'gdT': '#817717FF', - 'MAIT': '#AEB32BFF', - 'dnT': '#CCDC39FF', - 'ILC': '#DCE674FF', - 'CD8_TEM': '#311A92FF', - 'CD8_Naive': '#5E34B1FF', - 'CD8_TCM': '#7E57C1FF', - 'CD8_Proliferating': '#D1C4E9FF', - 'NK': '#AC1357FF', - 'NK_CD56bright': '#E91E63FF', - 'NK_Proliferating': '#F38EB1FF', - 'B_naive': '#F47F17FF', - 'B_intermediate': '#FABF2CFF', - 'B_memory': '#FFEB3AFF', - 'Plasmablast': '#FFF176FF', - 'CD14_Mono': '#388D3BFF', - 'CD16_Mono': '#80C684FF', - 'cDC2': '#5D3F37FF', - 'pDC': '#795447FF', - 'cDC1': '#A0877FFF', - 'ASDC': '#D7CCC7FF', - 'HSPC': '#BDBDBDFF', - } + # Set default color for permuted control or any cell type not in color_mapping + default_color = '#808080' # grey color for unmapped cell types # Pre-calculate sorted values expected_sorted_values = { - cell_type: np.sort(globals()[f'expected_log_pvals_{cell_type}']) for cell_type in cell_type_list + cell_type: np.sort(globals()[f'expected_log_pvals_{cell_type}']) + for cell_type in cell_type_list } observed_sorted_values = { - cell_type: np.sort(globals()[f'observed_log_pvals_{cell_type}']) for cell_type in cell_type_list + cell_type: np.sort(globals()[f'observed_log_pvals_{cell_type}']) + for cell_type in cell_type_list } - # Loop through each cell type and plot the scatter plot - for i, cell_type in enumerate(cell_type_list): + # Plot each cell type in the order specified by cell_type_list + for cell_type in cell_type_list: output_label = cell_type_mapping.get(cell_type, cell_type) - color = color_mapping.get(cell_type, 'grey') # Get the index of the color to use for the current cell type + # Use color from mapping if available, otherwise use default color + color = color_mapping.get(cell_type, default_color) + ax.scatter( expected_sorted_values[cell_type], observed_sorted_values[cell_type], @@ -136,29 +141,34 @@ def main(input_dir, cell_types, title, ylim): s=9, ) - # Create a legend for all items except "permuted control" + # Create a legend for permuted control and other items separately handles, labels = ax.get_legend_handles_labels() - permuted_control_handle = [h for h, l in zip(handles, labels) if l == "Permuted control"] - other_handles = [h for h, l in zip(handles, labels) if l != "Permuted control"] - other_labels = [l for l in labels if l != "Permuted control"] - ax.add_artist( - ax.legend( - permuted_control_handle, - ['Permuted control'], - bbox_to_anchor=(1.05, 0), - loc='upper left', - fontsize=11, - ), - ) - sns.despine() - # Create the main legend with other items + permuted_control_idx = [i for i, l in enumerate(labels) if l == "Permuted control"] + other_idx = [i for i, l in enumerate(labels) if l != "Permuted control"] + + # Add permuted control legend if it exists + if permuted_control_idx: + ax.add_artist( + ax.legend( + [handles[permuted_control_idx[0]]], + ["Permuted control"], + bbox_to_anchor=(1.05, 0), + loc='upper left', + fontsize=12, + ) + ) + + # Create the main legend with other items, maintaining the order from cell_type_list + other_handles = [handles[i] for i in other_idx] + other_labels = [labels[i] for i in other_idx] ax.legend(other_handles, other_labels, bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=11) - ax.set_xlabel('Expected -log₁₀(p-value)', fontsize=18) - ax.set_ylabel('Expected -log₁₀(p-value)', fontsize=18) + sns.despine() + ax.set_xlabel('Expected -log₁₀(p-value)', fontsize=17) + ax.set_ylabel('Expected -log₁₀(p-value)', fontsize=17) - plt.xticks(fontsize=14) - plt.yticks(fontsize=14) + plt.xticks(fontsize=15) + plt.yticks(fontsize=15) ax.set_ylim(0, ylim) ax.plot([0, 7], [0, 7], color='grey', linestyle='--') # Add a reference line @@ -168,6 +178,5 @@ def main(input_dir, cell_types, title, ylim): fig.savefig('qqplot.png') hl.hadoop_copy('qqplot.png', gcs_output_path) - if __name__ == '__main__': - main() + main() \ No newline at end of file