Note
Go to the end to download the full example code.
Load and interact with an AFQ dataset#
This example loads AFQ data from CSV files and manipulates that data using scikit-learn transformers and estimators. First we fetch the Weston-Havens dataset described in Yeatman et al [1]. This dataset contains tractometry features from 77 subjects ages 6-50.
Next, we split the dataset into a train and test split, impute missing values,
and fit a LASSO model, all using AFQDataset
methods. Predictive
performance for the default LASSO model is abysmal. It is only used here to
demonstrate the use of scikit-learn estimators. In a research setting, one might
use more advanced estimators, such as the SGL [2], a gradient boosting machine,
or a neural network.
Finally, we convert the AFQDataset to a tensorflow dataset and fit a basic one-dimensional CNN to predict age from the features. This last step requires that AFQ-Insight has been installed with:
pip install afqinsight[tf]
or that tensorflow has been separately installed with:
pip install tensorflow
import os.path as op
import tensorflow as tf
from sklearn.impute import SimpleImputer
from sklearn.linear_model import Lasso
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
import afqinsight.nn.tf_models as nn
from afqinsight import AFQDataset
from afqinsight.datasets import download_weston_havens
Fetch example data#
The download_weston_havens()
function download the data used in this
example and places it in the ~/.cache/afq-insight/weston_havens directory.
If the directory does not exist, it is created. The data follows the format
expected by the load_afq_data()
function: a file called nodes.csv that
contains AFQ tract profiles and a file called subjects.csv that contains
information about the subjects. The two files are linked through the
subjectID column that should exist in both of them. For more information
about this format, see also the AFQ-Browser documentation (items 2 and 3).
workdir = download_weston_havens()
Read in the data#
Next, we read in the data. The AFQDataset.from_files()
static method
expects a the filenames of a nodes.csv and subjects.csv file, and returns a
dataset object.
Train / test split#
We can use the dataset in the train_test_split()
function just as we
would with an array.
dataset_train, dataset_test = train_test_split(dataset, test_size=1 / 3)
Impute missing values#
Next we train an imputer on the training set and use it to transform the features in both the training and the test set.
imputer = dataset_train.model_fit(SimpleImputer(strategy="median"))
dataset_train = dataset_train.model_transform(imputer)
dataset_test = dataset_test.model_transform(imputer)
Fit a LASSO model#
Next we fit a LASSO estimator to the training data and print the score of that model on the test dataset.
estimator = dataset_train.model_fit(Lasso())
y_pred = dataset_test.model_predict(estimator)
train_score = dataset_train.model_score(estimator)
test_score = dataset_test.model_score(estimator)
print("LASSO train score:", train_score)
print("LASSO test score: ", test_score)
Convert to tensorflow datasets#
Next we convert the train and test datasets to tensorflow datasets and use one of AFQ-Insight’s built-in one-dimensional CNNs to predict age. This part of the example will only work if you have either installed AFQ-Insight with tensorflow using:
pip install afqinsight[tf]
or separately install tensorflow using:
pip install tensorflow
This model also performs poorly. It turns out predicting age in this dataset requires a bit more work.
tfset_train = dataset_train.as_tensorflow_dataset()
tfset_test = dataset_test.as_tensorflow_dataset()
batch_size = 2
tfset_train = tfset_train.batch(8)
tfset_test = tfset_test.batch(8)
print("CNN Architecture")
model = nn.cnn_lenet(
input_shape=(100, 40), output_activation=None, n_classes=1, verbose=True
)
model.compile(
loss="mean_squared_error",
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
metrics=["mean_squared_error"],
)
model.fit(tfset_train, epochs=500, validation_data=tfset_test, verbose=0)
print()
print("CNN R^2 score: ", r2_score(dataset_test.y, model.predict(tfset_test)))