Plotting Default Regions of Interest (ROIs) to Understand the Tracts#

This script visualizes the default Regions of Interest (ROIs) for the white matter tracts we recognize by default in pyAFQ. It loads predefined tract templates into MNI space, extracts inclusion, exclusion, start, and end ROIs from the tracts, and generates multi-panel figures showing sagittal, coronal, and axial views of these ROIs overlaid on the MNI template T1w brain.

The visualization helps understand the spatial relationships between tracts and their defining ROIs.

Import libraries, load the default tract templates

import numpy as np

import matplotlib
matplotlib.use('Agg')  # Use Agg backend for headless plotting
import matplotlib.pyplot as plt

import AFQ.data.fetch as afd
import AFQ.api.bundle_dict as abd


templates = abd.default_bd() + abd.callosal_bd()
/opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
2026-05-19 01:07:50,030	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.

Define a function to visualize ROIs for a specific tract

def visualize_tract_rois(tract_name):
    """
    Visualize ROIs for a specific tract overlaid on the template brain.

    Parameters
    ----------
    tract_name : str
        Name of the tract

    Returns
    -------
    fig : matplotlib figure
        Figure with the visualization
    """
    # Get the template brain
    template_brain = afd.read_mni_template(
        resolution=1, mask=True, weight="T1w")
    template_data = template_brain.get_fdata()

    figures = []

    # Get the ROIs for this tract and hemisphere
    if tract_name not in templates:
        raise ValueError(f"Tract {tract_name} not found in templates.")
    bundle_info = templates[tract_name]

    # Collect all ROIs with their roles
    all_roi_images = []

    # Add include ROIs
    if 'include' in bundle_info:
        all_roi_images.extend([
            (image, "Inclusion") for image in bundle_info['include']])

    # Add exclude ROIs
    if 'exclude' in bundle_info:
        all_roi_images.extend([
            (image, "Exclusion") for image in bundle_info['exclude']])

    # Add start ROIs
    if 'start' in bundle_info:
        all_roi_images.append((bundle_info["start"], "Start"))

    # Add end ROIs
    if 'end' in bundle_info:
        all_roi_images.append((bundle_info["end"], "End"))

    if not all_roi_images:
        raise ValueError(f"No ROIs found for tract {tract_name}")

    # Create a figure
    fig, axes = plt.subplots(3,
                             len(all_roi_images),
                             figsize=(len(all_roi_images) * 4, 10))
    fig.suptitle(f"{tract_name} ROIs", fontsize=16)

    # Handle case with just one ROI
    if len(all_roi_images) == 1:
        axes = np.array([axes]).reshape(3, 1)

    # Get dimensions
    x, y, z = template_data.shape
    mid_x, mid_y, mid_z = x // 2, y // 2, z // 2

    # Function to get slice index with maximum ROI coverage
    def get_max_slice(roi_img, axis=0):
        roi_data = roi_img.get_fdata()
        if axis == 0:  # Sagittal
            sums = np.sum(roi_data, axis=(1, 2))
            return np.argmax(sums) if np.any(sums) else mid_x
        elif axis == 1:  # Coronal
            sums = np.sum(roi_data, axis=(0, 2))
            return np.argmax(sums) if np.any(sums) else mid_y
        else:  # Axial
            sums = np.sum(roi_data, axis=(0, 1))
            return np.argmax(sums) if np.any(sums) else mid_z

    # Color mapping for different ROI types
    roi_type_colors = {
        "Inclusion": 'Greens',
        "Exclusion": 'Reds',
        "Start": 'Blues',
        "End": 'Purples'
    }

    # Find best slices for each ROI individually
    for i, (roi_img, roi_type_name) in enumerate(all_roi_images):
        roi_data = roi_img.get_fdata()

        # Get best slices for this ROI
        best_x = get_max_slice(roi_img, axis=0)
        best_y = get_max_slice(roi_img, axis=1)
        best_z = get_max_slice(roi_img, axis=2)

        # Assign color based on ROI type
        roi_color = roi_type_colors[roi_type_name]

        # Plot sagittal view (first row)
        ax = axes[0, i]
        ax.imshow(np.rot90(template_data[best_x, :, :]), cmap='gray')
        mask = np.rot90(roi_data[best_x, :, :])
        ax.imshow(mask, alpha=0.5, cmap=roi_color)
        if i == 0:
            ax.set_ylabel('Sagittal')
        ax.set_title(f"{tract_name}\n({roi_type_name})")

        # Plot coronal view (second row)
        ax = axes[1, i]
        ax.imshow(np.rot90(template_data[:, best_y, :]), cmap='gray')
        mask = np.rot90(roi_data[:, best_y, :])
        ax.imshow(mask, alpha=0.5, cmap=roi_color)
        if i == 0:
            ax.set_ylabel('Coronal')

        # Plot axial view (third row)
        ax = axes[2, i]
        ax.imshow(np.rot90(template_data[:, :, best_z]), cmap='gray')
        mask = np.rot90(roi_data[:, :, best_z])
        ax.imshow(mask, alpha=0.5, cmap=roi_color)
        if i == 0:
            ax.set_ylabel('Axial')

    # Turn off axes for cleaner look
    for row in axes:
        for ax in row:
            ax.axis('off')

    plt.tight_layout()

    figures.append(fig)

    return figures

Create visualization for each tract

for bundle_name in templates.bundle_names:
    print(f"Visualizing ROIs for tract: {bundle_name}")
    figs = visualize_tract_rois(bundle_name)
    for ii, fig in enumerate(figs):
        fig.savefig(f"{bundle_name}_{ii}.png")
        plt.close(fig)
Visualizing ROIs for tract: Left Optic Radiation
Visualizing ROIs for tract: Right Optic Radiation
Visualizing ROIs for tract: Left Anterior Thalamic
Visualizing ROIs for tract: Right Anterior Thalamic
Visualizing ROIs for tract: Left Cingulum Cingulate
Visualizing ROIs for tract: Right Cingulum Cingulate
Visualizing ROIs for tract: Left Corticospinal
Visualizing ROIs for tract: Right Corticospinal
Visualizing ROIs for tract: Left Inferior Fronto-occipital
Visualizing ROIs for tract: Right Inferior Fronto-occipital
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[3], line 3
      1 for bundle_name in templates.bundle_names:
      2     print(f"Visualizing ROIs for tract: {bundle_name}")
----> 3     figs = visualize_tract_rois(bundle_name)
      4     for ii, fig in enumerate(figs):
      5         fig.savefig(f"{bundle_name}_{ii}.png")
      6         plt.close(fig)

Cell In[2], line 23, in visualize_tract_rois(tract_name)
     19 
     20     figures = []
     21 
     22     # Get the ROIs for this tract and hemisphere
---> 23     if tract_name not in templates:
     24         raise ValueError(f"Tract {tract_name} not found in templates.")
     25     bundle_info = templates[tract_name]
     26 

File <frozen _collections_abc>:818, in Mapping.__contains__(self, key)
    815 'Could not get source, probably due dynamically evaluated source code.'

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/AFQ/api/bundle_dict.py:1369, in BundleDict.__getitem__(self, key)
   1367 if not self.keep_in_memory:
   1368     _item = self._dict[key].copy()
-> 1369     _res = self._cond_load_bundle(key, dry_run=True)
   1370     if _res is not None:
   1371         _item.update(_res)

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/AFQ/api/bundle_dict.py:1457, in BundleDict._cond_load_bundle(self, b_name, dry_run)
   1455 else:
   1456     resample_to = self.resample_subject_to
-> 1457 return self.apply_to_rois(
   1458     b_name,
   1459     self._cond_load,
   1460     resample_to,
   1461     dry_run=dry_run,
   1462     apply_to_recobundles=True,
   1463 )

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/AFQ/api/bundle_dict.py:1440, in BundleDict.apply_to_rois(self, b_name, *args, **kwargs)
   1431 def apply_to_rois(self, b_name, *args, **kwargs):
   1432     """
   1433     See: AFQ.api.bundle_dict.apply_to_roi_dict
   1434 
   (...)   1438         bundle name of bundle whose ROIs will be transformed.
   1439     """
-> 1440     return apply_to_roi_dict(self._dict[b_name], *args, **kwargs)

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/AFQ/api/bundle_dict.py:1736, in apply_to_roi_dict(dict_, func, dry_run, apply_to_recobundles, apply_to_prob_map, *args, **kwargs)
   1734             changed_rois = []
   1735             for _roi in dict_[roi_type]:
-> 1736                 changed_rois.append(func(_roi, *args, **kwargs))
   1737             return_vals[roi_type] = changed_rois
   1738 if apply_to_recobundles and "recobundles" in dict_:

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/AFQ/api/bundle_dict.py:1312, in BundleDict._cond_load(self, roi_or_sl, resample_to)
   1310 if isinstance(roi_or_sl, str):
   1311     if ".nii" in roi_or_sl:
-> 1312         return afd.read_resample_roi(roi_or_sl, resample_to=resample_to)
   1313     else:
   1314         return load_tractogram(
   1315             roi_or_sl, "same", bbox_valid_check=False
   1316         ).streamlines

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/AFQ/data/fetch.py:655, in read_resample_roi(roi, resample_to, threshold)
    652     logger.info("Resampling skipped as affines already match.")
    653     return roi
--> 655 as_array = resample(
    656     roi.get_fdata(),
    657     resample_to,
    658     moving_affine=roi.affine,
    659     static_affine=resample_to.affine,
    660 ).get_fdata()
    661 if threshold:
    662     as_array = (as_array > threshold).astype(int)

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/dipy/testing/decorators.py:201, in warning_for_keywords.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    194 # Check if the current version is within the warning range
    195 if (
    196     version.parse(from_version)
    197     <= version.parse(current_version)
    198     <= version.parse(until_version)
    199 ):
    200     # Convert positional to keyword arguments and issue a warning
--> 201     return convert_positional_to_keyword(func, args, kwargs)
    203 # If the version is greater than the until_version,
    204 # pass the arguments as they are
    205 elif version.parse(current_version) > version.parse(until_version):

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/dipy/testing/decorators.py:192, in warning_for_keywords.<locals>.decorator.<locals>.wrapper.<locals>.convert_positional_to_keyword(func, args, kwargs)
    182         warnings.warn(
    183             f"Pass {positionally_passed_kwonly_args} as keyword args. "
    184             f"From version {until_version} passing these as positional "
   (...)    187             stacklevel=3,
    188         )
    190     return func(*positional_args, **corrected_kwargs)
--> 192 return func(*args, **kwargs)

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/dipy/align/_public.py:398, in resample(moving, static, moving_affine, static_affine, between_affine)
    382 static, static_affine, moving, moving_affine, between_affine = (
    383     _handle_pipeline_inputs(
    384         moving,
   (...)    389     )
    390 )
    391 affine_map = AffineMap(
    392     between_affine,
    393     domain_grid_shape=static.shape,
   (...)    396     codomain_grid2world=moving_affine,
    397 )
--> 398 resampled = affine_map.transform(moving)
    399 return nib.Nifti1Image(resampled, static_affine)

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/dipy/testing/decorators.py:201, in warning_for_keywords.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    194 # Check if the current version is within the warning range
    195 if (
    196     version.parse(from_version)
    197     <= version.parse(current_version)
    198     <= version.parse(until_version)
    199 ):
    200     # Convert positional to keyword arguments and issue a warning
--> 201     return convert_positional_to_keyword(func, args, kwargs)
    203 # If the version is greater than the until_version,
    204 # pass the arguments as they are
    205 elif version.parse(current_version) > version.parse(until_version):

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/dipy/testing/decorators.py:192, in warning_for_keywords.<locals>.decorator.<locals>.wrapper.<locals>.convert_positional_to_keyword(func, args, kwargs)
    182         warnings.warn(
    183             f"Pass {positionally_passed_kwonly_args} as keyword args. "
    184             f"From version {until_version} passing these as positional "
   (...)    187             stacklevel=3,
    188         )
    190     return func(*positional_args, **corrected_kwargs)
--> 192 return func(*args, **kwargs)

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/dipy/align/imaffine.py:432, in AffineMap.transform(self, image, interpolation, image_grid2world, sampling_grid_shape, sampling_grid2world, resample_only)
    382 @warning_for_keywords()
    383 def transform(
    384     self,
   (...)    391     resample_only=False,
    392 ):
    393     """Transform the input image from co-domain to domain space.
    394 
    395     By default, the transformed image is sampled at a grid defined by
   (...)    430 
    431     """
--> 432     transformed = self._apply_transform(
    433         image,
    434         interpolation=interpolation,
    435         image_grid2world=image_grid2world,
    436         sampling_grid_shape=sampling_grid_shape,
    437         sampling_grid2world=sampling_grid2world,
    438         resample_only=resample_only,
    439         apply_inverse=False,
    440     )
    441     return np.array(transformed)

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/dipy/testing/decorators.py:201, in warning_for_keywords.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    194 # Check if the current version is within the warning range
    195 if (
    196     version.parse(from_version)
    197     <= version.parse(current_version)
    198     <= version.parse(until_version)
    199 ):
    200     # Convert positional to keyword arguments and issue a warning
--> 201     return convert_positional_to_keyword(func, args, kwargs)
    203 # If the version is greater than the until_version,
    204 # pass the arguments as they are
    205 elif version.parse(current_version) > version.parse(until_version):

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/dipy/testing/decorators.py:192, in warning_for_keywords.<locals>.decorator.<locals>.wrapper.<locals>.convert_positional_to_keyword(func, args, kwargs)
    182         warnings.warn(
    183             f"Pass {positionally_passed_kwonly_args} as keyword args. "
    184             f"From version {until_version} passing these as positional "
   (...)    187             stacklevel=3,
    188         )
    190     return func(*positional_args, **corrected_kwargs)
--> 192 return func(*args, **kwargs)

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/dipy/align/imaffine.py:379, in AffineMap._apply_transform(self, image, interpolation, image_grid2world, sampling_grid_shape, sampling_grid2world, resample_only, apply_inverse)
    377 if interpolation == "linear":
    378     image = image.astype(np.float64)
--> 379 transformed = _transform_method[(dim, interpolation)](image, shape, affine=comp)
    380 return transformed

KeyboardInterrupt: