Source code for AFQ.viz.plot

import os.path as op
import os
import logging
import tempfile

import numpy as np
from scipy.stats import sem
import pandas as pd
from tqdm.auto import tqdm

import AFQ.viz.utils as vut
from AFQ.viz.utils import display_string
from AFQ.utils.stats import contrast_index as calc_contrast_index
from AFQ.data.utils import BUNDLE_RECO_2_AFQ, BUNDLE_MAT_2_PYTHON

try:
    from pingouin import intraclass_corr, corr
    import seaborn as sns
    import matplotlib.pyplot as plt
    from matplotlib.lines import Line2D
    from matplotlib.patches import Patch
    import IPython.display as display
except (ImportError, ModuleNotFoundError):
    raise ImportError(vut.viz_import_msg_error("plot"))

__all__ = ["visualize_tract_profiles", "visualize_gif_inline"]


[docs]def visualize_tract_profiles(tract_profiles, scalar="dti_fa", ylim=None, n_boot=1000, file_name=None, positions=vut.POSITIONS): """ Visualize all tract profiles for a scalar in one plot Parameters ---------- tract_profiles : string Path to CSV containing tract_profiles. scalar : string, optional Scalar to use in plots. Default: "dti_fa". ylim : list of 2 floats, optional Minimum and maximum value used for y-axis bounds. If None, ylim is not set. Default: None n_boot : int, optional Number of bootstrap resamples for seaborn to use to estimate the ci. Default: 1000 file_name : string, optional File name to save figure to if not None. Default: None positions : dictionary, optional Dictionary that maps bundle names to position in plot. Default: vut.POSITIONS Returns ------- Matplotlib figure and axes. """ csv_comparison = GroupCSVComparison( None, [tract_profiles], ["my_tract_profiles"], remove_model=False, scalar_bounds={'lb': {}, 'ub': {}}) df = csv_comparison.tract_profiles( scalar=scalar, ylim=ylim, positions=positions, out_file=file_name, plot_subject_lines=False, n_boot=n_boot) return df
class BrainAxes(): ''' Helper class. Creates and handles a grid of axes. Each axis corresponds to a bundle. Axis placement should roughly correspond to the actual bundle placement in the brain. ''' def __init__(self, size=(5, 5), positions=vut.POSITIONS, fig=None): self.size = size self.positions = positions self.on_grid = np.zeros((5, 5), dtype=bool) if fig is None: self.fig = plt.figure() label = "1" self.twinning = False else: # we are twinning with another BrainAxes self.fig = fig label = "2" self.twinning = True self.axes = self.fig.subplots( self.size[0], self.size[1], subplot_kw={"label": label, "frame_on": False}) self.fig.subplots_adjust( left=None, bottom=None, right=None, top=None, wspace=0.4, hspace=0.6) self.fig.set_size_inches((18, 18)) self.temp_fig, self.temp_axis = plt.subplots() self.temp_axis_owner = None def get_axis(self, bundle, axes_dict={}): ''' Given a bundle, turn on and get an axis. If bundle not positioned, will claim the temporary axis. If bundle in axes_dict, onyl return relevant axis. ''' if bundle in axes_dict.keys(): return axes_dict[bundle] elif bundle in self.positions.keys(): self.on_grid[self.positions[bundle]] = True return self.axes[self.positions[bundle]] else: if self.temp_axis_owner != bundle: plt.close(self.temp_fig) self.temp_fig, self.temp_axis = plt.subplots() self.temp_axis_owner = bundle return self.temp_axis def plot_line(self, bundle, x, y, data, ylabel, ylim, n_boot, alpha, lineplot_kwargs, plot_subject_lines=True, ax=None): ''' Given a dataframe data with at least columns x, y, and subjectID, plot the mean of subjects with ci of 95 in alpha and the individual subjects in alpha-0.2 using sns.lineplot() ''' if ax is None: ax = self.get_axis(bundle) lineplot_kwargs_mean = lineplot_kwargs.copy() sns.set(style="whitegrid", rc={"lines.linewidth": 1}) ax.hlines(0, 0, 95, linestyles='dashed', color='red') if plot_subject_lines: lineplot_kwargs_mean["hue"] = None lineplot_kwargs_mean["palette"] = None lineplot_kwargs_mean["color"] = "lightgray" sns.set(style="whitegrid", rc={"lines.linewidth": 0.5}) sns.lineplot( x=x, y=y, data=data, ci=None, estimator=None, units='subjectID', legend=False, ax=ax, alpha=alpha - 0.2, style=[True] * len(data.index), **lineplot_kwargs) sns.set(style="whitegrid", rc={"lines.linewidth": 3}) sns.lineplot( x=x, y=y, data=data, estimator='mean', errorbar=('ci', 95), n_boot=n_boot, legend=False, ax=ax, alpha=alpha, style=[True] * len(data.index), **lineplot_kwargs_mean) ax.set_title(display_string(bundle), fontsize=vut.large_font) ax.set_ylabel(ylabel, fontsize=vut.medium_font) ax.set_ylim(ylim) def format(self, disable_x=True, disable_y=True): ''' Call this functions once after all axes that you intend to use have been plotted on. Automatically formats brain axes. ''' for i in range(self.size[0]): for j in range(self.size[1]): if self.twinning: self.axes[i, j].xaxis.tick_top() self.axes[i, j].yaxis.tick_right() self.axes[i, j].xaxis.set_label_position('top') self.axes[i, j].yaxis.set_label_position('right') self.axes[i, j].tick_params( axis='y', which='major', labelsize=vut.small_font) self.axes[i, j].tick_params( axis='x', which='major', labelsize=vut.small_font) if not self.on_grid[i, j]: self.axes[i, j].axis("off") if self.twinning: if j != self.size[1] - 1 and self.on_grid[i][j + 1]: self.axes[i, j].set_yticklabels([]) self.axes[i, j].set_ylabel("") self.axes[i, j].set_xticklabels([]) self.axes[i, j].set_xlabel("") else: if disable_y and (j != 0 and self.on_grid[i][j - 1]): self.axes[i, j].set_yticklabels([]) self.axes[i, j].set_ylabel("") if disable_x or (i != self.size[0] - 1 and self.on_grid[i + 1][j]): self.axes[i, j].set_xticklabels([]) self.axes[i, j].set_xlabel("") self.fig.tight_layout() def save_temp_fig(self, o_folder, o_file, save_func): ''' If using a temporary axis, save it out and clear it. Returns True if temporary axis was saved, false if no temporary axis was in use ''' if self.temp_axis_owner is None: return False self.temp_fig.tight_layout() save_func(self.temp_fig, o_folder, o_file) plt.close(self.temp_fig) self.temp_axis_owner = None return True def is_using_temp_axis(self): return (self.temp_axis_owner is not None) def close_all(self): ''' Close all associated figures. ''' plt.close(self.temp_fig) plt.close(self.fig) class GroupCSVComparison(): """ Compare different CSVs, using: tract profiles, contrast indices, scan-rescan reliability using ICC. """ def __init__(self, out_folder, csv_fnames, names, is_special="", subjects=None, scalar_bounds={'lb': {'FA': 0.2}, 'ub': {'MD': 0.002}}, bundles=None, percent_nan_tol=10, percent_edges_removed=10, remove_model=True, mat_bundle_converter=BUNDLE_MAT_2_PYTHON, mat_column_converter=vut.CSV_MAT_2_PYTHON, mat_scale_converter=vut.SCALE_MAT_2_PYTHON, bundle_converter=BUNDLE_RECO_2_AFQ, ICC_func="ICC2"): """ Load in csv files, converting from matlab if necessary. Parameters ---------- out_folder : path Folder where outputs of this class's methods will be saved. csv_fnames : list of filenames Filenames for the two CSVs containing tract profiles to compare. Will obtain subject list from the first file. names : list of strings Name to use to identify each CSV dataset. is_special : str or list of strs, optional Whether or not the csv needs special attention. Can be "", "mat" if the csv was generated using mAFQ, or "reco" if the csv was generated using Recobundles. Default: "" subjects : list of num, optional List of subjects to consider. If None, will use all subjects in first dataset. Default: None scalar_bounds : dictionary, optional A dictionary with a lower bound and upper bound containting a series of scalar / threshold pairs used as a white-matter mask on the profiles (any values outside of the threshold will be marked NaN and not used or set to 0, depending on the case). Default: {'lb': {'FA': 0.2}, 'ub': {'MD': 0.002}} bundles : list of strings, optional Bundles to compare. If None, use all bundles in the first profile group. Default: None percent_nan_tol : int, optional Percentage of NaNs tolerable. If a profile has less than this percentage of NaNs, NaNs are interpolated. If it has more, the profile is thrown out. Default: 10 percent_edges_removed : int, optional Percentage of nodes to remove from the edges of the profile. Scalar values often change dramatically at the boundary between the grey matter and the white matter, and these effects can dominate plots. However, they are generally not interesting to us, and have low intersubject reliability. In a profile of 100 nodes, percent_edges_removed=10 would remove 5 nodes from each edge. Default: 10 remove_model : bool, optional Whether to remove prefix of scalars which specify model i.e., dti_fa => FA. Default: True mat_bundle_converter : dictionary, optional Dictionary that maps matlab bundle names to python bundle names. Default: BUNDLE_MAT_2_PYTHON mat_column_converter : dictionary, optional Dictionary that maps matlab column names to python column names. Default: CSV_MAT_2_PYTHON mat_scale_converter : dictionary, optional Dictionary that maps scalar names to how they should be scaled to match pyAFQ's scale for that scalar. Default: SCALE_MAT_2_PYTHON bundle_converter : dictionary, optional Dictionary that maps bundle names to more standard bundle names. Unlike mat_bundle_converter, this converter is applied to all CSVs Default: BUNDLE_RECO_2_AFQ ICC_func : string, optional ICC function to use to calculate correlations. Can be 'ICC1, 'ICC2', 'ICC3', 'ICC1k', 'ICC2k', 'ICC3k'. Default: "ICC2" """ self.logger = logging.getLogger('AFQ') self.ICC_func = ICC_func if "k" in self.ICC_func: self.ICC_func_name = f"ICC({self.ICC_func[3]},k)" else: self.ICC_func_name = f"ICC({self.ICC_func[3]},1)" self.out_folder = out_folder self.percent_nan_tol = percent_nan_tol if not isinstance(is_special, list): is_special = [is_special] * len(csv_fnames) self.profile_dict = {} for i, fname in enumerate(csv_fnames): profile = pd.read_csv(fname) if 'subjectID' in profile.columns: profile['subjectID'] = \ profile['subjectID'].apply( lambda x: int( ''.join(c for c in x if c.isdigit()) ) if isinstance(x, str) else x) else: profile['subjectID'] = 0 if is_special[i] == "mat": profile.rename( columns=mat_column_converter, inplace=True) profile['tractID'] = \ profile['tractID'].apply( lambda x: mat_bundle_converter[x]) for scalar, scale in mat_scale_converter.items(): profile[scalar] = \ profile[scalar].apply(lambda x: x * scale) profile.replace({"tractID": bundle_converter}, inplace=True) if is_special[i] == "reco": def reco_flip(df): if df.tractID in vut.RECO_FLIP: return 99 - df.nodeID else: return df.nodeID profile["nodeID"] = profile.apply(reco_flip, axis=1) if remove_model: profile.rename( columns=vut.SCALAR_REMOVE_MODEL, inplace=True) for bound, constraint in scalar_bounds.items(): for scalar, threshold in constraint.items(): profile[scalar] = \ profile[scalar].apply( lambda x: self._threshold_scalar( bound, threshold, x)) if percent_edges_removed > 0: profile = profile.drop(profile[np.logical_or( (profile["nodeID"] < percent_nan_tol // 2), (profile["nodeID"] >= 100 - (percent_nan_tol // 2)) )].index) self.profile_dict[names[i]] = profile if subjects is None: self.subjects = self.profile_dict[names[0]]['subjectID'].unique() else: self.subjects = subjects self.prof_len = 100 - (percent_nan_tol // 2) * 2 if bundles is None: self.bundles = self.profile_dict[names[0]]['tractID'].unique() self.bundles.sort() else: self.bundles = bundles self.color_dict = vut.gen_color_dict([*self.bundles, "median"]) # TODO: make these parameters self.scalar_markers = ["o", "x"] self.patterns = ( None, '/', 'o', 'x', '-', '.', '+', '//', '\\', '*', 'O', '|') def _threshold_scalar(self, bound, threshold, val): """ Threshold scalars by a lower and upper bound. """ if bound == "lb": if val > threshold: return val else: return np.nan elif bound == "ub": if val < threshold: return val else: return np.nan else: raise RuntimeError("scalar_bounds dictionary " + " formatted incorrectly. See" + " the default for reference") def _save_fig(self, fig, folder, f_name): """ Get file to save to, and generate the folder if it does not exist. """ if self.out_folder is None and folder is None: f_folder = None elif self.out_folder is None: f_folder = folder elif folder is None: f_folder = self.out_folder else: f_folder = op.join( self.out_folder, folder) if f_folder is None: fig.savefig(f_name + ".png") fig.savefig(f_name + ".svg", format='svg', dpi=300) else: os.makedirs(f_folder, exist_ok=True) fig.savefig(op.join(f_folder, f_name) + ".png") fig.savefig(op.join(f_folder, f_name) + ".svg", format='svg', dpi=300) def _get_profile(self, name, bundle, subject, scalar): """ Get a single profile, then handle not found / NaNs """ profile = self.profile_dict[name] single_profile = profile[ (profile['subjectID'] == subject) & (profile['tractID'] == bundle) ].sort_values("nodeID")[scalar].to_numpy() nans = np.isnan(single_profile) percent_nan = (np.sum(nans) * 100) // self.prof_len if len(single_profile) < 1: self.logger.warning( 'No scalars found for scalar ' + scalar + ' for subject ' + str(subject) + ' for bundle ' + bundle + ' for CSV ' + name) return None if percent_nan > 0: message = ( f'{percent_nan}% NaNs found in scalar ' + scalar + ' for subject ' + str(subject) + ' for bundle ' + bundle + ' for CSV ' + name) if np.sum(nans) > self.percent_nan_tol: self.logger.warning(message + '. Profile ignored. ') return None else: self.logger.info(message + '. NaNs interpolated. ') non_nan = np.logical_not(nans) single_profile[nans] = np.interp( nans.nonzero()[0], non_nan.nonzero()[0], single_profile[non_nan]) return single_profile def _alpha(self, alpha): ''' Keep alpha in a reasonable range Useful when calculating alpha automatically ''' if alpha < 0.3: return 0.3 if alpha > 1: return 1 return alpha def _array_to_df(self, arr): ''' Converts a 2xn array to a pandas dataframe with columns x, y Useful for plotting with seaborn. ''' df = pd.DataFrame() df['x'] = arr[0] df['y'] = arr[1] return df def masked_corr(self, arr, corrtype): ''' Mask arr for NaNs before calling np.corrcoef ''' mask = np.logical_not( np.logical_or( np.isnan(arr[0, ...]), np.isnan(arr[1, ...]))) if np.sum(mask) < 2: return np.nan, np.nan, np.nan arr = arr[:, mask] if corrtype == "ICC": data = pd.DataFrame({ "targets": np.concatenate( (np.arange(arr.shape[1]), np.arange(arr.shape[1]))), "raters": np.concatenate( (np.zeros(arr.shape[1]), np.ones(arr.shape[1]))), "ratings": np.concatenate( (arr[0], arr[1]))}) stats = intraclass_corr( data=data, targets="targets", raters="raters", ratings="ratings") row = stats[stats["Type"] == self.ICC_func].iloc[0] return row["ICC"], row["ICC"] - row["CI95%"][0], \ row["CI95%"][1] - row["ICC"] elif corrtype == "Srho": stats = corr( x=arr[0], y=arr[1], method="spearman") row = stats.iloc[0] return row["r"], row["r"] - row["CI95%"][0], \ row["CI95%"][1] - row["r"] else: raise ValueError("corrtype not recognized") def tract_profiles(self, names=None, scalar="FA", ylim=[0.0, 1.0], show_plots=False, positions=vut.POSITIONS, out_file=None, n_boot=1000, plot_subject_lines=True, axes_dict={}): """ Compare all tract profiles for a scalar from different CSVs. Plots tract profiles for all in one plot. Bundles taken from positions argument. Parameters ---------- names : list of strings, optional Names of datasets to plot profiles of. If None, all datasets are used. Default: None scalar : string, optional Scalar to use in plots. Default: "FA". ylim : list of 2 floats, optional Minimum and maximum value used for y-axis bounds. If None, ylim is not set. Default: [0.0, 1.0] out_file : str, optional Path to save the figure to. If None, use the default naming convention in self.out_folder Default: None n_boot : int, optional Number of bootstrap resamples for seaborn to use to estimate the ci. Default: 1000 show_plots : bool, optional Whether to show plots if in an interactive environment. Default: False positions : dictionary, optional Dictionary that maps bundle names to position in plot. Default: vut.POSITIONS plot_subject_lines : bool, optional Whether to plot individual subject lines with a smaller width. Default: True axes_dict : dictionary of axes, optional Plot contrast index for bundles that are keys of axes_dict on the corresponding axis. Default: {} """ if not show_plots: plt.ioff() if names is None: names = list(self.profile_dict.keys()) if out_file is None: o_folder = f"tract_profiles/{scalar}" o_file = f"{'_'.join(names)}" else: o_folder = None o_file = out_file ba = BrainAxes(positions=positions) labels = [] self.logger.info("Calculating means and CIs...") for j, bundle in enumerate(tqdm(self.bundles)): labels_temp = [] for i, name in enumerate(names): if i == 0: plot_kwargs = { "hue": "tractID", "palette": [self.color_dict[bundle]]} else: plot_kwargs = { "dashes": [(2**(i - 1), 2**(i - 1))], "hue": "tractID", "palette": [self.color_dict[bundle]]} profile = self.profile_dict[name] profile = profile[profile['tractID'] == bundle] ba.plot_line( bundle, "nodeID", scalar, profile, display_string(scalar), ylim, n_boot, self._alpha( 0.6 + 0.2 * i), plot_kwargs, plot_subject_lines=plot_subject_lines, ax=axes_dict.get(bundle)) if j == 0: line = Line2D( [], [], color=[0, 0, 0]) line.set_dashes((2**(i + 1), 2**(i + 1))) if ba.is_using_temp_axis(): labels_temp.append(line) else: labels.append(line) if ba.is_using_temp_axis(): ba.temp_fig.legend( labels_temp, names, fontsize=vut.medium_font) ba.save_temp_fig( o_folder, f"{o_file}_{bundle}", self._save_fig) if len(names) > 1: ba.fig.legend( labels, names, loc='center', fontsize=vut.medium_font) ba.format() self._save_fig(ba.fig, o_folder, o_file) if not show_plots: ba.close_all() plt.ion() def _contrast_index_df_maker(self, bundles, names, scalar): ci_df = pd.DataFrame(columns=["subjectID", "nodeID", "diff"]) for subject in self.subjects: profiles = [None] * 2 both_found = True for i, name in enumerate(names): for j, bundle in enumerate(bundles): profiles[i + j] = self._get_profile( name, bundle, subject, scalar) if profiles[i + j] is None: both_found = False if both_found: this_contrast_index = \ calc_contrast_index(profiles[0], profiles[1]) for i, diff in enumerate(this_contrast_index): ci_df = ci_df.append({ "subjectID": subject, "nodeID": i, "diff": diff}, ignore_index=True) return ci_df def contrast_index(self, names=None, scalar="FA", show_plots=False, n_boot=1000, ylim=(-0.5, 0.5), show_legend=False, positions=vut.POSITIONS, plot_subject_lines=True, axes_dict={}): """ Calculate the contrast index for each bundle in two datasets. Parameters ---------- names : list of strings, optional Names of datasets to plot profiles of. If None, all datasets are used. Should be a total of only two datasets. Default: None scalar : string, optional Scalar to use for the contrast index. Default: "FA". show_plots : bool, optional Whether to show plots if in an interactive environment. Default: False n_boot : int, optional Number of bootstrap resamples for seaborn to use to estimate the ci. Default: 1000 ylim : list of 2 floats, optional Minimum and maximum value used for y-axis bounds. If None, ylim is not set. Default: None show_legend : bool, optional Show legend in center with single entry denoting the scalar used. Default: False positions : dictionary, optional Dictionary that maps bundle names to position in plot. Default: vut.POSITIONS plot_subject_lines : bool, optional Whether to plot individual subject lines with a smaller width. Default: True axes_dict : dictionary of axes, optional Plot contrast index for bundles that are keys of axes_dict on the corresponding axis. Default: {} """ if not show_plots: plt.ioff() if names is None: names = list(self.profile_dict.keys()) if len(names) != 2: self.logger.error("To calculate the contrast index, " + "only two dataset names should be given") return None ba = BrainAxes(positions=positions) ci_all_df = {} for j, bundle in enumerate(tqdm(self.bundles)): ci_df = self._contrast_index_df_maker( [bundle], names, scalar) ba.plot_line( bundle, "nodeID", "diff", ci_df, "ACI", ylim, n_boot, 1.0, {"color": self.color_dict[bundle]}, plot_subject_lines=plot_subject_lines, ax=axes_dict.get(bundle)) ci_all_df[bundle] = ci_df ba.save_temp_fig( f"contrast_plots/{scalar}/", f"{names[0]}_vs_{names[1]}_contrast_index_{bundle}", self._save_fig) if show_legend: ba.fig.legend([scalar], loc='center', fontsize=vut.medium_font) ba.format() self._save_fig( ba.fig, f"contrast_plots/{scalar}/", f"{names[0]}_vs_{names[1]}_contrast_index") if not show_plots: ba.close_all() plt.ion() return ba.fig, ba.axes, ci_all_df def lateral_contrast_index(self, name, scalar="FA", show_plots=False, n_boot=1000, ylim=(-1, 1), positions=vut.POSITIONS, plot_subject_lines=True): """ Calculate the lateral contrast index for each bundle in a given dataset, for each dataset in names. Parameters ---------- name : string Names of dataset to plot profiles of. scalar : string, optional Scalar to use for the contrast index. Default: "FA". show_plots : bool, optional Whether to show plots if in an interactive environment. Default: False n_boot : int, optional Number of bootstrap resamples for seaborn to use to estimate the ci. Default: 1000 ylim : list of 2 floats, optional Minimum and maximum value used for y-axis bounds. If None, ylim is not set. Default: None positions : dictionary, optional Dictionary that maps bundle names to position in plot. Default: vut.POSITIONS plot_subject_lines : bool, optional Whether to plot individual subject lines with a smaller width. Default: True """ if not show_plots: plt.ioff() ba = BrainAxes(positions=positions) for j, bundle in enumerate(tqdm(self.bundles)): other_bundle = list(bundle) if other_bundle[-1] == 'L': other_bundle[-1] = 'R' elif other_bundle[-1] == 'R': other_bundle[-1] = 'L' else: continue other_bundle = "".join(other_bundle) if other_bundle not in self.bundles: continue ci_df = self._contrast_index_df_maker( [bundle, other_bundle], [name], scalar) ba.plot_line( bundle, "nodeID", "diff", ci_df, "ACI", ylim, n_boot, 1.0, {"color": self.color_dict[bundle]}, plot_subject_lines=plot_subject_lines) ba.save_temp_fig( f"contrast_plots/{scalar}/", f"{name}_lateral_contrast_index_{bundle}", self._save_fig) ba.fig.legend([scalar], loc='center', fontsize=vut.medium_font) ba.format() self._save_fig( ba.fig, f"contrast_plots/{scalar}/", f"{name}_lateral_contrast_index") if not show_plots: ba.close_all() if not show_plots: plt.ion() def reliability_plots(self, names=None, scalars=["FA", "MD"], ylims=[0.0, 1.0], show_plots=False, only_plot_above_thr=None, rotate_y_labels=False, rtype="Reliability", positions=vut.POSITIONS, fig_axes=None, prof_axes_dict={}, sub_axes_dict={}): """ Plot the scan-rescan reliability using ICC for 2 scalars. Parameters ---------- names : list of strings, optional Names of datasets to plot profiles of. If None, all datasets are used. Should be a total of only two datasets. Default: None scalars : list of strings, optional Scalars to correlate. Default: ["FA", "MD"]. ylims : 2-tuple of floats, optional Limits of the y-axis. Useful to synchronize axes across graphs. Default: [0.0, 1.0]. show_plots : bool, optional Whether to show plots if in an interactive environment. Default: False only_plot_above_thr : int or None, optional Only plot bundles with intrersubject reliability above this threshold on the final reliability bar plots. If None, plot all. Default: None rotate_y_labels : bool, optional Rotate y labels on final reliability plots. Default: False rtype : str, optional Type of reliability to name the y axis of the reliability bar charts. Default: "Reliability" positions : dictionary, optional Dictionary that maps bundle names to position in plot. Default: vut.POSITIONS fig_axes : tuple of matplotlib figure and axes, optional If not None, the resulting reliability plots will use this figure and axes. Default: None prof_axes_dict : dictionary of axes, optional Plot profile reliability histograms for bundles that are keys of prof_axes_dict on the corresponding axis. Default: {} sub_axes_dict : dictionary of axes, optional Plot subject reliability scatter plots for bundles that are keys of sub_axes_dict on the corresponding axis. Default: {} Returns ------- Returns 8 objects: 1. Matplotlib figure 2. Matplotlib axes 3. A dictionary containing the number of missing bundles for each dataset. 4. A list of bundles with sufficient correlation 5. A pandas dataframe describing the intersubject reliabilities, per bundle 6. A numpy array describing the intersubject reliability errors, per bundle 7. A pandas dataframe desribing the profile reliabilities, per bundle 8. A numpy array desribing the profile reliability errors, per bundle """ if not show_plots: plt.ioff() if names is None: names = list(self.profile_dict.keys()) if len(names) != 2: self.logger.error("To plot correlations, " + "only two dataset names should be given") return None # extract relevant statistics / data from profiles N = len(self.subjects) all_sub_coef = np.zeros((len(scalars), len(self.bundles))) all_sub_coef_err = np.zeros((len(scalars), len(self.bundles), 2)) all_sub_means = np.zeros( (len(scalars), len(self.bundles), 2, N)) all_profile_coef = \ np.zeros((len(scalars), len(self.bundles), N)) all_node_coef = np.zeros( (len(scalars), len(self.bundles), self.prof_len)) miss_counts = pd.DataFrame(0, index=self.bundles, columns=[ f"miss_count{names[0]}", f"miss_count{names[1]}"]) for m, scalar in enumerate(scalars): for k, bundle in enumerate(tqdm(self.bundles)): bundle_profiles =\ np.zeros((2, N, self.prof_len)) for j, name in enumerate(names): for i, subject in enumerate(self.subjects): single_profile = self._get_profile( name, bundle, subject, scalar) if single_profile is None: bundle_profiles[j, i] = np.nan miss_counts.at[bundle, f"miss_count{name}"] =\ miss_counts.at[ bundle, f"miss_count{name}"] + 1 else: bundle_profiles[j, i] = single_profile all_sub_means[m, k] = np.nanmean(bundle_profiles, axis=2) all_sub_coef[m, k], all_sub_coef_err[m, k, 0], \ all_sub_coef_err[m, k, 1] =\ self.masked_corr(all_sub_means[m, k], "Srho") if np.isnan(all_sub_coef[m, k]).all(): self.logger.error(( f"Not enough non-nan profiles" f"for scalar {scalar} for bundle {bundle}")) all_sub_coef[m, k] = 0 bundle_coefs = np.zeros(N) for i in range(N): bundle_coefs[i], _, _ = \ self.masked_corr(bundle_profiles[:, i, :], "ICC") all_profile_coef[m, k] = bundle_coefs node_coefs = np.zeros(self.prof_len) for i in range(self.prof_len): node_coefs[i], _, _ =\ self.masked_corr(bundle_profiles[:, :, i], "ICC") all_node_coef[m, k] = node_coefs # plot histograms of subject ICC maxi = np.nanmax(all_profile_coef) mini = np.nanmin(all_profile_coef) bins = np.linspace(mini, maxi, 10) ba = BrainAxes(positions=positions) for k, bundle in enumerate(self.bundles): ax = ba.get_axis(bundle, axes_dict=prof_axes_dict) for m, scalar in enumerate(scalars): bundle_coefs = all_profile_coef[m, k] bundle_coefs = bundle_coefs[~np.isnan(bundle_coefs)] sns.set(style="whitegrid") sns.histplot( data=bundle_coefs, bins=bins, alpha=0.5, color=self.color_dict[bundle], hatch=self.patterns[m], label=scalar, ax=ax) ax.set_title(display_string(bundle), fontsize=vut.large_font) ax.set_xlabel(self.ICC_func_name, fontsize=vut.medium_font) ax.set_ylabel("Subject count", fontsize=vut.medium_font) ba.temp_fig.legend( display_string(scalars), fontsize=vut.medium_font) ba.save_temp_fig( f"rel_plots/{'_'.join(scalars)}/verbose", (f"{names[0]}_vs_{names[1]}_profile_r_distributions" f"_{bundle}"), self._save_fig) legend_labels = [] for m, _ in enumerate(scalars): legend_labels.append(Patch( facecolor='k', hatch=self.patterns[m])) ba.fig.legend( legend_labels, display_string(scalars), loc='center', fontsize=vut.medium_font) ba.format(disable_x=False) self._save_fig( ba.fig, f"rel_plots/{'_'.join(scalars)}/verbose", f"{names[0]}_vs_{names[1]}_profile_r_distributions") if not show_plots: ba.close_all() # plot node reliability profile all_node_coef[np.isnan(all_node_coef)] = 0 if ylims is None: maxi = all_node_coef.max() mini = all_node_coef.min() else: maxi = ylims[1] mini = ylims[0] ba = BrainAxes(positions=positions) for k, bundle in enumerate(self.bundles): ax = ba.get_axis(bundle) for m, scalar in enumerate(scalars): sns.set(style="whitegrid") sns.lineplot( data=all_node_coef[m, k], label=scalar, color=vut.tableau_20[m * 2], ax=ax, legend=False, ci=None, estimator=None) ax.set_ylim([mini, maxi]) ax.set_title(display_string(bundle), fontsize=vut.large_font) ax.set_ylabel(self.ICC_func_name, fontsize=vut.medium_font) ba.temp_fig.legend( display_string(scalars), fontsize=vut.medium_font) ba.save_temp_fig( f"rel_plots/{'_'.join(scalars)}/verbose", (f"{names[0]}_vs_{names[1]}_node_profiles" f"_{bundle}"), self._save_fig) ba.fig.legend(display_string(scalars), loc='center', fontsize=vut.medium_font) ba.format() self._save_fig( ba.fig, f"rel_plots/{'_'.join(scalars)}/verbose", f"{names[0]}_vs_{names[1]}_node_profiles") if not show_plots: ba.close_all() # plot mean profile scatter plots for m, scalar in enumerate(scalars): maxi = np.nanmax(all_sub_means[m]) mini = np.nanmin(all_sub_means[m]) if len(scalars) == 2: twinning_next = (m == 0) twinning = (m == 1) else: twinning = False twinning_next = False if twinning: ba = BrainAxes(positions=positions, fig=ba.fig) else: ba = BrainAxes(positions=positions) for k, bundle in enumerate(self.bundles): if twinning: fc = 'w' ec = self.color_dict[bundle] else: fc = self.color_dict[bundle] ec = 'w' ax = ba.get_axis(bundle, axes_dict=sub_axes_dict) sns.set(style="whitegrid") if not twinning: ax.plot( [[0, 0], [1, 1]], [[0, 0], [1, 1]], '--', color='red') ax.scatter( all_sub_means[m, k, 0], all_sub_means[m, k, 1], label=scalar, marker=self.scalar_markers[m - twinning], facecolors=fc, edgecolors=ec, s=vut.marker_size, linewidth=1) if twinning or twinning_next: twinning_color = 'k' if twinning_next: ax.spines['bottom'].set_color(twinning_color) ax.spines['left'].set_color(twinning_color) else: ax.spines['top'].set_color(twinning_color) ax.spines['right'].set_color(twinning_color) ax.xaxis.label.set_color(twinning_color) ax.tick_params(axis='x', colors=twinning_color) ax.xaxis.label.set_color(twinning_color) ax.yaxis.label.set_color(twinning_color) ax.tick_params(axis='y', colors=twinning_color) ax.yaxis.label.set_color(twinning_color) if not twinning: ax.set_title( display_string(bundle), fontsize=vut.large_font) ax.set_xlabel(names[0], fontsize=vut.medium_font) ax.set_ylabel(names[1], fontsize=vut.medium_font) ax.set_ylim([mini, maxi]) ax.set_xlim([mini, maxi]) ba.temp_fig.legend( [scalar], fontsize=vut.medium_font) ba.save_temp_fig( f"rel_plots/{'_'.join(scalars)}/verbose", (f"{names[0]}_vs_{names[1]}_{scalar}_mean_profiles" f"_{bundle}"), self._save_fig) if twinning: legend_labels = [ Line2D( [], [], markerfacecolor='k', markeredgecolor='w', marker=self.scalar_markers[0], linewidth=0, markersize=15), Line2D( [], [], markeredgecolor='k', markerfacecolor='w', marker=self.scalar_markers[0], linewidth=0, markersize=15)] ba.fig.legend( legend_labels, display_string(scalars), loc='center', fontsize=vut.medium_font) elif not twinning_next: ba.fig.legend([scalar], loc='center', fontsize=vut.medium_font) self._save_fig( ba.fig, f"rel_plots/{'_'.join(scalars)}/verbose", f"{names[0]}_vs_{names[1]}_{scalar}_mean_profiles") ba.format(disable_x=False) if not (show_plots or twinning_next): ba.close_all() # plot bar plots of ICC if fig_axes is None: fig, axes = plt.subplots(2, 1) fig.set_size_inches((8, 8)) else: fig = fig_axes[0] axes = fig_axes[1] bundle_prof_means = np.nanmean(all_profile_coef, axis=2) bundle_prof_stds = 1.95 * \ sem(all_profile_coef, axis=2, nan_policy='omit') if ylims is None: maxi = np.maximum(bundle_prof_means.max(), all_sub_coef.max()) mini = np.minimum(bundle_prof_means.min(), all_sub_coef.min()) else: maxi = ylims[1] mini = ylims[0] if only_plot_above_thr is not None: is_removed_bundle =\ np.logical_not( np.logical_and( np.all(all_sub_coef > only_plot_above_thr, axis=0), np.all(bundle_prof_means > only_plot_above_thr, axis=0))) removal_idx = np.where(is_removed_bundle)[0] bundle_prof_means_removed = np.delete( bundle_prof_means, removal_idx, axis=1) bundle_prof_stds_removed = np.delete( bundle_prof_stds, removal_idx, axis=1) all_sub_coef_removed = np.delete( all_sub_coef, removal_idx, axis=1) all_sub_coef_err_removed = np.delete( all_sub_coef_err, removal_idx, axis=1) else: is_removed_bundle = [False] * len(self.bundles) bundle_prof_means_removed = bundle_prof_means bundle_prof_stds_removed = bundle_prof_stds all_sub_coef_err_removed = all_sub_coef_err all_sub_coef_removed = all_sub_coef updated_bundles = [] for k, bundle in enumerate(self.bundles): if not is_removed_bundle[k]: if bundle == "CC_ForcepsMinor": updated_bundles.append("CC_FMi") else: updated_bundles.append(bundle) updated_bundles.append("median") sns.set(style="whitegrid") width = 0.6 spacing = 1.5 x = np.arange(len(updated_bundles)) * spacing x_shift = np.linspace(-0.5 * width, 0.5 * width, num=len(scalars)) bundle_prof_means_removed = np.pad( bundle_prof_means_removed, [(0, 0), (0, 1)]) bundle_prof_stds_removed = np.pad( bundle_prof_stds_removed, [(0, 0), (0, 1)]) all_sub_coef_removed = np.pad( all_sub_coef_removed, [(0, 0), (0, 1)]) all_sub_coef_err_removed = np.transpose(np.pad( all_sub_coef_err_removed, [(0, 0), (0, 1), (0, 0)])) for m, scalar in enumerate(scalars): bundle_prof_means_removed[m, -1] = np.median( bundle_prof_means_removed[m, :-1]) all_sub_coef_removed[m, -1] = np.median( all_sub_coef_removed[m, :-1]) # This code can be used as a baseline to make violin plots # # mask = ~np.isnan(all_profile_coef[m].T) # all_profile_coef_m_removed =\ # [d[k] for d, k in zip(all_profile_coef[m], mask.T)] # vl_parts = axes[0].violinplot( # all_profile_coef_m_removed, # positions=x[:-1] + x_shift[m], # showmedians=True # ) # color_list = list(self.color_dict.values()) # for c, pc in enumerate(vl_parts['bodies']): # pc.set_facecolor(color_list[c]) # pc.set_edgecolor(color_list[c]) # pc.set_alpha(1) # pc.set_hatch(self.patterns[m]) # vl_parts['cbars'].set_color(color_list) axes[0].bar( x + x_shift[m], bundle_prof_means_removed[m], width, label=scalar, yerr=bundle_prof_stds_removed[m], hatch=self.patterns[m], color=self.color_dict.values()) axes[1].bar( x + x_shift[m], all_sub_coef_removed[m], width, label=scalar, yerr=all_sub_coef_err_removed[:, :, m], hatch=self.patterns[m], color=self.color_dict.values()) if len(updated_bundles) > 20: xaxis_font_size = vut.small_font - 6 else: xaxis_font_size = vut.small_font axes[0].set_title("A", fontsize=vut.large_font) axes[0].set_ylabel(f'Profile {rtype}', fontsize=vut.medium_font) axes[0].set_ylim([mini, maxi]) axes[0].set_xlabel("") axes[0].set_yticklabels( [0, 0.2, 0.4, 0.6, 0.8, 1.0], fontsize=vut.small_font - 8) axes[0].set_xticks(x + 0.5) axes[0].set_xticklabels( display_string(updated_bundles), fontsize=xaxis_font_size) axes[1].set_title("B", fontsize=vut.large_font) axes[1].set_ylabel(f'Subject {rtype}', fontsize=vut.medium_font) axes[1].set_ylim([mini, maxi]) axes[1].set_xlabel("") axes[1].set_yticklabels( [0, 0.2, 0.4, 0.6, 0.8, 1.0], fontsize=vut.small_font - 8) axes[1].set_xticks(x + 0.5) axes[1].set_xticklabels( display_string(updated_bundles), fontsize=xaxis_font_size) plt.setp(axes[0].get_xticklabels(), rotation=65, horizontalalignment='right') plt.setp(axes[1].get_xticklabels(), rotation=65, horizontalalignment='right') if rotate_y_labels: plt.setp(axes[0].get_yticklabels(), rotation=90) plt.setp(axes[1].get_yticklabels(), rotation=90) fig.tight_layout() legend_labels = [] for m, _ in enumerate(scalars): legend_labels.append(Patch( facecolor='k', hatch=self.patterns[m])) fig.legend( legend_labels, display_string(scalars), fontsize=vut.small_font, bbox_to_anchor=(1.25, 0.5)) self._save_fig( fig, f"rel_plots/{'_'.join(scalars)}", f"{names[0]}_vs_{names[1]}") if not show_plots: plt.close(fig) plt.ion() return fig, axes, miss_counts, updated_bundles, \ all_sub_coef_removed, all_sub_coef_err_removed, \ bundle_prof_means_removed, bundle_prof_stds_removed def compare_reliability(self, reliability1, reliability2, analysis_label1, analysis_label2, bundles, errors1=None, errors2=None, scalars=["FA", "MD"], rtype="Subject Reliability", show_plots=False, show_legend=True, fig_ax=None): """ Plot a comparison of scan-rescan reliability between two analyses. Parameters ---------- reliability1, reliability2 : numpy arrays numpy arrays of reliabilities. Typically, each of this will be outputs of separate calls to reliability_plots. analysis_label1, analysis_label2 : Strings Names of the analyses used to obtain each dataset. Used to label the x and y axes. bundles : list of str List of bundles that correspond to the second dimension of the reliability arrays. errors1, errors2 : numpy arrays or None Numpy arrays describing the errors. Typically, each of this will be outputs of separate calls to reliability_plots. If None, errors are not shown. Default is None. scalars : list of str, optional Lsit of scalars that correspond to the first dimension of the reliability arrays. Default: ["FA", "MD"] rtype : str type of reliability. Can be any string; used in x axis lavel. Default: Subject Reliability show_plots : bool, optional Whether to show plots if in an interactive environment. Default: False show_legend : bool, optional Show legend for the plot, off to the right hand side. Default: True fig_ax : tuple of matplotlib figure and axis, optional If not None, the resulting reliability plots will use this figure and axis. Default: None Returns ------- Returns a Matplotlib figure and axes. """ show_error = ((errors1 is not None) and (errors2 is not None)) if fig_ax is None: fig, ax = plt.subplots() else: fig = fig_ax[0] ax = fig_ax[1] legend_labels = [] for i, scalar in enumerate(scalars): marker = self.scalar_markers[i] if marker == "x": marker = marker.upper() for j, bundle in enumerate(bundles): ax.scatter( reliability1[i, j], reliability2[i, j], s=vut.marker_size, c=[self.color_dict[bundle]], marker=marker ) if show_error: if len(errors1.shape) > 2: xerr = errors1[:, j, i].reshape((2, 1)) else: xerr = errors1[i, j] if len(errors2.shape) > 2: yerr = errors2[:, j, i].reshape((2, 1)) else: yerr = errors2[i, j] ax.errorbar( reliability1[i, j], reliability2[i, j], xerr=xerr, yerr=yerr, c=[self.color_dict[bundle]], alpha=0.5, fmt="none" ) if i == 0: legend_labels.append(Patch( facecolor=self.color_dict[bundle], label=bundle)) legend_labels.append(Line2D( [0], [0], marker=marker, color='k', lw=0, markersize=10, label=scalar)) ax.set_xlabel(f"{analysis_label1} {rtype}", fontsize=vut.medium_font) ax.set_ylabel(f"{analysis_label2} {rtype}", fontsize=vut.medium_font) ax.tick_params( axis='x', which='major', labelsize=vut.medium_font) ax.tick_params( axis='y', which='major', labelsize=vut.medium_font) ax.set_ylim(0.2, 1) ax.set_xlim(0.2, 1) ax.plot([[0, 0], [1, 1]], [[0, 0], [1, 1]], '--', color='red') legend_labels.append(Line2D( [0], [0], linewidth=3, linestyle='--', color='red', label='X=Y')) if show_legend: fig.legend( handles=legend_labels, fontsize=vut.small_font - 6, bbox_to_anchor=(1.5, 2.0)) fig.tight_layout() return fig, ax
[docs]def visualize_gif_inline(fname, use_s3fs=False): """Display a gif inline, possible from s3fs """ if use_s3fs: import s3fs fs = s3fs.S3FileSystem() tdir = tempfile.gettempdir() fname_remote = fname fname = op.join(tdir, "fig.gif") fs.get(fname_remote, fname) display.display(display.Image(fname))
def show_anatomical_slices(img_data, title): """ display anatomical slices from midpoint based on: https://nipy.org/nibabel/coordinate_systems.html """ axial_slice = img_data[:, :, int(img_data.shape[2] / 2)] coronal_slice = img_data[:, int(img_data.shape[1] / 2), :] sagittal_slice = img_data[int(img_data.shape[0] / 2), :, :] fig = plt.figure(constrained_layout=False) gs = fig.add_gridspec(nrows=3, ncols=2, wspace=0.01, hspace=0.01) ax1 = fig.add_subplot(gs[:-1, :]) ax1.imshow(axial_slice.T, cmap="gray", origin="lower") ax1.axis('off') ax2 = fig.add_subplot(gs[2, 0]) ax2.imshow(coronal_slice.T, cmap="gray", origin="lower") ax2.axis('off') ax3 = fig.add_subplot(gs[2, 1]) ax3.imshow(sagittal_slice.T, cmap="gray", origin="lower") ax3.axis('off') plt.suptitle(title) plt.show()