# Original source: github.com/SlicerDMRI/whitematteranalysis
# Copyright 2026 BWH and 3D Slicer contributors
# Licensed under 3D Slicer license (BSD style; https://github.com/SlicerDMRI/whitematteranalysis/blob/master/License.txt) # noqa
# Modified by John Kruper for pyAFQ
# Modifications:
# 1. Only mean distance included, and mean distance replaced with numba version.
# 2. Uses atlas data from dictionary and numpy files rather than pickled files,
# to avoid additional dependencies.
# 3. Added function to move template streamlines
# to subject space to calculate distances.
import numpy as np
import scipy
from dipy.io.stateful_tractogram import Space
from numba import njit, prange
import AFQ.data.fetch as afd
import AFQ.recognition.utils as abu
import AFQ.utils.streamlines as aus
@njit(parallel=True, fastmath=True, cache=True)
[docs]
def _compute_mean_euclidean_matrix(group_n, group_m):
len_n = group_n.shape[0]
len_m = group_m.shape[0]
num_points = group_n.shape[1]
dist_matrix = np.empty((len_n, len_m), dtype=np.float32)
for i in prange(len_n):
for j in range(len_m):
sum_dist = 0.0
sum_dist_ref = 0.0
for k in range(num_points):
dx = group_n[i, k, 0] - group_m[j, k, 0]
dx_ref = group_n[i, k, 0] + group_m[j, k, 0]
dy = group_n[i, k, 1] - group_m[j, k, 1]
dz = group_n[i, k, 2] - group_m[j, k, 2]
sum_dist += np.sqrt(dx * dx + dy * dy + dz * dz)
sum_dist_ref += np.sqrt(dx_ref * dx_ref + dy * dy + dz * dz)
mean_d = sum_dist / num_points
mean_d_ref = sum_dist_ref / num_points
final_d = min(mean_d, mean_d_ref)
dist_matrix[i, j] = final_d * final_d
return dist_matrix.T
[docs]
def _distance_to_similarity(distance, sigmasq):
similarities = np.exp(-distance / (sigmasq))
return similarities
[docs]
def _rectangular_similarity_matrix(fgarray_sub, fgarray_atlas, sigma):
distances = _compute_mean_euclidean_matrix(fgarray_sub, fgarray_atlas)
sigmasq = sigma * sigma
similarity_matrix = _distance_to_similarity(distances, sigmasq)
return similarity_matrix
[docs]
def spectral_atlas_label(
sub_fgarray,
atlas_fgarray,
atlas_data=None,
sigma_multiplier=1.0,
cluster_indices=None,
):
"""
Use an existing atlas to label a new streamlines.
Parameters
----------
sub_fgarray : ndarray
Resampled fiber group to be labeled.
atlas_fgarray : ndarray
Resampled atlas to use for labelling.
atlas_data : dict, optional
Precomputed atlas data formatted as a dictionary of arrays and floats.
See `afd.read_org800_templates` as a reference.
sigma_multiplier : float, optional
Multiplier for the sigma value used in computing the similarity
matrix. Default is 1.0.
cluster_indices : list of int, optional
If provided, only these cluster indices from the atlas will be used
for labeling. Default is None, which uses all clusters.
Returns
-------
tuple of (ndarray, ndarray)
Cluster indices for all the fibers and their embedding
"""
if atlas_data is None:
atlas_data = afd.read_org800_templates(load_trx=False)
number_fibers = sub_fgarray.shape[0]
sz = atlas_fgarray.shape[0]
# Compute fiber similarities.
B = _rectangular_similarity_matrix(
sub_fgarray, atlas_fgarray, sigma=atlas_data["sigma"] * sigma_multiplier
)
# Do Normalized Cuts transform of similarity matrix.
# row sum estimate for current B part of the matrix
row_sum_2 = np.sum(B, axis=0) + np.dot(atlas_data["row_sum_matrix"], B)
# This happens plenty in our cases. Why?
# Maybe a probabilistic vs UKF thing?
# In practice, this is not an issue since we just set to a small value.
if any(row_sum_2 <= 0):
row_sum_2[row_sum_2 < 0] = 1e-4
# Normalized cuts normalization
row_sum = np.concatenate((atlas_data["row_sum_1"], row_sum_2))
dhat = np.sqrt(np.divide(1, row_sum))
B = np.multiply(B, np.outer(dhat[0:sz], dhat[sz:].T))
# Compute embedding using eigenvectors
V = np.dot(
np.dot(B.T, atlas_data["e_vec"]), np.diag(np.divide(1.0, atlas_data["e_val"]))
)
V = np.divide(V, atlas_data["e_vec_norm"])
n_eigen = int(atlas_data["number_of_eigenvectors"])
embed = np.zeros((number_fibers, n_eigen))
for i in range(0, n_eigen):
embed[:, i] = np.divide(V[:, -(i + 2)], V[:, -1])
# Label streamlines using centroids from atlas
if cluster_indices is not None:
centroids = atlas_data["centroids"][cluster_indices, :]
cluster_idx, _ = scipy.cluster.vq.vq(embed, centroids)
cluster_idx = np.array([cluster_indices[i] for i in cluster_idx])
else:
cluster_idx, _ = scipy.cluster.vq.vq(embed, atlas_data["centroids"])
return cluster_idx, embed
[docs]
def subcluster_by_atlas(
sub_trk,
mapping,
dwi_ref,
cluster_indices,
atlas_data=None,
n_points=20,
batch_size=int(5e4),
):
"""
Use an existing atlas to label a new set of streamlines, and return the
cluster indices for each streamline.
Parameters
----------
sub_trk : StatefulTractogram
streamlines to be labeled.
mapping : DIPY or pyAFQ mapping
Mapping to use to move streamlines.
dwi_ref : Nifti1Image
Image defining reference for where the atlas streamlines move to.
cluster_indices : list of int
Cluster indices from the atlas to use for labeling.
atlas_data : dict, optional
Precomputed atlas data formatted as a dictionary of arrays and floats.
See `afd.read_org800_templates` as a reference.
n_points : int, optional
Number of points to resample streamlines to for labeling. Default is 20.
batch_size : int, optional
Number of streamlines to process in a batch. Default is 50,000.
"""
if atlas_data is None:
atlas_data = afd.read_org800_templates()
atlas_sft = atlas_data["tracks_reoriented"]
moved_atlas_sft = aus.move_streamlines(
atlas_sft, "subject", mapping, dwi_ref, to_space=Space.RASMM
)
atlas_fgarray = np.array(abu.resample_tg(moved_atlas_sft.streamlines, n_points))
sub_trk.to_rasmm()
n_sub = len(sub_trk.streamlines)
if n_sub <= batch_size:
sub_fgarray = np.asarray(
abu.resample_tg(sub_trk.streamlines, n_points), dtype=np.float32
)
cluster_idxs, _ = spectral_atlas_label(
sub_fgarray,
atlas_fgarray,
atlas_data=atlas_data,
cluster_indices=cluster_indices,
)
return cluster_idxs
all_idxs = np.empty(n_sub, dtype=np.int64)
for start in range(0, n_sub, batch_size):
end = min(start + batch_size, n_sub)
batch_sls = sub_trk.streamlines[start:end]
batch_fgarray = np.asarray(
abu.resample_tg(batch_sls, n_points), dtype=np.float32
)
batch_idxs, _ = spectral_atlas_label(
batch_fgarray,
atlas_fgarray,
atlas_data=atlas_data,
cluster_indices=cluster_indices,
)
all_idxs[start:end] = batch_idxs
return all_idxs