Source code for AFQ.recognition.preprocess

import numpy as np
import nibabel as nib
import pimms
from time import time
import logging

import dipy.tracking.streamline as dts

import AFQ.recognition.utils as abu


[docs]logger = logging.getLogger('AFQ')
@pimms.calc("tol", "dist_to_atlas", "vox_dim")
[docs]def tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas): # We need to calculate the size of a voxel, so we can transform # from mm to voxel units: R = img.affine[0:3, 0:3] vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) # Tolerance is set to the square of the distance to the corner # because we are using the squared Euclidean distance in calls to # `cdist` to make those calls faster. if dist_to_waypoint is None: tol = dts.dist_to_corner(img.affine) else: tol = dist_to_waypoint / vox_dim dist_to_atlas = int(input_dist_to_atlas / vox_dim) return tol, dist_to_atlas, vox_dim
@pimms.calc("fgarray")
[docs]def fgarray(tg): """ Streamlines resampled to 20 points. """ logger.info("Resampling Streamlines...") start_time = time() fg_array = np.array(abu.resample_tg(tg, 20)) logger.info(( "Streamlines Resampled " f"(time: {time()-start_time}s)")) return fg_array
@pimms.calc("crosses")
[docs]def crosses(fgarray, img): """ Classify the streamlines by whether they cross the midline. Creates a crosses attribute which is an array of booleans. Each boolean corresponds to a streamline, and is whether or not that streamline crosses the midline. """ # What is the x,y,z coordinate of 0,0,0 in the template space? zero_coord = np.dot(np.linalg.inv(img.affine), np.array([0, 0, 0, 1])) orientation = nib.orientations.aff2axcodes(img.affine) lr_axis = 0 for idx, axis_label in enumerate(orientation): if axis_label in ['L', 'R']: lr_axis = idx break return np.logical_and( np.any(fgarray[:, :, lr_axis] > zero_coord[lr_axis], axis=1), np.any(fgarray[:, :, lr_axis] < zero_coord[lr_axis], axis=1))
# Things that can be calculated for multiple bundles at once # (i.e., for a whole tractogram) go here
[docs]def get_preproc_plan(img, tg, dist_to_waypoint, dist_to_atlas): preproc_plan = pimms.Plan( tolerance_mm_to_vox=tolerance_mm_to_vox, fgarray=fgarray, crosses=crosses) return preproc_plan( img=img, tg=tg, dist_to_waypoint=dist_to_waypoint, input_dist_to_atlas=dist_to_atlas)