From 864b3e63114708789a7f23d0ea5381bbe81e3302 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Sat, 7 Oct 2023 05:21:33 +0200 Subject: [PATCH] Use constrained layout in matplotlib visualization (#12050) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Larson --- README.rst | 2 +- doc/changes/devel.rst | 1 + doc/conf.py | 1 - examples/decoding/decoding_rsa_sgskip.py | 6 +- examples/decoding/decoding_spoc_CMC.py | 3 +- ...decoding_time_generalization_conditions.py | 2 +- examples/decoding/decoding_xdawn_eeg.py | 9 +- examples/decoding/receptive_field_mtrf.py | 16 +- examples/inverse/label_source_activations.py | 6 +- .../inverse/mixed_source_space_inverse.py | 3 +- examples/inverse/source_space_snr.py | 3 +- examples/preprocessing/eeg_bridging.py | 7 +- examples/preprocessing/eeg_csd.py | 3 +- .../preprocessing/eog_artifact_histogram.py | 3 +- examples/preprocessing/eog_regression.py | 7 +- examples/preprocessing/shift_evoked.py | 3 - examples/simulation/plot_stc_metrics.py | 6 +- .../source_label_time_frequency.py | 25 +-- .../source_power_spectrum_opm.py | 1 - .../time_frequency_simulated.py | 17 +- examples/visualization/3d_to_2d.py | 3 +- examples/visualization/evoked_topomap.py | 4 +- mne/conftest.py | 38 +++-- mne/preprocessing/eyetracking/calibration.py | 2 +- mne/preprocessing/ica.py | 1 - mne/report/report.py | 30 +--- mne/time_frequency/spectrum.py | 1 - mne/time_frequency/tfr.py | 38 +++-- mne/viz/_3d.py | 26 +-- mne/viz/__init__.pyi | 2 - mne/viz/_dipole.py | 4 +- mne/viz/_figure.py | 12 +- mne/viz/_mpl_figure.py | 29 +++- mne/viz/_proj.py | 2 +- mne/viz/backends/_abstract.py | 15 +- mne/viz/backends/tests/test_utils.py | 3 + mne/viz/circle.py | 2 +- mne/viz/epochs.py | 14 +- mne/viz/evoked.py | 75 ++++----- mne/viz/ica.py | 24 +-- mne/viz/misc.py | 51 +++--- mne/viz/tests/test_epochs.py | 9 +- mne/viz/tests/test_evoked.py | 4 +- mne/viz/tests/test_topomap.py | 6 +- mne/viz/topo.py | 6 +- mne/viz/topomap.py | 106 +++++------- mne/viz/utils.py | 159 ++---------------- requirements.txt | 2 +- requirements_base.txt | 2 +- tools/github_actions_env_vars.sh | 2 +- .../epochs/60_make_fixed_length_epochs.py | 7 +- .../forward/50_background_freesurfer_mne.py | 3 +- tutorials/intro/70_report.py | 2 +- tutorials/inverse/20_dipole_fit.py | 2 +- tutorials/inverse/60_visualize_stc.py | 3 +- .../inverse/80_brainstorm_phantom_elekta.py | 2 +- tutorials/machine-learning/30_strf.py | 40 ++--- .../preprocessing/25_background_filtering.py | 10 +- .../preprocessing/30_filtering_resampling.py | 3 - .../50_artifact_correction_ssp.py | 5 +- .../preprocessing/60_maxwell_filtering_sss.py | 5 +- .../preprocessing/70_fnirs_processing.py | 15 +- tutorials/preprocessing/80_opm_processing.py | 12 +- tutorials/raw/20_event_arrays.py | 1 - tutorials/simulation/80_dics.py | 3 +- .../stats-sensor-space/10_background_stats.py | 17 +- .../40_cluster_1samp_time_freq.py | 17 +- .../50_cluster_between_time_freq.py | 3 +- .../70_cluster_rmANOVA_time_freq.py | 10 +- .../75_cluster_ftest_spatiotemporal.py | 13 +- .../time-freq/20_sensors_time_frequency.py | 2 +- 71 files changed, 351 insertions(+), 620 deletions(-) diff --git a/README.rst b/README.rst index c601e318b51..a3d35deb76a 100644 --- a/README.rst +++ b/README.rst @@ -96,7 +96,7 @@ The minimum required dependencies to run MNE-Python are: - Python >= 3.8 - NumPy >= 1.21.2 - SciPy >= 1.7.1 -- Matplotlib >= 3.4.3 +- Matplotlib >= 3.5.0 - pooch >= 1.5 - tqdm - Jinja2 diff --git a/doc/changes/devel.rst b/doc/changes/devel.rst index 81e16e8658e..b46c2a6fc60 100644 --- a/doc/changes/devel.rst +++ b/doc/changes/devel.rst @@ -37,6 +37,7 @@ Enhancements - Add :class:`~mne.time_frequency.EpochsSpectrumArray` and :class:`~mne.time_frequency.SpectrumArray` to support creating power spectra from :class:`NumPy array ` data (:gh:`11803` by `Alex Rockhill`_) - Add support for writing forward solutions to HDF5 and convenience function :meth:`mne.Forward.save` (:gh:`12036` by `Eric Larson`_) - Refactored internals of :func:`mne.read_annotations` (:gh:`11964` by `Paul Roujansky`_) +- By default MNE-Python creates matplotlib figures with ``layout='constrained'`` rather than the default ``layout='tight'`` (:gh:`12050` by `Mathieu Scheltienne`_ and `Eric Larson`_) - Enhance :func:`~mne.viz.plot_evoked_field` with a GUI that has controls for time, colormap, and contour lines (:gh:`11942` by `Marijn van Vliet`_) - Add :class:`mne.viz.ui_events.UIEvent` linking for interactive colorbars, allowing users to link figures and change the colormap and limits interactively. This supports :func:`~mne.viz.plot_evoked_topomap`, :func:`~mne.viz.plot_ica_components`, :func:`~mne.viz.plot_tfr_topomap`, :func:`~mne.viz.plot_projs_topomap`, :meth:`~mne.Evoked.plot_image`, and :meth:`~mne.Epochs.plot_image` (:gh:`12057` by `Santeri Ruuskanen`_) diff --git a/doc/conf.py b/doc/conf.py index d8c9f52ad6e..b8086500640 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1291,7 +1291,6 @@ def reset_warnings(gallery_conf, fname): warnings.filterwarnings("default", module="sphinx") # allow these warnings, but don't show them for key in ( - "The module matplotlib.tight_layout is deprecated", # nilearn "invalid version and will not be supported", # pyxdf "distutils Version classes are deprecated", # seaborn and neo "`np.object` is a deprecated alias for the builtin `object`", # pyxdf diff --git a/examples/decoding/decoding_rsa_sgskip.py b/examples/decoding/decoding_rsa_sgskip.py index 7cc6dbfbb01..3cc8467deb3 100644 --- a/examples/decoding/decoding_rsa_sgskip.py +++ b/examples/decoding/decoding_rsa_sgskip.py @@ -150,7 +150,7 @@ ############################################################################## # Plot labels = [""] * 5 + ["face"] + [""] * 11 + ["bodypart"] + [""] * 6 -fig, ax = plt.subplots(1) +fig, ax = plt.subplots(1, layout="constrained") im = ax.matshow(confusion, cmap="RdBu_r", clim=[0.3, 0.7]) ax.set_yticks(range(len(classes))) ax.set_yticklabels(labels) @@ -159,14 +159,13 @@ ax.axhline(11.5, color="k") ax.axvline(11.5, color="k") plt.colorbar(im) -plt.tight_layout() plt.show() ############################################################################## # Confusion matrix related to mental representations have been historically # summarized with dimensionality reduction using multi-dimensional scaling [1]. # See how the face samples cluster together. -fig, ax = plt.subplots(1) +fig, ax = plt.subplots(1, layout="constrained") mds = MDS(2, random_state=0, dissimilarity="precomputed") chance = 0.5 summary = mds.fit_transform(chance - confusion) @@ -186,7 +185,6 @@ ) ax.axis("off") ax.legend(loc="lower right", scatterpoints=1, ncol=2) -plt.tight_layout() plt.show() ############################################################################## diff --git a/examples/decoding/decoding_spoc_CMC.py b/examples/decoding/decoding_spoc_CMC.py index d73e9af9bbc..4e689d338d5 100644 --- a/examples/decoding/decoding_spoc_CMC.py +++ b/examples/decoding/decoding_spoc_CMC.py @@ -68,7 +68,7 @@ y_preds = cross_val_predict(clf, X, y, cv=cv) # Plot the True EMG power and the EMG power predicted from MEG data -fig, ax = plt.subplots(1, 1, figsize=[10, 4]) +fig, ax = plt.subplots(1, 1, figsize=[10, 4], layout="constrained") times = raw.times[meg_epochs.events[:, 0] - raw.first_samp] ax.plot(times, y_preds, color="b", label="Predicted EMG") ax.plot(times, y, color="r", label="True EMG") @@ -76,7 +76,6 @@ ax.set_ylabel("EMG Power") ax.set_title("SPoC MEG Predictions") plt.legend() -mne.viz.tight_layout() plt.show() ############################################################################## diff --git a/examples/decoding/decoding_time_generalization_conditions.py b/examples/decoding/decoding_time_generalization_conditions.py index 08ca0d9c0c3..a018ebbe75b 100644 --- a/examples/decoding/decoding_time_generalization_conditions.py +++ b/examples/decoding/decoding_time_generalization_conditions.py @@ -88,7 +88,7 @@ # %% # Plot -fig, ax = plt.subplots(constrained_layout=True) +fig, ax = plt.subplots(layout="constrained") im = ax.matshow( scores, vmin=0, diff --git a/examples/decoding/decoding_xdawn_eeg.py b/examples/decoding/decoding_xdawn_eeg.py index 3bdff716228..e7fac8c52e6 100644 --- a/examples/decoding/decoding_xdawn_eeg.py +++ b/examples/decoding/decoding_xdawn_eeg.py @@ -99,14 +99,13 @@ cm_normalized = cm.astype(float) / cm.sum(axis=1)[:, np.newaxis] # Plot confusion matrix -fig, ax = plt.subplots(1) +fig, ax = plt.subplots(1, layout="constrained") im = ax.imshow(cm_normalized, interpolation="nearest", cmap=plt.cm.Blues) ax.set(title="Normalized Confusion matrix") fig.colorbar(im) tick_marks = np.arange(len(target_names)) plt.xticks(tick_marks, target_names, rotation=45) plt.yticks(tick_marks, target_names) -fig.tight_layout() ax.set(ylabel="True label", xlabel="Predicted label") # %% @@ -114,7 +113,10 @@ # cross-validation fold) can be used for visualization. fig, axes = plt.subplots( - nrows=len(event_id), ncols=n_filter, figsize=(n_filter, len(event_id) * 2) + nrows=len(event_id), + ncols=n_filter, + figsize=(n_filter, len(event_id) * 2), + layout="constrained", ) fitted_xdawn = clf.steps[0][1] info = create_info(epochs.ch_names, 1, epochs.get_channel_types()) @@ -131,7 +133,6 @@ show=False, ) axes[ii, 0].set(ylabel=cur_class) -fig.tight_layout(h_pad=1.0, w_pad=1.0, pad=0.1) # %% # References diff --git a/examples/decoding/receptive_field_mtrf.py b/examples/decoding/receptive_field_mtrf.py index 0d24d5ebfa1..e927cd3cf25 100644 --- a/examples/decoding/receptive_field_mtrf.py +++ b/examples/decoding/receptive_field_mtrf.py @@ -67,12 +67,11 @@ n_channels = len(raw.ch_names) # Plot a sample of brain and stimulus activity -fig, ax = plt.subplots() +fig, ax = plt.subplots(layout="constrained") lns = ax.plot(scale(raw[:, :800][0].T), color="k", alpha=0.1) ln1 = ax.plot(scale(speech[0, :800]), color="r", lw=2) ax.legend([lns[0], ln1[0]], ["EEG", "Speech Envelope"], frameon=False) ax.set(title="Sample activity", xlabel="Time (s)") -mne.viz.tight_layout() # %% # Create and fit a receptive field model @@ -117,12 +116,11 @@ mean_scores = scores.mean(axis=0) # Plot mean prediction scores across all channels -fig, ax = plt.subplots() +fig, ax = plt.subplots(layout="constrained") ix_chs = np.arange(n_channels) ax.plot(ix_chs, mean_scores) ax.axhline(0, ls="--", color="r") ax.set(title="Mean prediction score", xlabel="Channel", ylabel="Score ($r$)") -mne.viz.tight_layout() # %% # Investigate model coefficients @@ -134,7 +132,7 @@ # Print mean coefficients across all time delays / channels (see Fig 1) time_plot = 0.180 # For highlighting a specific time. -fig, ax = plt.subplots(figsize=(4, 8)) +fig, ax = plt.subplots(figsize=(4, 8), layout="constrained") max_coef = mean_coefs.max() ax.pcolormesh( times, @@ -155,16 +153,14 @@ xticks=np.arange(tmin, tmax + 0.2, 0.2), ) plt.setp(ax.get_xticklabels(), rotation=45) -mne.viz.tight_layout() # Make a topographic map of coefficients for a given delay (see Fig 2C) ix_plot = np.argmin(np.abs(time_plot - times)) -fig, ax = plt.subplots() +fig, ax = plt.subplots(layout="constrained") mne.viz.plot_topomap( mean_coefs[:, ix_plot], pos=info, axes=ax, show=False, vlim=(-max_coef, max_coef) ) ax.set(title="Topomap of model coefficients\nfor delay %s" % time_plot) -mne.viz.tight_layout() # %% # Create and fit a stimulus reconstruction model @@ -240,7 +236,7 @@ y_pred = sr.predict(Y[test]) time = np.linspace(0, 2.0, 5 * int(sfreq)) -fig, ax = plt.subplots(figsize=(8, 4)) +fig, ax = plt.subplots(figsize=(8, 4), layout="constrained") ax.plot( time, speech[test][sr.valid_samples_][: int(5 * sfreq)], color="grey", lw=2, ls="--" ) @@ -248,7 +244,6 @@ ax.legend([lns[0], ln1[0]], ["Envelope", "Reconstruction"], frameon=False) ax.set(title="Stimulus reconstruction") ax.set_xlabel("Time (s)") -mne.viz.tight_layout() # %% # Investigate model coefficients @@ -292,7 +287,6 @@ title="Inverse-transformed coefficients\nbetween delays %s and %s" % (time_plot[0], time_plot[1]) ) -mne.viz.tight_layout() # %% # References diff --git a/examples/inverse/label_source_activations.py b/examples/inverse/label_source_activations.py index 599fff4c2f8..035533b4b9a 100644 --- a/examples/inverse/label_source_activations.py +++ b/examples/inverse/label_source_activations.py @@ -62,7 +62,7 @@ # View source activations # ----------------------- -fig, ax = plt.subplots(1) +fig, ax = plt.subplots(1, layout="constrained") t = 1e3 * stc_label.times ax.plot(t, stc_label.data.T, "k", linewidth=0.5, alpha=0.5) pe = [ @@ -81,7 +81,6 @@ xlim=xlim, ylim=ylim, ) -mne.viz.tight_layout() # %% # Using vector solutions @@ -92,7 +91,7 @@ pick_ori = "vector" stc_vec = apply_inverse(evoked, inverse_operator, lambda2, method, pick_ori=pick_ori) data = stc_vec.extract_label_time_course(label, src) -fig, ax = plt.subplots(1) +fig, ax = plt.subplots(1, layout="constrained") stc_vec_label = stc_vec.in_label(label) colors = ["#EE6677", "#228833", "#4477AA"] for ii, name in enumerate("XYZ"): @@ -117,4 +116,3 @@ xlim=xlim, ylim=ylim, ) -mne.viz.tight_layout() diff --git a/examples/inverse/mixed_source_space_inverse.py b/examples/inverse/mixed_source_space_inverse.py index 9baac7da379..f069b5e89ac 100644 --- a/examples/inverse/mixed_source_space_inverse.py +++ b/examples/inverse/mixed_source_space_inverse.py @@ -194,9 +194,8 @@ ) # plot the times series of 2 labels -fig, axes = plt.subplots(1) +fig, axes = plt.subplots(1, layout="constrained") axes.plot(1e3 * stc.times, label_ts[0][0, :], "k", label="bankssts-lh") axes.plot(1e3 * stc.times, label_ts[0][-1, :].T, "r", label="Brain-stem") axes.set(xlabel="Time (ms)", ylabel="MNE current (nAm)") axes.legend() -mne.viz.tight_layout() diff --git a/examples/inverse/source_space_snr.py b/examples/inverse/source_space_snr.py index 12d081f5c61..c7077d091e5 100644 --- a/examples/inverse/source_space_snr.py +++ b/examples/inverse/source_space_snr.py @@ -51,10 +51,9 @@ # Plot an average SNR across source points over time: ave = np.mean(snr_stc.data, axis=0) -fig, ax = plt.subplots() +fig, ax = plt.subplots(layout="constrained") ax.plot(evoked.times, ave) ax.set(xlabel="Time (s)", ylabel="SNR MEG-EEG") -fig.tight_layout() # Find time point of maximum SNR maxidx = np.argmax(ave) diff --git a/examples/preprocessing/eeg_bridging.py b/examples/preprocessing/eeg_bridging.py index d95ac709513..30cdde8502b 100644 --- a/examples/preprocessing/eeg_bridging.py +++ b/examples/preprocessing/eeg_bridging.py @@ -88,7 +88,7 @@ bridged_idx, ed_matrix = ed_data[6] -fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4)) +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4), layout="constrained") fig.suptitle("Subject 6 Electrical Distance Matrix") # take median across epochs, only use upper triangular, lower is NaNs @@ -110,8 +110,6 @@ ax.set_xlabel("Channel Index") ax.set_ylabel("Channel Index") -fig.tight_layout() - # %% # Examine the Distribution of Electrical Distances # ------------------------------------------------ @@ -208,7 +206,7 @@ # reflect neural or at least anatomical differences as well (i.e. the # distance from the sensors to the brain). -fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4)) +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4), layout="constrained") fig.suptitle("Electrical Distance Distribution for EEGBCI Subjects") for ax in (ax1, ax2): ax.set_ylabel("Count") @@ -229,7 +227,6 @@ ax1.axvspan(0, 30, color="r", alpha=0.5) ax2.legend(loc=(1.04, 0)) -fig.subplots_adjust(right=0.725, bottom=0.15, wspace=0.4) # %% # For the group of subjects, let's look at their electrical distances diff --git a/examples/preprocessing/eeg_csd.py b/examples/preprocessing/eeg_csd.py index dffe94e3f1e..892f856e75e 100644 --- a/examples/preprocessing/eeg_csd.py +++ b/examples/preprocessing/eeg_csd.py @@ -78,8 +78,7 @@ # CSD has parameters ``stiffness`` and ``lambda2`` affecting smoothing and # spline flexibility, respectively. Let's see how they affect the solution: -fig, ax = plt.subplots(4, 4) -fig.subplots_adjust(hspace=0.5) +fig, ax = plt.subplots(4, 4, layout="constrained") fig.set_size_inches(10, 10) for i, lambda2 in enumerate([0, 1e-7, 1e-5, 1e-3]): for j, m in enumerate([5, 4, 3, 2]): diff --git a/examples/preprocessing/eog_artifact_histogram.py b/examples/preprocessing/eog_artifact_histogram.py index 5aa209228d7..2d51370b571 100644 --- a/examples/preprocessing/eog_artifact_histogram.py +++ b/examples/preprocessing/eog_artifact_histogram.py @@ -50,7 +50,6 @@ # %% # Plot EOG artifact distribution -fig, ax = plt.subplots() +fig, ax = plt.subplots(layout="constrained") ax.stem(1e3 * epochs.times, data) ax.set(xlabel="Times (ms)", ylabel="Blink counts (from %s trials)" % len(epochs)) -fig.tight_layout() diff --git a/examples/preprocessing/eog_regression.py b/examples/preprocessing/eog_regression.py index 6c88cb01d9a..2123974dde4 100644 --- a/examples/preprocessing/eog_regression.py +++ b/examples/preprocessing/eog_regression.py @@ -69,10 +69,9 @@ epochs_after = mne.Epochs(raw_clean, events, event_id, tmin, tmax, baseline=(tmin, 0)) evoked_after = epochs_after.average() -fig, ax = plt.subplots(nrows=3, ncols=2, figsize=(10, 7), sharex=True, sharey="row") +fig, ax = plt.subplots( + nrows=3, ncols=2, figsize=(10, 7), sharex=True, sharey="row", layout="constrained" +) evoked_before.plot(axes=ax[:, 0], spatial_colors=True) evoked_after.plot(axes=ax[:, 1], spatial_colors=True) -fig.subplots_adjust( - top=0.905, bottom=0.09, left=0.08, right=0.975, hspace=0.325, wspace=0.145 -) fig.suptitle("Before --> After") diff --git a/examples/preprocessing/shift_evoked.py b/examples/preprocessing/shift_evoked.py index 3cd70715ac8..c16becc679c 100644 --- a/examples/preprocessing/shift_evoked.py +++ b/examples/preprocessing/shift_evoked.py @@ -14,7 +14,6 @@ import matplotlib.pyplot as plt import mne -from mne.viz import tight_layout from mne.datasets import sample print(__doc__) @@ -60,5 +59,3 @@ titles=dict(grad="Absolute shift: 500 ms"), time_unit="s", ) - -tight_layout() diff --git a/examples/simulation/plot_stc_metrics.py b/examples/simulation/plot_stc_metrics.py index 750dcab0c21..105c66d7e12 100644 --- a/examples/simulation/plot_stc_metrics.py +++ b/examples/simulation/plot_stc_metrics.py @@ -234,7 +234,7 @@ ] # Plot the results -f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, sharex="col", constrained_layout=True) +f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, sharex="col", layout="constrained") for ax, (title, results) in zip([ax1, ax2, ax3, ax4], region_results.items()): ax.plot(thresholds, results, ".-") ax.set(title=title, ylabel="score", xlabel="Threshold", xticks=thresholds) @@ -243,7 +243,7 @@ ax1.ticklabel_format(axis="y", style="sci", scilimits=(0, 1)) # tweak RLE # Cosine score with respect to time -f, ax1 = plt.subplots(constrained_layout=True) +f, ax1 = plt.subplots(layout="constrained") ax1.plot(stc_true_region.times, cosine_score(stc_true_region, stc_est_region)) ax1.set(title="Cosine score", xlabel="Time", ylabel="Score") @@ -277,6 +277,6 @@ # Plot the results for name, results in dipole_results.items(): - f, ax1 = plt.subplots(constrained_layout=True) + f, ax1 = plt.subplots(layout="constrained") ax1.plot(thresholds, 100 * np.array(results), ".-") ax1.set(title=name, ylabel="Error (cm)", xlabel="Threshold", xticks=thresholds) diff --git a/examples/time_frequency/source_label_time_frequency.py b/examples/time_frequency/source_label_time_frequency.py index da3af06e4dc..2e7cc4d3592 100644 --- a/examples/time_frequency/source_label_time_frequency.py +++ b/examples/time_frequency/source_label_time_frequency.py @@ -76,8 +76,7 @@ # subtract the evoked response in order to exclude evoked activity epochs_induced = epochs.copy().subtract_evoked() -plt.close("all") - +fig, axes = plt.subplots(2, 2, layout="constrained") for ii, (this_epochs, title) in enumerate( zip([epochs, epochs_induced], ["evoked + induced", "induced only"]) ): @@ -99,9 +98,8 @@ ########################################################################## # View time-frequency plots - plt.subplots_adjust(0.1, 0.08, 0.96, 0.94, 0.2, 0.43) - plt.subplot(2, 2, 2 * ii + 1) - plt.imshow( + ax = axes[ii, 0] + ax.imshow( 20 * power, extent=[times[0], times[-1], freqs[0], freqs[-1]], aspect="auto", @@ -110,13 +108,10 @@ vmax=30.0, cmap="RdBu_r", ) - plt.xlabel("Time (s)") - plt.ylabel("Frequency (Hz)") - plt.title("Power (%s)" % title) - plt.colorbar() + ax.set(xlabel="Time (s)", ylabel="Frequency (Hz)", title=f"Power ({title})") - plt.subplot(2, 2, 2 * ii + 2) - plt.imshow( + ax = axes[ii, 1] + ax.imshow( itc, extent=[times[0], times[-1], freqs[0], freqs[-1]], aspect="auto", @@ -125,9 +120,5 @@ vmax=0.7, cmap="RdBu_r", ) - plt.xlabel("Time (s)") - plt.ylabel("Frequency (Hz)") - plt.title("ITC (%s)" % title) - plt.colorbar() - -plt.show() + ax.set(xlabel="Time (s)", ylabel="Frequency (Hz)", title=f"ITC ({title})") + fig.colorbar(ax.images[0], ax=axes[ii]) diff --git a/examples/time_frequency/source_power_spectrum_opm.py b/examples/time_frequency/source_power_spectrum_opm.py index 14fcfa7039f..ce2ad03f607 100644 --- a/examples/time_frequency/source_power_spectrum_opm.py +++ b/examples/time_frequency/source_power_spectrum_opm.py @@ -84,7 +84,6 @@ .plot(picks="data", exclude="bads") ) fig.suptitle(titles[kind]) - fig.subplots_adjust(0.1, 0.1, 0.95, 0.85) ############################################################################## # Alignment and forward diff --git a/examples/time_frequency/time_frequency_simulated.py b/examples/time_frequency/time_frequency_simulated.py index 46747b6ae69..c6f00a9da32 100644 --- a/examples/time_frequency/time_frequency_simulated.py +++ b/examples/time_frequency/time_frequency_simulated.py @@ -100,7 +100,7 @@ freqs = np.arange(5.0, 100.0, 3.0) vmin, vmax = -3.0, 3.0 # Define our color limits. -fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True) +fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, layout="constrained") for n_cycles, time_bandwidth, ax, title in zip( [freqs / 2, freqs, freqs / 2], # number of cycles [2.0, 4.0, 8.0], # time bandwidth @@ -130,7 +130,6 @@ show=False, colorbar=False, ) -plt.tight_layout() ############################################################################## # Stockwell (S) transform @@ -143,7 +142,7 @@ # we control the spectral / temporal resolution by specifying different widths # of the gaussian window using the ``width`` parameter. -fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True) +fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, layout="constrained") fmin, fmax = freqs[[0, -1]] for width, ax in zip((0.2, 0.7, 3.0), axs): power = tfr_stockwell(epochs, fmin=fmin, fmax=fmax, width=width) @@ -151,7 +150,6 @@ [0], baseline=(0.0, 0.1), mode="mean", axes=ax, show=False, colorbar=False ) ax.set_title("Sim: Using S transform, width = {:0.1f}".format(width)) -plt.tight_layout() # %% # Morlet Wavelets @@ -162,7 +160,7 @@ # temporal resolution with the ``n_cycles`` parameter, which defines the # number of cycles to include in the window. -fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True) +fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, layout="constrained") all_n_cycles = [1, 3, freqs / 2.0] for n_cycles, ax in zip(all_n_cycles, axs): power = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False) @@ -178,7 +176,6 @@ ) n_cycles = "scaled by freqs" if not isinstance(n_cycles, int) else n_cycles ax.set_title(f"Sim: Using Morlet wavelet, n_cycles = {n_cycles}") -plt.tight_layout() # %% # Narrow-bandpass Filter and Hilbert Transform @@ -189,7 +186,7 @@ # important so that you isolate only one oscillation of interest, generally # the width of this filter is recommended to be about 2 Hz. -fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True) +fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, layout="constrained") bandwidths = [1.0, 2.0, 4.0] for bandwidth, ax in zip(bandwidths, axs): data = np.zeros((len(ch_names), freqs.size, epochs.times.size), dtype=complex) @@ -233,7 +230,6 @@ f"bandwidth = {bandwidth}, " f"transition bandwidth = {4 * bandwidth}" ) -plt.tight_layout() # %% # Calculating a TFR without averaging over epochs @@ -277,12 +273,9 @@ ) # Baseline the output rescale(power, epochs.times, (0.0, 0.1), mode="mean", copy=False) -fig, ax = plt.subplots() +fig, ax = plt.subplots(layout="constrained") x, y = centers_to_edges(epochs.times * 1000, freqs) mesh = ax.pcolormesh(x, y, power[0], cmap="RdBu_r", vmin=vmin, vmax=vmax) ax.set_title("TFR calculated on a numpy array") ax.set(ylim=freqs[[0, -1]], xlabel="Time (ms)") fig.colorbar(mesh) -plt.tight_layout() - -plt.show() diff --git a/examples/visualization/3d_to_2d.py b/examples/visualization/3d_to_2d.py index 590cc9df639..966e97f76ac 100644 --- a/examples/visualization/3d_to_2d.py +++ b/examples/visualization/3d_to_2d.py @@ -129,8 +129,7 @@ lt = mne.channels.read_layout(layout_path / layout_name, scale=False) x = lt.pos[:, 0] * float(im.shape[1]) y = (1 - lt.pos[:, 1]) * float(im.shape[0]) # Flip the y-position -fig, ax = plt.subplots() +fig, ax = plt.subplots(layout="constrained") ax.imshow(im) ax.scatter(x, y, s=80, color="r") -fig.tight_layout() ax.set_axis_off() diff --git a/examples/visualization/evoked_topomap.py b/examples/visualization/evoked_topomap.py index 1497e91bda8..dfd6be7f0f3 100644 --- a/examples/visualization/evoked_topomap.py +++ b/examples/visualization/evoked_topomap.py @@ -94,7 +94,7 @@ # and ``'head'`` otherwise. Here we show each option: extrapolations = ["local", "head", "box"] -fig, axes = plt.subplots(figsize=(7.5, 4.5), nrows=2, ncols=3) +fig, axes = plt.subplots(figsize=(7.5, 4.5), nrows=2, ncols=3, layout="constrained") # Here we look at EEG channels, and use a custom head sphere to get all the # sensors to be well within the drawn head surface @@ -111,7 +111,6 @@ sphere=(0.0, 0.0, 0.0, 0.09), ) ax.set_title("%s %s" % (ch_type.upper(), extr), fontsize=14) -fig.tight_layout() # %% # More advanced usage @@ -123,7 +122,6 @@ fig = evoked.plot_topomap( 0.1, ch_type="mag", show_names=True, colorbar=False, size=6, res=128 ) -fig.subplots_adjust(left=0.01, right=0.99, bottom=0.01, top=0.88) fig.suptitle("Auditory response") # %% diff --git a/mne/conftest.py b/mne/conftest.py index c1e6b36a93b..a0eeaf18dfb 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -33,7 +33,6 @@ Bunch, _check_qt_version, _TempDir, - check_version, ) # data from sample dataset @@ -84,6 +83,7 @@ def pytest_configure(config): "slowtest", "ultraslowtest", "pgtest", + "pvtest", "allow_unclosed", "allow_unclosed_pyside2", ): @@ -104,6 +104,13 @@ def pytest_configure(config): if os.getenv("PYTEST_QT_API") is None and os.getenv("QT_API") is not None: os.environ["PYTEST_QT_API"] = os.environ["QT_API"] + # suppress: + # Debugger warning: It seems that frozen modules are being used, which may + # make the debugger miss breakpoints. Please pass -Xfrozen_modules=off + # to python to disable frozen modules. + if os.getenv("PYDEVD_DISABLE_FILE_VALIDATION") is None: + os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1" + # https://numba.readthedocs.io/en/latest/reference/deprecation.html#deprecation-of-old-style-numba-captured-errors # noqa: E501 if "NUMBA_CAPTURED_ERRORS" not in os.environ: os.environ["NUMBA_CAPTURED_ERRORS"] = "new_style" @@ -514,8 +521,9 @@ def pg_backend(request, garbage_collect): import mne_qt_browser mne_qt_browser._browser_instances.clear() - if check_version("mne_qt_browser", min_version="0.4"): - _assert_no_instances(MNEQtBrowser, f"Closure of {request.node.name}") + if not _test_passed(request): + return + _assert_no_instances(MNEQtBrowser, f"Closure of {request.node.name}") @pytest.fixture( @@ -541,35 +549,35 @@ def browser_backend(request, garbage_collect, monkeypatch): mne_qt_browser._browser_instances.clear() -@pytest.fixture(params=["pyvistaqt"]) +@pytest.fixture(params=[pytest.param("pyvistaqt", marks=pytest.mark.pvtest)]) def renderer(request, options_3d, garbage_collect): """Yield the 3D backends.""" with _use_backend(request.param, interactive=False) as renderer: yield renderer -@pytest.fixture(params=["pyvistaqt"]) +@pytest.fixture(params=[pytest.param("pyvistaqt", marks=pytest.mark.pvtest)]) def renderer_pyvistaqt(request, options_3d, garbage_collect): """Yield the PyVista backend.""" with _use_backend(request.param, interactive=False) as renderer: yield renderer -@pytest.fixture(params=["notebook"]) +@pytest.fixture(params=[pytest.param("notebook", marks=pytest.mark.pvtest)]) def renderer_notebook(request, options_3d): """Yield the 3D notebook renderer.""" with _use_backend(request.param, interactive=False) as renderer: yield renderer -@pytest.fixture(params=["pyvistaqt"]) +@pytest.fixture(params=[pytest.param("pyvistaqt", marks=pytest.mark.pvtest)]) def renderer_interactive_pyvistaqt(request, options_3d, qt_windows_closed): """Yield the interactive PyVista backend.""" with _use_backend(request.param, interactive=True) as renderer: yield renderer -@pytest.fixture(params=["pyvistaqt"]) +@pytest.fixture(params=[pytest.param("pyvistaqt", marks=pytest.mark.pvtest)]) def renderer_interactive(request, options_3d): """Yield the interactive 3D backends.""" with _use_backend(request.param, interactive=True) as renderer: @@ -872,6 +880,14 @@ def protect_config(): yield +def _test_passed(request): + try: + outcome = request.node.harvest_rep_call + except Exception: + outcome = "passed" + return outcome == "passed" + + @pytest.fixture() def brain_gc(request): """Ensure that brain can be properly garbage collected.""" @@ -897,11 +913,7 @@ def brain_gc(request): yield close_func() # no need to warn if the test itself failed, pytest-harvest helps us here - try: - outcome = request.node.harvest_rep_call - except Exception: - outcome = "failed" - if outcome != "passed": + if not _test_passed(request): return _assert_no_instances(Brain, "after") # Check VTK diff --git a/mne/preprocessing/eyetracking/calibration.py b/mne/preprocessing/eyetracking/calibration.py index 962299f3a84..1891ebacb30 100644 --- a/mne/preprocessing/eyetracking/calibration.py +++ b/mne/preprocessing/eyetracking/calibration.py @@ -147,7 +147,7 @@ def plot(self, show_offsets=True, axes=None, show=True): ax = axes fig = ax.get_figure() else: # create new figure and axes - fig, ax = plt.subplots(constrained_layout=True) + fig, ax = plt.subplots(layout="constrained") px, py = self["positions"].T gaze_x, gaze_y = self["gaze"].T diff --git a/mne/preprocessing/ica.py b/mne/preprocessing/ica.py index 15c1d286d6e..fdb7d920267 100644 --- a/mne/preprocessing/ica.py +++ b/mne/preprocessing/ica.py @@ -3366,7 +3366,6 @@ def corrmap( template=True, sphere=sphere, ) - template_fig.subplots_adjust(top=0.8) template_fig.canvas.draw() # first run: use user-selected map diff --git a/mne/report/report.py b/mne/report/report.py index 89154d3de76..faf12a79bd6 100644 --- a/mne/report/report.py +++ b/mne/report/report.py @@ -78,7 +78,7 @@ ) from ..viz._brain.view import views_dicts from ..viz.misc import _plot_mri_contours, _get_bem_plotting_surfaces -from ..viz.utils import _ndarray_to_fig, tight_layout +from ..viz.utils import _ndarray_to_fig from ..viz._scraper import _mne_qt_browser_screenshot from ..forward import read_forward_solution, Forward from ..epochs import read_epochs, BaseEpochs @@ -431,11 +431,6 @@ def _fig_to_img(fig, *, image_format="png", own_figure=True): # matplotlib modifies the passed dict, which is a bug mpl_kwargs["pil_kwargs"] = pil_kwargs.copy() with warnings.catch_warnings(): - warnings.filterwarnings( - action="ignore", - message=".*Axes that are not compatible with tight_layout.*", - category=UserWarning, - ) fig.savefig(output, format=image_format, dpi=dpi, **mpl_kwargs) if own_figure: @@ -1648,7 +1643,6 @@ def _add_ica_overlay(self, *, ica, inst, image_format, section, tags, replace): fig = ica.plot_overlay(inst=inst_, show=False, on_baseline="reapply") del inst_ - tight_layout(fig=fig) _constrain_fig_resolution(fig, max_width=MAX_IMG_WIDTH, max_res=MAX_IMG_RES) self._add_figure( fig=fig, @@ -1770,9 +1764,6 @@ def _add_ica_components(self, *, ica, picks, image_format, section, tags, replac if not isinstance(figs, list): figs = [figs] - for fig in figs: - tight_layout(fig=fig) - title = "ICA component topographies" if len(figs) == 1: fig = figs[0] @@ -3241,7 +3232,6 @@ def _add_raw( init_kwargs.setdefault("fmax", fmax) plot_kwargs.setdefault("show", False) fig = raw.compute_psd(**init_kwargs).plot(**plot_kwargs) - tight_layout(fig=fig) _constrain_fig_resolution(fig, max_width=MAX_IMG_WIDTH, max_res=MAX_IMG_RES) self._add_figure( fig=fig, @@ -3323,7 +3313,6 @@ def _add_projs( # hard to see how (6, 4) could work in all number-of-projs by # number-of-channel-types conditions... fig.set_size_inches((6, 4)) - tight_layout(fig=fig) _constrain_fig_resolution(fig, max_width=MAX_IMG_WIDTH, max_res=MAX_IMG_RES) self._add_figure( fig=fig, @@ -3488,6 +3477,7 @@ def _plot_one_evoked_topomap_timepoint( len(ch_types) * 2, gridspec_kw={"width_ratios": [8, 0.5] * len(ch_types)}, figsize=(2.5 * len(ch_types), 2), + layout="constrained", ) _constrain_fig_resolution(fig, max_width=MAX_IMG_WIDTH, max_res=MAX_IMG_RES) ch_type_ax_map = dict( @@ -3508,8 +3498,6 @@ def _plot_one_evoked_topomap_timepoint( ) ch_type_ax_map[ch_type][0].set_title(ch_type) - tight_layout(fig=fig) - with BytesIO() as buff: fig.savefig(buff, format="png", pad_inches=0) plt.close(fig) @@ -3616,7 +3604,7 @@ def _add_evoked_gfp( import matplotlib.pyplot as plt - fig, ax = plt.subplots(len(ch_types), 1, sharex=True) + fig, ax = plt.subplots(len(ch_types), 1, sharex=True, layout="constrained") if len(ch_types) == 1: ax = [ax] for idx, ch_type in enumerate(ch_types): @@ -3636,7 +3624,6 @@ def _add_evoked_gfp( if idx < len(ch_types) - 1: ax[idx].set_xlabel(None) - tight_layout(fig=fig) _constrain_fig_resolution(fig, max_width=MAX_IMG_WIDTH, max_res=MAX_IMG_RES) title = "Global field power" self._add_figure( @@ -3655,7 +3642,6 @@ def _add_evoked_whitened( ): """Render whitened evoked.""" fig = evoked.plot_white(noise_cov=noise_cov, show=False) - tight_layout(fig=fig) _constrain_fig_resolution(fig, max_width=MAX_IMG_WIDTH, max_res=MAX_IMG_RES) title = "Whitened" @@ -4003,7 +3989,6 @@ def _add_epochs( fig = epochs.plot_drop_log( subject=self.subject, ignore=drop_log_ignore, show=False ) - tight_layout(fig=fig) _constrain_fig_resolution( fig, max_width=MAX_IMG_WIDTH, max_res=MAX_IMG_RES ) @@ -4179,18 +4164,17 @@ def _add_stc( if backend_is_3d: brain.set_time(t) - fig, ax = plt.subplots(figsize=(4.5, 4.5)) + fig, ax = plt.subplots(figsize=(4.5, 4.5), layout="constrained") ax.imshow(brain.screenshot(time_viewer=True, mode="rgb")) ax.axis("off") - tight_layout(fig=fig) _constrain_fig_resolution( fig, max_width=stc_plot_kwargs["size"][0], max_res=MAX_IMG_RES ) figs.append(fig) plt.close(fig) else: - fig_lh = plt.figure() - fig_rh = plt.figure() + fig_lh = plt.figure(layout="constrained") + fig_rh = plt.figure(layout="constrained") brain_lh = stc.plot( views="lat", @@ -4210,8 +4194,6 @@ def _add_stc( backend="matplotlib", figure=fig_rh, ) - tight_layout(fig=fig_lh) # TODO is this necessary? - tight_layout(fig=fig_rh) # TODO is this necessary? _constrain_fig_resolution( fig_lh, max_width=stc_plot_kwargs["size"][0], diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index 52ca167ee6c..1fc2c6ce2bd 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -742,7 +742,6 @@ def plot( sphere=sphere, xlabels_list=xlabels_list, ) - fig.subplots_adjust(hspace=0.3) plt_show(show, fig) return fig diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 1a061b8b173..83445a64690 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -70,7 +70,6 @@ figure_nobar, plt_show, _setup_cmap, - _connection_line, _prepare_joint_axes, _setup_vmin_vmax, _set_title_multiple_electrodes, @@ -141,7 +140,7 @@ def morlet(sfreq, freqs, n_cycles=7.0, sigma=None, zero_mean=False): s = w * sfreq / (2 * freq * np.pi) # from SciPy docs wavelet_sp = sp_morlet(M, s, w) * np.sqrt(2) # match our normalization - _, ax = plt.subplots(constrained_layout=True) + _, ax = plt.subplots(layout="constrained") colors = { ('MNE', 'real'): '#66CCEE', ('SciPy', 'real'): '#4477AA', @@ -1732,7 +1731,7 @@ def _plot( elif isinstance(axes, plt.Axes): figs_and_axes = [(ax.get_figure(), ax) for ax in [axes]] elif axes is None: - figs = [plt.figure() for i in range(n_picks)] + figs = [plt.figure(layout="constrained") for i in range(n_picks)] figs_and_axes = [(fig, fig.add_subplot(111)) for fig in figs] else: raise ValueError("axes must be None, plt.Axes, or list " "of plt.Axes.") @@ -1921,7 +1920,7 @@ def plot_joint( .. versionadded:: 0.16.0 """ # noqa: E501 - import matplotlib.pyplot as plt + from matplotlib.patches import ConnectionPatch ##################################### # Handle channels (picks and types) # @@ -2007,7 +2006,7 @@ def plot_joint( # Image plot # ############## - fig, tf_ax, map_ax, cbar_ax = _prepare_joint_axes(n_timefreqs) + fig, tf_ax, map_ax = _prepare_joint_axes(n_timefreqs) cmap = _setup_cmap(cmap) @@ -2162,28 +2161,32 @@ def plot_joint( ############# # Finish up # ############# - if colorbar: from matplotlib import ticker - cbar = plt.colorbar(ax.images[0], cax=cbar_ax) + cbar = fig.colorbar(ax.images[0]) if locator is None: locator = ticker.MaxNLocator(nbins=5) cbar.locator = locator cbar.update_ticks() - plt.subplots_adjust( - left=0.12, right=0.925, bottom=0.14, top=1.0 if title is not None else 1.2 - ) - # draw the connection lines between time series and topoplots - lines = [ - _connection_line( - time_, fig, tf_ax, map_ax_, y=freq_, y_source_transform="transData" + for (time_, freq_), map_ax_ in zip(timefreqs_array, map_ax): + con = ConnectionPatch( + xyA=[time_, freq_], + xyB=[0.5, 0], + coordsA="data", + coordsB="axes fraction", + axesA=tf_ax, + axesB=map_ax_, + color="grey", + linestyle="-", + linewidth=1.5, + alpha=0.66, + zorder=1, + clip_on=False, ) - for (time_, freq_), map_ax_ in zip(timefreqs_array, map_ax) - ] - fig.lines.extend(lines) + fig.add_artist(con) plt_show(show) return fig @@ -2289,7 +2292,6 @@ def _onselect( axes=ax, ) ax.set_title(ch_type) - fig.tight_layout() @verbose def plot_topo( diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index ce99f2e6352..680d52022b5 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -88,7 +88,6 @@ _get_color_list, _get_cmap, plt_show, - tight_layout, figure_nobar, _check_time_unit, ) @@ -314,7 +313,9 @@ def plot_head_positions( from mpl_toolkits.mplot3d.art3d import Line3DCollection from mpl_toolkits.mplot3d import Axes3D # noqa: F401, analysis:ignore - fig, ax = plt.subplots(1, subplot_kw=dict(projection="3d")) + fig, ax = plt.subplots( + 1, subplot_kw=dict(projection="3d"), layout="constrained" + ) # First plot the trajectory as a colormap: # http://matplotlib.org/examples/pylab_examples/multicolored_line.html @@ -374,7 +375,6 @@ def plot_head_positions( ax.set(xlabel="x", ylabel="y", zlabel="z", xlim=xlim, ylim=ylim, zlim=zlim) _set_aspect_equal(ax) ax.view_init(30, 45) - tight_layout(fig=fig) plt_show(show) return fig @@ -1901,7 +1901,7 @@ def _key_pressed_slider(event, params): time_viewer.slider.set_val(this_time) -def _smooth_plot(this_time, params): +def _smooth_plot(this_time, params, *, draw=True): """Smooth source estimate data and plot with mpl.""" from ..morph import _hemi_morph @@ -1957,7 +1957,8 @@ def _smooth_plot(this_time, params): _set_aspect_equal(ax) ax.axis("off") ax.set(xlim=[-80, 80], ylim=(-80, 80), zlim=[-80, 80]) - ax.figure.canvas.draw() + if draw: + ax.figure.canvas.draw() def _plot_mpl_stc( @@ -2022,7 +2023,8 @@ def _plot_mpl_stc( del transparent, mapdata time_label, times = _handle_time(time_label, time_unit, stc.times) - fig = plt.figure(figsize=(6, 6)) if figure is None else figure + # don't use constrained layout because Axes3D does not play well with it + fig = plt.figure(figsize=(6, 6), layout=None) if figure is None else figure try: ax = Axes3D(fig, auto_add_to_figure=False) except Exception: # old mpl @@ -2072,7 +2074,7 @@ def _plot_mpl_stc( time_label=time_label, time_unit=time_unit, ) - _smooth_plot(initial_time, params) + _smooth_plot(initial_time, params, draw=False) ax.view_init(**kwargs[hemi][views]) @@ -2100,7 +2102,6 @@ def _plot_mpl_stc( callback_key = partial(_key_pressed_slider, params=params) time_viewer.canvas.mpl_connect("key_press_event", callback_key) - time_viewer.subplots_adjust(left=0.12, bottom=0.05, right=0.75, top=0.95) fig.subplots_adjust(left=0.0, bottom=0.0, right=1.0, top=1.0) # add colorbar @@ -2932,7 +2933,7 @@ def _onclick(event, params, verbose=None): del ijk # Plot initial figure - fig, (axes, ax_time) = plt.subplots(2) + fig, (axes, ax_time) = plt.subplots(2, layout="constrained") axes.set(xticks=[], yticks=[]) marker = "o" if len(stc.times) == 1 else None ydata = stc.data[loc_idx] @@ -2943,7 +2944,6 @@ def _onclick(event, params, verbose=None): vert_legend = ax_time.legend([h], [""], title="Vertex") _update_vertlabel(loc_idx) lx = ax_time.axvline(stc.times[time_idx], color="g") - fig.tight_layout() allow_pos_lims = mode != "glass_brain" mapdata = _process_clim(clim, colormap, transparent, stc.data, allow_pos_lims) @@ -3390,7 +3390,7 @@ def plot_sparse_source_estimates( ) # Show time courses - fig = plt.figure(fig_number) + fig = plt.figure(fig_number, layout="constrained") fig.clf() ax = fig.add_subplot(111) @@ -3757,7 +3757,9 @@ def _plot_dipole_mri_orthoview( dims = len(data) # Symmetric size assumed. dd = dims // 2 if ax is None: - fig, ax = plt.subplots(1, subplot_kw=dict(projection="3d")) + fig, ax = plt.subplots( + 1, subplot_kw=dict(projection="3d"), layout="constrained" + ) else: _validate_type(ax, Axes3D, "ax", "Axes3D", extra='when mode is "orthoview"') fig = ax.get_figure() diff --git a/mne/viz/__init__.pyi b/mne/viz/__init__.pyi index e73226b6909..b709ebc2a05 100644 --- a/mne/viz/__init__.pyi +++ b/mne/viz/__init__.pyi @@ -82,7 +82,6 @@ __all__ = [ "set_3d_view", "set_browser_backend", "snapshot_brain_montage", - "tight_layout", "ui_events", "use_3d_backend", "use_browser_backend", @@ -149,7 +148,6 @@ from .topomap import ( plot_regression_weights, ) from .utils import ( - tight_layout, mne_analyze_colormap, compare_fiff, ClickableImage, diff --git a/mne/viz/_dipole.py b/mne/viz/_dipole.py index 64ab5774ba4..24fc4735f3c 100644 --- a/mne/viz/_dipole.py +++ b/mne/viz/_dipole.py @@ -53,9 +53,7 @@ def _plot_dipole_mri_outlines( _validate_type(surf, (str, None), "surf") _check_option("surf", surf, ("white", "pial", None)) if ax is None: - _, ax = plt.subplots( - 1, 3, figsize=(7, 2.5), squeeze=True, constrained_layout=True - ) + _, ax = plt.subplots(1, 3, figsize=(7, 2.5), squeeze=True, layout="constrained") _validate_if_list_of_axes(ax, 3, name="ax") dipoles = _check_concat_dipoles(dipoles) color = "r" if color is None else color diff --git a/mne/viz/_figure.py b/mne/viz/_figure.py index 82359f585ed..738bf838ce3 100644 --- a/mne/viz/_figure.py +++ b/mne/viz/_figure.py @@ -535,7 +535,7 @@ def _create_epoch_image_fig(self, pick): title = f"Epochs image ({ch_name})" fig = self._new_child_figure(figsize=(6, 4), fig_name=None, window_title=title) fig.suptitle = title - gs = GridSpec(nrows=3, ncols=10) + gs = GridSpec(nrows=3, ncols=10, figure=fig) fig.add_subplot(gs[:2, :9]) fig.add_subplot(gs[2, :9]) fig.add_subplot(gs[:2, 9]) @@ -580,16 +580,6 @@ def _create_epoch_histogram(self): ax.plot((reject, reject), (0, ax.get_ylim()[1]), color="r") # finalize fig.suptitle(title, y=0.99) - if hasattr(fig, "_inch_to_rel"): - kwargs = dict( - bottom=fig._inch_to_rel(0.5, horiz=False), - top=1 - fig._inch_to_rel(0.5, horiz=False), - left=fig._inch_to_rel(0.75), - right=1 - fig._inch_to_rel(0.25), - ) - else: - kwargs = dict() - fig.subplots_adjust(hspace=0.7, **kwargs) self.mne.fig_histogram = fig return fig diff --git a/mne/viz/_mpl_figure.py b/mne/viz/_mpl_figure.py index 2974df90958..c313bfe1edf 100644 --- a/mne/viz/_mpl_figure.py +++ b/mne/viz/_mpl_figure.py @@ -118,7 +118,7 @@ def __init__(self, **kwargs): for key in [k for k in kwargs if not hasattr(self.mne, k)]: setattr(self.mne, key, kwargs[key]) - def _close(self, event): + def _close(self, event=None): """Handle close events.""" logger.debug(f"Closing {self!r}") # remove references from parent fig to child fig @@ -886,9 +886,15 @@ def _create_ch_context_fig(self, idx): fig = super()._create_ch_context_fig(idx) plt_show(fig=fig) - def _new_child_figure(self, fig_name, **kwargs): + def _new_child_figure(self, fig_name, *, layout=None, **kwargs): """Instantiate a new MNE dialog figure (with event listeners).""" - fig = _figure(toolbar=False, parent_fig=self, fig_name=fig_name, **kwargs) + fig = _figure( + toolbar=False, + parent_fig=self, + fig_name=fig_name, + layout=layout, + **kwargs, + ) fig._add_default_callbacks() self.mne.child_figs.append(fig) if isinstance(fig_name, str): @@ -2324,8 +2330,8 @@ def _get_scale_bar_texts(self): class MNELineFigure(MNEFigure): """Interactive figure for non-scrolling line plots.""" - def __init__(self, inst, n_axes, figsize, **kwargs): - super().__init__(figsize=figsize, inst=inst, **kwargs) + def __init__(self, inst, n_axes, figsize, *, layout=None, **kwargs): + super().__init__(figsize=figsize, inst=inst, layout=layout, **kwargs) # AXES: default margins (inches) l_margin = 0.8 @@ -2372,6 +2378,8 @@ def _figure(toolbar=True, FigureClass=MNEFigure, **kwargs): from matplotlib import rc_context title = kwargs.pop("window_title", None) # extract title before init + if "layout" not in kwargs: + kwargs["layout"] = "constrained" rc = dict() if toolbar else dict(toolbar="none") with rc_context(rc=rc): fig = plt.figure(FigureClass=FigureClass, **kwargs) @@ -2379,6 +2387,14 @@ def _figure(toolbar=True, FigureClass=MNEFigure, **kwargs): fig.mne.backend = BACKEND if title is not None: _set_window_title(fig, title) + # TODO: for some reason for topomaps->_prepare_trellis the layout=constrained does + # not work the first time (maybe toolbar=False?) + if kwargs.get("layout") == "constrained": + if hasattr(fig, "set_layout_engine"): # 3.6+ + fig.set_layout_engine("constrained") + else: + fig.set_constrained_layout(True) + # add event callbacks fig._add_default_callbacks() return fig @@ -2409,6 +2425,7 @@ def _line_figure(inst, axes=None, picks=None, **kwargs): FigureClass=MNELineFigure, figsize=figsize, n_axes=n_axes, + layout=None, **kwargs, ) fig.mne.fig_size_px = fig._get_size_px() # can't do in __init__ @@ -2483,7 +2500,7 @@ def _init_browser(**kwargs): """Instantiate a new MNE browse-style figure.""" from mne.io import BaseRaw - fig = _figure(toolbar=False, FigureClass=MNEBrowseFigure, **kwargs) + fig = _figure(toolbar=False, FigureClass=MNEBrowseFigure, layout=None, **kwargs) # splash is ignored (maybe we could do it for mpl if we get_backend() and # check if it's Qt... but seems overkill) diff --git a/mne/viz/_proj.py b/mne/viz/_proj.py index 5a40df7dc03..0f8f02a3089 100644 --- a/mne/viz/_proj.py +++ b/mne/viz/_proj.py @@ -102,7 +102,7 @@ def plot_projs_joint( n_row = len(ch_types) shape = (n_row, n_col) fig = plt.figure( - figsize=(n_col * 1.1 + 0.5, n_row * 1.8 + 0.5), constrained_layout=True + figsize=(n_col * 1.1 + 0.5, n_row * 1.8 + 0.5), layout="constrained" ) ri = 0 # pick some sufficiently distinct colors (6 per proj type, e.g., ECG, diff --git a/mne/viz/backends/_abstract.py b/mne/viz/backends/_abstract.py index c2c3e08eb2b..e924e7deae9 100644 --- a/mne/viz/backends/_abstract.py +++ b/mne/viz/backends/_abstract.py @@ -7,10 +7,8 @@ # License: Simplified BSD from abc import ABC, abstractmethod, abstractclassmethod -from contextlib import nullcontext import warnings -from ..utils import tight_layout from ..ui_events import publish, TimeChange @@ -1333,19 +1331,10 @@ def _mpl_initialize(): class _AbstractMplCanvas(ABC): def __init__(self, width, height, dpi): """Initialize the MplCanvas.""" - from matplotlib import rc_context from matplotlib.figure import Figure - # prefer constrained layout here but live with tight_layout otherwise - context = nullcontext self._extra_events = ("resize",) - try: - context = rc_context({"figure.constrained_layout.use": True}) - self._extra_events = () - except KeyError: - pass - with context: - self.fig = Figure(figsize=(width, height), dpi=dpi) + self.fig = Figure(figsize=(width, height), dpi=dpi, layout="constrained") self.axes = self.fig.add_subplot(111) self.axes.set(xlabel="Time (s)", ylabel="Activation (AU)") self.manager = None @@ -1408,7 +1397,7 @@ def clear(self): def on_resize(self, event): """Handle resize events.""" - tight_layout(fig=self.axes.figure) + pass class _AbstractBrainMplCanvas(_AbstractMplCanvas): diff --git a/mne/viz/backends/tests/test_utils.py b/mne/viz/backends/tests/test_utils.py index 3bec2aafcc9..cfa0c65535f 100644 --- a/mne/viz/backends/tests/test_utils.py +++ b/mne/viz/backends/tests/test_utils.py @@ -7,6 +7,7 @@ from colorsys import rgb_to_hls from contextlib import nullcontext +import platform import numpy as np import pytest @@ -79,6 +80,8 @@ def test_theme_colors(pg_backend, theme, monkeypatch, tmp_path): return # we could add a ton of conditionals below, but KISS is_dark = _qt_is_dark(fig) # on Darwin these checks get complicated, so don't bother for now + if platform.system() == "Darwin": + pytest.skip("Problems on macOS") if theme == "dark": assert is_dark, theme elif theme == "light": diff --git a/mne/viz/circle.py b/mne/viz/circle.py index af160141741..983eef69c5c 100644 --- a/mne/viz/circle.py +++ b/mne/viz/circle.py @@ -212,7 +212,7 @@ def _plot_connectivity_circle( # Use a polar axes if ax is None: - fig = plt.figure(figsize=(8, 8), facecolor=facecolor) + fig = plt.figure(figsize=(8, 8), facecolor=facecolor, layout="constrained") ax = fig.add_subplot(polar=True) else: fig = ax.figure diff --git a/mne/viz/epochs.py b/mne/viz/epochs.py index d173c80a45b..7bd1785ada9 100644 --- a/mne/viz/epochs.py +++ b/mne/viz/epochs.py @@ -13,7 +13,6 @@ from collections import Counter from copy import deepcopy -import warnings import numpy as np from scipy.ndimage import gaussian_filter1d @@ -31,7 +30,6 @@ _VALID_CHANNEL_TYPES, ) from .utils import ( - tight_layout, _setup_vmin_vmax, plt_show, _check_cov, @@ -453,7 +451,7 @@ def _validate_fig_and_axes(fig, axes, group_by, evoked, colorbar, clear=False): rowspan = 2 if evoked else 3 shape = (3, 10) for this_group in group_by: - this_fig = figure() + this_fig = figure(layout="constrained") _set_window_title(this_fig, this_group) subplot2grid(shape, (0, 0), colspan=colspan, rowspan=rowspan, fig=this_fig) if evoked: @@ -602,8 +600,6 @@ def _plot_epochs_image( tmax = epochs.times[-1] ax_im = ax["image"] - fig = ax_im.get_figure() - # draw the image cmap = _setup_cmap(cmap, norm=norm) n_epochs = len(image) @@ -664,13 +660,10 @@ def _plot_epochs_image( ax_im.CB = DraggableColorbar( this_colorbar, im, kind="epochs_image", ch_type=unit ) - with warnings.catch_warnings(record=True): - warnings.simplefilter("ignore") - tight_layout(fig=fig) # finish plt_show(show) - return fig + return ax_im.get_figure() def plot_drop_log( @@ -733,7 +726,7 @@ def plot_drop_log( ch_names = np.array(list(scores.keys())) counts = np.array(list(scores.values())) # init figure, handle easy case (no drops) - fig, ax = plt.subplots() + fig, ax = plt.subplots(layout="constrained") title = f"{absolute} of {n_epochs_before_drop} epochs removed " f"({percent:.1f}%)" if subject is not None: title = f"{subject}: {title}" @@ -755,7 +748,6 @@ def plot_drop_log( ) ax.set_ylabel("% of epochs removed") ax.grid(axis="y") - tight_layout(pad=1, fig=fig) plt_show(show) return fig diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index 687203cad49..5886bb26db3 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -30,7 +30,6 @@ from ..defaults import _handle_default from .utils import ( _draw_proj_checkbox, - tight_layout, _check_delayed_ssp, plt_show, _process_times, @@ -41,7 +40,6 @@ _make_combine_callable, _validate_if_list_of_axes, _triage_rank_sss, - _connection_line, _get_color_list, _setup_ax_spines, _setup_plot_projector, @@ -165,7 +163,11 @@ def _line_plot_onselect( minidx = np.abs(times - xmin).argmin() maxidx = np.abs(times - xmax).argmin() fig, axarr = plt.subplots( - 1, len(ch_types), squeeze=False, figsize=(3 * len(ch_types), 3) + 1, + len(ch_types), + squeeze=False, + figsize=(3 * len(ch_types), 3), + layout="constrained", ) for idx, ch_type in enumerate(ch_types): @@ -211,7 +213,6 @@ def _line_plot_onselect( unit = "Hz" if psd else time_unit fig.suptitle("Average over %.2f%s - %.2f%s" % (xmin, unit, xmax, unit), y=0.1) - tight_layout(pad=2.0, fig=fig) plt_show() if text is not None: text.set_visible(False) @@ -332,7 +333,7 @@ def _plot_evoked( if axes is None: axes = dict() for sel in group_by: - plt.figure() + plt.figure(layout="constrained") axes[sel] = plt.axes() if not isinstance(axes, dict): raise ValueError( @@ -458,8 +459,7 @@ def _plot_evoked( fig = None if axes is None: - fig, axes = plt.subplots(len(ch_types_used), 1) - fig.subplots_adjust(left=0.125, bottom=0.1, right=0.975, top=0.92, hspace=0.63) + fig, axes = plt.subplots(len(ch_types_used), 1, layout="constrained") if isinstance(axes, plt.Axes): axes = [axes] fig.set_size_inches(6.4, 2 + len(axes)) @@ -738,6 +738,7 @@ def _plot_lines( else: y_offset = this_ylim[0] this_gfp += y_offset + ax.autoscale(False) ax.fill_between( times, y_offset, @@ -1628,7 +1629,7 @@ def whitened_gfp(x, rank=None): sharex=True, sharey=False, figsize=(8.8, 2.2 * n_rows), - constrained_layout=True, + layout="constrained", ) else: axes = np.array(axes) @@ -1772,7 +1773,7 @@ def plot_snr_estimate(evoked, inv, show=True, axes=None, verbose=None): snr, snr_est = estimate_snr(evoked, inv) _validate_type(axes, (None, plt.Axes)) if axes is None: - _, ax = plt.subplots(1, 1) + _, ax = plt.subplots(1, 1, layout="constrained") else: ax = axes del axes @@ -1858,7 +1859,7 @@ def plot_evoked_joint( ----- .. versionadded:: 0.12.0 """ - import matplotlib.pyplot as plt + from matplotlib.patches import ConnectionPatch if ts_args is not None and not isinstance(ts_args, dict): raise TypeError("ts_args must be dict or None, got type %s" % (type(ts_args),)) @@ -1955,9 +1956,8 @@ def plot_evoked_joint( # prepare axes for topomap if not got_axes: - fig, ts_ax, map_ax, cbar_ax = _prepare_joint_axes( - len(times_sec), figsize=(8.0, 4.2) - ) + fig, ts_ax, map_ax = _prepare_joint_axes(len(times_sec), figsize=(8.0, 4.2)) + cbar_ax = None else: ts_ax = ts_args["axes"] del ts_args["axes"] @@ -1995,20 +1995,10 @@ def plot_evoked_joint( old_title = ts_ax.get_title() ts_ax.set_title("") - # XXX BUG destroys ax -> fig assignment if title & axes are passed if title is not None: - title_ax = fig.add_subplot(4, 3, 2) if title == "": title = old_title - title_ax.text( - 0.5, - 0.5, - title, - transform=title_ax.transAxes, - horizontalalignment="center", - verticalalignment="center", - ) - title_ax.axis("off") + fig.suptitle(title) # topomap contours = topomap_args.get("contours", 6) @@ -2034,8 +2024,8 @@ def plot_evoked_joint( if topomap_args.get("colorbar", True): from matplotlib import ticker - cbar_ax.grid(False) # auto-removal deprecated as of 2021/10/05 - cbar = plt.colorbar(map_ax[0].images[0], cax=cbar_ax) + cbar = fig.colorbar(map_ax[0].images[0], ax=map_ax, cax=cbar_ax) + cbar.ax.grid(False) # auto-removal deprecated as of 2021/10/05 if isinstance(contours, (list, np.ndarray)): cbar.set_ticks(contours) else: @@ -2044,19 +2034,24 @@ def plot_evoked_joint( cbar.locator = locator cbar.update_ticks() - if not got_axes: - plt.subplots_adjust( - left=0.1, right=0.93, bottom=0.14, top=1.0 if title is not None else 1.2 - ) - # connection lines # draw the connection lines between time series and topoplots - lines = [ - _connection_line(timepoint, fig, ts_ax, map_ax_) - for timepoint, map_ax_ in zip(times_ts, map_ax) - ] - for line in lines: - fig.lines.append(line) + for timepoint, map_ax_ in zip(times_ts, map_ax): + con = ConnectionPatch( + xyA=[timepoint, ts_ax.get_ylim()[1]], + xyB=[0.5, 0], + coordsA="data", + coordsB="axes fraction", + axesA=ts_ax, + axesB=map_ax_, + color="grey", + linestyle="-", + linewidth=1.5, + alpha=0.66, + zorder=1, + clip_on=False, + ) + fig.add_artist(con) # mark times in time series plot for timepoint in times_ts: @@ -2941,7 +2936,9 @@ def plot_compare_evokeds( axes = ["topo"] * len(ch_types) else: if axes is None: - axes = (plt.subplots(figsize=(8, 6))[1] for _ in ch_types) + axes = ( + plt.subplots(figsize=(8, 6), layout="constrained")[1] for _ in ch_types + ) elif isinstance(axes, plt.Axes): axes = [axes] _validate_if_list_of_axes(axes, obligatory_len=len(ch_types)) @@ -3015,7 +3012,7 @@ def plot_compare_evokeds( from .topo import iter_topography from ..channels.layout import find_layout - fig = plt.figure(figsize=(18, 14)) + fig = plt.figure(figsize=(18, 14), layout=None) # Not "constrained" for topo def click_func( ax_, diff --git a/mne/viz/ica.py b/mne/viz/ica.py index a414775b635..d80ed9aec65 100644 --- a/mne/viz/ica.py +++ b/mne/viz/ica.py @@ -14,7 +14,6 @@ from scipy.stats import gaussian_kde from .utils import ( - tight_layout, _make_event_color_dict, _get_cmap, plt_show, @@ -767,7 +766,7 @@ def _plot_ica_sources_evoked(evoked, picks, exclude, title, show, ica, labels=No if title is None: title = "Reconstructed latent sources, time-locked" - fig, axes = plt.subplots(1) + fig, axes = plt.subplots(1, layout="constrained") ax = axes axes = [axes] times = evoked.times * 1e3 @@ -852,7 +851,6 @@ def _plot_ica_sources_evoked(evoked, picks, exclude, title, show, ica, labels=No ax.set(title=title, xlim=times[[0, -1]], xlabel="Time (ms)", ylabel="(NA)") if len(exclude) > 0: plt.legend(loc="best") - tight_layout(fig=fig) texts.append( ax.text( @@ -959,7 +957,9 @@ def plot_ica_scores( if figsize is None: figsize = (6.4 * n_cols, 2.7 * n_rows) - fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, sharex=True, sharey=True) + fig, axes = plt.subplots( + n_rows, n_cols, figsize=figsize, sharex=True, sharey=True, layout="constrained" + ) if isinstance(axes, np.ndarray): axes = axes.flatten() @@ -1012,11 +1012,6 @@ def plot_ica_scores( ax.set_title("(%s)" % label) ax.set_xlabel("ICA components") ax.set_xlim(-0.6, len(this_scores) - 0.4) - - tight_layout(fig=fig) - - adjust_top = 0.8 if len(fig.axes) == 1 else 0.9 - fig.subplots_adjust(top=adjust_top) fig.canvas.draw() plt_show(show) return fig @@ -1159,13 +1154,13 @@ def _plot_ica_overlay_raw(*, raw, raw_cln, picks, start, stop, title, show): ch_types = raw.get_channel_types(picks=picks, unique=True) for ch_type in ch_types: if ch_type in ("mag", "grad"): - fig, ax = plt.subplots(3, 1, sharex=True, constrained_layout=True) + fig, ax = plt.subplots(3, 1, sharex=True, layout="constrained") elif ch_type == "eeg" and not _has_eeg_average_ref_proj( raw.info, check_active=True ): - fig, ax = plt.subplots(3, 1, sharex=True, constrained_layout=True) + fig, ax = plt.subplots(3, 1, sharex=True, layout="constrained") else: - fig, ax = plt.subplots(2, 1, sharex=True, constrained_layout=True) + fig, ax = plt.subplots(2, 1, sharex=True, layout="constrained") fig.suptitle(title) # select sensors and retrieve data array @@ -1236,7 +1231,7 @@ def _plot_ica_overlay_evoked(evoked, evoked_cln, title, show): if len(ch_types_used) != len(ch_types_used_cln): raise ValueError("Raw and clean evokeds must match. Found different channels.") - fig, axes = plt.subplots(n_rows, 1) + fig, axes = plt.subplots(n_rows, 1, layout="constrained") if title is None: title = "Average signal before (red) and after (black) ICA" fig.suptitle(title) @@ -1248,9 +1243,6 @@ def _plot_ica_overlay_evoked(evoked, evoked_cln, title, show): line.set_color("r") fig.canvas.draw() evoked_cln.plot(axes=axes, show=False, time_unit="s", spatial_colors=False) - tight_layout(fig=fig) - - fig.subplots_adjust(top=0.90) fig.canvas.draw() plt_show(show) return fig diff --git a/mne/viz/misc.py b/mne/viz/misc.py index d2c1a4242dc..c903244f9ff 100644 --- a/mne/viz/misc.py +++ b/mne/viz/misc.py @@ -50,7 +50,6 @@ ) from ..filter import estimate_ringing_samples from .utils import ( - tight_layout, _get_color_list, _prepare_trellis, plt_show, @@ -172,7 +171,11 @@ def plot_cov( C = np.sqrt((C * C.conj()).real) fig_cov, axes = plt.subplots( - 1, len(idx_names), squeeze=False, figsize=(3.8 * len(idx_names), 3.7) + 1, + len(idx_names), + squeeze=False, + figsize=(3.8 * len(idx_names), 3.7), + layout="constrained", ) for k, (idx, name, _, _, _) in enumerate(idx_names): vlim = np.max(np.abs(C[idx][:, idx])) @@ -192,13 +195,14 @@ def plot_cov( cax.grid(False) # avoid mpl warning about auto-removal plt.colorbar(im, cax=cax, format="%.0e") - fig_cov.subplots_adjust(0.04, 0.0, 0.98, 0.94, 0.2, 0.26) - tight_layout(fig=fig_cov) - fig_svd = None if show_svd: fig_svd, axes = plt.subplots( - 1, len(idx_names), squeeze=False, figsize=(3.8 * len(idx_names), 3.7) + 1, + len(idx_names), + squeeze=False, + figsize=(3.8 * len(idx_names), 3.7), + layout="constrained", ) for k, (idx, name, unit, scaling, key) in enumerate(idx_names): this_C = C[idx][:, idx] @@ -233,10 +237,8 @@ def plot_cov( title=name, xlim=[0, len(s) - 1], ) - tight_layout(fig=fig_svd) plt_show(show) - return fig_cov, fig_svd @@ -321,7 +323,7 @@ def plot_source_spectrogram( time_grid, freq_grid = np.meshgrid(time_bounds, freq_bounds) # Plotting the results - fig = plt.figure(figsize=(9, 6)) + fig = plt.figure(figsize=(9, 6), layout="constrained") plt.pcolor(time_grid, freq_grid, source_power[:, source_index, :], cmap="Reds") ax = plt.gca() @@ -344,7 +346,6 @@ def plot_source_spectrogram( plt.grid(True, ls="-") if colorbar: plt.colorbar() - tight_layout(fig=fig) # Covering frequency gaps with horizontal bars for lower_bound, upper_bound in gap_bounds: @@ -481,6 +482,8 @@ def _plot_mri_contours( if slices_as_subplots: ax = axs[ai] else: + # No need for constrained layout here because we make our axes fill the + # entire figure fig = _figure_agg(figsize=figsize, dpi=dpi, facecolor="k") ax = fig.add_axes([0, 0, 1, 1], frame_on=False, facecolor="k") @@ -588,9 +591,6 @@ def _plot_mri_contours( figs.append(fig) if slices_as_subplots: - fig.subplots_adjust( - left=0.0, bottom=0.0, right=1.0, top=1.0, wspace=0.0, hspace=0.0 - ) plt_show(show, fig=fig) return fig else: @@ -848,7 +848,7 @@ def plot_events( fig = None if axes is None: - fig = plt.figure() + fig = plt.figure(layout="constrained") ax = axes if axes else plt.gca() unique_events_id = np.array(unique_events_id) @@ -948,7 +948,7 @@ def plot_dipole_amplitudes(dipoles, colors=None, show=True): if colors is None: colors = cycle(_get_color_list()) - fig, ax = plt.subplots(1, 1) + fig, ax = plt.subplots(1, 1, layout="constrained") xlim = [np.inf, -np.inf] for dip, color in zip(dipoles, colors): ax.plot(dip.times, dip.amplitude * 1e9, color=color, linewidth=1.5) @@ -1191,7 +1191,7 @@ def plot_filter( fig = None if axes is None: - fig, axes = plt.subplots(len(plot), 1) + fig, axes = plt.subplots(len(plot), 1, layout="constrained") if isinstance(axes, plt.Axes): axes = [axes] elif isinstance(axes, np.ndarray): @@ -1263,7 +1263,6 @@ def plot_filter( ) adjust_axes(axes) - tight_layout() plt_show(show) return fig @@ -1357,7 +1356,7 @@ def plot_ideal_filter( my_gain.append(gain[ii]) my_gain = 10 * np.log10(np.maximum(my_gain, 10 ** (alim[0] / 10.0))) if axes is None: - axes = plt.subplots(1)[1] + axes = plt.subplots(1, layout="constrained")[1] for transition in transitions: axes.axvspan(*transition, color=color, alpha=0.1) axes.plot( @@ -1378,7 +1377,6 @@ def plot_ideal_filter( if title: axes.set(title=title) adjust_axes(axes) - tight_layout() plt_show(show) return axes.figure @@ -1508,7 +1506,11 @@ def plot_csd( continue fig, axes = plt.subplots( - n_rows, n_cols, squeeze=False, figsize=(2 * n_cols + 1, 2.2 * n_rows) + n_rows, + n_cols, + squeeze=False, + figsize=(2 * n_cols + 1, 2.2 * n_rows), + layout="constrained", ) csd_mats = [] @@ -1535,8 +1537,6 @@ def plot_csd( ax.set_title("%.1f Hz." % freq) plt.suptitle(title) - plt.subplots_adjust(top=0.8) - if colorbar: cb = plt.colorbar(im, ax=[a for ax_ in axes for a in ax_]) if mode == "csd": @@ -1580,9 +1580,7 @@ def plot_chpi_snr(snr_dict, axes=None): ----- If you supply a list of existing `~matplotlib.axes.Axes`, then the figure legend will not be drawn automatically. If you still want it, running - ``fig.legend(loc='right', title='cHPI frequencies')`` will recreate it, - though you may also need to manually adjust the margin to make room for it - (e.g., using ``fig.subplots_adjust(right=0.8)``). + ``fig.legend(loc='right', title='cHPI frequencies')`` will recreate it. .. versionadded:: 0.24 """ @@ -1593,7 +1591,7 @@ def plot_chpi_snr(snr_dict, axes=None): full_names = dict(mag="magnetometers", grad="gradiometers") axes_was_none = axes is None if axes_was_none: - fig, axes = plt.subplots(len(valid_keys), 1, sharex=True) + fig, axes = plt.subplots(len(valid_keys), 1, sharex=True, layout="constrained") else: fig = axes[0].get_figure() if len(axes) != len(valid_keys): @@ -1627,6 +1625,5 @@ def plot_chpi_snr(snr_dict, axes=None): if axes_was_none: ax.set(xlabel="Time (s)") fig.align_ylabels() - fig.subplots_adjust(left=0.1, right=0.825, bottom=0.075, top=0.95, hspace=0.7) fig.legend(loc="right", title="cHPI frequencies") return fig diff --git a/mne/viz/tests/test_epochs.py b/mne/viz/tests/test_epochs.py index 711afdea480..bfe5d07eebf 100644 --- a/mne/viz/tests/test_epochs.py +++ b/mne/viz/tests/test_epochs.py @@ -272,14 +272,7 @@ def test_plot_epochs_nodata(browser_backend): @pytest.mark.slowtest def test_plot_epochs_image(epochs): - """Test plotting of epochs image. - - Note that some of these tests that should pass are triggering MPL - UserWarnings about tight_layout not being applied ("tight_layout cannot - make axes width small enough to accommodate all axes decorations"). Calling - `plt.close('all')` just before the offending test seems to prevent this - warning, though it's unclear why. - """ + """Test plotting of epochs image.""" figs = epochs.plot_image() assert len(figs) == 2 # one fig per ch_type (test data has mag, grad) assert len(plt.get_fignums()) == 2 diff --git a/mne/viz/tests/test_evoked.py b/mne/viz/tests/test_evoked.py index ce67febd0a9..644b2fb4e3e 100644 --- a/mne/viz/tests/test_evoked.py +++ b/mne/viz/tests/test_evoked.py @@ -231,7 +231,7 @@ def test_plot_evoked(): def test_constrained_layout(): """Test that we handle constrained layouts correctly.""" - fig, ax = plt.subplots(1, 1, constrained_layout=True) + fig, ax = plt.subplots(1, 1, layout="constrained") assert fig.get_constrained_layout() evoked = mne.read_evokeds(evoked_fname)[0] evoked.pick(evoked.ch_names[:2]) @@ -612,7 +612,7 @@ def test_plot_ctf(): fig = plt.figure() # create custom axes for topomaps, colorbar and the timeseries - gs = gridspec.GridSpec(3, 7, hspace=0.5, top=0.8) + gs = gridspec.GridSpec(3, 7, hspace=0.5, top=0.8, figure=fig) topo_axes = [ fig.add_subplot(gs[0, idx * 2 : (idx + 1) * 2]) for idx in range(len(times)) ] diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py index 4f95f586d98..e20b1987dd1 100644 --- a/mne/viz/tests/test_topomap.py +++ b/mne/viz/tests/test_topomap.py @@ -75,8 +75,8 @@ fast_test = dict(res=8, contours=0, sensors=False) -@pytest.mark.parametrize("constrained_layout", (False, True)) -def test_plot_topomap_interactive(constrained_layout): +@pytest.mark.parametrize("layout", (None, "constrained")) +def test_plot_topomap_interactive(layout): """Test interactive topomap projection plotting.""" evoked = read_evokeds(evoked_fname, baseline=(None, 0))[0] evoked.pick(picks="mag") @@ -86,7 +86,7 @@ def test_plot_topomap_interactive(constrained_layout): evoked.add_proj(compute_proj_evoked(evoked, n_mag=1)) plt.close("all") - fig, ax = plt.subplots(constrained_layout=constrained_layout) + fig, ax = plt.subplots(layout=layout) canvas = fig.canvas kwargs = dict( diff --git a/mne/viz/topo.py b/mne/viz/topo.py index 683c22d9a6a..5a832c954a3 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -145,7 +145,8 @@ def _iter_topography( from ..channels.layout import find_layout if fig is None: - fig = plt.figure() + # Don't use constrained layout because we place axes manually + fig = plt.figure(layout=None) def format_coord_unified(x, y, pos=None, ch_names=None): """Update status bar with channel name under cursor.""" @@ -296,7 +297,8 @@ def _plot_topo( ) if axes is None: - fig = plt.figure() + # Don't use constrained layout because we place axes manually + fig = plt.figure(layout=None) axes = plt.axes([0.015, 0.025, 0.97, 0.95]) axes.set_facecolor(fig_facecolor) else: diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index d47ec145e07..a90400c6421 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -54,7 +54,6 @@ ) from ..utils.spectrum import _split_psd_kwargs from .utils import ( - tight_layout, _setup_vmin_vmax, _prepare_trellis, _check_delayed_ssp, @@ -301,8 +300,8 @@ def _add_colorbar( ax, im, cmap, + *, side="right", - pad=0.05, title=None, format=None, size="5%", @@ -310,14 +309,10 @@ def _add_colorbar( ch_type=None, ): """Add a colorbar to an axis.""" - import matplotlib.pyplot as plt - from mpl_toolkits.axes_grid1 import make_axes_locatable - - divider = make_axes_locatable(ax) - cax = divider.append_axes(side, size=size, pad=pad) - cbar = plt.colorbar(im, cax=cax, format=format) + cbar = ax.figure.colorbar(im, format=format) if cmap is not None and cmap[1]: ax.CB = DraggableColorbar(cbar, im, kind, ch_type) + cax = cbar.ax if title is not None: cax.set_title(title, y=1.05, fontsize=10) return cbar, cax @@ -450,7 +445,6 @@ def plot_projs_topomap( ) with warnings.catch_warnings(record=True): warnings.simplefilter("ignore") - tight_layout(fig=fig) plt_show(show) return fig @@ -1020,7 +1014,7 @@ def plot_topomap( from matplotlib.colors import Normalize if axes is None: - _, axes = plt.subplots(figsize=(size, size)) + _, axes = plt.subplots(figsize=(size, size), layout="constrained") sphere = _check_sphere(sphere, pos if isinstance(pos, Info) else None) _validate_type(cnorm, (Normalize, None), "cnorm") if cnorm is not None and (vlim[0] is not None or vlim[1] is not None): @@ -1379,9 +1373,6 @@ def _plot_topomap( size="x-small", ) - if not axes.figure.get_constrained_layout(): - axes.figure.subplots_adjust(top=0.95) - if onselect is not None: lim = axes.dataLim x0, y0, width, height = lim.x0, lim.y0, lim.width, lim.height @@ -1475,7 +1466,6 @@ def _plot_ica_topomap( axes, im, cmap, - pad=0.05, title="AU", format="%3.2f", kind="ica_topomap", @@ -1716,7 +1706,6 @@ def plot_ica_components( cmap, title="AU", side="right", - pad=0.05, format=cbar_fmt, kind="ica_comp_topomap", ch_type=ch_type, @@ -1725,9 +1714,6 @@ def plot_ica_components( cbar.set_ticks(_vlim) _hide_frame(ax) del pos - if not user_passed_axes: - tight_layout(fig=fig) - fig.subplots_adjust(top=0.88, bottom=0.0) fig.canvas.draw() # add title selection interactivity @@ -1934,7 +1920,11 @@ def plot_tfr_topomap( vlim = _setup_vmin_vmax(data, *vlim, norm) cmap = _setup_cmap(cmap, norm=norm) - axes = plt.subplots(figsize=(size, size))[1] if axes is None else axes + axes = ( + plt.subplots(figsize=(size, size), layout="constrained")[1] + if axes is None + else axes + ) fig = axes.figure _hide_frame(axes) @@ -2204,18 +2194,17 @@ def plot_evoked_topomap( if interactive: height_ratios = [5, 1] nrows = 2 - ncols = want_axes - width = size * ncols + ncols = n_times + width = size * want_axes height = size + max(0, 0.1 * (4 - size)) fig = figure_nobar(figsize=(width * 1.5, height * 1.5)) - g_kwargs = {"left": 0.2, "right": 0.8, "bottom": 0.05, "top": 0.9} - gs = GridSpec(nrows, ncols, height_ratios=height_ratios, **g_kwargs) + gs = GridSpec(nrows, ncols, height_ratios=height_ratios, figure=fig) axes = [] for ax_idx in range(n_times): axes.append(plt.subplot(gs[0, ax_idx])) elif axes is None: fig, axes, ncols, nrows = _prepare_trellis( - n_times, ncols=ncols, nrows=nrows, colorbar=colorbar, size=size + n_times, ncols=ncols, nrows=nrows, size=size ) else: nrows, ncols = None, None # Deactivate ncols when axes were passed @@ -2227,13 +2216,7 @@ def plot_evoked_topomap( f"You must provide {want_axes} axes (one for " f"each time{cbar_err}), got {len(axes)}." ) - # figure margins - if not fig.get_constrained_layout(): - side_margin = plt.rcParams["figure.subplot.wspace"] / (2 * want_axes) - top_margin = max(0.05, 0.2 / size) - fig.subplots_adjust( - left=side_margin, right=1 - side_margin, bottom=0, top=1 - top_margin - ) + del want_axes # find first index that's >= (to rounding error) to each time point time_idx = [ np.where( @@ -2336,12 +2319,10 @@ def plot_evoked_topomap( images, contours_ = [], [] # loop over times for average_idx, (time, this_average) in enumerate(zip(times, average)): - adjust_for_cbar = colorbar and ncols is not None and average_idx >= ncols - 1 - ax_idx = average_idx + 1 if adjust_for_cbar else average_idx tp, cn, interp = _plot_topomap( data[:, average_idx], pos, - axes=axes[ax_idx], + axes=axes[average_idx], mask=mask_[:, average_idx] if mask is not None else None, vmin=_vlim[0], vmax=_vlim[1], @@ -2362,13 +2343,13 @@ def plot_evoked_topomap( to_time = time_format % (tmax_ * scaling_time) axes_title = f"{from_time} – {to_time}" del from_time, to_time, tmin_, tmax_ - axes[ax_idx].set_title(axes_title) + axes[average_idx].set_title(axes_title) if interactive: # Add a slider to the figure and start publishing and subscribing to time_change # events. kwargs.update(vlim=_vlim) - axes.append(plt.subplot(gs[1, :-1])) + axes.append(fig.add_subplot(gs[1])) slider = Slider( axes[-1], "Time", @@ -2412,19 +2393,15 @@ def _slider_changed(val): ) if colorbar: - if interactive: - cax = plt.subplot(gs[0, -1]) - _resize_cbar(cax, ncols, size) - elif nrows is None or ncols is None: + if nrows is None or ncols is None: # axes were given by the user, so don't resize the colorbar cax = axes[-1] - else: # use the entire last column - cax = axes[ncols - 1] - _resize_cbar(cax, ncols, size) + else: # use the default behavior + cax = None + cbar = fig.colorbar(images[-1], ax=axes, cax=cax, format=cbar_fmt, shrink=0.6) if unit is not None: - cax.set_title(unit) - cbar = fig.colorbar(images[-1], ax=cax, cax=cax, format=cbar_fmt) + cbar.ax.set_title(unit) if cn is not None: cbar.set_ticks(contours) cbar.ax.tick_params(labelsize=7) @@ -2578,9 +2555,7 @@ def _plot_topomap_multi_cbar( ) if colorbar: - cbar, cax = _add_colorbar( - ax, im, cmap, pad=0.25, title=None, size="10%", format=cbar_fmt - ) + cbar, cax = _add_colorbar(ax, im, cmap, title=None, size="10%", format=cbar_fmt) cbar.set_ticks(_vlim) if unit is not None: cbar.ax.set_ylabel(unit, fontsize=8) @@ -2857,7 +2832,9 @@ def plot_psds_topomap( _validate_if_list_of_axes(axes, n_axes) fig = axes[0].figure else: - fig, axes = plt.subplots(1, n_axes, figsize=(2 * n_axes, 1.5)) + fig, axes = plt.subplots( + 1, n_axes, figsize=(2 * n_axes, 1.5), layout="constrained" + ) if n_axes == 1: axes = [axes] # loop over subplots/frequency bands @@ -2892,7 +2869,6 @@ def plot_psds_topomap( ) if not user_passed_axes: - tight_layout(fig=fig) fig.canvas.draw() plt_show(show) return fig @@ -2923,9 +2899,10 @@ def plot_layout(layout, picks=None, show_axes=False, show=True): """ import matplotlib.pyplot as plt - fig = plt.figure(figsize=(max(plt.rcParams["figure.figsize"]),) * 2) + fig = plt.figure( + figsize=(max(plt.rcParams["figure.figsize"]),) * 2, layout="constrained" + ) ax = fig.add_subplot(111) - fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None) ax.set(xticks=[], yticks=[], aspect="equal") outlines = dict(border=([0, 1, 1, 0, 0], [0, 0, 1, 1, 0])) _draw_outlines(ax, outlines) @@ -2945,7 +2922,6 @@ def plot_layout(layout, picks=None, show_axes=False, show=True): x1, x2, y1, y2 = p[0], p[0] + p[2], p[1], p[1] + p[3] ax.plot([x1, x1, x2, x2, x1], [y1, y2, y2, y1, y1], color="k") ax.axis("off") - tight_layout(fig=fig, pad=0, w_pad=0, h_pad=0) plt_show(show) return fig @@ -3163,7 +3139,6 @@ def _init_anim( outlines_ = _draw_outlines(ax, outlines) params.update({"patch": patch_, "outlines": outlines_}) - tight_layout(fig=ax.figure) return tuple(items) + cont_collections @@ -3306,7 +3281,7 @@ def _topomap_animation( norm = np.min(data) >= 0 vmin, vmax = _setup_vmin_vmax(data, vmin, vmax, norm) - fig = plt.figure(figsize=(6, 5)) + fig = plt.figure(figsize=(6, 5), layout="constrained") shape = (8, 12) colspan = shape[1] - 1 rowspan = shape[0] - bool(butterfly) @@ -3491,8 +3466,6 @@ def _plot_corrmap( border=border, ) _hide_frame(ax) - tight_layout(fig=fig) - fig.subplots_adjust(top=0.8) fig.canvas.draw() plt_show(show) return fig @@ -3652,7 +3625,7 @@ def plot_arrowmap( ) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) if axes is None: - fig, axes = plt.subplots() + fig, axes = plt.subplots(layout="constrained") else: fig = axes.figure plot_topomap( @@ -3679,11 +3652,7 @@ def plot_arrowmap( dx, dy = _trigradient(x, y, data) dxx = dy.data dyy = -dx.data - axes.quiver(x, y, dxx, dyy, scale=scale, color="k", lw=1, clip_on=False) - axes.figure.canvas.draw_idle() - with warnings.catch_warnings(record=True): - warnings.simplefilter("ignore") - tight_layout(fig=fig) + axes.quiver(x, y, dxx, dyy, scale=scale, color="k", lw=1) plt_show(show) return fig @@ -3735,7 +3704,7 @@ def plot_bridged_electrodes( topomap_args.setdefault("contours", False) sphere = topomap_args.get("sphere", _check_sphere(None)) if "axes" not in topomap_args: - fig, ax = plt.subplots() + fig, ax = plt.subplots(layout="constrained") topomap_args["axes"] = ax else: fig = None @@ -4075,7 +4044,11 @@ def plot_regression_weights( axes_was_none = axes is None if axes_was_none: fig, axes = plt.subplots( - nrows, ncols, squeeze=False, figsize=(ncols * 2, nrows * 1.5 + 1) + nrows, + ncols, + squeeze=False, + figsize=(ncols * 2, nrows * 1.5 + 1), + layout="constrained", ) axes = axes.T.ravel() else: @@ -4143,8 +4116,5 @@ def plot_regression_weights( ) if axes_was_none: fig.suptitle(title) - fig.subplots_adjust( - top=0.88, bottom=0.06, left=0.025, right=0.911, hspace=0.2, wspace=0.5 - ) plt_show(show) return fig diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 78f05ee9109..08d4e69ec48 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -21,7 +21,6 @@ import sys import tempfile import traceback -import warnings import webbrowser from decorator import decorator @@ -203,63 +202,6 @@ def _show_browser(show=True, block=True, fig=None, **kwargs): _qt_app_exec(QApplication.instance()) -def tight_layout(pad=1.2, h_pad=None, w_pad=None, fig=None): - """Adjust subplot parameters to give specified padding. - - .. note:: For plotting please use this function instead of - ``plt.tight_layout``. - - Parameters - ---------- - pad : float - Padding between the figure edge and the edges of subplots, as a - fraction of the font-size. - h_pad : float - Padding height between edges of adjacent subplots. - Defaults to ``pad_inches``. - w_pad : float - Padding width between edges of adjacent subplots. - Defaults to ``pad_inches``. - fig : instance of Figure - Figure to apply changes to. - - Notes - ----- - This will not force constrained_layout=False if the figure was created - with that method. - """ - _validate_type(pad, "numeric", "pad") - import matplotlib.pyplot as plt - - fig = plt.gcf() if fig is None else fig - - fig.canvas.draw() - constrained = fig.get_constrained_layout() - kwargs = dict(pad=pad, h_pad=h_pad, w_pad=w_pad) - if constrained: - return # no-op - try: # see https://github.com/matplotlib/matplotlib/issues/2654 - with warnings.catch_warnings(record=True) as ws: - fig.tight_layout(**kwargs) - except Exception: - try: - with warnings.catch_warnings(record=True) as ws: - if hasattr(fig, "set_layout_engine"): - fig.set_layout_engine("tight", **kwargs) - else: - fig.set_tight_layout(kwargs) - except Exception: - warn( - 'Matplotlib function "tight_layout" is not supported.' - " Skipping subplot adjustment." - ) - return - for w in ws: - w_msg = str(w.message) if hasattr(w, "message") else w.get_message() - if not w_msg.startswith("This figure includes Axes"): - warn(w_msg, w.category, "matplotlib") - - def _check_delayed_ssp(container): """Handle interactive SSP selection.""" if container.proj is True or all(p["active"] for p in container.info["projs"]): @@ -489,7 +431,6 @@ def _prepare_trellis( ncols, nrows="auto", title=False, - colorbar=False, size=1.3, sharex=False, sharey=False, @@ -517,22 +458,13 @@ def _prepare_trellis( "figure.".format(n_cells, nrows, ncols) ) - if colorbar: - ncols += 1 width = size * ncols height = (size + max(0, 0.1 * (4 - size))) * nrows + bool(title) * 0.5 - height_ratios = None fig = _figure(toolbar=False, figsize=(width * 1.5, 0.25 + height * 1.5)) - gs = GridSpec(nrows, ncols, figure=fig, height_ratios=height_ratios) + gs = GridSpec(nrows, ncols, figure=fig) axes = [] - if colorbar: - # exclude last axis of each row except top row, which is for colorbar - exclude = set(range(2 * ncols - 1, nrows * ncols, ncols)) - ax_idxs = sorted(set(range(nrows * ncols)) - exclude)[: n_cells + 1] - else: - ax_idxs = range(n_cells) - for ax_idx in ax_idxs: + for ax_idx in range(n_cells): subplot_kw = dict() if ax_idx > 0: if sharex: @@ -560,7 +492,8 @@ def _draw_proj_checkbox(event, params, draw_current_state=True): width = max([4.0, max([len(p["desc"]) for p in projs]) / 6.0 + 0.5]) height = (len(projs) + 1) / 6.0 + 1.5 - fig_proj = figure_nobar(figsize=(width, height)) + # We manually place everything here so avoid constrained layouts + fig_proj = figure_nobar(figsize=(width, height), layout=None) _set_window_title(fig_proj, "SSP projection vectors") offset = 1.0 / 6.0 / height params["fig_proj"] = fig_proj # necessary for proper toggling @@ -707,6 +640,8 @@ def figure_nobar(*args, **kwargs): old_val = rcParams["toolbar"] try: rcParams["toolbar"] = "none" + if "layout" not in kwargs: + kwargs["layout"] = "constrained" fig = plt.figure(*args, **kwargs) # remove button press catchers (for toolbar) cbs = list(fig.canvas.callbacks.callbacks["key_press_event"].keys()) @@ -1319,7 +1254,10 @@ def _plot_sensors( if kind == "3d": subplot_kw.update(projection="3d") fig, ax = plt.subplots( - 1, figsize=(max(rcParams["figure.figsize"]),) * 2, subplot_kw=subplot_kw + 1, + figsize=(max(rcParams["figure.figsize"]),) * 2, + subplot_kw=subplot_kw, + layout="constrained", ) else: fig = ax.get_figure() @@ -1367,8 +1305,6 @@ def _plot_sensors( # Equal aspect for 3D looks bad, so only use for 2D ax.set(aspect="equal") - if axes_was_none: # we'll show the plot title as the window title - fig.subplots_adjust(left=0, bottom=0, right=1, top=1) ax.axis("off") # remove border around figure del sphere @@ -1393,14 +1329,6 @@ def _plot_sensors( connect_picker = kind == "select" # make sure no names go off the edge of the canvas xmin, ymin, xmax, ymax = fig.get_window_extent().bounds - renderer = fig.canvas.get_renderer() - extents = [x.get_window_extent(renderer=renderer) for x in ax.texts] - xmaxs = np.array([x.max[0] for x in extents]) - bad_xmax_ixs = np.nonzero(xmaxs > xmax)[0] - if len(bad_xmax_ixs): - needed_space = (xmaxs[bad_xmax_ixs] - xmax).max() / xmax - fig.subplots_adjust(right=1 - 1.1 * needed_space) - if connect_picker: picker = partial( _onpick_sensor, @@ -1530,38 +1458,14 @@ def _setup_cmap(cmap, n_axes=1, norm=False): def _prepare_joint_axes(n_maps, figsize=None): - """Prepare axes for topomaps and colorbar in joint plot figure. - - Parameters - ---------- - n_maps: int - Number of topomaps to include in the figure - figsize: tuple - Figure size, see plt.figsize - - Returns - ------- - fig : matplotlib.figure.Figure - Figure with initialized axes - main_ax: matplotlib.axes._subplots.AxesSubplot - Axes in which to put the main plot - map_ax: list - List of axes for each topomap - cbar_ax: matplotlib.axes._subplots.AxesSubplot - Axes for colorbar next to topomaps - """ import matplotlib.pyplot as plt + from matplotlib.gridspec import GridSpec - fig = plt.figure(figsize=figsize) - main_ax = fig.add_subplot(212) - ts = n_maps + 2 - map_ax = [plt.subplot(4, ts, x + 2 + ts) for x in range(n_maps)] - # Position topomap subplots on the second row, starting on the - # second column - cbar_ax = plt.subplot(4, 5 * (ts + 1), 10 * (ts + 1)) - # Position colorbar at the very end of a more finely divided - # second row of subplots - return fig, main_ax, map_ax, cbar_ax + fig = plt.figure(figsize=figsize, layout="constrained") + gs = GridSpec(2, n_maps, height_ratios=[1, 2], figure=fig) + map_ax = [fig.add_subplot(gs[0, x]) for x in range(n_maps)] # first row + main_ax = fig.add_subplot(gs[1, :]) # second row + return fig, main_ax, map_ax class DraggableColorbar: @@ -1908,37 +1812,6 @@ def _merge_annotations(start, stop, description, annotations, current=()): annotations.append(onset, duration, description) -def _connection_line(x, fig, sourceax, targetax, y=1.0, y_source_transform="transAxes"): - """Connect source and target plots with a line. - - Connect source and target plots with a line, such as time series - (source) and topolots (target). Primarily used for plot_joint - functions. - """ - from matplotlib.lines import Line2D - - trans_fig = fig.transFigure - trans_fig_inv = fig.transFigure.inverted() - - xt, yt = trans_fig_inv.transform(targetax.transAxes.transform([0.5, 0.0])) - xs, _ = trans_fig_inv.transform(sourceax.transData.transform([x, 0.0])) - _, ys = trans_fig_inv.transform( - getattr(sourceax, y_source_transform).transform([0.0, y]) - ) - - return Line2D( - (xt, xs), - (yt, ys), - transform=trans_fig, - color="grey", - linestyle="-", - linewidth=1.5, - alpha=0.66, - zorder=1, - clip_on=False, - ) - - class DraggableLine: """Custom matplotlib line for moving around by drag and drop. diff --git a/requirements.txt b/requirements.txt index 39ae2c37815..90944200247 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ # requirements for full MNE-Python functionality (other than raw/epochs export) numpy>=1.15.4 scipy>=1.7.1 -matplotlib>=3.4.3 +matplotlib>=3.5.0 tqdm pooch>=1.5 decorator diff --git a/requirements_base.txt b/requirements_base.txt index 551156522c3..2e6ba6e6c80 100644 --- a/requirements_base.txt +++ b/requirements_base.txt @@ -1,7 +1,7 @@ # requirements for basic MNE-Python functionality numpy>=1.21.2 scipy>=1.7.1 -matplotlib>=3.4.3 +matplotlib>=3.5.0 tqdm pooch>=1.5 decorator diff --git a/tools/github_actions_env_vars.sh b/tools/github_actions_env_vars.sh index a0ab494a9db..ba1cac712a5 100755 --- a/tools/github_actions_env_vars.sh +++ b/tools/github_actions_env_vars.sh @@ -4,7 +4,7 @@ set -eo pipefail -x # old and minimal use conda if [[ "$MNE_CI_KIND" == "old" ]]; then echo "Setting conda env vars for old" - echo "CONDA_DEPENDENCIES=numpy=1.21.2 scipy=1.7.1 matplotlib=3.4.3 pandas=1.3.2 scikit-learn=1.0" >> $GITHUB_ENV + echo "CONDA_DEPENDENCIES=numpy=1.21.2 scipy=1.7.1 matplotlib=3.5.0 pandas=1.3.2 scikit-learn=1.0" >> $GITHUB_ENV echo "MNE_IGNORE_WARNINGS_IN_TESTS=true" >> $GITHUB_ENV echo "MNE_SKIP_NETWORK_TESTS=1" >> $GITHUB_ENV elif [[ "$MNE_CI_KIND" == "minimal" ]]; then diff --git a/tutorials/epochs/60_make_fixed_length_epochs.py b/tutorials/epochs/60_make_fixed_length_epochs.py index a3186ca25c2..9a6eace0ab9 100644 --- a/tutorials/epochs/60_make_fixed_length_epochs.py +++ b/tutorials/epochs/60_make_fixed_length_epochs.py @@ -113,13 +113,10 @@ color_lims = np.percentile(np.array(corr_matrices), [5, 95]) titles = ["First 30 Seconds", "Last 30 Seconds"] -fig, axes = plt.subplots(nrows=1, ncols=2) +fig, axes = plt.subplots(nrows=1, ncols=2, layout="constrained") fig.suptitle("Correlation Matrices from First 30 Seconds and Last 30 Seconds") for ci, corr_matrix in enumerate(corr_matrices): ax = axes[ci] mpbl = ax.imshow(corr_matrix, clim=color_lims) ax.set_xlabel(titles[ci]) -fig.subplots_adjust(right=0.8) -cax = fig.add_axes([0.85, 0.2, 0.025, 0.6]) -cbar = fig.colorbar(ax.images[0], cax=cax) -cbar.set_label("Correlation Coefficient") +cbar = fig.colorbar(ax.images[0], label="Correlation Coefficient") diff --git a/tutorials/forward/50_background_freesurfer_mne.py b/tutorials/forward/50_background_freesurfer_mne.py index 5efcc07d0d1..0150088de83 100644 --- a/tutorials/forward/50_background_freesurfer_mne.py +++ b/tutorials/forward/50_background_freesurfer_mne.py @@ -124,7 +124,7 @@ def imshow_mri(data, img, vox, xyz, suptitle): """Show an MRI slice with a voxel annotated.""" i, j, k = vox - fig, ax = plt.subplots(1, figsize=(6, 6)) + fig, ax = plt.subplots(1, figsize=(6, 6), layout="constrained") codes = nibabel.orientations.aff2axcodes(img.affine) # Figure out the title based on the code of this axis ori_slice = dict( @@ -157,7 +157,6 @@ def imshow_mri(data, img, vox, xyz, suptitle): title=f"{title} view: i={i} ({ori_names[codes[0]]}+)", ) fig.suptitle(suptitle) - fig.subplots_adjust(0.1, 0.1, 0.95, 0.85) return fig diff --git a/tutorials/intro/70_report.py b/tutorials/intro/70_report.py index b23c8852694..12f04772ce8 100644 --- a/tutorials/intro/70_report.py +++ b/tutorials/intro/70_report.py @@ -463,7 +463,7 @@ fig_array_rotated = fig_array_rotated.clip(min=0, max=1) # Create the figure - fig, ax = plt.subplots(figsize=(3, 3), constrained_layout=True) + fig, ax = plt.subplots(figsize=(3, 3), layout="constrained") ax.imshow(fig_array_rotated) ax.set_axis_off() diff --git a/tutorials/inverse/20_dipole_fit.py b/tutorials/inverse/20_dipole_fit.py index 89cf81af671..c81c16f3252 100644 --- a/tutorials/inverse/20_dipole_fit.py +++ b/tutorials/inverse/20_dipole_fit.py @@ -100,6 +100,7 @@ ncols=4, figsize=[10.0, 3.4], gridspec_kw=dict(width_ratios=[1, 1, 1, 0.1], top=0.85), + layout="constrained", ) vmin, vmax = -400, 400 # make sure each plot has same colour range @@ -119,7 +120,6 @@ "at {:.0f} ms".format(best_time * 1000.0), fontsize=16, ) -fig.tight_layout() # %% # Estimate the time course of a single dipole with fixed position and diff --git a/tutorials/inverse/60_visualize_stc.py b/tutorials/inverse/60_visualize_stc.py index 01bd0c28a84..3be86643c61 100644 --- a/tutorials/inverse/60_visualize_stc.py +++ b/tutorials/inverse/60_visualize_stc.py @@ -156,7 +156,7 @@ label_tc = stc.extract_label_time_course(fname_aseg, src=src) lidx, tidx = np.unravel_index(np.argmax(label_tc), label_tc.shape) -fig, ax = plt.subplots(1) +fig, ax = plt.subplots(1, layout="constrained") ax.plot(stc.times, label_tc.T, "k", lw=1.0, alpha=0.5) xy = np.array([stc.times[tidx], label_tc[lidx, tidx]]) xytext = xy + [0.01, 1] @@ -164,7 +164,6 @@ ax.set(xlim=stc.times[[0, -1]], xlabel="Time (s)", ylabel="Activation") for key in ("right", "top"): ax.spines[key].set_visible(False) -fig.tight_layout() # %% # We can plot several labels with the most activation in their time course diff --git a/tutorials/inverse/80_brainstorm_phantom_elekta.py b/tutorials/inverse/80_brainstorm_phantom_elekta.py index cca2c3470af..95a2a8e8f59 100644 --- a/tutorials/inverse/80_brainstorm_phantom_elekta.py +++ b/tutorials/inverse/80_brainstorm_phantom_elekta.py @@ -144,7 +144,7 @@ actual_amp = 100.0 # nAm fig, (ax1, ax2, ax3) = plt.subplots( - nrows=3, ncols=1, figsize=(6, 7), constrained_layout=True + nrows=3, ncols=1, figsize=(6, 7), layout="constrained" ) diffs = 1000 * np.sqrt(np.sum((dip.pos - actual_pos) ** 2, axis=-1)) diff --git a/tutorials/machine-learning/30_strf.py b/tutorials/machine-learning/30_strf.py index af0db4d1d20..9cc53a7a2da 100644 --- a/tutorials/machine-learning/30_strf.py +++ b/tutorials/machine-learning/30_strf.py @@ -86,12 +86,10 @@ shading="gouraud", ) -fig, ax = plt.subplots() +fig, ax = plt.subplots(layout="constrained") ax.pcolormesh(delays_sec, freqs, weights, **kwargs) ax.set(title="Simulated STRF", xlabel="Time Lags (s)", ylabel="Frequency (Hz)") plt.setp(ax.get_xticklabels(), rotation=45) -plt.autoscale(tight=True) -mne.viz.tight_layout() # %% # Simulate a neural response @@ -147,7 +145,7 @@ X_plt = scale(np.hstack(X[:2]).T).T y_plt = scale(np.hstack(y[:2])) time = np.arange(X_plt.shape[-1]) / sfreq -_, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 6), sharex=True) +_, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 6), sharex=True, layout="constrained") ax1.pcolormesh(time, freqs, X_plt, vmin=0, vmax=4, cmap="Reds", shading="gouraud") ax1.set_title("Input auditory features") ax1.set(ylim=[freqs.min(), freqs.max()], ylabel="Frequency (Hz)") @@ -158,7 +156,6 @@ xlabel="Time (s)", ylabel="Activity (a.u.)", ) -mne.viz.tight_layout() # %% @@ -197,14 +194,19 @@ best_pred = best_mod.predict(X_test)[:, 0] # Plot the original STRF, and the one that we recovered with modeling. -_, (ax1, ax2) = plt.subplots(1, 2, figsize=(6, 3), sharey=True, sharex=True) +_, (ax1, ax2) = plt.subplots( + 1, + 2, + figsize=(6, 3), + sharey=True, + sharex=True, + layout="constrained", +) ax1.pcolormesh(delays_sec, freqs, weights, **kwargs) ax2.pcolormesh(times, rf.feature_names, coefs, **kwargs) ax1.set_title("Original STRF") ax2.set_title("Best Reconstructed STRF") plt.setp([iax.get_xticklabels() for iax in [ax1, ax2]], rotation=45) -plt.autoscale(tight=True) -mne.viz.tight_layout() # Plot the actual response and the predicted response on a held out stimulus time_pred = np.arange(best_pred.shape[0]) / sfreq @@ -213,8 +215,6 @@ ax.plot(time_pred, best_pred, color="r", lw=1) ax.set(title="Original and predicted activity", xlabel="Time (s)") ax.legend(["Original", "Predicted"]) -plt.autoscale(tight=True) -mne.viz.tight_layout() # %% @@ -229,7 +229,7 @@ # in :footcite:`TheunissenEtAl2001,WillmoreSmyth2003,HoldgrafEtAl2016`. # Plot model score for each ridge parameter -fig = plt.figure(figsize=(10, 4)) +fig = plt.figure(figsize=(10, 4), layout="constrained") ax = plt.subplot2grid([2, len(alphas)], [1, 0], 1, len(alphas)) ax.plot(np.arange(len(alphas)), scores, marker="o", color="r") ax.annotate( @@ -244,7 +244,6 @@ ylabel="Score ($R^2$)", xlim=[-0.4, len(alphas) - 0.6], ) -mne.viz.tight_layout() # Plot the STRF of each ridge parameter for ii, (rf, i_alpha) in enumerate(zip(models, alphas)): @@ -252,9 +251,7 @@ ax.pcolormesh(times, rf.feature_names, rf.coef_[0], **kwargs) plt.xticks([], []) plt.yticks([], []) - plt.autoscale(tight=True) fig.suptitle("Model coefficients / scores for many ridge parameters", y=1) -mne.viz.tight_layout() # %% # Using different regularization types @@ -308,7 +305,7 @@ # This matches the "true" receptive field structure and results in a better # model fit. -fig = plt.figure(figsize=(10, 6)) +fig = plt.figure(figsize=(10, 6), layout="constrained") ax = plt.subplot2grid([3, len(alphas)], [2, 0], 1, len(alphas)) ax.plot(np.arange(len(alphas)), scores_lap, marker="o", color="r") ax.plot(np.arange(len(alphas)), scores, marker="o", color="0.5", ls=":") @@ -330,7 +327,6 @@ ylabel="Score ($R^2$)", xlim=[-0.4, len(alphas) - 0.6], ) -mne.viz.tight_layout() # Plot the STRF of each ridge parameter xlim = times[[0, -1]] @@ -346,13 +342,19 @@ if ii == 0: ax.set(ylabel="Ridge") fig.suptitle("Model coefficients / scores for laplacian regularization", y=1) -mne.viz.tight_layout() # %% # Plot the original STRF, and the one that we recovered with modeling. rf = models[ix_best_alpha] rf_lap = models_lap[ix_best_alpha_lap] -_, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(9, 3), sharey=True, sharex=True) +_, (ax1, ax2, ax3) = plt.subplots( + 1, + 3, + figsize=(9, 3), + sharey=True, + sharex=True, + layout="constrained", +) ax1.pcolormesh(delays_sec, freqs, weights, **kwargs) ax2.pcolormesh(times, rf.feature_names, rf.coef_[0], **kwargs) ax3.pcolormesh(times, rf_lap.feature_names, rf_lap.coef_[0], **kwargs) @@ -360,8 +362,6 @@ ax2.set_title("Best Ridge STRF") ax3.set_title("Best Laplacian STRF") plt.setp([iax.get_xticklabels() for iax in [ax1, ax2, ax3]], rotation=45) -plt.autoscale(tight=True) -mne.viz.tight_layout() # %% # References diff --git a/tutorials/preprocessing/25_background_filtering.py b/tutorials/preprocessing/25_background_filtering.py index a5ec433ac7c..09e5db8173e 100644 --- a/tutorials/preprocessing/25_background_filtering.py +++ b/tutorials/preprocessing/25_background_filtering.py @@ -478,7 +478,7 @@ # and the time-domain ringing is thus more pronounced for the steep-slope, # long-duration filter than the shorter, shallower-slope filter: -axes = plt.subplots(1, 2)[1] +axes = plt.subplots(1, 2, layout="constrained")[1] def plot_signal(x, offset): @@ -524,7 +524,6 @@ def plot_signal(x, offset): for text in axes[0].get_yticklabels(): text.set(rotation=45, size=8) axes[1].set(xlim=flim, ylim=(-60, 10), xlabel="Frequency (Hz)", ylabel="Magnitude (dB)") -mne.viz.tight_layout() plt.show() # %% @@ -665,7 +664,7 @@ def plot_signal(x, offset): # Now let's look at how our shallow and steep Butterworth IIR filters # perform on our Morlet signal from before: -axes = plt.subplots(1, 2)[1] +axes = plt.subplots(1, 2, layout="constrained")[1] yticks = np.arange(4) / -30.0 yticklabels = ["Original", "Noisy", "Butterworth-2", "Butterworth-8"] plot_signal(x_orig, offset=yticks[0]) @@ -684,7 +683,6 @@ def plot_signal(x, offset): text.set(rotation=45, size=8) axes[1].set(xlim=flim, ylim=(-60, 10), xlabel="Frequency (Hz)", ylabel="Magnitude (dB)") mne.viz.adjust_axes(axes) -mne.viz.tight_layout() plt.show() # %% @@ -793,7 +791,6 @@ def plot_signal(x, offset): ) mne.viz.adjust_axes(axes) -mne.viz.tight_layout() plt.show() # %% @@ -832,7 +829,7 @@ def plot_signal(x, offset): def baseline_plot(x): - all_axes = plt.subplots(3, 2)[1] + all_axes = plt.subplots(3, 2, layout="constrained")[1] for ri, (axes, freq) in enumerate(zip(all_axes, [0.1, 0.3, 0.5])): for ci, ax in enumerate(axes): if ci == 0: @@ -849,7 +846,6 @@ def baseline_plot(x): ax.set(xticks=tticks, ylim=ylim, xlim=xlim, xlabel=xlabel) ax.set_ylabel("%0.1f Hz" % freq, rotation=0, horizontalalignment="right") mne.viz.adjust_axes(axes) - mne.viz.tight_layout() plt.suptitle(title) plt.show() diff --git a/tutorials/preprocessing/30_filtering_resampling.py b/tutorials/preprocessing/30_filtering_resampling.py index 32854096194..53b1f550fcc 100644 --- a/tutorials/preprocessing/30_filtering_resampling.py +++ b/tutorials/preprocessing/30_filtering_resampling.py @@ -156,7 +156,6 @@ def add_arrows(axes): raw_notch = raw.copy().notch_filter(freqs=freqs, picks=meg_picks) for title, data in zip(["Un", "Notch "], [raw, raw_notch]): fig = data.compute_psd(fmax=250).plot(average=True, picks="data", exclude="bads") - fig.subplots_adjust(top=0.85) fig.suptitle("{}filtered".format(title), size="xx-large", weight="bold") add_arrows(fig.axes[:2]) @@ -176,7 +175,6 @@ def add_arrows(axes): ) for title, data in zip(["Un", "spectrum_fit "], [raw, raw_notch_fit]): fig = data.compute_psd(fmax=250).plot(average=True, picks="data", exclude="bads") - fig.subplots_adjust(top=0.85) fig.suptitle("{}filtered".format(title), size="xx-large", weight="bold") add_arrows(fig.axes[:2]) @@ -212,7 +210,6 @@ def add_arrows(axes): for data, title in zip([raw, raw_downsampled], ["Original", "Downsampled"]): fig = data.compute_psd().plot(average=True, picks="data", exclude="bads") - fig.subplots_adjust(top=0.9) fig.suptitle(title) plt.setp(fig.axes, xlim=(0, 300)) diff --git a/tutorials/preprocessing/50_artifact_correction_ssp.py b/tutorials/preprocessing/50_artifact_correction_ssp.py index a1ea7135d8e..55d18b276a6 100644 --- a/tutorials/preprocessing/50_artifact_correction_ssp.py +++ b/tutorials/preprocessing/50_artifact_correction_ssp.py @@ -498,7 +498,9 @@ evoked_eeg = epochs.average().pick("eeg") evoked_eeg.del_proj().add_proj(ecg_projs).add_proj(eog_projs) -fig, axes = plt.subplots(1, 3, figsize=(8, 3), sharex=True, sharey=True) +fig, axes = plt.subplots( + 1, 3, figsize=(8, 3), sharex=True, sharey=True, layout="constrained" +) for pi, proj in enumerate((False, True, "reconstruct")): ax = axes[pi] evoked_eeg.plot(proj=proj, axes=ax, spatial_colors=True) @@ -512,7 +514,6 @@ ax.yaxis.set_tick_params(labelbottom=True) for text in list(ax.texts): text.remove() -mne.viz.tight_layout() # %% # Note that here the bias in the EEG and magnetometer channels is reduced by diff --git a/tutorials/preprocessing/60_maxwell_filtering_sss.py b/tutorials/preprocessing/60_maxwell_filtering_sss.py index 191eabf2b45..a3659b1f765 100644 --- a/tutorials/preprocessing/60_maxwell_filtering_sss.py +++ b/tutorials/preprocessing/60_maxwell_filtering_sss.py @@ -163,7 +163,7 @@ ) # First, plot the "raw" scores. -fig, ax = plt.subplots(1, 2, figsize=(12, 8)) +fig, ax = plt.subplots(1, 2, figsize=(12, 8), layout="constrained") fig.suptitle( f"Automated noisy channel detection: {ch_type}", fontsize=16, fontweight="bold" ) @@ -188,9 +188,6 @@ ] ax[1].set_title("Scores > Limit", fontweight="bold") -# The figure title should not overlap with the subplots. -fig.tight_layout(rect=[0, 0.03, 1, 0.95]) - # %% # # .. note:: You can use the very same code as above to produce figures for diff --git a/tutorials/preprocessing/70_fnirs_processing.py b/tutorials/preprocessing/70_fnirs_processing.py index 1dd30c628ab..886d99fc618 100644 --- a/tutorials/preprocessing/70_fnirs_processing.py +++ b/tutorials/preprocessing/70_fnirs_processing.py @@ -110,7 +110,7 @@ # coupling index. sci = mne.preprocessing.nirs.scalp_coupling_index(raw_od) -fig, ax = plt.subplots() +fig, ax = plt.subplots(layout="constrained") ax.hist(sci) ax.set(xlabel="Scalp Coupling Index", ylabel="Count", xlim=[0, 1]) @@ -157,7 +157,6 @@ for when, _raw in dict(Before=raw_haemo_unfiltered, After=raw_haemo).items(): fig = _raw.compute_psd().plot(average=True, picks="data", exclude="bads") fig.suptitle(f"{when} filtering", weight="bold", size="x-large") - fig.subplots_adjust(top=0.88) # %% # Extract epochs @@ -172,7 +171,6 @@ events, event_dict = mne.events_from_annotations(raw_haemo) fig = mne.viz.plot_events(events, event_id=event_dict, sfreq=raw_haemo.info["sfreq"]) -fig.subplots_adjust(right=0.7) # make room for the legend # %% @@ -238,7 +236,7 @@ # pairs that we selected. All the channels in this data are located over the # motor cortex, and all channels show a similar pattern in the data. -fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 6)) +fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 6), layout="constrained") clims = dict(hbo=[-20, 20], hbr=[-20, 20]) epochs["Control"].average().plot_image(axes=axes[:, 0], clim=clims) epochs["Tapping"].average().plot_image(axes=axes[:, 1], clim=clims) @@ -308,7 +306,11 @@ # And we can plot the comparison at a single time point for two conditions. fig, axes = plt.subplots( - nrows=2, ncols=4, figsize=(9, 5), gridspec_kw=dict(width_ratios=[1, 1, 1, 0.1]) + nrows=2, + ncols=4, + figsize=(9, 5), + gridspec_kw=dict(width_ratios=[1, 1, 1, 0.1]), + layout="constrained", ) vlim = (-8, 8) ts = 9.0 @@ -341,13 +343,12 @@ for column, condition in enumerate(["Tapping Left", "Tapping Right", "Left-Right"]): for row, chroma in enumerate(["HbO", "HbR"]): axes[row, column].set_title("{}: {}".format(chroma, condition)) -fig.tight_layout() # %% # Lastly, we can also look at the individual waveforms to see what is # driving the topographic plot above. -fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 4)) +fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 4), layout="constrained") mne.viz.plot_evoked_topo( epochs["Left"].average(picks="hbo"), color="b", axes=axes, legend=False ) diff --git a/tutorials/preprocessing/80_opm_processing.py b/tutorials/preprocessing/80_opm_processing.py index 7c76499fd36..a8d30c12abd 100644 --- a/tutorials/preprocessing/80_opm_processing.py +++ b/tutorials/preprocessing/80_opm_processing.py @@ -57,7 +57,7 @@ data_ds, time_ds = raw[picks[::5], :stop] data_ds, time_ds = data_ds[:, ::step] * amp_scale, time_ds[::step] -fig, ax = plt.subplots(constrained_layout=True) +fig, ax = plt.subplots(layout="constrained") plot_kwargs = dict(lw=1, alpha=0.5) ax.plot(time_ds, data_ds.T - np.mean(data_ds, axis=1), **plot_kwargs) ax.grid(True) @@ -111,7 +111,7 @@ data_ds, _ = raw[picks[::5], :stop] data_ds = data_ds[:, ::step] * amp_scale -fig, ax = plt.subplots(constrained_layout=True) +fig, ax = plt.subplots(layout="constrained") ax.plot(time_ds, data_ds.T - np.mean(data_ds, axis=1), **plot_kwargs) ax.grid(True, ls=":") ax.set(title="After reference regression", **set_kwargs) @@ -139,7 +139,7 @@ data_ds, _ = raw[picks[::5], :stop] data_ds = data_ds[:, ::step] * amp_scale -fig, ax = plt.subplots(constrained_layout=True) +fig, ax = plt.subplots(layout="constrained") ax.plot(time_ds, data_ds.T - np.mean(data_ds, axis=1), **plot_kwargs) ax.grid(True, ls=":") ax.set(title="After HFC", **set_kwargs) @@ -168,7 +168,7 @@ shielding = 10 * np.log10(psd_pre[:] / psd_post_reg[:]) -fig, ax = plt.subplots(constrained_layout=True) +fig, ax = plt.subplots(layout="constrained") ax.plot(psd_post_reg.freqs, shielding.T, **plot_kwargs) ax.grid(True, ls=":") ax.set(xticks=psd_post_reg.freqs) @@ -182,7 +182,7 @@ shielding = 10 * np.log10(psd_pre[:] / psd_post_hfc[:]) -fig, ax = plt.subplots(constrained_layout=True) +fig, ax = plt.subplots(layout="constrained") ax.plot(psd_post_hfc.freqs, shielding.T, **plot_kwargs) ax.grid(True, ls=":") ax.set(xticks=psd_post_hfc.freqs) @@ -215,7 +215,7 @@ # plot data_ds, _ = raw[picks[::5], :stop] data_ds = data_ds[:, ::step] * amp_scale -fig, ax = plt.subplots(constrained_layout=True) +fig, ax = plt.subplots(layout="constrained") plot_kwargs = dict(lw=1, alpha=0.5) ax.plot(time_ds, data_ds.T - np.mean(data_ds, axis=1), **plot_kwargs) ax.grid(True) diff --git a/tutorials/raw/20_event_arrays.py b/tutorials/raw/20_event_arrays.py index 6fedcfe0ade..e13b1f361a7 100644 --- a/tutorials/raw/20_event_arrays.py +++ b/tutorials/raw/20_event_arrays.py @@ -158,7 +158,6 @@ fig = mne.viz.plot_events( events, sfreq=raw.info["sfreq"], first_samp=raw.first_samp, event_id=event_dict ) -fig.subplots_adjust(right=0.7) # make room for legend # %% # Plotting events and raw data together diff --git a/tutorials/simulation/80_dics.py b/tutorials/simulation/80_dics.py index b8efcad9319..951671df1e4 100644 --- a/tutorials/simulation/80_dics.py +++ b/tutorials/simulation/80_dics.py @@ -99,7 +99,7 @@ def coh_signal_gen(): signal1 = coh_signal_gen() signal2 = coh_signal_gen() -fig, axes = plt.subplots(2, 2, figsize=(8, 4)) +fig, axes = plt.subplots(2, 2, figsize=(8, 4), layout="constrained") # Plot the timeseries ax = axes[0][0] @@ -133,7 +133,6 @@ def coh_signal_gen(): ylabel="Coherence", title="Coherence between the timeseries", ) -fig.tight_layout() # %% # Now we put the signals at two locations on the cortex. We construct a diff --git a/tutorials/stats-sensor-space/10_background_stats.py b/tutorials/stats-sensor-space/10_background_stats.py index 066ab249121..412715b3042 100644 --- a/tutorials/stats-sensor-space/10_background_stats.py +++ b/tutorials/stats-sensor-space/10_background_stats.py @@ -76,7 +76,7 @@ # %% # The data averaged over all subjects looks like this: -fig, ax = plt.subplots() +fig, ax = plt.subplots(layout="constrained") ax.imshow(X.mean(0), cmap="inferno") ax.set(xticks=[], yticks=[], title="Data averaged over subjects") @@ -121,7 +121,7 @@ def plot_t_p(t, p, title, mcc, axes=None): if axes is None: - fig = plt.figure(figsize=(6, 3)) + fig = plt.figure(figsize=(6, 3), layout="constrained") axes = [fig.add_subplot(121, projection="3d"), fig.add_subplot(122)] show = True else: @@ -150,7 +150,7 @@ def plot_t_p(t, p, title, mcc, axes=None): xticks=[], yticks=[], zticks=[], xlim=[0, width - 1], ylim=[0, width - 1] ) axes[0].view_init(30, 15) - cbar = plt.colorbar( + cbar = axes[0].figure.colorbar( ax=axes[0], shrink=0.75, orientation="horizontal", @@ -172,7 +172,7 @@ def plot_t_p(t, p, title, mcc, axes=None): use_p, cmap="inferno", vmin=p_lims[0], vmax=p_lims[1], interpolation="nearest" ) axes[1].set(xticks=[], yticks=[]) - cbar = plt.colorbar( + cbar = axes[1].figure.colorbar( ax=axes[1], shrink=0.75, orientation="horizontal", @@ -188,8 +188,6 @@ def plot_t_p(t, p, title, mcc, axes=None): text = fig.suptitle(title) if mcc: text.set_weight("bold") - plt.subplots_adjust(0, 0.05, 1, 0.9, wspace=0, hspace=0) - mne.viz.utils.plt_show() plot_t_p(ts[-1], ps[-1], titles[-1], mccs[-1]) @@ -286,7 +284,7 @@ def plot_t_p(t, p, title, mcc, axes=None): N = np.arange(1, 80) alpha = 0.05 p_type_I = 1 - (1 - alpha) ** N -fig, ax = plt.subplots(figsize=(4, 3)) +fig, ax = plt.subplots(figsize=(4, 3), layout="constrained") ax.scatter(N, p_type_I, 3) ax.set( xlim=N[[0, -1]], @@ -295,7 +293,6 @@ def plot_t_p(t, p, title, mcc, axes=None): ylabel="Probability of at least\none type I error", ) ax.grid(True) -fig.tight_layout() fig.show() # %% @@ -612,7 +609,7 @@ def plot_t_p(t, p, title, mcc, axes=None): # and the bottom shows p-values for various statistical tests, with the ones # with proper control over FWER or FDR with bold titles. -fig = plt.figure(facecolor="w", figsize=(14, 3)) +fig = plt.figure(facecolor="w", figsize=(14, 3), layout="constrained") assert len(ts) == len(titles) == len(ps) for ii in range(len(ts)): ax = [ @@ -620,8 +617,6 @@ def plot_t_p(t, p, title, mcc, axes=None): fig.add_subplot(2, 10, 11 + ii), ] plot_t_p(ts[ii], ps[ii], titles[ii], mccs[ii], ax) -fig.tight_layout(pad=0, w_pad=0.05, h_pad=0.1) -plt.show() # %% # The first three columns show the parametric and non-parametric statistics diff --git a/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py b/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py index a43fdfd46aa..cf49f48ddf4 100644 --- a/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py +++ b/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py @@ -235,8 +235,7 @@ evoked_data = evoked.data times = 1e3 * evoked.times -plt.figure() -plt.subplots_adjust(0.12, 0.08, 0.96, 0.94, 0.2, 0.43) +fig, (ax, ax2) = plt.subplots(2, layout="constrained") T_obs_plot = np.nan * np.ones_like(T_obs) for c, p_val in zip(clusters, cluster_p_values): @@ -252,8 +251,7 @@ vmax = np.max(np.abs(T_obs)) vmin = -vmax -plt.subplot(2, 1, 1) -plt.imshow( +ax.imshow( T_obs[ch_idx], cmap=plt.cm.gray, extent=[times[0], times[-1], freqs[0], freqs[-1]], @@ -262,7 +260,7 @@ vmin=vmin, vmax=vmax, ) -plt.imshow( +ax.imshow( T_obs_plot[ch_idx], cmap=plt.cm.RdBu_r, extent=[times[0], times[-1], freqs[0], freqs[-1]], @@ -271,11 +269,8 @@ vmin=vmin, vmax=vmax, ) -plt.colorbar() -plt.xlabel("Time (ms)") -plt.ylabel("Frequency (Hz)") -plt.title(f"Induced power ({tfr_epochs.ch_names[ch_idx]})") +fig.colorbar(ax.images[0]) +ax.set(xlabel="Time (ms)", ylabel="Frequency (Hz)") +ax.set(title=f"Induced power ({tfr_epochs.ch_names[ch_idx]})") -ax2 = plt.subplot(2, 1, 2) evoked.plot(axes=[ax2], time_unit="s") -plt.show() diff --git a/tutorials/stats-sensor-space/50_cluster_between_time_freq.py b/tutorials/stats-sensor-space/50_cluster_between_time_freq.py index 6ef0eaf3de3..69bdbbc5d91 100644 --- a/tutorials/stats-sensor-space/50_cluster_between_time_freq.py +++ b/tutorials/stats-sensor-space/50_cluster_between_time_freq.py @@ -147,8 +147,7 @@ times = 1e3 * epochs_condition_1.times # change unit to ms -fig, (ax, ax2) = plt.subplots(2, 1, figsize=(6, 4)) -fig.subplots_adjust(0.12, 0.08, 0.96, 0.94, 0.2, 0.43) +fig, (ax, ax2) = plt.subplots(2, 1, figsize=(6, 4), layout="constrained") # Compute the difference in evoked to determine which was greater since # we used a 1-way ANOVA which tested for a difference in population means diff --git a/tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py b/tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py index 1dfcfc79f86..a57112bedc4 100644 --- a/tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py +++ b/tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py @@ -172,7 +172,7 @@ effect_labels = ["modality", "location", "modality by location"] -fig, axes = plt.subplots(3, 1, figsize=(6, 6)) +fig, axes = plt.subplots(3, 1, figsize=(6, 6), layout="constrained") # let's visualize our effects by computing f-images for effect, sig, effect_label, ax in zip(fvals, pvals, effect_labels, axes): @@ -198,8 +198,6 @@ ax.set_ylabel("Frequency (Hz)") ax.set_title(f'Time-locked response for "{effect_label}" ({ch_name})') -fig.tight_layout() - # %% # Account for multiple comparisons using FDR versus permutation clustering test # ----------------------------------------------------------------------------- @@ -250,7 +248,7 @@ def stat_fun(*args): F_obs_plot = F_obs.copy() F_obs_plot[~clusters[np.squeeze(good_clusters)]] = np.nan -fig, ax = plt.subplots(figsize=(6, 4)) +fig, ax = plt.subplots(figsize=(6, 4), layout="constrained") for f_image, cmap in zip([F_obs, F_obs_plot], ["gray", "autumn"]): c = ax.imshow( f_image, @@ -267,7 +265,6 @@ def stat_fun(*args): f'Time-locked response for "modality by location" ({ch_name})\n' "cluster-level corrected (p <= 0.05)" ) -fig.tight_layout() # %% # Now using FDR: @@ -276,7 +273,7 @@ def stat_fun(*args): F_obs_plot2 = F_obs.copy() F_obs_plot2[~mask.reshape(F_obs_plot.shape)] = np.nan -fig, ax = plt.subplots(figsize=(6, 4)) +fig, ax = plt.subplots(figsize=(6, 4), layout="constrained") for f_image, cmap in zip([F_obs, F_obs_plot2], ["gray", "autumn"]): c = ax.imshow( f_image, @@ -293,7 +290,6 @@ def stat_fun(*args): f'Time-locked response for "modality by location" ({ch_name})\n' "FDR corrected (p <= 0.05)" ) -fig.tight_layout() # %% # Both cluster-level and FDR correction help get rid of potential diff --git a/tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py b/tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py index db6505fbafe..7a3234c5346 100644 --- a/tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py +++ b/tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py @@ -199,7 +199,7 @@ mask[ch_inds, :] = True # initialize figure - fig, ax_topo = plt.subplots(1, 1, figsize=(10, 3)) + fig, ax_topo = plt.subplots(1, 1, figsize=(10, 3), layout="constrained") # plot average test statistic and mark significant sensors f_evoked = mne.EvokedArray(f_map[:, np.newaxis], epochs.info, tmin=0) @@ -251,10 +251,7 @@ (ymin, ymax), sig_times[0], sig_times[-1], color="orange", alpha=0.3 ) - # clean up viz - mne.viz.tight_layout(fig=fig) - fig.subplots_adjust(bottom=0.05) - plt.show() +plt.show() # %% # Permutation statistic for time-frequencies @@ -352,7 +349,7 @@ sig_times = epochs.times[time_inds] # initialize figure - fig, ax_topo = plt.subplots(1, 1, figsize=(10, 3)) + fig, ax_topo = plt.subplots(1, 1, figsize=(10, 3), layout="constrained") # create spatial mask mask = np.zeros((f_map.shape[0], 1), dtype=bool) @@ -414,9 +411,7 @@ ax_colorbar2.set_ylabel("F-stat") # clean up viz - mne.viz.tight_layout(fig=fig) - fig.subplots_adjust(bottom=0.05) - plt.show() +plt.show() # %% diff --git a/tutorials/time-freq/20_sensors_time_frequency.py b/tutorials/time-freq/20_sensors_time_frequency.py index 776a230ecad..07a31e99db5 100644 --- a/tutorials/time-freq/20_sensors_time_frequency.py +++ b/tutorials/time-freq/20_sensors_time_frequency.py @@ -209,7 +209,7 @@ power.plot_topo(baseline=(-0.5, 0), mode="logratio", title="Average power") power.plot([82], baseline=(-0.5, 0), mode="logratio", title=power.ch_names[82]) -fig, axes = plt.subplots(1, 2, figsize=(7, 4), constrained_layout=True) +fig, axes = plt.subplots(1, 2, figsize=(7, 4), layout="constrained") topomap_kw = dict( ch_type="grad", tmin=0.5, tmax=1.5, baseline=(-0.5, 0), mode="logratio", show=False )