Skip to content

Commit

Permalink
Merge pull request #203 from pni-lab/FUGUE-fieldmap-correction
Browse files Browse the repository at this point in the history
FUGUE fieldmap correction
  • Loading branch information
spisakt authored Jan 17, 2025
2 parents ea7cdbf + da50825 commit 0a0665d
Show file tree
Hide file tree
Showing 3 changed files with 606 additions and 136 deletions.
12 changes: 7 additions & 5 deletions PUMI/pipelines/anat/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@QcPipeline(inputspec_fields=['background', 'overlay'],
outputspec_fields=['out_file'])
def qc_segmentation(wf, fmri=False, **kwargs):
def qc_segmentation(wf, sinking_name=None, fmri=False, **kwargs):
"""
Create quality check images for background extraction workflows
Expand All @@ -39,7 +39,9 @@ def qc_segmentation(wf, fmri=False, **kwargs):
wf.connect('inputspec', 'overlay', plot, 'overlay')

# sinking
if fmri:
if sinking_name:
wf.connect(plot, 'out_file', 'sinker', sinking_name)
elif fmri:
wf.connect(plot, 'out_file', 'sinker', 'qc_func_segmentation')
else:
wf.connect(plot, 'out_file', 'sinker', 'qc_anat_segmentation')
Expand Down Expand Up @@ -154,7 +156,7 @@ def bet_fsl(wf, fmri=False, volume='middle', **kwargs):

@AnatPipeline(inputspec_fields=['in_file'],
outputspec_fields=['out_file', 'brain_mask', 'tiv'])
def bet_deepbet(wf, fmri=False, volume='middle', threshold=0.5, n_dilate=0, no_gpu=False, **kwargs):
def bet_deepbet(wf, fmri=False, volume='middle', threshold=0.5, n_dilate=0, no_gpu=False, sinking_name=None, **kwargs):
"""
Perform brain extraction with deepbet.
Expand Down Expand Up @@ -230,11 +232,11 @@ def run_deepbet(in_file, threshold=0.5, n_dilate=0, no_gpu=False):
background = pick_volume('qc_background', volume=volume)
wf.connect('inputspec', 'in_file', background, 'in_file')

qc = qc_segmentation(name='qc_segmentation', fmri=True, qc_dir=wf.qc_dir)
qc = qc_segmentation(name='qc_segmentation', sinking_name=sinking_name, fmri=True, qc_dir=wf.qc_dir)
wf.connect(overlay, 'out_file', qc, 'overlay')
wf.connect(background, 'out_file', qc, 'background')
else:
qc = qc_segmentation(name='qc_segmentation', qc_dir=wf.qc_dir)
qc = qc_segmentation(name='qc_segmentation', sinking_name=sinking_name, qc_dir=wf.qc_dir)
wf.connect(bet, 'out_file', qc, 'overlay')
wf.connect('inputspec', 'in_file', qc, 'background')

Expand Down
229 changes: 98 additions & 131 deletions PUMI/pipelines/func/deconfound.py
Original file line number Diff line number Diff line change
@@ -1,180 +1,147 @@
import os
from pathlib import Path
from nipype import Function
from nipype.algorithms import confounds
from nipype.interfaces import afni, fsl, utility
from PUMI.engine import NestedNode as Node, QcPipeline
from PUMI.engine import FuncPipeline
from PUMI.pipelines.anat.segmentation import bet_deepbet
from PUMI.pipelines.multimodal.image_manipulation import pick_volume, timecourse2png
from PUMI.utils import calc_friston_twenty_four, calculate_FD_Jenkinson, mean_from_txt, max_from_txt
from PUMI.plot.carpet_plot import plot_carpet


@QcPipeline(inputspec_fields=['func_1', 'func_2', 'func_corrected'],
outputspec_fields=['out_file'])
def fieldmap_correction_qc(wf, volume='middle', **kwargs):
@QcPipeline(inputspec_fields=['background', 'overlay'],
outputspec_fields=['out_file'])
def qc_fieldmap_correction_fugue(wf, overlay_volume='middle', **kwargs):
"""
Generate a quality check (QC) image for the FUGUE fieldmap correction workflow.
Quality check image generation for fieldmap correction pipeline.
Parameters:
overlay_volume (str): The volume of the overlay image to be used for the QC plot.
Options are "first", "middle", "last", or an integer specifying the volume index.
Default is "middle".
Inputs:
func_1 (str): Path to functional image (e.g. LR phase encoded rsfMRI).
func_2 (str): Path to functional image with another phase encoding than func_1 (e.g. RL phase encoded rsfMRI).
func_corrected (str): Path to fieldmap corrected functional image.
background (str): Path to the anatomical background image.
overlay (str): Path to the overlay image (e.g., the unwarped functional scan).
Outputs:
out_file (str): Path to quality check image.
out_file (str): Path to the generated QC image.
Sinking:
- Quality check image.
- Generated QC image showing the overlay on the background.
"""

def get_cut_cords(func, n_slices=10):
import nibabel as nib
import numpy as np
def create_fieldmap_plot(overlay, background):
from PUMI.utils import plot_roi

func_img = nib.load(func)
y_dim = func_img.shape[1] # y-dimension (coronal direction) is the second dimension in the image shape
_, output_file = plot_roi(roi_img=overlay, bg_img=background)

slices = np.linspace(-y_dim / 2, y_dim / 2, n_slices)
# slices might contain floats but this is not a problem since nilearn will round floats to the
# nearest integer value!
return slices
return output_file

def create_montage(vol_1, vol_2, vol_corrected, n_slices=10):
from matplotlib import pyplot as plt
from pathlib import Path
from nilearn import plotting
import os
overlay_vol = pick_volume('overlay_vol', overlay_volume)
wf.connect('inputspec', 'overlay', overlay_vol, 'in_file')

fig, axes = plt.subplots(3, 1, facecolor='black', figsize=(10, 15))
overlay_bet = bet_deepbet('overlay_bet')
wf.connect(overlay_vol, 'out_file', overlay_bet, 'in_file')

plotting.plot_anat(vol_1, display_mode='y', cut_coords=get_cut_cords(vol_1, n_slices=n_slices),
title='Image #1', black_bg=True, axes=axes[0])
plotting.plot_anat(vol_2, display_mode='y', cut_coords=get_cut_cords(vol_2, n_slices=n_slices),
title='Image #2', black_bg=True, axes=axes[1])
plotting.plot_anat(vol_corrected, display_mode='y', cut_coords=get_cut_cords(vol_corrected, n_slices=n_slices),
title='Corrected', black_bg=True, axes=axes[2])
plot = Node(Function(input_names=['overlay', 'background'],
output_names=['out_file'],
function=create_fieldmap_plot),
name='plot')

path = str(Path(os.getcwd() + '/fieldmap_correction_comparison.png'))
plt.savefig(path)
plt.close(fig)
return path
wf.connect('inputspec', 'background', plot, 'background')
wf.connect(overlay_bet, 'out_file', plot, 'overlay')

vol_1 = pick_volume('vol_1', volume=volume)
wf.connect('inputspec', 'func_1', vol_1, 'in_file')
wf.connect(plot, 'out_file', 'sinker', 'qc_fieldmap_correction')

vol_2 = pick_volume('vol_2', volume=volume)
wf.connect('inputspec', 'func_2', vol_2, 'in_file')

vol_corrected = pick_volume('vol_corrected', volume=volume)
wf.connect('inputspec', 'func_corrected', vol_corrected, 'in_file')

montage = Node(Function(
input_names=['vol_1', 'vol_2', 'vol_corrected'],
output_names=['out_file'],
function=create_montage),
name='montage_node'
)
wf.connect(vol_1, 'out_file', montage, 'vol_1')
wf.connect(vol_2, 'out_file', montage, 'vol_2')
wf.connect(vol_corrected, 'out_file', montage, 'vol_corrected')

wf.connect(montage, 'out_file', 'outputspec', 'out_file')
wf.connect(montage, 'out_file', 'sinker', 'qc_fieldmap_correction')
# output
wf.connect(plot, 'out_file', 'outputspec', 'out_file')


@FuncPipeline(inputspec_fields=['func_1', 'func_2'],
@FuncPipeline(inputspec_fields=['main_img', 'main_json', 'anat_img', 'phasediff_img', 'phasediff_json',
'magnitude_img'],
outputspec_fields=['out_file'])
def fieldmap_correction(wf, encoding_direction=['x-', 'x'], trt=[0.0522, 0.0522], tr=0.72, **kwargs):
def fieldmap_correction_fugue(wf, **kwargs):
"""
Perform fieldmap correction using FSL's FUGUE.
Fieldmap correction pipeline.
Parameters:
encoding_direction (list): List of encoding directions (default is left-right and right-left phase encoding).
trt (list): List of total readout times (default adapted to rsfMRI data of the HCP WU 1200 dataset).
Default is:
1*(10**(-3))*EchoSpacingMS*EpiFactor = 1*(10**(-3))*0.58*90 = 0.0522 (for LR and RL image)
tr (float): Repetition time (default adapted to rsfMRI data of the HCP WU 1200 dataset).
This pipeline uses the magnitude and phase-difference images to generate a fieldmap and then applies
fieldmap correction to a functional image.
Inputs:
func_1 (str): Path to functional image (e.g. LR phase encoded rsfMRI).
func_2 (str): Path to functional image with another phase encoding than func_1 (e.g. RL phase encoded rsfMRI).
main_img (str): Path to the 4D functional image to be corrected.
main_json (str): Path to the JSON metadata file for the functional image.
anat_img (str): Path to the anatomical image for QC background.
phasediff_img (str): Path to the phase-difference image.
phasediff_json (str): Path to the JSON metadata file for the phase-difference image.
magnitude_img (str): Path to the magnitude image.
Outputs:
out_file (str): 4d distortion corrected image.
out_file (str): Path to the fieldmap-corrected functional image.
Sinking:
- 4d distortion corrected image.
- Fieldmap-corrected functional image.
- QC images for the fieldmap correction.
"""

For more information:
https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/topup/ExampleTopupFollowedByApplytopup
https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/topup/Faq#How_do_I_know_what_phase-encode_vectors_to_put_into_my_--datain_text_file.3F
https://www.humanconnectome.org/storage/app/media/documentation/s1200/HCP_S1200_Release_Appendix_I.pdf
bet_magnitude_img = bet_deepbet('bet_magnitude_img', sinking_name='magnitude_img_segm')
wf.connect('inputspec', 'magnitude_img', bet_magnitude_img, 'in_file')

"""
def get_fieldmap_parameters(main_json, phasediff_json):
import json

items_to_list_function = lambda item_1, item_2: [item_1, item_2] # helper function we will need later
with open(main_json, 'r') as f:
main_metadata = json.load(f)

# We use the first volume of func_1 and the first volume of the func_2 4D-image for the estimation of the field.
first_func1_vol = pick_volume('first_func1_vol', volume='first')
wf.connect('inputspec', 'func_1', first_func1_vol, 'in_file')
with open(phasediff_json, 'r') as f:
phasediff_metadata = json.load(f)

first_func2_vol = pick_volume('first_func2_vol', volume='first')
wf.connect('inputspec', 'func_2', first_func2_vol, 'in_file')
# Extract dwell_time (EffectiveEchoSpacing)
dwell_time = main_metadata.get('EffectiveEchoSpacing') # In seconds

# We need to combine the two 3D images we extracted into one 4D image
# fsl.Merge expects a list as input, so we need to combine our two 3D images first into a list
first_volumes_to_list = Node(Function(
input_names=['item_1', 'item_2'],
output_names=['output'],
function=items_to_list_function),
name='first_volumes_to_list'
)
wf.connect(first_func1_vol, 'out_file', first_volumes_to_list, 'item_1')
wf.connect(first_func2_vol, 'out_file', first_volumes_to_list, 'item_2')

# Now combine 3D images to 4D image along the time axis
merger = Node(fsl.Merge(), name='merger')
merger.inputs.dimension = 't'
merger.inputs.output_type = 'NIFTI_GZ'
merger.inputs.tr = tr
wf.connect(first_volumes_to_list, 'output', merger, 'in_files')

# Estimate susceptibility induced distortions
topup = Node(fsl.TOPUP(), name='topup')
topup.inputs.encoding_direction = encoding_direction
topup.inputs.readout_times = trt
wf.connect(merger, 'merged_file', topup, 'in_file')

# The two original 4D files are also needed inside a list
func_files_to_list = Node(Function(
input_names=['item_1', 'item_2'],
output_names=['output'],
function=items_to_list_function),
name='func_files_to_list'
)
wf.connect('inputspec', 'func_1', func_files_to_list, 'item_1')
wf.connect('inputspec', 'func_2', func_files_to_list, 'item_2')

# Apply result of fsl.TOPUP to our original data
# Result will be one 4D distortion corrected image
apply_topup = Node(fsl.ApplyTOPUP(), name='apply_topup')
wf.connect(func_files_to_list, 'output', apply_topup, 'in_files')
wf.connect(topup, 'out_fieldcoef', apply_topup, 'in_topup_fieldcoef')
wf.connect(topup, 'out_movpar', apply_topup, 'in_topup_movpar')
wf.connect(topup, 'out_enc_file', apply_topup, 'encoding_file')

qc_fieldmap_correction = fieldmap_correction_qc('qc_fieldmap_correction')
wf.connect('inputspec', 'func_1', qc_fieldmap_correction, 'func_1')
wf.connect('inputspec', 'func_2', qc_fieldmap_correction, 'func_2')
wf.connect(topup, 'out_corrected', qc_fieldmap_correction, 'func_corrected')

wf.connect(apply_topup, 'out_corrected', 'outputspec', 'out_file')
wf.connect(apply_topup, 'out_corrected', 'sinker', 'out_file')
if dwell_time is None:
raise ValueError(f'{main_json} does not contain EffectiveEchoSpacing')

# Extract and calculate delta_TE (in ms)
echo_time_1 = phasediff_metadata.get('EchoTime1') # In seconds
echo_time_2 = phasediff_metadata.get('EchoTime2') # In seconds

if echo_time_1 is None:
raise ValueError(f'{main_json} does not contain EchoTime1')

if echo_time_2 is None:
raise ValueError(f'{main_json} does not contain EchoTime2')

asym_se_time = abs(echo_time_2 - echo_time_1) # In seconds
delta_TE = asym_se_time * 1000 # Convert to ms

return dwell_time, delta_TE, asym_se_time

get_params = Node(Function(
input_names=['main_json', 'phasediff_json'],
output_names=['dwell_time', 'delta_TE', 'asym_se_time'],
function=get_fieldmap_parameters
), name='get_params')
wf.connect('inputspec', 'phasediff_json', get_params, 'phasediff_json')
wf.connect('inputspec', 'main_json', get_params, 'main_json')

prepare_fieldmap = Node(fsl.PrepareFieldmap(), name='prepare_fieldmap')
wf.connect(get_params, 'delta_TE', prepare_fieldmap, 'delta_TE')
wf.connect(bet_magnitude_img, 'out_file', prepare_fieldmap, 'in_magnitude')
wf.connect('inputspec', 'phasediff_img', prepare_fieldmap, 'in_phase')

fugue = Node(fsl.FUGUE(), name='fugue')
wf.connect(get_params, 'dwell_time', fugue, 'dwell_time')
wf.connect(get_params, 'asym_se_time', fugue, 'asym_se_time')
wf.connect(prepare_fieldmap, 'out_fieldmap', fugue, 'fmap_in_file')
wf.connect('inputspec', 'main_img', fugue, 'in_file')

qc = qc_fieldmap_correction_fugue('qc_fieldmap_correction')
wf.connect(fugue, 'unwarped_file', qc, 'overlay')
wf.connect('inputspec', 'anat_img', qc, 'background')

wf.connect(fugue, 'unwarped_file', 'outputspec', 'out_file')


@FuncPipeline(inputspec_fields=['in_file'],
Expand Down
Loading

0 comments on commit 0a0665d

Please sign in to comment.