Skip to content

Commit

Permalink
MAINT: Improve NumPy 2 compat
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Jul 31, 2023
1 parent 4f82954 commit 2ad03cb
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 17 deletions.
2 changes: 1 addition & 1 deletion mne/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1728,4 +1728,4 @@ def count_annotations(annotations):
{'T0': 2, 'T1': 1}
"""
types, counts = np.unique(annotations.description, return_counts=True)
return {t: count for t, count in zip(types, counts)}
return {str(t): int(count) for t, count in zip(types, counts)}
4 changes: 0 additions & 4 deletions mne/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,6 @@ def pytest_configure(config):
warning_line = warning_line.strip()
if warning_line and not warning_line.startswith("#"):
config.addinivalue_line("filterwarnings", warning_line)
# TODO: Fix this with casts?
# https://github.com/numpy/numpy/pull/22449
if check_version("numpy", "1.26"):
np.set_printoptions(legacy="1.25")


# Have to be careful with autouse=True, but this is just an int comparison
Expand Down
4 changes: 2 additions & 2 deletions mne/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,7 +1681,7 @@ def count_events(events, ids=None):
{1: 2, 11: 0}
"""
counts = np.bincount(events[:, 2])
counts = {i: count for i, count in enumerate(counts) if count > 0}
counts = {i: int(count) for i, count in enumerate(counts) if count > 0}
if ids is not None:
return {id: counts.get(id, 0) for id in ids}
counts = {id_: int(counts.get(id_, 0)) for id_ in ids}
return counts
2 changes: 1 addition & 1 deletion mne/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2120,7 +2120,7 @@ def detrend(x, order=1, axis=-1):
>>> npoints = int(1e3)
>>> noise = randgen.randn(npoints)
>>> x = 3 + 2*np.linspace(0, 1, npoints) + noise
>>> (detrend(x) - noise).max() < 0.01
>>> bool((detrend(x) - noise).max() < 0.01)
True
"""
from scipy.signal import detrend
Expand Down
17 changes: 8 additions & 9 deletions mne/preprocessing/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from copy import deepcopy
from numbers import Integral
from time import time
from dataclasses import dataclass
from dataclasses import dataclass, is_dataclass
from typing import Optional, List, Literal
import warnings

Expand Down Expand Up @@ -112,15 +112,14 @@

def _make_xy_sfunc(func, ndim_output=False):
"""Aux function."""
if ndim_output:

def sfunc(x, y):
return np.array([func(a, y.ravel()) for a in x])[:, 0]

else:

def sfunc(x, y):
return np.array([func(a, y.ravel()) for a in x])
def sfunc(x, y, ndim_output=ndim_output):
out = [func(a, y.ravel()) for a in x]
if len(out) and is_dataclass(out[0]): # PermutationTestResult
out = [(o.statistic, o.pvalue) for o in out]
if ndim_output:
out = np.array(out)[:, 0]
return out

sfunc.__name__ = ".".join(["score_func", func.__module__, func.__name__])
sfunc.__doc__ = func.__doc__
Expand Down

0 comments on commit 2ad03cb

Please sign in to comment.