Source code for AFQ.nn.multiaxial

import logging
import os.path as op

import nibabel as nib
import numpy as np
from tqdm import tqdm

from AFQ.data.fetch import afq_home, fetch_multiaxial_models
from AFQ.nn.utils import prepare_t1_for_nn, resample_output

logger = logging.getLogger("AFQ")


__all__ = ["run_multiaxial", "extract_brain_mask", "multiaxial"]


[docs] def multiaxial( ort, img, model_sagittal, model_axial, model_coronal, consensus_model, onnx_kwargs ): """ Perform multiaxial segmentation using three ONNX models and a consensus model [1]. Parameters ---------- img : ndarray 3D T1 image to segment. model_sagittal : str Path to sagittal ONNX model. model_axial : str Path to axial ONNX model. model_coronal : str Path to coronal ONNX model. consensus_model : str Path to consensus ONNX model. onnx_kwargs : dict ONNX kwargs to use for inference. Returns ------- pred : ndarray Segmentation labels for each coordinate. References ---------- [1] Birnbaum, Andrew M., et al. "Full-head segmentation of MRI with abnormal brain anatomy: model and data release." Journal of Medical Imaging 12.5 (2025): 054001-054001. """ img = img.astype(np.float32) coords = _create_coord_grid().astype(np.float32) pbar = tqdm(total=4) input_ = img[..., None] sagittal_results = _run_onnx_model(ort, model_sagittal, input_, coords, onnx_kwargs) pbar.update(1) input_ = np.swapaxes(img, 0, 1)[..., None] coronal_results = np.swapaxes( _run_onnx_model(ort, model_coronal, input_, coords, onnx_kwargs), 0, 1 ) pbar.update(1) input_ = np.transpose(img, (2, 0, 1))[..., None] axial_results = np.transpose( _run_onnx_model(ort, model_axial, input_, coords, onnx_kwargs), (1, 2, 0, 3) ) pbar.update(1) X = np.concatenate( [img[..., None], sagittal_results, coronal_results, axial_results], -1 ) sess = ort.InferenceSession(consensus_model, **onnx_kwargs) input_name = sess.get_inputs()[0].name output_name = sess.get_outputs()[0].name yhat = sess.run([output_name], {input_name: X[None, ...]})[0] pbar.update(1) pbar.close() pred = np.argmax(yhat[0], -1) return pred
def _run_onnx_model(ort, model, input_, coords, onnx_kwargs): sess = ort.InferenceSession(model, **onnx_kwargs) input_name = sess.get_inputs()[0].name coord_name = sess.get_inputs()[1].name output_name = sess.get_outputs()[0].name results = np.zeros((256, 256, 256, 7), np.float32) for ii in tqdm(range(input_.shape[0]), leave=False): results[ii] = sess.run( [output_name], {input_name: input_[ii : ii + 1], coord_name: coords[ii : ii + 1]}, )[0] return results def _create_coord_grid(): x, y, z = (256, 256, 256) ac = (128, 128, 128) # assume anterior commissure meshgrid = np.meshgrid( np.linspace(0, x - 1, x), np.linspace(0, y - 1, y), np.linspace(0, z - 1, z), indexing="ij", ) coordinates = np.stack(meshgrid, axis=-1) - np.array(ac) coords = np.concatenate( [ coordinates, np.ones( (coordinates.shape[0], coordinates.shape[1], coordinates.shape[2], 1) ), ], axis=-1, ) coords = coords[:, :, :, :3] coords = coords / 256.0 return coords.astype(np.int16) def _get_multiaxial_model(): model_dict = {} for model_name in [ "sagittal_model", "axial_model", "coronal_model", "consensus_model", ]: model_path = op.join(afq_home, "multiaxial_models_onnx", model_name + ".onnx") if not op.exists(model_path): fetch_multiaxial_models() model_dict[model_name] = model_path return model_dict
[docs] def run_multiaxial(ort, t1_img, onnx_kwargs): """ Run the multiaxial model. """ model_dict = _get_multiaxial_model() t1_data, conformed_affine = prepare_t1_for_nn(t1_img) logger.info("Running multiaxial T1w segmentation...") output = multiaxial( ort, t1_data, model_dict["sagittal_model"], model_dict["axial_model"], model_dict["coronal_model"], model_dict["consensus_model"], onnx_kwargs, ) output_img = resample_output(output, conformed_affine, t1_img) return output_img
[docs] def extract_brain_mask(predictions): """ Extract brain mask from multiaxial predictions. Parameters ---------- predictions : Nifti1Image Multiaxial segmentation predictions. Returns ------- bm_data : ndarray Brain mask data. """ gm = predictions.get_fdata() == 2 wm = predictions.get_fdata() == 3 csf = predictions.get_fdata() == 4 bm_data = wm | gm | csf return bm_data
def extract_pve(prediction): """ Extract PVE maps from multiaxial predictions. Parameters ---------- prediction : Nifti1Image Multiaxial segmentation predictions. Returns ------- pve_img : Nifti1Image PVE image with CSF, GM, and WM segmentations. """ gm = prediction.get_fdata() == 2 wm = prediction.get_fdata() == 3 csf = prediction.get_fdata() == 4 pve_data = np.zeros(prediction.get_fdata().shape + (3,), dtype=np.float32) pve_data[..., 0] = csf.astype(np.float32) pve_data[..., 1] = gm.astype(np.float32) pve_data[..., 2] = wm.astype(np.float32) pve_img = nib.Nifti1Image(pve_data.astype(np.float32), prediction.affine) return pve_img