Skip to content

Commit

Permalink
FIX: Recalculate BOLD reference after HM/ST/SD corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
mgxd committed Sep 8, 2021
1 parent e9059f7 commit 4b0a580
Showing 1 changed file with 54 additions and 74 deletions.
128 changes: 54 additions & 74 deletions nibabies/workflows/bold/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):
"""
from niworkflows.engine.workflows import LiterateWorkflow as Workflow
from niworkflows.interfaces.utility import DictMerge
from niworkflows.workflows.epi.refmap import init_epi_reference_wf

mem_gb = {'filesize': 1, 'resampled': 1, 'largemem': 1}
bold_tlen = 10
Expand Down Expand Up @@ -479,17 +480,27 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):
omp_nthreads=omp_nthreads,
name='bold_t2smap_wf')

# Mask BOLD reference image
final_boldref_masker = pe.Node(BrainExtraction(), name='final_boldref_masker')
# Mask input BOLD reference image
initial_boldref_mask = pe.Node(BrainExtraction(), name='initial_boldref_mask')

# This final boldref will be calculated after bold_bold_trans_wf, which includes one or more:
# HMC (head motion correction)
# STC (slice time correction)
# SDC (susceptibility distortion correction)
final_boldref_wf = init_epi_reference_wf(
auto_bold_nss=True, omp_nthreads=omp_nthreads, name="final_boldref_wf",
)
final_boldref_mask = pe.Node(BrainExtraction(), name='final_boldref_mask')

# MAIN WORKFLOW STRUCTURE #######################################################
# fmt: off
workflow.connect([
# BOLD buffer has slice-time corrected if it was run, original otherwise
(boldbuffer, bold_split, [('bold_file', 'in_file')]),
# HMC
(inputnode, bold_hmc_wf, [
('bold_ref', 'inputnode.raw_ref_image')]),
(inputnode, final_boldref_masker, [('bold_ref', 'in_file')]),
(inputnode, initial_boldref_mask, [('bold_ref', 'in_file')]),
(val_bold, bold_hmc_wf, [
(("out_file", pop_file), 'inputnode.bold_file')]),
(inputnode, summary, [
Expand Down Expand Up @@ -542,14 +553,18 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):
('outputnode.xforms', 'inputnode.hmc_xforms')]),
# Summary
(outputnode, summary, [('confounds', 'confounds_file')]),
(final_boldref_wf, final_boldref_mask, [('outputnode.epi_ref_file', 'in_file')]),
])
# fmt: on

# for standard EPI data, pass along correct file
if not multiecho:
# TODO: Add SDC
workflow.connect([
(inputnode, func_derivatives_wf, [
('bold_file', 'inputnode.source_file')]),
(bold_bold_trans_wf, final_boldref_wf, [
('outputnode.bold', 'inputnode.in_files')]),
(bold_bold_trans_wf, bold_confounds_wf, [
('outputnode.bold', 'inputnode.bold')]),
(bold_split, bold_t1_trans_wf, [
Expand All @@ -564,8 +579,8 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):
(('bold_file', combine_meepi_source), 'inputnode.source_file')]),
(bold_bold_trans_wf, join_echos, [
('outputnode.bold', 'bold_files')]),
# (join_echos, final_boldref_wf, [
# ('bold_files', 'inputnode.bold_file')]),
(join_echos, final_boldref_wf, [
('bold_files', 'inputnode.in_files')]),
# TODO: Check with multi-echo data
(bold_bold_trans_wf, skullstrip_bold_wf, [
('outputnode.bold', 'inputnode.in_file')]),
Expand Down Expand Up @@ -604,9 +619,6 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):

if nonstd_spaces.intersection(('func', 'run', 'bold', 'boldref', 'sbref')):
workflow.connect([
(inputnode, func_derivatives_wf, [
('bold_ref', 'inputnode.bold_native_ref'),
]),
(bold_bold_trans_wf if not multiecho else bold_t2s_wf, outputnode, [
('outputnode.bold', 'bold_native')])
])
Expand Down Expand Up @@ -816,41 +828,39 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):
workflow.get_node(node).inputs.base_directory = nibabies_dir
workflow.get_node(node).inputs.source_file = ref_file

# Distortion correction
if not has_fieldmap:
# Finalize workflow with fieldmap-less connections
summary.inputs.distortion_correction = "None"
# fmt: off
# fmt: off
workflow.connect([
(final_boldref_mask, bold_t1_trans_wf, [
('out_mask', 'inputnode.ref_bold_mask'),
('out_file', 'inputnode.ref_bold_brain'),
]),
(final_boldref_mask, bold_reg_wf, [
('out_file', 'inputnode.ref_bold_brain'),
]),
(final_boldref_mask, bold_confounds_wf, [
('out_mask', 'inputnode.bold_mask')
]),
])
if nonstd_spaces.intersection(('T1w', 'anat')):
workflow.connect([
(final_boldref_masker, bold_t1_trans_wf, [
('out_mask', 'inputnode.ref_bold_mask'),
('out_file', 'inputnode.ref_bold_brain'),
]),
(final_boldref_masker, bold_reg_wf, [
('out_file', 'inputnode.ref_bold_brain'),
]),
(final_boldref_masker, bold_confounds_wf, [
('out_mask', 'inputnode.bold_mask')
]),
(final_boldref_mask, boldmask_to_t1w, [
('out_mask', 'input_image')]),
])
if nonstd_spaces.intersection(('func', 'run', 'bold', 'boldref', 'sbref')):
workflow.connect([
(final_boldref_mask, func_derivatives_wf, [
('out_file', 'inputnode.bold_native_ref'),
('out_mask', 'inputnode.bold_mask_native')]),
])
if spaces.get_spaces(nonstandard=False, dim=(3,)):
workflow.connect([
(final_boldref_mask, bold_std_trans_wf, [
('out_mask', 'inputnode.bold_mask')]),
])
# fmt: on

if nonstd_spaces.intersection(('T1w', 'anat')):
workflow.connect([
(final_boldref_masker, boldmask_to_t1w, [
('out_mask', 'input_image')]),
])
if nonstd_spaces.intersection(('func', 'run', 'bold', 'boldref', 'sbref')):
workflow.connect([
(final_boldref_masker, func_derivatives_wf, [
('out_file', 'inputnode.bold_native_ref'),
('out_mask', 'inputnode.bold_mask_native')]),
])
if spaces.get_spaces(nonstandard=False, dim=(3,)):
workflow.connect([
(final_boldref_masker, bold_std_trans_wf, [
('out_mask', 'inputnode.bold_mask')]),
])
# fmt: on
if not has_fieldmap:
summary.inputs.distortion_correction = "None"
return workflow

# SDC
Expand Down Expand Up @@ -900,7 +910,6 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):
name="ds_report_sdc",
run_without_submitting=True,
)
unwarp_masker = pe.Node(BrainExtraction(), name='unwarp_masker')

# fmt: off
workflow.connect([
Expand All @@ -917,8 +926,8 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):
(output_select, summary, [("sdc_method", "distortion_correction")]),
(inputnode, coeff2epi_wf, [
("bold_ref", "inputnode.target_ref")]),
(final_boldref_masker, coeff2epi_wf, [
("out_file", "inputnode.target_mask")]),
(initial_boldref_mask, coeff2epi_wf, [
("out_file", "inputnode.target_mask")]), # skull-stripped brain
(inputnode, unwarp_wf, [("bold_ref", "inputnode.distorted")]),
(coeff2epi_wf, unwarp_wf, [
("outputnode.fmap_coeff", "inputnode.fmap_coeff")]),
Expand All @@ -927,46 +936,17 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):
("outputnode.corrected_mask", "wm_seg")]),
(inputnode, ds_report_sdc, [("bold_file", "source_file")]),
(sdc_report, ds_report_sdc, [("out_report", "in_file")]),
# remaining workflow connections
(unwarp_wf, bold_bold_trans_wf, [
# ('outputnode.corrected_mask', 'inputnode.bold_mask'),
('outputnode.fieldwarp', 'inputnode.fieldwarp'),
]),
(unwarp_wf, unwarp_masker, [('outputnode.corrected', 'in_file')]),
(unwarp_masker, bold_confounds_wf, [('out_mask', 'inputnode.bold_mask')]),
(unwarp_masker, bold_t1_trans_wf, [
('out_mask', 'inputnode.ref_bold_mask'),
('out_file', 'inputnode.ref_bold_brain')]),
(unwarp_masker, bold_reg_wf, [
('out_file', 'inputnode.ref_bold_brain')]),
(unwarp_wf, bold_bold_trans_wf, [('outputnode.fieldwarp', 'inputnode.fieldwarp')]),
])

if nonstd_spaces.intersection(('T1w', 'anat')):
workflow.connect([
(unwarp_masker, boldmask_to_t1w, [
('out_mask', 'input_image')]),
])

if nonstd_spaces.intersection(('func', 'run', 'bold', 'boldref', 'sbref')):
workflow.connect([
(unwarp_masker, func_derivatives_wf, [
('out_file', 'inputnode.bold_native_ref'),
('out_mask', 'inputnode.bold_mask_native')]),
])

if spaces.get_spaces(nonstandard=False, dim=(3,)):
workflow.connect([
(unwarp_masker, bold_std_trans_wf, [
('out_mask', 'inputnode.bold_mask')]),
])

if not multiecho:
workflow.connect([
(unwarp_wf, bold_t1_trans_wf, [
('outputnode.fieldwarp', 'inputnode.fieldwarp')]),
(unwarp_wf, bold_std_trans_wf, [
('outputnode.fieldwarp', 'inputnode.fieldwarp')]),
])
# fmt: on

return workflow

Expand Down

0 comments on commit 4b0a580

Please sign in to comment.