Source code for AFQ.viz.altair

import altair as alt
import numpy as np
import scipy.stats as stats

from AFQ.viz.utils import COLOR_DICT


[docs] def altair_color_dict(names_to_include=None): """ Given a list of bundle names, return a dictionary of colors for each Formatted for Altair. """ altair_cd = dict(COLOR_DICT.copy()) for key in list(altair_cd.keys()): value = altair_cd[key] if (names_to_include is None) or (key in names_to_include): altair_cd[key] = ( f"rgb({int(value[0] * 255)}," f"{int(value[1] * 255)}," f"{int(value[2] * 255)})" ) else: del altair_cd[key] return altair_cd
[docs] def combined_profiles_df_to_altair_df(profiles, tissue_properties=None): """ Given a profiles dataframe that is combined from many subjects, return a dataframe formatted for Altair. """ if tissue_properties is None: tissue_properties = ["dti_fa", "dti_md"] profiles = profiles.copy() if "dki_md" in tissue_properties: profiles.dki_md = profiles.dki_md * 1000.0 if "dti_md" in tissue_properties: profiles.dti_md = profiles.dti_md * 1000.0 id_vars = ["tractID", "nodeID", "subjectID"] if "sessionID" in profiles.columns: id_vars.append("sessionID") profiles = profiles.melt( id_vars=id_vars, value_vars=tissue_properties, var_name="TP", value_name="Value" ) # Function to calculate 95% CI using a normal distribution def calculate_95CI(x): ci = stats.norm.interval( 0.95, loc=np.mean(x), scale=np.std(x) / np.sqrt(len(x)) ) return ci # Group by 'tractID', 'nodeID', 'TP' and apply the aggregation functions profiles = ( profiles.groupby(["tractID", "nodeID", "TP"])["Value"] .agg( mean="mean", CI_lower=lambda x: calculate_95CI(x)[0], CI_upper=lambda x: calculate_95CI(x)[1], IQR_lower=lambda x: x.quantile(0.25), IQR_upper=lambda x: x.quantile(0.75), ) .reset_index() ) def get_hemi(cc): if cc == "L": return "Left" elif cc == "R": return "Right" else: return "Callosal" def get_bname(s): if s.startswith("Left "): return s[5:] elif s.startswith("Right "): return s[6:] return s def formal_tp(tp_name): return tp_name.upper().replace("_", " ") profiles["Hemi"] = profiles["tractID"].apply(lambda x: get_hemi(x[-1])) profiles["Bundle Name"] = profiles["tractID"].apply(get_bname) profiles["TP"] = profiles["TP"].apply(formal_tp) return profiles
[docs] def altair_df_to_chart( profiles, position_domain=(20, 80), column_count=1, font_size=20, line_size=10, row_label_angle=90, bundle_list=None, legend_line_size=5, alt_x_kwargs=None, alt_y_kwargs=None, **kwargs, ): """ Given a dataframe formatted for Altair, probably from combined_profiles_df_to_altair_df, return a chart. Example ------- call_results = results[results.Hemi == "Callosal"] stand_results = results[results.Hemi != "Callosal"] prof_chart = altair_df_to_chart(call_results) prof_chart.save("supp_chart_call.png", dpi=300) prof_chart = altair_df_to_chart(stand_results, column_count=2, color="Hemi") prof_chart.save("supp_chart_stand.png", dpi=300) """ if alt_y_kwargs is None: alt_y_kwargs = {} if alt_x_kwargs is None: alt_x_kwargs = {} altair_color_dict(profiles.tractID.unique()) alt.data_transformers.disable_max_rows() profiles = profiles[ np.logical_and( profiles.nodeID >= position_domain[0], profiles.nodeID < position_domain[1] ) ] tp_units = { "DKI AWF": "", "DKI FA": "", "DKI MD": " (µm²/ms)", "DKI MK": "", "DTI FA": "", "DTI MD": " (µm²/ms)", } if bundle_list is None: bundle_list = profiles["Bundle Name"].unique() row_charts = [] for jj, b_name in enumerate(bundle_list): row_dataframe = profiles[profiles["Bundle Name"] == b_name] charts = [] for ii, tp in enumerate(sorted(profiles.TP.unique())): this_dataframe = row_dataframe[row_dataframe.TP == tp] if jj == 0: title_name = tp + tp_units[tp] else: title_name = "" if ii == 0: y_axis_title = b_name else: y_axis_title = "" if jj == len(profiles["Bundle Name"].unique()) - 1: x_axis_title = "Position (%)" useXlab = True else: x_axis_title = "" useXlab = False y_kwargs = { "scale": alt.Scale(zero=False), "title": y_axis_title, **alt_y_kwargs, } x_kwargs = { "axis": alt.Axis(title=x_axis_title, labels=useXlab), **alt_x_kwargs, } prof_chart = ( alt.Chart(this_dataframe, title=title_name) .mark_line(size=line_size) .encode( y=alt.Y("mean", **y_kwargs), x=alt.X("nodeID", **x_kwargs), **kwargs ) ) prof_chart = prof_chart + alt.Chart(this_dataframe).mark_line( size=line_size, opacity=0.5, strokeDash=[1, 1] ).encode( y=alt.Y("IQR_lower", **y_kwargs), x=alt.X("nodeID", **x_kwargs), **kwargs, ) prof_chart = prof_chart + alt.Chart(this_dataframe).mark_line( size=line_size, opacity=0.5, strokeDash=[1, 1] ).encode( y=alt.Y("IQR_upper", **y_kwargs), x=alt.X("nodeID", **x_kwargs), **kwargs, ) prof_chart = prof_chart + alt.Chart(this_dataframe).mark_line( size=line_size, opacity=0.5 ).encode( y=alt.Y("CI_lower", **y_kwargs), x=alt.X("nodeID", **x_kwargs), **kwargs ) prof_chart = prof_chart + alt.Chart(this_dataframe).mark_line( size=line_size, opacity=0.5 ).encode( y=alt.Y("CI_upper", **y_kwargs), x=alt.X("nodeID", **x_kwargs), **kwargs ) charts.append(prof_chart) row_charts.append(alt.HConcatChart(hconcat=charts)) return ( alt.VConcatChart(vconcat=row_charts) .configure_axis(labelFontSize=font_size, titleFontSize=font_size, labelLimit=0) .configure_legend( labelFontSize=font_size, titleFontSize=font_size, titleLimit=0, labelLimit=0, columns=column_count, symbolStrokeWidth=legend_line_size * 10, symbolSize=legend_line_size * 100, orient="right", ) .configure_title(fontSize=font_size) )