From a10cec34cde0d12c6889f76fff0418698a415285 Mon Sep 17 00:00:00 2001 From: Nabil Freij Date: Wed, 7 Feb 2024 13:55:27 -0800 Subject: [PATCH 1/2] Added Gallery --- .circleci/config.yml | 3 +- .github/workflows/ci.yml | 20 +- .pre-commit-config.yaml | 109 +++------ .readthedocs.yml | 1 - CHANGES.md | 1 + docs/conf.py | 93 +++----- examples/README.txt | 6 + examples/arrayanimatorwcs.py | 87 +++++++ examples/lineanimator.py | 52 ++++ mpl_animators/__init__.py | 11 +- mpl_animators/base.py | 236 ++++++++++--------- mpl_animators/extern/modest_image.py | 186 ++++++++------- mpl_animators/image.py | 22 +- mpl_animators/line.py | 44 ++-- mpl_animators/tests/helpers.py | 17 +- mpl_animators/tests/test_basefuncanimator.py | 102 ++++---- mpl_animators/tests/test_wcs.py | 130 +++++----- mpl_animators/wcs.py | 104 ++++---- pyproject.toml | 77 +++++- pytest.ini | 27 +++ ruff.toml | 80 +++++++ setup.cfg | 72 ------ setup.py | 20 +- tox.ini | 5 +- 24 files changed, 866 insertions(+), 639 deletions(-) create mode 100644 examples/README.txt create mode 100644 examples/arrayanimatorwcs.py create mode 100644 examples/lineanimator.py create mode 100644 pytest.ini create mode 100644 ruff.toml delete mode 100644 setup.cfg diff --git a/.circleci/config.yml b/.circleci/config.yml index 93af1f0..aa07c87 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -37,8 +37,7 @@ jobs: type: string docker: - image: cimg/python:3.11.6 - environment: - TOXENV=<< parameters.jobname >> + environment: TOXENV=<< parameters.jobname >> steps: - run: *no-backports - checkout diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b8e8544..335556c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,14 +3,14 @@ name: CI on: push: branches: - - 'main' - - '*.*' - - '!*backport*' + - "main" + - "*.*" + - "!*backport*" tags: - - 'v*' - - '!*dev*' - - '!*pre*' - - '!*post*' + - "v*" + - "!*dev*" + - "!*pre*" + - "!*post*" pull_request: # Allow manual runs through the web UI workflow_dispatch: @@ -50,7 +50,7 @@ jobs: needs: [core] uses: OpenAstronomy/github-actions-workflows/.github/workflows/tox.yml@main with: - default_python: '3.9' + default_python: "3.9" submodules: false pytest: false toxdeps: tox-pypi-filter @@ -65,7 +65,7 @@ jobs: needs: [test] uses: OpenAstronomy/github-actions-workflows/.github/workflows/tox.yml@main with: - default_python: '3.9' + default_python: "3.9" submodules: false coverage: codecov toxdeps: tox-pypi-filter @@ -93,7 +93,7 @@ jobs: uses: OpenAstronomy/github-actions-workflows/.github/workflows/publish_pure_python.yml@main with: python-version: "3.10" - test_extras: 'all,tests' + test_extras: "all,tests" test_command: 'pytest -p no:warnings -m "not mpl_image_compare" --pyargs mpl_animators' submodules: false secrets: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7887a4f..685b486 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,93 +1,44 @@ repos: - # The warnings/errors we check for here are: - # E101 - mix of tabs and spaces - # E11 - Fix indentation. - # E111 - 4 spaces per indentation level - # E112 - 4 spaces per indentation level - # E113 - 4 spaces per indentation level - # E121 - Fix indentation to be a multiple of four. - # E122 - Add absent indentation for hanging indentation. - # E123 - Align closing bracket to match opening bracket. - # E124 - Align closing bracket to match visual indentation. - # E125 - Indent to distinguish line from next logical line. - # E126 - Fix over-indented hanging indentation. - # E127 - Fix visual indentation. - # E128 - Fix visual indentation. - # E129 - Fix visual indentation. - # E131 - Fix hanging indent for unaligned continuation line. - # E133 - Fix missing indentation for closing bracket. - # E20 - Remove extraneous whitespace. - # E211 - Remove extraneous whitespace. - # E231 - Add missing whitespace. - # E241 - Fix extraneous whitespace around keywords. - # E242 - Remove extraneous whitespace around operator. - # E251 - Remove whitespace around parameter '=' sign. - # E252 - Missing whitespace around parameter equals. - # E26 - Fix spacing after comment hash for inline comments. - # E265 - Fix spacing after comment hash for block comments. - # E266 - Fix too many leading '#' for block comments. - # E27 - Fix extraneous whitespace around keywords. - # E301 - Add missing blank line. - # E302 - Add missing 2 blank lines. - # E303 - Remove extra blank lines. - # E304 - Remove blank line following function decorator. - # E305 - expected 2 blank lines after class or function definition - # E305 - Expected 2 blank lines after end of function or class. - # E306 - expected 1 blank line before a nested definition - # E306 - Expected 1 blank line before a nested definition. - # E401 - Put imports on separate lines. - # E402 - Fix module level import not at top of file - # E502 - Remove extraneous escape of newline. - # E701 - Put colon-separated compound statement on separate lines. - # E711 - Fix comparison with None. - # E712 - Fix comparison with boolean. - # E713 - Use 'not in' for test for membership. - # E714 - Use 'is not' test for object identity. - # E722 - Fix bare except. - # E731 - Use a def when use do not assign a lambda expression. - # E901 - SyntaxError or IndentationError - # E902 - IOError - # F822 - undefined name in __all__ - # F823 - local variable name referenced before assignment - # W291 - Remove trailing whitespace. - # W292 - Add a single newline at the end of the file. - # W293 - Remove trailing whitespace on blank line. - # W391 - Remove trailing blank lines. - # W601 - Use "in" rather than "has_key()". - # W602 - Fix deprecated form of raising exception. - # W603 - Use "!=" instead of "<>" - # W604 - Use "repr()" instead of backticks. - # W605 - Fix invalid escape sequence 'x'. - # W690 - Fix various deprecated code (via lib2to3). - - repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 + - repo: https://github.com/myint/docformatter + rev: v1.7.5 hooks: - - id: flake8 - args: ['--count', '--select', 'E101,E11,E111,E112,E113,E121,E122,E123,E124,E125,E126,E127,E128,E129,E131,E133,E20,E211,E231,E241,E242,E251,E252,E26,E265,E266,E27,E301,E302,E303,E304,E305,E306,E401,E402,E502,E701,E711,E712,E713,E714,E722,E731,E901,E902,F822,F823,W191,W291,W292,W293,W391,W601,W602,W603,W604,W605,W690'] - exclude: ".*(.fits|.fts|.fit|.txt|tca.*|extern.*|.rst|.md|mpl_animators/extern|docs/conf.py)$" - - repo: https://github.com/PyCQA/autoflake + - id: docformatter + args: ["--in-place", "--pre-summary-newline", "--make-summary-multi"] + - repo: https://github.com/myint/autoflake rev: v2.2.1 hooks: - id: autoflake - args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable'] - exclude: ".*(.fits|.fts|.fit|.txt|tca.*|extern.*|.rst|.md|__init__.py|mpl_animators/extern|docs/conf.py)$" - - repo: https://github.com/PyCQA/isort - rev: 5.13.2 + args: + [ + "--in-place", + "--remove-all-unused-imports", + "--remove-unused-variable", + ] + exclude: ".*(.fits|.fts|.fit|.txt|tca.*|.*extern.*|.rst|.md)$" + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: "v0.2.1" hooks: - - id: isort - args: ['--sp','setup.cfg'] - exclude: ".*(.fits|.fts|.fit|.txt|tca.*|extern.*|.rst|.md|mpl_animators/extern|docs/conf.py)$" + - id: ruff + args: ["--fix", "--unsafe-fixes"] + exclude: ".*(.fits|.fts|.fit|.txt|tca.*|.*extern.*|.rst|.md)$" + - id: ruff-format + exclude: ".*(.fits|.fts|.fit|.txt|tca.*|.*extern.*|.rst|.md)$" - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - id: check-ast - id: check-case-conflict - id: trailing-whitespace - exclude: ".*(.fits|.fts|.fit|.txt)$" + exclude: ".*(.fits|.fts|.fit|.txt|.csv)$" + - id: mixed-line-ending + exclude: ".*(.fits|.fts|.fit|.txt|.csv)$" + - id: end-of-file-fixer + exclude: ".*(.fits|.fts|.fit|.txt|.csv)$" - id: check-yaml - id: debug-statements - - id: check-added-large-files - - id: end-of-file-fixer - exclude: ".*(.fits|.fts|.fit|.txt|tca.*)$" - - id: mixed-line-ending - exclude: ".*(.fits|.fts|.fit|.txt|tca.*)$" + - repo: https://github.com/codespell-project/codespell + rev: v2.2.6 + hooks: + - id: codespell + additional_dependencies: + - tomli diff --git a/.readthedocs.yml b/.readthedocs.yml index 1f7666d..ea5ed1a 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -7,7 +7,6 @@ build: tools: python: "3.10" jobs: - post_checkout: - git fetch --unshallow || true pre_install: diff --git a/CHANGES.md b/CHANGES.md index 50bd8e4..9102150 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -3,6 +3,7 @@ ## v1.1.1 - 2023-11-17 + ### What's Changed #### Other Changes diff --git a/docs/conf.py b/docs/conf.py index 2d31b11..2da588c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,63 +1,35 @@ -# -*- coding: utf-8 -*- -# # Configuration file for the Sphinx documentation builder. -# -# This file does only contain a selection of the most common options. For a -# full list see the documentation: -# http://www.sphinx-doc.org/en/master/config +import datetime +from pathlib import Path +from sunpy_sphinx_theme import PNG_ICON # -- Project information ----------------------------------------------------- - -project = 'mpl-animators' -copyright = '2021, The SunPy Developers' -author = 'The SunPy Developers' - -# The full version, including alpha/beta/rc tags -from mpl_animators import __version__ -release = __version__ +project = "mpl-animators" +author = "The SunPy Community" +copyright = f"{datetime.datetime.now(datetime.timezone.utc).year}, {author}" # NOQA: A001 +author = "The SunPy Developers" # -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.inheritance_diagram', - 'sphinx.ext.viewcode', - 'sphinx.ext.napoleon', - 'sphinx.ext.doctest', - 'sphinx.ext.mathjax', - 'sphinx_automodapi.automodapi', - 'sphinx_automodapi.smart_resolver', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.inheritance_diagram", + "sphinx.ext.viewcode", + "sphinx.ext.napoleon", + "sphinx.ext.doctest", + "sphinx.ext.mathjax", + "sphinx_automodapi.automodapi", + "sphinx_automodapi.smart_resolver", ] - -# Add any paths that contain templates here, relative to this directory. -# templates_path = ['_templates'] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -source_suffix = '.rst' - -# The master toctree document. -master_doc = 'index' - -# The reST default role (used for this markup: `text`) to use for all -# documents. Set to the "smart" one. -default_role = 'obj' +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] +source_suffix = ".rst" +master_doc = "index" +default_role = "obj" # -- Options for intersphinx extension --------------------------------------- - -# Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { "python": ( "https://docs.python.org/3/", @@ -74,11 +46,16 @@ "astropy": ("https://docs.astropy.org/en/stable/", None), } -# -- Options for HTML output ------------------------------------------------- - -from sunpy_sphinx_theme.conf import * # NOQA - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -# html_static_path = ['_static'] +# -- Sphinx Gallery ------------------------------------------------------------ +sphinx_gallery_conf = { + "backreferences_dir": Path("generated") / "modules", + "filename_pattern": "^((?!skip_).)*$", + "examples_dirs": Path("..") / "examples", + "gallery_dirs": Path("generated") / "gallery", + "matplotlib_animations": True, + "default_thumb_file": PNG_ICON, + "abort_on_example_error": False, + "plot_gallery": "True", + "remove_config_comments": True, + "only_warn_on_example_error": True, +} diff --git a/examples/README.txt b/examples/README.txt new file mode 100644 index 0000000..fdc4456 --- /dev/null +++ b/examples/README.txt @@ -0,0 +1,6 @@ +*************** +Example Gallery +*************** + +The gallery contains examples of how to use mpl-animators. +Each example is a short and self contained how-to guide for performing a specific task. diff --git a/examples/arrayanimatorwcs.py b/examples/arrayanimatorwcs.py new file mode 100644 index 0000000..7d685fe --- /dev/null +++ b/examples/arrayanimatorwcs.py @@ -0,0 +1,87 @@ +""" +============================================== +Creating a visualization with ArrayAnimatorWCS +============================================== + +This example shows how to create a simple visualization using +`~mpl_animators.ArrayAnimatorWCS`. +""" +import astropy.units as u +import astropy.wcs +import matplotlib.pyplot as plt +import sunpy.map +from astropy.visualization import AsinhStretch, ImageNormalize +from sunpy.data.sample import AIA_171_IMAGE, AIA_193_IMAGE +from sunpy.time import parse_time + +from mpl_animators import ArrayAnimatorWCS + +################################################################################ +# To showcase how to visualize a sequence of 2D images using +# `~mpl_animators.ArrayAnimatorWCS`, we will use images from +# our sample data. The problem with this is that they are not part of +# a continuous dataset. To overcome this we will do two things. +# Create a stacked array of the images and create a `~astropy.wcs.WCS` header. +# The easiest method for the array is to create a `~sunpy.map.MapSequence`. + +# Here we only use two files but you could pass in a larger selection of files. +map_sequence = sunpy.map.Map(AIA_171_IMAGE, AIA_193_IMAGE, sequence=True) + +# Now we can just cast the sequence away into a NumPy array. +sequence_array = map_sequence.as_array() + +# We'll also define a common normalization to use in the animations +norm = ImageNormalize(vmin=0, vmax=3e4, stretch=AsinhStretch(0.01)) + +############################################################################### +# Now we need to create the `~astropy.wcs.WCS` header that +# `~mpl_animators.ArrayAnimatorWCS` will need. +# To create the new header we can use the stored meta information from the +# ``map_sequence``. + +# Now we need to get the time difference between the two observations. +t0, t1 = map(parse_time, [k["date-obs"] for k in map_sequence.all_meta()]) +time_diff = (t1 - t0).to(u.s) + +m = map_sequence[0] + +wcs = astropy.wcs.WCS(naxis=3) +wcs.wcs.crpix = u.Quantity([0 * u.pix, *list(m.reference_pixel)]) +wcs.wcs.cdelt = [time_diff.value, *list(u.Quantity(m.scale).value)] +wcs.wcs.crval = [0, m._reference_longitude.value, m._reference_latitude.value] +wcs.wcs.ctype = ["TIME", *list(m.coordinate_system)] +wcs.wcs.cunit = ["s", *list(m.spatial_units)] +wcs.wcs.aux.rsun_ref = m.rsun_meters.to_value(u.m) + +# Now the resulting WCS object will look like: +print(wcs) + +############################################################################### +# Now we can create the animation. +# `~mpl_animators.ArrayAnimatorWCS` requires you to select which +# axes you want to plot on the image. All other axes should have a ``0`` and +# sliders will be created to control the value for this axis. + +wcs_anim = ArrayAnimatorWCS(sequence_array, wcs, [0, "x", "y"], norm=norm).get_animation() + +plt.show() + +############################################################################### +# You might notice that the animation could do with having the axes look +# neater. `~mpl_animators.ArrayAnimatorWCS` provides a way of setting +# some display properties of the `~astropy.visualization.wcsaxes.WCSAxes` +# object on every frame of the animation via use of the ``coord_params`` dict. +# They keys of the ``coord_params`` dict are either the first half of the +# ``CTYPE`` key, the whole ``CTYPE`` key or the entries in +# ``wcs.world_axis_physical_types`` here we use the short ctype identifiers for +# the latitude and longitude axes. + +coord_params = { + "hpln": {"axislabel": "Helioprojective Longitude", "ticks": {"spacing": 10 * u.arcmin, "color": "black"}}, + "hplt": {"axislabel": "Helioprojective Latitude", "ticks": {"spacing": 10 * u.arcmin, "color": "black"}}, +} + +# We have to recreate the visualization since we displayed it earlier. +wcs_anim = ArrayAnimatorWCS(sequence_array, wcs, [0, "x", "y"], norm=norm, coord_params=coord_params).get_animation() + +plt.show() diff --git a/examples/lineanimator.py b/examples/lineanimator.py new file mode 100644 index 0000000..b7607aa --- /dev/null +++ b/examples/lineanimator.py @@ -0,0 +1,52 @@ +""" +=========================== +How to use the LineAnimator +=========================== + +This example shows off some ways in which you can use the +LineAnimator object to animate line plots. +""" +import matplotlib.pyplot as plt +import numpy as np + +from mpl_animators import LineAnimator + +############################################################################### +# Animate a 2D cube of random data as a line plot along an +# axis where the x-axis drifts with time. + +# Define some random data +data_shape0 = (10, 20) +rng = np.random.default_rng() +data0 = rng.random(data_shape0) + +############################################################################### +# Define the axis that will make up the line plot. + +plot_axis0 = 1 +slider_axis0 = 0 + +############################################################################### +# Let's customize the values along the x-axis. To do this, we must define the +# edges of the pixels/bins being plotted along the x-axis. This requires us to +# supply an array, say xdata, of length equal to data.shape[plot_axis_index]+1. +# In this example, the data has a shape of (10, 20) and let's say we are +# iterating through the 0th axis and plotting the 1st axis, +# i.e. plot_axis_index=1. Therefore we need to define an xdata array of length +# 21. +# This will give the same customized x-axis values for each frame of the +# animation. However, what if we want the x-axis values to change as we +# animate through the other dimensions of the cube? To do this we supply a +# (10, 21) xdata where each row (i.e. xdata[i, :]) gives the pixel/bin edges +# along the x-axis for the of the i-th frame of the animation. Note that this +# API extends in the same way to higher dimension. In our 2D case here though, +# we can define our non-constant x-axis values like so: + +xdata = np.tile(np.linspace(0, 100, (data_shape0[plot_axis0] + 1)), (data_shape0[slider_axis0], 1)) + +############################################################################### +# Generate animation object with variable x-axis data. + +ani = LineAnimator(data0, plot_axis_index=plot_axis0, axis_ranges=[None, xdata]).get_animation() + +plt.show() diff --git a/mpl_animators/__init__.py b/mpl_animators/__init__.py index 27b74a9..4d5ebb9 100644 --- a/mpl_animators/__init__.py +++ b/mpl_animators/__init__.py @@ -1,7 +1,8 @@ -# Licensed under a 3-clause BSD style license - see LICENSE.rst -from mpl_animators.base import * -from mpl_animators.image import * -from mpl_animators.line import * -from mpl_animators.wcs import * +from mpl_animators.base import ArrayAnimator, BaseFuncAnimator +from mpl_animators.image import ImageAnimator +from mpl_animators.line import LineAnimator +from mpl_animators.wcs import ArrayAnimatorWCS from .version import __version__ + +__all__ = ["ArrayAnimator", "BaseFuncAnimator", "LineAnimator", "ArrayAnimatorWCS", "ImageAnimator", "__version__"] diff --git a/mpl_animators/base.py b/mpl_animators/base.py index 110e8c3..a903861 100644 --- a/mpl_animators/base.py +++ b/mpl_animators/base.py @@ -13,7 +13,7 @@ except ImportError: units = None -__all__ = ['BaseFuncAnimator', 'ArrayAnimator'] +__all__ = ["BaseFuncAnimator", "ArrayAnimator"] class BaseFuncAnimator(metaclass=abc.ABCMeta): @@ -79,10 +79,19 @@ class BaseFuncAnimator(metaclass=abc.ABCMeta): Extra keywords are passed to `matplotlib.pyplot.imshow`. """ - def __init__(self, data, slider_functions, slider_ranges, fig=None, - interval=200, colorbar=False, button_func=None, button_labels=None, - start_image_func=None, slider_labels=None, **kwargs): - + def __init__( + self, + data, + slider_functions, + slider_ranges, + fig=None, + interval=200, + colorbar=False, + button_func=None, + button_labels=None, + slider_labels=None, + **kwargs, + ): # Allow the user to specify the button func: self.button_func = button_func or [] if button_func and not button_labels: @@ -100,16 +109,17 @@ def __init__(self, data, slider_functions, slider_ranges, fig=None, self.imshow_kwargs = kwargs if len(slider_functions) != len(slider_ranges): - raise ValueError("slider_functions and slider_ranges must be the same length.") + msg = "slider_functions and slider_ranges must be the same length." + raise ValueError(msg) - if slider_labels is not None: - if len(slider_labels) != len(slider_functions): - raise ValueError("slider_functions and slider_labels must be the same length.") + if slider_labels is not None and len(slider_labels) != len(slider_functions): + msg = "slider_functions and slider_labels must be the same length." + raise ValueError(msg) self.num_sliders = len(slider_functions) self.slider_functions = slider_functions self.slider_ranges = slider_ranges - self.slider_labels = slider_labels or [''] * len(slider_functions) + self.slider_labels = slider_labels or [""] * len(slider_functions) # Set active slider self.active_slider = 0 @@ -149,8 +159,7 @@ def label_slider(self, i, label): """ self.sliders[i]._slider.label.set_text(label) - def get_animation(self, axes=None, slider=0, startframe=0, endframe=None, - stepframe=1, **kwargs): + def get_animation(self, axes=None, slider=0, startframe=0, endframe=None, stepframe=1, **kwargs): """ Return a `~matplotlib.animation.FuncAnimation` instance for the selected slider. @@ -186,14 +195,13 @@ def get_animation(self, axes=None, slider=0, startframe=0, endframe=None, im = self.plot_start_image(axes) - anim_kwargs = {'frames': list(range(startframe, endframe, stepframe)), - 'fargs': [im, self.sliders[slider]._slider]} + anim_kwargs = { + "frames": list(range(startframe, endframe, stepframe)), + "fargs": [im, self.sliders[slider]._slider], + } anim_kwargs.update(kwargs) - ani = mplanim.FuncAnimation(anim_fig, self.slider_functions[slider], - **anim_kwargs) - - return ani + return mplanim.FuncAnimation(anim_fig, self.slider_functions[slider], **anim_kwargs) @abc.abstractmethod def plot_start_image(self, ax): @@ -215,39 +223,41 @@ def plot_start_image(self, ax): `~matplotlib.image.AxesImage` object, or a `~matplotlib.lines.Line2D`. """ - raise NotImplementedError("Please define this function.") + msg = "Please define this function." + raise NotImplementedError(msg) def _connect_fig_events(self): - self.fig.canvas.mpl_connect('button_press_event', self._mouse_click) - self.fig.canvas.mpl_connect('key_press_event', self._key_press) + self.fig.canvas.mpl_connect("button_press_event", self._mouse_click) + self.fig.canvas.mpl_connect("key_press_event", self._key_press) def _add_colorbar(self, im): self.colorbar = self.fig.colorbar(im, self.cax) -# ============================================================================= -# Figure event callback functions -# ============================================================================= + # ============================================================================= + # Figure event callback functions + # ============================================================================= def _mouse_click(self, event): if event.inaxes in self.sliders: slider = self.sliders.index(event.inaxes) self._set_active_slider(slider) def _key_press(self, event): - if event.key == 'left': + if event.key == "left": self._previous(self.sliders[self.active_slider]._slider) - elif event.key == 'right': + elif event.key == "right": self._step(self.sliders[self.active_slider]._slider) - elif event.key == 'up': - self._set_active_slider((self.active_slider+1) % self.num_sliders) - elif event.key == 'down': - self._set_active_slider((self.active_slider-1) % self.num_sliders) - elif event.key == 'p': - self._click_slider_button(event, self.slider_buttons[self.active_slider]._button, - self.sliders[self.active_slider]._slider) - -# ============================================================================= -# Active Slider methods -# ============================================================================= + elif event.key == "up": + self._set_active_slider((self.active_slider + 1) % self.num_sliders) + elif event.key == "down": + self._set_active_slider((self.active_slider - 1) % self.num_sliders) + elif event.key == "p": + self._click_slider_button( + event, self.slider_buttons[self.active_slider]._button, self.sliders[self.active_slider]._slider + ) + + # ============================================================================= + # Active Slider methods + # ============================================================================= def _set_active_slider(self, ind): self._dehighlight_slider(self.active_slider) self._highlight_slider(ind) @@ -263,12 +273,13 @@ def _dehighlight_slider(self, ind): [a.set_linewidth(1.0) for n, a in ax.spines.items()] self.fig.canvas.draw() -# ============================================================================= -# Build the figure and place the widgets -# ============================================================================= + # ============================================================================= + # Build the figure and place the widgets + # ============================================================================= def _setup_main_axes(self): """ Allow replacement of main axes by subclassing. + This method must set the ``axes`` attribute. """ if self.axes is None: @@ -286,22 +297,22 @@ def _make_axes_grid(self): button_grid = max((7, self.num_buttons)) # Define size of useful axes cells, 50% each in x 20% for buttons in y. - ysize = Size.Fraction((1.-2.*pad)/15., Size.AxesY(self.axes)) - xsize = Size.Fraction((1.-2.*pad)/button_grid, Size.AxesX(self.axes)) + ysize = Size.Fraction((1.0 - 2.0 * pad) / 15.0, Size.AxesY(self.axes)) + xsize = Size.Fraction((1.0 - 2.0 * pad) / button_grid, Size.AxesX(self.axes)) # Set up grid, 3x3 with cells for padding. if self.num_buttons > 0: - horiz = [xsize] + [pad_size, xsize]*(button_grid-1) - vert = [ysize, pad_size] * self.num_sliders + \ - [large_pad_size, large_pad_size, Size.AxesY(self.axes)] + horiz = [xsize] + [pad_size, xsize] * (button_grid - 1) + vert = [ysize, pad_size] * self.num_sliders + [large_pad_size, large_pad_size, Size.AxesY(self.axes)] else: - vert = [ysize, large_pad_size] * self.num_sliders + \ - [large_pad_size, Size.AxesY(self.axes)] - horiz = [Size.Fraction(0.1, Size.AxesX(self.axes))] + \ - [Size.Fraction(0.05, Size.AxesX(self.axes))] + \ - [Size.Fraction(0.65, Size.AxesX(self.axes))] + \ - [Size.Fraction(0.1, Size.AxesX(self.axes))] + \ - [Size.Fraction(0.1, Size.AxesX(self.axes))] + vert = [ysize, large_pad_size] * self.num_sliders + [large_pad_size, Size.AxesY(self.axes)] + horiz = [ + Size.Fraction(0.1, Size.AxesX(self.axes)), + Size.Fraction(0.05, Size.AxesX(self.axes)), + Size.Fraction(0.65, Size.AxesX(self.axes)), + Size.Fraction(0.1, Size.AxesX(self.axes)), + Size.Fraction(0.1, Size.AxesX(self.axes)), + ] self.divider.set_horizontal(horiz) self.divider.set_vertical(vert) @@ -310,57 +321,58 @@ def _make_axes_grid(self): # If we are going to add a colorbar it'll need an axis next to the plot if self.if_colorbar: nx1 = -3 - self.cax = self.fig.add_axes((0., 0., 0.141, 1.)) - locator = self.divider.new_locator(nx=-2, ny=len(vert)-1, nx1=-1) + self.cax = self.fig.add_axes((0.0, 0.0, 0.141, 1.0)) + locator = self.divider.new_locator(nx=-2, ny=len(vert) - 1, nx1=-1) self.cax.set_axes_locator(locator) else: # Main figure spans all horiz and is in the top (2) in vert. nx1 = -1 - self.axes.set_axes_locator( - self.divider.new_locator(nx=0, ny=len(vert)-1, nx1=nx1)) + self.axes.set_axes_locator(self.divider.new_locator(nx=0, ny=len(vert) - 1, nx1=nx1)) def _add_widgets(self): self.buttons = [] - for i in range(0, self.num_buttons): + for i in range(self.num_buttons): x = i * 2 # The i+1/10. is a bug that if you make two axes directly on top of # one another then the divider doesn't work. - self.buttons.append(self.fig.add_axes((0., 0., 0.+i/10., 1.))) + self.buttons.append(self.fig.add_axes((0.0, 0.0, 0.0 + i / 10.0, 1.0))) locator = self.divider.new_locator(nx=x, ny=self.button_ny) self.buttons[-1].set_axes_locator(locator) - self.buttons[-1]._button = widgets.Button(self.buttons[-1], - self.button_labels[i]) + self.buttons[-1]._button = widgets.Button(self.buttons[-1], self.button_labels[i]) self.buttons[-1]._button.on_clicked(partial(self.button_func[i], self)) self.sliders = [] self.slider_buttons = [] for i in range(self.num_sliders): y = i * 2 - self.sliders.append(self.fig.add_axes((0., 0., 0.01+i/10., 1.))) - if self.num_buttons == 0: - nx1 = 3 - else: - nx1 = -2 + self.sliders.append(self.fig.add_axes((0.0, 0.0, 0.01 + i / 10.0, 1.0))) + nx1 = 3 if self.num_buttons == 0 else -2 locator = self.divider.new_locator(nx=2, ny=y, nx1=nx1) self.sliders[-1].set_axes_locator(locator) - self.sliders[-1].text(0.5, 0.5, self.slider_labels[i], - transform=self.sliders[-1].transAxes, - horizontalalignment="center", - verticalalignment="center") - - sframe = widgets.Slider(self.sliders[-1], "", - self.slider_ranges[i][0], - self.slider_ranges[i][-1]-1, - valinit=self.slider_ranges[i][0], - valfmt='%4.1f') + self.sliders[-1].text( + 0.5, + 0.5, + self.slider_labels[i], + transform=self.sliders[-1].transAxes, + horizontalalignment="center", + verticalalignment="center", + ) + + sframe = widgets.Slider( + self.sliders[-1], + "", + self.slider_ranges[i][0], + self.slider_ranges[i][-1] - 1, + valinit=self.slider_ranges[i][0], + valfmt="%4.1f", + ) sframe.on_changed(partial(self._slider_changed, slider=sframe)) sframe.slider_ind = i sframe.cval = sframe.val self.sliders[-1]._slider = sframe - self.slider_buttons.append( - self.fig.add_axes((0., 0., 0.05+y/10., 1.))) + self.slider_buttons.append(self.fig.add_axes((0.0, 0.0, 0.05 + y / 10.0, 1.0))) locator = self.divider.new_locator(nx=0, ny=y) self.slider_buttons[-1].set_axes_locator(locator) @@ -369,9 +381,9 @@ def _add_widgets(self): butt.clicked = False self.slider_buttons[-1]._button = butt -# ============================================================================= -# Widget callbacks -# ============================================================================= + # ============================================================================= + # Widget callbacks + # ============================================================================= def _slider_changed(self, val, slider): self.slider_functions[slider.slider_ind](val, self.im, slider) @@ -404,7 +416,7 @@ def _step(self, slider): if s.val >= s.valmax: s.set_val(s.valmin) else: - s.set_val(s.val+1) + s.set_val(s.val + 1) self.fig.canvas.draw() def _previous(self, slider): @@ -412,7 +424,7 @@ def _previous(self, slider): if s.val <= s.valmin: s.set_val(s.valmax) else: - s.set_val(s.val-1) + s.set_val(s.val - 1) self.fig.canvas.draw() @@ -450,8 +462,9 @@ class ArrayAnimator(BaseFuncAnimator, metaclass=abc.ABCMeta): Extra keywords are passed to `~sunpy.visualization.animator.BaseFuncAnimator`. """ - def __init__(self, data, image_axes=[-2, -1], axis_ranges=None, **kwargs): - + def __init__(self, data, image_axes=None, axis_ranges=None, **kwargs): + if image_axes is None: + image_axes = [-2, -1] all_axes = list(range(self.naxis)) # Handle negative indexes self.image_axes = [all_axes[i] for i in image_axes] @@ -461,14 +474,16 @@ def __init__(self, data, image_axes=[-2, -1], axis_ranges=None, **kwargs): slider_axes.remove(x) if len(slider_axes) != self.num_sliders: - raise ValueError("Number of sliders doesn't match the number of slider axes.") + msg = "Number of sliders doesn't match the number of slider axes." + raise ValueError(msg) self.slider_axes = slider_axes # Verify that combined slider_axes and image_axes make all axes ax = self.slider_axes + self.image_axes ax.sort() if ax != list(range(self.naxis)): - raise ValueError("Number of image and slider axes do not match total number of axes.") + msg = "Number of image and slider axes do not match total number of axes." + raise ValueError(msg) self.axis_ranges, self.extent = self._sanitize_axis_ranges(axis_ranges, data.shape) @@ -480,8 +495,8 @@ def __init__(self, data, image_axes=[-2, -1], axis_ranges=None, **kwargs): slider_functions = kwargs.pop("slider_functions", []) slider_ranges = kwargs.pop("slider_ranges", []) base_kwargs = { - 'slider_functions': ([self.update_plot] * self.num_sliders) + slider_functions, - 'slider_ranges': [[0, dim] for dim in np.array(data.shape)[self.slider_axes]] + slider_ranges + "slider_functions": ([self.update_plot] * self.num_sliders) + slider_functions, + "slider_ranges": [[0, dim] for dim in np.array(data.shape)[self.slider_axes]] + slider_ranges, } self.num_sliders = len(base_kwargs["slider_functions"]) base_kwargs.update(kwargs) @@ -537,19 +552,23 @@ def _sanitize_axis_ranges(self, axis_ranges, data_shape): # need the same number of axis ranges as axes if len(axis_ranges) != ndim: - raise ValueError("Length of axis_ranges must equal number of axes") + msg = "Length of axis_ranges must equal number of axes" + raise ValueError(msg) # Define error message for incompatible axis_range input. - def incompatible_axis_ranges_error_message(j): return \ - (f"Unrecognized format for {j}th entry in axis_ranges: {axis_ranges[j]}" - "axis_ranges must be None, a ``[min, max]`` pair, or " - "an array-like giving the edge values of each pixel, " - "i.e. length must be length of axis + 1.") + def incompatible_axis_ranges_error_message(j): + return ( + f"Unrecognized format for {j}th entry in axis_ranges: {axis_ranges[j]}" + "axis_ranges must be None, a ``[min, max]`` pair, or " + "an array-like giving the edge values of each pixel, " + "i.e. length must be length of axis + 1." + ) # If axis range not given, define a function such that the range goes # from -0.5 to number of pixels-0.5. Thus, the center of the pixels # along the axis will correspond to integer values. - def none_image_axis_range(j): return [-0.5, data_shape[j]-0.5] + def none_image_axis_range(j): + return [-0.5, data_shape[j] - 0.5] # For each axis validate and translate the axis_ranges. For image axes, # also determine the plot extent. To do this, iterate through image and slider @@ -567,16 +586,16 @@ def none_image_axis_range(j): return [-0.5, data_shape[j]-0.5] if len(axis_ranges[i]) == 2: # Set extent. extent += [axis_ranges[i][0], axis_ranges[i][-1]] - elif axis_ranges[i].ndim == 1 and len(axis_ranges[i]) == data_shape[i]+1: + elif axis_ranges[i].ndim == 1 and len(axis_ranges[i]) == data_shape[i] + 1: # If array of individual pixel edges supplied, first set extent # from first and last pixel edge, then convert axis_ranges to pixel centers. # The reason that pixel edges are required as input rather than centers # is so that the plot extent can be derived from axis_ranges (above) # and APIs using both [min, max] pair and manual definition of each pixel - # values can be unambiguously and simultanously supported. + # values can be unambiguously and simultaneously supported. extent += [axis_ranges[i][0], axis_ranges[i][-1]] axis_ranges[i] = edges_to_centers_nd(axis_ranges[i], 0) - elif axis_ranges[i].ndim == ndim and axis_ranges[i].shape[i] == data_shape[i]+1: + elif axis_ranges[i].ndim == ndim and axis_ranges[i].shape[i] == data_shape[i] + 1: extent += [axis_ranges[i].min(), axis_ranges[i].max()] axis_ranges[i] = edges_to_centers_nd(axis_ranges[i], i) else: @@ -586,6 +605,7 @@ def none_image_axis_range(j): return [-0.5, data_shape[j]-0.5] def get_pixel_to_world_callable(array): def pixel_to_world(pixel): return array[pixel] + return pixel_to_world for sidx in self.slider_axes: @@ -598,22 +618,22 @@ def pixel_to_world(pixel): # If axis range given as a min, max pair, derive the center of each pixel # assuming they are equally spaced. - axis_ranges[sidx] = np.linspace(axis_ranges[sidx][0], axis_ranges[sidx][-1], - data_shape[sidx]+1) - axis_ranges[sidx] = get_pixel_to_world_callable( - edges_to_centers_nd(axis_ranges[sidx], sidx)) - elif axis_ranges[sidx].ndim == 1 and len(axis_ranges[sidx]) == data_shape[sidx]+1: + axis_ranges[sidx] = np.linspace(axis_ranges[sidx][0], axis_ranges[sidx][-1], data_shape[sidx] + 1) + axis_ranges[sidx] = get_pixel_to_world_callable(edges_to_centers_nd(axis_ranges[sidx], sidx)) + elif axis_ranges[sidx].ndim == 1 and len(axis_ranges[sidx]) == data_shape[sidx] + 1: # If axis range given as 1D array of pixel edges (i.e. axis is independent), # derive pixel centers. axis_ranges[sidx] = get_pixel_to_world_callable( - edges_to_centers_nd(np.asarray(axis_ranges[sidx]), 0)) - elif axis_ranges[sidx].ndim == ndim and axis_ranges[sidx].shape[sidx] == data_shape[sidx]+1: + edges_to_centers_nd(np.asarray(axis_ranges[sidx]), 0) + ) + elif axis_ranges[sidx].ndim == ndim and axis_ranges[sidx].shape[sidx] == data_shape[sidx] + 1: # If axis range given as array of pixel edges the same shape as # the data array (i.e. axis is not independent), derive pixel centers. axis_ranges[sidx] = get_pixel_to_world_callable( - edges_to_centers_nd(np.asarray(axis_ranges[sidx]), i)) + edges_to_centers_nd(np.asarray(axis_ranges[sidx]), i) + ) else: raise ValueError(incompatible_axis_ranges_error_message(i)) @@ -639,9 +659,7 @@ def update_plot(self, val, artist, slider): # Update slider label to reflect real world values in axis_ranges. label = self.axis_ranges[ax_ind](ind) if units is not None and isinstance(label, units.Quantity): - slider.valtext.set_text(label.to_string(precision=5, - format='latex', - subfmt='inline')) + slider.valtext.set_text(label.to_string(precision=5, format="latex", subfmt="inline")) elif isinstance(label, str): slider.valtext.set_text(label) else: diff --git a/mpl_animators/extern/modest_image.py b/mpl_animators/extern/modest_image.py index 8e073ad..3ea444e 100644 --- a/mpl_animators/extern/modest_image.py +++ b/mpl_animators/extern/modest_image.py @@ -4,56 +4,53 @@ """ # This file is copied from glue under the terms of the 3 Clause BSD licence. See licenses/GLUE.rst -from __future__ import print_function, division -import matplotlib -rcParams = matplotlib.rcParams +import matplotlib as mpl -import matplotlib.image as mi -import matplotlib.colors as mcolors -import matplotlib.cbook as cbook -from matplotlib.transforms import IdentityTransform, Affine2D +rcParams = mpl.rcParams +import matplotlib.cbook as cbook +import matplotlib.colors as mcolors +import matplotlib.image as mi import numpy as np +from matplotlib.transforms import Affine2D, IdentityTransform IDENTITY_TRANSFORM = IdentityTransform() class ModestImage(mi.AxesImage): - """ Computationally modest image class. - ModestImage is an extension of the Matplotlib AxesImage class - better suited for the interactive display of larger images. Before - drawing, ModestImage resamples the data array based on the screen - resolution and view window. This has very little affect on the - appearance of the image, but can substantially cut down on - computation since calculations of unresolved or clipped pixels - are skipped. + ModestImage is an extension of the Matplotlib AxesImage class better + suited for the interactive display of larger images. Before drawing, + ModestImage resamples the data array based on the screen resolution + and view window. This has very little affect on the appearance of + the image, but can substantially cut down on computation since + calculations of unresolved or clipped pixels are skipped. The interface of ModestImage is the same as AxesImage. However, it - does not currently support setting the 'extent' property. There - may also be weird coordinate warping operations for images that - I'm not aware of. Don't expect those to work either. + does not currently support setting the 'extent' property. There may + also be weird coordinate warping operations for images that I'm not + aware of. Don't expect those to work either. """ def __init__(self, *args, **kwargs): self._pressed = False self._full_res = None - self._full_extent = kwargs.get('extent', None) - super(ModestImage, self).__init__(*args, **kwargs) + self._full_extent = kwargs.get("extent", None) + super().__init__(*args, **kwargs) self.invalidate_cache() - self.axes.figure.canvas.mpl_connect('button_press_event', self._press) - self.axes.figure.canvas.mpl_connect('button_release_event', self._release) - self.axes.figure.canvas.mpl_connect('resize_event', self._resize) + self.axes.figure.canvas.mpl_connect("button_press_event", self._press) + self.axes.figure.canvas.mpl_connect("button_release_event", self._release) + self.axes.figure.canvas.mpl_connect("resize_event", self._resize) self._timer = self.axes.figure.canvas.new_timer(interval=500) self._timer.single_shot = True self._timer.add_callback(self._resize_paused) def remove(self): - super(ModestImage, self).remove() + super().remove() self._timer.stop() self._timer = None @@ -79,20 +76,20 @@ def _release(self, *args): def set_data(self, A): """ - Set the image array + Set the image array. ACCEPTS: numpy/PIL Image A """ self._full_res = A self._A = A - if self._A.dtype != np.uint8 and not np.can_cast(self._A.dtype, - float): - raise TypeError("Image data can not convert to float") + if self._A.dtype != np.uint8 and not np.can_cast(self._A.dtype, float): + msg = "Image data can not convert to float" + raise TypeError(msg) - if (self._A.ndim not in (2, 3) or - (self._A.ndim == 3 and self._A.shape[-1] not in (3, 4))): - raise TypeError("Invalid dimensions for image data") + if self._A.ndim not in (2, 3) or (self._A.ndim == 3 and self._A.shape[-1] not in (3, 4)): + msg = "Invalid dimensions for image data" + raise TypeError(msg) self.invalidate_cache() @@ -113,7 +110,7 @@ def contains(self, mouseevent): if self._A is None or self._A.shape is None: return False else: - return super(ModestImage, self).contains(mouseevent) + return super().contains(mouseevent) def set_extent(self, extent): self._full_extent = extent @@ -121,14 +118,14 @@ def set_extent(self, extent): mi.AxesImage.set_extent(self, extent) def get_array(self): - """Override to return the full-resolution array""" + """ + Override to return the full-resolution array. + """ return self._full_res @property def _pixel2world(self): - if self._pixel2world_cache is None: - # Pre-compute affine transforms to convert between the 'world' # coordinates of the axes (what is shown by the axis labels) to # 'pixel' coordinates in the underlying array. @@ -136,17 +133,16 @@ def _pixel2world(self): extent = self._full_extent if extent is None: - self._pixel2world_cache = IDENTITY_TRANSFORM else: - self._pixel2world_cache = Affine2D() self._pixel2world.translate(+0.5, +0.5) - self._pixel2world.scale((extent[1] - extent[0]) / self._full_res.shape[1], - (extent[3] - extent[2]) / self._full_res.shape[0]) + self._pixel2world.scale( + (extent[1] - extent[0]) / self._full_res.shape[1], (extent[3] - extent[2]) / self._full_res.shape[0] + ) self._pixel2world.translate(extent[0], extent[2]) @@ -169,16 +165,21 @@ def _scale_to_res(self): # Find out how we need to slice the array to make sure we match the # resolution of the display. We pass self._world2pixel which matters # for cases where the extent has been set. - x0, x1, sx, y0, y1, sy = extract_matched_slices(axes=self.axes, - shape=self._full_res.shape, - transform=self._world2pixel) + x0, x1, sx, y0, y1, sy = extract_matched_slices( + axes=self.axes, shape=self._full_res.shape, transform=self._world2pixel + ) # Check whether we've already calculated what we need, and if so just # return without doing anything further. - if (self._bounds is not None and - sx >= self._sx and sy >= self._sy and - x0 >= self._bounds[0] and x1 <= self._bounds[1] and - y0 >= self._bounds[2] and y1 <= self._bounds[3]): + if ( + self._bounds is not None + and sx >= self._sx + and sy >= self._sy + and x0 >= self._bounds[0] + and x1 <= self._bounds[1] + and y0 >= self._bounds[2] + and y1 <= self._bounds[3] + ): return # Slice the array using the slices determined previously to optimally @@ -193,10 +194,10 @@ def _scale_to_res(self): # demonstration of why origin='upper' and extent=None needs to be # special-cased. - if self.origin == 'upper' and self._full_extent is None: - xmin, xmax, ymin, ymax = x0 - .5, x1 - .5, y1 - .5, y0 - .5 + if self.origin == "upper" and self._full_extent is None: + xmin, xmax, ymin, ymax = x0 - 0.5, x1 - 0.5, y1 - 0.5, y0 - 0.5 else: - xmin, xmax, ymin, ymax = x0 - .5, x1 - .5, y0 - .5, y1 - .5 + xmin, xmax, ymin, ymax = x0 - 0.5, x1 - 0.5, y0 - 0.5, y1 - 0.5 xmin, ymin, xmax, ymax = self._pixel2world.transform([(xmin, ymin), (xmax, ymax)]).ravel() @@ -218,16 +219,18 @@ def draw(self, renderer, *args, **kwargs): self._scale_to_res() # Due to a bug in Matplotlib, we need to return here if all values # in the array are masked. - if hasattr(self._A, 'mask') and np.all(self._A.mask): + if hasattr(self._A, "mask") and np.all(self._A.mask): return - super(ModestImage, self).draw(renderer, *args, **kwargs) + super().draw(renderer, *args, **kwargs) def main(): from time import time + import matplotlib.pyplot as plt + x, y = np.mgrid[0:2000, 0:2000] - data = np.sin(x / 10.) * np.cos(y / 30.) + data = np.sin(x / 10.0) * np.cos(y / 30.0) f = plt.figure() ax = f.add_subplot(111) @@ -235,38 +238,61 @@ def main(): # try switching between artist = ModestImage(ax, data=data) - ax.set_aspect('equal') + ax.set_aspect("equal") artist.norm.vmin = -1 artist.norm.vmax = 1 ax.add_artist(artist) - t0 = time() + time() plt.gcf().canvas.draw_idle() - t1 = time() - - print("Draw time for %s: %0.1f ms" % (artist.__class__.__name__, - (t1 - t0) * 1000)) + time() plt.show() -def imshow(axes, X, cmap=None, norm=None, aspect=None, - interpolation=None, alpha=None, vmin=None, vmax=None, - origin=None, extent=None, shape=None, filternorm=1, - filterrad=4.0, imlim=None, resample=None, url=None, **kwargs): - """Similar to matplotlib's imshow command, but produces a ModestImage +def imshow( + axes, + X, + cmap=None, + norm=None, + aspect=None, + interpolation=None, + alpha=None, + vmin=None, + vmax=None, + origin=None, + extent=None, + shape=None, + filternorm=1, + filterrad=4.0, + imlim=None, + resample=None, + url=None, + **kwargs, +): + """ + Similar to matplotlib's imshow command, but produces a ModestImage. Unlike matplotlib version, must explicitly specify axes """ if norm is not None: - assert(isinstance(norm, mcolors.Normalize)) + assert isinstance(norm, mcolors.Normalize) if aspect is None: - aspect = rcParams['image.aspect'] + aspect = rcParams["image.aspect"] axes.set_aspect(aspect) - im = ModestImage(axes, cmap=cmap, norm=norm, interpolation=interpolation, - origin=origin, extent=extent, filternorm=filternorm, - filterrad=filterrad, resample=resample, **kwargs) + im = ModestImage( + axes, + cmap=cmap, + norm=norm, + interpolation=interpolation, + origin=origin, + extent=extent, + filternorm=filternorm, + filterrad=filterrad, + resample=resample, + **kwargs, + ) im.set_data(X) im.set_alpha(alpha) @@ -300,21 +326,17 @@ def remove(h): return im -def extract_matched_slices(axes=None, shape=None, extent=None, - transform=IDENTITY_TRANSFORM): - """Determine the slice parameters to use, matched to the screen. +def extract_matched_slices(axes=None, shape=None, extent=None, transform=IDENTITY_TRANSFORM): + """ + Determine the slice parameters to use, matched to the screen. :param ax: Axes object to query. It's extent and pixel size - determine the slice parameters - + determine the slice parameters :param shape: Tuple of the full image shape to slice into. Upper - boundaries for slices will be cropped to fit within - this shape. - - :rtype: tulpe of x0, x1, sx, y0, y1, sy - - Indexing the full resolution array as array[y0:y1:sy, x0:x1:sx] returns - a view well-matched to the axes' resolution and extent + boundaries for slices will be cropped to fit within this shape. + :rtype: tulpe of x0, x1, sx, y0, y1, sy Indexing the full resolution + array as array[y0:y1:sy, x0:x1:sx] returns a view well-matched + to the axes' resolution and extent """ # Find extent in display pixels (this gives the resolution we need @@ -341,8 +363,8 @@ def _clip(val, lo, hi): x1 = _clip(ind1[0] + 5, 1, shape[1]) # Determine the strides that can be used when extracting the array - sy = int(max(1, min((y1 - y0) / 5., np.ceil(abs((ind1[1] - ind0[1]) / ext[1]))))) - sx = int(max(1, min((x1 - x0) / 5., np.ceil(abs((ind1[0] - ind0[0]) / ext[0]))))) + sy = int(max(1, min((y1 - y0) / 5.0, np.ceil(abs((ind1[1] - ind0[1]) / ext[1]))))) + sx = int(max(1, min((x1 - x0) / 5.0, np.ceil(abs((ind1[0] - ind0[0]) / ext[0]))))) return x0, x1, sx, y0, y1, sy diff --git a/mpl_animators/image.py b/mpl_animators/image.py index bc7146d..18896a3 100644 --- a/mpl_animators/image.py +++ b/mpl_animators/image.py @@ -2,7 +2,7 @@ from .base import ArrayAnimator -__all__ = ['ImageAnimator'] +__all__ = ["ImageAnimator"] class ImageAnimator(ArrayAnimator): @@ -39,13 +39,16 @@ class ImageAnimator(ArrayAnimator): Extra keywords are passed to `~sunpy.visualization.animator.ArrayAnimator`. """ - def __init__(self, data, image_axes=[-2, -1], axis_ranges=None, **kwargs): + def __init__(self, data, image_axes=None, axis_ranges=None, **kwargs): # Check that number of axes is 2. + if image_axes is None: + image_axes = [-2, -1] if len(image_axes) != 2: - raise ValueError("There can only be two spatial axes") + msg = "There can only be two spatial axes" + raise ValueError(msg) # Define number of slider axes. self.naxis = data.ndim - self.num_sliders = self.naxis-2 + self.num_sliders = self.naxis - 2 # Define marker to determine if plot axes values are supplied via array of # pixel values or min max pair. This will determine the type of image produced # and hence how to plot and update it. @@ -68,8 +71,7 @@ def plot_start_image(self, ax): extent.append(self.axis_ranges[i][0]) extent.append(self.axis_ranges[i][-1]) - imshow_args = {'interpolation': 'nearest', - 'origin': 'lower'} + imshow_args = {"interpolation": "nearest", "origin": "lower"} imshow_args.update(self.imshow_kwargs) # If value along an axis is set with an array, generate a NonUniformImage @@ -82,15 +84,14 @@ def plot_start_image(self, ax): # Initialize a NonUniformImage with the relevant data and axis values and # add the image to the axes. im = mpl.image.NonUniformImage(ax, **imshow_args) - im.set_data(self.axis_ranges[self.image_axes[0]], - self.axis_ranges[self.image_axes[1]], data) + im.set_data(self.axis_ranges[self.image_axes[0]], self.axis_ranges[self.image_axes[1]], data) ax.add_image(im) # Define the xlim and ylim from the pixel edges. ax.set_xlim(self.extent[0], self.extent[1]) ax.set_ylim(self.extent[2], self.extent[3]) else: # Else produce a more basic plot with regular axes. - imshow_args.update({'extent': extent}) + imshow_args.update({"extent": extent}) im = ax.imshow(self.data[self.frame_index], **imshow_args) if self.if_colorbar: self._add_colorbar(im) @@ -110,8 +111,7 @@ def update_plot(self, val, im, slider): data = self.data[self.frame_index].transpose() else: data = self.data[self.frame_index] - im.set_data(self.axis_ranges[self.image_axes[0]], - self.axis_ranges[self.image_axes[1]], data) + im.set_data(self.axis_ranges[self.image_axes[0]], self.axis_ranges[self.image_axes[1]], data) else: im.set_array(self.data[self.frame_index]) slider.cval = val diff --git a/mpl_animators/line.py b/mpl_animators/line.py index 81c72fa..9722f8c 100644 --- a/mpl_animators/line.py +++ b/mpl_animators/line.py @@ -2,7 +2,7 @@ from .base import ArrayAnimator, edges_to_centers_nd -__all__ = ['LineAnimator'] +__all__ = ["LineAnimator"] class LineAnimator(ArrayAnimator): @@ -77,16 +77,31 @@ class LineAnimator(ArrayAnimator): Extra keywords are passed to `~sunpy.visualization.animator.ArrayAnimator`. """ - def __init__(self, data, plot_axis_index=-1, axis_ranges=None, ylabel=None, xlabel=None, - xlim=None, ylim=None, aspect='auto', **kwargs): + def __init__( + self, + data, + plot_axis_index=-1, + axis_ranges=None, + ylabel=None, + xlabel=None, + xlim=None, + ylim=None, + aspect="auto", + **kwargs, + ): # Check inputs. self.plot_axis_index = int(plot_axis_index) if self.plot_axis_index not in range(-data.ndim, data.ndim): - raise ValueError("plot_axis_index must be within range of number of data dimensions" - " (or equivalent negative indices).") + msg = ( + "plot_axis_index must be within range of number of data dimensions" " (or equivalent negative indices)." + ) + raise ValueError(msg) if data.ndim < 2: - raise ValueError("data must have at least two dimensions. One for data " - "for each single plot and at least one for time/iteration.") + msg = ( + "data must have at least two dimensions. One for data " + "for each single plot and at least one for time/iteration." + ) + raise ValueError(msg) # Define number of slider axes. self.naxis = data.ndim self.num_sliders = self.naxis - 1 @@ -105,8 +120,7 @@ def __init__(self, data, plot_axis_index=-1, axis_ranges=None, ylabel=None, xlab # Else derive the xdata as pixel centers from the pixel edges supplied by # the user in axis_ranges[plot_axis_index] along axis=plot_axis_index else: - self.xdata = edges_to_centers_nd(np.asarray(axis_ranges[self.plot_axis_index]), - plot_axis_index) + self.xdata = edges_to_centers_nd(np.asarray(axis_ranges[self.plot_axis_index]), plot_axis_index) if ylim is None: ylim = (np.nanmin(data), np.nanmax(data)) if xlim is None: @@ -117,8 +131,7 @@ def __init__(self, data, plot_axis_index=-1, axis_ranges=None, ylabel=None, xlab self.ylabel = ylabel self.aspect = aspect # Run init for base class - super().__init__(data, image_axes=[self.plot_axis_index], axis_ranges=axis_ranges, - **kwargs) + super().__init__(data, image_axes=[self.plot_axis_index], axis_ranges=axis_ranges, **kwargs) def plot_start_image(self, ax): """ @@ -126,7 +139,7 @@ def plot_start_image(self, ax): """ ax.set_xlim(self.xlim) ax.set_ylim(self.ylim) - ax.set_aspect(self.aspect, adjustable='datalim') + ax.set_aspect(self.aspect, adjustable="datalim") if self.xlabel is not None: ax.set_xlabel(self.xlabel) if self.ylabel is not None: @@ -139,7 +152,7 @@ def plot_start_image(self, ax): xdata = np.squeeze(self.xdata[tuple(item)]) else: xdata = self.xdata - line, = ax.plot(xdata, self.data[self.frame_index], **plot_args) + (line,) = ax.plot(xdata, self.data[self.frame_index], **plot_args) return line def update_plot(self, val, line, slider): @@ -154,10 +167,7 @@ def update_plot(self, val, line, slider): if self.xdata.shape == self.data.shape: item = [int(slid._slider.val) for slid in self.sliders] item[ax_ind] = val - if self.plot_axis_index < 0: - i = self.data.ndim + self.plot_axis_index - else: - i = self.plot_axis_index + i = self.data.ndim + self.plot_axis_index if self.plot_axis_index < 0 else self.plot_axis_index item.insert(i, slice(None)) line.set_xdata(self.xdata[tuple(item)]) slider.cval = val diff --git a/mpl_animators/tests/helpers.py b/mpl_animators/tests/helpers.py index ea360fe..c7cd432 100644 --- a/mpl_animators/tests/helpers.py +++ b/mpl_animators/tests/helpers.py @@ -12,8 +12,14 @@ def get_hash_library_name(): Generate the hash library name for this env. """ ft2_version = f"{mpl.ft2font.__freetype_version__.replace('.', '')}" - mpl_version = "dev" if (("dev" in mpl.__version__) or ("rc" in mpl.__version__)) else mpl.__version__.replace('.', '') - astropy_version = "dev" if (("dev" in astropy.__version__) or ("rc" in astropy.__version__)) else astropy.__version__.replace('.', '') + mpl_version = ( + "dev" if (("dev" in mpl.__version__) or ("rc" in mpl.__version__)) else mpl.__version__.replace(".", "") + ) + astropy_version = ( + "dev" + if (("dev" in astropy.__version__) or ("rc" in astropy.__version__)) + else astropy.__version__.replace(".", "") + ) return f"figure_hashes_mpl_{mpl_version}_ft_{ft2_version}_astropy_{astropy_version}.json" @@ -35,13 +41,14 @@ def test_simple_plot(): hash_library_name = get_hash_library_name() hash_library_file = Path(__file__).parent / hash_library_name - @pytest.mark.mpl_image_compare(hash_library=hash_library_file, - savefig_kwargs={'metadata': {'Software': None}}, - style='default') + @pytest.mark.mpl_image_compare( + hash_library=hash_library_file, savefig_kwargs={"metadata": {"Software": None}}, style="default" + ) @wraps(test_function) def test_wrapper(*args, **kwargs): ret = test_function(*args, **kwargs) if ret is None: ret = plt.gcf() return ret + return test_wrapper diff --git a/mpl_animators/tests/test_basefuncanimator.py b/mpl_animators/tests/test_basefuncanimator.py index c062957..0432213 100644 --- a/mpl_animators/tests/test_basefuncanimator.py +++ b/mpl_animators/tests/test_basefuncanimator.py @@ -26,22 +26,23 @@ def update_plotval(val, im, slider, data): def button_func1(*args, **kwargs): - print(*args, **kwargs) + pass -@pytest.mark.parametrize('fig, colorbar, buttons', - ((None, False, [[], []]), - (mfigure.Figure(), True, [[button_func1], ["hi"]]))) +@pytest.mark.parametrize( + ("fig", "colorbar", "buttons"), [(None, False, [[], []]), (mfigure.Figure(), True, [[button_func1], ["hi"]])] +) def test_base_func_init(fig, colorbar, buttons): - data = np.random.random((3, 10, 10)) + rng = np.random.default_rng() + data = rng.random((3, 10, 10)) func0 = partial(update_plotval, data=data) - func1 = partial(update_plotval, data=data*10) + func1 = partial(update_plotval, data=data * 10) funcs = [func0, func1] ranges = [(0, 3), (0, 3)] - tfa = FuncAnimatorTest(data, funcs, ranges, fig=fig, colorbar=colorbar, - button_func=buttons[0], - button_labels=buttons[1]) + tfa = FuncAnimatorTest( + data, funcs, ranges, fig=fig, colorbar=colorbar, button_func=buttons[0], button_labels=buttons[1] + ) tfa.label_slider(0, "hello") assert tfa.sliders[0]._slider.label.get_text() == "hello" @@ -50,41 +51,41 @@ def test_base_func_init(fig, colorbar, buttons): assert tfa.active_slider == 1 fig = tfa.fig - event = mback.KeyEvent(name='key_press_event', canvas=fig.canvas, key='down') + event = mback.KeyEvent(name="key_press_event", canvas=fig.canvas, key="down") tfa._key_press(event) assert tfa.active_slider == 0 - event.key = 'up' + event.key = "up" tfa._key_press(event) assert tfa.active_slider == 1 tfa.slider_buttons[tfa.active_slider]._button.clicked = False - event.key = 'p' - tfa._click_slider_button(event=event, button=tfa.slider_buttons[tfa.active_slider]._button, - slider=tfa.sliders[tfa.active_slider]._slider) + event.key = "p" + tfa._click_slider_button( + event=event, button=tfa.slider_buttons[tfa.active_slider]._button, slider=tfa.sliders[tfa.active_slider]._slider + ) assert tfa.slider_buttons[tfa.active_slider]._button.label._text == "||" tfa._key_press(event) assert tfa.slider_buttons[tfa.active_slider]._button.label._text == ">" - event.key = 'left' + event.key = "left" tfa._key_press(event) assert tfa.sliders[tfa.active_slider]._slider.val == tfa.sliders[tfa.active_slider]._slider.valmax - event.key = 'right' + event.key = "right" tfa._key_press(event) assert tfa.sliders[tfa.active_slider]._slider.val == tfa.sliders[tfa.active_slider]._slider.valmin - event.key = 'right' + event.key = "right" tfa._key_press(event) assert tfa.sliders[tfa.active_slider]._slider.val == tfa.sliders[tfa.active_slider]._slider.valmin + 1 - event.key = 'left' + event.key = "left" tfa._key_press(event) assert tfa.sliders[tfa.active_slider]._slider.val == tfa.sliders[tfa.active_slider]._slider.valmin - tfa._start_play(event, tfa.slider_buttons[tfa.active_slider]._button, - tfa.sliders[tfa.active_slider]._slider) + tfa._start_play(event, tfa.slider_buttons[tfa.active_slider]._button, tfa.sliders[tfa.active_slider]._slider) assert tfa.timer tfa._stop_play(event) @@ -101,7 +102,8 @@ def test_base_func_init(fig, colorbar, buttons): # Make sure figures created directly and through pyplot work @pytest.fixture(params=[plt.figure, mfigure.Figure]) def funcanimator(request): - data = np.random.random((3, 10, 10)) + rng = np.random.default_rng() + data = rng.random((3, 10, 10)) func = partial(update_plotval, data=data) funcs = [func] ranges = [(0, 3)] @@ -120,7 +122,8 @@ def test_to_axes(funcanimator): def test_axes_set(): - data = np.random.random((3, 10, 10)) + rng = np.random.default_rng() + data = rng.random((3, 10, 10)) funcs = [partial(update_plotval, data=data)] ranges = [(0, 3)] @@ -162,24 +165,21 @@ def update_plot(self, val, artist, slider): axis_ranges1 = np.tile(np.linspace(0, 100, 21), (10, 1)) -@pytest.mark.parametrize('axis_ranges, exp_extent, exp_axis_ranges', - [([None, None], [-0.5, 19.5], - [np.arange(10), np.array([-0.5, 19.5])]), - - ([[0, 10], [0, 20]], [0, 20], - [np.arange(0.5, 10.5), np.asarray([0, 20])]), - - ([np.arange(0, 11), np.arange(0, 21)], [0, 20], - [np.arange(0.5, 10.5), np.arange(0.5, 20.5)]), - - ([None, axis_ranges1], [0.0, 100.0], - [np.arange(10), base.edges_to_centers_nd(axis_ranges1, 1)])]) +@pytest.mark.parametrize( + ("axis_ranges", "exp_extent", "exp_axis_ranges"), + [ + ([None, None], [-0.5, 19.5], [np.arange(10), np.array([-0.5, 19.5])]), + ([[0, 10], [0, 20]], [0, 20], [np.arange(0.5, 10.5), np.asarray([0, 20])]), + ([np.arange(0, 11), np.arange(0, 21)], [0, 20], [np.arange(0.5, 10.5), np.arange(0.5, 20.5)]), + ([None, axis_ranges1], [0.0, 100.0], [np.arange(10), base.edges_to_centers_nd(axis_ranges1, 1)]), + ], +) def test_sanitize_axis_ranges(axis_ranges, exp_extent, exp_axis_ranges): data_shape = (10, 20) - data = np.random.rand(*data_shape) + rng = np.random.default_rng() + data = rng.random(data_shape) aanim = ArrayAnimatorTest(data=data) - out_axis_ranges, out_extent = aanim._sanitize_axis_ranges(axis_ranges=axis_ranges, - data_shape=data_shape) + out_axis_ranges, out_extent = aanim._sanitize_axis_ranges(axis_ranges=axis_ranges, data_shape=data_shape) assert exp_extent == out_extent assert np.array_equal(exp_axis_ranges[1], out_axis_ranges[1]) assert callable(out_axis_ranges[0]) @@ -189,20 +189,23 @@ def test_sanitize_axis_ranges(axis_ranges, exp_extent, exp_axis_ranges): XDATA = np.tile(np.linspace(0, 100, 11), (5, 5, 1)) -@pytest.mark.parametrize('plot_axis_index, axis_ranges, xlabel, xlim', - [(-1, None, None, None), - (-1, [None, None, XDATA], 'x-axis', None)]) +@pytest.mark.parametrize( + ("plot_axis_index", "axis_ranges", "xlabel", "xlim"), + [(-1, None, None, None), (-1, [None, None, XDATA], "x-axis", None)], +) def test_lineanimator_init(plot_axis_index, axis_ranges, xlabel, xlim): - data = np.random.random((5, 5, 10)) - LineAnimator(data=data, plot_axis_index=plot_axis_index, axis_ranges=axis_ranges, - xlabel=xlabel, xlim=xlim) + rng = np.random.default_rng() + data = rng.random((5, 5, 10)) + LineAnimator(data=data, plot_axis_index=plot_axis_index, axis_ranges=axis_ranges, xlabel=xlabel, xlim=xlim) def test_lineanimator_init_nans(): - data = np.random.random((5, 5, 10)) + rng = np.random.default_rng() + data = rng.random((5, 5, 10)) data[0][0][:] = np.nan - line_anim = LineAnimator(data=data, plot_axis_index=-1, axis_ranges=[None, None, XDATA], - xlabel='x-axis', xlim=None, ylim=None) + line_anim = LineAnimator( + data=data, plot_axis_index=-1, axis_ranges=[None, None, XDATA], xlabel="x-axis", xlim=None, ylim=None + ) assert line_anim.ylim[0] is not None assert line_anim.ylim[1] is not None assert line_anim.xlim[0] is not None @@ -211,12 +214,11 @@ def test_lineanimator_init_nans(): @figure_test def test_lineanimator_figure(): - np.random.seed(1) + rng = np.random.default_rng(seed=1) data_shape0 = (10, 20) - data0 = np.random.rand(*data_shape0) + data0 = rng.random(data_shape0) plot_axis0 = 1 slider_axis0 = 0 - xdata = np.tile(np.linspace( - 0, 100, (data_shape0[plot_axis0] + 1)), (data_shape0[slider_axis0], 1)) + xdata = np.tile(np.linspace(0, 100, (data_shape0[plot_axis0] + 1)), (data_shape0[slider_axis0], 1)) ani = LineAnimator(data0, plot_axis_index=plot_axis0, axis_ranges=[None, xdata]) return ani.fig diff --git a/mpl_animators/tests/test_wcs.py b/mpl_animators/tests/test_wcs.py index 0fa1879..98ca4ec 100644 --- a/mpl_animators/tests/test_wcs.py +++ b/mpl_animators/tests/test_wcs.py @@ -11,13 +11,15 @@ from mpl_animators.wcs import ArrayAnimatorWCS # See https://github.com/astropy/astropy/pull/10400 -pytestmark = pytest.mark.filterwarnings('ignore:target cannot be converted to ICRS, so will not be ' - 'set on SpectralCoord') +pytestmark = pytest.mark.filterwarnings( + "ignore:target cannot be converted to ICRS, so will not be " "set on SpectralCoord" +) -@pytest.fixture +@pytest.fixture() def wcs_4d(): - header = dedent("""\ + header = dedent( + """\ WCSAXES = 4 / Number of coordinate axes CRPIX1 = 0.0 / Pixel coordinate of reference point CRPIX2 = 0.0 / Pixel coordinate of reference point @@ -41,13 +43,15 @@ def wcs_4d(): CRVAL4 = 0.0 / [deg] Coordinate value at reference point LONPOLE = 180.0 / [deg] Native longitude of celestial pole LATPOLE = 0.0 / [deg] Native latitude of celestial pole - """) - return WCS(header=fits.Header.fromstring(header, sep='\n')) + """ + ) + return WCS(header=fits.Header.fromstring(header, sep="\n")) -@pytest.fixture +@pytest.fixture() def wcs_3d(): - header = dedent("""\ + header = dedent( + """\ NAXIS = 3 / Number of data axes NAXIS1 = 205 / NAXIS2 = 77 / @@ -75,18 +79,22 @@ def wcs_3d(): PC3_2 = 0.000939457726278 / PC3_3 = 0.999988496304 / PC2_3 = -0.135178965950 / - """) - return WCS(header=fits.Header.fromstring(header, sep='\n')) - - -@pytest.mark.parametrize("data, slices, dim", ( - (np.arange(120).reshape((5, 4, 3, 2)), [0, 0, 'x', 'y'], 2), - (np.arange(120).reshape((5, 4, 3, 2)), [0, 'x', 0, 'y'], 2), - (np.arange(120).reshape((5, 4, 3, 2)), ['x', 0, 0, 'y'], 2), - (np.arange(120).reshape((5, 4, 3, 2)), ['y', 0, 'x', 0], 2), - (np.arange(120).reshape((5, 4, 3, 2)), ['x', 'y', 0, 0], 2), - (np.arange(120).reshape((5, 4, 3, 2)), [0, 0, 0, 'x'], 1), -)) + """ + ) + return WCS(header=fits.Header.fromstring(header, sep="\n")) + + +@pytest.mark.parametrize( + ("data", "slices", "dim"), + [ + (np.arange(120).reshape((5, 4, 3, 2)), [0, 0, "x", "y"], 2), + (np.arange(120).reshape((5, 4, 3, 2)), [0, "x", 0, "y"], 2), + (np.arange(120).reshape((5, 4, 3, 2)), ["x", 0, 0, "y"], 2), + (np.arange(120).reshape((5, 4, 3, 2)), ["y", 0, "x", 0], 2), + (np.arange(120).reshape((5, 4, 3, 2)), ["x", "y", 0, 0], 2), + (np.arange(120).reshape((5, 4, 3, 2)), [0, 0, 0, "x"], 1), + ], +) def test_construct_array_animator(wcs_4d, data, slices, dim): array_animator = ArrayAnimatorWCS(data, wcs_4d, slices) @@ -94,7 +102,7 @@ def test_construct_array_animator(wcs_4d, data, slices, dim): assert array_animator.plot_dimensionality == dim assert array_animator.num_sliders == data.ndim - dim for i, (wslice, arange) in enumerate(zip(slices, array_animator.axis_ranges[::-1])): - if wslice not in ['x', 'y']: + if wslice not in ["x", "y"]: assert callable(arange) a = arange(0) if "pos" in wcs_4d.world_axis_physical_types[i]: @@ -110,41 +118,40 @@ def test_construct_array_animator(wcs_4d, data, slices, dim): def test_constructor_errors(wcs_4d): # WCS is not BaseLowLevelWCS - with pytest.raises(ValueError, match="provided that implements the astropy WCS API."): - ArrayAnimatorWCS(np.arange(25).reshape((5, 5)), {}, ['x', 'y']) + with pytest.raises(TypeError, match="provided that implements the astropy WCS API."): + ArrayAnimatorWCS(np.arange(25).reshape((5, 5)), {}, ["x", "y"]) # Data has wrong number of dimensions with pytest.raises(ValueError, match="Dimensionality of the data and WCS object do not match."): - ArrayAnimatorWCS(np.arange(25).reshape((5, 5)), wcs_4d, ['x', 'y']) + ArrayAnimatorWCS(np.arange(25).reshape((5, 5)), wcs_4d, ["x", "y"]) # Slices is wrong length with pytest.raises(ValueError, match="slices should be the same length"): - ArrayAnimatorWCS(np.arange(16).reshape((2, 2, 2, 2)), wcs_4d, ['x', 'y']) + ArrayAnimatorWCS(np.arange(16).reshape((2, 2, 2, 2)), wcs_4d, ["x", "y"]) # x not in slices with pytest.raises(ValueError, match="slices should contain at least"): - ArrayAnimatorWCS(np.arange(16).reshape((2, 2, 2, 2)), wcs_4d, [0, 0, 0, 'y']) + ArrayAnimatorWCS(np.arange(16).reshape((2, 2, 2, 2)), wcs_4d, [0, 0, 0, "y"]) @figure_test def test_array_animator_wcs_2d_simple_plot(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) - a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y']) + a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, "x", "y"]) return a.fig @figure_test def test_array_animator_wcs_2d_clip_interval(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) - a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y'], clip_interval=(1, 99)*u.percent) + a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, "x", "y"], clip_interval=(1, 99) * u.percent) return a.fig def test_array_animator_wcs_2d_clip_interval_change(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) pclims = [5, 95] - a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y'], - clip_interval=pclims * u.percent) + a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, "x", "y"], clip_interval=pclims * u.percent) lims0 = a._get_2d_plot_limits() a.update_plot(1, a.im, a.sliders[0]._slider) lims1 = a._get_2d_plot_limits() @@ -156,20 +163,20 @@ def test_array_animator_wcs_2d_clip_interval_change(wcs_4d): @figure_test def test_array_animator_wcs_2d_celestial_sliders(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) - a = ArrayAnimatorWCS(data, wcs_4d, ['x', 'y', 0, 0]) + a = ArrayAnimatorWCS(data, wcs_4d, ["x", "y", 0, 0]) return a.fig def test_to_axes(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) - a = ArrayAnimatorWCS(data, wcs_4d, ['x', 'y', 0, 0]) + a = ArrayAnimatorWCS(data, wcs_4d, ["x", "y", 0, 0]) assert isinstance(a.axes, WCSAxes) @figure_test def test_array_animator_wcs_2d_update_plot(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) - a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y']) + a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, "x", "y"]) a.update_plot(1, a.im, a.sliders[0]._slider) return a.fig @@ -177,7 +184,7 @@ def test_array_animator_wcs_2d_update_plot(wcs_4d): @figure_test def test_array_animator_wcs_2d_transpose_update_plot(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) - a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'y', 'x'], colorbar=True) + a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, "y", "x"], colorbar=True) a.update_plot(1, a.im, a.sliders[0]._slider) return a.fig @@ -186,9 +193,8 @@ def test_array_animator_wcs_2d_transpose_update_plot(wcs_4d): def test_array_animator_wcs_2d_colorbar_buttons(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) bf = [lambda x: x] * 10 - bl = ['h'] * 10 - a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'y', 'x'], - colorbar=True, button_func=bf, button_labels=bl) + bl = ["h"] * 10 + a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, "y", "x"], colorbar=True, button_func=bf, button_labels=bl) a.update_plot(1, a.im, a.sliders[0]._slider) return a.fig @@ -197,7 +203,7 @@ def test_array_animator_wcs_2d_colorbar_buttons(wcs_4d): def test_array_animator_wcs_2d_colorbar_buttons_default_labels(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) bf = [lambda x: x] * 10 - a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'y', 'x'], colorbar=True, button_func=bf) + a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, "y", "x"], colorbar=True, button_func=bf) a.update_plot(1, a.im, a.sliders[0]._slider) return a.fig @@ -211,9 +217,14 @@ def vmax_slider(val, im, slider): im.set_clim(vmax=val) data = np.arange(120).reshape((5, 4, 3, 2)) - a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'y', 'x'], colorbar=True, - slider_functions=[vmin_slider, vmax_slider], - slider_ranges=[[0, 100], [0, 100]]) + a = ArrayAnimatorWCS( + data, + wcs_4d, + [0, 0, "y", "x"], + colorbar=True, + slider_functions=[vmin_slider, vmax_slider], + slider_ranges=[[0, 100], [0, 100]], + ) a.update_plot(1, a.im, a.sliders[0]._slider) return a.fig @@ -221,7 +232,7 @@ def vmax_slider(val, im, slider): @figure_test def test_array_animator_wcs_1d_update_plot(wcs_4d): data = np.arange(120).reshape((5, 4, 3, 2)) - a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 0], ylabel="Y axis!") + a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, "x", 0], ylabel="Y axis!") a.sliders[0]._slider.set_val(1) return a.fig @@ -240,7 +251,7 @@ def test_array_animator_wcs_1d_update_plot_masked(wcs_3d): # Check that the generated data satisfies the test condition assert data.mask[0, 0].all() - a = ArrayAnimatorWCS(data, wcs_3d, ['x', 0, 0], ylabel="Y axis!") + a = ArrayAnimatorWCS(data, wcs_3d, ["x", 0, 0], ylabel="Y axis!") a.sliders[0]._slider.set_val(wcs_3d.array_shape[0] / 2) return a.fig @@ -248,50 +259,37 @@ def test_array_animator_wcs_1d_update_plot_masked(wcs_3d): @figure_test def test_array_animator_wcs_coord_params(wcs_4d): - coord_params = { - 'hpln': { - 'format_unit': u.deg, - 'major_formatter': 'hh:mm:ss', - 'axislabel': 'Longitude', - 'ticks': {'spacing': 10*u.arcsec} + "hpln": { + "format_unit": u.deg, + "major_formatter": "hh:mm:ss", + "axislabel": "Longitude", + "ticks": {"spacing": 10 * u.arcsec}, } } data = np.arange(120).reshape((5, 4, 3, 2)) - a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y'], coord_params=coord_params) + a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, "x", "y"], coord_params=coord_params) return a.fig @figure_test def test_array_animator_wcs_coord_params_no_ticks(wcs_4d): - coord_params = { - 'hpln': { - 'format_unit': u.deg, - 'major_formatter': 'hh:mm:ss', - 'axislabel': 'Longitude', - 'ticks': False - } + "hpln": {"format_unit": u.deg, "major_formatter": "hh:mm:ss", "axislabel": "Longitude", "ticks": False} } data = np.arange(120).reshape((5, 4, 3, 2)) - a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y'], coord_params=coord_params) + a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, "x", "y"], coord_params=coord_params) return a.fig @figure_test def test_array_animator_wcs_coord_params_grid(wcs_4d): - coord_params = { - 'hpln': { - 'format_unit': u.deg, - 'major_formatter': 'hh:mm:ss', - 'axislabel': 'Longitude', - 'grid': True - } + "hpln": {"format_unit": u.deg, "major_formatter": "hh:mm:ss", "axislabel": "Longitude", "grid": True} } data = np.arange(120).reshape((5, 4, 3, 2)) - a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, 'x', 'y'], coord_params=coord_params) + a = ArrayAnimatorWCS(data, wcs_4d, [0, 0, "x", "y"], coord_params=coord_params) return a.fig diff --git a/mpl_animators/wcs.py b/mpl_animators/wcs.py index 16cdfcf..bb05fb0 100644 --- a/mpl_animators/wcs.py +++ b/mpl_animators/wcs.py @@ -9,12 +9,13 @@ from .base import ArrayAnimator -__all__ = ['ArrayAnimatorWCS'] +__all__ = ["ArrayAnimatorWCS"] class ArrayAnimatorWCS(ArrayAnimator): """ - Animate an array with associated `~astropy.wcs.wcsapi.BaseLowLevelWCS` object. + Animate an array with associated `~astropy.wcs.wcsapi.BaseLowLevelWCS` + object. The following keyboard shortcuts are defined in the viewer: @@ -64,17 +65,29 @@ class ArrayAnimatorWCS(ArrayAnimator): If provided, the data for each step will be clipped to the percentile interval bounded by the two numbers. """ - def __init__(self, data, wcs, slices, coord_params=None, ylim='dynamic', ylabel=None, - clip_interval: u.percent = None, **kwargs): + def __init__( + self, + data, + wcs, + slices, + coord_params=None, + ylim="dynamic", + ylabel=None, + clip_interval: u.percent = None, + **kwargs, + ): if not isinstance(wcs, BaseLowLevelWCS): - raise ValueError("A WCS object should be provided that implements the astropy WCS API.") + msg = "A WCS object should be provided that implements the astropy WCS API." + raise TypeError(msg) if wcs.pixel_n_dim != data.ndim: - raise ValueError("Dimensionality of the data and WCS object do not match.") + msg = "Dimensionality of the data and WCS object do not match." + raise ValueError(msg) if len(slices) != wcs.pixel_n_dim: - raise ValueError("slices should be the same length as the number of pixel dimensions.") + msg = "slices should be the same length as the number of pixel dimensions." + raise ValueError(msg) if "x" not in slices: - raise ValueError( - "slices should contain at least 'x' to indicate the axis to plot on the x axis.") + msg = "slices should contain at least 'x' to indicate the axis to plot on the x axis." + raise ValueError(msg) self.plot_dimensionality = 1 @@ -92,23 +105,23 @@ def __init__(self, data, wcs, slices, coord_params=None, ylim='dynamic', ylabel= self.ylabel = ylabel if clip_interval is not None and len(clip_interval) != 2: - raise ValueError('A range of 2 values must be specified for clip_interval.') + msg = "A range of 2 values must be specified for clip_interval." + raise ValueError(msg) self.clip_interval = clip_interval extra_slider_labels = [] if "slider_functions" in kwargs and "slider_labels" not in kwargs: - extra_slider_labels = [a.__name__ for a in kwargs['slider_functions']] + extra_slider_labels = [a.__name__ for a in kwargs["slider_functions"]] slider_labels = self._compute_slider_labels_from_wcs(slices) + extra_slider_labels - super().__init__(data, image_axes=image_axes, axis_ranges=None, - slider_labels=slider_labels, - **kwargs) + super().__init__(data, image_axes=image_axes, axis_ranges=None, slider_labels=slider_labels, **kwargs) def _get_wcs_labels(self): """ - Read first the axes names property of the wcs and fall back to physical types. + Read first the axes names property of the wcs and fall back to physical + types. """ # Return the name if it is set, or the physical type if it is not. return [l or t for l, t in zip(self.wcs.world_axis_names, self.wcs.world_axis_physical_types)] @@ -116,22 +129,23 @@ def _get_wcs_labels(self): def _compute_slider_labels_from_wcs(self, slices): """ For each pixel dimension, not used in the plot, calculate the world - names which are correlated with that pixel dimension. This can return - more than one world name per pixel dimension (i.e. lat & lon) so join - them if there are. + names which are correlated with that pixel dimension. + + This can return more than one world name per pixel dimension + (i.e. lat & lon) so join them if there are. """ labels = [] wal = np.array(self._get_wcs_labels()) - pixel_indicies = np.array([a not in ['x', 'y'] for a in slices]) + pixel_indicies = np.array([a not in ["x", "y"] for a in slices]) for sliced_axis in self.wcs.axis_correlation_matrix[:, pixel_indicies].T: - labels.append(" / ".join(list(map(str, wal[sliced_axis])))) + labels.append(" / ".join([str(wal[axis]) for axis in sliced_axis])) return labels[::-1] def _partial_pixel_to_world(self, pixel_dimension, pixel_coord): """ - Return the world coordinate along one axis, if it is only - correlated to that axis. + Return the world coordinate along one axis, if it is only correlated to + that axis. """ wcs_dimension = self.wcs.pixel_n_dim - pixel_dimension - 1 corr = self.wcs.axis_correlation_matrix[:, wcs_dimension] @@ -195,13 +209,11 @@ def _apply_coord_params(self, axes): elif isinstance(ticks, dict): coord.set_ticks(**ticks) else: - raise TypeError( - "The 'ticks' value in the coord_params dictionary must be a dict or a boolean." - ) + msg = "The 'ticks' value in the coord_params dictionary must be a dict or a boolean." + raise TypeError(msg) def _setup_main_axes(self): - self.axes = self.fig.add_axes([0.1, 0.1, 0.8, 0.8], projection=self.wcs, - slices=self.slices_wcsaxes) + self.axes = self.fig.add_axes([0.1, 0.1, 0.8, 0.8], projection=self.wcs, slices=self.slices_wcsaxes) self._apply_coord_params(self.axes) def plot_start_image(self, ax): @@ -217,12 +229,12 @@ def update_plot(self, val, artist, slider): """ Update the plot when a slider changes. - This method both updates the state of the Animator and also re-draws - the matplotlib artist. + This method both updates the state of the Animator and also re- + draws the matplotlib artist. """ ind = int(val) if ind == int(slider.cval): - return + return None ax_ind = self.slider_axes[slider.slider_ind] self.frame_slice[ax_ind] = ind self.slices_wcsaxes[self.wcs.pixel_n_dim - ax_ind - 1] = ind @@ -239,11 +251,12 @@ def plot_start_image_1d(self, ax): """ Set up a line plot. - When plotting with WCSAxes, we always plot against pixel coordinate. + When plotting with WCSAxes, we always plot against pixel + coordinate. """ - if self.ylim != 'dynamic': + if self.ylim != "dynamic": ylim = self.ylim - if ylim == 'fixed': + if ylim == "fixed": ylim = (float(self.data.min()), float(self.data.max())) ax.set_ylim(ylim) @@ -251,7 +264,7 @@ def plot_start_image_1d(self, ax): ax.set_ylabel(self.ylabel) ydata = self.data[self.frame_index] - line, = ax.plot(ydata, **self.imshow_kwargs) + (line,) = ax.plot(ydata, **self.imshow_kwargs) if isinstance(self.data, np.ma.MaskedArray): ax.set_xlim((0, ydata.shape[0])) @@ -263,10 +276,9 @@ def data_transposed(self): """ Return data for 2D plotting, transposed if needed. """ - if self.slices_wcsaxes.index('y') < self.slices_wcsaxes.index("x"): + if self.slices_wcsaxes.index("y") < self.slices_wcsaxes.index("x"): return self.data[self.frame_index].transpose() - else: - return self.data[self.frame_index] + return self.data[self.frame_index] def update_plot_1d(self, val, line, slider): """ @@ -276,27 +288,25 @@ def update_plot_1d(self, val, line, slider): line.set_ydata(self.data[self.frame_index]) # If we are not setting ylim globally then we set it per frame. - if self.ylim == 'dynamic': - self.axes.set_ylim(float(self.data[self.frame_index].min()), - float(self.data[self.frame_index].max())) + if self.ylim == "dynamic": + self.axes.set_ylim(float(self.data[self.frame_index].min()), float(self.data[self.frame_index].max())) slider.cval = val def plot_start_image_2d(self, ax): """ Setup an image plot. """ - imshow_args = {'interpolation': 'nearest', - 'origin': 'lower'} + imshow_args = {"interpolation": "nearest", "origin": "lower"} imshow_args.update(self.imshow_kwargs) if self.clip_interval is not None: - imshow_args['vmin'], imshow_args['vmax'] = self._get_2d_plot_limits() + imshow_args["vmin"], imshow_args["vmax"] = self._get_2d_plot_limits() im = modest_image.imshow(ax, self.data_transposed, **imshow_args) - if 'extent' in imshow_args: - ax.set_xlim(imshow_args['extent'][:2]) - ax.set_ylim(imshow_args['extent'][2:]) + if "extent" in imshow_args: + ax.set_xlim(imshow_args["extent"][:2]) + ax.set_ylim(imshow_args["extent"][2:]) else: ny, nx = self.data_transposed.shape ax.set_xlim(-0.5, nx - 0.5) @@ -314,7 +324,7 @@ def _get_2d_plot_limits(self): """ Get vmin, vmax of a data slice when clip_interval is specified. """ - percent_limits = self.clip_interval.to('%').value + percent_limits = self.clip_interval.to("%").value vmin, vmax = AsymmetricPercentileInterval(*percent_limits).get_limits(self.data_transposed) return vmin, vmax diff --git a/pyproject.toml b/pyproject.toml index 953cec9..908b401 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,79 @@ [build-system] +requires = [ + "setuptools", + "setuptools_scm[toml]", + "wheel", +] +build-backend = 'setuptools.build_meta' -requires = ["setuptools", - "setuptools_scm", - "wheel"] +[project] +name = "mpl_animators" +dynamic = ["version"] +description = "An interactive animation framework for matplotlib" +readme = "README.rst" +requires-python = ">=3.9" +license = {file = "LICENSE.txt"} +keywords = ["matplotlib", "animations", "mutli-dimensional", "interactive"] +authors = [ + {email = "sunpy@googlegroups.com"}, + {name = "The SunPy Developers"} +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: BSD License", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Physics", +] +dependencies = [ + 'matplotlib>=3.5.0', + 'numpy>=1.21.0', +] -build-backend = 'setuptools.build_meta' +[project.urls] +changelog = "https://github.com/sunpy/mpl-animators/releases" +documentation = "https://docs.sunpy.org/projects/mpl-animators" +issue_tracker = "https://github.com/sunpy/mpl-animators/issues" +repository = "https://github.com/sunpy/mpl-animators" + +[project.optional-dependencies] +all = ["mpl-animators"] +tests = [ + "mpl-animators[all]", + 'pytest-astropy', + "pytest-mpl", +] +docs = [ + "mpl-animators[all]", + 'sphinx', + 'sphinx-automodapi', + 'sphinx-gallery', + 'sunpy-sphinx-theme', + 'sunpy[all]', +] + +[tool.setuptools_scm] +write_to = "mpl_animators/version.py" + +[tool.setuptools] +include-package-data = true +platforms = ["any"] +provides = ["mpl_animators"] +license-files = ["LICENSE.rst"] + +[tool.setuptools.packages.find] +namespaces = false + +[tool.codespell] +skip = "*.asdf,*.fits,*.fts,*.header,*.json,*.xsh,*cache*,*egg*,*extern*,.git,.idea,.tox,_build,*truncated,*.svg,.asv_env,.history" +ignore-words-list = "sav,nd" [ tool.gilesbot ] [ tool.gilesbot.circleci_artifacts ] diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..52f2dc5 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,27 @@ +[pytest] +minversion = 7.0 +testpaths = + mpl_animators + docs +norecursedirs = + .tox + build + docs/_build + docs/generated + *.egg-info + examples + .history +doctest_plus = enabled +doctest_optionflags = NORMALIZE_WHITESPACE FLOAT_CMP ELLIPSIS +addopts = --arraydiff --doctest-rst --doctest-ignore-import-errors -p no:unraisableexception -p no:threadexception +remote_data_strict = true +junit_family = xunit1 +filterwarnings = + error + # Do not fail on pytest config issues (i.e. missing plugins) but do show them + always::pytest.PytestConfigWarning + # A list of warnings to ignore follows. If you add to this list, you MUST + # add a comment or ideally a link to an issue that explains why the warning + # is being ignored + ignore:.*utcfromtimestamp.*:DeprecationWarning + ignore:.*may indicate binary incompatibility.*:RuntimeWarning \ No newline at end of file diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..ece1ad4 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,80 @@ +# Allow unused variables when underscore-prefixed. +lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +target-version = "py39" +line-length = 120 +exclude=[ + ".git,", + "__pycache__", + "build", + "tools/**", +] +lint.select = [ + "A", + "ASYNC", + "B", + "BLE", + "C4", + "COM", + "DTZ", + "E", + "EM", + "ERA", + "EXE", + "F", + "FBT", + "FLY", + "G", + "I", + "ICN", + "INP", + "INT", + "ISC", + "LOG", + "NPY", + "PGH", + "PIE", + "PLE", + "PT", + "PTH", + "PYI", + "Q", + "RET", + "RSE", + "RUF", + "SIM", + "SLOT", + "T10", + "T20", + "TCH", + "TID", + "TRIO", + "TRY", + "UP", + "W", + "YTT", +] +lint.extend-ignore = [ + # TODO: Fix in future + "E501", # Line too long + "E741", # Ambiguous variable name + "FBT002", # Bool arg + "COM812", # May cause conflicts when used with the formatter + "ISC001", # May cause conflicts when used with the formatter +] + +[lint.per-file-ignores] +"examples/*.py" = [ + "INP001", # examples is part of an implicit namespace package + "T201", # We need print in our examples +] +"docs/conf.py" = [ + "INP001", # conf.py is part of an implicit namespace package +] + +[lint.pydocstyle] +convention = "numpy" + +[format] +docstring-code-format = true +indent-style = "space" +quote-style = "double" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 78caf6d..0000000 --- a/setup.cfg +++ /dev/null @@ -1,72 +0,0 @@ -[metadata] -name = mpl_animators -author = The SunPy Developers -author_email = sunpy@googlegroups.com -license = BSD 3-Clause -license_file = LICENSE.rst -url = https://sunpy.org -description = An interactive animation framework for matplotlib -long_description = file: README.rst -long_description_content_type = text/x-rst - -[options] -zip_safe = False -packages = find: -include_package_data = True -python_requires = >=3.9 -setup_requires = setuptools_scm -install_requires = - matplotlib>=3.5.0 - numpy>=1.21.0 - -[options.extras_require] -all = - astropy>=5.0.6,!=5.1.0 -wcs = - astropy>=5.0.6,!=5.1.0 -tests = - pytest - pytest-cov - pytest-mpl -docs = - sphinx - sphinx-automodapi - sunpy-sphinx-theme - -[tool:pytest] -testpaths = "mpl_animators" "docs" -mpl-results-path = figure_test_images -mpl-use-full-test-name = True - -[coverage:run] -omit = - mpl_animators/__init* - mpl_animators/conftest.py - mpl_animators/*setup_package* - mpl_animators/tests/* - mpl_animators/*/tests/* - mpl_animators/extern/* - mpl_animators/version* - */mpl_animators/__init* - */mpl_animators/conftest.py - */mpl_animators/*setup_package* - */mpl_animators/tests/* - */mpl_animators/*/tests/* - */mpl_animators/extern/* - */mpl_animators/version* - -[coverage:report] -exclude_lines = - # Have to re-enable the standard pragma - pragma: no cover - # Don't complain about packages we have installed - except ImportError - # Don't complain if tests don't hit assertions - raise AssertionError - raise NotImplementedError - # Don't complain about script hooks - def main\(.*\): - # Ignore branches that don't pertain to this version of Python - pragma: py{ignore_python_version} - # Don't complain about IPython completion helper - def _ipython_key_completions_ diff --git a/setup.py b/setup.py index e11957f..c823345 100755 --- a/setup.py +++ b/setup.py @@ -1,22 +1,4 @@ #!/usr/bin/env python -# Licensed under a 3-clause BSD style license - see LICENSE.rst - -import os - from setuptools import setup -VERSION_TEMPLATE = """ -# Note that we need to fall back to the hard-coded version if either -# setuptools_scm can't be imported or setuptools_scm can't determine the -# version, so we catch the generic 'Exception'. -try: - from setuptools_scm import get_version - __version__ = get_version(root='..', relative_to=__file__) -except Exception: - __version__ = '{version}' -""".lstrip() - -setup( - use_scm_version={'write_to': os.path.join('mpl_animators', 'version.py'), - 'write_to_template': VERSION_TEMPLATE}, -) +setup() diff --git a/tox.ini b/tox.ini index 3a7b253..207f1b5 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] minversion = 4.0 envlist = - py{39,310,311}{,-devdeps,-figure} + py{39,310,311,312}{,-devdeps,-figure} build_docs codestyle @@ -22,7 +22,8 @@ extras = all tests setenv = - PYTEST_COMMAND = pytest -vvv -s -raR --pyargs mpl_animators --cov-report=xml --cov=mpl_animators --cov-config={toxinidir}/setup.cfg {toxinidir}/docs + PYTEST_COMMAND = pytest -vvv -s -raR --pyargs mpl_animators --cov-report=xml --cov=mpl_animators {toxinidir}/docs + MPL_BACKEND = agg allowlist_externals = /bin/bash commands = pip freeze From 3b4cf6434a71840597406f00e8940b73a87abdbf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Feb 2024 22:37:31 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .pre-commit-config.yaml | 3 +++ docs/api.rst | 1 + docs/conf.py | 23 +++++++++++++---------- docs/index.rst | 9 ++++----- pytest.ini | 3 +-- tox.ini | 3 +++ 6 files changed, 25 insertions(+), 17 deletions(-) create mode 100644 docs/api.rst diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 685b486..c2dc8da 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,6 @@ +ci: + autofix_prs: false + autoupdate_schedule: "quarterly" repos: - repo: https://github.com/myint/docformatter rev: v1.7.5 diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 0000000..1ce0bf3 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1 @@ +.. automodapi:: mpl_animators diff --git a/docs/conf.py b/docs/conf.py index 2da588c..2426097 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,22 +12,24 @@ # -- General configuration --------------------------------------------------- extensions = [ + "sphinx_gallery.gen_gallery", + "sphinx_automodapi.automodapi", + "sphinx_automodapi.smart_resolver", "sphinx.ext.autodoc", - "sphinx.ext.intersphinx", - "sphinx.ext.todo", "sphinx.ext.coverage", - "sphinx.ext.inheritance_diagram", - "sphinx.ext.viewcode", - "sphinx.ext.napoleon", "sphinx.ext.doctest", + "sphinx.ext.inheritance_diagram", + "sphinx.ext.intersphinx", "sphinx.ext.mathjax", - "sphinx_automodapi.automodapi", - "sphinx_automodapi.smart_resolver", + "sphinx.ext.napoleon", + "sphinx.ext.todo", + "sphinx.ext.viewcode", ] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] source_suffix = ".rst" master_doc = "index" default_role = "obj" +html_theme = "sunpy" # -- Options for intersphinx extension --------------------------------------- intersphinx_mapping = { @@ -48,14 +50,15 @@ # -- Sphinx Gallery ------------------------------------------------------------ sphinx_gallery_conf = { - "backreferences_dir": Path("generated") / "modules", + "backreferences_dir": (Path("generated") / "modules").absolute(), "filename_pattern": "^((?!skip_).)*$", - "examples_dirs": Path("..") / "examples", - "gallery_dirs": Path("generated") / "gallery", + "examples_dirs": (Path("..") / "examples").absolute(), + "gallery_dirs": (Path("generated") / "gallery").absolute(), "matplotlib_animations": True, "default_thumb_file": PNG_ICON, "abort_on_example_error": False, "plot_gallery": "True", "remove_config_comments": True, + "doc_module": ("mpl_animators"), "only_warn_on_example_error": True, } diff --git a/docs/index.rst b/docs/index.rst index ce825b0..5621d59 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,9 +6,8 @@ The ``mpl_animators`` package provides a set of classes which allow the easy con As well as this there is a specialised `.ArrayAnimatorWCS` class which can make line or image plots for a numpy array and associated World Coordinate System (WCS) object from `astropy`. Finally, there are two base classes: `.BaseFuncAnimator` which can be extended to generate an interactive visualization from any data structure and set of functions to update the plot, and `.ArrayAnimator` which can be extended to generate any visualisation based on the axes of a numpy array. - -.. automodapi:: mpl_animators - .. toctree:: - :maxdepth: 2 - :caption: Contents: + :maxdepth: 1 + + generated/gallery/index + api diff --git a/pytest.ini b/pytest.ini index 52f2dc5..3bcd18e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -14,7 +14,6 @@ norecursedirs = doctest_plus = enabled doctest_optionflags = NORMALIZE_WHITESPACE FLOAT_CMP ELLIPSIS addopts = --arraydiff --doctest-rst --doctest-ignore-import-errors -p no:unraisableexception -p no:threadexception -remote_data_strict = true junit_family = xunit1 filterwarnings = error @@ -24,4 +23,4 @@ filterwarnings = # add a comment or ideally a link to an issue that explains why the warning # is being ignored ignore:.*utcfromtimestamp.*:DeprecationWarning - ignore:.*may indicate binary incompatibility.*:RuntimeWarning \ No newline at end of file + ignore:.*may indicate binary incompatibility.*:RuntimeWarning diff --git a/tox.ini b/tox.ini index 207f1b5..f3c441d 100644 --- a/tox.ini +++ b/tox.ini @@ -18,6 +18,9 @@ deps = oldestdeps: matplotlib==3.5.0 figure-!devdeps-!oldestdeps: astropy==5.3.4 figure-!devdeps-!oldestdeps: matplotlib==3.8.1 + # Due to https://github.com/matplotlib/pytest-mpl/issues/216 + # Needs a release of pytest-mpl with the fix + figure: pluggy<1.4 extras = all tests