import logging
import os.path as op
from time import time
import immlib
import nibabel as nib
import numpy as np
import pandas as pd
import AFQ.utils.streamlines as aus
import AFQ.utils.volume as auv
from AFQ._fixes import gaussian_weights
from AFQ.recognition.recognize import recognize
from AFQ.tasks.decorators import as_file
from AFQ.tasks.utils import get_default_args, get_fname, get_tp, str_to_desc, with_name
from AFQ.utils.path import drop_extension, write_json
try:
from trx.io import load as load_trx
from trx.io import save as save_trx
from trx.trx_file_memmap import TrxFile
except ModuleNotFoundError:
has_trx = False
import gzip
import shutil
from tempfile import mkdtemp
from dipy.io.streamline import load_tractogram, save_tractogram
from dipy.stats.analysis import afq_profile
from dipy.tracking.streamline import set_number_of_points, values_from_volume
from dipy.tracking.utils import length
[docs]
logger = logging.getLogger("AFQ")
@immlib.calc("bundles")
@as_file("_desc-bundles_tractography")
[docs]
def segment(
structural_imap, data_imap, mapping_imap, tractography_imap, segmentation_params
):
"""
full path to a trk/trx file containing containing
segmented streamlines, labeled by bundle
Parameters
----------
segmentation_params : dict, optional
The parameters for segmentation.
Defaults to using the default behavior of the seg.Segmentation object.
"""
bundle_dict = data_imap["bundle_dict"]
reg_template = data_imap["reg_template"]
streamlines = tractography_imap["streamlines"]
if (
streamlines.endswith(".trk")
or streamlines.endswith(".tck")
or streamlines.endswith(".vtk")
):
tg = load_tractogram(streamlines, data_imap["dwi"], bbox_valid_check=False)
is_trx = False
elif streamlines.endswith(".trx"):
is_trx = True
if segmentation_params["nb_streamlines"] or segmentation_params["nb_points"]:
tg = load_trx(streamlines, data_imap["dwi"])
else:
tg = streamlines
elif streamlines.endswith(".tck.gz"):
# uncompress tck.gz to a temporary tck:
temp_tck = op.join(mkdtemp(), op.split(streamlines.replace(".gz", ""))[1])
logger.info(f"Temporary tck file created at: {temp_tck}")
with gzip.open(streamlines, "rb") as f_in:
with open(temp_tck, "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
# initialize stateful tractogram from tck file:
tg = load_tractogram(temp_tck, data_imap["dwi"], bbox_valid_check=False)
is_trx = False
if not is_trx:
indices_to_remove, _ = tg.remove_invalid_streamlines()
if len(indices_to_remove) > 0:
logger.warning((f"{len(indices_to_remove)} invalid streamlines removed"))
start_time = time()
bundles, bundle_meta = recognize(
tg,
data_imap["dwi"],
mapping_imap["mapping"],
bundle_dict,
reg_template,
**segmentation_params,
)
seg_sft = aus.SegmentedSFT(bundles)
if len(seg_sft.sft) < 1:
raise ValueError("Fatal: No bundles recognized.")
if is_trx:
seg_sft.sft.dtype_dict = {"positions": np.float32, "offsets": np.uint32}
tgram = TrxFile.from_sft(seg_sft.sft)
tgram.groups = seg_sft.bundle_idxs
else:
tgram = seg_sft.sft
meta = seg_sft.sidecar_info
seg_params_out = {}
for arg_name, value in segmentation_params.items():
if isinstance(value, (int, float, bool, str)):
seg_params_out[arg_name] = value
elif isinstance(value, (list, tuple)):
seg_params_out[arg_name] = [str(v) for v in value]
elif isinstance(value, dict):
for k, v in value.items():
seg_params_out[k] = str(v)
else:
seg_params_out[arg_name] = str(value)
meta["source"] = streamlines
meta["Recognition Parameters"] = seg_params_out
meta["Bundle Parameters"] = bundle_meta
meta["Timing"] = time() - start_time
return tgram, meta
@immlib.calc("indiv_bundles")
[docs]
def export_bundles(base_fname, output_dir, bundles, tracking_params):
"""
dictionary of paths, where each path is
a full path to a trk file containing the streamlines of a given bundle.
"""
is_trx = tracking_params.get("trx", False)
if is_trx:
extension = ".trx"
else:
extension = ".trk"
base_fname = op.join(output_dir, op.split(base_fname)[1])
seg_sft = aus.SegmentedSFT.fromfile(bundles)
for bundle in seg_sft.bundle_names:
fname = get_fname(
base_fname,
f"_desc-{str_to_desc(bundle)}_tractography{extension}",
subfolder="bundles",
)
if op.exists(fname):
logger.info(f"Bundle {bundle} already exists at {fname}. Skipping export.")
else:
bundle_sft = seg_sft.get_bundle(bundle)
if len(bundle_sft) > 0:
logger.info(f"Saving {fname}")
if is_trx:
seg_sft.sft.dtype_dict = {
"positions": np.float32,
"offsets": np.uint32,
}
trxfile = TrxFile.from_sft(bundle_sft)
save_trx(trxfile, fname)
else:
save_tractogram(bundle_sft, fname, bbox_valid_check=False)
else:
logger.info(f"No bundle to save for {bundle}")
meta = dict(source=bundles, params=seg_sft.get_bundle_param_info(bundle))
meta_fname = drop_extension(fname) + ".json"
write_json(meta_fname, meta)
return op.dirname(fname)
@immlib.calc("sl_counts")
@as_file("_desc-slCount_tractography.csv", subfolder="stats")
[docs]
def export_sl_counts(bundles):
"""
full path to a JSON file containing streamline counts
"""
sl_counts = []
seg_sft = aus.SegmentedSFT.fromfile(bundles)
for bundle in seg_sft.bundle_names:
sl_counts.append(len(seg_sft.get_bundle(bundle).streamlines))
sl_counts.append(len(seg_sft.sft.streamlines))
counts_df = pd.DataFrame(
data=dict(n_streamlines=sl_counts),
index=seg_sft.bundle_names + ["Total Recognized"],
)
return counts_df, dict(source=bundles)
@immlib.calc("bundle_lengths")
@as_file("_desc-medianBundleLengths_tractography.csv", subfolder="stats")
[docs]
def export_bundle_lengths(bundles):
"""
full path to a CSV file containing median + min + max bundle lengths
"""
len_data = {}
seg_sft = aus.SegmentedSFT.fromfile(bundles)
for bundle in seg_sft.bundle_names:
these_lengths = list(length(seg_sft.get_bundle(bundle).streamlines))
if len(these_lengths) > 0:
len_data[f"{bundle} Median"] = np.median(these_lengths)
len_data[f"{bundle} Min"] = np.min(these_lengths)
len_data[f"{bundle} Max"] = np.max(these_lengths)
else:
len_data[f"{bundle} Median"] = 0
len_data[f"{bundle} Min"] = 0
len_data[f"{bundle} Max"] = 0
len_data["Total Recognized Median"] = np.median(
seg_sft.sft._tractogram._streamlines._lengths
)
len_data["Total Recognized Min"] = np.min(
seg_sft.sft._tractogram._streamlines._lengths
)
len_data["Total Recognized Max"] = np.max(
seg_sft.sft._tractogram._streamlines._lengths
)
counts_df = pd.DataFrame(
data=len_data,
index=[0],
)
return counts_df, dict(source=bundles)
@immlib.calc("density_maps")
@as_file("_desc-density_tractography.nii.gz")
[docs]
def export_density_maps(bundles, data_imap):
"""
full path to 4d nifti file containing streamline counts per voxel
per bundle, where the 4th dimension encodes the bundle
"""
seg_sft = aus.SegmentedSFT.fromfile(bundles)
entire_density_map = np.zeros(
(*data_imap["data"].shape[:3], len(seg_sft.bundle_names))
)
for ii, bundle_name in enumerate(seg_sft.bundle_names):
bundle_sl = seg_sft.get_bundle(bundle_name)
bundle_density = auv.density_map(bundle_sl).get_fdata()
entire_density_map[..., ii] = bundle_density
return nib.Nifti1Image(entire_density_map, data_imap["dwi_affine"]), dict(
source=bundles, bundles=list(seg_sft.bundle_names)
)
@immlib.calc("profiles")
@as_file("_desc-profiles_tractography.csv")
[docs]
def tract_profiles(
bundles, scalar_dict, data_imap, profile_weights="gauss", n_points_profile=100
):
"""
full path to a CSV file containing tract profiles
Parameters
----------
profile_weights : str, 1D array, 2D array, or callable, optional
How to weight each streamline (1D) or each node (2D)
when calculating the tract-profiles. If callable, this is a
function that calculates weights. If None, no weighting will
be applied. If "gauss", gaussian weights will be used.
If "median", the median of values at each node will be used
instead of a mean or weighted mean.
Default: "gauss"
n_points_profile : int, optional
Number of points to resample each streamline to before
calculating the tract-profiles.
Default: 100
"""
if not (
profile_weights is None
or isinstance(profile_weights, str)
or callable(profile_weights)
or hasattr(profile_weights, "__len__")
):
raise TypeError(
"profile_weights must be string, None, callable, or" + "a 1D or 2D array"
)
if isinstance(profile_weights, str):
profile_weights = profile_weights.lower()
if (
isinstance(profile_weights, str)
and profile_weights != "gauss"
and profile_weights != "median"
):
raise TypeError(
"if profile_weights is a string," + " it must be 'gauss' or 'median'"
)
bundle_names = []
node_numbers = []
profiles = np.empty((len(scalar_dict), 0)).tolist()
this_profile = np.zeros((len(scalar_dict), n_points_profile))
reference = nib.load(scalar_dict[list(scalar_dict.keys())[0]])
seg_sft = aus.SegmentedSFT.fromfile(bundles, reference=reference)
seg_sft.sft.to_rasmm()
for bundle_name in seg_sft.bundle_names:
this_sl = seg_sft.get_bundle(bundle_name).streamlines
if len(this_sl) == 0:
continue
if profile_weights == "gauss":
# calculate only once per bundle
bundle_profile_weights = gaussian_weights(
this_sl, n_points=n_points_profile
)
for ii, (scalar, scalar_file) in enumerate(scalar_dict.items()):
if isinstance(scalar_file, str):
scalar_file = nib.load(scalar_file)
scalar_data = scalar_file.get_fdata()
if isinstance(profile_weights, str):
if profile_weights == "gauss":
this_prof_weights = np.asarray(bundle_profile_weights)
elif profile_weights == "median":
# weights bundle to only return the mean
def _median_weight(bundle):
fgarray = set_number_of_points(bundle, n_points_profile)
values = np.array(
values_from_volume(
scalar_data, # noqa B023
fgarray,
data_imap["dwi_affine"],
)
)
weights = np.zeros(values.shape)
for ii, jj in enumerate(
np.argsort(values, axis=0)[len(values) // 2, :]
):
weights[jj, ii] = 1
return weights
this_prof_weights = _median_weight
else:
this_prof_weights = np.asarray(profile_weights)
if isinstance(this_prof_weights, np.ndarray) and np.any(
np.isnan(this_prof_weights)
): # fit failed
logger.warning(
(
f"Even weighting used for "
f"bundle {bundle_name}, scalar {scalar} "
f"in profiling due inability to estimate weights. "
"This is often caused by low streamline count or "
"low variance in the scalar data."
)
)
this_prof_weights = np.ones_like(this_prof_weights)
this_profile[ii] = afq_profile(
scalar_data,
this_sl,
data_imap["dwi_affine"],
weights=this_prof_weights,
n_points=n_points_profile,
)
profiles[ii].extend(list(this_profile[ii]))
nodes = list(np.arange(this_profile[0].shape[0]))
bundle_names.extend([bundle_name] * len(nodes))
node_numbers.extend(nodes)
profile_dict = dict()
profile_dict["tractID"] = bundle_names
profile_dict["nodeID"] = node_numbers
for ii, scalar in enumerate(scalar_dict.keys()):
profile_dict[scalar] = profiles[ii]
profile_dframe = pd.DataFrame(profile_dict)
meta = dict(
source=bundles,
parameters=get_default_args(afq_profile),
scalars=list(scalar_dict.keys()),
bundles=list(seg_sft.bundle_names),
)
return profile_dframe, meta
@immlib.calc("scalar_dict")
[docs]
def get_scalar_dict(
structural_imap,
data_imap,
tissue_imap,
mapping_imap,
t1_file,
scalars=None,
):
"""
dictionary mapping scalar names
to their respective file paths
Parameters
----------
scalars : list of strings and/or scalar definitions, optional
List of scalars to use.
Can be any of: "dti_fa", "dti_md", "dki_fa", "dki_md", "dki_awf",
"dki_mk", or other scalars found in AFQ.tasks.data.
Can also be a scalar from AFQ.definitions.image.
Finally, can also be "t1w".
Defaults for single shell data to ["dti_fa", "dti_md", "t1w"],
and for multi-shell data to ["dki_fa", "dki_md", "dki_kfa",
"dki_mk", "t1w"].
Default: ['dti_fa', 'dti_md', 't1w']
"""
# Note: some scalars preprocessing done in mapping plan, before this step
if scalars is None:
scalars = ["dti_fa", "dti_md", "t1w"]
scalar_dict = {}
for scalar in scalars:
if isinstance(scalar, str):
sc = scalar.lower()
if sc == "t1w":
scalar_dict[sc] = t1_file
else:
scalar_dict[sc] = get_tp(
f"{sc}", structural_imap, data_imap, tissue_imap
)
elif f"{scalar.get_name()}" in mapping_imap:
scalar_dict[scalar.get_name()] = mapping_imap[f"{scalar.get_name()}"]
return {"scalar_dict": scalar_dict}
[docs]
def get_segmentation_plan(kwargs):
if "segmentation_params" in kwargs and not isinstance(
kwargs["segmentation_params"], dict
):
raise TypeError("segmentation_params a dict")
if "cleaning_params" in kwargs:
raise ValueError(
"cleaning_params should be passed inside ofsegmentation_params"
)
segmentation_tasks = with_name(
[
get_scalar_dict,
export_sl_counts,
export_bundle_lengths,
export_bundles,
export_density_maps,
segment,
tract_profiles,
]
)
default_seg_params = get_default_args(recognize)
if "segmentation_params" in kwargs:
for k in kwargs["segmentation_params"]:
default_seg_params[k] = kwargs["segmentation_params"][k]
kwargs["segmentation_params"] = default_seg_params
return immlib.plan(**segmentation_tasks)