import numpy as np
import os.path as op
import os
import logging
import dipy.tracking.streamlinespeed as dps
from dipy.io.stateful_tractogram import StatefulTractogram, Space
import AFQ.recognition.utils as abu
from AFQ.api.bundle_dict import BundleDict
from AFQ.recognition.criteria import run_bundle_rec_plan
from AFQ.recognition.preprocess import get_preproc_plan
[docs]logger = logging.getLogger('AFQ') 
[docs]def recognize(
        tg,
        img,
        mapping,
        bundle_dict,
        reg_template,
        nb_points=False,
        nb_streamlines=False,
        clip_edges=False,
        parallel_segmentation={"engine": "serial"},
        rb_recognize_params=dict(
            model_clust_thr=1.25,
            reduction_thr=25,
            pruning_thr=12),
        refine_reco=False,
        prob_threshold=0,
        dist_to_waypoint=None,
        rng=None,
        return_idx=False,
        filter_by_endpoints=True,
        dist_to_atlas=4,
        save_intermediates=None,
        cleaning_params={}):
    """
    Segment streamlines into bundles.
    Parameters
    ----------
    tg : str, StatefulTractogram
        Tractogram to segment.
    img : str, nib.Nifti1Image
        Image for reference.
    mapping : MappingDefinition
        Mapping from subject to template.
    bundle_dict : dict or AFQ.api.BundleDict
        Dictionary of bundles to segment.
    reg_template : str, nib.Nifti1Image
        Template image for registration.
    nb_points : int, boolean
        Resample streamlines to nb_points number of points.
        If False, no resampling is done. Default: False
    nb_streamlines : int, boolean
        Subsample streamlines to nb_streamlines.
        If False, no subsampling is don. Default: False
    clip_edges : bool
        Whether to clip the streamlines to be only in between the ROIs.
        Default: False
    parallel_segmentation : dict or AFQ.api.BundleDict
        How to parallelize segmentation across processes when performing
        waypoint ROI segmentation. Set to {"engine": "serial"} to not
        perform parallelization. Some engines may cause errors, depending
        on the system. See ``dipy.utils.parallel.paramap`` for
        details.
        Default: {"engine": "serial"}
    rb_recognize_params : dict
        RecoBundles parameters for the recognize function.
        Default: dict(model_clust_thr=1.25, reduction_thr=25, pruning_thr=12)
    refine_reco : bool
        Whether to refine the RecoBundles segmentation.
        Default: False
    prob_threshold : float.
        Using AFQ Algorithm.
        Initial cleaning of fiber groups is done using probability maps
        from [Hua2008]_. Here, we choose an average probability that
        needs to be exceeded for an individual streamline to be retained.
        Default: 0.
    dist_to_waypoint : float.
        The distance that a streamline node has to be from the waypoint
        ROI in order to be included or excluded.
        If set to None (default), will be calculated as the
        center-to-corner distance of the voxel in the diffusion data.
        If a bundle has inc_addtol or exc_addtol in its bundle_dict, that
        tolerance will be added to this distance.
        For example, if you wanted to increase tolerance for the right
        arcuate waypoint ROIs by 3 each, you could make the following
        modification to your bundle_dict:
        bundle_dict["Right Arcuate"]["inc_addtol"] = [3, 3]
        Additional tolerances can also be negative.
        Default: None.
    rng : RandomState or int
        If None, creates RandomState.
        If int, creates RandomState with seed rng.
        Used in RecoBundles Algorithm.
        Default: None.
    return_idx : bool
        Whether to return the indices in the original streamlines as part
        of the output of segmentation.
        Default: False.
    filter_by_endpoints: bool
        Whether to filter the bundles based on their endpoints.
        Default: True.
    dist_to_atlas : float
        If filter_by_endpoints is True, this is the required distance
        from the endpoints to the atlas ROIs.
        Default: 4
    save_intermediates : str, optional
        The full path to a folder into which intermediate products
        are saved. Default: None, means no saving of intermediates.
    cleaning_params : dict, optional
        Cleaning params to pass to seg.clean_bundle. This will
        override the default parameters of that method. However, this
        can be overriden by setting the cleaning parameters in the
        bundle_dict. Default: {}.
    References
    ----------
    .. [Hua2008] Hua K, Zhang J, Wakana S, Jiang H, Li X, et al. (2008)
        Tract probability maps in stereotaxic spaces: analyses of white
        matter anatomy and tract-specific quantification. Neuroimage 39:
        336-347
    .. [Yeatman2012] Yeatman, Jason D., Robert F. Dougherty, Nathaniel J.
        Myall, Brian A. Wandell, and Heidi M. Feldman. 2012. "Tract Profiles
        of White Matter Properties: Automating Fiber-Tract Quantification"
        PloS One 7 (11): e49790.
    .. [Garyfallidis2018] Garyfallidis et al. Recognition of white matter
        bundles using local and global streamline-based registration and
        clustering, Neuroimage, 2017.
    """
    if rng is None:
        rng = np.random.RandomState()
    elif isinstance(rng, int):
        rng = np.random.RandomState(rng)
    if (save_intermediates is not None) and \
            
(not op.exists(save_intermediates)):
        os.makedirs(save_intermediates, exist_ok=True)
    logger.info("Preprocessing Streamlines")
    tg = abu.read_tg(tg, nb_streamlines)
    # If resampling over-write the sft:
    if nb_points:
        tg = StatefulTractogram(
            dps.set_number_of_points(tg.streamlines, nb_points),
            tg, tg.space)
    if not isinstance(bundle_dict, BundleDict):
        bundle_dict = BundleDict(bundle_dict)
    tg.to_vox()
    n_streamlines = len(tg)
    bundle_decisions = np.zeros(
        (n_streamlines, len(bundle_dict)),
        dtype=np.bool_)
    bundle_to_flip = np.zeros(
        (n_streamlines, len(bundle_dict)),
        dtype=np.bool_)
    bundle_roi_closest = -np.ones(
        (
            n_streamlines,
            len(bundle_dict),
            bundle_dict.max_includes),
        dtype=np.uint32)
    fiber_groups = {}
    meta = {}
    preproc_imap = get_preproc_plan(img, tg, dist_to_waypoint, dist_to_atlas)
    logger.info("Assigning Streamlines to Bundles")
    for bundle_idx, bundle_name in enumerate(
            bundle_dict.bundle_names):
        logger.info(f"Finding Streamlines for {bundle_name}")
        run_bundle_rec_plan(
            bundle_dict, tg, mapping, img, reg_template, preproc_imap,
            bundle_name, bundle_idx, bundle_to_flip, bundle_roi_closest,
            bundle_decisions,
            clip_edges=clip_edges,
            parallel_segmentation=parallel_segmentation,
            rb_recognize_params=rb_recognize_params,
            prob_threshold=prob_threshold,
            refine_reco=refine_reco,
            rng=rng,
            return_idx=return_idx,
            filter_by_endpoints=filter_by_endpoints,
            save_intermediates=save_intermediates,
            cleaning_params=cleaning_params)
    if save_intermediates is not None:
        os.makedirs(save_intermediates, exist_ok=True)
        bc_path = op.join(save_intermediates,
                          "sls_bundle_decisions.npy")
        np.save(bc_path, bundle_decisions)
    conflicts = np.sum(np.sum(bundle_decisions, axis=1) > 1)
    if conflicts > 0:
        logger.warning((
            "Conflicts in bundle assignment detected. "
            f"{conflicts} conflicts detected in total out of "
            f"{n_streamlines} total streamlines."
            "Defaulting to whichever bundle appears first"
            "in the bundle_dict."))
    bundle_decisions = np.concatenate((
        bundle_decisions, np.ones((n_streamlines, 1))), axis=1)
    bundle_decisions = np.argmax(bundle_decisions, -1)
    # We do another round through, so that we can:
    # 1. Clip streamlines according to ROIs
    # 2. Re-orient streamlines
    logger.info("Re-orienting streamlines to consistent directions")
    for bundle_idx, bundle in enumerate(bundle_dict.bundle_names):
        logger.info(f"Processing {bundle}")
        select_idx = np.where(bundle_decisions == bundle_idx)[0]
        if len(select_idx) == 0:
            # There's nothing here, set and move to the next bundle:
            if "bundlesection" in bundle_dict.get_b_info(bundle):
                for sb_name in bundle_dict.get_b_info(bundle)[
                        "bundlesection"]:
                    _return_empty(sb_name, return_idx, fiber_groups, img)
            else:
                _return_empty(bundle, return_idx, fiber_groups, img)
            continue
        # Use a list here, because ArraySequence doesn't support item
        # assignment:
        select_sl = list(tg.streamlines[select_idx])
        roi_closest = bundle_roi_closest[select_idx, bundle_idx, :]
        n_includes = len(bundle_dict.get_b_info(
            bundle).get("include", []))
        if clip_edges and n_includes > 1:
            logger.info("Clipping Streamlines by ROI")
            select_sl = abu.cut_sls_by_closest(
                select_sl, roi_closest,
                (0, n_includes - 1), in_place=True)
        to_flip = bundle_to_flip[select_idx, bundle_idx]
        b_def = dict(bundle_dict.get_b_info(bundle_name))
        if "bundlesection" in b_def:
            for sb_name, sb_include_cuts in bundle_dict.get_b_info(
                    bundle)["bundlesection"].items():
                bundlesection_select_sl = abu.cut_sls_by_closest(
                    select_sl, roi_closest,
                    sb_include_cuts, in_place=False)
                _add_bundle_to_fiber_group(
                    sb_name, bundlesection_select_sl, select_idx,
                    to_flip, return_idx, fiber_groups, img)
                _add_bundle_to_meta(sb_name, b_def, meta)
        else:
            _add_bundle_to_fiber_group(
                bundle, select_sl, select_idx, to_flip,
                return_idx, fiber_groups, img)
            _add_bundle_to_meta(bundle, b_def, meta)
    return fiber_groups, meta 
# Helper functions for formatting the results
[docs]def _return_empty(bundle_name, return_idx, fiber_groups, img):
    """
    Helper function to return an empty dict under
    some conditions.
    """
    if return_idx:
        fiber_groups[bundle_name] = {}
        fiber_groups[bundle_name]['sl'] = StatefulTractogram(
            [], img, Space.VOX)
        fiber_groups[bundle_name]['idx'] = np.array([])
    else:
        fiber_groups[bundle_name] = StatefulTractogram(
            [], img, Space.VOX) 
[docs]def _add_bundle_to_fiber_group(b_name, sl, idx, to_flip,
                               return_idx, fiber_groups, img):
    """
    Helper function to add a bundle to a fiber group.
    """
    sl = abu.flip_sls(
        sl, to_flip,
        in_place=False)
    sl = StatefulTractogram(
        sl,
        img,
        Space.VOX)
    if return_idx:
        fiber_groups[b_name] = {}
        fiber_groups[b_name]['sl'] = sl
        fiber_groups[b_name]['idx'] = idx
    else:
        fiber_groups[b_name] = sl