import nibabel as nib
import numpy as np
import logging
from time import time
import os.path as op
from AFQ.definitions.utils import Definition, find_file
from dipy.align import syn_registration, affine_registration
import AFQ.registration as reg
from AFQ.utils.path import write_json
from AFQ.tasks.utils import get_fname
from dipy.align.imaffine import AffineMap
try:
from fsl.data.image import Image
from fsl.transform.fnirt import readFnirt
from fsl.transform.nonlinear import applyDeformation
has_fslpy = True
except ModuleNotFoundError:
has_fslpy = False
__all__ = ["FnirtMap", "SynMap", "SlrMap", "AffMap", "IdentityMap"]
logger = logging.getLogger('AFQ')
# For map defintions, get_for_subses should return only the mapping
# Where the mapping has transform and transform_inverse functions
# which each accept data, **kwargs
[docs]class FnirtMap(Definition):
"""
Use an existing FNIRT map. Expects a warp file
and an image file for each subject / session; image file
is used as src space for warp.
Parameters
----------
warp_path : str, optional
path to file to get warp from. Use this or warp_suffix.
Default: None
space_path : str, optional
path to file to get warp from. Use this or space_suffix.
Default: None
warp_suffix : str, optional
suffix to pass to bids_layout.get() to identify the warp file.
Default: None
space_suffix : str, optional
suffix to pass to bids_layout.get() to identify the space file.
Default: None
warp_filters : str, optional
Additional filters to pass to bids_layout.get() to identify
the warp file.
Default: {}
space_filters : str, optional
Additional filters to pass to bids_layout.get() to identify
the space file.
Default: {}
Notes
-----
If you have an existing mapping calculated using Fnirt,
you can pass bids filters to :class:`AFQ.definitions.mapping.FnirtMap`
and pyAFQ will find and use that mapping.
Examples
--------
fnirt_map = FnirtMap(
warp_suffix="warp",
space_suffix="MNI",
warp_filters={"scope": "TBSS"},
space_filters={"scope": "TBSS"})
api.GroupAFQ(mapping=fnirt_map)
"""
def __init__(self, warp_path=None, space_path=None,
warp_suffix=None, space_suffix=None,
warp_filters={}, space_filters={}):
if not has_fslpy:
raise ImportError(
"Please install fslpy if you want to use FnirtMap")
if warp_path is None and warp_suffix is None:
raise ValueError((
"One of `warp_path` or `warp_suffix` should be set "
"to a value other than None."))
if space_path is None and space_suffix is None:
raise ValueError(
"One of space_path or space_suffix must not be None.")
if warp_path is not None and space_path is None\
or space_path is not None and warp_path is None:
raise ValueError((
"If passing a value for `warp_path`, "
"you must also pass a value for `space_path`"))
if warp_path is not None:
self._from_path = True
self.fnames = (warp_path, space_path)
else:
self._from_path = False
self.warp_suffix = warp_suffix
self.warp_filters = warp_filters
self.space_suffix = space_suffix
self.space_filters = space_filters
self.fnames = {}
[docs] def find_path(self, bids_layout, from_path,
subject, session, required=True):
if self._from_path:
return
if session not in self.fnames:
self.fnames[session] = {}
nearest_warp = find_file(
bids_layout, from_path, self.warp_filters, self.warp_suffix,
session, subject, required=required)
nearest_space = find_file(
bids_layout, from_path, self.space_filters, self.space_suffix,
session, subject, required=required)
self.fnames[from_path] = (nearest_warp, nearest_space)
[docs] def get_for_subses(self, base_fname, dwi, dwi_data_file, reg_subject,
reg_template):
if self._from_path:
nearest_warp, nearest_space = self.fnames
else:
nearest_warp, nearest_space = self.fnames[dwi_data_file]
our_templ = reg_template
subj = Image(dwi)
their_templ = Image(nearest_space)
warp = readFnirt(nearest_warp, their_templ, subj)
return ConformedFnirtMapping(warp, our_templ.affine)
class ConformedFnirtMapping():
"""
ConformedFnirtMapping which matches the generic mapping API.
"""
def __init__(self, warp, ref_affine):
self.ref_affine = ref_affine
self.warp = warp
def transform_inverse(self, data, **kwargs):
data_img = Image(nib.Nifti1Image(
data.astype(np.float32), self.ref_affine))
return np.asarray(applyDeformation(data_img, self.warp).data)
def transform_inverse_pts(self, pts):
# This should only be used for curvature analysis,
# Because I think the results still need to be shifted
pts = nib.affines.apply_affine(
self.warp.src.getAffine('voxel', 'world'), pts)
pts = nib.affines.apply_affine(
np.linalg.inv(self.ref_affine), pts)
pts = self.warp.transform(pts, 'fsl', "world")
return pts
def transform(self, data, **kwargs):
raise NotImplementedError(
"Fnirt based mappings can currently"
+ " only transform from template to subject space")
[docs]class IdentityMap(Definition):
"""
Does not perform any transformations from MNI to subject where
pyAFQ normally would.
Examples
--------
my_example_mapping = IdentityMap()
api.GroupAFQ(mapping=my_example_mapping)
"""
def __init__(self):
pass
[docs] def get_for_subses(self, base_fname, dwi, dwi_data_file, reg_subject,
reg_template):
return ConformedAffineMapping(
np.identity(4),
domain_grid_shape=reg.reduce_shape(
reg_subject.shape),
domain_grid2world=reg_subject.affine,
codomain_grid_shape=reg.reduce_shape(
reg_template.shape),
codomain_grid2world=reg_template.affine)
class GeneratedMapMixin(object):
"""
Helper Class
Useful for maps that are generated by pyAFQ
"""
def get_fnames(self, extension, base_fname):
mapping_file = get_fname(
base_fname,
'_desc-mapping_from-DWI_to-MNI_xform')
meta_fname = f'{mapping_file}.json'
mapping_file = mapping_file + extension
return mapping_file, meta_fname
def prealign(self, base_fname, reg_subject, reg_template, save=True):
prealign_file_desc = "_desc-prealign_from-DWI_to-MNI_xform"
prealign_file = get_fname(
base_fname, f'{prealign_file_desc}.npy')
if not op.exists(prealign_file):
start_time = time()
_, aff = affine_registration(
reg_subject,
reg_template,
**self.affine_kwargs)
meta = dict(
type="rigid",
dependent="dwi",
timing=time() - start_time)
if not save:
return aff
logger.info(f"Saving {prealign_file}")
np.save(prealign_file, aff)
meta_fname = get_fname(
base_fname, f'{prealign_file_desc}.json')
write_json(meta_fname, meta)
return prealign_file if save else np.load(prealign_file)
def get_for_subses(self, base_fname, dwi, dwi_data_file, reg_subject,
reg_template, subject_sls=None, template_sls=None):
mapping_file, meta_fname = self.get_fnames(
self.extension, base_fname)
if self.use_prealign:
reg_prealign = np.load(self.prealign(
base_fname, reg_subject, reg_template))
else:
reg_prealign = None
if not op.exists(mapping_file):
start_time = time()
mapping = self.gen_mapping(
base_fname, reg_subject, reg_template,
subject_sls, template_sls,
reg_prealign)
total_time = time() - start_time
logger.info(f"Saving {mapping_file}")
reg.write_mapping(mapping, mapping_file)
meta = dict(
type="displacementfield",
timing=total_time)
if subject_sls is None:
meta["dependent"] = "dwi"
else:
meta["dependent"] = "trk"
if isinstance(reg_subject, str):
meta["reg_subject"] = reg_subject
if isinstance(reg_template, str):
meta["reg_template"] = reg_template
write_json(meta_fname, meta)
reg_prealign_inv = np.linalg.inv(reg_prealign) if self.use_prealign\
else None
mapping = reg.read_mapping(
mapping_file,
dwi,
reg_template,
prealign=reg_prealign_inv)
return mapping
[docs]class SynMap(GeneratedMapMixin, Definition):
"""
Calculate a Syn registration for each subject/session
using reg_subject and reg_template.
Parameters
----------
use_prealign : bool
Whether to perform a linear pre-registration.
Default: True
affine_kwargs : dictionary, optional
Parameters to pass to affine_registration
in dipy.align, which does the linear pre-alignment.
Only used if use_prealign is True.
Default: {}
syn_kwargs : dictionary, optional
Parameters to pass to syn_registration
in dipy.align, which does the SyN alignment.
Default: {}
Notes
-----
The default mapping class is to
use Symmetric Diffeomorphic Image Registration (SyN).
This is done with an optional linear pre-alignment by default.
The parameters of the pre-alginment can be specified when
initializing the SynMap.
Examples
--------
api.GroupAFQ(mapping=SynMap())
"""
def __init__(self, use_prealign=True, affine_kwargs={}, syn_kwargs={}):
self.use_prealign = use_prealign
self.affine_kwargs = affine_kwargs
self.syn_kwargs = syn_kwargs
self.extension = ".nii.gz"
[docs] def gen_mapping(self, base_fname, reg_subject, reg_template,
subject_sls, template_sls,
reg_prealign):
_, mapping = syn_registration(
reg_subject.get_fdata(),
reg_template.get_fdata(),
moving_affine=reg_subject.affine,
static_affine=reg_template.affine,
prealign=reg_prealign,
**self.syn_kwargs)
if self.use_prealign:
mapping.codomain_world2grid = np.linalg.inv(reg_prealign)
return mapping
[docs]class SlrMap(GeneratedMapMixin, Definition):
"""
Calculate a SLR registration for each subject/session
using reg_subject and reg_template.
Parameters
----------
slr_kwargs : dictionary, optional
Parameters to pass to whole_brain_slr
in dipy, which does the SLR alignment.
Default: {}
Notes
-----
Use this class to tell pyAFQ to use
Streamline-based Linear Registration (SLR)
for registration. Note that the reg_template and reg_subject
parameters passed to :class:`AFQ.api.group.GroupAFQ` should
be streamlines when using this registration.
Examples
--------
api.GroupAFQ(mapping=SlrMap())
"""
def __init__(self, slr_kwargs={}):
self.slr_kwargs = {}
self.use_prealign = False
self.extension = ".npy"
[docs] def gen_mapping(self, base_fname, reg_template, reg_subject,
subject_sls, template_sls, reg_prealign):
return reg.slr_registration(
subject_sls, template_sls,
moving_affine=reg_subject.affine,
moving_shape=reg_subject.shape,
static_affine=reg_template.affine,
static_shape=reg_template.shape,
**self.slr_kwargs)
[docs]class AffMap(GeneratedMapMixin, Definition):
"""
Calculate an affine registration for each subject/session
using reg_subject and reg_template.
Parameters
----------
affine_kwargs : dictionary, optional
Parameters to pass to affine_registration
in dipy.align, which does the linear pre-alignment.
Default: {}
Notes
-----
This will only perform a linear alignment for registration.
Examples
--------
api.GroupAFQ(mapping=AffMap())
"""
def __init__(self, affine_kwargs={}):
self.use_prealign = False
self.affine_kwargs = affine_kwargs
self.extension = ".npy"
[docs] def gen_mapping(self, base_fname, reg_subject, reg_template,
subject_sls, template_sls,
reg_prealign):
return ConformedAffineMapping(
np.linalg.inv(self.prealign(
base_fname, reg_subject, reg_template, save=False)),
domain_grid_shape=reg.reduce_shape(
reg_subject.shape),
domain_grid2world=reg_subject.affine,
codomain_grid_shape=reg.reduce_shape(
reg_template.shape),
codomain_grid2world=reg_template.affine)
class ConformedAffineMapping(AffineMap):
"""
Modifies AffineMap API to match DiffeomorphicMap API.
Important for SLR maps API to be indistinguishable from SYN maps API.
"""
def transform(self, *args, interpolation='linear', **kwargs):
kwargs['interp'] = interpolation
return super().transform_inverse(*args, **kwargs)
def transform_inverse(self, *args, interpolation='linear', **kwargs):
kwargs['interp'] = interpolation
return super().transform(*args, **kwargs)