import numpy as np
import os.path as op
from time import time
import logging
from dipy.io.stateful_tractogram import StatefulTractogram, Space
from dipy.io.streamline import save_tractogram
import dipy.tracking.streamlinespeed as dps
import dipy.tracking.streamline as dts
from dipy.tracking.distances import bundles_distances_mdf
from AFQ.definitions.mapping import ConformedFnirtMapping
[docs]logger = logging.getLogger('AFQ')
[docs]def flip_sls(select_sl, idx_to_flip, in_place=False):
"""
Helper function to flip streamlines
"""
if in_place:
flipped_sl = select_sl
else:
flipped_sl = [None] * len(select_sl)
for ii, sl in enumerate(select_sl):
if idx_to_flip[ii]:
flipped_sl[ii] = sl[::-1]
else:
flipped_sl[ii] = sl
return flipped_sl
[docs]def cut_sls_by_dist(select_sl, roi_dists, roi_idxs,
in_place=False):
"""
Helper function to cut streamlines according to which points
are closest to certain rois.
Parameters
----------
select_sl, streamlines to cut
roi_dists, distances from a given streamline to a given inclusion roi
roi_idxs, two indices into the list of inclusion rois to use for the cut
in_place, whether to modify select_sl
"""
if in_place:
cut_sl = select_sl
else:
cut_sl = [None] * len(select_sl)
for idx, this_sl in enumerate(select_sl):
if roi_idxs[0] == -1:
min0 = 0
else:
min0 = int(roi_dists[idx, roi_idxs[0]])
if roi_idxs[1] == -1:
min1 = len(this_sl)
else:
min1 = int(roi_dists[idx, roi_idxs[1]])
# handle if sls not flipped
if min0 > min1:
min0, min1 = min1, min0
# If the point that is closest to the first ROI
# is the same as the point closest to the second ROI,
# include the surrounding points to make a streamline.
if min0 == min1:
min1 = min1 + 1
min0 = min0 - 1
cut_sl[idx] = this_sl[min0:min1]
return cut_sl
[docs]def read_tg(tg, nb_streamlines=None):
if nb_streamlines and len(tg) > nb_streamlines:
tg = StatefulTractogram.from_sft(
dts.select_random_set_of_streamlines(
tg.streamlines,
nb_streamlines
),
tg)
return tg
[docs]def orient_by_streamline(sls, template_sl):
DM = bundles_distances_mdf(
sls,
[template_sl, template_sl[::-1]])
return DM[:, 0] > DM[:, 1]
[docs]def move_streamlines(tg, to, mapping, img, save_intermediates=None):
"""Move streamlines to or from template space.
to : str
Either "template" or "subject".
mapping : ConformedMapping
Mapping to use to move streamlines.
img : Nifti1Image
Space to move streamlines to.
"""
tg_og_space = tg.space
if isinstance(mapping, ConformedFnirtMapping):
if to != "subject":
raise ValueError(
"Attempted to transform streamlines to template using "
"unsupported mapping. "
"Use something other than Fnirt.")
tg.to_vox()
moved_sl = []
for sl in tg.streamlines:
moved_sl.append(mapping.transform_inverse_pts(sl))
else:
tg.to_rasmm()
if to == "template":
volume = mapping.forward
else:
volume = mapping.backward
delta = dts.values_from_volume(
volume,
tg.streamlines, np.eye(4))
moved_sl = dts.Streamlines(
[d + s for d, s in zip(delta, tg.streamlines)])
moved_sft = StatefulTractogram(
moved_sl,
img,
Space.RASMM)
if save_intermediates is not None:
save_tractogram(
moved_sft,
op.join(save_intermediates,
f'sls_in_{to}.trk'),
bbox_valid_check=False)
tg.to_space(tg_og_space)
return moved_sft
[docs]def resample_tg(tg, n_points):
# reformat for dipy's set_number_of_points
if isinstance(tg, np.ndarray):
if len(tg.shape) > 2:
streamlines = tg.tolist()
streamlines = [np.asarray(item) for item in streamlines]
elif hasattr(tg, "streamlines"):
streamlines = tg.streamlines
else:
streamlines = tg
return dps.set_number_of_points(streamlines, n_points)
[docs]class SlsBeingRecognized:
def __init__(self, sls, logger, save_intermediates, b_name, ref,
n_roi_dists):
self.oriented_yet = False
self.selected_fiber_idxs = np.arange(len(sls), dtype=np.uint32)
self.sls_flipped = np.zeros(len(sls), dtype=np.bool8)
self.logger = logger
self.start_time = -1
self.save_intermediates = save_intermediates
self.b_name = b_name
self.ref_sls = sls
self.ref = ref
self.n_roi_dists = n_roi_dists
[docs] def initiate_selection(self, clean_name):
self.start_time = time()
self.logger.info(f"Filtering by {clean_name}")
return np.zeros(len(self.selected_fiber_idxs), dtype=np.bool8)
[docs] def select(self, idx, clean_name, cut=False):
self.selected_fiber_idxs = self.selected_fiber_idxs[idx]
self.sls_flipped = self.sls_flipped[idx]
if hasattr(self, "roi_dists"):
self.roi_dists = self.roi_dists[idx]
time_taken = time() - self.start_time
self.logger.info(
f"After filtering by {clean_name} (time: {time_taken}s), "
f"{len(self)} streamlines remain.")
if self.save_intermediates is not None:
save_tractogram(
StatefulTractogram(
self.get_selected_sls(cut=cut),
self.ref, Space.VOX),
op.join(self.save_intermediates,
f'sls_after_{clean_name}_for_{self.b_name}.trk'),
bbox_valid_check=False)
[docs] def get_selected_sls(self, cut=False, flip=False):
selected_sls = self.ref_sls[self.selected_fiber_idxs]
if cut and hasattr(self, "roi_dists") and self.n_roi_dists > 1:
selected_sls = cut_sls_by_dist(
selected_sls, self.roi_dists,
(0, self.n_roi_dists - 1),
in_place=False)
if flip:
selected_sls = flip_sls(
selected_sls, self.sls_flipped,
in_place=False)
return selected_sls
[docs] def reorient(self, idx):
if self.oriented_yet:
raise RuntimeError((
"Attempted to oriented streamlines "
"that were already oriented. "
"This is a bug in the implementation of a "
"bundle recognition procedure. "))
self.oriented_yet = True
self.sls_flipped[idx] = True
[docs] def __bool__(self):
return len(self) > 0
[docs] def __len__(self):
return len(self.selected_fiber_idxs)