Skip to content

Commit

Permalink
eegbci api: allow downloading multiple subjects (mne-tools#12918)
Browse files Browse the repository at this point in the history
  • Loading branch information
sappelhoff authored Oct 28, 2024
1 parent 7250311 commit e15292f
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 47 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12918.apichange.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Deprecate ``subject`` parameter in favor of ``subjects`` in :func:`mne.datasets.eegbci.load_data`, by `Stefan Appelhoff`_.
18 changes: 12 additions & 6 deletions doc/documentation/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,23 @@ EEGBCI motor imagery
====================
:func:`mne.datasets.eegbci.load_data`

The EEGBCI dataset is documented in :footcite:`SchalkEtAl2004`. The data set is
available at PhysioNet :footcite:`GoldbergerEtAl2000`. The dataset contains
64-channel EEG recordings from 109 subjects and 14 runs on each subject in EDF+
format. The recordings were made using the BCI2000 system. To load a subject,
do::
The EEGBCI dataset is documented in :footcite:`SchalkEtAl2004` and on the
`PhysioNet documentation page <https://physionet.org/content/eegmmidb/1.0.0/>`_.
The data set is available at PhysioNet :footcite:`GoldbergerEtAl2000`.
It contains 64-channel EEG recordings from 109 subjects and 14 runs on each
subject in EDF+ format. The recordings were made using the BCI2000 system.
To load a subject, do::

from mne.io import concatenate_raws, read_raw_edf
from mne.datasets import eegbci
raw_fnames = eegbci.load_data(subject, runs)
subjects = [1] # may vary
runs = [4, 8, 12] # may vary
raw_fnames = eegbci.load_data(subjects, runs)
raws = [read_raw_edf(f, preload=True) for f in raw_fnames]
# concatenate runs from subject
raw = concatenate_raws(raws)
# make channel names follow standard conventions
eegbci.standardize(raw)

.. topic:: Examples

Expand Down
11 changes: 6 additions & 5 deletions examples/decoding/decoding_csp_eeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
See https://en.wikipedia.org/wiki/Common_spatial_pattern and
:footcite:`Koles1991`. The EEGBCI dataset is documented in
:footcite:`SchalkEtAl2004` and is available at PhysioNet
:footcite:`GoldbergerEtAl2000`.
:footcite:`SchalkEtAl2004` and on the
`PhysioNet documentation page <https://physionet.org/content/eegmmidb/1.0.0/>`_.
The dataset is available at PhysioNet :footcite:`GoldbergerEtAl2000`.
"""
# Authors: Martin Billinger <martin.billinger@tugraz.at>
#
Expand Down Expand Up @@ -40,15 +41,15 @@
# avoid classification of evoked responses by using epochs that start 1s after
# cue onset.
tmin, tmax = -1.0, 4.0
subject = 1
subjects = 1
runs = [6, 10, 14] # motor imagery: hands vs feet

raw_fnames = eegbci.load_data(subject, runs)
raw_fnames = eegbci.load_data(subjects, runs)
raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
eegbci.standardize(raw) # set channel names
montage = make_standard_montage("standard_1005")
raw.set_montage(montage)
raw.annotations.rename(dict(T1="hands", T2="feet"))
raw.annotations.rename(dict(T1="hands", T2="feet")) # as documented on PhysioNet
raw.set_eeg_reference(projection=True)

# Apply band-pass filter
Expand Down
2 changes: 1 addition & 1 deletion examples/preprocessing/eeg_bridging.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
raw_data = dict() # store infos for electrode positions
for sub in range(1, 11):
print(f"Computing electrode bridges for subject {sub}")
raw_fname = mne.datasets.eegbci.load_data(subject=sub, runs=(1,))[0]
raw_fname = mne.datasets.eegbci.load_data(subjects=sub, runs=(1,))[0]
raw = mne.io.read_raw(raw_fname, preload=True, verbose=False)
mne.datasets.eegbci.standardize(raw) # set channel names
raw.set_montage(montage, verbose=False)
Expand Down
2 changes: 1 addition & 1 deletion examples/preprocessing/muscle_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@

for sub in (1, 2):
raw = mne.io.read_raw_edf(
mne.datasets.eegbci.load_data(subject=sub, runs=(1,))[0], preload=True
mne.datasets.eegbci.load_data(subjects=sub, runs=(1,))[0], preload=True
)
mne.datasets.eegbci.standardize(raw) # set channel names
montage = mne.channels.make_standard_montage("standard_1005")
Expand Down
2 changes: 1 addition & 1 deletion examples/time_frequency/time_frequency_erds.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# First, we load and preprocess the data. We use runs 6, 10, and 14 from
# subject 1 (these runs contains hand and feet motor imagery).

fnames = eegbci.load_data(subject=1, runs=(6, 10, 14))
fnames = eegbci.load_data(subjects=1, runs=(6, 10, 14))
raw = concatenate_raws([read_raw_edf(f, preload=True) for f in fnames])

raw.rename_channels(lambda x: x.strip(".")) # remove dots from channel names
Expand Down
80 changes: 55 additions & 25 deletions mne/datasets/eegbci/eegbci.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from os import path as op
from pathlib import Path

from ...utils import _url_to_local_path, logger, verbose
from ...utils import _url_to_local_path, logger, verbose, warn
from ..utils import _do_path_update, _downloader_params, _get_path, _log_time_size

EEGMI_URL = "https://physionet.org/files/eegmmidb/1.0.0/"
Expand All @@ -21,7 +21,9 @@ def data_path(url, path=None, force_update=False, update_path=None, *, verbose=N
This is a low-level function useful for getting a local copy of a remote EEGBCI
dataset :footcite:`SchalkEtAl2004`, which is also available at PhysioNet
:footcite:`GoldbergerEtAl2000`.
:footcite:`GoldbergerEtAl2000`. Metadata, such as the meaning of event markers
may be obtained from the
`PhysioNet documentation page <https://physionet.org/content/eegmmidb/1.0.0/>`_.
Parameters
----------
Expand Down Expand Up @@ -92,8 +94,10 @@ def data_path(url, path=None, force_update=False, update_path=None, *, verbose=N

@verbose
def load_data(
subject,
runs,
subjects=None,
runs=None,
*,
subject=None,
path=None,
force_update=False,
update_path=None,
Expand All @@ -103,14 +107,19 @@ def load_data(
"""Get paths to local copies of EEGBCI dataset files.
This will fetch data for the EEGBCI dataset :footcite:`SchalkEtAl2004`, which is
also available at PhysioNet :footcite:`GoldbergerEtAl2000`.
also available at PhysioNet :footcite:`GoldbergerEtAl2000`. Metadata, such as the
meaning of event markers may be obtained from the
`PhysioNet documentation page <https://physionet.org/content/eegmmidb/1.0.0/>`_.
Parameters
----------
subject : int
The subject to use. Can be in the range of 1-109 (inclusive).
subjects : int | list of int
The subjects to use. Can be in the range of 1-109 (inclusive).
runs : int | list of int
The runs to use (see Notes for details).
subject : int
This parameter is deprecated and will be removed in mne version 1.9.
Please use ``subjects`` instead.
path : None | path-like
Location of where to look for the EEGBCI data. If ``None``, the environment
variable or config parameter ``MNE_DATASETS_EEGBCI_PATH`` is used. If neither
Expand Down Expand Up @@ -149,20 +158,39 @@ def load_data(
For example, one could do::
>>> from mne.datasets import eegbci
>>> eegbci.load_data(1, [6, 10, 14], "~/datasets") # doctest:+SKIP
>>> eegbci.load_data([1, 2], [6, 10, 14], "~/datasets") # doctest:+SKIP
This would download runs 6, 10, and 14 (hand/foot motor imagery) runs from subject 1
in the EEGBCI dataset to "~/datasets" and prompt the user to store this path in the
config (if it does not already exist).
This would download runs 6, 10, and 14 (hand/foot motor imagery) runs from subjects
1 and 2 in the EEGBCI dataset to "~/datasets" and prompt the user to store this path
in the config (if it does not already exist).
References
----------
.. footbibliography::
"""
import pooch

# XXX: Remove this with mne 1.9 ↓↓↓
# Also remove the subject parameter at that point.
# Also remove the `None` default for subjects and runs params at that point.
if subject is not None:
subjects = subject
warn(
"The ``subject`` parameter is deprecated and will be removed in version "
"1.9. Use the ``subjects`` parameter (note the `s`) to suppress this "
"warning.",
FutureWarning,
)
del subject
if subjects is None or runs is None:
raise ValueError("You must pass the parameters ``subjects`` and ``runs``.")
# ↑↑↑

t0 = time.time()

if not hasattr(subjects, "__iter__"):
subjects = [subjects]

if not hasattr(runs, "__iter__"):
runs = [runs]

Expand Down Expand Up @@ -198,20 +226,22 @@ def load_data(
# fetch the file(s)
data_paths = []
sz = 0
for run in runs:
file_part = f"S{subject:03d}/S{subject:03d}R{run:02d}.edf"
destination = Path(base_path, file_part)
data_paths.append(destination)
if destination.exists():
if force_update:
destination.unlink()
else:
continue
if sz == 0: # log once
logger.info("Downloading EEGBCI data")
fetcher.fetch(file_part)
# update path in config if desired
sz += destination.stat().st_size
for subject in subjects:
for run in runs:
file_part = f"S{subject:03d}/S{subject:03d}R{run:02d}.edf"
destination = Path(base_path, file_part)
data_paths.append(destination)
if destination.exists():
if force_update:
destination.unlink()
else:
continue
if sz == 0: # log once
logger.info("Downloading EEGBCI data")
fetcher.fetch(file_part)
# update path in config if desired
sz += destination.stat().st_size

_do_path_update(path, update_path, config_key, name)
if sz > 0:
_log_time_size(t0, sz)
Expand Down
18 changes: 16 additions & 2 deletions mne/datasets/eegbci/tests/test_eegbci.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,26 @@
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import pytest

from mne.datasets import eegbci


def test_eegbci_download(tmp_path, fake_retrieve):
"""Test Sleep Physionet URL handling."""
for subj in range(4):
fnames = eegbci.load_data(subj + 1, runs=[3], path=tmp_path, update_path=False)
subjects = range(1, 5)
for subj in subjects:
fnames = eegbci.load_data(subj, runs=[3], path=tmp_path, update_path=False)
assert len(fnames) == 1, subj
assert fake_retrieve.call_count == 4

# XXX: remove in version 1.9
with pytest.warns(FutureWarning, match="The ``subject``"):
fnames = eegbci.load_data(
subject=subjects, runs=[3], path=tmp_path, update_path=False
)
assert len(fnames) == 4

# XXX: remove in version 1.9
with pytest.raises(ValueError, match="You must pass the parameters"):
fnames = eegbci.load_data(path=tmp_path, update_path=False)
5 changes: 2 additions & 3 deletions mne/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,8 @@ def _download_all_example_data(verbose=True):
sleep_physionet,
)

eegbci.load_data(1, [6, 10, 14], update_path=True)
for subj in range(4):
eegbci.load_data(subj + 1, runs=[3], update_path=True)
eegbci.load_data(subjects=1, runs=[6, 10, 14], update_path=True)
eegbci.load_data(subjects=range(1, 5), runs=[3], update_path=True)
logger.info("[done eegbci]")

sleep_physionet.age.fetch_data(subjects=[0, 1], recording=[1])
Expand Down
2 changes: 1 addition & 1 deletion tutorials/forward/35_eeg_no_mri.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
# .. note:: See :ref:`plot_montage` to view all the standard EEG montages
# available in MNE-Python.

(raw_fname,) = eegbci.load_data(subject=1, runs=[6])
(raw_fname,) = eegbci.load_data(subjects=1, runs=[6])
raw = mne.io.read_raw_edf(raw_fname, preload=True)

# Clean channel names to be able to use a standard 1005 montage
Expand Down
4 changes: 2 additions & 2 deletions tutorials/preprocessing/40_artifact_correction_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,9 +532,9 @@
raws = list()
icas = list()

for subj in range(4):
for subj in range(1, 5):
# EEGBCI subjects are 1-indexed; run 3 is a left/right hand movement task
fname = mne.datasets.eegbci.load_data(subj + 1, runs=[3])[0]
fname = mne.datasets.eegbci.load_data(subj, runs=[3])[0]
raw = mne.io.read_raw_edf(fname).load_data().resample(50)
# remove trailing `.` from channel names so we can set montage
mne.datasets.eegbci.standardize(raw)
Expand Down

0 comments on commit e15292f

Please sign in to comment.