Skip to content

Commit

Permalink
Merge pull request #199 from pni-lab/197-change-default-registration-…
Browse files Browse the repository at this point in the history
…template

Introduce ventricle mask generation workflow; Change default templates; Make templateflow template retrieval MUCH more time-efficient; Redo anat_prov, get_references and more
  • Loading branch information
spisakt authored Jan 17, 2025
2 parents cffbc40 + 3172720 commit 07d8f48
Show file tree
Hide file tree
Showing 5 changed files with 336 additions and 130 deletions.
78 changes: 52 additions & 26 deletions PUMI/pipelines/anat/anat_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import nipype.interfaces.ants as ants
from PUMI.engine import NestedNode as Node
from PUMI.pipelines.anat.anat2mni import anat2mni_fsl, anat2mni_ants_hardcoded
from PUMI.pipelines.multimodal.masks import create_ventricle_mask
from PUMI.pipelines.anat.segmentation import bet_fsl, tissue_segmentation_fsl, bet_hd, bet_deepbet
from PUMI.engine import AnatPipeline
from PUMI.utils import get_reference
Expand All @@ -29,7 +30,7 @@ def anat_proc(wf, bet_tool='FSL', reg_tool='ANTS', **kwargs):
Inputs:
brain (str): Path to the brain which should be segmented.
stand2anat_xfm (str): Path to standard2input matrix calculated by FSL FLIRT.
Only necessary when using prior probability maps!
Only necessary when using FSL's prior probability maps!
Outputs:
brain (str): brain extracted image in subject space
Expand All @@ -52,6 +53,8 @@ def anat_proc(wf, bet_tool='FSL', reg_tool='ANTS', **kwargs):
"""

# Step 1: Extract brain / skull removal

if bet_tool == 'FSL':
bet_wf = bet_fsl('bet_fsl')
elif bet_tool == 'HD-BET':
Expand All @@ -61,7 +64,9 @@ def anat_proc(wf, bet_tool='FSL', reg_tool='ANTS', **kwargs):
else:
raise ValueError('bet_tool can be \'FSL\', \'HD-BET\' or \'deepbet\' but not ' + bet_tool)

tissue_segmentation_wf = tissue_segmentation_fsl('tissue_segmentation_fsl')
wf.connect('inputspec', 'in_file', bet_wf, 'in_file')

# Step 2: Transform head from anatomical space to MNI space

if reg_tool == 'FSL':
anat2mni_wf = anat2mni_fsl('anat2mni_fsl')
Expand All @@ -70,50 +75,71 @@ def anat_proc(wf, bet_tool='FSL', reg_tool='ANTS', **kwargs):
else:
raise ValueError('reg_tool can be \'ANTS\' or \'FSL\' but not ' + reg_tool)


# resample 2mm-std ventricle to the actual standard space
resample_std_ventricle = Node(interface=afni.Resample(outputtype='NIFTI_GZ',
in_file=get_reference(wf, 'ventricle_mask')),
name='resample_std_ventricle')

# transform std ventricle mask to anat space, applying the invers warping filed
if reg_tool == 'FSL':
unwarp_ventricle = Node(interface=fsl.ApplyWarp(), name='unwarp_ventricle')
elif reg_tool == 'ANTS':
unwarp_ventricle = Node(interface=ants.ApplyTransforms(), name='unwarp_ventricle')

# mask csf segmentation with anat-space ventricle mask
ventricle_mask = Node(fsl.ImageMaths(op_string=' -mas'), name="ventricle_mask")

wf.connect('inputspec', 'in_file', bet_wf, 'in_file')
wf.connect('inputspec', 'in_file', anat2mni_wf, 'head')
wf.connect(bet_wf, 'out_file', anat2mni_wf, 'brain')

# Step 3: Apply tissue segmentation

tissue_segmentation_wf = tissue_segmentation_fsl('tissue_segmentation_fsl')
wf.connect(bet_wf, 'out_file', tissue_segmentation_wf, 'brain')
wf.connect(bet_wf, 'out_file', anat2mni_wf, 'brain')
wf.connect(anat2mni_wf, 'inv_linear_xfm', tissue_segmentation_wf, 'stand2anat_xfm')
wf.connect(anat2mni_wf, 'inv_linear_xfm', tissue_segmentation_wf, 'stand2anat_xfm') # Used to transform FSL priors to subject space

# Step 4: If needed, create ventricle mask, afterward resample ventricle mask to MNI space

std_ventricle_mask_file = wf.cfg_parser.get('TEMPLATES', 'ventricle_mask', fallback='')
std_csf_probseg_file = wf.cfg_parser.get('TEMPLATES', 'csf_probseg', fallback='')

if std_ventricle_mask_file:
resample_std_ventricle = Node(
interface=afni.Resample(outputtype='NIFTI_GZ', in_file=get_reference(wf, 'ventricle_mask')),
name='resample_std_ventricle'
)
elif std_csf_probseg_file:
create_ventricle_mask_wf = create_ventricle_mask(name='create_ventricle_mask_wf')
create_ventricle_mask_wf.get_node('inputspec').inputs.csf_probseg = get_reference(wf, 'csf_probseg')
create_ventricle_mask_wf.get_node('inputspec').inputs.template = get_reference(wf, 'brain')

resample_std_ventricle = Node(
interface=afni.Resample(outputtype='NIFTI_GZ'),
name='resample_std_ventricle'
)
wf.connect(create_ventricle_mask_wf, 'out_file', resample_std_ventricle, 'in_file')
else:
raise ValueError("Either 'ventricle_mask' or 'csf_probseg' must be specified in settings.ini!")
wf.connect(anat2mni_wf, 'std_template', resample_std_ventricle, 'master')
wf.connect(tissue_segmentation_wf, 'probmap_csf', ventricle_mask, 'in_file')

# Step 5: Transform ventricle mask from MNI space to anat space

if reg_tool == 'FSL':
unwarp_ventricle = Node(interface=fsl.ApplyWarp(), name='unwarp_ventricle')
wf.connect(resample_std_ventricle, 'out_file', unwarp_ventricle, 'in_file')
wf.connect('inputspec', 'in_file', unwarp_ventricle, 'ref_file')
wf.connect(anat2mni_wf, 'inv_nonlinear_xfm', unwarp_ventricle, 'field_file')
wf.connect(anat2mni_wf, 'inv_nonlinear_xfm', 'outputspec', 'mni2anat_warpfield')
wf.connect(unwarp_ventricle, 'out_file', ventricle_mask, 'in_file2')
elif reg_tool == 'ANTS':
unwarp_ventricle = Node(interface=ants.ApplyTransforms(), name='unwarp_ventricle')
wf.connect(resample_std_ventricle, 'out_file', unwarp_ventricle, 'input_image')
wf.connect('inputspec', 'in_file', unwarp_ventricle, 'reference_image')
wf.connect(anat2mni_wf, 'inv_nonlinear_xfm', unwarp_ventricle, 'transforms')
wf.connect(anat2mni_wf, 'inv_nonlinear_xfm', 'outputspec', 'mni2anat_warpfield')
wf.connect(unwarp_ventricle, 'output_image', ventricle_mask, 'in_file2')

# Step 6: Mask csf segmentation with anat-space ventricle mask

anat_ventricle_mask = Node(fsl.ImageMaths(op_string=' -mas'), name='anat_ventricle_mask')
wf.connect(tissue_segmentation_wf, 'probmap_csf', anat_ventricle_mask, 'in_file')
if reg_tool == 'FSL':
wf.connect(unwarp_ventricle, 'out_file', anat_ventricle_mask, 'in_file2')
elif reg_tool == 'ANTS':
wf.connect(unwarp_ventricle, 'output_image', anat_ventricle_mask, 'in_file2')

# Outputs

wf.connect('inputspec', 'in_file', 'outputspec', 'head')
wf.connect(bet_wf, 'out_file', 'outputspec', 'brain')
wf.connect(bet_wf, 'brain_mask', 'outputspec', 'brain_mask')
wf.connect(anat2mni_wf, 'inv_nonlinear_xfm', 'outputspec', 'mni2anat_warpfield')
wf.connect(anat2mni_wf, 'nonlinear_xfm', 'outputspec', 'anat2mni_warpfield')
wf.connect(anat2mni_wf, 'output_brain', 'outputspec', 'std_brain')
wf.connect(anat2mni_wf, 'std_template', 'outputspec', 'std_template')
wf.connect(ventricle_mask, 'out_file', 'outputspec', 'probmap_ventricle')
wf.connect(anat_ventricle_mask, 'out_file', 'outputspec', 'probmap_ventricle')
wf.connect(tissue_segmentation_wf, 'partial_volume_map', 'outputspec', 'parvol_map')
wf.connect(tissue_segmentation_wf, 'probmap_csf', 'outputspec', 'probmap_csf')
wf.connect(tissue_segmentation_wf, 'probmap_gm', 'outputspec', 'probmap_gm')
Expand Down
10 changes: 1 addition & 9 deletions PUMI/pipelines/anat/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from PUMI.engine import AnatPipeline, QcPipeline
from PUMI.engine import QcPipeline
from PUMI.interfaces.HDBet import HDBet
from PUMI.utils import create_segmentation_qc
from nipype.interfaces import fsl
Expand Down Expand Up @@ -324,11 +324,6 @@ def tissue_segmentation_fsl(wf, priormap=True, **kwargs):
"""

if priormap:
priorprob_csf = os.path.join(os.environ['FSLDIR'], '/data/standard/tissuepriors/avg152T1_csf.hdr')
priorprob_gm = os.path.join(os.environ['FSLDIR'], '/data/standard/tissuepriors/avg152T1_gray.hdr')
priorprob_wm = os.path.join(os.environ['FSLDIR'], '/data/standard/tissuepriors/avg152T1_white.hdr')

fast = Node(interface=fsl.FAST(), name='fast')
fast.inputs.img_type = 1
fast.inputs.segments = True
Expand Down Expand Up @@ -357,9 +352,6 @@ def tissue_segmentation_fsl(wf, priormap=True, **kwargs):
wf.connect(split_probability_maps, 'out3', 'sinker', 'fast_wm')

# output
wf.get_node('outputspec').inputs.probmap_csf = priorprob_csf
wf.get_node('outputspec').inputs.probmap_gm = priorprob_gm
wf.get_node('outputspec').inputs.probmap_wm = priorprob_wm
wf.connect(fast, 'mixeltype', 'outputspec', 'mixeltype')
wf.connect(fast, 'partial_volume_map', 'outputspec', 'partial_volume_map')
wf.connect(split_probability_maps, 'out1', 'outputspec', 'probmap_csf')
Expand Down
132 changes: 132 additions & 0 deletions PUMI/pipelines/multimodal/masks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from PUMI.engine import QcPipeline, GroupPipeline
from PUMI.engine import NestedNode as Node
import nipype.interfaces.fsl as fsl
from nipype.interfaces.utility import Function
from PUMI.utils import get_reference


@QcPipeline(inputspec_fields=['ventricle_mask', 'template'],
outputspec_fields=['out_file'])
def qc(wf, **kwargs):
"""
Pipeline for generating QC images for ventricle mask creation.
Inputs:
ventricle_mask (str): Path to the ventricle mask file.
template (str): Path to the template image file.
Outputs:
out_file (str): Path to the output QC image.
Parameters:
wf (Workflow): The workflow object.
kwargs (dict): Additional arguments for the workflow.
"""

def create_image(ventricle_mask, template):
"""
Create a QC image overlaying the ventricle mask on the template.
Parameters:
ventricle_mask (str): Path to the ventricle mask file.
template (str): Path to the template image file.
Returns:
str: Path to the output QC image.
"""
from PUMI.utils import plot_roi

_, out_file = plot_roi(roi_img=ventricle_mask, bg_img=template, cmap='winter', save_img=True)
return out_file

plot = Node(Function(input_names=['ventricle_mask', 'template'],
output_names=['out_file'],
function=create_image),
name='plot')
wf.connect('inputspec', 'ventricle_mask', plot, 'ventricle_mask')
wf.connect('inputspec', 'template', plot, 'template')

# Sinking
wf.connect(plot, 'out_file', 'sinker', 'qc_ventricle_mask')

# Output
wf.connect(plot, 'out_file', 'outputspec', 'out_file')


@GroupPipeline(inputspec_fields=['csf_probseg', 'template'],
outputspec_fields=['out_file'])
def create_ventricle_mask(wf, fallback_threshold=0, fallback_dilate_mask=0, **kwargs):
"""
Pipeline for generating a ventricle mask based on CSF probability segmentation and an atlas.
This pipeline generates a ventricle mask by thresholding the CSF probability map,
combining it with atlas-defined ventricle labels, and optionally dilating the resulting mask.
Inputs:
csf_probseg (str): Path to the CSF probability segmentation file.
template (str): Path to the template image file used for QC.
Outputs:
out_file (str): Path to the generated ventricle mask file.
Parameters:
wf (Workflow): The workflow object.
fallback_threshold (float, optional): Default threshold for the CSF probability map thresholding.
Used if not defined in the settings file (default: 0).
fallback_dilate_mask (int, optional): Default dilation value for the ventricle mask.
Used if not defined in the settings file (default: 0).
kwargs (dict): Additional arguments for the workflow.
Raises:
ValueError: If ventricle labels are not defined in the configuration file.
"""

# Load ventricle labels from settings.ini
ventricle_labels = wf.cfg_parser.get('TEMPLATES', 'ventricle_labels', fallback='')
ventricle_labels = [int(label) for label in ventricle_labels.replace(' ', '').split(',')]
if len(ventricle_labels) == 0:
raise ValueError('You need to define ventricle labels in settings.ini!')

# Threshold ventricle labels individually
threshold_nodes = []
atlas = get_reference(wf, 'atlas')

for label in ventricle_labels:
node = Node(fsl.ImageMaths(op_string=f"-thr {label} -uthr {label} -bin"), name=f'threshold_ventricle_{label}')
node.inputs.in_file = atlas
wf.add_nodes([node])
threshold_nodes.append(node)

# Use MultiImageMaths to combine all ventricle masks using -max
combine_ventricles_op_string = " ".join(["-add %s"] * (len(threshold_nodes) - 1))
combine_ventricles = Node(fsl.MultiImageMaths(op_string=combine_ventricles_op_string), name='combine_ventricles')
wf.connect(threshold_nodes[0], 'out_file', combine_ventricles, 'in_file')
for n in threshold_nodes[1:]:
wf.connect(n, 'out_file', combine_ventricles, 'operand_files')

# Dilate ventricle mask
dilate_mask_value = int(wf.cfg_parser.get('FSL', 'ImageMaths_dilate_ventricle_mask', fallback=fallback_dilate_mask))
dilate_mask_op_string = '-dilM ' * dilate_mask_value + '-bin'
dilate_ventricle_mask = Node(fsl.ImageMaths(op_string=dilate_mask_op_string), name='dilate_ventricle_mask')
wf.connect(combine_ventricles, 'out_file', dilate_ventricle_mask, 'in_file')

# Threshold CSF probability map
threshold = float(wf.cfg_parser.get('FSL', 'ImageMaths_ventricle_threshold', fallback=fallback_threshold))
threshold_csf = Node(fsl.ImageMaths(op_string=f'-thr {threshold} -bin'), name='threshold_csf')
wf.connect('inputspec', 'csf_probseg', threshold_csf, 'in_file')

# Multiply the combined ventricle mask with the CSF mask
combine_csf_ventricles = Node(fsl.MultiImageMaths(op_string='-mul %s'), name='combine_csf_ventricles')
wf.connect(threshold_csf, 'out_file', combine_csf_ventricles, 'in_file')
wf.connect(dilate_ventricle_mask, 'out_file', combine_csf_ventricles, 'operand_files')

# QC
qc_create_ventricle_mask = qc(name='qc_create_ventricle_mask', qc_dir=wf.qc_dir)
wf.connect(combine_csf_ventricles, 'out_file', qc_create_ventricle_mask, 'ventricle_mask')
wf.connect('inputspec', 'template', qc_create_ventricle_mask, 'template')

# Sinking
wf.connect(combine_csf_ventricles, 'out_file', 'sinker', 'create_ventricle_mask')

# Outputspec
wf.connect(combine_csf_ventricles, 'out_file', 'outputspec', 'out_file')
22 changes: 16 additions & 6 deletions PUMI/settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ qc_dir = qc
bet_frac_anat = 0.5
bet_frac_func = 0.3
bet_vertical_gradient = -0.3
ImageMaths_ventricle_threshold = 0
ImageMaths_dilate_ventricle_mask = 0
# fnirt_config = /usr/local/fsl/etc/flirtsch/TI_2_MNI152_2mm.cnf

[HD-Bet]
Expand All @@ -28,9 +30,17 @@ overwrite_existing = 1
num_volumes = 5

[TEMPLATES]
head = data/standard/MNI152_T1_2mm.nii.gz
#also okay: head = tpl-MNI152Lin/tpl-MNI152Lin_res-02_T1w.nii.gz; source=templateflow
brain = data/standard/MNI152_T1_2mm_brain.nii.gz
brain_mask = data/standard/MNI152_T1_2mm_brain_mask_dil.nii.gz
#also okay: brain_mask = MNI152Lin/tpl-MNI152Lin_res-02_desc-head_mask.nii.gz; source=tf
ventricle_mask = data/standard/MNI152_T1_2mm_VentricleMask.nii.gz
head = tpl-MNI152NLin2009cAsym_res-02_T1w.nii.gz; source=templateflow
brain = tpl-MNI152NLin2009cAsym_res-02_desc-brain_T1w.nii.gz; source=templateflow
brain_mask = tpl-MNI152NLin2009cAsym_res-02_desc-brain_mask.nii.gz; source=templateflow
csf_probseg = tpl-MNI152NLin2009cAsym_res-02_label-CSF_probseg.nii.gz; source=templateflow
atlas = tpl-MNI152NLin2009cAsym_res-02_atlas-HOSPA_desc-th0_dseg.nii.gz; source=templateflow
ventricle_labels = 3, 14
# ---
# Some other possibilities:
# head = data/standard/MNI152_T1_2mm.nii.gz; source=fsl
# head = tpl-MNI152Lin/tpl-MNI152Lin_res-02_T1w.nii.gz; source=templateflow
# brain = data/standard/MNI152_T1_2mm_brain.nii.gz; source=fsl
# brain_mask = data/standard/MNI152_T1_2mm_brain_mask_dil.nii.gz; source=fsl
# brain_mask = MNI152Lin/tpl-MNI152Lin_res-02_desc-head_mask.nii.gz; source=tf
# ventricle_mask = data/standard/MNI152_T1_2mm_VentricleMask.nii.gz; source=fsl
Loading

0 comments on commit 07d8f48

Please sign in to comment.