Running pyAFQ using the GPU for tractography#

Running pyAFQ using the GPU for tractography is as simple as (1) Installing GPUStreamlines using pip install and (2) passing in the jit_backend parameter when you create your GroupAFQ object. To install GPUStreamlines, do: pip install git+https://github.com/dipy/GPUStreamlines.git That’s step 1 complete! The rest of this example is the same as the GroupAFQ example except with the jit_backend parameter set.

from AFQ.api.group import GroupAFQ
import AFQ.data.fetch as afd
import os.path as op
import plotly
2026-05-26 22:56:54,923	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.

We start with some example data. The data we will use here is generated from the Stanford HARDI dataset. We then setup our myafq object which we will use to demonstrate the clobber method.

afd.organize_stanford_data()
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[2], line 1
----> 1 afd.organize_stanford_data()

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/AFQ/data/fetch.py:1819, in organize_stanford_data(path, clear_previous_afq)
   1817 # fetches data for first subject and session
   1818 logger.info("fetching Stanford HARDI data")
-> 1819 dpd.fetch_stanford_hardi()
   1821 if path is None:
   1822     if not op.exists(afq_home):

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/dipy/data/fetcher.py:494, in _make_fetcher.<locals>.fetcher(include_optional)
    491         continue
    492     files[str(n)] = (baseurl + f, md5_list[i] if md5_list is not None else None)
--> 494 fetch_data(files, folder, data_size=data_size, use_headers=use_headers)
    496 if msg is not None:
    497     logger.info(msg)

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/dipy/testing/decorators.py:201, in warning_for_keywords.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    194 # Check if the current version is within the warning range
    195 if (
    196     version.parse(from_version)
    197     <= version.parse(current_version)
    198     <= version.parse(until_version)
    199 ):
    200     # Convert positional to keyword arguments and issue a warning
--> 201     return convert_positional_to_keyword(func, args, kwargs)
    203 # If the version is greater than the until_version,
    204 # pass the arguments as they are
    205 elif version.parse(current_version) > version.parse(until_version):

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/dipy/testing/decorators.py:192, in warning_for_keywords.<locals>.decorator.<locals>.wrapper.<locals>.convert_positional_to_keyword(func, args, kwargs)
    182         warnings.warn(
    183             f"Pass {positionally_passed_kwonly_args} as keyword args. "
    184             f"From version {until_version} passing these as positional "
   (...)    187             stacklevel=3,
    188         )
    190     return func(*positional_args, **corrected_kwargs)
--> 192 return func(*args, **kwargs)

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/dipy/data/fetcher.py:397, in fetch_data(files, folder, data_size, use_headers, raise_on_error)
    395 logger.info(f"From: {url}")
    396 try:
--> 397     _get_file_data(fullpath, url, use_headers=use_headers, stored_md5=md5)
    398     successful_downloads += 1
    399 except (FetcherError, Exception) as e:

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/site-packages/dipy/data/fetcher.py:262, in _get_file_data(fname, url, use_headers, timeout, max_retries, stored_md5)
    260 with open(fname, "wb") as data:
    261     if response_size is None:
--> 262         copyfileobj(opener, data)
    263     else:
    264         copyfileobj_withprogress(opener, data, response_size)

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/shutil.py:203, in copyfileobj(fsrc, fdst, length)
    201 fsrc_read = fsrc.read
    202 fdst_write = fdst.write
--> 203 while buf := fsrc_read(length):
    204     fdst_write(buf)

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/http/client.py:478, in HTTPResponse.read(self, amt)
    475     return b""
    477 if self.chunked:
--> 478     return self._read_chunked(amt)
    480 if amt is not None and amt >= 0:
    481     if self.length is not None and amt > self.length:
    482         # clip the read to the "end of response"

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/http/client.py:602, in HTTPResponse._read_chunked(self, amt)
    600 value = []
    601 try:
--> 602     while (chunk_left := self._get_chunk_left()) is not None:
    603         if amt is not None and amt <= chunk_left:
    604             value.append(self._safe_read(amt))

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/http/client.py:584, in HTTPResponse._get_chunk_left(self)
    582     self._safe_read(2)  # toss the CRLF at the end of the chunk
    583 try:
--> 584     chunk_left = self._read_next_chunk_size()
    585 except ValueError:
    586     raise IncompleteRead(b'')

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/http/client.py:544, in HTTPResponse._read_next_chunk_size(self)
    542 def _read_next_chunk_size(self):
    543     # Read the next chunk size from the file
--> 544     line = self.fp.readline(_MAXLINE + 1)
    545     if len(line) > _MAXLINE:
    546         raise LineTooLong("chunk size")

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/socket.py:719, in SocketIO.readinto(self, b)
    717     raise OSError("cannot read from timed out object")
    718 try:
--> 719     return self._sock.recv_into(b)
    720 except timeout:
    721     self._timeout_occurred = True

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/ssl.py:1304, in SSLSocket.recv_into(self, buffer, nbytes, flags)
   1300     if flags != 0:
   1301         raise ValueError(
   1302           "non-zero flags not allowed in calls to recv_into() on %s" %
   1303           self.__class__)
-> 1304     return self.read(nbytes, buffer)
   1305 else:
   1306     return super().recv_into(buffer, nbytes, flags)

File /opt/hostedtoolcache/Python/3.13.13/x64/lib/python3.13/ssl.py:1138, in SSLSocket.read(self, len, buffer)
   1136 try:
   1137     if buffer is not None:
-> 1138         return self._sslobj.read(len, buffer)
   1139     else:
   1140         return self._sslobj.read(len)

KeyboardInterrupt: 

Set tractography parameters#

We make create a tracking_params variable to define the parameters for tractography. The only parameter we need to set to use the GPU is jit_backend, which we set to “cuda”. Other backends include: “metal”, “webgpu”, or “numba”. Numba is the default. Note that the GPU backend will only run for probabilistic tracking, which is the default.

tracking_params = dict(n_seeds=1e7,
                       random_seeds=True,
                       rng_seed=2025,
                       jit_backend="cuda",
                       trx=True)

Running with the GPU#

Then, run pyAFQ normally. That’s it!

myafq = GroupAFQ(
    bids_path=op.join(afd.afq_home, 'stanford_hardi'),
    dwi_preproc_pipeline='vistasoft',
    t1_preproc_pipeline='freesurfer',
    tracking_params=tracking_params)

bundle_html = myafq.export("all_bundles_figure")
plotly.io.show(bundle_html["01"][0])