Source code for AFQ.recognition.criteria

import logging
from time import time

import numpy as np
import nibabel as nib

from scipy.ndimage import distance_transform_edt

import dipy.tracking.streamline as dts
from dipy.utils.parallel import paramap
from dipy.segment.clustering import QuickBundles
from dipy.segment.metricspeed import AveragePointwiseEuclideanMetric
from dipy.segment.featurespeed import ResampleFeature
from dipy.io.streamline import load_tractogram
from dipy.segment.bundles import RecoBundles
from dipy.io.stateful_tractogram import StatefulTractogram, Space

from AFQ.api.bundle_dict import apply_to_roi_dict
import AFQ.recognition.utils as abu
import AFQ.recognition.cleaning as abc
import AFQ.recognition.curvature as abv
import AFQ.recognition.roi as abr
import AFQ.recognition.other_bundles as abo

[docs]bundle_criterion_order = [ "prob_map", "cross_midline", "start", "end", "length", "primary_axis", "include", "exclude", "curvature", "recobundles", "qb_thresh"]
[docs]valid_noncriterion = [ "space", "mahal", "primary_axis_percentage", "inc_addtol", "exc_addtol"]
[docs]logger = logging.getLogger('AFQ')
[docs]def prob_map(b_sls, bundle_def, preproc_imap, prob_threshold, **kwargs): b_sls.initiate_selection("Prob. Map") # using entire fgarray here only because it is the first step fiber_probabilities = dts.values_from_volume( bundle_def["prob_map"].get_fdata(), preproc_imap["fgarray"], np.eye(4)) fiber_probabilities = np.mean(fiber_probabilities, -1) b_sls.select( fiber_probabilities > prob_threshold, "Prob. Map")
[docs]def cross_midline(b_sls, bundle_def, preproc_imap, **kwargs): b_sls.initiate_selection("Cross Mid.") accepted = preproc_imap["crosses"][b_sls.selected_fiber_idxs] if not bundle_def["cross_midline"]: accepted = np.invert(accepted) b_sls.select(accepted, "Cross Mid.")
[docs]def start(b_sls, bundle_def, preproc_imap, **kwargs): accept_idx = b_sls.initiate_selection("Startpoint") abr.clean_by_endpoints( b_sls.get_selected_sls(), bundle_def["start"], 0, tol=preproc_imap["dist_to_atlas"], flip_sls=b_sls.sls_flipped, accepted_idxs=accept_idx) if not b_sls.oriented_yet: accepted_idx_flipped = abr.clean_by_endpoints( b_sls.get_selected_sls(), bundle_def["start"], -1, tol=preproc_imap["dist_to_atlas"]) b_sls.reorient(accepted_idx_flipped) accept_idx = np.logical_xor( accepted_idx_flipped, accept_idx) b_sls.select(accept_idx, "Startpoint")
[docs]def end(b_sls, bundle_def, preproc_imap, **kwargs): accept_idx = b_sls.initiate_selection("endpoint") abr.clean_by_endpoints( b_sls.get_selected_sls(), bundle_def["end"], -1, tol=preproc_imap["dist_to_atlas"], flip_sls=b_sls.sls_flipped, accepted_idxs=accept_idx) if not b_sls.oriented_yet: accepted_idx_flipped = abr.clean_by_endpoints( b_sls.get_selected_sls(), bundle_def["end"], 0, tol=preproc_imap["dist_to_atlas"]) b_sls.reorient(accepted_idx_flipped) accept_idx = np.logical_xor( accepted_idx_flipped, accept_idx) b_sls.select(accept_idx, "endpoint")
[docs]def length(b_sls, bundle_def, preproc_imap, **kwargs): accept_idx = b_sls.initiate_selection("length") min_len = bundle_def["length"].get( "min_len", 0) / preproc_imap["vox_dim"] max_len = bundle_def["length"].get( "max_len", np.inf) / preproc_imap["vox_dim"] for idx, sl in enumerate(b_sls.get_selected_sls()): sl_len = np.sum( np.linalg.norm(np.diff(sl, axis=0), axis=1)) if sl_len >= min_len and sl_len <= max_len: accept_idx[idx] = 1 b_sls.select(accept_idx, "length")
[docs]def primary_axis(b_sls, bundle_def, img, **kwargs): b_sls.initiate_selection("orientation") accept_idx = abc.clean_by_orientation( b_sls.get_selected_sls(), bundle_def["primary_axis"], img.affine, bundle_def.get( "primary_axis_percentage", None)) b_sls.select(accept_idx, "orientation")
[docs]def include(b_sls, bundle_def, preproc_imap, max_includes, parallel_segmentation, **kwargs): accept_idx = b_sls.initiate_selection("include") flip_using_include = len(bundle_def["include"]) > 1\ and not b_sls.oriented_yet if f'inc_addtol' in bundle_def: include_roi_tols = [] for inc_tol in bundle_def["inc_addtol"]: include_roi_tols.append(( inc_tol / preproc_imap["vox_dim"] + preproc_imap["tol"])**2) else: include_roi_tols = [preproc_imap["tol"]**2] * len( bundle_def["include"]) # with parallel segmentation, the first for loop will # only collect streamlines and does not need tqdm if parallel_segmentation["engine"] != "serial": inc_results = paramap( abr.check_sl_with_inclusion, b_sls.get_selected_sls(), func_args=[ bundle_def["include"], include_roi_tols], **parallel_segmentation) else: inc_results = abr.check_sls_with_inclusion( b_sls.get_selected_sls(), bundle_def["include"], include_roi_tols) roi_closest = -np.ones( (max_includes, len(b_sls)), dtype=np.int32) if flip_using_include: to_flip = np.ones_like(accept_idx, dtype=np.bool_) for sl_idx, inc_result in enumerate(inc_results): sl_accepted, sl_closest = inc_result if sl_accepted: if len(sl_closest) > 1: roi_closest[:len(sl_closest), sl_idx] = sl_closest # Only accept SLs that, when cut, are meaningful if (len(sl_closest) < 2) or abs( sl_closest[0] - sl_closest[-1]) > 1: # Flip sl if it is close to second ROI # before its close to the first ROI if flip_using_include: to_flip[sl_idx] =\ sl_closest[0] > sl_closest[-1] if to_flip[sl_idx]: roi_closest[:len(sl_closest), sl_idx] =\ np.flip(sl_closest) accept_idx[sl_idx] = 1 else: accept_idx[sl_idx] = 1 # see https://github.com/joblib/joblib/issues/945 if ( (parallel_segmentation.get( "engine", "joblib") != "serial") and (parallel_segmentation.get( "backend", "loky") == "loky")): from joblib.externals.loky import get_reusable_executor get_reusable_executor().shutdown(wait=True) b_sls.roi_closest = roi_closest.T if flip_using_include: b_sls.reorient(to_flip) b_sls.select(accept_idx, "include")
[docs]def curvature(b_sls, bundle_def, mapping, img, save_intermediates, **kwargs): ''' Filters streamlines by how well they match a curve in orientation and shape but not scale ''' accept_idx = b_sls.initiate_selection("curvature") if "sft" in bundle_def["curvature"]: ref_sl = bundle_def["curvature"]["sft"] else: ref_sl = load_tractogram( bundle_def["curvature"]["path"], "same", bbox_valid_check=False) moved_ref_sl = abu.move_streamlines( ref_sl, "subject", mapping, img, save_intermediates=save_intermediates) moved_ref_sl.to_vox() moved_ref_sl = moved_ref_sl.streamlines[0] moved_ref_curve = abv.sl_curve( moved_ref_sl, len(moved_ref_sl)) ref_curve_threshold = np.radians(bundle_def["curvature"].get( "thresh", 10)) cut = bundle_def["curvature"].get("cut", True) for idx, sl in enumerate(b_sls.get_selected_sls( cut=cut, flip=True)): if len(sl) > 1: this_sl_curve = abv.sl_curve(sl, len(moved_ref_sl)) dist = abv.sl_curve_dist(this_sl_curve, moved_ref_curve) if dist <= ref_curve_threshold: accept_idx[idx] = 1 b_sls.select(accept_idx, "curvature", cut=cut)
[docs]def exclude(b_sls, bundle_def, preproc_imap, **kwargs): accept_idx = b_sls.initiate_selection("exclude") if f'exc_addtol' in bundle_def: exclude_roi_tols = [] for exc_tol in bundle_def["exc_addtol"]: exclude_roi_tols.append(( exc_tol / preproc_imap["vox_dim"] + preproc_imap["tol"])**2) else: exclude_roi_tols = [ preproc_imap["tol"]**2] * len(bundle_def["exclude"]) for sl_idx, sl in enumerate(b_sls.get_selected_sls()): if abr.check_sl_with_exclusion( sl, bundle_def["exclude"], exclude_roi_tols): accept_idx[sl_idx] = 1 b_sls.select(accept_idx, "exclude")
[docs]def recobundles(b_sls, mapping, bundle_def, reg_template, img, refine_reco, save_intermediates, rng, rb_recognize_params, **kwargs): b_sls.initiate_selection("Recobundles") moved_sl = abu.move_streamlines( StatefulTractogram(b_sls.get_selected_sls(), img, Space.VOX), "template", mapping, reg_template, save_intermediates=save_intermediates).streamlines rb = RecoBundles(moved_sl, verbose=True, rng=rng) _, rec_labels = rb.recognize( bundle_def['recobundles']['sl'], **rb_recognize_params) if refine_reco: _, rec_labels = rb.refine( bundle_def['recobundles']['sl'], moved_sl[rec_labels], **rb_recognize_params) if not b_sls.oriented_yet: standard_sl = next(iter(bundle_def['recobundles']['centroid'])) oriented_idx = abu.orient_by_streamline( moved_sl[rec_labels], standard_sl) b_sls.reorient(rec_labels[oriented_idx]) b_sls.select(rec_labels, "Recobundles")
[docs]def qb_thresh(b_sls, bundle_def, preproc_imap, clip_edges, **kwargs): b_sls.initiate_selection("qb_thresh") cut = clip_edges or ("bundlesection" in bundle_def) qbx = QuickBundles( bundle_def["qb_thresh"] / preproc_imap["vox_dim"], AveragePointwiseEuclideanMetric( ResampleFeature(nb_points=12))) clusters = qbx.cluster(b_sls.get_selected_sls( cut=cut, flip=True)) cleaned_idx = clusters[np.argmax( clusters.clusters_sizes())].indices b_sls.select(cleaned_idx, "qb_thresh", cut=cut)
[docs]def clean_by_other_bundle(b_sls, bundle_def, img, preproc_imap, other_bundle_name, other_bundle_sls, **kwargs): cleaned_idx = b_sls.initiate_selection(other_bundle_name) cleaned_idx = 1 if 'node_thresh' in bundle_def[other_bundle_name]: cleaned_idx_node_thresh = abo.clean_by_other_density_map( b_sls.get_selected_sls(), other_bundle_sls, bundle_def[other_bundle_name]["node_thresh"], img) cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_node_thresh) if 'core' in bundle_def[other_bundle_name]: cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]['core'].lower(), preproc_imap["fgarray"][b_sls.selected_fiber_idxs], np.array(abu.resample_tg(other_bundle_sls, 20)), img.affine) cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_core) b_sls.select(cleaned_idx, other_bundle_name)
[docs]def mahalanobis(b_sls, bundle_def, clip_edges, cleaning_params, **kwargs): b_sls.initiate_selection("Mahalanobis") clean_params = bundle_def.get("mahal", {}) clean_params = { **cleaning_params, **clean_params} clean_params["return_idx"] = True cut = clip_edges or ("bundlesection" in bundle_def) _, cleaned_idx = abc.clean_bundle( b_sls.get_selected_sls(cut=cut, flip=True), **clean_params) b_sls.select(cleaned_idx, "Mahalanobis", cut=cut)
[docs]def run_bundle_rec_plan( bundle_dict, tg, mapping, img, reg_template, preproc_imap, bundle_name, bundle_idx, bundle_to_flip, bundle_roi_closest, bundle_decisions, **segmentation_params): # Warp ROIs logger.info(f"Preparing ROIs for {bundle_name}") start_time = time() bundle_def = dict(bundle_dict.get_b_info(bundle_name)) bundle_def.update(bundle_dict.transform_rois( bundle_name, mapping, img.affine, apply_to_recobundles=True)) apply_to_roi_dict( bundle_def, lambda roi_img: nib.Nifti1Image( distance_transform_edt( np.where(roi_img.get_fdata() == 0, 1, 0)), roi_img.affine), dry_run=False, apply_to_recobundles=False, apply_to_prob_map=False) logger.info(f"Time to prep ROIs: {time()-start_time}s") b_sls = abu.SlsBeingRecognized( tg.streamlines, logger, segmentation_params["save_intermediates"], bundle_name, img, len(bundle_def.get("include", []))) inputs = {} inputs["b_sls"] = b_sls inputs["preproc_imap"] = preproc_imap inputs["bundle_def"] = bundle_def inputs["max_includes"] = bundle_dict.max_includes inputs["mapping"] = mapping inputs["img"] = img inputs["reg_template"] = reg_template for key, value in segmentation_params.items(): inputs[key] = value for potential_criterion in bundle_def.keys(): if (potential_criterion not in bundle_criterion_order) and\ (potential_criterion not in bundle_dict.bundle_names) and\ (potential_criterion not in valid_noncriterion): raise ValueError(( "Invalid criterion in bundle definition:\n" f"{potential_criterion} in bundle {bundle_name}.\n" "Valid criteria are:\n" f"{bundle_criterion_order}\n" f"{bundle_dict.bundle_names}\n" f"{valid_noncriterion}\n")) for criterion in bundle_criterion_order: if b_sls and criterion in bundle_def: inputs[criterion] = globals()[criterion](**inputs) if b_sls: for ii, bundle_name in enumerate(bundle_dict.bundle_names): if bundle_name in bundle_def.keys(): idx = np.where(bundle_decisions[:, ii])[0] clean_by_other_bundle( **inputs, other_bundle_name=bundle_name, other_bundle_sls=tg.streamlines[idx]) if b_sls: mahalanobis(**inputs) if b_sls and not b_sls.oriented_yet: raise ValueError( "pyAFQ was unable to consistently orient streamlines " f"in bundle {bundle_name} using the provided ROIs. " "This can be fixed by including at least 2 " "waypoint ROIs, or by using " "endpoint ROIs.") if b_sls: bundle_to_flip[ b_sls.selected_fiber_idxs, bundle_idx] = b_sls.sls_flipped.copy() bundle_decisions[ b_sls.selected_fiber_idxs, bundle_idx] = 1 if hasattr(b_sls, "roi_closest"): bundle_roi_closest[ b_sls.selected_fiber_idxs, bundle_idx, : ] = b_sls.roi_closest.copy()