import logging
import os.path as op
from time import time
import nibabel as nib
import numpy as np
from dipy.align import affine_registration, syn_registration
from dipy.align.streamlinear import whole_brain_slr
import AFQ.registration as reg
from AFQ._fixes import get_simplified_transform
from AFQ.definitions.utils import Definition, find_file
from AFQ.tasks.utils import get_fname
from AFQ.utils.path import space_from_fname, write_json
try:
from fsl.data.image import Image
from fsl.transform.fnirt import readFnirt
from fsl.transform.nonlinear import applyDeformation
has_fslpy = True
except ModuleNotFoundError:
has_fslpy = False
__all__ = ["FnirtMap", "SynMap", "SlrMap", "AffMap", "IdentityMap"]
logger = logging.getLogger("AFQ")
# For map definitions, get_for_subses should return only the mapping
# Where the mapping has transform and transform_inverse functions
# which each accept data, **kwargs
[docs]
class FnirtMap(Definition):
"""
Use an existing FNIRT map. Expects a warp file
and an image file for each subject / session; image file
is used as src space for warp.
Parameters
----------
warp_path : str, optional
path to file to get warp from. Use this or warp_suffix.
Default: None
space_path : str, optional
path to file to get warp from. Use this or space_suffix.
Default: None
warp_suffix : str, optional
suffix to pass to bids_layout.get() to identify the warp file.
Default: None
space_suffix : str, optional
suffix to pass to bids_layout.get() to identify the space file.
Default: None
warp_filters : str, optional
Additional filters to pass to bids_layout.get() to identify
the warp file.
Default: {}
space_filters : str, optional
Additional filters to pass to bids_layout.get() to identify
the space file.
Default: {}
Notes
-----
If you have an existing mapping calculated using Fnirt,
you can pass bids filters to :class:`AFQ.definitions.mapping.FnirtMap`
and pyAFQ will find and use that mapping.
Examples
--------
fnirt_map = FnirtMap(
warp_suffix="warp",
space_suffix="MNI",
warp_filters={"scope": "TBSS"},
space_filters={"scope": "TBSS"})
api.GroupAFQ(mapping=fnirt_map)
"""
def __init__(
self,
warp_path=None,
space_path=None,
warp_suffix=None,
space_suffix=None,
warp_filters=None,
space_filters=None,
):
if space_filters is None:
space_filters = {}
if warp_filters is None:
warp_filters = {}
if not has_fslpy:
raise ImportError("Please install fslpy if you want to use FnirtMap")
if warp_path is None and warp_suffix is None:
raise ValueError(
(
"One of `warp_path` or `warp_suffix` should be set "
"to a value other than None."
)
)
if space_path is None and space_suffix is None:
raise ValueError("One of space_path or space_suffix must not be None.")
if (
warp_path is not None
and space_path is None
or space_path is not None
and warp_path is None
):
raise ValueError(
(
"If passing a value for `warp_path`, "
"you must also pass a value for `space_path`"
)
)
if warp_path is not None:
self._from_path = True
self.fnames = (warp_path, space_path)
else:
self._from_path = False
self.warp_suffix = warp_suffix
self.warp_filters = warp_filters
self.space_suffix = space_suffix
self.space_filters = space_filters
self.fnames = {}
[docs]
def find_path(self, bids_layout, from_path, subject, session, required=True):
if self._from_path:
return
if session not in self.fnames:
self.fnames[session] = {}
nearest_warp = find_file(
bids_layout,
from_path,
self.warp_filters,
self.warp_suffix,
session,
subject,
required=required,
)
nearest_space = find_file(
bids_layout,
from_path,
self.space_filters,
self.space_suffix,
session,
subject,
required=required,
)
self.fnames[from_path] = (nearest_warp, nearest_space)
[docs]
def get_for_subses(
self, base_fname, dwi, dwi_data_file, reg_subject, reg_template, tmpl_name
):
if self._from_path:
nearest_warp, nearest_space = self.fnames
else:
nearest_warp, nearest_space = self.fnames[dwi_data_file]
our_templ = reg_template
subj = Image(dwi)
their_templ = Image(nearest_space)
warp = readFnirt(nearest_warp, their_templ, subj)
return ConformedFnirtMapping(warp, our_templ.affine)
class ConformedFnirtMapping:
"""
ConformedFnirtMapping which matches the generic mapping API.
"""
def __init__(self, warp, ref_affine):
self.ref_affine = ref_affine
self.warp = warp
def transform(self, data, **kwargs):
data_img = Image(nib.Nifti1Image(data.astype(np.float32), self.ref_affine))
return np.asarray(applyDeformation(data_img, self.warp).data)
def transform_pts(self, pts):
# This should only be used for curvature analysis,
# Because I think the results still need to be shifted
pts = nib.affines.apply_affine(self.warp.src.getAffine("voxel", "world"), pts)
pts = nib.affines.apply_affine(np.linalg.inv(self.ref_affine), pts)
pts = self.warp.transform(pts, "fsl", "world")
return pts
def transform_inverse(self, data, **kwargs):
raise NotImplementedError(
"Fnirt based mappings can currently"
+ " only transform from template to subject space"
)
class GeneratedMapMixin(object):
"""
Helper Class
Useful for maps that are generated by pyAFQ
"""
def get_fnames(self, extension, base_fname, sub_name, tmpl_name):
mapping_file = get_fname(
base_fname, f"_desc-mapping_from-{sub_name}_to-{tmpl_name}_xform"
)
meta_fname = f"{mapping_file}.json"
mapping_file = mapping_file + extension
return mapping_file, meta_fname
def prealign(self, reg_subject, reg_template):
logger.info("Calculating affine pre-alignment...")
_, aff = affine_registration(reg_subject, reg_template, **self.affine_kwargs)
return aff
class AffineMapMixin(GeneratedMapMixin):
"""
Helper Class
Useful for maps that are generated by pyAFQ
"""
def get_for_subses(
self,
base_fname,
dwi,
dwi_data_file,
reg_subject,
reg_template,
tmpl_name,
subject_sls=None,
template_sls=None,
):
sub_space = space_from_fname(dwi_data_file)
mapping_file, meta_fname = self.get_fnames(
".npy", base_fname, sub_space, tmpl_name
)
if not op.exists(mapping_file):
start_time = time()
affine = self.gen_mapping(
reg_subject,
reg_template,
subject_sls,
template_sls,
)
total_time = time() - start_time
logger.info(f"Saving {mapping_file}")
np.save(mapping_file, affine)
meta = dict(type="affine", timing=total_time)
if subject_sls is None:
meta["dependent"] = "dwi"
else:
meta["dependent"] = "trk"
if isinstance(reg_subject, str):
meta["reg_subject"] = reg_subject
if isinstance(reg_template, str):
meta["reg_template"] = reg_template
write_json(meta_fname, meta)
mapping = reg.read_affine_mapping(mapping_file, dwi, reg_template)
return mapping
[docs]
class SynMap(GeneratedMapMixin, Definition):
"""
Calculate a Syn registration for each subject/session
using reg_subject and reg_template.
Parameters
----------
use_prealign : bool
Whether to perform a linear pre-registration.
Default: True
affine_kwargs : dictionary, optional
Parameters to pass to affine_registration
in dipy.align, which does the linear pre-alignment.
Only used if use_prealign is True.
Default: {}
syn_kwargs : dictionary, optional
Parameters to pass to syn_registration
in dipy.align, which does the SyN alignment.
Default: {}
Notes
-----
The default mapping class is to
use Symmetric Diffeomorphic Image Registration (SyN).
This is done with an optional linear pre-alignment by default.
The parameters of the pre-alginment can be specified when
initializing the SynMap.
Examples
--------
api.GroupAFQ(mapping=SynMap())
"""
def __init__(self, use_prealign=True, affine_kwargs=None, syn_kwargs=None):
if syn_kwargs is None:
syn_kwargs = {}
if affine_kwargs is None:
affine_kwargs = {}
[docs]
self.use_prealign = use_prealign
[docs]
self.affine_kwargs = affine_kwargs
[docs]
self.syn_kwargs = syn_kwargs
[docs]
def get_for_subses(
self,
base_fname,
dwi,
dwi_data_file,
reg_subject,
reg_template,
tmpl_name,
subject_sls=None,
template_sls=None,
):
sub_space = space_from_fname(dwi_data_file)
mapping_file_forward, meta_forward_fname = self.get_fnames(
".nii.gz", base_fname, sub_space, tmpl_name
)
mapping_file_backward, meta_backward_fname = self.get_fnames(
".nii.gz", base_fname, tmpl_name, sub_space
)
if not op.exists(mapping_file_forward) or not op.exists(mapping_file_backward):
meta = dict(type="displacementfield")
meta["dependent"] = "dwi"
if isinstance(reg_subject, str):
meta["reg_subject"] = reg_subject
if isinstance(reg_template, str):
meta["reg_template"] = reg_template
start_time = time()
if self.use_prealign:
reg_prealign = self.prealign(reg_subject, reg_template)
else:
reg_prealign = None
logger.info("Calculating SyN registration...")
_, mapping = syn_registration(
reg_subject.get_fdata(),
reg_template.get_fdata(),
moving_affine=reg_subject.affine,
static_affine=reg_template.affine,
prealign=reg_prealign,
**self.syn_kwargs,
)
mapping = get_simplified_transform(mapping)
total_time = time() - start_time
meta["total_time"] = total_time
logger.info(f"Saving {mapping_file_forward}")
nib.save(
nib.Nifti1Image(mapping.forward, reg_subject.affine),
mapping_file_forward,
)
write_json(meta_forward_fname, meta)
logger.info(f"Saving {mapping_file_backward}")
nib.save(
nib.Nifti1Image(mapping.backward, reg_template.affine),
mapping_file_backward,
)
write_json(meta_backward_fname, meta)
mapping = reg.read_syn_mapping(mapping_file_forward, mapping_file_backward)
return mapping
[docs]
class SlrMap(AffineMapMixin, Definition):
"""
Calculate a SLR registration for each subject/session
using reg_subject and reg_template.
Parameters
----------
slr_kwargs : dictionary, optional
Parameters to pass to whole_brain_slr
in dipy, which does the SLR alignment.
Default: {}
Notes
-----
Use this class to tell pyAFQ to use
Streamline-based Linear Registration (SLR)
for registration. Note that the reg_template and reg_subject
parameters passed to :class:`AFQ.api.group.GroupAFQ` should
be streamlines when using this registration.
Examples
--------
api.GroupAFQ(mapping=SlrMap())
"""
def __init__(self, slr_kwargs=None):
if slr_kwargs is None:
slr_kwargs = {}
[docs]
self.slr_kwargs = slr_kwargs
[docs]
def gen_mapping(
self,
reg_subject,
reg_template,
subject_sls,
template_sls,
):
_, transform, _, _ = whole_brain_slr(
subject_sls, template_sls, x0="affine", verbose=False, **self.slr_kwargs
)
return transform
[docs]
class AffMap(AffineMapMixin, Definition):
"""
Calculate an affine registration for each subject/session
using reg_subject and reg_template.
Parameters
----------
affine_kwargs : dictionary, optional
Parameters to pass to affine_registration
in dipy.align, which does the linear pre-alignment.
Default: {}
Notes
-----
This will only perform a linear alignment for registration.
Examples
--------
api.GroupAFQ(mapping=AffMap())
"""
def __init__(self, affine_kwargs=None):
if affine_kwargs is None:
affine_kwargs = {}
[docs]
self.affine_kwargs = affine_kwargs
[docs]
def gen_mapping(
self,
reg_subject,
reg_template,
subject_sls,
template_sls,
):
return np.linalg.inv(self.prealign(reg_subject, reg_template))
[docs]
class IdentityMap(AffineMapMixin, Definition):
"""
Does not perform any transformations from MNI to subject where
pyAFQ normally would.
Examples
--------
my_example_mapping = IdentityMap()
api.GroupAFQ(mapping=my_example_mapping)
"""
def __init__(self):
pass
[docs]
def gen_mapping(
self,
reg_subject,
reg_template,
subject_sls,
template_sls,
):
return np.identity(4)