Make MyPy happy
All checks were successful
Lint / MyPy (push) Successful in 1m5s

You know, I don't think this makes it more readable.
The problem is really that MPL's typed interface isn't very good, especially when dealing with 3D graphs, so I've gotta do a bunch of workaround to get around it.
This commit is contained in:
Michael Bradley 2025-02-23 11:38:32 -05:00
parent 8dcb9b21c4
commit a72b709214
Signed by: MichaelBradley
SSH key fingerprint: SHA256:cj/YZ5VT+QOKncqSkx+ibKTIn0Obg7OIzwzl9BL8EO8

32
data.py
View file

@ -1,10 +1,13 @@
from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Callable, cast, Literal
import matplotlib.animation as animation import matplotlib.animation as animation
import matplotlib.cm as cm import matplotlib.cm as cm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import PathCollection from matplotlib.collections import PathCollection
from matplotlib.colors import LinearSegmentedColormap
import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from physics import n_body_matrix from physics import n_body_matrix
@ -45,6 +48,17 @@ def parse_csv(file: Path) -> tuple[NDArray, NDArray, NDArray]:
return pos, vel, rad return pos, vel, rad
class Axis(ABC):
"""Improved types for ``matplotlib.axes.Axes``"""
@abstractmethod
def axis(self, _arg: tuple[float, float, float, float, float, float] | tuple[float, float, float, float], /) -> Any: ...
@abstractmethod
def scatter(self, x: NDArray, y: NDArray, z: NDArray | None = None, *, c: NDArray | None = None, s: NDArray | None = None, depthshade: bool = True) -> PathCollection: ...
type Set3DProperties = Callable[[float | list[float] | NDArray, Literal["x", "y", "z"]], None]
class Animator: class Animator:
"""Runs the simulation and displays it in a plot""" """Runs the simulation and displays it in a plot"""
def __init__(self, pos: NDArray, vel: NDArray, rad: NDArray, gravity: float) -> None: def __init__(self, pos: NDArray, vel: NDArray, rad: NDArray, gravity: float) -> None:
@ -69,7 +83,7 @@ class Animator:
# Objects will be represented using a scatter plot # Objects will be represented using a scatter plot
self._scatter_plot: plt.PathCollection | None = None self._scatter_plot: plt.PathCollection | None = None
# We'll give each objects a random colour to better differentiate them # We'll give each objects a random colour to better differentiate them
self._object_colours: NDArray = cm.rainbow( self._object_colours = cast(LinearSegmentedColormap, cast(Any, cm).rainbow)(
np.random.random( np.random.random(
(object_count,) (object_count,)
) )
@ -78,9 +92,9 @@ class Animator:
# Create plot in an appropriate number of dimensions # Create plot in an appropriate number of dimensions
self._fig = plt.figure() self._fig = plt.figure()
if dimensions == 2: if dimensions == 2:
self._ax = self._fig.add_subplot() self._ax = cast(Axis, self._fig.add_subplot())
else: else:
self._ax = self._fig.add_subplot(projection="3d") self._ax = cast(Axis, self._fig.add_subplot(projection="3d"))
# Set up animation loop # Set up animation loop
# This attribute never gets used again, but we'll keep a reference so that it doesn't get garbage collected # This attribute never gets used again, but we'll keep a reference so that it doesn't get garbage collected
@ -111,7 +125,7 @@ class Animator:
s=self._rad * 10, # To make the objects more visible s=self._rad * 10, # To make the objects more visible
) )
# These values work nicely for a landscape window # These values work nicely for a landscape window
self._ax.axis([-950, 950, -500, 500]) self._ax.axis((-950., 950., -500., 500.))
else: else:
self._scatter_plot = self._ax.scatter( self._scatter_plot = self._ax.scatter(
self._pos[:, 0], self._pos[:, 0],
@ -122,7 +136,7 @@ class Animator:
depthshade=False, # I find it confusing, YMMV depthshade=False, # I find it confusing, YMMV
) )
# These values work nicely for a square window # These values work nicely for a square window
self._ax.axis([-500, 500, -500, 500, -500, 500]) self._ax.axis((-500., 500., -500., 500., -500., 500.))
return self._scatter_plot, return self._scatter_plot,
def update(self, *_args, **_kwargs) -> tuple[PathCollection]: def update(self, *_args, **_kwargs) -> tuple[PathCollection]:
@ -132,6 +146,10 @@ class Animator:
"arg _kwargs: Again, not necessary for us "arg _kwargs: Again, not necessary for us
:return: The single scatter plot we're using :return: The single scatter plot we're using
""" """
# As long as this is called after setup_plot() we're good
# Hey, dealing with race conditions is my day job, I will 100% ignore them on my on time
assert self._scatter_plot
_n, dimensions = self._pos.shape _n, dimensions = self._pos.shape
# Update the state of our simulation # Update the state of our simulation
n_body_matrix(self._pos, self._vel, self._mass, self._gravity) n_body_matrix(self._pos, self._vel, self._mass, self._gravity)
@ -140,7 +158,7 @@ class Animator:
self._scatter_plot.set_offsets(self._pos[:, :2]) self._scatter_plot.set_offsets(self._pos[:, :2])
if dimensions == 3: if dimensions == 3:
# Update the Z value if in 3D # Update the Z value if in 3D
self._scatter_plot.set_3d_properties(self._pos[:, 2], 'z') cast(Set3DProperties, cast(Any, self._scatter_plot).set_3d_properties)(self._pos[:, 2], 'z')
# Use radius to represent mass # Use radius to represent mass
self._scatter_plot.set_sizes(self._rad[:, 0] * 10) self._scatter_plot.set_sizes(self._rad[:, 0] * 10)
# Redraw the plot # Redraw the plot