Source code for AFQ.recognition.cleaning

import logging

import dipy.tracking.streamline as dts
import numpy as np
from dipy.io.stateful_tractogram import StatefulTractogram
from dipy.stats.analysis import assignment_map
from scipy.stats import zscore
from sklearn.ensemble import IsolationForest

import AFQ.recognition.utils as abu
from AFQ._fixes import gaussian_weights

[docs] logger = logging.getLogger("AFQ")
[docs] def clean_by_orientation(streamlines, primary_axis, core_only=0.6): """ Retain streamlines whose core is oriented along the primary axis and have endpoints that are also oriented along the primary axis and have a majority of their steps along the primary axis. Parameters ---------- streamlines : sequence of N by 3 arrays Where N is number of nodes in the array, the collection of streamlines to filter down to. core_only : float, optional If non-zero, only the core of the bundle is used for cleaning. The core is defined as the middle 60% of each streamline, thus our default is 0.6. This means streamlines are allowed to deviate in the starting and ending 20% of the bundle. This is useful for allowing more diverse endpoints. Default: 0.6 Returns ------- cleaned_idx, indices of streamlines that passed cleaning """ axes_names = ["L/R", "P/A", "I/S"] if primary_axis not in axes_names: raise ValueError( f"Primary axis must be one of {axes_names}, got {primary_axis}" ) primary_axis = abu.axes_dict[primary_axis] core_accepted_idx = np.zeros(len(streamlines), dtype=bool) if core_only != 0: crop_edge = (1.0 - core_only) / 2 for ii, sl in enumerate(streamlines): n_points = len(sl) along_diff = np.abs( np.diff( sl[int(n_points * crop_edge) : int(n_points * (1 - crop_edge))], axis=0, ) ) # The majority of steps must be in the primary axis direction: core_accepted_idx[ii] = np.sum( np.argmax(along_diff, axis=1) == primary_axis ) > (len(along_diff) / 2) endpoint_diff = np.zeros((len(streamlines), 3)) for ii, sl in enumerate(streamlines): endpoint_diff[ii, :] = np.abs(sl[0, :] - sl[-1, :]) orientation_end = np.argmax(endpoint_diff, axis=1) end_accepted_idx = orientation_end == primary_axis cleaned_idx = np.logical_and(end_accepted_idx, core_accepted_idx) return cleaned_idx
[docs] def clean_by_orientation_mahalanobis( streamlines, n_points=100, core_only=0, min_sl=20, distance_threshold=3, length_threshold=4, clean_rounds=5, remove_lengths="long", ): if length_threshold == 0: length_threshold = np.inf fgarray = abu.resample_tg(streamlines, n_points) _, assignment_idxs = np.asarray(assignment_map(fgarray, fgarray, n_points)) assignment_idxs = assignment_idxs.reshape((len(fgarray), n_points)) fgarray = np.asarray(fgarray) if core_only != 0: crop_edge = (1.0 - core_only) / 2 fgarray = fgarray[ :, int(n_points * crop_edge) : int(n_points * (1 - crop_edge)), : ] fgarray_dists = fgarray[:, 1:, :] - fgarray[:, :-1, :] assignment_idxs = assignment_idxs[:, 1:] lengths = np.array([sl.shape[0] for sl in streamlines]) idx = np.arange(len(fgarray)) rounds_elapsed = 0 while rounds_elapsed < clean_rounds: m_dist = gaussian_weights( fgarray_dists, assignment_idxs=assignment_idxs, return_mahalanobis=True, n_points=None, stat=np.mean, ) length_z = zscore(lengths) logger.debug(f"Shape of fgarray: {np.asarray(fgarray_dists).shape}") logger.debug((f"Maximum m_dist for each fiber: {np.max(m_dist, axis=1)}")) if not ( np.any(m_dist >= distance_threshold) or np.any(length_z >= length_threshold) ): break idx_dist = np.all(m_dist < distance_threshold, axis=-1) if remove_lengths == "long": idx_len = length_z < length_threshold elif remove_lengths == "short": idx_len = length_z > -length_threshold elif remove_lengths == "both": idx_len = np.abs(length_z) < length_threshold else: raise ValueError( f"Invalid value for remove_lengths: {remove_lengths}. " "Expected 'long', 'short', or 'both'." ) idx_belong = np.logical_and(idx_dist, idx_len) if np.sum(idx_belong) < min_sl: idx = idx[np.argsort(np.sum(m_dist, axis=-1))[:min_sl].astype(int)] idx = np.sort(idx) logger.debug( (f"At rounds elapsed {rounds_elapsed}, minimum streamlines reached") ) break else: idx = idx[idx_belong] fgarray_dists = fgarray_dists[idx_belong] lengths = lengths[idx_belong] assignment_idxs = assignment_idxs[idx_belong] rounds_elapsed += 1 logger.debug((f"Rounds elapsed: {rounds_elapsed}, num kept: {len(idx)}")) logger.debug(f"Kept indices: {idx}") return idx
[docs] def clean_bundle( tg, n_points=100, clean_rounds=5, distance_threshold=4, length_threshold=4, min_sl=20, stat=np.mean, core_only=0.6, return_idx=False, remove_lengths="long", ): """ Clean a segmented fiber group based on the Mahalnobis distance of each streamline Parameters ---------- tg : StatefulTractogram class instance or ArraySequence A whole-brain tractogram to be segmented. n_points : int, optional Number of points to resample streamlines to. Default: 100 clean_rounds : int, optional. Number of rounds of cleaning based on the Mahalanobis distance from the mean of extracted bundles. Default: 5 distance_threshold : float, optional. Threshold of cleaning based on the Mahalanobis distance (the units are standard deviations). Default: 4. length_threshold: float, optional Threshold for cleaning based on length (in standard deviations). Length of any streamline should not be *more* than this number of stdevs from the mean length. min_sl : int, optional. Number of streamlines in a bundle under which we will not bother with cleaning outliers. Default: 20. stat : callable or str, optional. The statistic of each node relative to which the Mahalanobis is calculated. Default: `np.mean` (but can also use median, etc.) core_only : float, optional If non-zero, only the core of the bundle is used for cleaning. The core is defined as the middle 60% of each streamline, thus our default is 0.6. This means streamlines are allowed to deviate in the starting and ending 20% of the bundle. This is useful for allowing more diverse endpoints. Default: 0.6 return_idx : bool Whether to return indices in the original streamlines. Default: False. remove_lengths : str Specifies which streamlines to remove based on their length. Options are "long" (remove long streamlines), "short" (remove short streamlines), or "both" (remove both long and short streamlines). Default: "long" Returns ------- A StatefulTractogram class instance containing only the streamlines that have a Mahalanobis distance smaller than `clean_threshold` from the mean of each one of the nodes. """ # Convert string to callable, if that's what you got. if isinstance(stat, str): stat = getattr(np, stat) if hasattr(tg, "streamlines"): streamlines = tg.streamlines else: streamlines = dts.Streamlines(tg) # We don't even bother if there aren't enough streamlines: if len(streamlines) < min_sl: logger.warning( ("Mahalanobis cleaning halted early due to low streamline count") ) if return_idx: return tg, np.arange(len(streamlines)) else: return tg if length_threshold == 0: length_threshold = np.inf # Resample once up-front: fgarray = np.asarray(abu.resample_tg(streamlines, n_points)) if core_only != 0: crop_edge = (1.0 - core_only) / 2 fgarray = fgarray[ :, int(n_points * crop_edge) : int(n_points * (1 - crop_edge)), : ] # Crop to middle 60% # Keep this around, so you can use it for indexing at the very end: idx = np.arange(len(fgarray)) # get lengths of each streamline lengths = np.array([sl.shape[0] for sl in streamlines]) # We'll only do this for clean_rounds rounds_elapsed = 0 while rounds_elapsed < clean_rounds: # This calculates the Mahalanobis for each streamline/node: m_dist = gaussian_weights( fgarray, return_mahalanobis=True, n_points=None, stat=stat ) logger.debug(f"Shape of fgarray: {np.asarray(fgarray).shape}") logger.debug(f"Shape of m_dist: {m_dist.shape}") logger.debug(f"Maximum m_dist: {np.max(m_dist)}") logger.debug((f"Maximum m_dist for each fiber: {np.max(m_dist, axis=1)}")) length_z = zscore(lengths) logger.debug(f"Shape of length_z: {length_z.shape}") logger.debug(f"Maximum length_z: {np.max(length_z)}") logger.debug((f"length_z for each fiber: {length_z}")) if not ( np.any(m_dist >= distance_threshold) or np.any(length_z >= length_threshold) ): break # Select the fibers that have Mahalanobis smaller than the # threshold for all their nodes: idx_dist = np.all(m_dist < distance_threshold, axis=-1) if remove_lengths == "long": idx_len = length_z < length_threshold elif remove_lengths == "short": idx_len = length_z > -length_threshold elif remove_lengths == "both": idx_len = np.abs(length_z) < length_threshold else: raise ValueError( f"Invalid value for remove_lengths: {remove_lengths}. " "Expected 'long', 'short', or 'both'." ) idx_belong = np.logical_and(idx_dist, idx_len) if np.sum(idx_belong) < min_sl: # need to sort and return exactly min_sl: idx = idx[np.argsort(np.sum(m_dist, axis=-1))[:min_sl].astype(int)] idx = np.sort(idx) logger.debug( (f"At rounds elapsed {rounds_elapsed}, minimum streamlines reached") ) break else: # Update by selection: idx = idx[idx_belong] fgarray = fgarray[idx_belong] lengths = lengths[idx_belong] rounds_elapsed += 1 logger.debug((f"Rounds elapsed: {rounds_elapsed}, num kept: {len(idx)}")) logger.debug(f"Kept indices: {idx}") # Select based on the variable that was keeping track of things for us: if hasattr(tg, "streamlines"): out = StatefulTractogram(tg.streamlines[idx], tg, tg.space) else: out = streamlines[idx] if return_idx: return out, idx else: return out
[docs] def clean_by_isolation_forest( tg, n_points=100, distance_threshold=3, length_threshold=4, n_rounds=5, min_sl=20, n_jobs=None, random_state=None, ): """ Use Isolation Forest (IF) to clean streamlines. Nodes are passed to IF, and nodes are assigned anamoly scores. These are re-mapped back on to the streamlines. Streamlines with maximum outlier scores too many s.d. away from the mean outlier score are removed. This is done in several rounds. This is better for cleaning bundles that are not tube-like. Parameters ---------- tg : StatefulTractogram class instance or ArraySequence A whole-brain tractogram to be segmented. n_points : int, optional Number of points to resample streamlines to. Default: 100 distance_threshold : int, optional Streamlines with average node anamoly score below this many s.d. of average node anaomly score are removed. Default: 3 length_threshold: float, optional Threshold for cleaning based on length (in standard deviations). Length of any streamline should not be *more* than this number of stdevs from the mean length. Default: 4. n_rounds : int, optional. Number of rounds of cleaning based on Isolation Forest. Default: 5 min_sl : int, optional. Number of streamlines in a bundle under which we will not bother with cleaning outliers. Default: 20. n_jobs : int, optional Number of parallel jobs to use for LOF. Default: None (single-threaded). random_state : int, optional Random state for IsolationForest. Default: None Returns ------- indices of streamlines that passed cleaning """ if hasattr(tg, "streamlines"): streamlines = tg.streamlines else: streamlines = dts.Streamlines(tg) # We don't even bother if there aren't enough streamlines: if len(streamlines) < min_sl: logger.warning( ("Isolation Forest cleaning not performed due to low streamline count") ) return np.ones(len(streamlines), dtype=bool) if length_threshold == 0: length_threshold = np.inf # Resample once up-front: fgarray = np.asarray(abu.resample_tg(streamlines, n_points)) fgarray_dists = np.zeros_like(fgarray) fgarray_dists[:, 1:, :] = fgarray[:, 1:, :] - fgarray[:, :-1, :] fgarray_dists[:, 0, :] = fgarray_dists[:, 1, :] X_ = np.concatenate( (fgarray.reshape((-1, n_points, 3)), fgarray_dists.reshape((-1, n_points, 3))), axis=1, ) lengths = np.array([sl.shape[0] for sl in streamlines]) idx = np.arange(len(fgarray)) rounds_elapsed = 0 idx_belong = idx while rounds_elapsed < n_rounds: # This calculates the Isolation Forest outlier for each node: lof = IsolationForest(n_jobs=n_jobs, random_state=random_state) lof.fit(X_.reshape(-1, 6)) outliers = lof.score_samples(X_.reshape(-1, 6)) outliers = outliers.reshape((len(idx), n_points)) sl_outliers = np.min(outliers, axis=1) mean_outlier = np.mean(outliers) sd_outlier = np.std(outliers) logger.debug((f"Mean outlier: {mean_outlier}, SD outlier: {sd_outlier}, ")) length_z = zscore(lengths) logger.debug(f"Shape of length_z: {length_z.shape}") logger.debug(f"Maximum length_z: {np.max(length_z)}") logger.debug((f"length_z for each fiber: {length_z}")) idx_len = length_z < length_threshold idx_dist = sl_outliers > mean_outlier - distance_threshold * sd_outlier idx_belong = np.logical_and(idx_dist, idx_len) if len(idx_belong) == len(idx): break if np.sum(idx_belong) < min_sl: # need to sort and return exactly min_sl: idx = idx[np.argsort(-sl_outliers)[:min_sl].astype(int)] idx = np.sort(idx) logger.debug( (f"At rounds elapsed {rounds_elapsed}, minimum streamlines reached") ) break else: # Update by selection: idx = idx[idx_belong] X_ = X_[idx_belong] rounds_elapsed += 1 logger.debug((f"Rounds elapsed: {rounds_elapsed}, num kept: {len(idx)}")) logger.debug(f"Kept indices: {idx}") return idx