Source code for AFQ.registration

"""
Registration tools
"""
import numpy as np
import nibabel as nib
from dipy.align.imwarp import DiffeomorphicMap

from dipy.align import syn_registration

import dipy.core.gradients as dpg
from dipy.align.streamlinear import whole_brain_slr


__all__ = ["syn_register_dwi", "write_mapping", "read_mapping",
           "slr_registration"]


def reduce_shape(shape):
    """
    Reduce dimension in shape to 3 if possible
    """
    try:
        return shape[:3]
    except TypeError:
        return shape


[docs]def syn_register_dwi(dwi, gtab, template=None, **syn_kwargs): """ Register DWI data to a template. Parameters ----------- dwi : nifti image or str Image containing DWI data, or full path to a nifti file with DWI. gtab : GradientTable or list of strings The gradients associated with the DWI data, or a string with [fbcal, ] template : nifti image or str, optional syn_kwargs : key-word arguments for :func:`syn_registration` Returns ------- DiffeomorphicMap object """ if template is None: import AFQ.data.fetch as afd template = afd.read_mni_template() if isinstance(template, str): template = nib.load(template) template_data = template.get_fdata() template_affine = template.affine if isinstance(dwi, str): dwi = nib.load(dwi) if not isinstance(gtab, dpg.GradientTable): gtab = dpg.gradient_table(*gtab) dwi_affine = dwi.affine dwi_data = dwi.get_fdata() mean_b0 = np.mean(dwi_data[..., gtab.b0s_mask], -1) warped_b0, mapping = syn_registration(mean_b0, template_data, moving_affine=dwi_affine, static_affine=template_affine, **syn_kwargs) return warped_b0, mapping
[docs]def write_mapping(mapping, fname): """ Write out a syn registration mapping to file Parameters ---------- mapping : a DiffeomorphicMap object derived from :func:`syn_registration` fname : str Full path to the nifti file storing the mapping """ if isinstance(mapping, DiffeomorphicMap): mapping_imap = np.array([mapping.forward.T, mapping.backward.T]).T nib.save(nib.Nifti1Image(mapping_imap, mapping.codomain_world2grid), fname) else: np.save(fname, mapping.affine)
[docs]def read_mapping(disp, domain_img, codomain_img, prealign=None): """ Read a syn registration mapping from a nifti file Parameters ---------- disp : str, Nifti1Image, or ndarray If string, file must of an image or ndarray. If image, contains the mapping displacement field in each voxel Shape (x, y, z, 3, 2) If ndarray, contains affine transformation used for mapping domain_img : str or Nifti1Image codomain_img : str or Nifti1Image Returns ------- A :class:`DiffeomorphicMap` object """ if isinstance(disp, str): if "nii.gz" in disp: disp = nib.load(disp) else: disp = np.load(disp) if isinstance(domain_img, str): domain_img = nib.load(domain_img) if isinstance(codomain_img, str): codomain_img = nib.load(codomain_img) if isinstance(disp, nib.Nifti1Image): mapping = DiffeomorphicMap(3, disp.shape[:3], disp_grid2world=np.linalg.inv(disp.affine), domain_shape=domain_img.shape[:3], domain_grid2world=domain_img.affine, codomain_shape=codomain_img.shape, codomain_grid2world=codomain_img.affine, prealign=prealign) disp_data = disp.get_fdata().astype(np.float32) mapping.forward = disp_data[..., 0] mapping.backward = disp_data[..., 1] mapping.is_inverse = True else: from AFQ.definitions.mapping import ConformedAffineMapping mapping = ConformedAffineMapping( disp, domain_grid_shape=reduce_shape( domain_img.shape), domain_grid2world=domain_img.affine, codomain_grid_shape=reduce_shape( codomain_img.shape), codomain_grid2world=codomain_img.affine) return mapping
[docs]def slr_registration(moving_data, static_data, moving_affine=None, static_affine=None, moving_shape=None, static_shape=None, **kwargs): """Register a source image (moving) to a target image (static). Parameters ---------- moving : ndarray The source tractography data to be registered moving_affine : ndarray The affine associated with the moving (source) data. moving_shape : ndarray The shape of the space associated with the static (target) data. static : ndarray The target tractography data for registration static_affine : ndarray The affine associated with the static (target) data. static_shape : ndarray The shape of the space associated with the static (target) data. **kwargs: kwargs are passed into whole_brain_slr Returns ------- AffineMap """ from AFQ.definitions.mapping import ConformedAffineMapping _, transform, _, _ = whole_brain_slr( static_data, moving_data, x0='affine', verbose=False, **kwargs) return ConformedAffineMapping( transform, codomain_grid_shape=reduce_shape(static_shape), codomain_grid2world=static_affine, domain_grid_shape=reduce_shape(moving_shape), domain_grid2world=moving_affine)