Source code for AFQ.definitions.mapping

import logging
import os.path as op
from time import time

import nibabel as nib
import numpy as np
from dipy.align import affine_registration, syn_registration
from dipy.align.streamlinear import whole_brain_slr

import AFQ.registration as reg
from AFQ._fixes import get_simplified_transform
from AFQ.definitions.utils import Definition, find_file
from AFQ.tasks.utils import get_fname
from AFQ.utils.path import space_from_fname, write_json

try:
    from fsl.data.image import Image
    from fsl.transform.fnirt import readFnirt
    from fsl.transform.nonlinear import applyDeformation

    has_fslpy = True
except ModuleNotFoundError:
    has_fslpy = False

__all__ = ["FnirtMap", "SynMap", "SlrMap", "AffMap", "IdentityMap"]


logger = logging.getLogger("AFQ")


# For map definitions, get_for_subses should return only the mapping
# Where the mapping has transform and transform_inverse functions
# which each accept data, **kwargs


[docs] class FnirtMap(Definition): """ Use an existing FNIRT map. Expects a warp file and an image file for each subject / session; image file is used as src space for warp. Parameters ---------- warp_path : str, optional path to file to get warp from. Use this or warp_suffix. Default: None space_path : str, optional path to file to get warp from. Use this or space_suffix. Default: None warp_suffix : str, optional suffix to pass to bids_layout.get() to identify the warp file. Default: None space_suffix : str, optional suffix to pass to bids_layout.get() to identify the space file. Default: None warp_filters : str, optional Additional filters to pass to bids_layout.get() to identify the warp file. Default: {} space_filters : str, optional Additional filters to pass to bids_layout.get() to identify the space file. Default: {} Notes ----- If you have an existing mapping calculated using Fnirt, you can pass bids filters to :class:`AFQ.definitions.mapping.FnirtMap` and pyAFQ will find and use that mapping. Examples -------- fnirt_map = FnirtMap( warp_suffix="warp", space_suffix="MNI", warp_filters={"scope": "TBSS"}, space_filters={"scope": "TBSS"}) api.GroupAFQ(mapping=fnirt_map) """ def __init__( self, warp_path=None, space_path=None, warp_suffix=None, space_suffix=None, warp_filters=None, space_filters=None, ): if space_filters is None: space_filters = {} if warp_filters is None: warp_filters = {} if not has_fslpy: raise ImportError("Please install fslpy if you want to use FnirtMap") if warp_path is None and warp_suffix is None: raise ValueError( ( "One of `warp_path` or `warp_suffix` should be set " "to a value other than None." ) ) if space_path is None and space_suffix is None: raise ValueError("One of space_path or space_suffix must not be None.") if ( warp_path is not None and space_path is None or space_path is not None and warp_path is None ): raise ValueError( ( "If passing a value for `warp_path`, " "you must also pass a value for `space_path`" ) ) if warp_path is not None: self._from_path = True self.fnames = (warp_path, space_path) else: self._from_path = False self.warp_suffix = warp_suffix self.warp_filters = warp_filters self.space_suffix = space_suffix self.space_filters = space_filters self.fnames = {}
[docs] def find_path(self, bids_layout, from_path, subject, session, required=True): if self._from_path: return if session not in self.fnames: self.fnames[session] = {} nearest_warp = find_file( bids_layout, from_path, self.warp_filters, self.warp_suffix, session, subject, required=required, ) nearest_space = find_file( bids_layout, from_path, self.space_filters, self.space_suffix, session, subject, required=required, ) self.fnames[from_path] = (nearest_warp, nearest_space)
[docs] def get_for_subses( self, base_fname, dwi, dwi_data_file, reg_subject, reg_template, tmpl_name ): if self._from_path: nearest_warp, nearest_space = self.fnames else: nearest_warp, nearest_space = self.fnames[dwi_data_file] our_templ = reg_template subj = Image(dwi) their_templ = Image(nearest_space) warp = readFnirt(nearest_warp, their_templ, subj) return ConformedFnirtMapping(warp, our_templ.affine)
class ConformedFnirtMapping: """ ConformedFnirtMapping which matches the generic mapping API. """ def __init__(self, warp, ref_affine): self.ref_affine = ref_affine self.warp = warp def transform(self, data, **kwargs): data_img = Image(nib.Nifti1Image(data.astype(np.float32), self.ref_affine)) return np.asarray(applyDeformation(data_img, self.warp).data) def transform_pts(self, pts): # This should only be used for curvature analysis, # Because I think the results still need to be shifted pts = nib.affines.apply_affine(self.warp.src.getAffine("voxel", "world"), pts) pts = nib.affines.apply_affine(np.linalg.inv(self.ref_affine), pts) pts = self.warp.transform(pts, "fsl", "world") return pts def transform_inverse(self, data, **kwargs): raise NotImplementedError( "Fnirt based mappings can currently" + " only transform from template to subject space" ) class GeneratedMapMixin(object): """ Helper Class Useful for maps that are generated by pyAFQ """ def get_fnames(self, extension, base_fname, sub_name, tmpl_name): mapping_file = get_fname( base_fname, f"_desc-mapping_from-{sub_name}_to-{tmpl_name}_xform" ) meta_fname = f"{mapping_file}.json" mapping_file = mapping_file + extension return mapping_file, meta_fname def prealign(self, reg_subject, reg_template): logger.info("Calculating affine pre-alignment...") _, aff = affine_registration(reg_subject, reg_template, **self.affine_kwargs) return aff class AffineMapMixin(GeneratedMapMixin): """ Helper Class Useful for maps that are generated by pyAFQ """ def get_for_subses( self, base_fname, dwi, dwi_data_file, reg_subject, reg_template, tmpl_name, subject_sls=None, template_sls=None, ): sub_space = space_from_fname(dwi_data_file) mapping_file, meta_fname = self.get_fnames( ".npy", base_fname, sub_space, tmpl_name ) if not op.exists(mapping_file): start_time = time() affine = self.gen_mapping( reg_subject, reg_template, subject_sls, template_sls, ) total_time = time() - start_time logger.info(f"Saving {mapping_file}") np.save(mapping_file, affine) meta = dict(type="affine", timing=total_time) if subject_sls is None: meta["dependent"] = "dwi" else: meta["dependent"] = "trk" if isinstance(reg_subject, str): meta["reg_subject"] = reg_subject if isinstance(reg_template, str): meta["reg_template"] = reg_template write_json(meta_fname, meta) mapping = reg.read_affine_mapping(mapping_file, dwi, reg_template) return mapping
[docs] class SynMap(GeneratedMapMixin, Definition): """ Calculate a Syn registration for each subject/session using reg_subject and reg_template. Parameters ---------- use_prealign : bool Whether to perform a linear pre-registration. Default: True affine_kwargs : dictionary, optional Parameters to pass to affine_registration in dipy.align, which does the linear pre-alignment. Only used if use_prealign is True. Default: {} syn_kwargs : dictionary, optional Parameters to pass to syn_registration in dipy.align, which does the SyN alignment. Default: {} Notes ----- The default mapping class is to use Symmetric Diffeomorphic Image Registration (SyN). This is done with an optional linear pre-alignment by default. The parameters of the pre-alginment can be specified when initializing the SynMap. Examples -------- api.GroupAFQ(mapping=SynMap()) """ def __init__(self, use_prealign=True, affine_kwargs=None, syn_kwargs=None): if syn_kwargs is None: syn_kwargs = {} if affine_kwargs is None: affine_kwargs = {}
[docs] self.use_prealign = use_prealign
[docs] self.affine_kwargs = affine_kwargs
[docs] self.syn_kwargs = syn_kwargs
[docs] def get_for_subses( self, base_fname, dwi, dwi_data_file, reg_subject, reg_template, tmpl_name, subject_sls=None, template_sls=None, ): sub_space = space_from_fname(dwi_data_file) mapping_file_forward, meta_forward_fname = self.get_fnames( ".nii.gz", base_fname, sub_space, tmpl_name ) mapping_file_backward, meta_backward_fname = self.get_fnames( ".nii.gz", base_fname, tmpl_name, sub_space ) if not op.exists(mapping_file_forward) or not op.exists(mapping_file_backward): meta = dict(type="displacementfield") meta["dependent"] = "dwi" if isinstance(reg_subject, str): meta["reg_subject"] = reg_subject if isinstance(reg_template, str): meta["reg_template"] = reg_template start_time = time() if self.use_prealign: reg_prealign = self.prealign(reg_subject, reg_template) else: reg_prealign = None logger.info("Calculating SyN registration...") _, mapping = syn_registration( reg_subject.get_fdata(), reg_template.get_fdata(), moving_affine=reg_subject.affine, static_affine=reg_template.affine, prealign=reg_prealign, **self.syn_kwargs, ) mapping = get_simplified_transform(mapping) total_time = time() - start_time meta["total_time"] = total_time logger.info(f"Saving {mapping_file_forward}") nib.save( nib.Nifti1Image(mapping.forward, reg_subject.affine), mapping_file_forward, ) write_json(meta_forward_fname, meta) logger.info(f"Saving {mapping_file_backward}") nib.save( nib.Nifti1Image(mapping.backward, reg_template.affine), mapping_file_backward, ) write_json(meta_backward_fname, meta) mapping = reg.read_syn_mapping(mapping_file_forward, mapping_file_backward) return mapping
[docs] class SlrMap(AffineMapMixin, Definition): """ Calculate a SLR registration for each subject/session using reg_subject and reg_template. Parameters ---------- slr_kwargs : dictionary, optional Parameters to pass to whole_brain_slr in dipy, which does the SLR alignment. Default: {} Notes ----- Use this class to tell pyAFQ to use Streamline-based Linear Registration (SLR) for registration. Note that the reg_template and reg_subject parameters passed to :class:`AFQ.api.group.GroupAFQ` should be streamlines when using this registration. Examples -------- api.GroupAFQ(mapping=SlrMap()) """ def __init__(self, slr_kwargs=None): if slr_kwargs is None: slr_kwargs = {}
[docs] self.slr_kwargs = slr_kwargs
[docs] def gen_mapping( self, reg_subject, reg_template, subject_sls, template_sls, ): _, transform, _, _ = whole_brain_slr( subject_sls, template_sls, x0="affine", verbose=False, **self.slr_kwargs ) return transform
[docs] class AffMap(AffineMapMixin, Definition): """ Calculate an affine registration for each subject/session using reg_subject and reg_template. Parameters ---------- affine_kwargs : dictionary, optional Parameters to pass to affine_registration in dipy.align, which does the linear pre-alignment. Default: {} Notes ----- This will only perform a linear alignment for registration. Examples -------- api.GroupAFQ(mapping=AffMap()) """ def __init__(self, affine_kwargs=None): if affine_kwargs is None: affine_kwargs = {}
[docs] self.affine_kwargs = affine_kwargs
[docs] def gen_mapping( self, reg_subject, reg_template, subject_sls, template_sls, ): return np.linalg.inv(self.prealign(reg_subject, reg_template))
[docs] class IdentityMap(AffineMapMixin, Definition): """ Does not perform any transformations from MNI to subject where pyAFQ normally would. Examples -------- my_example_mapping = IdentityMap() api.GroupAFQ(mapping=my_example_mapping) """ def __init__(self): pass
[docs] def gen_mapping( self, reg_subject, reg_template, subject_sls, template_sls, ): return np.identity(4)