Source code for AFQ.recognition.sparse_decisions

import numpy as np
from scipy.sparse import csr_matrix


[docs] def compute_sparse_decisions(bundles_being_recognized, n_streamlines): """ Compute a sparse matrix of distances to ROIs for the streamlines that are currently being recognized. This can be used to weight decisions by distance to ROIs, without having to create a dense matrix of distances for all streamlines and all bundles. Parameters ---------- bundles_being_recognized : dict A dictionary of SlsBeingRecognized objects, keyed by bundle name. n_streamlines : int The total number of streamlines in the original tractogram. Returns ------- csr_matrix A sparse matrix of shape (number of bundles being recognized, n_streamlines), where the entry (i, j) is a score: bundles with ROIs result in weights [2.0 to 3.0] with higher scores for streamlines closer to ROIs Non-ROI bundles result in weight 1.0 Everything else is 0.0 (implicit in sparse matrices) """ rows, cols, data = [], [], [] epsilon = 1e-6 global_max_dist = 0.0 for b in bundles_being_recognized.values(): if hasattr(b, "roi_dists"): global_max_dist = max(global_max_dist, np.sum(b.roi_dists, axis=-1).max()) norm_factor = global_max_dist + 1.0 for b_idx, name in enumerate(bundles_being_recognized.keys()): bundle = bundles_being_recognized[name] indices = bundle.selected_fiber_idxs if hasattr(bundle, "roi_dists"): dists = np.sum(bundle.roi_dists, axis=-1) dists = np.maximum(dists, epsilon) bundle_weights = dists / norm_factor else: bundle_weights = np.full(len(indices), 2.0, dtype=np.float32) rows.extend([b_idx] * len(indices)) cols.extend(indices) data.extend(bundle_weights) sparse_scores = csr_matrix( (data, (rows, cols)), shape=(len(bundles_being_recognized), n_streamlines) ) # Final Decision: 3.0 - Score # ROI bundles result in weights [2.0 to 3.0] # No-ROI bundles result in weight 1.0 sparse_scores.data = 3.0 - sparse_scores.data return sparse_scores
[docs] def get_conflict_count(sparse_scores): """ Count how many streamlines are being considered for more than one bundle """ sorted_indices = np.sort(sparse_scores.indices) is_duplicate = np.diff(sorted_indices) == 0 num_conflicts = np.sum(is_duplicate) return num_conflicts
[docs] def remove_conflicts(sparse_scores, bundles_being_recognized): """ Returns a dictionary of {bundle_name: np.array(accepted_indices)} """ coo = sparse_scores.tocoo() order = np.lexsort((-coo.data, coo.col)) mask = np.concatenate(([True], np.diff(coo.col[order]) != 0)) winner_rows = coo.row[order][mask] winner_cols = coo.col[order][mask] row_sort = np.argsort(winner_rows) winner_rows = winner_rows[row_sort] winner_cols = winner_cols[row_sort] num_bundles = len(bundles_being_recognized) split_indices = np.searchsorted(winner_rows, np.arange(num_bundles + 1)) for i, b_name in enumerate(list(bundles_being_recognized.keys())): b_sls = bundles_being_recognized[b_name] if np.any(b_sls.selected_fiber_idxs[:-1] > b_sls.selected_fiber_idxs[1:]): raise NotImplementedError( f"Bundle '{b_name}' has unsorted selected_fiber_idxs. " "The searchsorted optimization requires sorted indices." "This is a bug in the implementation of the bundle " "recognition procedure, please report it to the developers." ) accept_idx = b_sls.initiate_selection(f"{b_name} conflicts") start, end = split_indices[i], split_indices[i + 1] bundle_winners = winner_cols[start:end] if len(bundle_winners) > 0: local_positions = np.searchsorted(b_sls.selected_fiber_idxs, bundle_winners) accept_idx[local_positions] = True b_sls.select(local_positions, "conflicts") else: b_sls.select(accept_idx, "conflicts") bundles_being_recognized.pop(b_name)