Source code for AFQ.recognition.utils

import copy
import os.path as op
from time import time

import dipy.tracking.streamline as dts
import dipy.tracking.streamlinespeed as dps
import numpy as np
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import save_tractogram
from dipy.tracking.distances import bundles_distances_mdf
from tqdm import tqdm

[docs] axes_dict = { "L/R": 0, "L": 0, "R": 0, "P/A": 1, "P": 1, "A": 1, "I/S": 2, "I": 2, "S": 2, }
[docs] def manual_orient_sls(fgarray): """ Helper function to manually orient streamlines by their endpoints, according to LPI+ pyAFQ standard assuming streamlines are in RASMM """ endpoint_diff = fgarray[:, 0, :] - fgarray[:, -1, :] primary_axis = np.argmax(np.abs(endpoint_diff), axis=1) return endpoint_diff[np.arange(len(fgarray)), primary_axis] < 0
[docs] def tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas): # We need to calculate the size of a voxel, so we can transform # from mm to voxel units: R = img.affine[0:3, 0:3] vox_dim = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) # Tolerance is set to the square of the distance to the corner # because we are using the squared Euclidean distance in calls to # `cdist` to make those calls faster. if dist_to_waypoint is None: tol = dts.dist_to_corner(img.affine) else: tol = dist_to_waypoint / vox_dim dist_to_atlas = int(input_dist_to_atlas / vox_dim) return tol, dist_to_atlas, vox_dim
[docs] def flip_sls(select_sl, idx_to_flip, in_place=False): """ Helper function to flip streamlines """ if in_place: flipped_sl = select_sl else: flipped_sl = [None] * len(select_sl) for ii, sl in enumerate(select_sl): if idx_to_flip[ii]: flipped_sl[ii] = sl[::-1] else: flipped_sl[ii] = sl return flipped_sl
[docs] def cut_sls_by_closest(select_sl, roi_closest, roi_idxs, in_place=False): """ Helper function to cut streamlines according to which points are closest to certain rois. Parameters ---------- select_sl streamlines to cut roi_closest indices into given streamline of points nearest to inclusion rois roi_idxs two indices into the list of inclusion rois to use for the cut in_place whether to modify select_sl """ if in_place: cut_sl = select_sl else: cut_sl = [None] * len(select_sl) for idx, this_sl in enumerate(select_sl): if roi_idxs[0] == -1: min0 = 0 else: min0 = roi_closest[idx, roi_idxs[0]] if roi_idxs[1] == -1: min1 = len(this_sl) else: min1 = roi_closest[idx, roi_idxs[1]] # handle if sls not flipped if min0 > min1: min0, min1 = min1, min0 # If the point that is closest to the first ROI # is the same as the point closest to the second ROI, # include the surrounding points to make a streamline. if min0 == min1: min1 = min1 + 1 min0 = min0 - 1 cut_sl[idx] = this_sl[min0:min1] return cut_sl
[docs] def orient_by_streamline(sls, template_sl): DM = bundles_distances_mdf(sls, [template_sl, template_sl[::-1]]) return DM[:, 0] > DM[:, 1]
[docs] def resample_tg(tg, n_points): # reformat for dipy's set_number_of_points if isinstance(tg, np.ndarray): if len(tg.shape) > 2: streamlines = tg.tolist() streamlines = [np.asarray(item) for item in streamlines] elif hasattr(tg, "streamlines"): streamlines = tg.streamlines else: streamlines = tg return dps.set_number_of_points(streamlines, n_points)
[docs] class SlsBeingRecognized: def __init__(self, sls, save_intermediates, b_name, ref, n_roi):
[docs] self.oriented_yet = False
[docs] self.selected_fiber_idxs = np.arange(len(sls), dtype=np.uint32)
[docs] self.sls_flipped = np.zeros(len(sls), dtype=np.bool_)
[docs] self.start_time = -1
[docs] self.save_intermediates = save_intermediates
[docs] self.b_name = b_name
[docs] self.ref_sls = sls
[docs] self.ref = ref
[docs] self.n_roi = n_roi
[docs] def initiate_selection(self, clean_name): self.start_time = time() tqdm.write(f"Filtering by {clean_name}") return np.zeros(len(self.selected_fiber_idxs), dtype=np.bool_)
[docs] def select(self, idx, clean_name, cut=False): self.selected_fiber_idxs = self.selected_fiber_idxs[idx] self.sls_flipped = self.sls_flipped[idx] if hasattr(self, "roi_closest"): self.roi_closest = self.roi_closest[idx] if hasattr(self, "roi_dists"): self.roi_dists = self.roi_dists[idx] time_taken = time() - self.start_time tqdm.write( f"After filtering by {clean_name} (time: {time_taken}s), " f"{len(self)} streamlines remain." ) # Only save intermediates after the 90% of the # streamlines have been filtered out, # otherwise its impractical if self.save_intermediates is not None and len(self) < 0.1 * len(self.ref_sls): save_tractogram( StatefulTractogram( self.get_selected_sls(cut=cut), self.ref, Space.RASMM ), op.join( self.save_intermediates, f"sls_after_{clean_name}_for_{self.b_name}.trk", ), bbox_valid_check=False, )
[docs] def get_selected_sls(self, cut=False, flip=False): selected_sls = self.ref_sls[self.selected_fiber_idxs] if cut and hasattr(self, "roi_closest") and self.n_roi > 1: selected_sls = cut_sls_by_closest( selected_sls, self.roi_closest, (0, self.n_roi - 1), in_place=False ) if flip: selected_sls = flip_sls(selected_sls, self.sls_flipped, in_place=False) return selected_sls
[docs] def reorient(self, idx): if self.oriented_yet: raise RuntimeError( ( "Attempted to oriented streamlines " "that were already oriented. " "This is a bug in the implementation of a " "bundle recognition procedure. " ) ) self.oriented_yet = True self.sls_flipped[idx] = True
[docs] def __bool__(self): return len(self) > 0
[docs] def __len__(self): return len(self.selected_fiber_idxs)
[docs] def copy(self, new_name, n_roi): new_copy = copy.copy(self) new_copy.b_name = new_name if n_roi > 0: if self.n_roi > 0: raise NotImplementedError( ( "You cannot have includes in the original bundle and" " subbundles; only one or the other." ) ) else: new_copy.n_roi = n_roi new_copy.selected_fiber_idxs = self.selected_fiber_idxs.copy() new_copy.sls_flipped = self.sls_flipped.copy() if hasattr(self, "roi_closest"): new_copy.roi_closest = self.roi_closest.copy() if hasattr(self, "roi_dists"): new_copy.roi_dists = self.roi_dists.copy() return new_copy
[docs] def export_selected(self, chunk_offset): return { "global_idx": ( self.selected_fiber_idxs.astype(np.int64) + int(chunk_offset) ).copy(), "sls_flipped": self.sls_flipped.copy(), "oriented_yet": self.oriented_yet, "roi_closest": ( self.roi_closest.copy() if hasattr(self, "roi_closest") else None ), "roi_dists": ( self.roi_dists.copy() if hasattr(self, "roi_dists") else None ), }
@classmethod
[docs] def from_selected( cls, survivor_dicts, full_streamlines, save_intermediates, b_name, ref, n_roi, ): non_empty = [d for d in survivor_dicts if d["global_idx"].size > 0] if not non_empty: return None global_idx = np.concatenate([d["global_idx"] for d in non_empty]) sls_flipped = np.concatenate([d["sls_flipped"] for d in non_empty]) oriented_yet = any(d["oriented_yet"] for d in non_empty) if global_idx.size > 1 and not np.all(np.diff(global_idx) > 0): order = np.argsort(global_idx, kind="stable") global_idx = global_idx[order] sls_flipped = sls_flipped[order] else: order = None has_roi = [d["roi_closest"] is not None for d in non_empty] if any(has_roi): if not all(has_roi): raise RuntimeError( "Inconsistent roi_closest across chunks for bundle " f"{b_name}: some chunks have it, some don't. This is a" " bug in chunked recognition." ) roi_closest = np.concatenate([d["roi_closest"] for d in non_empty], axis=0) roi_dists = np.concatenate([d["roi_dists"] for d in non_empty], axis=0) if order is not None: roi_closest = roi_closest[order] roi_dists = roi_dists[order] else: roi_closest = None roi_dists = None inst = cls( full_streamlines, save_intermediates, b_name, ref, n_roi, ) inst.selected_fiber_idxs = global_idx.astype(np.uint32, copy=False) inst.sls_flipped = sls_flipped inst.oriented_yet = oriented_yet if roi_closest is not None: inst.roi_closest = roi_closest inst.roi_dists = roi_dists return inst