Source code for AFQ.viz.utils

import colorsys
import logging
import os.path as op
from collections import OrderedDict

import dipy.tracking.streamlinespeed as dps
import imageio as io
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
from dipy.align import resample
from dipy.io.stateful_tractogram import StatefulTractogram
from dipy.tracking.streamline import transform_streamlines
from PIL import Image, ImageChops

import AFQ.utils.streamlines as aus

__all__ = ["Viz"]


def get_distinct_shades(base_rgb, n_steps, hue_shift):
    """
    Creates distinct shades by shifting Hue
    """
    hh, ll, ss = colorsys.rgb_to_hls(*base_rgb)
    shades = []

    for i in range(n_steps):
        offset = i - (n_steps - 1) / 2

        new_h = (hh + (offset * hue_shift)) % 1.0

        shades.append(colorsys.hls_to_rgb(new_h, ll, ss))
    return shades


viz_logger = logging.getLogger("AFQ")
tableau_20 = [
    (0.12156862745098039, 0.4666666666666667, 0.7058823529411765),
    (0.6823529411764706, 0.7803921568627451, 0.9098039215686274),
    (1.0, 0.4980392156862745, 0.054901960784313725),
    (1.0, 0.7333333333333333, 0.47058823529411764),
    (0.17254901960784313, 0.6274509803921569, 0.17254901960784313),
    (0.596078431372549, 0.8745098039215686, 0.5411764705882353),
    (0.8392156862745098, 0.15294117647058825, 0.1568627450980392),
    (1.0, 0.596078431372549, 0.5882352941176471),
    (0.5803921568627451, 0.403921568627451, 0.7411764705882353),
    (0.7725490196078432, 0.6901960784313725, 0.8352941176470589),
    (0.5490196078431373, 0.33725490196078434, 0.29411764705882354),
    (0.7686274509803922, 0.611764705882353, 0.5803921568627451),
    (0.8901960784313725, 0.4666666666666667, 0.7607843137254902),
    (0.9686274509803922, 0.7137254901960784, 0.8235294117647058),
    (0.4980392156862745, 0.4980392156862745, 0.4980392156862745),
    (0.7803921568627451, 0.7803921568627451, 0.7803921568627451),
    (0.7372549019607844, 0.7411764705882353, 0.13333333333333333),
    (0.8588235294117647, 0.8588235294117647, 0.5529411764705883),
    (0.09019607843137255, 0.7450980392156863, 0.8117647058823529),
    (0.6196078431372549, 0.8549019607843137, 0.8980392156862745),
]
tableau_extension = [
    (0.95, 0.85, 0.25),  # Warm, Soft Yellow
    (0.98, 0.92, 0.5),
    (0.4, 0.5, 0.6),  # Cool, Muted Blue
    (0.6, 0.7, 0.9),
]
large_font = 28
medium_font = 24
small_font = 20
marker_size = 200

slf_l_base = tableau_extension[0]
slf_r_base = tableau_extension[1]

vof_l_base = tableau_20[6]
vof_r_base = tableau_20[7]

slf_l_shades = get_distinct_shades(slf_l_base, 3, hue_shift=0.1)
slf_r_shades = get_distinct_shades(slf_r_base, 3, hue_shift=0.1)

vof_l_shades = get_distinct_shades(vof_l_base, 3, hue_shift=0.15)
vof_r_shades = get_distinct_shades(vof_r_base, 3, hue_shift=0.15)

COLOR_DICT = OrderedDict(
    {
        "Left Anterior Thalamic": tableau_20[0],
        "C_L": tableau_20[0],
        "Right Anterior Thalamic": tableau_20[1],
        "C_R": tableau_20[1],
        "Left Corticospinal": tableau_20[2],
        "Right Corticospinal": tableau_20[3],
        "Left Cingulum Cingulate": tableau_20[4],
        "MCP": tableau_20[4],
        "Right Cingulum Cingulate": tableau_20[5],
        "CCMid": tableau_20[5],
        "Forceps Minor": tableau_20[8],
        "CC_ForcepsMinor": tableau_20[8],
        "Forceps Major": tableau_20[9],
        "CC_ForcepsMajor": tableau_20[9],
        "Left Inferior Fronto-occipital": tableau_20[10],
        "IFOF_L": tableau_20[10],
        "Right Inferior Fronto-occipital": tableau_20[11],
        "IFOF_R": tableau_20[11],
        "Left Inferior Longitudinal": tableau_20[12],
        "F_L": tableau_20[12],
        "Right Inferior Longitudinal": tableau_20[13],
        "F_R": tableau_20[13],
        "Left Superior Longitudinal": slf_l_base,
        "Right Superior Longitudinal": slf_r_base,
        "Left Superior Longitudinal I": slf_l_shades[0],
        "Left Superior Longitudinal II": slf_l_shades[1],
        "Left Superior Longitudinal III": slf_l_shades[2],
        "Right Superior Longitudinal I": slf_r_shades[0],
        "Right Superior Longitudinal II": slf_r_shades[1],
        "Right Superior Longitudinal III": slf_r_shades[2],
        "Left Uncinate": tableau_20[16],
        "UF_L": tableau_20[16],
        "Right Uncinate": tableau_20[17],
        "UF_R": tableau_20[17],
        "Left Arcuate": tableau_20[18],
        "AF_L": tableau_20[18],
        "Right Arcuate": tableau_20[19],
        "AF_R": tableau_20[19],
        "Left Posterior Arcuate": tableau_20[14],
        "Right Posterior Arcuate": tableau_20[15],
        "Left Vertical Occipital": vof_l_base,
        "Right Vertical Occipital": vof_r_base,
        "Left V1V3": vof_l_shades[0],
        "Left Posterior Vertical Occipital": vof_l_shades[1],
        "Left Anterior Vertical Occipital": vof_l_shades[2],
        "Right V1V3": vof_r_shades[0],
        "Right Posterior Vertical Occipital": vof_r_shades[1],
        "Right Anterior Vertical Occipital": vof_r_shades[2],
        "median": tableau_20[6],
        # Paul Tol's palette for callosal bundles
        "Callosum Orbital": (0.2, 0.13, 0.53),
        "Callosum Anterior Frontal": (0.07, 0.47, 0.2),
        "Callosum Superior Frontal": (0.27, 0.67, 0.6),
        "Callosum Motor": (0.53, 0.8, 0.93),
        "Callosum Superior Parietal": (0.87, 0.8, 0.47),
        "Callosum Posterior Parietal": (0.8, 0.4, 0.47),
        "Callosum Occipital": (0.67, 0.27, 0.6),
        "Callosum Temporal": (0.53, 0.13, 0.33),
    }
)

POSITIONS = OrderedDict(
    {
        "Left Anterior Thalamic": (1, 0),
        "Right Anterior Thalamic": (1, 4),
        "C_L": (1, 0),
        "C_R": (1, 4),
        "Left Corticospinal": (1, 1),
        "Right Corticospinal": (1, 3),
        "Left Cingulum Cingulate": (3, 1),
        "Right Cingulum Cingulate": (3, 3),
        "MCP": (3, 1),
        "CCMid": (3, 3),
        "Forceps Minor": (4, 2),
        "Forceps Major": (0, 2),
        "CC_ForcepsMinor": (4, 2),
        "CC_ForcepsMajor": (0, 2),
        "Left Inferior Fronto-occipital": (4, 1),
        "Right Inferior Fronto-occipital": (4, 3),
        "IFOF_L": (4, 1),
        "IFOF_R": (4, 3),
        "Left Inferior Longitudinal": (3, 0),
        "Right Inferior Longitudinal": (3, 4),
        "F_L": (3, 0),
        "F_R": (3, 4),
        "Left Superior Longitudinal": (2, 1),
        "Right Superior Longitudinal": (2, 3),
        "Left Arcuate": (2, 0),
        "Right Arcuate": (2, 4),
        "AF_L": (2, 0),
        "AF_R": (2, 4),
        "Left Uncinate": (0, 1),
        "Right Uncinate": (0, 3),
        "UF_L": (0, 1),
        "UF_R": (0, 3),
    }
)

CSV_MAT_2_PYTHON = {"fa": "dti_fa", "md": "dti_md"}

SCALE_MAT_2_PYTHON = {"dti_md": 0.001}

SCALAR_REMOVE_MODEL = {"dti_md": "MD", "dki_md": "MD", "dki_fa": "FA", "dti_fa": "FA"}

RECO_FLIP = ["IFO_L", "IFO_R", "UNC_L", "ILF_L", "ILF_R"]

BEST_BUNDLE_ORIENTATIONS = {
    "Left Anterior Thalamic": ("Left", "Front", "Top"),
    "Right Anterior Thalamic": ("Right", "Front", "Top"),
    "Left Corticospinal": ("Left", "Front", "Top"),
    "Right Corticospinal": ("Right", "Front", "Top"),
    "Left Cingulum Cingulate": ("Left", "Front", "Top"),
    "Right Cingulum Cingulate": ("Right", "Front", "Top"),
    "Forceps Minor": ("Left", "Front", "Top"),
    "Forceps Major": ("Left", "Back", "Top"),
    "Left Inferior Fronto-occipital": ("Left", "Front", "Bottom"),
    "Right Inferior Fronto-occipital": ("Right", "Front", "Bottom"),
    "Left Inferior Longitudinal": ("Left", "Front", "Bottom"),
    "Right Inferior Longitudinal": ("Right", "Front", "Bottom"),
    "Left Superior Longitudinal": ("Left", "Front", "Top"),
    "Right Superior Longitudinal": ("Right", "Front", "Top"),
    "Left Uncinate": ("Left", "Front", "Bottom"),
    "Right Uncinate": ("Right", "Front", "Bottom"),
    "Left Arcuate": ("Left", "Front", "Top"),
    "Right Arcuate": ("Right", "Front", "Top"),
    "Left Vertical Occipital": ("Left", "Back", "Top"),
    "Right Vertical Occipital": ("Right", "Back", "Top"),
    "Left Posterior Arcuate": ("Left", "Back", "Top"),
    "Right Posterior Arcuate": ("Right", "Back", "Top"),
}


class PanelFigure:
    """
    Super useful class for organizing existing images
    into subplots using matplotlib
    """

    def __init__(self, num_rows, num_cols, width, height, panel_label_kwargs=None):
        """
        Initialize PanelFigure.

        Parameters
        ----------
        num_rows : int
            Number of rows in figure
        num_cols : int
            Number of columns in figure
        width : int
            Width of figure in inches
        height : int
            Height of figure in inches
        panel_label_kwargs : dict
            Additional arguments for matplotlib's text method,
            which is used to add panel labels to each subplot
        """
        if panel_label_kwargs is None:
            panel_label_kwargs = {}
        self.fig = plt.figure(figsize=(width, height))
        self.grid = plt.GridSpec(num_rows, num_cols, hspace=0, wspace=0)
        self.subplot_count = 0
        self.panel_label_kwargs = dict(
            fontfamily="Helvetica",
            fontsize="xx-large",
            color="white",
            fontweight="bold",
            verticalalignment="top",
            bbox=dict(facecolor="none", edgecolor="none"),
        )
        self.panel_label_kwargs.update(panel_label_kwargs)
        self.panel_label_queue = []

    def add_img(
        self,
        fname,
        x_coord,
        y_coord,
        reduct_count=1,
        subplot_label_pos=(0.1, 0.1),
        legend=None,
        legend_kwargs=None,
        add_panel_label=True,
        panel_label_kwargs=None,
    ):
        """
        Add image from fname into figure as a panel.

        Parameters
        ----------
        fname : str
            path to image file to add to subplot
        x_coord : int or slice
            x coordinate(s) of subplot in matlotlib figure
        y_coord : int or slice
            y coordinate(s) of subplot in matlotlib figure
        reduct_count : int
            number of times to trim whitespace around image
            Default: 1
        subplot_label_pos : tuple of floats
            position of subplot label
            Default: (0.1, 0.1)
        legend : dict
            dictionary of legend items, where keys are labels
            and values are colors
            Default: None
        legend_kwargs : dict
            Additional arguments for matplotlib's legend method
        add_panel_label : bool
            Whether or not to add a panel label to the subplot
            Default: True
        panel_label_kwargs : dict
            Additional arguments for matplotlib's text method,
            which is used to add panel labels to each subplot.
        """
        if panel_label_kwargs is None:
            panel_label_kwargs = {}
        if legend_kwargs is None:
            legend_kwargs = {}
        ax = self.fig.add_subplot(self.grid[y_coord, x_coord])
        im1 = Image.open(fname)
        for _ in range(reduct_count):
            im1 = trim(im1)
        if legend is not None:
            patches = []
            for value, color in legend.items():
                patches.append(mpatches.Patch(color=color, label=value))
            ax.legend(handles=patches, borderaxespad=0.0, **legend_kwargs)
        ax.imshow(np.asarray(im1), aspect=1)
        ax.axis("off")
        if add_panel_label:
            this_pl_kwargs = self.panel_label_kwargs.copy()
            this_pl_kwargs.update(panel_label_kwargs)
            self.panel_label_queue.append(
                (
                    y_coord,
                    x_coord,
                    subplot_label_pos,
                    f"{chr(65 + self.subplot_count)}",
                    this_pl_kwargs,
                )
            )

        self.subplot_count = self.subplot_count + 1
        return ax

    def format_and_save_figure(self, fname, trim_final=True, tight_final=True):
        """
        Format and save figure to fname.
        Parameters
        ----------
        fname : str
            Path to save figure to
        tight_final : bool
            Whether or not to use tight_layout on figure.
        trim : bool
            Whether or not to trim whitespace around figure.
            Default: True
        """
        if tight_final:
            self.fig.tight_layout()
        self.fig.canvas.draw()
        for (
            y_coord,
            x_coord,
            subplot_label_pos,
            label_text,
            kwargs,
        ) in self.panel_label_queue:
            if isinstance(x_coord, slice):
                x_coord = x_coord.start
            if isinstance(y_coord, slice):
                y_coord = y_coord.start
            x_fig_pos = x_coord / self.grid.ncols
            y_fig_pos = 1.0 - (y_coord / self.grid.nrows)
            self.fig.text(
                x_fig_pos + subplot_label_pos[0] / self.grid.ncols,
                y_fig_pos - subplot_label_pos[1] / self.grid.nrows,
                label_text,
                ha="right",
                va="bottom",
                **kwargs,
            )
        self.fig.savefig(fname, dpi=300)
        if trim_final:
            im1 = Image.open(fname)
            im1 = trim(im1)
            im1.save(fname, dpi=(300, 300))


def get_eye(view, direc):
    direc = direc.lower()
    view = view.lower()

    if view in ["sagital", "saggital"]:
        viz_logger.warning("You don't know how to spell sagggitttal!")
        view = "sagittal"

    if view not in ["sagittal", "coronal", "axial"]:
        raise ValueError("View must be one of: sagittal, coronal, or axial")

    if direc not in ["left", "right", "top", "bottom", "front", "back"]:
        raise ValueError("View must be one of: left, right, top, bottom, front, back")

    eye = {}
    if view == "sagittal":
        if direc == "left":
            eye["x"] = -1
        else:
            eye["x"] = 1
        eye["y"] = 0
        eye["z"] = 0
    elif view == "coronal":
        eye["x"] = 0
        if direc == "front":
            eye["y"] = 1
        else:
            eye["y"] = -1
        eye["z"] = 0
    elif view == "axial":
        eye["x"] = 0
        eye["y"] = 0
        if direc == "top":
            eye["z"] = 1
        else:
            eye["z"] = -1
    return eye


def display_string(scalar_name):
    if isinstance(scalar_name, str):
        return scalar_name.replace("_", " ").upper()
    else:
        return [sn.replace("_", " ").upper() for sn in scalar_name]


def gen_color_dict(bundles):
    """
    Helper function.
    Generate a color dict given a list of bundles.
    """

    def incr_color_idx(color_idx):
        return (color_idx + 1) % 20

    custom_color_dict = {}
    color_idx = 0
    for bundle in bundles:
        if bundle not in custom_color_dict.keys():
            if bundle in COLOR_DICT.keys():
                custom_color_dict[bundle] = COLOR_DICT[bundle]
            else:
                other_bundle = bundle
                if bundle.startswith("Left "):
                    other_bundle = "Right" + other_bundle[5:]
                elif bundle.startswith("Right "):
                    other_bundle = "Left" + other_bundle[4:]
                other_bundle = str(other_bundle)

                if other_bundle == bundle:  # lone bundle
                    custom_color_dict[bundle] = tableau_20[color_idx]
                    color_idx = incr_color_idx(color_idx)
                else:  # right left pair
                    if color_idx % 2 != 0:
                        color_idx = incr_color_idx(color_idx)
                    custom_color_dict[bundle] = tableau_20[color_idx]
                    custom_color_dict[other_bundle] = tableau_20[color_idx + 1]
                    color_idx = incr_color_idx(incr_color_idx(color_idx))
    return custom_color_dict


def viz_import_msg_error(module):
    """Alerts user to install the appropriate viz module"""
    if module == "plot":
        msg = "To make plots in pyAFQ, you will need to install "
        msg += "the relevant plotting software packages."
        msg += "You can do that by installing pyAFQ with "
        msg += "`pip install pyAFQ[plot]`"
    else:
        msg = f"To use {module.upper()} visualizations in pyAFQ, you will "
        msg += f"need to have {module.upper()} installed. "
        msg += "You can do that by installing pyAFQ with "
        msg += f"`pip install pyAFQ[{module.lower()}]`, or by "
        msg += f"separately installing {module.upper()}: "
        msg += f"`pip install {module.lower()}`."
    return msg


def _sls_to_t1(sls, ref_sft, t1_img):
    if t1_img is None:
        return sls, ref_sft.dimensions
    else:
        sft = StatefulTractogram.from_sft(sls, ref_sft)
        sft.to_rasmm()
        sls = transform_streamlines(sft.streamlines, np.linalg.inv(t1_img.affine))
        return sls, t1_img.shape


def tract_generator(
    trk_file, bundle, colors, n_points, t1_img, n_sls_viz=65536, n_sls_min=256
):
    """
    Generates bundles of streamlines from the tractogram.
    Only generates from relevant bundle if bundle is set.
    Otherwise, returns all streamlines.

    Helper function

    Parameters
    ----------
    trk_file : str or SegmentedSFT
        Path to a trk file or SegmentedSFT

    bundle : str
        The name of a bundle to select from the trk metadata.

    colors : dict or list
        If this is a dict, keys are bundle names and values are RGB tuples.
        If this is a list, each item is an RGB tuple. Defaults to a list
        with Tableau 20 RGB values

    n_points : int or None
        n_points to resample streamlines to before plotting. If None, no
        resampling is done.

    t1_img : Nifti1Image
        T1-weighted image to register streamlines to.

    n_sls_viz : int
        Number of streamlines to randomly select if plotting
        all bundles. Selections will be proportional to the original number of
        streamlines per bundle.
        Default: 3600

    n_sls_min : int
        Minimum number of streamlines to display per bundle.
        Default: 75

    Returns
    -------
    Statefule Tractogram streamlines, RGB numpy array, str
    """
    viz_logger.info("Loading Stateful Tractogram...")
    if isinstance(trk_file, str):
        seg_sft = aus.SegmentedSFT.fromfile(trk_file)
    else:
        seg_sft = trk_file

    if colors is None:
        colors = gen_color_dict(seg_sft.bundle_names)

    seg_sft.sft.to_rasmm()
    streamlines = seg_sft.sft.streamlines
    viz_logger.info("Generating colorful lines from tractography...")

    if len(seg_sft.bundle_names) == 1 and seg_sft.bundle_names[0] == "whole_brain":
        if isinstance(colors, dict):
            colors = list(colors.values())
        # There are no bundles in here:
        if len(streamlines) > n_sls_viz:
            idx = np.arange(len(streamlines))
            idx = np.random.choice(idx, size=n_sls_viz, replace=False)
            streamlines = streamlines[idx]
        if n_points is not None:
            streamlines = dps.set_number_of_points(streamlines, n_points)
        sls, dim = _sls_to_t1(streamlines, seg_sft.sft, t1_img)
        yield sls, colors[0], "all_bundles", dim
    else:
        if bundle is None:
            # No selection: visualize all of them:
            for bundle_name in sorted(seg_sft.bundle_names):
                idx = seg_sft.bundle_idxs[bundle_name]
                if len(idx) == 0:
                    continue
                n_sl_viz = (len(idx) * n_sls_viz) // len(streamlines)
                n_sl_viz = max(n_sls_min, n_sl_viz)
                if len(idx) > n_sl_viz:
                    idx = np.random.choice(idx, size=n_sl_viz, replace=False)
                these_sls = streamlines[idx]
                if n_points is not None:
                    these_sls = dps.set_number_of_points(these_sls, n_points)
                if isinstance(colors, dict):
                    color = colors[bundle_name]
                else:
                    color = colors[0]
                sls, dim = _sls_to_t1(these_sls, seg_sft.sft, t1_img)
                yield sls, color, bundle_name, dim
        else:
            these_sls = seg_sft.get_bundle(bundle).streamlines
            if len(these_sls) > n_sls_viz:
                idx = np.random.choice(len(these_sls), size=n_sls_viz, replace=False)
                these_sls = these_sls[idx]
            if n_points is not None:
                these_sls = dps.set_number_of_points(these_sls, n_points)
            if isinstance(colors, dict):
                color = colors[bundle]
            else:
                color = colors[0]
            sls, dim = _sls_to_t1(these_sls, seg_sft.sft, t1_img)
            yield sls, color, bundle, dim


def bbox(img):
    img = np.sum(img, axis=-1)
    rows = np.any(img, axis=1)
    cols = np.any(img, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]

    return cmin, rmin, cmax, rmax


def trim(im):
    bg = Image.new(im.mode, im.size, im.getpixel((0, 0)))
    diff = ImageChops.difference(im, bg)
    this_bbox = bbox(diff)
    if this_bbox:
        return im.crop(this_bbox)


def gif_from_pngs(tdir, gif_fname, n_frames, png_fname="tgif", add_zeros=False):
    """
    Helper function
    Stitches together gif from screenshots
    """
    if add_zeros:
        fname_suffix10 = "00000"
        fname_suffix100 = "0000"
        fname_suffix1000 = "000"
    else:
        fname_suffix10 = ""
        fname_suffix100 = ""
        fname_suffix1000 = ""
    angles = []
    n_frame_copies = 60 // n_frames
    for i in range(n_frames):
        if i < 10:
            angle_fname = f"{png_fname}{fname_suffix10}{i}.png"
        elif i < 100:
            angle_fname = f"{png_fname}{fname_suffix100}{i}.png"
        else:
            angle_fname = f"{png_fname}{fname_suffix1000}{i}.png"
        frame = io.imread(op.join(tdir, angle_fname))
        for _j in range(n_frame_copies):
            angles.append(frame)

    io.mimsave(gif_fname, angles)


def prepare_roi(roi, resample_to=None):
    """
    Load the ROI
    Possibly perform a transformation on an ROI
    Helper function

    Parameters
    ----------
    roi : str or Nifti1Image
        The ROI information.
        If str, ROI will be loaded using the str as a path.

    resample_to : Nifti1Image, optional
        If not None, the ROI will be resampled to the space of this image.

    Returns
    -------
    ndarray
    """
    viz_logger.info("Preparing ROI...")
    if isinstance(roi, str):
        roi = nib.load(roi)

    if resample_to is not None:
        if not isinstance(roi, nib.Nifti1Image):
            raise ValueError(
                ("If resampling, roi must be a Nifti1Image or a path to one.")
            )
        roi = resample(roi, resample_to)

    if not isinstance(roi, np.ndarray):
        roi = roi.get_fdata()

    return roi


def load_volume(volume):
    """
    Load a volume
    Helper function

    Parameters
    ----------
    volume : ndarray or str
        3d volume to load.
        If string, it is used as a file path.
        If it is an ndarray, it is simply returned.

    Returns
    -------
    ndarray
    """
    viz_logger.info("Loading Volume...")
    if isinstance(volume, str):
        return nib.load(volume).get_fdata()
    elif isinstance(volume, nib.Nifti1Image):
        return volume.get_fdata()
    else:
        return volume


[docs] class Viz: def __init__(self, backend="fury"): """ Set up visualization preferences. Parameters ---------- backend : str, optional Should be either "fury" or "plotly". Default: "fury" """
[docs] self.backend = backend
if "fury" in backend: try: import AFQ.viz.fury_backend except ImportError as e: raise ImportError(viz_import_msg_error("fury")) from e self.visualize_bundles = AFQ.viz.fury_backend.visualize_bundles self.visualize_roi = AFQ.viz.fury_backend.visualize_roi self.visualize_volume = AFQ.viz.fury_backend.visualize_volume self.create_gif = AFQ.viz.fury_backend.create_gif elif "plotly" in backend: try: import AFQ.viz.plotly_backend except ImportError as e: raise ImportError(viz_import_msg_error("plotly")) from e self.visualize_bundles = AFQ.viz.plotly_backend.visualize_bundles self.visualize_roi = AFQ.viz.plotly_backend.visualize_roi self.visualize_volume = AFQ.viz.plotly_backend.visualize_volume self.create_gif = AFQ.viz.plotly_backend.create_gif self.single_bundle_viz = AFQ.viz.plotly_backend.single_bundle_viz else: raise TypeError( "Visualization backend contain" + " either 'plotly' or 'fury'. " + "It is currently set to %s" % backend )