import cuslines
import numpy as np
from math import radians
from tqdm import tqdm
import logging
from dipy.reconst.shm import OpdtModel, CsaOdfModel
from dipy.reconst import shm
from import StatefulTractogram, Space
from nibabel.streamlines.array_sequence import concatenate
from nibabel.streamlines.tractogram import Tractogram
from trx.trx_file_memmap import TrxFile
from AFQ.tractography.utils import gen_seeds, get_percentile_threshold
[docs]logger = logging.getLogger('AFQ')
# Modified from
[docs]def gpu_track(data, gtab, seed_img, stop_img,
odf_model, sphere, directions,
seed_threshold, stop_threshold, thresholds_as_percentages,
max_angle, step_size, n_seeds, random_seeds, rng_seed, use_trx, ngpus,
Perform GPU tractography on DWI data.
data : ndarray
DWI data.
gtab : GradientTable
The gradient table.
seed_img : Nifti1Image
Float or binary mask describing the ROI within which we seed for
stop_img : Nifti1Image
A float or binary mask that determines a stopping criterion
(e.g. FA).
odf_model : str, optional
One of {"OPDT", "CSA"}
seed_threshold : float
The value of the seed_img above which tracking is seeded.
stop_threshold : float
The value of the stop_img below which tracking is
thresholds_as_percentages : bool
Interpret seed_threshold and stop_threshold as percentages of the
total non-nan voxels in the seed and stop mask to include
(between 0 and 100), instead of as a threshold on the
values themselves.
max_angle : float
The maximum turning angle in each step.
step_size : float
The size of a step (in mm) of tractography.
n_seeds : int
The seeding density: if this is an int, it is is how many seeds in each
voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D
array, these are the coordinates of the seeds. Unless random_seeds is
set to True, in which case this is the total number of random seeds
to generate within the mask. Default: 1
random_seeds : bool
If True, n_seeds is total number of random seeds to generate.
If False, n_seeds encodes the density of seeds to generate.
rng_seed : int
random seed used to generate random seeds if random_seeds is
set to True. Default: None
use_trx : bool
Whether to use trx.
ngpus : int
Number of GPUs to use.
chunk_size : int
Chunk size for GPU tracking.
sh_order = 8
seed_data = seed_img.get_fdata()
stop_data = stop_img.get_fdata()
if thresholds_as_percentages:
stop_threshold = get_percentile_threshold(
stop_data, stop_threshold)
theta = sphere.theta
phi = sphere.phi
sampling_matrix, _, _ = shm.real_sym_sh_basis(sh_order, theta, phi)
if directions == "boot":
if odf_model.lower() == "opdt":
model_type = cuslines.ModelType.OPDT
model = OpdtModel(
fit_matrix = model._fit_matrix
delta_b, delta_q = fit_matrix
elif odf_model.lower() == "csa":
model_type = cuslines.ModelType.CSA
model = CsaOdfModel(
gtab, sh_order=sh_order,
smooth=0.006, min_signal=1)
fit_matrix = model._fit_matrix
delta_b = fit_matrix
delta_q = fit_matrix
raise ValueError((
f"odf_model must be 'opdt' or "
f"'csa', not {odf_model}"))
if directions == "prob":
model_type = cuslines.ModelType.PROB
model_type = cuslines.ModelType.PTT
model = shm.SphHarmModel(gtab)
model.cache_set("sampling_matrix", sphere, sampling_matrix)
model_fit = shm.SphHarmFit(model, data, None)
data = model_fit.odf(sphere).clip(min=0)
delta_b = sampling_matrix
delta_q = sampling_matrix
b0s_mask = gtab.b0s_mask
dwi_mask = ~b0s_mask
x, y, z = model.gtab.gradients[dwi_mask].T
_, theta, phi = shm.cart2sphere(x, y, z)
B, _, _ = shm.real_sym_sh_basis(sh_order, theta, phi)
H = shm.hat(B)
R = shm.lcr_matrix(H)
gpu_tracker = cuslines.GPUTracker(
0.25, # relative peak threshold
radians(45), # min separation angle
data.astype(np.float64), H.astype(np.float64), R.astype(np.float64),
delta_b.astype(np.float64), delta_q.astype(np.float64),
b0s_mask.astype(np.int32), stop_data.astype(
np.float64), sampling_matrix.astype(np.float64),
sphere.vertices.astype(np.float64), sphere.edges.astype(np.int32),
ngpus=ngpus, rng_seed=0)
seeds = gen_seeds(
seed_data, seed_threshold,
n_seeds, thresholds_as_percentages,
random_seeds, rng_seed, np.eye(4))
global_chunk_sz = chunk_size * ngpus
nchunks = (seeds.shape[0] + global_chunk_sz - 1) // global_chunk_sz
# TODO: this code duplicated with GPUStreamlines...
# should probably be moved up to trx or cudipy at some point
if use_trx:
# Will resize by a factor of 2 if these are exceeded
sl_len_guess = 100
sl_per_seed_guess = 3
n_sls_guess = sl_per_seed_guess * len(seeds.shape[0])
# trx files use memory mapping
trx_file = TrxFile(
nb_vertices=n_sls_guess * sl_len_guess)
offsets_idx = 0
sls_data_idx = 0
with tqdm(total=seeds.shape[0]) as pbar:
for idx in range(int(nchunks)):
streamlines = gpu_tracker.generate_streamlines(
seeds[idx * global_chunk_sz:(idx + 1) * global_chunk_sz])
tractogram = Tractogram(
streamlines, affine_to_rasmm=seed_img.affine)
sls = tractogram.streamlines
new_offsets_idx = offsets_idx + len(sls._offsets)
new_sls_data_idx = sls_data_idx + len(sls._data)
if new_offsets_idx > trx_file.header["NB_STREAMLINES"]\
or new_sls_data_idx > trx_file.header["NB_VERTICES"]:
print("TRX resizing...")
* 2, nb_vertices=new_sls_data_idx * 2)
# TRX uses memmaps here
trx_file.streamlines._data[sls_data_idx:new_sls_data_idx] = sls._data
new_offsets_idx] = offsets_idx + sls._offsets
trx_file.streamlines._lengths[offsets_idx:new_offsets_idx] = sls._lengths
offsets_idx = new_offsets_idx
sls_data_idx = new_sls_data_idx
seeds[idx * global_chunk_sz:(idx + 1) * global_chunk_sz].shape[0])
return trx_file
streamlines_ls = [None] * nchunks
with tqdm(total=seeds.shape[0]) as pbar:
for idx in range(int(nchunks)):
streamlines_ls[idx] = gpu_tracker.generate_streamlines(
seeds[idx * global_chunk_sz:(idx + 1) * global_chunk_sz])
seeds[idx * global_chunk_sz:(idx + 1) * global_chunk_sz].shape[0])
sft = StatefulTractogram(
concatenate(streamlines_ls, 0),
seed_img, Space.VOX)
return sft