Source code for AFQ.viz.plotly_backend

import tempfile
import enum
import logging

import numpy as np
import pandas as pd

import AFQ.viz.utils as vut

from dipy.tracking.streamline import set_number_of_points

try:
    import plotly
    import plotly.graph_objs as go
    import plotly.io as pio
    from plotly.subplots import make_subplots
    import plotly.express as px
    from plotly.colors import hex_to_rgb
except (ImportError, ModuleNotFoundError):
    raise ImportError(vut.viz_import_msg_error("plotly"))


[docs]scope = pio.kaleido.scope
[docs]viz_logger = logging.getLogger("AFQ")
[docs]def _inline_interact(figure, show, show_inline): """ Helper function to reuse across viz functions """ if show: viz_logger.info("Creating interactive figure in HTML file...") plotly.offline.plot(figure) if show_inline: viz_logger.info("Creating interactive figure inline...") plotly.offline.init_notebook_mode() plotly.offline.iplot(figure) return figure
[docs]def _to_color_range(num): if num < 0: num = 0 if num >= 0.999: num = 0.999 if num <= 0.001: num = 0.001 return num
[docs]def _color_arr2str(color_arr, opacity=1.0): return ( f"rgba({_to_color_range(color_arr[0])}, " f"{_to_color_range(color_arr[1])}, " f"{_to_color_range(color_arr[2])}, " f"{_to_color_range(opacity)})" )
[docs]def set_layout(figure, color=None): if color is None: color = "rgba(0,0,0,0)" figure.update_layout( plot_bgcolor=color, scene1=dict( xaxis=dict( showbackground=False, showticklabels=False, title=''), yaxis=dict( showbackground=False, showticklabels=False, title=''), zaxis=dict( showbackground=False, showticklabels=False, title=''), aspectmode='data' ) )
[docs]def _draw_streamlines(figure, sls, dimensions, color, name, cbv=None, cbs=None, sbv_lims=[None, None], flip_axes=[False, False, False], opacity=1.0): color = np.asarray(color) if len(sls._offsets) > 1: plotting_shape = (sls._data.shape[0] + sls._offsets.shape[0]) else: plotting_shape = sls._data.shape[0] # dtype object so None can be stored x_pts = np.zeros(plotting_shape) y_pts = np.zeros(plotting_shape) z_pts = np.zeros(plotting_shape) if cbs is not None: cbs = np.asarray(cbs) line_color = np.zeros((plotting_shape, cbs.shape[1])) color = cbs[0, :] elif cbv is not None: if sbv_lims[0] is None: sbv_lims[0] = 0 if sbv_lims[1] is None: sbv_lims[1] = cbv.max() color_constant = (color / color.max())\ * (1.4 / (sbv_lims[1] - sbv_lims[0])) + sbv_lims[0] line_color = np.zeros((plotting_shape, 3)) else: color_constant = color line_color = np.zeros((plotting_shape, 3)) customdata_tp = np.zeros(plotting_shape) customdata_nodes = np.zeros(plotting_shape) for sl_index, plotting_offset in enumerate(sls._offsets): sl_length = sls._lengths[sl_index] sl = sls._data[plotting_offset:plotting_offset + sl_length] # add sl to lines total_offset = plotting_offset + sl_index x_pts[total_offset:total_offset + sl_length] = sl[:, 0] y_pts[total_offset:total_offset + sl_length] = sl[:, 1] z_pts[total_offset:total_offset + sl_length] = sl[:, 2] # don't draw between streamlines if len(sls._offsets) > 1: x_pts[total_offset + sl_length] = np.nan y_pts[total_offset + sl_length] = np.nan z_pts[total_offset + sl_length] = np.nan if cbs is not None: color_constant = cbs[sl_index] if cbv is not None: brightness = cbv[ sl[:, 0].astype(int), sl[:, 1].astype(int), sl[:, 2].astype(int) ] line_color[total_offset:total_offset + sl_length, :] = \ np.outer(brightness, color_constant) customdata_tp[total_offset:total_offset + sl_length] = brightness else: line_color[total_offset:total_offset + sl_length, :] = \ color_constant customdata_tp[total_offset:total_offset + sl_length] = 1 customdata_nodes[total_offset:total_offset + sl_length] =\ np.arange(sl_length) if line_color.shape[1] > 3: line_color[total_offset:total_offset + sl_length, 3] = \ color_constant[3] # dont shade alpha values if len(sls._offsets) > 1: line_color[total_offset + sl_length, :] = 0 if flip_axes[0]: x_pts = dimensions[0] - x_pts if flip_axes[1]: y_pts = dimensions[1] - y_pts if flip_axes[2]: z_pts = dimensions[2] - z_pts hovertext = [ f'TP: {i1}<br>Node ID: {i2}' for i1, i2 in zip( customdata_tp, customdata_nodes)] figure.add_trace( go.Scatter3d( x=x_pts, y=y_pts, z=z_pts, name=vut.display_string(name), legendgroup=vut.display_string(name), marker=dict( size=0.0001, color=_color_arr2str(color) ), # this is necessary to add color to legend line=dict( width=8, color=line_color, ), hovertext=hovertext, hoverinfo='all', opacity=opacity ), row=1, col=1 ) return color_constant
[docs]def _plot_profiles(profiles, bundle_name, color, fig, scalar): if isinstance(profiles, pd.DataFrame): profiles = profiles[profiles.tractID == bundle_name] x = profiles["nodeID"] y = profiles[scalar] line_color = [] for scalar_val in profiles[scalar].to_numpy(): line_color.append(_color_arr2str(scalar_val * color)) else: x = np.arange(len(profiles)) y = profiles line_color = [] for indiv_color in color: line_color.append(_color_arr2str(indiv_color)) fig.add_trace( go.Scatter3d( x=x, y=y, z=np.zeros(len(y)), name=vut.display_string(bundle_name), line=dict(color=line_color, width=15), mode="lines", legendgroup=vut.display_string(bundle_name)), row=1, col=2) font = dict(size=20, family="Overpass") fixed_camera_for_2d = dict( projection=dict(type="orthographic"), up=dict(x=0, y=1, z=0), eye=dict(x=0, y=0, z=1), center=dict(x=0, y=0, z=0)) fig.update_layout( margin={"t": 15, "b": 0, "l": 0, "r": 0}, scene2=dict( camera=fixed_camera_for_2d, zaxis=dict(visible=False), dragmode=False, xaxis_title=dict(text="Location", font=font), yaxis_title=dict(text=vut.display_string(scalar), font=font)))
[docs]def visualize_bundles(sft, n_points=None, bundle=None, colors=None, shade_by_volume=None, color_by_streamline=None, n_sls_viz=3600, sbv_lims=[None, None], include_profiles=(None, None), flip_axes=[False, False, False], opacity=1.0, figure=None, background=(1, 1, 1), interact=False, inline=False, **kwargs): """ Visualize bundles in 3D Parameters ---------- sft : Stateful Tractogram, str A Stateful Tractogram containing streamline information or a path to a trk file. In order to visualize individual bundles, the Stateful Tractogram must contain a bundle key in it's data_per_streamline which is a list of bundle `'uid'`. n_points : int or None n_points to resample streamlines to before plotting. If None, no resampling is done. bundle : str, optional The name of a bundle to select or an integer for selection from the sft metadata. colors : dict or list If this is a dict, keys are bundle names and values are RGB tuples. If this is a list, each item is an RGB tuple. Defaults to a dict from bundles to Tableau 20 RGB values. shade_by_volume : ndarray or str, optional 3d volume use to shade the bundles. If None, no shading is performed. Only works when using the plotly backend. Default: None color_by_streamline : ndarray or dict, optional N by 3 array, where N is the number of streamlines in sft; for each streamline you specify 3 values between 0 and 1 for rgb. If sft has multiple bundles, then use a dict for color_by_streamline, where keys are bundle names and values are n by 3 arrays. Overrides colors for bundles in the keys of the dict if passing a dict, or for all streamlines if using ndarray. Default: None n_sls_viz : int Number of streamlines to randomly select if plotting all bundles. Selections will be proportional to the original number of streamlines per bundle. Default: 3600 sbv_lims : ndarray Of the form (lower bound, upper bound). Shading based on shade_by_volume will only differentiate values within these bounds. If lower bound is None, will default to 0. If upper bound is None, will default to the maximum value in shade_by_volume. Default: [None, None] include_profiles : Tuple of Pandas Dataframe and string The first element of the uple is a Pandas Dataframe containing profiles in the standard pyAFQ output format for the bundle(s) being displayed. It will be used to generate a graph of the tract profiles for each bundle, with colors corresponding to the colors on the bundles. The string is the scalar to use from the profile. If these are None, no tract profiles will be graphed. Defualt: (None, None) flip_axes : ndarray Which axes to flip, to orient the image as RAS, which is how we visualize. For example, if the input image is LAS, use [True, False, False]. Default: [False, False, False] opacity : float Float between 0 and 1 defining the opacity of the bundle. Default: 1.0 background : tuple, optional RGB values for the background. Default: (1, 1, 1), which is white background. figure : Plotly Figure object, optional If provided, the visualization will be added to this Figure. Default: Initialize a new Figure. interact : bool Whether to open the visualization in an interactive window. Default: False inline : bool Whether to embed the interactivevisualization inline in a notebook. Only works in the notebook context. Default: False. Returns ------- Plotly Figure object """ if shade_by_volume is not None: shade_by_volume = vut.load_volume(shade_by_volume) if figure is None: if include_profiles[0] is None: figure = make_subplots( rows=1, cols=1, specs=[[{"type": "scene"}]]) else: figure = make_subplots( rows=1, cols=2, specs=[[{"type": "scene"}, {"type": "scene"}]]) set_layout(figure, color=_color_arr2str(background)) for (sls, color, name, dimensions) in vut.tract_generator( sft, bundle, colors, n_points, n_sls_viz=n_sls_viz): if isinstance(color_by_streamline, dict): if name in color_by_streamline: cbs = color_by_streamline[name] else: cbs = color_by_streamline color_constant = _draw_streamlines( figure, sls, dimensions, color, name, cbv=shade_by_volume, cbs=cbs, sbv_lims=sbv_lims, flip_axes=flip_axes, opacity=opacity) if include_profiles[0] is not None: _plot_profiles( include_profiles[0], name, color_constant, figure, include_profiles[1]) figure.update_layout(legend=dict(itemsizing="constant")) return _inline_interact(figure, interact, inline)
[docs]def create_gif(figure, file_name, n_frames=30, zoom=2.5, z_offset=0.5, size=(600, 600)): """ Convert a Plotly Figure object into a gif Parameters ---------- figure: Plotly Figure object Figure to be converted to a gif file_name: str File to save gif to. n_frames: int, optional Number of frames in gif. Will be evenly distributed throughout the rotation. Default: 60 zoom: float, optional How much to magnify the figure in the fig. Default: 2.5 size: tuple, optional Size of the gif. Default: (600, 600) """ tdir = tempfile.gettempdir() for i in range(n_frames): theta = (i * 6.28) / n_frames camera = dict( eye=dict(x=np.cos(theta) * zoom, y=np.sin(theta) * zoom, z=z_offset) ) figure.update_layout(scene_camera=camera) figure.write_image(tdir + f"/tgif{i}.png") scope._shutdown_kaleido() # temporary fix for memory leak vut.gif_from_pngs(tdir, file_name, n_frames, png_fname="tgif", add_zeros=False)
[docs]def _draw_roi(figure, roi, name, color, opacity, dimensions, flip_axes): roi = np.where(roi == 1) pts = [] for i, flip in enumerate(flip_axes): if flip: pts.append(dimensions[i] - (roi[i] + 1)) else: pts.append(roi[i] + 1) figure.add_trace( go.Scatter3d( x=pts[0], y=pts[1], z=pts[2], name=name, marker=dict(color=_color_arr2str(color, opacity=opacity)), line=dict(color="rgba(0,0,0,0)") ), row=1, col=1 )
[docs]def visualize_roi(roi, affine_or_mapping=None, static_img=None, roi_affine=None, static_affine=None, reg_template=None, name='ROI', figure=None, flip_axes=[False, False, False], color=np.array([0.9999, 0, 0]), opacity=1.0, interact=False, inline=False): """ Render a region of interest into a volume Parameters ---------- roi : str or Nifti1Image The ROI information affine_or_mapping : ndarray, Nifti1Image, or str, optional An affine transformation or mapping to apply to the ROIs before visualization. Default: no transform. static_img: str or Nifti1Image, optional Template to resample roi to. Default: None roi_affine: ndarray, optional Default: None static_affine: ndarray, optional Default: None reg_template: str or Nifti1Image, optional Template to use for registration. Default: None name: str, optional Name of ROI for the legend. Default: 'ROI' color : ndarray, optional RGB color for ROI. Default: np.array([0.9999, 0, 0]) opacity : float, optional Opacity of ROI. Default: 1.0 flip_axes : ndarray Which axes to flip, to orient the image as RAS, which is how we visualize. For example, if the input image is LAS, use [True, False, False]. Default: [False, False, False] figure : Plotly Figure object, optional If provided, the visualization will be added to this Figure. Default: Initialize a new Figure. interact : bool Whether to open the visualization in an interactive window. Default: False inline : bool Whether to embed the interactive visualization inline in a notebook. Only works in the notebook context. Default: False. Returns ------- Plotly Figure object """ roi = vut.prepare_roi(roi, affine_or_mapping, static_img, roi_affine, static_affine, reg_template) if figure is None: figure = make_subplots( rows=1, cols=1, specs=[[{"type": "scene"}]]) set_layout(figure) _draw_roi(figure, roi, name, color, opacity, roi.shape, flip_axes) return _inline_interact(figure, interact, inline)
[docs]class Axes(enum.IntEnum):
[docs] X = 0
[docs] Y = 1
[docs] Z = 2
[docs]def _draw_slice(figure, axis, volume, opacity=0.3, pos=0.5, colorscale="greys", invert_colorscale=False): height = int(volume.shape[axis] * pos) v_min = volume.min() sf = volume.max() - v_min if axis == Axes.X: X, Y, Z = np.mgrid[height:height + 1, :volume.shape[1], :volume.shape[2]] values = volume[height, :, :].flatten() elif axis == Axes.Y: X, Y, Z = np.mgrid[:volume.shape[0], height:height + 1, :volume.shape[2]] values = volume[:, height, :].flatten() elif axis == Axes.Z: X, Y, Z = np.mgrid[:volume.shape[0], :volume.shape[1], height:height + 1] values = volume[:, :, height].flatten() values = (values - v_min) / sf if invert_colorscale: values = 1 - values figure.add_trace( go.Volume( x=X.flatten(), y=Y.flatten(), z=Z.flatten(), value=values, colorscale=colorscale, surface_count=1, showscale=False, opacity=opacity, name=_name_from_enum(axis), hoverinfo='skip', showlegend=True ), row=1, col=1 )
[docs]def _name_from_enum(axis): if axis == Axes.X: return "Sagittal" elif axis == Axes.Y: return "Coronal" elif axis == Axes.Z: return "Axial"
[docs]def visualize_volume(volume, figure=None, x_pos=0.5, y_pos=0.5, z_pos=0.5, interact=False, inline=False, opacity=0.3, colorscale="gray", invert_colorscale=False, flip_axes=[False, False, False]): """ Visualize a volume Parameters ---------- volume : ndarray or str 3d volume to visualize. figure : Plotly Figure object, optional If provided, the visualization will be added to this Figure. Default: Initialize a new Figure. x_pos : float or None, optional Where to show Coronal Slice. If None, slice is not shown. Should be a decimal between 0 and 1. Indicatesthe fractional position along the perpendicular axis to the slice. Default: 0.5 y_pos : float or None, optional Where to show Sagittal Slice. If None, slice is not shown. Should be a decimal between 0 and 1. Indicatesthe fractional position along the perpendicular axis to the slice. Default: 0.5 z_pos : float or None, optional Where to show Axial Slice. If None, slice is not shown. Should be a decimal between 0 and 1. Indicatesthe fractional position along the perpendicular axis to the slice. Default: 0.5 opacity : float, optional Opacity of slices. Default: 1.0 colorscale : string, optional Plotly colorscale to use to color slices. Default: "greys" invert_colorscale : bool, optional Whether to invert colorscale. Default: False flip_axes : ndarray Which axes to flip, to orient the image as RAS, which is how we visualize. For example, if the input image is LAS, use [True, False, False]. Default: [False, False, False] interact : bool Whether to open the visualization in an interactive window. Default: False inline : bool Whether to embed the interactive visualization inline in a notebook. Only works in the notebook context. Default: False. Returns ------- Plotly Figure object """ volume = vut.load_volume(volume) for i, flip in enumerate(flip_axes): if flip: volume = np.flip(volume, axis=i) if figure is None: figure = make_subplots( rows=1, cols=1, specs=[[{"type": "scene"}]]) set_layout(figure) for pos, axis in [(x_pos, Axes.X), (y_pos, Axes.Y), (z_pos, Axes.Z)]: if pos is not None: _draw_slice( figure, axis, volume, opacity=opacity, pos=pos, colorscale=colorscale, invert_colorscale=invert_colorscale) return _inline_interact(figure, interact, inline)
[docs]def _draw_core(sls, n_points, figure, bundle_name, indiv_profile, labelled_points, dimensions, flip_axes): fgarray = np.asarray(set_number_of_points(sls, n_points)) fgarray = np.median(fgarray, axis=0) # colormap = px.colors.diverging.Portland # colormap = np.asarray( # [[int(i) for i in c[4:-1].split(',')] for c in colormap]) / 256 colormap = px.colors.sequential.Viridis colormap = np.asarray( [hex_to_rgb(c) for c in colormap]) / 256 xp = np.linspace( np.min(indiv_profile), np.max(indiv_profile), num=len(colormap)) line_color = np.ones((n_points, 4)) for i in range(3): line_color[:, i] = np.interp(indiv_profile, xp, colormap[:, i]) line_color_untouched = line_color.copy() for i in range(n_points): if i < n_points - 1: direc = fgarray[i + 1] - fgarray[i] direc = direc / np.linalg.norm(direc) light_direc = -fgarray[i] / np.linalg.norm(fgarray[i]) direc_adjust = np.dot(direc, light_direc) direc_adjust = (direc_adjust + 3) / 4 line_color[i, 0:3] = line_color[i, 0:3] * direc_adjust text = [None] * n_points for label in labelled_points: if label == -1: text[label] = str(n_points) else: text[label] = str(label) if flip_axes[0]: fgarray[:, 0] = dimensions[0] - fgarray[:, 0] if flip_axes[1]: fgarray[:, 1] = dimensions[1] - fgarray[:, 1] if flip_axes[2]: fgarray[:, 2] = dimensions[2] - fgarray[:, 2] figure.add_trace( go.Scatter3d( x=fgarray[:, 0], y=fgarray[:, 1], z=fgarray[:, 2], name=vut.display_string(bundle_name + "_core"), line=dict( width=25, color=line_color, ), hovertext=indiv_profile, hoverinfo='all', text=text, textfont=dict(size=20, family="Overpass"), textposition="top right", mode="lines+text" ), row=1, col=1 ) return line_color_untouched
[docs]def single_bundle_viz(indiv_profile, sft, bundle, scalar_name, flip_axes=[False, False, False], labelled_nodes=[0, -1], figure=None, include_profile=False): """ Visualize a single bundle in 3D with core bundle and associated profile Parameters ---------- indiv_profile : ndarray A numpy array containing a tract profile for this bundle for a scalar. sft : Stateful Tractogram, str A Stateful Tractogram containing streamline information. If bundle is an int, the Stateful Tractogram must contain a bundle key in it's data_per_streamline which is a list of bundle `'uid'. Otherwise, the entire Stateful Tractogram will be used as the bundle for the visualization. bundle : str or int The name of the bundle to be used as the label for the plot, and for selection from the sft metadata. scalar_name : str The name of the scalar being used. flip_axes : ndarray Which axes to flip, to orient the image as RAS, which is how we visualize. For example, if the input image is LAS, use [True, False, False]. Default: [False, False, False] labelled_nodes : list or ndarray Which nodes to label. -1 indicates the last node. Default: [0, -1] figure : Plotly Figure object, optional If provided, the visualization will be added to this Figure. Default: Initialize a new Figure. include_profile : bool, optional If true, also plot the tract profile. Default: False Returns ------- Plotly Figure object """ if figure is None: if include_profile: figure = make_subplots( rows=1, cols=2, specs=[[{"type": "scene"}, {"type": "scene"}]]) else: figure = make_subplots( rows=1, cols=1, specs=[[{"type": "scene"}]]) set_layout(figure) n_points = len(indiv_profile) sls, _, bundle_name, dimensions = next(vut.tract_generator( sft, bundle, None, n_points)) line_color = _draw_core( sls, n_points, figure, bundle_name, indiv_profile, labelled_nodes, dimensions, flip_axes) if include_profile: _plot_profiles( indiv_profile, bundle_name + "_profile", line_color, figure, scalar_name) return figure