Source code for AFQ.recognition.criteria

import logging
from time import time

import dipy.tracking.streamline as dts
import nibabel as nib
import numpy as np
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import load_tractogram
from dipy.segment.bundles import RecoBundles
from dipy.segment.clustering import QuickBundles
from dipy.segment.featurespeed import ResampleFeature
from dipy.segment.metricspeed import AveragePointwiseEuclideanMetric
from scipy.ndimage import distance_transform_edt
from tqdm import tqdm
from trx.io import load as load_trx

import AFQ.recognition.cleaning as abc
import AFQ.recognition.curvature as abv
import AFQ.recognition.other_bundles as abo
import AFQ.recognition.roi as abr
import AFQ.recognition.utils as abu
from AFQ.api.bundle_dict import apply_to_roi_dict
from AFQ.recognition.clustering import subcluster_by_atlas
from AFQ.recognition.preprocess import PreprocPlan
from AFQ.recognition.utils import resample_tg
from AFQ.utils.streamlines import move_streamlines

# Criteria that are purely per-streamline and safe to run on a chunk
# without needing to see the rest of the tractogram. These run in the
# chunk-local phase.
[docs] criteria_order_chunk_local = [ "length", "endpoint_dists", "cross_midline", "start", "end", "prob_map", "primary_axis", "include", "exclude", "curvature", ]
# RecoBundles needs the whole candidate pool for a bundle, so it runs # in the global phase even though it's nominally a "pre-other-bundles" # criterion.
[docs] criteria_order_pre_other_bundles = criteria_order_chunk_local + ["recobundles"]
[docs] criteria_order_post_other_bundles = ["orient_mahal", "isolation_forest", "qb_thresh"]
[docs] valid_noncriterion = [ "space", "mahal", "inc_addtol", "exc_addtol", "exact_endpoints", "ORG_spectral_subbundles", "cluster_IDs", "startpoint_location", "endpoint_location", "primary_axis_core_only", ]
[docs] logger = logging.getLogger("AFQ")
[docs] def prob_map(b_sls, bundle_def, preproc_plan, prob_threshold, img, **kwargs): b_sls.initiate_selection("Prob. Map") fiber_probabilities = dts.values_from_volume( bundle_def["prob_map"].get_fdata(), preproc_plan.fgarray[b_sls.selected_fiber_idxs], img.affine, ) fiber_probabilities = np.mean(fiber_probabilities, -1) b_sls.select(fiber_probabilities > prob_threshold, "Prob. Map")
[docs] def cross_midline(b_sls, bundle_def, preproc_plan, **kwargs): b_sls.initiate_selection("Cross Mid.") accepted = preproc_plan.crosses[b_sls.selected_fiber_idxs] if not bundle_def["cross_midline"]: accepted = np.invert(accepted) b_sls.select(accepted, "Cross Mid.")
[docs] def start(b_sls, bundle_def, preproc_plan, **kwargs): b_sls.initiate_selection("Startpoint") exact_endpoints = bundle_def.get("exact_endpoints", False) if exact_endpoints: tol = 0 else: tol = kwargs["dist_to_atlas"] accept_idx = abr.clean_by_endpoints( preproc_plan.fgarray[b_sls.selected_fiber_idxs], bundle_def["start"], 0, tol=tol, flip_sls=b_sls.sls_flipped, ) if not b_sls.oriented_yet: accepted_idx_flipped = abr.clean_by_endpoints( preproc_plan.fgarray[b_sls.selected_fiber_idxs], bundle_def["start"], -1, tol=tol, ) new_accept_idx = np.logical_or(accepted_idx_flipped, accept_idx) special_idx = np.logical_and(accept_idx, accepted_idx_flipped) special_idx_to_flip = abu.manual_orient_sls( preproc_plan.fgarray[b_sls.selected_fiber_idxs][special_idx] ) accepted_idx_flipped[special_idx] = special_idx_to_flip b_sls.reorient(accepted_idx_flipped) accept_idx = new_accept_idx b_sls.select(accept_idx, "Startpoint")
[docs] def end(b_sls, bundle_def, preproc_plan, **kwargs): b_sls.initiate_selection("endpoint") exact_endpoints = bundle_def.get("exact_endpoints", False) if exact_endpoints: tol = 0 else: tol = kwargs["dist_to_atlas"] accept_idx = abr.clean_by_endpoints( preproc_plan.fgarray[b_sls.selected_fiber_idxs], bundle_def["end"], -1, tol=tol, flip_sls=b_sls.sls_flipped, ) if not b_sls.oriented_yet: accepted_idx_flipped = abr.clean_by_endpoints( preproc_plan.fgarray[b_sls.selected_fiber_idxs], bundle_def["end"], 0, tol=tol, ) new_accept_idx = np.logical_or(accepted_idx_flipped, accept_idx) special_idx = np.logical_and(accept_idx, accepted_idx_flipped) special_idx_to_flip = abu.manual_orient_sls( preproc_plan.fgarray[b_sls.selected_fiber_idxs][special_idx] ) accepted_idx_flipped[special_idx] = special_idx_to_flip b_sls.reorient(accepted_idx_flipped) accept_idx = new_accept_idx b_sls.select(accept_idx, "endpoint")
[docs] def length(b_sls, bundle_def, preproc_plan, **kwargs): b_sls.initiate_selection("length") min_len = bundle_def["length"].get("min_len", 0) max_len = bundle_def["length"].get("max_len", np.inf) sl_lens = preproc_plan.lengths accept_idx = (sl_lens >= min_len) & (sl_lens <= max_len) b_sls.select(accept_idx, "length")
[docs] def endpoint_dists(b_sls, bundle_def, preproc_plan, **kwargs): b_sls.initiate_selection("endpoint_dists") min_dist = bundle_def["endpoint_dists"].get("min_dist", 0) max_dist = bundle_def["endpoint_dists"].get("max_dist", np.inf) sl_endpoint_dists = preproc_plan.endpoint_dists[b_sls.selected_fiber_idxs] accept_idx = (sl_endpoint_dists >= min_dist) & (sl_endpoint_dists <= max_dist) b_sls.select(accept_idx, "endpoint_dists")
[docs] def primary_axis(b_sls, bundle_def, **kwargs): b_sls.initiate_selection("orientation") accept_idx = abc.clean_by_orientation( b_sls.get_selected_sls(), bundle_def["primary_axis"], bundle_def.get("primary_axis_core_only", 0.6), ) b_sls.select(accept_idx, "orientation")
[docs] def include(b_sls, bundle_def, **kwargs): accept_idx = b_sls.initiate_selection("include") flip_using_include = len(bundle_def["include"]) > 1 and not b_sls.oriented_yet if "inc_addtol" in bundle_def: include_roi_tols = [] for inc_tol in bundle_def["inc_addtol"]: include_roi_tols.append((inc_tol / kwargs["vox_dim"] + kwargs["tol"]) ** 2) else: include_roi_tols = [kwargs["tol"] ** 2] * len(bundle_def["include"]) inc_results = abr.check_sls_with_inclusion( b_sls.get_selected_sls(), bundle_def["include"], include_roi_tols ) n_inc = len(bundle_def["include"]) roi_closest = np.zeros((n_inc, len(b_sls)), dtype=np.int32) roi_dists = np.zeros((n_inc, len(b_sls)), dtype=np.float32) if flip_using_include: to_flip = np.ones_like(accept_idx, dtype=np.bool_) for sl_idx, inc_result in enumerate(inc_results): sl_accepted, sl_closest, sl_dists = inc_result if sl_accepted: roi_closest[:, sl_idx] = sl_closest roi_dists[:, sl_idx] = sl_dists if len(sl_closest) > 1: if (len(sl_closest) < 2) or abs(sl_closest[0] - sl_closest[-1]) > 1: if flip_using_include: to_flip[sl_idx] = sl_closest[0] > sl_closest[-1] if to_flip[sl_idx]: roi_closest[:, sl_idx] = np.flip(sl_closest) roi_dists[:, sl_idx] = np.flip(sl_dists) accept_idx[sl_idx] = 1 else: accept_idx[sl_idx] = 1 b_sls.roi_closest = roi_closest.T b_sls.roi_dists = roi_dists.T if flip_using_include: b_sls.reorient(to_flip) b_sls.select(accept_idx, "include")
[docs] def curvature(b_sls, bundle_def, mapping, img, save_intermediates, **kwargs): """ Filters streamlines by how well they match a curve in orientation and shape but not scale """ accept_idx = b_sls.initiate_selection("curvature") if "sft" in bundle_def["curvature"]: ref_sl = bundle_def["curvature"]["sft"] else: ref_sl = load_tractogram( bundle_def["curvature"]["path"], "same", bbox_valid_check=False ) moved_ref_sl = move_streamlines( ref_sl, "subject", mapping, img, save_intermediates=save_intermediates ) moved_ref_sl = moved_ref_sl.streamlines[0] moved_ref_curve = abv.sl_curve(moved_ref_sl, len(moved_ref_sl)) ref_curve_threshold = np.radians(bundle_def["curvature"].get("thresh", 10)) cut = bundle_def["curvature"].get("cut", True) for idx, sl in enumerate(b_sls.get_selected_sls(cut=cut, flip=True)): if len(sl) > 1: this_sl_curve = abv.sl_curve(sl, len(moved_ref_sl)) dist = abv.sl_curve_dist(this_sl_curve, moved_ref_curve) if dist <= ref_curve_threshold: accept_idx[idx] = 1 b_sls.select(accept_idx, "curvature", cut=cut)
[docs] def exclude(b_sls, bundle_def, **kwargs): accept_idx = b_sls.initiate_selection("exclude") if "exc_addtol" in bundle_def: exclude_roi_tols = [] for exc_tol in bundle_def["exc_addtol"]: exclude_roi_tols.append((exc_tol / kwargs["vox_dim"] + kwargs["tol"]) ** 2) else: exclude_roi_tols = [kwargs["tol"] ** 2] * len(bundle_def["exclude"]) for sl_idx, sl in enumerate(b_sls.get_selected_sls()): if abr.check_sl_with_exclusion(sl, bundle_def["exclude"], exclude_roi_tols): accept_idx[sl_idx] = 1 b_sls.select(accept_idx, "exclude")
[docs] def recobundles( b_sls, mapping, bundle_def, reg_template, img, refine_reco, save_intermediates, rng, rb_recognize_params, **kwargs, ): b_sls.initiate_selection("Recobundles") moved_sl = move_streamlines( StatefulTractogram(b_sls.get_selected_sls(), img, Space.RASMM), "template", mapping, reg_template, to_space=Space.RASMM, save_intermediates=save_intermediates, ).streamlines moved_sl_resampled = abu.resample_tg(moved_sl, 100) rb = RecoBundles(moved_sl, verbose=True, rng=rng) _, rec_labels = rb.recognize(bundle_def["recobundles"]["sl"], **rb_recognize_params) if refine_reco: _, rec_labels = rb.refine( bundle_def["recobundles"]["sl"], moved_sl_resampled[rec_labels], **rb_recognize_params, ) if not b_sls.oriented_yet and np.sum(rec_labels) > 0: standard_sl = next(iter(bundle_def["recobundles"]["centroid"])) oriented_idx = abu.orient_by_streamline(moved_sl[rec_labels], standard_sl) b_sls.reorient(rec_labels[oriented_idx]) rec_labels = sorted(rec_labels) b_sls.select(rec_labels, "Recobundles")
[docs] def qb_thresh(b_sls, bundle_def, clip_edges, **kwargs): b_sls.initiate_selection("qb_thresh") cut = clip_edges or ("bundlesection" in bundle_def) qbx = QuickBundles( bundle_def["qb_thresh"], AveragePointwiseEuclideanMetric(ResampleFeature(nb_points=12)), ) clusters = qbx.cluster(b_sls.get_selected_sls(cut=cut, flip=True)) cleaned_idx = clusters[np.argmax(clusters.clusters_sizes())].indices b_sls.select(cleaned_idx, "qb_thresh", cut=cut)
[docs] def clean_by_other_bundle( b_sls, bundle_def, img, other_bundle_name, other_bundle_sls, **kwargs ): cleaned_idx = b_sls.initiate_selection(other_bundle_name) cleaned_idx = 1 flipped_sls = b_sls.get_selected_sls(flip=True) if "overlap" in bundle_def[other_bundle_name]: cleaned_idx_overlap = abo.clean_by_overlap( flipped_sls, other_bundle_sls, bundle_def[other_bundle_name]["overlap"], img, remove=False, project=bundle_def[other_bundle_name].get("project", None), ) cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_overlap) if "node_thresh" in bundle_def[other_bundle_name]: cleaned_idx_node_thresh = abo.clean_by_overlap( flipped_sls, other_bundle_sls, bundle_def[other_bundle_name]["node_thresh"], img, remove=True, project=bundle_def[other_bundle_name].get("project", None), ) cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_node_thresh) if "core" in bundle_def[other_bundle_name]: consideration = bundle_def[other_bundle_name].get("consideration", 10.0) if isinstance(consideration, (int, float)): consideration = float(consideration) consideration = consideration / kwargs["vox_dim"] cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["core"].lower(), np.array(abu.resample_tg(flipped_sls, 100)), np.array(abu.resample_tg(other_bundle_sls, 100)), consideration=consideration, ) cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_core) b_sls.select(cleaned_idx, other_bundle_name)
[docs] def orient_mahal(b_sls, bundle_def, **kwargs): b_sls.initiate_selection("orient_mahal") accept_idx = abc.clean_by_orientation_mahalanobis( b_sls.get_selected_sls(), **bundle_def.get("orient_mahal", {}) ) b_sls.select(accept_idx, "orient_mahal")
[docs] def isolation_forest(b_sls, bundle_def, rng, **kwargs): b_sls.initiate_selection("isolation_forest") accept_idx = abc.clean_by_isolation_forest( b_sls.get_selected_sls(), distance_threshold=bundle_def["isolation_forest"].get("distance_threshold", 3), n_rounds=bundle_def["isolation_forest"].get("n_rounds", 5), random_state=rng, ) b_sls.select(accept_idx, "isolation_forest")
[docs] def mahalanobis(b_sls, bundle_def, clip_edges, cleaning_params, **kwargs): b_sls.initiate_selection("Mahalanobis") clean_params = bundle_def.get("mahal", {}) clean_params = {**cleaning_params, **clean_params} clean_params["return_idx"] = True cut = clip_edges or ("bundlesection" in bundle_def) _, cleaned_idx = abc.clean_bundle( b_sls.get_selected_sls(cut=cut, flip=True), **clean_params ) b_sls.select(cleaned_idx, "Mahalanobis", cut=cut)
[docs] def _prepare_bundle_def(bundle_dict, bundle_name, mapping, img): """ Warp ROIs and apply distance-transform conversion """ tqdm.write(f"Preparing ROIs for {bundle_name}") start_time = time() bundle_def = dict(bundle_dict.get_b_info(bundle_name)) bundle_def.update( bundle_dict.transform_rois(bundle_name, mapping, img, apply_to_recobundles=True) ) def check_space(roi): if not np.allclose(img.affine, roi.affine): logger.warning( "Resampling set to False in case where affines " "do not match. This is likely due to subject space ROIs" " not being in the right space. This found for bundle " f"{bundle_name}" ) apply_to_roi_dict(bundle_def, check_space, dry_run=True, apply_to_prob_map=True) apply_to_roi_dict( bundle_def, lambda roi_img: nib.Nifti1Image( distance_transform_edt(np.where(roi_img.get_fdata() == 0, 1, 0)), roi_img.affine, ), dry_run=False, apply_to_recobundles=False, apply_to_prob_map=False, ) tqdm.write(f"Time to prep ROIs: {time() - start_time}s") return bundle_def
[docs] def _validate_criteria(bundle_def, bundle_name, bundle_dict, recognized_bundles_dict): for potential_criterion in bundle_def.keys(): if ( (potential_criterion not in criteria_order_post_other_bundles) and (potential_criterion not in criteria_order_pre_other_bundles) and (potential_criterion not in recognized_bundles_dict.keys()) and (potential_criterion not in valid_noncriterion) ): if potential_criterion in bundle_dict.bundle_names: raise ValueError( f"Bundle {potential_criterion} is being used as a criterion in " f"the definition of bundle {bundle_name}, however this bundle " "was not found. This could be because of insufficient streamlines" ) else: raise ValueError( "Invalid criterion in bundle definition:\n" f"{potential_criterion} in bundle {bundle_name}.\n" "Valid criteria are:\n" f"{criteria_order_pre_other_bundles}\n" f"{criteria_order_post_other_bundles}\n" f"{recognized_bundles_dict.keys()}\n" f"{valid_noncriterion}\n" )
[docs] def _run_chunk_local( bundle_def, chunk_streamlines, bundle_name, img, preproc_plan, save_intermediates, vox_dim, tol, dist_to_atlas, **segmentation_params, ): b_sls = abu.SlsBeingRecognized( chunk_streamlines, save_intermediates, bundle_name, img, len(bundle_def.get("include", [])), ) inputs = { "b_sls": b_sls, "preproc_plan": preproc_plan, "bundle_def": bundle_def, "img": img, "save_intermediates": save_intermediates, "vox_dim": vox_dim, "tol": tol, "dist_to_atlas": dist_to_atlas, } inputs.update(segmentation_params) for criterion in criteria_order_chunk_local: if b_sls and criterion in bundle_def: inputs[criterion] = globals()[criterion](**inputs) return b_sls
[docs] def _run_global_phase( bundle_def, bundle_name, b_sls, fgarray_for_candidates, candidate_global_idx, mapping, img, reg_template, preproc_scalars, recognized_bundles_dict, vox_dim, tol, dist_to_atlas, is_subbundle=False, **segmentation_params, ): if not b_sls: return inputs = { "b_sls": b_sls, "preproc_plan": preproc_scalars, "bundle_def": bundle_def, "mapping": mapping, "img": img, "reg_template": reg_template, "vox_dim": vox_dim, "tol": tol, "dist_to_atlas": dist_to_atlas, } inputs.update(segmentation_params) if "recobundles" in bundle_def: recobundles(**inputs) if b_sls: for o_bundle_name in recognized_bundles_dict.keys(): if o_bundle_name in bundle_def.keys(): clean_by_other_bundle( **inputs, other_bundle_name=o_bundle_name, other_bundle_sls=recognized_bundles_dict[ o_bundle_name ].get_selected_sls(flip=True), ) for criterion in criteria_order_post_other_bundles: if b_sls and criterion in bundle_def: inputs[criterion] = globals()[criterion](**inputs) if b_sls: if "mahal" in bundle_def or ( "isolation_forest" not in bundle_def and "orient_mahal" not in bundle_def and "ORG_spectral_subbundles" not in bundle_def ): mahalanobis(**inputs) # Wrong-side-of-midline cleanup. fgarray_for_candidates is in # candidate-local order; b_sls.selected_fiber_idxs is in global # order. searchsorted translates between them. if ( b_sls and not is_subbundle and "cross_midline" in bundle_def and not bundle_def["cross_midline"] and fgarray_for_candidates is not None and candidate_global_idx is not None ): pos = np.searchsorted(candidate_global_idx, b_sls.selected_fiber_idxs) b_sls.initiate_selection("Wrong side of mid.") avg_side = np.sign(np.mean(fgarray_for_candidates[pos, :, 0], axis=1)) majority_side = np.sign(np.sum(avg_side)) b_sls.select(avg_side == majority_side, "Wrong side of mid.") if b_sls and not b_sls.oriented_yet: raise ValueError( "pyAFQ was unable to consistently orient streamlines " f"in bundle {bundle_name} using the provided ROIs. " "This can be fixed by including at least 2 " "waypoint ROIs, or by using endpoint ROIs." ) if not b_sls: return if "ORG_spectral_subbundles" in bundle_def: if is_subbundle: raise ValueError("Nested ORG_spectral_subbundles are not supported.") subdict = bundle_def["ORG_spectral_subbundles"] b_sls.initiate_selection( f"ORG spectral clustering, {len(subdict.bundle_names)} " "subbundles being recognized" ) sub_sft = StatefulTractogram( b_sls.get_selected_sls(flip=True), img, Space.RASMM ) cluster_labels = subcluster_by_atlas( sub_sft, mapping, img, subdict.all_cluster_IDs, n_points=40 ) for sub_b_name in subdict.bundle_names: c_ids = subdict._dict[sub_b_name]["cluster_IDs"] n_roi = len(subdict._dict[sub_b_name].get("include", [])) cluster_b_sls = b_sls.copy(sub_b_name, n_roi) selected = np.zeros(len(b_sls), dtype=bool) for c_id in c_ids: selected = np.logical_or(selected, cluster_labels == c_id) cluster_b_sls.select(selected, f"Clusters {c_ids}") sub_bundle_def = _prepare_bundle_def(subdict, sub_b_name, mapping, img) _validate_criteria( sub_bundle_def, sub_b_name, subdict, recognized_bundles_dict ) _run_global_phase( sub_bundle_def, sub_b_name, cluster_b_sls, None, None, mapping, img, reg_template, preproc_scalars, recognized_bundles_dict, vox_dim, tol, dist_to_atlas, is_subbundle=True, **segmentation_params, ) else: b_sls.bundle_def = bundle_def recognized_bundles_dict[bundle_name] = b_sls
[docs] def recognize_bundles( tg, bundle_dict, mapping, img, reg_template, chunk_size, dist_to_waypoint, dist_to_atlas, save_intermediates, **segmentation_params, ): if isinstance(tg, str): tg_path = tg tg = load_trx(tg_path, img) else: tg_path = None n_streamlines = len(tg) recognized_bundles_dict = {} tqdm.write( f"Recognizing bundles over {n_streamlines} streamlines " f"in chunks of {chunk_size}" ) tol, dist_to_atlas, vox_dim = abu.tolerance_mm_to_vox( img, dist_to_waypoint, dist_to_atlas ) preproc_scalars = { "vox_dim": vox_dim, "tol": tol, "dist_to_atlas": dist_to_atlas, } bundle_defs = {} survivor_dicts = {} for bundle_name in bundle_dict.bundle_names: bd = _prepare_bundle_def(bundle_dict, bundle_name, mapping, img) bundle_defs[bundle_name] = bd survivor_dicts[bundle_name] = [] total_chunks = (n_streamlines + chunk_size - 1) // chunk_size for chunk_start in tqdm( range(0, n_streamlines, chunk_size), total=total_chunks, desc="Batched Portion of Recognition", ): chunk_end = min(chunk_start + chunk_size, n_streamlines) tqdm.write( f"Processing chunk {chunk_start}:{chunk_end} of {n_streamlines} " f"({(chunk_end / n_streamlines) * 100:.2f}%)" ) if tg_path is not None and tg is None: tg = load_trx(tg_path, img) chunk_streamlines = tg.streamlines[chunk_start:chunk_end].copy() if tg_path is not None: tg.close() tg = None chunk_preproc = PreprocPlan(chunk_streamlines) for bundle_name in bundle_dict.bundle_names: tqdm.write(f"Running chunk-local phase for bundle {bundle_name}") chunk_b_sls = _run_chunk_local( bundle_defs[bundle_name], chunk_streamlines, bundle_name, img, chunk_preproc, save_intermediates, mapping=mapping, reg_template=reg_template, vox_dim=vox_dim, tol=tol, dist_to_atlas=dist_to_atlas, **segmentation_params, ) survivor_dicts[bundle_name].append(chunk_b_sls.export_selected(chunk_start)) del chunk_b_sls del chunk_preproc, chunk_streamlines if tg_path is not None: tg = load_trx(tg_path, img) for bundle_name in bundle_dict.bundle_names: tqdm.write(f"Running global phase for bundle {bundle_name}") bundle_def = bundle_defs[bundle_name] merged = abu.SlsBeingRecognized.from_selected( survivor_dicts[bundle_name], tg.streamlines, save_intermediates, bundle_name, img, len(bundle_def.get("include", [])), ) survivor_dicts[bundle_name] = None # free per-chunk dicts if merged is None: tqdm.write( f"Bundle {bundle_name}: 0 candidates after chunk-local filtering" ) continue _validate_criteria( bundle_def, bundle_name, bundle_dict, recognized_bundles_dict ) tqdm.write( f"Bundle {bundle_name}: {len(merged)} candidates after " "chunk-local filtering" ) need_fgarray = "cross_midline" in bundle_def and not bundle_def["cross_midline"] if need_fgarray: candidate_global_idx = np.array(merged.selected_fiber_idxs, dtype=np.int64) cand_streamlines = [tg.streamlines[int(i)] for i in candidate_global_idx] start_time = time() fgarray_for_candidates = np.asarray( resample_tg(cand_streamlines, 20), dtype=np.float32 ) tqdm.write(f"Resampling took {time() - start_time:.2f} seconds") del cand_streamlines else: candidate_global_idx = None fgarray_for_candidates = None _run_global_phase( bundle_def, bundle_name, merged, fgarray_for_candidates, candidate_global_idx, mapping, img, reg_template, preproc_scalars, recognized_bundles_dict, vox_dim, tol, dist_to_atlas, save_intermediates=save_intermediates, **segmentation_params, ) return recognized_bundles_dict, n_streamlines