Skip to content

Commit

Permalink
MAINT: Check for shadowing and mutable defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Jan 22, 2024
1 parent 566c6ea commit 957f926
Show file tree
Hide file tree
Showing 37 changed files with 115 additions and 84 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where :func:`mne.preprocessing.compute_proj_ecg` and :func:`mne.preprocessing.compute_proj_eog` could modify the default ``reject`` and ``flat`` arguments on multiple calls based on channel types present, by `Eric Larson`_.
2 changes: 1 addition & 1 deletion mne/_fiff/open.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def show_fiff(
return out


def _find_type(value, fmts=["FIFF_"], exclude=["FIFF_UNIT"]):
def _find_type(value, fmts=("FIFF_",), exclude=("FIFF_UNIT",)):
"""Find matching values."""
value = int(value)
vals = [
Expand Down
10 changes: 5 additions & 5 deletions mne/_fiff/pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def channel_type(info, idx):


@verbose
def pick_channels(ch_names, include, exclude=[], ordered=None, *, verbose=None):
def pick_channels(ch_names, include, exclude=(), ordered=None, *, verbose=None):
"""Pick channels by names.
Returns the indices of ``ch_names`` in ``include`` but not in ``exclude``.
Expand Down Expand Up @@ -706,7 +706,7 @@ def _has_kit_refs(info, picks):

@verbose
def pick_channels_forward(
orig, include=[], exclude=[], ordered=None, copy=True, *, verbose=None
orig, include=(), exclude=(), ordered=None, copy=True, *, verbose=None
):
"""Pick channels from forward operator.
Expand Down Expand Up @@ -797,8 +797,8 @@ def pick_types_forward(
seeg=False,
ecog=False,
dbs=False,
include=[],
exclude=[],
include=(),
exclude=(),
):
"""Pick by channel type and names from a forward operator.
Expand Down Expand Up @@ -893,7 +893,7 @@ def channel_indices_by_type(info, picks=None):

@verbose
def pick_channels_cov(
orig, include=[], exclude="bads", ordered=None, copy=True, *, verbose=None
orig, include=(), exclude="bads", ordered=None, copy=True, *, verbose=None
):
"""Pick channels from covariance matrix.
Expand Down
4 changes: 2 additions & 2 deletions mne/channels/tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,8 @@ def test_1020_selection():
raw = raw.rename_channels(dict(zip(raw.ch_names, montage.ch_names)))
raw.set_montage(montage)

for input in ("a_string", 100, raw, [1, 2]):
pytest.raises(TypeError, make_1020_channel_selections, input)
for input_ in ("a_string", 100, raw, [1, 2]):
pytest.raises(TypeError, make_1020_channel_selections, input_)

sels = make_1020_channel_selections(raw.info)
# are all frontal channels placed before all occipital channels?
Expand Down
12 changes: 7 additions & 5 deletions mne/commands/mne_setup_source_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def run():
subjects_dir = options.subjects_dir
spacing = options.spacing
ico = options.ico
oct = options.oct
oct_ = options.oct
surface = options.surface
n_jobs = options.n_jobs
add_dist = options.add_dist
Expand All @@ -130,20 +130,22 @@ def run():
overwrite = True if options.overwrite is not None else False

# Parse source spacing option
spacing_options = [ico, oct, spacing]
spacing_options = [ico, oct_, spacing]
n_options = len([x for x in spacing_options if x is not None])
use_spacing = "oct6"
if n_options > 1:
raise ValueError("Only one spacing option can be set at the same time")
elif n_options == 0:
# Default to oct6
use_spacing = "oct6"
pass
elif n_options == 1:
if ico is not None:
use_spacing = "ico" + str(ico)
elif oct is not None:
use_spacing = "oct" + str(oct)
elif oct_ is not None:
use_spacing = "oct" + str(oct_)
elif spacing is not None:
use_spacing = spacing
del ico, oct_, spacing
# Generate filename
if fname is None:
if subject_to is None:
Expand Down
2 changes: 1 addition & 1 deletion mne/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def __iadd__(self, cov):
def plot(
self,
info,
exclude=[],
exclude=(),
colorbar=True,
proj=False,
show_svd=True,
Expand Down
4 changes: 2 additions & 2 deletions mne/datasets/sleep_physionet/tests/test_physionet.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ def _check_mocked_function_calls(mocked_func, call_fname_hash_pairs, base_path):
# order.
for idx, current in enumerate(call_fname_hash_pairs):
_, call_kwargs = mocked_func.call_args_list[idx]
hash_type, hash = call_kwargs["known_hash"].split(":")
hash_type, hash_ = call_kwargs["known_hash"].split(":")
assert call_kwargs["url"] == _get_expected_url(current["name"]), idx
assert Path(call_kwargs["path"], call_kwargs["fname"]) == _get_expected_path(
base_path, current["name"]
)
assert hash == current["hash"]
assert hash_ == current["hash"]
assert hash_type == "sha1"


Expand Down
2 changes: 1 addition & 1 deletion mne/decoding/tests/test_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


def simulate_data(
freqs_sig=[9, 12],
freqs_sig=(9, 12),
n_trials=100,
n_channels=20,
n_samples=500,
Expand Down
4 changes: 2 additions & 2 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3662,7 +3662,7 @@ def _is_good(
reject,
flat,
full_report=False,
ignore_chs=[],
ignore_chs=(),
verbose=None,
):
"""Test if data segment e is good according to reject and flat.
Expand Down Expand Up @@ -4631,7 +4631,7 @@ def make_fixed_length_epochs(
reject_by_annotation=True,
proj=True,
overlap=0.0,
id=1,
id=1, # noqa: A002
verbose=None,
):
"""Divide continuous raw data into equal-sized consecutive epochs.
Expand Down
8 changes: 7 additions & 1 deletion mne/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,13 @@ def shift_time_events(events, ids, tshift, sfreq):

@fill_doc
def make_fixed_length_events(
raw, id=1, start=0, stop=None, duration=1.0, first_samp=True, overlap=0.0
raw,
id=1, # noqa: A002
start=0,
stop=None,
duration=1.0,
first_samp=True,
overlap=0.0,
):
"""Make a set of :term:`events` separated by a fixed duration.
Expand Down
2 changes: 1 addition & 1 deletion mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def plot_topo(
scalings=None,
title=None,
proj=False,
vline=[0.0],
vline=(0.0,),
fig_background=None,
merge_grads=False,
legend=True,
Expand Down
4 changes: 2 additions & 2 deletions mne/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def _infer_check_export_fmt(fmt, fname, supported_formats):

if fmt not in supported_formats:
supported = []
for format, extensions in supported_formats.items():
for fmt, extensions in supported_formats.items():
ext_str = ", ".join(f"*.{ext}" for ext in extensions)
supported.append(f"{format} ({ext_str})")
supported.append(f"{fmt} ({ext_str})")

supported_str = ", ".join(supported)
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion mne/inverse_sparse/mxne_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ def __call__(self, x): # noqa: D105
else:
return np.hstack([x @ op for op in self.ops]) / np.sqrt(self.n_dicts)

def norm(self, z, ord=2):
def norm(self, z, ord=2): # noqa: A002
"""Squared L2 norm if ord == 2 and L1 norm if order == 1."""
if ord not in (1, 2):
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion mne/io/bti/bti.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, target):
def __enter__(self): # noqa: D105
return self.target

def __exit__(self, type, value, tb): # noqa: D105
def __exit__(self, exception_type, value, tb): # noqa: D105
pass


Expand Down
6 changes: 3 additions & 3 deletions mne/io/curry/curry.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@ def _read_curry_parameters(fname):
if any(var_name in line for var_name in var_names):
key, val = line.replace(" ", "").replace("\n", "").split("=")
param_dict[key.lower().replace("_", "")] = val
for type in CHANTYPES:
if "DEVICE_PARAMETERS" + CHANTYPES[type] + " START" in line:
for key, type_ in CHANTYPES.items():
if f"DEVICE_PARAMETERS{type_} START" in line:
data_unit = next(fid)
unit_dict[type] = (
unit_dict[key] = (
data_unit.replace(" ", "").replace("\n", "").split("=")[-1]
)

Expand Down
4 changes: 3 additions & 1 deletion mne/io/edf/edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,7 +1454,9 @@ def _read_gdf_header(fname, exclude, include=None):


def _check_stim_channel(
stim_channel, ch_names, tal_ch_names=["EDF Annotations", "BDF Annotations"]
stim_channel,
ch_names,
tal_ch_names=("EDF Annotations", "BDF Annotations"),
):
"""Check that the stimulus channel exists in the current datafile."""
DEFAULT_STIM_CH_NAMES = ["status", "trigger"]
Expand Down
2 changes: 1 addition & 1 deletion mne/io/fieldtrip/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def get_epochs(system):
else:
event_id = [int(cfg_local["eventvalue"])]

event_id = [id for id in event_id if id in events[:, 2]]
event_id = [id_ for id_ in event_id if id_ in events[:, 2]]

epochs = mne.Epochs(
raw_data,
Expand Down
10 changes: 5 additions & 5 deletions mne/io/fieldtrip/tests/test_fieldtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,18 +254,18 @@ def test_one_channel_elec_bug(version):
@pytest.mark.filterwarnings("ignore:.*number of bytes.*:RuntimeWarning")
@pytest.mark.parametrize("version", all_versions)
@pytest.mark.parametrize("type", ["averaged", "epoched", "raw"])
def test_throw_exception_on_cellarray(version, type):
def test_throw_exception_on_cellarray(version, type_):
"""Test for a meaningful exception when the data is a cell array."""
fname = get_data_paths("cellarray") / f"{type}_{version}.mat"
fname = get_data_paths("cellarray") / f"{type_}_{version}.mat"
info = get_raw_info("CNT")
with pytest.raises(
RuntimeError, match="Loading of data in cell arrays " "is not supported"
):
if type == "averaged":
if type_ == "averaged":
mne.read_evoked_fieldtrip(fname, info)
elif type == "epoched":
elif type_ == "epoched":
mne.read_epochs_fieldtrip(fname, info)
elif type == "raw":
elif type_ == "raw":
mne.io.read_raw_fieldtrip(fname, info)


Expand Down
4 changes: 1 addition & 3 deletions mne/preprocessing/nirs/_tddr.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def _TDDR(signal, sample_rate):
tune = 4.685
D = np.sqrt(np.finfo(signal.dtype).eps)
mu = np.inf
iter = 0

# Step 1. Compute temporal derivative of the signal
deriv = np.diff(signal_low)
Expand All @@ -120,8 +119,7 @@ def _TDDR(signal, sample_rate):
w = np.ones(deriv.shape)

# Step 3. Iterative estimation of robust weights
while iter < 50:
iter = iter + 1
for _ in range(50):
mu0 = mu

# Step 3a. Estimate weighted mean
Expand Down
14 changes: 9 additions & 5 deletions mne/preprocessing/ssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .._fiff.reference import make_eeg_average_ref_proj
from ..epochs import Epochs
from ..proj import compute_proj_epochs, compute_proj_evoked
from ..utils import logger, verbose, warn
from ..utils import _validate_type, logger, verbose, warn
from .ecg import find_ecg_events
from .eog import find_eog_events

Expand Down Expand Up @@ -112,7 +112,10 @@ def _compute_exg_proj(
my_info["bads"] += bads

# Handler rejection parameters
_validate_type(reject, (None, dict), "reject")
_validate_type(flat, (None, dict), "flat")
if reject is not None: # make sure they didn't pass None
reject = reject.copy() # must make a copy or we modify default!
if (
len(
pick_types(
Expand Down Expand Up @@ -170,6 +173,7 @@ def _compute_exg_proj(
):
_safe_del_key(reject, "eog")
if flat is not None: # make sure they didn't pass None
flat = flat.copy()
if (
len(
pick_types(
Expand Down Expand Up @@ -300,9 +304,9 @@ def compute_proj_ecg(
filter_length="10s",
n_jobs=None,
ch_name=None,
reject=dict(grad=2000e-13, mag=3000e-15, eeg=50e-6, eog=250e-6),
reject=dict(grad=2000e-13, mag=3000e-15, eeg=50e-6, eog=250e-6), # noqa: B006
flat=None,
bads=[],
bads=(),
avg_ref=False,
no_proj=False,
event_id=999,
Expand Down Expand Up @@ -461,9 +465,9 @@ def compute_proj_eog(
average=True,
filter_length="10s",
n_jobs=None,
reject=dict(grad=2000e-13, mag=3000e-15, eeg=500e-6, eog=np.inf),
reject=dict(grad=2000e-13, mag=3000e-15, eeg=500e-6, eog=np.inf), # noqa: B006
flat=None,
bads=[],
bads=(),
avg_ref=False,
no_proj=False,
event_id=998,
Expand Down
3 changes: 3 additions & 0 deletions mne/preprocessing/tests/test_ssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def test_compute_proj_ecg(short_raw, average):
# XXX: better tests

# without setting a bad channel, this should throw a warning
# (first with a call that makes sure we copy the mutable default "reject")
with pytest.warns(RuntimeWarning, match="longer than the signal"):
compute_proj_ecg(raw.copy().pick("mag"), l_freq=None, h_freq=None)
with pytest.warns(RuntimeWarning, match="No good epochs found"):
projs, events, drop_log = compute_proj_ecg(
raw,
Expand Down
2 changes: 1 addition & 1 deletion mne/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -2998,7 +2998,7 @@ def __enter__(self):
"""Do nothing when entering the context block."""
return self

def __exit__(self, type, value, traceback):
def __exit__(self, exception_type, value, traceback):
"""Save the report when leaving the context block."""
if self.fname is not None:
self.save(self.fname, open_browser=False, overwrite=True)
Expand Down
10 changes: 8 additions & 2 deletions mne/source_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2476,7 +2476,7 @@ def save_as_volume(
src,
dest="mri",
mri_resolution=False,
format="nifti1",
format="nifti1", # noqa: A002
*,
overwrite=False,
verbose=None,
Expand Down Expand Up @@ -2525,7 +2525,13 @@ def save_as_volume(
)
nib.save(img, fname)

def as_volume(self, src, dest="mri", mri_resolution=False, format="nifti1"):
def as_volume(
self,
src,
dest="mri",
mri_resolution=False,
format="nifti1", # noqa: A002
):
"""Export volume source estimate as a nifti object.
Parameters
Expand Down
3 changes: 2 additions & 1 deletion mne/source_space/_source_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -2047,11 +2047,12 @@ def _make_volume_source_space(
volume_labels=None,
do_neighbors=True,
n_jobs=None,
vol_info={},
vol_info=None,
single_volume=False,
):
"""Make a source space which covers the volume bounded by surf."""
# Figure out the grid size in the MRI coordinate frame
vol_info = {} if vol_info is None else vol_info
if "rr" in surf:
mins = np.min(surf["rr"], axis=0)
maxs = np.max(surf["rr"], axis=0)
Expand Down
2 changes: 1 addition & 1 deletion mne/time_frequency/csd.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

@verbose
def pick_channels_csd(
csd, include=[], exclude=[], ordered=None, copy=True, *, verbose=None
csd, include=(), exclude=(), ordered=None, copy=True, *, verbose=None
):
"""Pick channels from cross-spectral density matrix.
Expand Down
Loading

0 comments on commit 957f926

Please sign in to comment.