{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Plotting Default Regions of Interest (ROIs) to Understand the Tracts\nThis script visualizes the default Regions of Interest (ROIs) for the\nwhite matter tracts we recognize by default in pyAFQ. It loads predefined\ntract templates into MNI space, extracts inclusion, exclusion, start, and\nend ROIs from the tracts, and generates multi-panel figures showing sagittal,\ncoronal, and axial views of these ROIs overlaid on the MNI template T1w brain.\n\nThe visualization helps understand the spatial relationships between tracts and\ntheir defining ROIs.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import libraries, load the default tract templates\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\n\nimport matplotlib\nmatplotlib.use('Agg') # Use Agg backend for headless plotting\nimport matplotlib.pyplot as plt\n\nimport AFQ.data.fetch as afd\nimport AFQ.api.bundle_dict as abd\n\n\ntemplates = abd.default_bd() + abd.callosal_bd()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define a function to visualize ROIs for a specific tract\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def visualize_tract_rois(tract_name):\n \"\"\"\n Visualize ROIs for a specific tract overlaid on the template brain.\n\n Parameters\n ----------\n tract_name : str\n Name of the tract\n\n Returns\n -------\n fig : matplotlib figure\n Figure with the visualization\n \"\"\"\n # Get the template brain\n template_brain = afd.read_mni_template(\n resolution=1, mask=True, weight=\"T1w\")\n template_data = template_brain.get_fdata()\n\n figures = []\n\n # Get the ROIs for this tract and hemisphere\n if tract_name not in templates:\n raise ValueError(f\"Tract {tract_name} not found in templates.\")\n bundle_info = templates[tract_name]\n\n # Collect all ROIs with their roles\n all_roi_images = []\n\n # Add include ROIs\n if 'include' in bundle_info:\n all_roi_images.extend([\n (image, \"Inclusion\") for image in bundle_info['include']])\n\n # Add exclude ROIs\n if 'exclude' in bundle_info:\n all_roi_images.extend([\n (image, \"Exclusion\") for image in bundle_info['exclude']])\n\n # Add start ROIs\n if 'start' in bundle_info:\n all_roi_images.append((bundle_info[\"start\"], \"Start\"))\n\n # Add end ROIs\n if 'end' in bundle_info:\n all_roi_images.append((bundle_info[\"end\"], \"End\"))\n\n if not all_roi_images:\n raise ValueError(f\"No ROIs found for tract {tract_name}\")\n\n # Create a figure\n fig, axes = plt.subplots(3,\n len(all_roi_images),\n figsize=(len(all_roi_images) * 4, 10))\n fig.suptitle(f\"{tract_name} ROIs\", fontsize=16)\n\n # Handle case with just one ROI\n if len(all_roi_images) == 1:\n axes = np.array([axes]).reshape(3, 1)\n\n # Get dimensions\n x, y, z = template_data.shape\n mid_x, mid_y, mid_z = x // 2, y // 2, z // 2\n\n # Function to get slice index with maximum ROI coverage\n def get_max_slice(roi_img, axis=0):\n roi_data = roi_img.get_fdata()\n if axis == 0: # Sagittal\n sums = np.sum(roi_data, axis=(1, 2))\n return np.argmax(sums) if np.any(sums) else mid_x\n elif axis == 1: # Coronal\n sums = np.sum(roi_data, axis=(0, 2))\n return np.argmax(sums) if np.any(sums) else mid_y\n else: # Axial\n sums = np.sum(roi_data, axis=(0, 1))\n return np.argmax(sums) if np.any(sums) else mid_z\n\n # Color mapping for different ROI types\n roi_type_colors = {\n \"Inclusion\": 'Greens',\n \"Exclusion\": 'Reds',\n \"Start\": 'Blues',\n \"End\": 'Purples'\n }\n\n # Find best slices for each ROI individually\n for i, (roi_img, roi_type_name) in enumerate(all_roi_images):\n roi_data = roi_img.get_fdata()\n\n # Get best slices for this ROI\n best_x = get_max_slice(roi_img, axis=0)\n best_y = get_max_slice(roi_img, axis=1)\n best_z = get_max_slice(roi_img, axis=2)\n\n # Assign color based on ROI type\n roi_color = roi_type_colors[roi_type_name]\n\n # Plot sagittal view (first row)\n ax = axes[0, i]\n ax.imshow(np.rot90(template_data[best_x, :, :]), cmap='gray')\n mask = np.rot90(roi_data[best_x, :, :])\n ax.imshow(mask, alpha=0.5, cmap=roi_color)\n if i == 0:\n ax.set_ylabel('Sagittal')\n ax.set_title(f\"{tract_name}\\n({roi_type_name})\")\n\n # Plot coronal view (second row)\n ax = axes[1, i]\n ax.imshow(np.rot90(template_data[:, best_y, :]), cmap='gray')\n mask = np.rot90(roi_data[:, best_y, :])\n ax.imshow(mask, alpha=0.5, cmap=roi_color)\n if i == 0:\n ax.set_ylabel('Coronal')\n\n # Plot axial view (third row)\n ax = axes[2, i]\n ax.imshow(np.rot90(template_data[:, :, best_z]), cmap='gray')\n mask = np.rot90(roi_data[:, :, best_z])\n ax.imshow(mask, alpha=0.5, cmap=roi_color)\n if i == 0:\n ax.set_ylabel('Axial')\n\n # Turn off axes for cleaner look\n for row in axes:\n for ax in row:\n ax.axis('off')\n\n plt.tight_layout()\n\n figures.append(fig)\n\n return figures" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create visualization for each tract\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "for bundle_name in templates.bundle_names:\n print(f\"Visualizing ROIs for tract: {bundle_name}\")\n figs = visualize_tract_rois(bundle_name)\n for ii, fig in enumerate(figs):\n fig.savefig(f\"{bundle_name}_{ii}.png\")\n plt.close(fig)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.13" } }, "nbformat": 4, "nbformat_minor": 0 }