Source code for AFQ.tasks.decorators

import functools
import inspect
import logging
import os.path as op
from time import time

import nibabel as nib
from dipy.io.streamline import save_tractogram
from dipy.io.stateful_tractogram import StatefulTractogram

try:
    from trx.trx_file_memmap import TrxFile
    from trx.io import save as save_trx
    has_trx = True
except ModuleNotFoundError:
    has_trx = False

import numpy as np

from AFQ.tasks.utils import get_fname
from AFQ.utils.path import drop_extension, write_json


# These should only be used with pimms.calc
__all__ = ["as_file", "as_fit_deriv", "as_img"]


logger = logging.getLogger('AFQ')
logger.setLevel(logging.INFO)


# get args and kwargs from function
def get_args_and_kwargs(func):
    param_dict = inspect.signature(func).parameters
    param_list = func.__code__.co_varnames[
        :func.__code__.co_argcount]
    is_param_kwarg = {
        name: name in param_dict and param_dict[name].default is
        not param_dict[name].empty for name in param_list}
    return param_list, is_param_kwarg, param_dict


# replaces *args and **kwargs with specific parameters from og_func
# so that pimms can see original parameter names after wrapping
# also adds on any args the decorator requires
# these will be extracted with extract_added_args
def has_args(og_func, needed_args):
    def _has_args(func):
        header = "def wrapper_has_args_func("
        content = "):\n    return func("
        found_args = []
        param_list, is_param_kwarg, param_dict = get_args_and_kwargs(og_func)

        # add func args
        for name in param_list:
            if not is_param_kwarg[name]:
                header += f"{name}, "
                content += f"{name}, "
                found_args.append(name)

        # add decorator args
        for arg in needed_args:
            if arg not in found_args:
                header += f"{arg}, "
                content += f"{arg}, "

        # add func kwargs
        for name in param_list:
            if is_param_kwarg[name]:
                default = param_dict[name].default
                if isinstance(default, str):
                    header += f"{name}='{default}', "
                else:
                    header += f"{name}={default}, "
                content += f"{name}={name}, "

        header = header[:-2]
        content = content[:-2]
        content = f"{content})"

        wrapper_has_args = header + content
        scope = {"func": func}
        exec(wrapper_has_args, scope)
        return scope['wrapper_has_args_func']
    return _has_args


# from function where needed args (like base_fname) are added,
# return length of args before added args, and the added args
def extract_added_args(func, names, args, includes=None):
    vals = []
    param_list, is_param_kwarg, _ = get_args_and_kwargs(func)
    arg_list = [param for param in param_list if not is_param_kwarg[param]]
    extra_count = 0
    for jj, name in enumerate(names):
        if includes is not None and not includes[jj]:
            vals.append(None)
            continue

        found = False
        for ii, arg_name in enumerate(arg_list):
            if arg_name == name:
                vals.append(args[ii])
                found = True
                break
        if not found:
            vals.append(args[len(arg_list) + extra_count])
            extra_count = extra_count + 1

    return len(arg_list), *vals


[docs]def as_file(suffix, include_track=False, include_seg=False): """ return img and meta as saved file path, with json, and only run if not already found """ def _as_file(func): needed_args = ["base_fname", "output_dir"] if include_track: needed_args.append("tracking_params") if include_seg: needed_args.append("segmentation_params") @functools.wraps(func) @has_args(func, needed_args) def wrapper_as_file(*args, **kwargs): og_arg_count, base_fname, output_dir, \ tracking_params, segmentation_params =\ extract_added_args( func, ["base_fname", "output_dir", "tracking_params", "segmentation_params"], args, includes=[True, True, include_track, include_seg]) this_file = get_fname( base_fname, suffix, tracking_params=tracking_params, segmentation_params=segmentation_params) # tracking_params is defined and file has no extension if tracking_params is not None and not op.splitext(this_file)[1]: if tracking_params["trx"]: this_file = this_file + ".trx" else: this_file = this_file + ".trk" if not op.exists(this_file): gen, meta = func(*args[:og_arg_count], **kwargs) logger.info(f"Saving {this_file}") if isinstance(gen, nib.Nifti1Image): nib.save(gen, this_file) elif isinstance(gen, StatefulTractogram): save_tractogram( gen, this_file, bbox_valid_check=False) elif isinstance(gen, np.ndarray): np.save(this_file, gen) elif has_trx and isinstance(gen, TrxFile): save_trx(gen, this_file) else: gen.to_csv(this_file) # these are used to determine dependencies # when clobbering derivatives if "_desc-profiles" in suffix: meta["dependent"] = "prof" elif include_seg: meta["dependent"] = "rec" elif include_track: meta["dependent"] = "trk" else: meta["dependent"] = "dwi" # modify meta source to be relative if "source" in meta: meta["source"] = op.relpath(meta["source"], output_dir) meta_fname = get_fname( base_fname, f"{drop_extension(suffix)}.json", tracking_params=tracking_params, segmentation_params=segmentation_params) write_json(meta_fname, meta) return this_file return wrapper_as_file return _as_file
[docs]def as_fit_deriv(tf_name): """ return data as nibabel image, meta with params information """ def _as_fit_deriv(func): needed_args = ["dwi_affine", f"{tf_name.lower()}_params"] @functools.wraps(func) @has_args(func, needed_args) def wrapper_as_fit_deriv(*args, **kwargs): og_arg_count, dwi_affine, params = extract_added_args( func, needed_args, args) img = nib.Nifti1Image( func(*args[:og_arg_count], **kwargs), dwi_affine) return img, {f"{tf_name}ParamsFile": params} return wrapper_as_fit_deriv return _as_fit_deriv
[docs]def as_img(func): """ return data, meta as nibabel image, meta with timing """ needed_args = ["dwi_affine"] @functools.wraps(func) @has_args(func, needed_args) def wrapper_as_img(*args, **kwargs): og_arg_count, affine = extract_added_args( func, needed_args, args) start_time = time() data, meta = func(*args[:og_arg_count], **kwargs) meta['timing'] = time() - start_time img = nib.Nifti1Image(data.astype(np.float32), affine) return img, meta return wrapper_as_img