Predict age from white matter features¶

This example uses data from the HBN POD2 dataset, which includes 1867 subjects ages 5-21. We will use the sparse group lasso implemented in AFQ-Insight to fit a predictive model that uses tractometry features to predict each subject's age. Because white matter develops dramatically during childhood and adolescence, this model can be fit to account for a substantial proportion of variance in a held-out dataset.

In [1]:
import os.path as op
from paths import afq_home

import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from groupyr.decomposition import GroupPCA

from afqinsight.neurocombat_sklearn import CombatModel
from afqinsight import make_afq_regressor_pipeline
from afqinsight import AFQDataset

Read the data¶

The nodes.csv file, which is the input here is the output of pyAFQ processing. The subjects.tsv file is a BIDS-compliant participants file, which includes subject identifiers that match those that are stored in the pyAFQ output. This allows AFQ-Insight to merge the data between the two files.

In [2]:
afqdata = AFQDataset.from_files(
    fn_nodes=op.join(afq_home, "afq-insight/hbn/nodes.csv"),
    fn_subjects=op.join(afq_home, "afq-insight/hbn/subjects.tsv"),
    dwi_metrics=["dki_md", "dki_fa"],
    target_cols=["age", "sex", "scan_site_id"],
    label_encode_cols = ["sex", "scan_site_id"],
    index_col="subject_id"
)
/Users/john/AFQ-Insight/afqinsight/transform.py:144: FutureWarning: The previous implementation of stack is deprecated and will be removed in a future version of pandas. See the What's New notes for pandas 2.1.0 for details. Specify future_stack=True to adopt the new implementation and silence this warning.
  features = interpolated.stack(["subjectID", "tractID", "metric"]).unstack(
In [3]:
afqdata.drop_target_na()
print(afqdata)
AFQDataset(n_samples=1867, n_features=4800, n_targets=3, targets=['age', 'sex', 'scan_site_id'])

Train / test split¶

We can pass the AFQDataset class instance to scikit-learn's train_test_split function, just as we would with an array.

In [4]:
dataset_train, dataset_test = train_test_split(afqdata, test_size=0.25)

Impute missing values¶

Next we impute missing values using median imputation (some values are missing because of noisy MRI scans). We fit the imputer using the training set and then use it to transform both the training and test sets.

In [5]:
imputer = dataset_train.model_fit(SimpleImputer(strategy="median"))
dataset_train = dataset_train.model_transform(imputer)
dataset_test = dataset_test.model_transform(imputer)

Harmonize the sites and replot¶

The HBN dataset contains measurements from four different sites. and there are substantial scan site differences in both the FA and MD profiles. We use neuroComBat (Fortin et al., 2017) to harmonize the site differences and then replot the mean bundle profiles.

In [6]:
# Fit the ComBat transformer to the training set

combat = CombatModel()
combat.fit(
    dataset_train.X,
    dataset_train.y[:, 2][:, np.newaxis],
    dataset_train.y[:, 1][:, np.newaxis],
    dataset_train.y[:, 0][:, np.newaxis],
)


# And then transform a copy of the test set and a copy of the train set:
harmonized_test = dataset_test.copy()
harmonized_test.X = combat.transform(
    dataset_test.X,
    dataset_test.y[:, 2][:, np.newaxis],
    dataset_test.y[:, 1][:, np.newaxis],
    dataset_test.y[:, 0][:, np.newaxis],
)

harmonized_train = dataset_train.copy()
harmonized_train.X = combat.transform(
    dataset_train.X,
    dataset_train.y[:, 2][:, np.newaxis],
    dataset_train.y[:, 1][:, np.newaxis],
    dataset_train.y[:, 0][:, np.newaxis],
)

Create an analysis pipeline¶

Finally, we can use the imputed and harmonized data. AFQ-Insight implements complex pipelines that include multiple analysis steps. Helper functions (such as make_afq_regressor_pipeline) create scikit-learn compatible pipelines that can then be used to fit, predict and score the model.

In [7]:
do_group_pca = True

if do_group_pca:
    n_components = 10

    # The next three lines retrieve the group structure of the group-wise PCA
    # and store it in ``groups_pca``. We do not use the GroupPCA transformer
    # for anything else
    imputer = SimpleImputer(strategy="median")
    gpca = GroupPCA(n_components=n_components, groups=afqdata.groups)
    groups_pca = gpca.fit(harmonized_test.X).groups_out_

    transformer = GroupPCA
    transformer_kwargs = {"groups": afqdata.groups, "n_components": n_components}
else:
    transformer = False
    transformer_kwargs = None

pipe = make_afq_regressor_pipeline(
    imputer_kwargs={"strategy": "median"},  # Use median imputation
    use_cv_estimator=True,  # Automatically determine the best hyperparameters
    scaler="standard",  # Standard scale the features before regression
    feature_transformer=transformer,  # See note above about group PCA
    feature_transformer_kwargs=transformer_kwargs,
    groups=(
        groups_pca if do_group_pca else afqdata.groups
    ),  # SGL will use the original feature groups or the PCA feature groups depending on the choice above # noqa E501
    verbose=0,  # Be quiet!
    pipeline_verbosity=False,  # No really, be quiet!
    tuning_strategy="bayes",  # Use BayesSearchCV to determine optimal hyperparameters
    n_bayes_iter=20,  # Consider this many points in hyperparameter space
    cv=3,  # Use three CV splits to evaluate each hyperparameter combination
    l1_ratio=[0.0, 1.0],  # Explore the entire range of ``l1_ratio``
    eps=5e-2,  # This is the ratio of the smallest to largest ``alpha`` value
)
In [8]:
pipe.fit(harmonized_train.X, harmonized_train.y[:, 0])
Out[8]:
Pipeline(steps=[('impute', SimpleImputer(strategy='median')),
                ('scale', StandardScaler()),
                ('feature_transform',
                 GroupPCA(groups=[array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
       68, 69, 70, 71, 72, 73, 74, 75,...
                               array([240, 241, 242, 243, 244, 245, 246, 247, 248, 249]),
                               array([250, 251, 252, 253, 254, 255, 256, 257, 258, 259]),
                               array([260, 261, 262, 263, 264, 265, 266, 267, 268, 269]),
                               array([270, 271, 272, 273, 274, 275, 276, 277, 278, 279]),
                               array([280, 281, 282, 283, 284, 285, 286, 287, 288, 289]),
                               array([290, 291, 292, 293, 294, 295, 296, 297, 298, 299]), ...],
                       l1_ratio=[0.0, 1.0], n_bayes_iter=20,
                       tuning_strategy='bayes', verbose=0))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('impute', SimpleImputer(strategy='median')),
                ('scale', StandardScaler()),
                ('feature_transform',
                 GroupPCA(groups=[array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
       68, 69, 70, 71, 72, 73, 74, 75,...
                               array([240, 241, 242, 243, 244, 245, 246, 247, 248, 249]),
                               array([250, 251, 252, 253, 254, 255, 256, 257, 258, 259]),
                               array([260, 261, 262, 263, 264, 265, 266, 267, 268, 269]),
                               array([270, 271, 272, 273, 274, 275, 276, 277, 278, 279]),
                               array([280, 281, 282, 283, 284, 285, 286, 287, 288, 289]),
                               array([290, 291, 292, 293, 294, 295, 296, 297, 298, 299]), ...],
                       l1_ratio=[0.0, 1.0], n_bayes_iter=20,
                       tuning_strategy='bayes', verbose=0))])
SimpleImputer(strategy='median')
StandardScaler()
GroupPCA(groups=[array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
       68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
       85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]),
                 array([100, 101, 102, 103, 104, 105, 106...
       2933, 2934, 2935, 2936, 2937, 2938, 2939, 2940, 2941, 2942, 2943,
       2944, 2945, 2946, 2947, 2948, 2949, 2950, 2951, 2952, 2953, 2954,
       2955, 2956, 2957, 2958, 2959, 2960, 2961, 2962, 2963, 2964, 2965,
       2966, 2967, 2968, 2969, 2970, 2971, 2972, 2973, 2974, 2975, 2976,
       2977, 2978, 2979, 2980, 2981, 2982, 2983, 2984, 2985, 2986, 2987,
       2988, 2989, 2990, 2991, 2992, 2993, 2994, 2995, 2996, 2997, 2998,
       2999]), ...],
         n_components=10)
SGLCV(cv=3, eps=0.05,
      groups=[array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
              array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19]),
              array([20, 21, 22, 23, 24, 25, 26, 27, 28, 29]),
              array([30, 31, 32, 33, 34, 35, 36, 37, 38, 39]),
              array([40, 41, 42, 43, 44, 45, 46, 47, 48, 49]),
              array([50, 51, 52, 53, 54, 55, 56, 57, 58, 59]),
              array([60, 61, 62, 63, 64, 65, 66, 67, 68, 69]),
              array([70, 71, 72, 73, 74, 75, 76, 77, 78, 79]),
              array([80, 81, 82, 83, 8...
              array([240, 241, 242, 243, 244, 245, 246, 247, 248, 249]),
              array([250, 251, 252, 253, 254, 255, 256, 257, 258, 259]),
              array([260, 261, 262, 263, 264, 265, 266, 267, 268, 269]),
              array([270, 271, 272, 273, 274, 275, 276, 277, 278, 279]),
              array([280, 281, 282, 283, 284, 285, 286, 287, 288, 289]),
              array([290, 291, 292, 293, 294, 295, 296, 297, 298, 299]), ...],
      l1_ratio=[0.0, 1.0], n_bayes_iter=20, tuning_strategy='bayes', verbose=0)
In [9]:
pred_age = pipe.predict(harmonized_test.X)
In [10]:
fig, ax = plt.subplots()
ax.scatter(harmonized_test.y[:, 0], pred_age)
Out[10]:
<matplotlib.collections.PathCollection at 0x307d77dd0>
No description has been provided for this image
In [11]:
pipe.score(harmonized_test.X, harmonized_test.y[:, 0])
Out[11]:
0.5101207020779658