Accelerating Multi-Shell Multi-Tissue CSD with Ray¶

Multi-shell multi-tissue constrained spherical deconvolution is a powerful model for reconstructing the configuration of fibers and the volume fraction of different tissue compartments simultaneuosly (Jeurissen et al., 2014. However, because it requires convex optimization to be executed at every voxel, it can also be a performance bottleneck. This example demonstrates how to fit Multi-Shell Multi-Tissue Constrained Spherical Deconvolution (MSMT-CSD), while using Ray for parallelization to accelerate processing.

We demonstrate this functionality here directly with the DIPY library functionality (based on an example in the DIPY documentation).

In [1]:
import os.path as op
import numpy as np
import matplotlib.pyplot as plt

import AFQ.data.fetch as afd
from AFQ.models.QBallTP import anisotropic_power

from dipy.core.gradients import gradient_table, unique_bvals_tolerance
from dipy.data import get_sphere
from dipy.io.gradients import read_bvals_bvecs
from dipy.io.image import load_nifti
from dipy.reconst.mcsd import (
    MultiShellDeconvModel,
    mask_for_response_msmt,
    multi_shell_fiber_response,
    response_from_mask_msmt,
)

Download dataset¶

We will use a multi-shell dataset from the HBN POD2 data-set (Richie-Halford et al., 2022). This dataset also includes T1-weighted data and tissue-type segmentations that can be used to constrain the response function that is calculated for MSMT. For simplicity, we will use the functionality of DIPY without using this information, but for completeness, we point out here that it could be used to restrict the regions accessed by the code that computes the response function for a more refined response function.

In [2]:
sphere = get_sphere(name="symmetric724")
study_dir = afd.fetch_hbn_preproc(["NDARAA948VFH"])[1]
sub_dir = op.join(study_dir, "derivatives/qsiprep/sub-NDARAA948VFH")

fraw = op.join(sub_dir, "ses-HBNsiteRU/dwi/sub-NDARAA948VFH_ses-HBNsiteRU_acq-64dir_space-T1w_desc-preproc_dwi.nii.gz")
fbval = op.join(sub_dir, "ses-HBNsiteRU/dwi/sub-NDARAA948VFH_ses-HBNsiteRU_acq-64dir_space-T1w_desc-preproc_dwi.bval")
fbvec = op.join(sub_dir, "ses-HBNsiteRU/dwi/sub-NDARAA948VFH_ses-HBNsiteRU_acq-64dir_space-T1w_desc-preproc_dwi.bvec")
t1_fname = op.join(sub_dir, "anat/sub-NDARAA948VFH_desc-preproc_T1w.nii.gz")
brain_mask = op.join(sub_dir, "anat/sub-NDARAA948VFH_desc-brain_mask.nii.gz")
gm_seg = op.join(sub_dir, "anat/sub-NDARAA948VFH_space-MNI152NLin2009cAsym_label-GM_probseg.nii.gz")
wm_seg = op.join(sub_dir, "anat/sub-NDARAA948VFH_space-MNI152NLin2009cAsym_label-WM_probseg.nii.gz")
csf_seg = op.join(sub_dir, "anat/sub-NDARAA948VFH_space-MNI152NLin2009cAsym_label-CSF_probseg.nii.gz")
In [3]:
data, affine = load_nifti(fraw)
bvals, bvecs = read_bvals_bvecs(fbval, fbvec)
gtab = gradient_table(bvals, bvecs=bvecs)
In [4]:
csf = np.where(load_nifti(csf_seg)[0] > 0.5, 1, 0)
gm = np.where(load_nifti(gm_seg)[0] > 0.5, 1, 0)
wm = np.where(load_nifti(wm_seg)[0] > 0.5, 1, 0)

Estimate response functions¶

In [5]:
mask_wm, mask_gm, mask_csf = mask_for_response_msmt(
    gtab,
    data,
    roi_radii=10,
    wm_fa_thr=0.7,
    gm_fa_thr=0.3,
    csf_fa_thr=0.15,
    gm_md_thr=0.001,
    csf_md_thr=0.0032,
)

response_wm, response_gm, response_csf = response_from_mask_msmt(
    gtab, data, mask_wm, mask_gm, mask_csf
)

print(response_wm)
print(response_gm)
print(response_csf)
/Users/john/miniconda3/envs/afq11/lib/python3.11/site-packages/dipy/testing/decorators.py:192: UserWarning: Some b-values are higher than 1200.
        The DTI fit might be affected.
  return func(*args, **kwargs)
[[1.75411039e-03 5.61681042e-04 5.61681042e-04 1.73149780e+02]
 [1.25543581e-03 4.17972190e-04 4.17972190e-04 1.73149780e+02]]
[[1.39038934e-03 1.22287510e-03 1.22287510e-03 2.69328430e+02]
 [9.84485172e-04 8.78908523e-04 8.78908523e-04 2.69328430e+02]]
[[1.36327748e-03 1.26285548e-03 1.26285548e-03 2.95602264e+02]
 [9.82441622e-04 9.19054215e-04 9.19054215e-04 2.95602264e+02]]

Reconstruction with MSMT-CSD¶

Finally, this code fits the MSMT-CSD model to the data. Using engine="ray" tells DIPY that the fit should be parallelized across chunks of voxels. This can result in substantial speedup (see article figures for details).

In [6]:
ubvals = unique_bvals_tolerance(gtab.bvals)
response_mcsd = multi_shell_fiber_response(
    sh_order_max=8,
    bvals=ubvals,
    wm_rf=response_wm,
    gm_rf=response_gm,
    csf_rf=response_csf,
)

mcsd_model = MultiShellDeconvModel(gtab, response_mcsd)
mcsd_fit = mcsd_model.fit(data[:, :, 50], engine="ray") # Using a subset of the data for speed in this example

# We can use the anisotropic power map to visualize the fit
plt.imshow(anisotropic_power(mcsd_fit.shm_coeff))
2025-05-28 14:37:26,899	INFO worker.py:1841 -- Started a local Ray instance.
(_parallel_fit_worker pid=20868) /Users/john/miniconda3/envs/afq11/lib/python3.11/site-packages/cvxpy/problems/problem.py:1481: UserWarning: Solution may be inaccurate. Try another solver, adjusting the solver settings, or solve with verbose=True for more information.
(_parallel_fit_worker pid=20868)   warnings.warn(
(_parallel_fit_worker pid=20865) /Users/john/miniconda3/envs/afq11/lib/python3.11/site-packages/cvxpy/problems/problem.py:1481: UserWarning: Solution may be inaccurate. Try another solver, adjusting the solver settings, or solve with verbose=True for more information. [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(_parallel_fit_worker pid=20865)   warnings.warn( [repeated 2x across cluster]
(_parallel_fit_worker pid=20862) /Users/john/miniconda3/envs/afq11/lib/python3.11/site-packages/cvxpy/problems/problem.py:1481: UserWarning: Solution may be inaccurate. Try another solver, adjusting the solver settings, or solve with verbose=True for more information. [repeated 2x across cluster]
(_parallel_fit_worker pid=20862)   warnings.warn( [repeated 2x across cluster]
Out[6]:
<matplotlib.image.AxesImage at 0x1662ea110>
No description has been provided for this image