import copy
import os.path as op
from time import time
import dipy.tracking.streamline as dts
import dipy.tracking.streamlinespeed as dps
import numpy as np
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import save_tractogram
from dipy.tracking.distances import bundles_distances_mdf
from tqdm import tqdm
[docs]
axes_dict = {
"L/R": 0,
"L": 0,
"R": 0,
"P/A": 1,
"P": 1,
"A": 1,
"I/S": 2,
"I": 2,
"S": 2,
}
[docs]
def manual_orient_sls(fgarray):
"""
Helper function to manually orient streamlines by their endpoints,
according to LPI+ pyAFQ standard assuming streamlines are in RASMM
"""
endpoint_diff = fgarray[:, 0, :] - fgarray[:, -1, :]
primary_axis = np.argmax(np.abs(endpoint_diff), axis=1)
return endpoint_diff[np.arange(len(fgarray)), primary_axis] < 0
[docs]
def tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas):
# We need to calculate the size of a voxel, so we can transform
# from mm to voxel units:
R = img.affine[0:3, 0:3]
vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R))))
# Tolerance is set to the square of the distance to the corner
# because we are using the squared Euclidean distance in calls to
# `cdist` to make those calls faster.
if dist_to_waypoint is None:
tol = dts.dist_to_corner(img.affine)
else:
tol = dist_to_waypoint / vox_dim
dist_to_atlas = int(input_dist_to_atlas / vox_dim)
return tol, dist_to_atlas, vox_dim
[docs]
def flip_sls(select_sl, idx_to_flip, in_place=False):
"""
Helper function to flip streamlines
"""
if in_place:
flipped_sl = select_sl
else:
flipped_sl = [None] * len(select_sl)
for ii, sl in enumerate(select_sl):
if idx_to_flip[ii]:
flipped_sl[ii] = sl[::-1]
else:
flipped_sl[ii] = sl
return flipped_sl
[docs]
def cut_sls_by_closest(select_sl, roi_closest, roi_idxs, in_place=False):
"""
Helper function to cut streamlines according to which points
are closest to certain rois.
Parameters
----------
select_sl
streamlines to cut
roi_closest
indices into given streamline of points nearest
to inclusion rois
roi_idxs
two indices into the list of inclusion rois to use for the cut
in_place
whether to modify select_sl
"""
if in_place:
cut_sl = select_sl
else:
cut_sl = [None] * len(select_sl)
for idx, this_sl in enumerate(select_sl):
if roi_idxs[0] == -1:
min0 = 0
else:
min0 = roi_closest[idx, roi_idxs[0]]
if roi_idxs[1] == -1:
min1 = len(this_sl)
else:
min1 = roi_closest[idx, roi_idxs[1]]
# handle if sls not flipped
if min0 > min1:
min0, min1 = min1, min0
# If the point that is closest to the first ROI
# is the same as the point closest to the second ROI,
# include the surrounding points to make a streamline.
if min0 == min1:
min1 = min1 + 1
min0 = min0 - 1
cut_sl[idx] = this_sl[min0:min1]
return cut_sl
[docs]
def orient_by_streamline(sls, template_sl):
DM = bundles_distances_mdf(sls, [template_sl, template_sl[::-1]])
return DM[:, 0] > DM[:, 1]
[docs]
def resample_tg(tg, n_points):
# reformat for dipy's set_number_of_points
if isinstance(tg, np.ndarray):
if len(tg.shape) > 2:
streamlines = tg.tolist()
streamlines = [np.asarray(item) for item in streamlines]
elif hasattr(tg, "streamlines"):
streamlines = tg.streamlines
else:
streamlines = tg
return dps.set_number_of_points(streamlines, n_points)
[docs]
class SlsBeingRecognized:
def __init__(self, sls, save_intermediates, b_name, ref, n_roi):
[docs]
self.oriented_yet = False
[docs]
self.selected_fiber_idxs = np.arange(len(sls), dtype=np.uint32)
[docs]
self.sls_flipped = np.zeros(len(sls), dtype=np.bool_)
[docs]
def initiate_selection(self, clean_name):
self.start_time = time()
tqdm.write(f"Filtering by {clean_name}")
return np.zeros(len(self.selected_fiber_idxs), dtype=np.bool_)
[docs]
def select(self, idx, clean_name, cut=False):
self.selected_fiber_idxs = self.selected_fiber_idxs[idx]
self.sls_flipped = self.sls_flipped[idx]
if hasattr(self, "roi_closest"):
self.roi_closest = self.roi_closest[idx]
if hasattr(self, "roi_dists"):
self.roi_dists = self.roi_dists[idx]
time_taken = time() - self.start_time
tqdm.write(
f"After filtering by {clean_name} (time: {time_taken}s), "
f"{len(self)} streamlines remain."
)
# Only save intermediates after the 90% of the
# streamlines have been filtered out,
# otherwise its impractical
if self.save_intermediates is not None and len(self) < 0.1 * len(self.ref_sls):
save_tractogram(
StatefulTractogram(
self.get_selected_sls(cut=cut), self.ref, Space.RASMM
),
op.join(
self.save_intermediates,
f"sls_after_{clean_name}_for_{self.b_name}.trk",
),
bbox_valid_check=False,
)
[docs]
def get_selected_sls(self, cut=False, flip=False):
selected_sls = self.ref_sls[self.selected_fiber_idxs]
if cut and hasattr(self, "roi_closest") and self.n_roi > 1:
selected_sls = cut_sls_by_closest(
selected_sls, self.roi_closest, (0, self.n_roi - 1), in_place=False
)
if flip:
selected_sls = flip_sls(selected_sls, self.sls_flipped, in_place=False)
return selected_sls
[docs]
def reorient(self, idx):
if self.oriented_yet:
raise RuntimeError(
(
"Attempted to oriented streamlines "
"that were already oriented. "
"This is a bug in the implementation of a "
"bundle recognition procedure. "
)
)
self.oriented_yet = True
self.sls_flipped[idx] = True
[docs]
def __bool__(self):
return len(self) > 0
[docs]
def __len__(self):
return len(self.selected_fiber_idxs)
[docs]
def copy(self, new_name, n_roi):
new_copy = copy.copy(self)
new_copy.b_name = new_name
if n_roi > 0:
if self.n_roi > 0:
raise NotImplementedError(
(
"You cannot have includes in the original bundle and"
" subbundles; only one or the other."
)
)
else:
new_copy.n_roi = n_roi
new_copy.selected_fiber_idxs = self.selected_fiber_idxs.copy()
new_copy.sls_flipped = self.sls_flipped.copy()
if hasattr(self, "roi_closest"):
new_copy.roi_closest = self.roi_closest.copy()
if hasattr(self, "roi_dists"):
new_copy.roi_dists = self.roi_dists.copy()
return new_copy
[docs]
def export_selected(self, chunk_offset):
return {
"global_idx": (
self.selected_fiber_idxs.astype(np.int64) + int(chunk_offset)
).copy(),
"sls_flipped": self.sls_flipped.copy(),
"oriented_yet": self.oriented_yet,
"roi_closest": (
self.roi_closest.copy() if hasattr(self, "roi_closest") else None
),
"roi_dists": (
self.roi_dists.copy() if hasattr(self, "roi_dists") else None
),
}
@classmethod
[docs]
def from_selected(
cls,
survivor_dicts,
full_streamlines,
save_intermediates,
b_name,
ref,
n_roi,
):
non_empty = [d for d in survivor_dicts if d["global_idx"].size > 0]
if not non_empty:
return None
global_idx = np.concatenate([d["global_idx"] for d in non_empty])
sls_flipped = np.concatenate([d["sls_flipped"] for d in non_empty])
oriented_yet = any(d["oriented_yet"] for d in non_empty)
if global_idx.size > 1 and not np.all(np.diff(global_idx) > 0):
order = np.argsort(global_idx, kind="stable")
global_idx = global_idx[order]
sls_flipped = sls_flipped[order]
else:
order = None
has_roi = [d["roi_closest"] is not None for d in non_empty]
if any(has_roi):
if not all(has_roi):
raise RuntimeError(
"Inconsistent roi_closest across chunks for bundle "
f"{b_name}: some chunks have it, some don't. This is a"
" bug in chunked recognition."
)
roi_closest = np.concatenate([d["roi_closest"] for d in non_empty], axis=0)
roi_dists = np.concatenate([d["roi_dists"] for d in non_empty], axis=0)
if order is not None:
roi_closest = roi_closest[order]
roi_dists = roi_dists[order]
else:
roi_closest = None
roi_dists = None
inst = cls(
full_streamlines,
save_intermediates,
b_name,
ref,
n_roi,
)
inst.selected_fiber_idxs = global_idx.astype(np.uint32, copy=False)
inst.sls_flipped = sls_flipped
inst.oriented_yet = oriented_yet
if roi_closest is not None:
inst.roi_closest = roi_closest
inst.roi_dists = roi_dists
return inst