Source code for AFQ.tractography.tractography

import logging
from math import radians
from time import time

import dipy.data as dpd
import nibabel as nib
import numba
import numpy as np
from dipy.align import resample
from dipy.core.sphere import HemiSphere
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.reconst import shm
from dipy.reconst.dti import decompose_tensor, from_lower_triangular
from dipy.tracking.stopping_criterion import ActStoppingCriterion
from dipy.tracking.tracker import (
    deterministic_tracking,
    pft_tracking,
)
from nibabel.streamlines.tractogram import LazyTractogram
from skimage.segmentation import find_boundaries
from tqdm import tqdm

from AFQ._fixes import tensor_odf
from AFQ.tractography.utils import gen_seeds


[docs] def track( params_file, pve, n_threads, directions="prob", max_angle=30.0, sphere="repulsion724", seed_mask=None, seed_threshold=0.5, gm_threshold=0.4, thresholds_as_percentages=False, n_seeds=1e7, random_seeds=True, rng_seed=None, step_size=0.5, minlen=20, maxlen=500, odf_model="CSD_AODF", basis_type="descoteaux07", legacy=True, trx=True, jit_backend="numba", jit_chunk_size=None, ): """ Tractography Parameters ---------- params_file : str, nibabel img. Full path to a nifti file containing CSD spherical harmonic coefficients, or nibabel img with model params. pve : str, nibabel img Full path to a nifti file containing tissue probability maps, or nibabel img with tissue probability maps. This should be of the order (pve_csf, pve_gm, pve_wm). n_threads : int The number of threads to use in tracking. If 0 or -1, uses all available threads. directions : str How tracking directions are determined. One of: {"det" | "prob" | "pft"} pft refers to Particle Filtering Tracking ([Girard2014]_). Default: "prob" max_angle : float, optional. The maximum turning angle in each step. Default: 30 sphere : str or DIPY Sphere The discretization of the ODF. Can be a DIPY Sphere or a string name of a DIPY Sphere. Default: "repulsion724" seed_mask : array, optional. Float or binary mask describing the ROI within which we seed for tracking. Default to the entire volume (all ones). seed_threshold : float, optional. A value of the seed_mask above which tracking is seeded. Default to 0. gm_threshold : float, optional. A value of the pve_gm_data above which we consider a voxel to be GM for the purposes of ACT stopping criterion. Default: 0.4. n_seeds : int or 2D array, optional. The seeding density: if this is an int, it is is how many seeds in each voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D array, these are the coordinates of the seeds in RASMM. Unless random_seeds is set to True, in which case this is the total number of random seeds to generate within the mask. Default: 1e7 random_seeds : bool Whether to generate a total of n_seeds random seeds in the mask. Default: True rng_seed : int random seed used to generate random seeds if random_seeds is set to True. Default: None thresholds_as_percentages : bool, optional Interpret seed_threshold as percentages of the total non-nan voxels in the seed mask to include (between 0 and 100), instead of as a threshold on the values themselves. Default: False step_size : float, optional. The size of a step (in mm) of tractography. Default: 0.5 minlen: int, optional The minimal length (mm) in a streamline. Default: 20 maxlen: int, optional The maximum length (mm) in a streamline. Default: 250 odf_model : str or Definition, optional Can be either a string or Definition. If a string, it must be one of {"DTI", "CSD", "DKI", "GQ", "RUMBA", "MSMT_AODF", "CSD_AODF", "MSMTCSD"}. If a Definition, we assume it is a definition of a file containing Spherical Harmonics coefficients. Defaults to use "CSD_AODF" basis_type : str, optional The spherical harmonic basis type used to represent the coefficients. One of {"descoteaux07", "tournier07"}. Default: "descoteaux07" legacy : bool, optional Whether the legacy SH basis definition should be used. See Dipy documentation for more details. Default: True trx : bool, optional Whether to return the streamlines compatible with input to TRX file (i.e., as a LazyTractogram class instance). Default: True jit_backend : str, optional If directions is "prob" or "ptt", the JIT backend to use. One of {"auto", "cuda", "metal", "webgpu", or "numba"}. Default: "numba" jit_chunk_size : int, optional If directions is "prob" or "ptt", the chunk size to use for JIT tracking. If None, chooses 25000 for numba backend and 5000 for other backends. Default: None Returns ------- list of streamlines () References ---------- .. [Girard2014] Girard, G., Whittingstall, K., Deriche, R., & Descoteaux, M. Towards quantitative connectivity analysis: reducing tractography biases. NeuroImage, 98, 266-278, 2014. .. [Smith2012] Smith RE, Tournier JD, Calamante F, Connelly A. Anatomically-constrained tractography: improved diffusion MRI streamlines tractography through effective use of anatomical information. Neuroimage. 2012 Sep;62(3):1924-38. doi: 10.1016/j.neuroimage.2012.06.005. Epub 2012 Jun 13. """ logger = logging.getLogger("AFQ") logger.info("Loading Image...") if isinstance(params_file, str): params_img = nib.load(params_file) else: params_img = params_file if isinstance(pve, str): pve_img = nib.load(pve) if isinstance(pve, nib.Nifti1Image): pve_img = pve pve_data = pve_img.get_fdata() model_params = params_img.get_fdata() if isinstance(odf_model, str): odf_model = odf_model.upper() directions = directions.lower() if n_threads == -1: n_threads = 0 if seed_mask is None: seed_mask = np.ones(params_img.shape[:3]) if isinstance(sphere, str): sphere = dpd.get_sphere(name=sphere) if not len(pve_data.shape) == 4 or pve_data.shape[3] != 3: raise RuntimeError( "For pve, expected pve_data with shape [x, y, z, 3]. " f"Instead, got {pve_data.shape}." ) pve_csf_data = pve_data[..., 0] pve_gm_data = pve_data[..., 1] pve_wm_data = pve_data[..., 2] pve_csf_data = resample( pve_csf_data, model_params[..., 0], moving_affine=pve_img.affine, static_affine=params_img.affine, ).get_fdata() pve_gm_data = resample( pve_gm_data, model_params[..., 0], moving_affine=pve_img.affine, static_affine=params_img.affine, ).get_fdata() pve_wm_data = resample( pve_wm_data, model_params[..., 0], moving_affine=pve_img.affine, static_affine=params_img.affine, ).get_fdata() # here we treat wm that borders the edge of the brain mask as gm # this is so that streamlines that hit the end of the # (presumably masked) fodf are treated as valid # (think brain stem) brain_mask = np.any(model_params != 0, axis=-1).astype(np.uint8) edge = find_boundaries(brain_mask, mode="inner") pve_gm_data[edge] = 1.0 pve_wm_data[edge] = 0.0 pve_csf_data[edge] = 0.0 # We relax ACT stopping criterion here to allow streamlines closer # to the WM/GM boundary. pve_gm_data *= 0.5 / gm_threshold stopping_criterion = ActStoppingCriterion.from_pve( pve_wm_data, pve_gm_data, pve_csf_data ) if odf_model == "DTI" or odf_model == "DKI": evals, evecs = decompose_tensor(from_lower_triangular(model_params)) odf = tensor_odf(evals, evecs, sphere) model_params = shm.sf_to_sh( odf, sphere, basis_type=basis_type, legacy=legacy, full_basis=True ) tracking_kwargs = {} if directions == "pft" and (odf_model == "DTI" or odf_model == "DKI"): tracking_kwargs["sf"] = odf else: sym_order = (-3.0 + np.sqrt(1.0 + 8.0 * model_params.shape[3])) / 2.0 if sym_order.is_integer(): sh_order_max = sym_order full_basis = False else: full_order = np.sqrt(model_params.shape[3]) - 1.0 sh_order_max = full_order full_basis = True pmf = shm.sh_to_sf( model_params, sphere, sh_order_max=sh_order_max, full_basis=full_basis ) pmf[pmf < 0] = 0 tracking_kwargs["sf"] = pmf if rng_seed is not None: tracking_kwargs["random_seed"] = int(rng_seed) else: tracking_kwargs["random_seed"] = np.random.randint(0, 2**31 - 1) seeds = gen_seeds( seed_mask, seed_threshold, n_seeds, thresholds_as_percentages, random_seeds, rng_seed, params_img.affine, ) if directions == "prob" or directions == "ptt": jit_backend = jit_backend.lower() if jit_backend == "auto": from cuslines import ( ProbDirectionGetter, PttDirectionGetter, Tracker, ) elif jit_backend == "cuda": from cuslines.cuda_python import ( GPUTracker as Tracker, ) from cuslines.cuda_python import ( ProbDirectionGetter, PttDirectionGetter, ) elif jit_backend == "metal": from cuslines.metal import ( MetalGPUTracker as Tracker, ) from cuslines.metal import ( MetalProbDirectionGetter as ProbDirectionGetter, ) from cuslines.metal import ( MetalPttDirectionGetter as PttDirectionGetter, ) elif jit_backend == "webgpu": from cuslines.webgpu import ( WebGPUProbDirectionGetter as ProbDirectionGetter, ) from cuslines.webgpu import ( WebGPUPttDirectionGetter as PttDirectionGetter, ) from cuslines.webgpu import ( WebGPUTracker as Tracker, ) elif jit_backend == "numba": from cuslines.numba import ( CPUProbDirectionGetter as ProbDirectionGetter, ) from cuslines.numba import ( CPUPttDirectionGetter as PttDirectionGetter, ) from cuslines.numba import ( CPUTracker as Tracker, ) else: raise ValueError( "jit_backend must be one of 'auto', 'cuda', " f"'metal', 'numba', or 'webgpu', not {jit_backend}" ) if directions == "ptt": dg = PttDirectionGetter() else: dg = ProbDirectionGetter() inv_affine = np.linalg.inv(params_img.affine) seeds = np.dot(seeds, inv_affine[:3, :3].T) seeds += inv_affine[:3, 3] minlen = int(minlen / step_size) maxlen = int(maxlen / step_size) R = params_img.affine[0:3, 0:3] vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) step_size = step_size / vox_dim if jit_chunk_size is None: if jit_backend == "numba": jit_chunk_size = 25000 else: jit_chunk_size = 5000 if n_threads != 0: old_numba_n_threads = numba.get_num_threads() numba.set_num_threads(n_threads) with Tracker( dg, tracking_kwargs["sf"], pve_wm_data, gm_threshold, sphere.vertices, sphere.edges, sphere_symm=isinstance(sphere, HemiSphere), max_angle=radians(max_angle), step_size=step_size, min_pts=minlen, max_pts=maxlen, rng_seed=tracking_kwargs["random_seed"], chunk_size=jit_chunk_size, ) as jit_tracker: if trx: res = jit_tracker.generate_trx(seeds, params_img) else: res = jit_tracker.generate_sft(seeds, params_img) if n_threads != 0: numba.set_num_threads(old_numba_n_threads) return res else: if directions == "det": tracker = deterministic_tracking elif directions == "pft": tracker = pft_tracking else: raise ValueError(f"Unrecognized direction '{directions}'.") logger.info("Note there will be a long initial delay as seeds are initialized") start_time = time() tracker = tqdm( tracker( seeds, stopping_criterion, params_img.affine, max_angle=max_angle, sphere=sphere, basis_type=basis_type, legacy=legacy, step_size=step_size, min_len=minlen, max_len=maxlen, return_all=False, nbr_threads=int(n_threads), **tracking_kwargs, ), total=len(seeds), desc="Tracking, note that the total is an overestimate...", ) logger.info((f"Seed initialization took {time() - start_time:.2f} seconds.")) if trx: return LazyTractogram(lambda: tracker, affine_to_rasmm=params_img.affine) else: return StatefulTractogram(tracker, params_img, Space.RASMM)