Make MyPy happy
All checks were successful
Lint / MyPy (push) Successful in 1m5s

You know, I don't think this makes it more readable.
The problem is really that MPL's typed interface isn't very good, especially when dealing with 3D graphs, so I've gotta do a bunch of workaround to get around it.
This commit is contained in:
Michael Bradley 2025-02-23 11:38:32 -05:00
parent 8dcb9b21c4
commit a72b709214
Signed by: MichaelBradley
SSH key fingerprint: SHA256:cj/YZ5VT+QOKncqSkx+ibKTIn0Obg7OIzwzl9BL8EO8

32
data.py
View file

@ -1,10 +1,13 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, cast, Literal
import matplotlib.animation as animation
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import PathCollection
from matplotlib.colors import LinearSegmentedColormap
import numpy as np
from numpy.typing import NDArray
from physics import n_body_matrix
@ -45,6 +48,17 @@ def parse_csv(file: Path) -> tuple[NDArray, NDArray, NDArray]:
return pos, vel, rad
class Axis(ABC):
"""Improved types for ``matplotlib.axes.Axes``"""
@abstractmethod
def axis(self, _arg: tuple[float, float, float, float, float, float] | tuple[float, float, float, float], /) -> Any: ...
@abstractmethod
def scatter(self, x: NDArray, y: NDArray, z: NDArray | None = None, *, c: NDArray | None = None, s: NDArray | None = None, depthshade: bool = True) -> PathCollection: ...
type Set3DProperties = Callable[[float | list[float] | NDArray, Literal["x", "y", "z"]], None]
class Animator:
"""Runs the simulation and displays it in a plot"""
def __init__(self, pos: NDArray, vel: NDArray, rad: NDArray, gravity: float) -> None:
@ -69,7 +83,7 @@ class Animator:
# Objects will be represented using a scatter plot
self._scatter_plot: plt.PathCollection | None = None
# We'll give each objects a random colour to better differentiate them
self._object_colours: NDArray = cm.rainbow(
self._object_colours = cast(LinearSegmentedColormap, cast(Any, cm).rainbow)(
np.random.random(
(object_count,)
)
@ -78,9 +92,9 @@ class Animator:
# Create plot in an appropriate number of dimensions
self._fig = plt.figure()
if dimensions == 2:
self._ax = self._fig.add_subplot()
self._ax = cast(Axis, self._fig.add_subplot())
else:
self._ax = self._fig.add_subplot(projection="3d")
self._ax = cast(Axis, self._fig.add_subplot(projection="3d"))
# Set up animation loop
# This attribute never gets used again, but we'll keep a reference so that it doesn't get garbage collected
@ -111,7 +125,7 @@ class Animator:
s=self._rad * 10, # To make the objects more visible
)
# These values work nicely for a landscape window
self._ax.axis([-950, 950, -500, 500])
self._ax.axis((-950., 950., -500., 500.))
else:
self._scatter_plot = self._ax.scatter(
self._pos[:, 0],
@ -122,7 +136,7 @@ class Animator:
depthshade=False, # I find it confusing, YMMV
)
# These values work nicely for a square window
self._ax.axis([-500, 500, -500, 500, -500, 500])
self._ax.axis((-500., 500., -500., 500., -500., 500.))
return self._scatter_plot,
def update(self, *_args, **_kwargs) -> tuple[PathCollection]:
@ -132,6 +146,10 @@ class Animator:
"arg _kwargs: Again, not necessary for us
:return: The single scatter plot we're using
"""
# As long as this is called after setup_plot() we're good
# Hey, dealing with race conditions is my day job, I will 100% ignore them on my on time
assert self._scatter_plot
_n, dimensions = self._pos.shape
# Update the state of our simulation
n_body_matrix(self._pos, self._vel, self._mass, self._gravity)
@ -140,7 +158,7 @@ class Animator:
self._scatter_plot.set_offsets(self._pos[:, :2])
if dimensions == 3:
# Update the Z value if in 3D
self._scatter_plot.set_3d_properties(self._pos[:, 2], 'z')
cast(Set3DProperties, cast(Any, self._scatter_plot).set_3d_properties)(self._pos[:, 2], 'z')
# Use radius to represent mass
self._scatter_plot.set_sizes(self._rad[:, 0] * 10)
# Redraw the plot