Source code for AFQ.definitions.image

import numpy as np
import logging

import nibabel as nib

from dipy.segment.mask import median_otsu
from dipy.align import resample

from AFQ.definitions.utils import Definition, find_file, name_from_path

from skimage.morphology import convex_hull_image, binary_opening

__all__ = [
    "ImageFile", "FullImage", "RoiImage", "B0Image", "LabelledImageFile",
    "ThresholdedImageFile", "ScalarImage", "ThresholdedScalarImage",
    "TemplateImage", "GQImage"]


logger = logging.getLogger('AFQ')


def _resample_image(image_data, dwi_data, image_affine, dwi_affine):
    '''
    Helper function
    Resamples image to dwi if necessary
    '''
    image_type = image_data.dtype
    if ((dwi_data is not None)
        and (dwi_affine is not None)
            and (dwi_data[..., 0].shape != image_data.shape)):
        return np.round(resample(
            image_data.astype(float),
            dwi_data[..., 0],
            image_affine,
            dwi_affine).get_fdata()).astype(image_type)
    else:
        return image_data


class ImageDefinition(Definition):
    '''
    All Image Definitions should inherit this.
    '''

    def get_name(self):
        raise NotImplementedError("Please implement a get_name method")

    def get_image_getter(self, task_name):
        raise NotImplementedError(
            "Please implement a get_image_getter method")


class CombineImageMixin(object):
    """
    Helper Class
    Useful for making an image by combining different conditions
    """

    def __init__(self, combine):
        self.combine = combine.lower()

    def reset_image_draft(self, shape):
        if self.combine == "or":
            self.image_draft = np.zeros(shape, dtype=bool)
        elif self.combine == "and":
            self.image_draft = np.ones(shape, dtype=bool)
        else:
            self.combine_illdefined()

    def __mul__(self, other_image):
        if self.combine == "or":
            return np.logical_or(self.image_draft, other_image)
        elif self.combine == "and":
            return np.logical_and(self.image_draft, other_image)
        else:
            self.combine_illdefined()

    def combine_illdefined(self):
        raise TypeError((
            f"combine should be either 'or' or 'and',"
            f" you set combine to {self.combine}"))


[docs]class ImageFile(ImageDefinition): """ Define an image based on a file. Does not apply any labels or thresholds; Generates image with floating point data. Useful for seed and stop images, where threshold can be applied after interpolation (see example). Parameters ---------- path : str, optional path to file to get image from. Use this or suffix. Default: None suffix : str, optional suffix to pass to bids_layout.get() to identify the file. Default: None filters : str, optional Additional filters to pass to bids_layout.get() to identify the file. Default: {} Examples -------- seed_image = ImageFile( suffix="WM", filters={"scope":"dmriprep"}) api.GroupAFQ(tracking_params={"seed_image": seed_image, "seed_threshold": 0.1}) """ def __init__(self, path=None, suffix=None, filters={}): if path is None and suffix is None: raise ValueError(( "One of `path` or `suffix` must set to " "a value other than None.")) if path is not None: self._from_path = True self.fname = path else: self._from_path = False self.suffix = suffix self.filters = filters self.fnames = {}
[docs] def find_path(self, bids_layout, from_path, subject, session, required=True): if self._from_path: return nearest_image = find_file( bids_layout, from_path, self.filters, self.suffix, session, subject, required=required) if nearest_image is None: return False self.fnames[from_path] = nearest_image
[docs] def get_path_data_affine(self, dwi_path): if self._from_path: image_file = self.fname else: image_file = self.fnames[dwi_path] image_img = nib.load(image_file) return image_file, image_img.get_fdata(), image_img.affine
# This function is set up to be overriden by other images
[docs] def apply_conditions(self, image_data_orig, image_file): return image_data_orig, dict(source=image_file)
[docs] def get_name(self): return name_from_path(self.fname) if self._from_path else self.suffix
[docs] def get_image_getter(self, task_name): def _image_getter_helper(dwi, dwi_data_file): # Load data image_file, image_data_orig, image_affine = \ self.get_path_data_affine(dwi_data_file) # Apply any conditions on the data image_data, meta = self.apply_conditions( image_data_orig, image_file) # Resample to DWI data: image_data = _resample_image( image_data, dwi.get_fdata(), image_affine, dwi.affine) return nib.Nifti1Image( image_data.astype(np.float32), dwi.affine), meta if task_name == "data": def image_getter(dwi, dwi_data_file): return _image_getter_helper(dwi, dwi_data_file) else: def image_getter(data_imap, dwi_data_file): return _image_getter_helper(data_imap["dwi"], dwi_data_file) return image_getter
[docs]class FullImage(ImageDefinition): """ Define an image which covers a full volume. Examples -------- brain_image_definition = FullImage() """ def __init__(self): pass
[docs] def get_name(self): return "entire_volume"
[docs] def get_image_getter(self, task_name): def _image_getter_helper(dwi): return nib.Nifti1Image( np.ones(dwi.get_fdata()[..., 0].shape, dtype=np.float32), dwi.affine), dict(source="Entire Volume") if task_name == "data": def image_getter(dwi): return _image_getter_helper(dwi) else: def image_getter(data_imap): return _image_getter_helper(data_imap["dwi"]) return image_getter
[docs]class RoiImage(ImageDefinition): """ Define an image which is all include ROIs or'd together. Parameters ---------- use_waypoints : bool Whether to use the include ROIs to generate the image. use_presegment : bool Whether to use presegment bundle dict from segmentation params to get ROIs. use_endpoints : bool Whether to use the endpoints ("start" and "end") to generate the image. tissue_property : str or None Tissue property from `scalars` to multiply the ROI image with. Can be useful to limit seed mask to the core white matter. Note: this must be a built-in tissue property. Default: None tissue_property_n_voxel : int or None Threshold `tissue_property` to a boolean mask with tissue_property_n_voxel number of voxels set to True. Default: None tissue_property_threshold : int or None Threshold to threshold `tissue_property` if a boolean mask is desired. This threshold is interpreted as a percentile. Overrides tissue_property_n_voxel. Default: None Examples -------- seed_image = RoiImage() api.GroupAFQ(tracking_params={"seed_image": seed_image}) """ def __init__(self, use_waypoints=True, use_presegment=False, use_endpoints=False, tissue_property=None, tissue_property_n_voxel=None, tissue_property_threshold=None): self.use_waypoints = use_waypoints self.use_presegment = use_presegment self.use_endpoints = use_endpoints self.tissue_property = tissue_property self.tissue_property_n_voxel = tissue_property_n_voxel self.tissue_property_threshold = tissue_property_threshold if not np.logical_or(self.use_waypoints, np.logical_or( self.use_endpoints, self.use_presegment)): raise ValueError(( "One of use_waypoints, use_presegment, " "use_endpoints, must be True"))
[docs] def get_name(self): return "roi"
[docs] def get_image_getter(self, task_name): def _image_getter_helper(mapping, data_imap, segmentation_params): image_data = None bundle_dict = data_imap["bundle_dict"] if self.use_presegment: bundle_dict = \ segmentation_params["presegment_bundle_dict"] else: bundle_dict = bundle_dict for bundle_name in bundle_dict: bundle_entry = bundle_dict.transform_rois( bundle_name, mapping, data_imap["dwi_affine"]) rois = [] if self.use_endpoints: rois.extend( [bundle_entry[end_type] for end_type in ["start", "end"] if end_type in bundle_entry]) if self.use_waypoints: rois.extend(bundle_entry.get('include', [])) for roi in rois: warped_roi = roi.get_fdata() if image_data is None: image_data = np.zeros(warped_roi.shape) image_data = np.logical_or( image_data, warped_roi.astype(bool)) if self.tissue_property is not None: tp = nib.load(data_imap[self.tissue_property]).get_fdata() image_data = image_data.astype(np.float32) * tp if self.tissue_property_threshold is not None: zero_mask = image_data == 0 image_data[zero_mask] = np.nan tp_thresh = np.nanpercentile( image_data, 100 - self.tissue_property_threshold) image_data[zero_mask] = 0 image_data = image_data > tp_thresh elif self.tissue_property_n_voxel is not None: tp_thresh = np.sort(image_data.flatten())[ -1 - self.tissue_property_n_voxel] image_data = image_data > tp_thresh if image_data is None: raise ValueError(( "BundleDict does not have enough ROIs to generate " f"an ROI Image: {bundle_dict._dict}")) return nib.Nifti1Image( image_data.astype(np.float32), data_imap["dwi_affine"]), dict(source="ROIs") if task_name == "data": raise ValueError(( "RoiImage cannot be used in this context, as they" "require later derivatives to be calculated")) elif task_name == "mapping": def image_getter( mapping, data_imap, segmentation_params): return _image_getter_helper( mapping, data_imap, segmentation_params) else: def image_getter( mapping_imap, data_imap, segmentation_params): return _image_getter_helper( mapping_imap["mapping"], data_imap, segmentation_params) return image_getter
[docs]class GQImage(ImageDefinition): """ Threshold the anisotropic diffusion component of the Generalized Q-Sampling Model to generate a brain mask which will include the eyes, optic nerve, and cerebrum but will exclude most or all of the skull. Examples -------- api.GroupAFQ(brain_mask_definition=GQImage()) """ def __init__(self): pass
[docs] def get_name(self): return "GQ"
[docs] def get_image_getter(self, task_name): def image_getter_helper(gq_aso): gq_aso_img = nib.load(gq_aso) gq_aso_data = gq_aso_img.get_fdata() ASO_mask = convex_hull_image( binary_opening( gq_aso_data > 0.1)) return nib.Nifti1Image( ASO_mask.astype(np.float32), gq_aso_img.affine), dict( source=gq_aso, technique="GQ ASO thresholded maps") if task_name == "data": return image_getter_helper else: return lambda data_imap: image_getter_helper( data_imap["gq_aso"])
[docs]class B0Image(ImageDefinition): """ Define an image using b0 and dipy's median_otsu. Parameters ---------- median_otsu_kwargs: dict, optional Optional arguments to pass into dipy's median_otsu. Default: {} Examples -------- brain_image_definition = B0Image() api.GroupAFQ(brain_image_definition=brain_image_definition) """ def __init__(self, median_otsu_kwargs={}): self.median_otsu_kwargs = median_otsu_kwargs
[docs] def get_name(self): return "b0"
[docs] def get_image_getter(self, task_name): def image_getter_helper(b0): mean_b0_img = nib.load(b0) mean_b0 = mean_b0_img.get_fdata() logger.warning(( "It is recommended that you provide a brain mask. " "It is provided with the brain_mask_definition argument. " "Otherwise, the default brain mask is calculated " "by using OTSU on the median-filtered B0 image. " "This can be unreliable. ")) _, image_data = median_otsu(mean_b0, **self.median_otsu_kwargs) return nib.Nifti1Image( image_data.astype(np.float32), mean_b0_img.affine), dict( source=b0, technique="median_otsu applied to b0", median_otsu_kwargs=self.median_otsu_kwargs) if task_name == "data": return image_getter_helper else: return lambda data_imap: image_getter_helper(data_imap["b0"])
[docs]class LabelledImageFile(ImageFile, CombineImageMixin): """ Define an image based on labels in a file. Parameters ---------- path : str, optional path to file to get image from. Use this or suffix. Default: None suffix : str, optional suffix to pass to bids_layout.get() to identify the file. Default: None filters : str, optional Additional filters to pass to bids_layout.get() to identify the file. Default: {} inclusive_labels : list of ints, optional The labels from the file to include from the boolean image. If None, no inclusive labels are applied. exclusive_labels : list of ints, optional The labels from the file to exclude from the boolean image. If None, no exclusive labels are applied. Default: None. combine : str, optional How to combine the boolean images generated by inclusive_labels and exclusive_labels. If "and", they will be and'd together. If "or", they will be or'd. Note: in this class, you will most likely want to either set inclusive_labels or exclusive_labels, not both, so combine will not matter. Default: "or" Examples -------- brain_image_definition = LabelledImageFile( suffix="aseg", filters={"scope": "dmriprep"}, exclusive_labels=[0]) api.GroupAFQ(brain_image_definition=brain_image_definition) """ def __init__(self, path=None, suffix=None, filters={}, inclusive_labels=None, exclusive_labels=None, combine="or"): ImageFile.__init__(self, path, suffix, filters) CombineImageMixin.__init__(self, combine) self.inclusive_labels = inclusive_labels self.exclusive_labels = exclusive_labels # overrides ImageFile
[docs] def apply_conditions(self, image_data_orig, image_file): # For different sets of labels, extract all the voxels that # have any / all of these values: self.reset_image_draft(image_data_orig.shape) if self.inclusive_labels is not None: for label in self.inclusive_labels: self.image_draft = self * (image_data_orig == label) if self.exclusive_labels is not None: for label in self.exclusive_labels: self.image_draft = self * (image_data_orig != label) meta = dict(source=image_file, inclusive_labels=self.inclusive_labels, exclusive_lavels=self.exclusive_labels, combined_with=self.combine) return self.image_draft, meta
[docs]class ThresholdedImageFile(ImageFile, CombineImageMixin): """ Define an image based on thresholding a file. Note that this should not be used to directly make a seed image or a stop image. In those cases, consider thresholding after interpolation, as in the example for ImageFile. Parameters ---------- path : str, optional path to file to get image from. Use this or suffix. Default: None suffix : str, optional suffix to pass to bids_layout.get() to identify the file. Default: None filters : str, optional Additional filters to pass to bids_layout.get() to identify the file. Default: {} lower_bound : float, optional Lower bound to generate boolean image from data in the file. If None, no lower bound is applied. Default: None. upper_bound : float, optional Upper bound to generate boolean image from data in the file. If None, no upper bound is applied. Default: None. as_percentage : bool, optional Interpret lower_bound and upper_bound as percentages of the total non-nan voxels in the image to include (between 0 and 100), instead of as a threshold on the values themselves. Default: False combine : str, optional How to combine the boolean images generated by lower_bound and upper_bound. If "and", they will be and'd together. If "or", they will be or'd. Default: "and" Examples -------- brain_image_definition = ThresholdedImageFile( suffix="BM", filters={"scope":"dmriprep"}, lower_bound=0.1) api.GroupAFQ(brain_image_definition=brain_image_definition) """ def __init__(self, path=None, suffix=None, filters={}, lower_bound=None, upper_bound=None, as_percentage=False, combine="and"): ImageFile.__init__(self, path, suffix, filters) CombineImageMixin.__init__(self, combine) self.lower_bound = lower_bound self.upper_bound = upper_bound self.as_percentage = as_percentage # overrides ImageFile
[docs] def apply_conditions(self, image_data_orig, image_file): # Apply thresholds self.reset_image_draft(image_data_orig.shape) if self.upper_bound is not None: if self.as_percentage: upper_bound = np.nanpercentile( image_data_orig, self.upper_bound) else: upper_bound = self.upper_bound self.image_draft = self * (image_data_orig < upper_bound) if self.lower_bound is not None: if self.as_percentage: lower_bound = np.nanpercentile( image_data_orig, 100 - self.lower_bound) else: lower_bound = self.lower_bound self.image_draft = self * (image_data_orig > lower_bound) meta = dict(source=image_file, upper_bound=self.upper_bound, lower_bound=self.lower_bound, combined_with=self.combine) return self.image_draft, meta
[docs]class ScalarImage(ImageDefinition): """ Define an image based on a scalar. Does not apply any labels or thresholds; Generates image with floating point data. Useful for seed and stop images, where threshold can be applied after interpolation (see example). Parameters ---------- scalar : str Scalar to threshold. Can be one of "dti_fa", "dti_md", "dki_fa", "dki_md". Examples -------- seed_image = ScalarImage( "dti_fa") api.GroupAFQ(tracking_params={ "seed_image": seed_image, "seed_threshold": 0.2}) """ def __init__(self, scalar): self.scalar = scalar
[docs] def get_name(self): return self.scalar
[docs] def get_image_getter(self, task_name): if task_name == "data": raise ValueError(( "ScalarImage cannot be used in this context, as they" "require later derivatives to be calculated")) def image_getter(data_imap): return nib.load(data_imap[self.scalar]), dict( FromScalar=self.scalar) return image_getter
[docs]class ThresholdedScalarImage(ThresholdedImageFile, ScalarImage): """ Define an image based on thresholding a scalar image. Note that this should not be used to directly make a seed image or a stop image. In those cases, consider thresholding after interpolation, as in the example for ScalarImage. Parameters ---------- scalar : str Scalar to threshold. Can be one of "dti_fa", "dti_md", "dki_fa", "dki_md". lower_bound : float, optional Lower bound to generate boolean image from data in the file. If None, no lower bound is applied. Default: None. upper_bound : float, optional Upper bound to generate boolean image from data in the file. If None, no upper bound is applied. Default: None. combine : str, optional How to combine the boolean images generated by lower_bound and upper_bound. If "and", they will be and'd together. If "or", they will be or'd. Default: "and" Examples -------- seed_image = ThresholdedScalarImage( "dti_fa", lower_bound=0.2) api.GroupAFQ(tracking_params={"seed_image": seed_image}) """ def __init__(self, scalar, lower_bound=None, upper_bound=None, combine="and"): self.scalar = scalar CombineImageMixin.__init__(self, combine) self.lower_bound = lower_bound self.upper_bound = upper_bound
class PFTImage(ImageDefinition): """ Define an image for use in PFT tractography. Only use if tracker set to 'pft' in tractography. Parameters ---------- WM_probseg : ImageFile White matter segmentation file. GM_probseg : ImageFile Gray matter segmentation file. CSF_probseg : ImageFile Corticospinal fluid segmentation file. Examples -------- stop_image = PFTImage( afm.ImageFile(suffix="WMprobseg"), afm.ImageFile(suffix="GMprobseg"), afm.ImageFile(suffix="CSFprobseg")) api.GroupAFQ(tracking_params={ "stop_image": stop_image, "stop_threshold": "CMC", "tracker": "pft"}) """ def __init__(self, WM_probseg, GM_probseg, CSF_probseg): self.probsegs = (WM_probseg, GM_probseg, CSF_probseg) def find_path(self, bids_layout, from_path, subject, session, required=True): if required == False: raise ValueError( "PFTImage cannot be used in this context") for probseg in self.probsegs: probseg.find_path( bids_layout, from_path, subject, session, required=required) def get_name(self): return "pft" def get_image_getter(self, task_name): if task_name == "data": raise ValueError("PFTImage cannot be used in this context") return [probseg.get_image_getter(task_name) for probseg in self.probsegs]
[docs]class TemplateImage(ImageDefinition): """ Define a scalar based on a template. This template will be transformed into subject space before use. Parameters ---------- path : str path to the template. Examples -------- my_scalar = TemplateImage( "path/to/my_scalar_in_MNI.nii.gz") api.GroupAFQ(scalars=["dti_fa", "dti_md", my_scalar]) """ def __init__(self, path): self.path = path
[docs] def get_name(self): return name_from_path(self.path)
[docs] def get_image_getter(self, task_name): def _image_getter_helper(mapping, reg_template, reg_subject): img = nib.load(self.path) img_data = resample( img.get_fdata(), reg_template, img.affine, reg_template.affine).get_fdata() scalar_data = mapping.transform_inverse( img_data, interpolation='nearest') return nib.Nifti1Image( scalar_data.astype(np.float32), reg_subject.affine), dict(source=self.path) if task_name == "data": raise ValueError(( "TemplateImage cannot be used in this context, as they" "require later derivatives to be calculated")) elif task_name == "mapping": def image_getter(mapping, reg_subject, data_imap): return _image_getter_helper( mapping, data_imap["reg_template"], reg_subject) else: def image_getter(mapping_imap, data_imap): return _image_getter_helper( mapping_imap["mapping"], data_imap["reg_template"], mapping_imap["reg_subject"]) return image_getter