From a72b7092146c322e364fff7f4591256984675445 Mon Sep 17 00:00:00 2001 From: Michael Bradley Date: Sun, 23 Feb 2025 11:38:32 -0500 Subject: [PATCH] Make MyPy happy You know, I don't think this makes it more readable. The problem is really that MPL's typed interface isn't very good, especially when dealing with 3D graphs, so I've gotta do a bunch of workaround to get around it. --- data.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/data.py b/data.py index a382516..87cdbe5 100644 --- a/data.py +++ b/data.py @@ -1,10 +1,13 @@ +from abc import ABC, abstractmethod from pathlib import Path +from typing import Any, Callable, cast, Literal import matplotlib.animation as animation import matplotlib.cm as cm import matplotlib.pyplot as plt -import numpy as np from matplotlib.collections import PathCollection +from matplotlib.colors import LinearSegmentedColormap +import numpy as np from numpy.typing import NDArray from physics import n_body_matrix @@ -45,6 +48,17 @@ def parse_csv(file: Path) -> tuple[NDArray, NDArray, NDArray]: return pos, vel, rad +class Axis(ABC): + """Improved types for ``matplotlib.axes.Axes``""" + @abstractmethod + def axis(self, _arg: tuple[float, float, float, float, float, float] | tuple[float, float, float, float], /) -> Any: ... + @abstractmethod + def scatter(self, x: NDArray, y: NDArray, z: NDArray | None = None, *, c: NDArray | None = None, s: NDArray | None = None, depthshade: bool = True) -> PathCollection: ... + + +type Set3DProperties = Callable[[float | list[float] | NDArray, Literal["x", "y", "z"]], None] + + class Animator: """Runs the simulation and displays it in a plot""" def __init__(self, pos: NDArray, vel: NDArray, rad: NDArray, gravity: float) -> None: @@ -69,7 +83,7 @@ class Animator: # Objects will be represented using a scatter plot self._scatter_plot: plt.PathCollection | None = None # We'll give each objects a random colour to better differentiate them - self._object_colours: NDArray = cm.rainbow( + self._object_colours = cast(LinearSegmentedColormap, cast(Any, cm).rainbow)( np.random.random( (object_count,) ) @@ -78,9 +92,9 @@ class Animator: # Create plot in an appropriate number of dimensions self._fig = plt.figure() if dimensions == 2: - self._ax = self._fig.add_subplot() + self._ax = cast(Axis, self._fig.add_subplot()) else: - self._ax = self._fig.add_subplot(projection="3d") + self._ax = cast(Axis, self._fig.add_subplot(projection="3d")) # Set up animation loop # This attribute never gets used again, but we'll keep a reference so that it doesn't get garbage collected @@ -111,7 +125,7 @@ class Animator: s=self._rad * 10, # To make the objects more visible ) # These values work nicely for a landscape window - self._ax.axis([-950, 950, -500, 500]) + self._ax.axis((-950., 950., -500., 500.)) else: self._scatter_plot = self._ax.scatter( self._pos[:, 0], @@ -122,7 +136,7 @@ class Animator: depthshade=False, # I find it confusing, YMMV ) # These values work nicely for a square window - self._ax.axis([-500, 500, -500, 500, -500, 500]) + self._ax.axis((-500., 500., -500., 500., -500., 500.)) return self._scatter_plot, def update(self, *_args, **_kwargs) -> tuple[PathCollection]: @@ -132,6 +146,10 @@ class Animator: "arg _kwargs: Again, not necessary for us :return: The single scatter plot we're using """ + # As long as this is called after setup_plot() we're good + # Hey, dealing with race conditions is my day job, I will 100% ignore them on my on time + assert self._scatter_plot + _n, dimensions = self._pos.shape # Update the state of our simulation n_body_matrix(self._pos, self._vel, self._mass, self._gravity) @@ -140,7 +158,7 @@ class Animator: self._scatter_plot.set_offsets(self._pos[:, :2]) if dimensions == 3: # Update the Z value if in 3D - self._scatter_plot.set_3d_properties(self._pos[:, 2], 'z') + cast(Set3DProperties, cast(Any, self._scatter_plot).set_3d_properties)(self._pos[:, 2], 'z') # Use radius to represent mass self._scatter_plot.set_sizes(self._rad[:, 0] * 10) # Redraw the plot