# -*- coding: utf-8 -*-
# Original source: github.com/scilus/scilpy
# Copyright (c) 2012-- Sherbrooke Connectivity Imaging Lab [SCIL],
# Université de Sherbrooke.
# Licensed under the MIT License (https://opensource.org/licenses/MIT).
# Modified by John Kruper for pyAFQ
# OpenCL and cosine filtering removed
# Replaced with numba
import logging
import numpy as np
from dipy.data import get_sphere
from dipy.direction import peak_directions
from dipy.reconst.shm import sh_to_sf, sh_to_sf_matrix, sph_harm_ind_list
from numba import config, njit, prange, set_num_threads
from tqdm import tqdm
logger = logging.getLogger("AFQ")
__all__ = [
"unified_filtering",
"compute_asymmetry_index",
"compute_odd_power_map",
"compute_nufid_asym",
]
def _get_sh_order_and_fullness(ncoeffs):
"""
Get the order of the SH basis from the number of SH coefficients
as well as a boolean indicating if the basis is full.
"""
# the two curves (sym and full) intersect at ncoeffs = 1, in what
# case both bases correspond to order 1.
sym_order = (-3.0 + np.sqrt(1.0 + 8.0 * ncoeffs)) / 2.0
if sym_order.is_integer():
return sym_order, False
full_order = np.sqrt(ncoeffs) - 1.0
if full_order.is_integer():
return full_order, True
raise ValueError("Invalid number of coefficients for SH basis.")
[docs]
def unified_filtering(
sh_data,
sphere,
sh_basis="descoteaux07",
is_legacy=True,
sigma_spatial=1.0,
sigma_align=0.8,
sigma_angle=None,
rel_sigma_range=0.2,
n_threads=None,
low_mem=False,
):
"""
Unified asymmetric filtering as described in [1].
Parameters
----------
sh_data: ndarray
SH coefficients image.
sphere: str or DIPY sphere
Name of the DIPY sphere to use for SH to SF projection.
sh_basis: str
SH basis definition used for input and output SH image.
One of 'descoteaux07' or 'tournier07'.
Default: 'descoteaux07'.
is_legacy: bool
Whether the legacy SH basis definition should be used.
Default: False.
sigma_spatial: float or None
Standard deviation of spatial filter. Can be None to replace
by mean filter, in what case win_hwidth must be given.
sigma_align: float or None
Standard deviation of alignment filter. `None` disables
alignment filtering.
sigma_angle: float or None
Standard deviation of the angle filter. `None` disables
angle filtering.
rel_sigma_range: float or None
Standard deviation of the range filter, relative to the
range of SF amplitudes. `None` disables range filtering.
n_threads: int or None
Number of threads to use for numba. If None, uses
the number of available threads.
Default: None.
low_mem: bool
Whether to use the low-memory version of the filtering.
It will be between 50% and 100% slower.
Default: False.
References
----------
[1] Poirier and Descoteaux, 2024, "A Unified Filtering Method for
Estimating Asymmetric Orientation Distribution Functions",
Neuroimage, https://doi.org/10.1016/j.neuroimage.2024.120516
"""
if isinstance(sphere, str):
sphere = get_sphere(name=sphere)
if sigma_spatial is not None:
if sigma_spatial <= 0.0:
raise ValueError("sigma_spatial cannot be <= 0.")
if sigma_align is not None:
if sigma_align <= 0.0:
raise ValueError("sigma_align cannot be <= 0.")
if sigma_angle is not None:
if sigma_angle <= 0.0:
raise ValueError("sigma_angle cannot be <= 0.")
if n_threads is not None:
set_num_threads(n_threads)
if low_mem:
sh_data = np.ascontiguousarray(sh_data, dtype=np.float32)
sphere.vertices = sphere.vertices.astype(np.float32)
else:
sphere.vertices = sphere.vertices.astype(np.float64)
sh_order, full_basis = _get_sh_order_and_fullness(sh_data.shape[-1])
# build filters
config.THREADING_LAYER = "workqueue"
uv_filter = _unified_filter_build_uv(sigma_angle, sphere.vertices)
nx_filter = _unified_filter_build_nx(
sphere.vertices, sigma_spatial, sigma_align, False, False
)
B = sh_to_sf_matrix(
sphere,
sh_order_max=sh_order,
basis_type=sh_basis,
full_basis=full_basis,
legacy=is_legacy,
return_inv=False,
)
_, B_inv = sh_to_sf_matrix(
sphere,
sh_order_max=sh_order,
basis_type=sh_basis,
full_basis=True,
legacy=is_legacy,
return_inv=True,
)
# compute "real" sigma_range scaled by sf amplitudes
# if rel_sigma_range is supplied
sigma_range = None
if rel_sigma_range is not None:
if rel_sigma_range <= 0.0:
raise ValueError("sigma_rangel cannot be <= 0.")
sigma_range = rel_sigma_range * _get_sf_range(sh_data, B)
if low_mem:
return _unified_filter_call_lowmem(
sh_data, nx_filter, uv_filter, sigma_range, B, B_inv, sphere
)
else:
return _unified_filter_call_python(
sh_data, nx_filter, uv_filter, sigma_range, B, B_inv, sphere
)
@njit(fastmath=True, cache=True)
def _unified_filter_build_uv(sigma_angle, directions):
"""
Build the angle filter, weighted on angle between current direction u
and neighbour direction v.
Parameters
----------
sigma_angle: float
Standard deviation of filter. Values at distances greater than
sigma_angle are clipped to 0 to reduce computation time.
directions: DIPY sphere directions.
Vertices from DIPY sphere for sampling the SF.
Returns
-------
weights: ndarray
Angle filter of shape (N_dirs, N_dirs).
"""
if sigma_angle is not None:
dot = directions.dot(directions.T)
x = np.arccos(np.clip(dot, -1.0, 1.0))
weights = _evaluate_gaussian_distribution(x, sigma_angle)
mask = x > (3.0 * sigma_angle)
weights[mask] = 0.0
weights /= np.sum(weights, axis=-1)
else:
weights = np.eye(len(directions), dtype=np.float32)
return weights
@njit(fastmath=True, cache=True)
def _unified_filter_build_nx(
directions,
sigma_spatial,
sigma_align,
disable_spatial,
disable_align,
j_invariance=False,
):
"""
Original source: github.com/CHrlS98/aodf-toolkit
Copyright (c) 2023 Charles Poirier
Licensed under the MIT License (https://opensource.org/licenses/MIT).
"""
directions = np.ascontiguousarray(directions.astype(np.float32))
half_width = int(round(3 * sigma_spatial))
nx_weights = np.zeros(
(2 * half_width + 1, 2 * half_width + 1, 2 * half_width + 1, len(directions)),
dtype=np.float32,
)
for i in range(-half_width, half_width + 1):
for j in range(-half_width, half_width + 1):
for k in range(-half_width, half_width + 1):
dxy = np.array([[i, j, k]], dtype=np.float32)
len_xy = np.sqrt(dxy[0, 0] ** 2 + dxy[0, 1] ** 2 + dxy[0, 2] ** 2)
if disable_spatial:
w_spatial = 1.0
else:
# the length controls spatial weight
w_spatial = np.exp(-(len_xy**2) / (2 * sigma_spatial**2))
# the direction controls the align weight
if i == j == k == 0 or disable_align:
# hack for main direction to have maximal weight
w_align = np.zeros((1, len(directions)), dtype=np.float32)
else:
dxy /= len_xy
w_align = np.arccos(
np.clip(np.dot(dxy, directions.T), -1.0, 1.0)
) # 1, N
w_align = np.exp(-(w_align**2) / (2 * sigma_align**2))
nx_weights[half_width + i, half_width + j, half_width + k] = (
w_align * w_spatial
)
if j_invariance:
# A filter is j-invariant if its prediction does not
# depend on the content of the current voxel
nx_weights[half_width, half_width, half_width, :] = 0.0
for ui in range(len(directions)):
w_sum = np.sum(nx_weights[..., ui])
nx_weights[..., ui] /= w_sum
return nx_weights
def _get_sf_range(sh_data, B_mat):
"""
Get the range of SF amplitudes for input `sh_data`.
Parameters
----------
sh_data: ndarray
Spherical harmonics coefficients image.
B_mat: ndarray
SH to SF projection matrix.
Returns
-------
sf_range: float
Range of SF amplitudes.
"""
sf = np.array([np.dot(i, B_mat) for i in sh_data], dtype=sh_data.dtype)
sf[sf < 0.0] = 0.0
sf_max = np.max(sf)
sf_min = np.min(sf)
return sf_max - sf_min
def _unified_filter_call_python(
sh_data, nx_filter, uv_filter, sigma_range, B_mat, B_inv, sphere
):
"""
Run filtering using pure python implementation.
Parameters
----------
sh_data: ndarray
Input SH data.
nx_filter: ndarray
Combined spatial and alignment filter.
uv_filter: ndarray
Angle filter.
sigma_range: float or None
Standard deviation of range filter. None disables range filtering.
B_mat: ndarray
SH to SF projection matrix.
B_inv: ndarray
SF to SH projection matrix.
sphere: DIPY sphere
Sphere for SH to SF projection.
Returns
-------
out_sh: ndarray
Filtered output as SH coefficients.
"""
nb_sf = len(sphere.vertices)
mean_sf = np.zeros(sh_data.shape[:-1] + (nb_sf,))
sh_data = np.ascontiguousarray(sh_data, dtype=np.float64)
B_mat = np.ascontiguousarray(B_mat, dtype=np.float64)
config.THREADING_LAYER = "workqueue"
h_w, h_h, h_d = nx_filter.shape[:3]
half_w, half_h, half_d = h_w // 2, h_h // 2, h_d // 2
sh_data_padded = np.ascontiguousarray(
np.pad(
sh_data,
((half_w, half_w), (half_h, half_h), (half_d, half_d), (0, 0)),
mode="constant",
),
dtype=np.float64,
)
for u_sph_id in tqdm(range(nb_sf)):
mean_sf[..., u_sph_id] = _correlate(
sh_data, sh_data_padded, nx_filter, uv_filter, sigma_range, u_sph_id, B_mat
)
out_sh = np.array([np.dot(i, B_inv) for i in mean_sf], dtype=sh_data.dtype)
return out_sh
@njit(fastmath=True, parallel=True)
def _correlate(
sh_data, sh_data_padded, nx_filter, uv_filter, sigma_range, u_index, B_mat
):
"""
Apply the filters to the SH image for the sphere direction
described by `u_index`.
Parameters
----------
sh_data: ndarray
Input SH coefficients.
sh_data_padded: ndarray
Input SH coefficients, pre-padded.
nx_filter: ndarray
Combined spatial and alignment filter.
uv_filter: ndarray
Angle filter.
sigma_range: float or None
Standard deviation of range filter. None disables range filtering.
u_index: int
Index of the current sphere direction to process.
B_mat: ndarray
SH to SF projection matrix.
Returns
-------
out_sf: ndarray
Output SF amplitudes along the direction described by `u_index`.
"""
v_indices = np.flatnonzero(uv_filter[u_index])
nx_filter = nx_filter[..., u_index]
h_w, h_h, h_d = nx_filter.shape[:3]
half_w, half_h, half_d = h_w // 2, h_h // 2, h_d // 2
out_sf = np.zeros(sh_data.shape[:3])
# sf_u = np.dot(sh_data, B_mat[:, u_index])
# sf_v = np.dot(sh_data, B_mat[:, v_indices])
sf_u = np.zeros(sh_data_padded.shape[:3])
sf_v = np.zeros(sh_data_padded.shape[:3] + (len(v_indices),))
for i in prange(sh_data_padded.shape[0]):
for j in range(sh_data_padded.shape[1]):
for k in range(sh_data_padded.shape[2]):
for c in range(sh_data_padded.shape[3]):
sf_u[i, j, k] += sh_data_padded[i, j, k, c] * B_mat[c, u_index]
for vi in range(len(v_indices)):
sf_v[i, j, k, vi] += (
sh_data_padded[i, j, k, c] * B_mat[c, v_indices[vi]]
)
uv_filter = uv_filter[u_index, v_indices]
for ii in prange(out_sf.shape[0]):
for jj in range(out_sf.shape[1]):
for kk in range(out_sf.shape[2]):
a = sf_v[ii : ii + h_w, jj : jj + h_h, kk : kk + h_d]
b = sf_u[ii + half_w, jj + half_h, kk + half_d]
x_range = a - b
if sigma_range is None:
range_filter = np.ones_like(x_range)
else:
range_filter = _evaluate_gaussian_distribution(x_range, sigma_range)
# the resulting filter for the current voxel and v_index
res_filter = range_filter * nx_filter[..., None]
res_filter = res_filter * np.reshape(
uv_filter, (1, 1, 1, len(uv_filter))
)
out_sf[ii, jj, kk] = np.sum(
sf_v[ii : ii + h_w, jj : jj + h_h, kk : kk + h_d] * res_filter
)
out_sf[ii, jj, kk] /= np.sum(res_filter)
return out_sf
def _unified_filter_call_lowmem(
sh_data, nx_filter, uv_filter, sigma_range, B_mat, B_inv, sphere
):
"""
Low-memory version of the filtering function.
"""
nb_sf = len(sphere.vertices)
mean_sf = np.zeros(sh_data.shape[:-1] + (nb_sf,), dtype=np.float32)
sh_data = np.ascontiguousarray(sh_data, dtype=np.float32)
B_mat = np.ascontiguousarray(B_mat, dtype=np.float32)
config.THREADING_LAYER = "workqueue"
for u_sph_id in tqdm(range(nb_sf)):
mean_sf[..., u_sph_id] = _correlate_low_mem(
sh_data, nx_filter, uv_filter, sigma_range, u_sph_id, B_mat
)
out_sh = np.array([np.dot(i, B_inv) for i in mean_sf], dtype=np.float32)
return out_sh
@njit(fastmath=True, parallel=True)
def _correlate_low_mem(sh_data, nx_filter, uv_filter, sigma_range, u_index, B_mat):
"""
Low-memory version of the correlate function.
"""
v_indices = np.flatnonzero(uv_filter[u_index])
n_v = v_indices.shape[0]
h_w = nx_filter.shape[0]
h_h = nx_filter.shape[1]
h_d = nx_filter.shape[2]
half_w = h_w // 2
half_h = h_h // 2
half_d = h_d // 2
nx_filter_u = nx_filter[:, :, :, u_index]
X = sh_data.shape[0]
Y = sh_data.shape[1]
Z = sh_data.shape[2]
C = sh_data.shape[3]
out_sf = np.zeros((X, Y, Z))
uv_filter_u = np.empty(n_v)
for vi in range(n_v):
uv_filter_u[vi] = uv_filter[u_index, v_indices[vi]]
B_u = np.empty(C)
for c in range(C):
B_u[c] = B_mat[c, u_index]
B_v = np.empty((C, n_v))
for vi in range(n_v):
v_idx = v_indices[vi]
for c in range(C):
B_v[c, vi] = B_mat[c, v_idx]
use_range = sigma_range is not None
for ii in prange(X):
for jj in range(Y):
for kk in range(Z):
sf_u_center = 0.0
for c in range(C):
sf_u_center += sh_data[ii, jj, kk, c] * B_u[c]
num = 0.0
den = 0.0
for wx in range(h_w):
i2 = ii + wx - half_w
for wy in range(h_h):
j2 = jj + wy - half_h
for wz in range(h_d):
k2 = kk + wz - half_d
if (
i2 < 0
or i2 >= X
or j2 < 0
or j2 >= Y
or k2 < 0
or k2 >= Z
):
continue
base_nx = nx_filter_u[wx, wy, wz]
if base_nx == 0.0:
continue
for vi in range(n_v):
sf_v_val = 0.0
for c in range(C):
sf_v_val += sh_data[i2, j2, k2, c] * B_v[c, vi]
if use_range:
x = sf_v_val - sf_u_center
x_norm = x / sigma_range
range_w = np.exp(-0.5 * x_norm * x_norm)
else:
range_w = 1.0
w = base_nx * uv_filter_u[vi] * range_w
num += sf_v_val * w
den += w
if den > 0.0:
out_sf[ii, jj, kk] = num / den
else:
out_sf[ii, jj, kk] = 0.0
return out_sf
@njit(fastmath=True, cache=True)
def _evaluate_gaussian_distribution(x, sigma):
"""
1-dimensional 0-centered Gaussian distribution
with standard deviation sigma.
Parameters
----------
x: ndarray or float
Points where the distribution is evaluated.
sigma: float
Standard deviation.
Returns
-------
out: ndarray or float
Values at x.
"""
if sigma <= 0.0:
raise ValueError("Sigma must be greater than 0.")
cnorm = 1.0 / sigma / np.sqrt(2.0 * np.pi)
return cnorm * np.exp(-(x**2) / 2.0 / sigma**2)
[docs]
def compute_asymmetry_index(sh_coeffs, mask):
"""
Compute asymmetry index (ASI) [1] from
asymmetric ODF volume expressed in full SH basis.
Parameters
----------
sh_coeffs: ndarray (x, y, z, ncoeffs)
Input spherical harmonics coefficients.
mask: ndarray (x, y, z), bool
Mask inside which ASI should be computed.
Returns
-------
asi_map: ndarray (x, y, z)
Asymmetry index map.
References
----------
[1] S. Cetin Karayumak, E. Özarslan, and G. Unal,
"Asymmetric Orientation Distribution Functions (AODFs)
revealing intravoxel geometry in diffusion MRI"
Magnetic Resonance Imaging, vol. 49, pp. 145-158, Jun. 2018,
doi: https://doi.org/10.1016/j.mri.2018.03.006.
"""
order, full_basis = _get_sh_order_and_fullness(sh_coeffs.shape[-1])
_, l_list = sph_harm_ind_list(order, full_basis=full_basis)
sign = np.power(-1.0, l_list)
sign = np.reshape(sign, (1, 1, 1, len(l_list)))
sh_squared = sh_coeffs**2
mask = np.logical_and(sh_squared.sum(axis=-1) > 0.0, mask)
asi_map = np.zeros(sh_coeffs.shape[:-1])
asi_map[mask] = (
np.sum(sh_squared * sign, axis=-1)[mask] / np.sum(sh_squared, axis=-1)[mask]
)
# Negatives should not happen (amplitudes always positive)
asi_map = np.clip(asi_map, 0.0, 1.0)
asi_map = np.sqrt(1 - asi_map**2) * mask
return asi_map
[docs]
def compute_odd_power_map(sh_coeffs, mask):
"""
Compute odd-power map [1] from
asymmetric ODF volume expressed in full SH basis.
Parameters
----------
sh_coeffs: ndarray (x, y, z, ncoeffs)
Input spherical harmonics coefficients.
mask: ndarray (x, y, z), bool
Mask inside which odd-power map should be computed.
Returns
-------
odd_power_map: ndarray (x, y, z)
Odd-power map.
References
----------
[1] C. Poirier, E. St-Onge, and M. Descoteaux,
"Investigating the Occurrence of Asymmetric Patterns in
White Matter Fiber Orientation Distribution Functions"
[Abstract], In: Proc. Intl. Soc. Mag. Reson. Med. 29 (2021),
2021 May 15-20, Vancouver, BC, Abstract number 0865.
"""
order, full_basis = _get_sh_order_and_fullness(sh_coeffs.shape[-1])
_, l_list = sph_harm_ind_list(order, full_basis=full_basis)
odd_l_list = (l_list % 2 == 1).reshape((1, 1, 1, -1))
odd_order_norm = np.linalg.norm(sh_coeffs * odd_l_list, ord=2, axis=-1)
full_order_norm = np.linalg.norm(sh_coeffs, ord=2, axis=-1)
asym_map = np.zeros(sh_coeffs.shape[:-1])
mask = np.logical_and(full_order_norm > 0, mask)
asym_map[mask] = odd_order_norm[mask] / full_order_norm[mask]
return asym_map
[docs]
def compute_nufid_asym(sh_coeffs, sphere, csf, mask):
"""
Number of fiber directions (nufid) map [1].
Parameters
----------
sh_coeffs: ndarray (x, y, z, ncoeffs)
Input spherical harmonics coefficients.
sphere: DIPY sphere
Sphere for SH to SF projection.
csf: ndarray (x, y, z)
CSF probability map, used to guess the absolute threshold.
mask: ndarray (x, y, z), bool
Mask inside which ASI should be computed.
References
----------
[1] C. Poirier and M. Descoteaux,
"Filtering Methods for Asymmetric ODFs:
Where and How Asymmetry Occurs in the White Matter."
bioRxiv. 2022 Jan 1; 2022.12.18.520881.
doi: https://doi.org/10.1101/2022.12.18.520881
"""
sh_order, full_basis = _get_sh_order_and_fullness(sh_coeffs.shape[-1])
odf = sh_to_sf(
sh_coeffs,
sphere,
sh_order_max=sh_order,
basis_type="descoteaux07",
full_basis=full_basis,
legacy=True,
)
# Guess at threshold from 2.0 * mean of ODF maxes in CSF
absolute_threshold = 2.0 * np.mean(np.max(odf[csf > 0.99], axis=-1))
odf[odf < absolute_threshold] = 0.0
nufid_data = np.zeros(sh_coeffs.shape[:-1], dtype=np.float32)
for ii in tqdm(range(sh_coeffs.shape[0])):
for jj in range(sh_coeffs.shape[1]):
for kk in range(sh_coeffs.shape[2]):
if mask[ii, jj, kk]:
_, peaks, _ = peak_directions(
odf[ii, jj, kk], sphere, is_symmetric=False
)
nufid_data[ii, jj, kk] = np.count_nonzero(peaks)
return nufid_data