You know, I don't think this makes it more readable. The problem is really that MPL's typed interface isn't very good, especially when dealing with 3D graphs, so I've gotta do a bunch of workaround to get around it.
This commit is contained in:
parent
8dcb9b21c4
commit
a72b709214
1 changed files with 25 additions and 7 deletions
32
data.py
32
data.py
|
|
@ -1,10 +1,13 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, cast, Literal
|
||||
|
||||
import matplotlib.animation as animation
|
||||
import matplotlib.cm as cm
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from matplotlib.collections import PathCollection
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from physics import n_body_matrix
|
||||
|
|
@ -45,6 +48,17 @@ def parse_csv(file: Path) -> tuple[NDArray, NDArray, NDArray]:
|
|||
return pos, vel, rad
|
||||
|
||||
|
||||
class Axis(ABC):
|
||||
"""Improved types for ``matplotlib.axes.Axes``"""
|
||||
@abstractmethod
|
||||
def axis(self, _arg: tuple[float, float, float, float, float, float] | tuple[float, float, float, float], /) -> Any: ...
|
||||
@abstractmethod
|
||||
def scatter(self, x: NDArray, y: NDArray, z: NDArray | None = None, *, c: NDArray | None = None, s: NDArray | None = None, depthshade: bool = True) -> PathCollection: ...
|
||||
|
||||
|
||||
type Set3DProperties = Callable[[float | list[float] | NDArray, Literal["x", "y", "z"]], None]
|
||||
|
||||
|
||||
class Animator:
|
||||
"""Runs the simulation and displays it in a plot"""
|
||||
def __init__(self, pos: NDArray, vel: NDArray, rad: NDArray, gravity: float) -> None:
|
||||
|
|
@ -69,7 +83,7 @@ class Animator:
|
|||
# Objects will be represented using a scatter plot
|
||||
self._scatter_plot: plt.PathCollection | None = None
|
||||
# We'll give each objects a random colour to better differentiate them
|
||||
self._object_colours: NDArray = cm.rainbow(
|
||||
self._object_colours = cast(LinearSegmentedColormap, cast(Any, cm).rainbow)(
|
||||
np.random.random(
|
||||
(object_count,)
|
||||
)
|
||||
|
|
@ -78,9 +92,9 @@ class Animator:
|
|||
# Create plot in an appropriate number of dimensions
|
||||
self._fig = plt.figure()
|
||||
if dimensions == 2:
|
||||
self._ax = self._fig.add_subplot()
|
||||
self._ax = cast(Axis, self._fig.add_subplot())
|
||||
else:
|
||||
self._ax = self._fig.add_subplot(projection="3d")
|
||||
self._ax = cast(Axis, self._fig.add_subplot(projection="3d"))
|
||||
|
||||
# Set up animation loop
|
||||
# This attribute never gets used again, but we'll keep a reference so that it doesn't get garbage collected
|
||||
|
|
@ -111,7 +125,7 @@ class Animator:
|
|||
s=self._rad * 10, # To make the objects more visible
|
||||
)
|
||||
# These values work nicely for a landscape window
|
||||
self._ax.axis([-950, 950, -500, 500])
|
||||
self._ax.axis((-950., 950., -500., 500.))
|
||||
else:
|
||||
self._scatter_plot = self._ax.scatter(
|
||||
self._pos[:, 0],
|
||||
|
|
@ -122,7 +136,7 @@ class Animator:
|
|||
depthshade=False, # I find it confusing, YMMV
|
||||
)
|
||||
# These values work nicely for a square window
|
||||
self._ax.axis([-500, 500, -500, 500, -500, 500])
|
||||
self._ax.axis((-500., 500., -500., 500., -500., 500.))
|
||||
return self._scatter_plot,
|
||||
|
||||
def update(self, *_args, **_kwargs) -> tuple[PathCollection]:
|
||||
|
|
@ -132,6 +146,10 @@ class Animator:
|
|||
"arg _kwargs: Again, not necessary for us
|
||||
:return: The single scatter plot we're using
|
||||
"""
|
||||
# As long as this is called after setup_plot() we're good
|
||||
# Hey, dealing with race conditions is my day job, I will 100% ignore them on my on time
|
||||
assert self._scatter_plot
|
||||
|
||||
_n, dimensions = self._pos.shape
|
||||
# Update the state of our simulation
|
||||
n_body_matrix(self._pos, self._vel, self._mass, self._gravity)
|
||||
|
|
@ -140,7 +158,7 @@ class Animator:
|
|||
self._scatter_plot.set_offsets(self._pos[:, :2])
|
||||
if dimensions == 3:
|
||||
# Update the Z value if in 3D
|
||||
self._scatter_plot.set_3d_properties(self._pos[:, 2], 'z')
|
||||
cast(Set3DProperties, cast(Any, self._scatter_plot).set_3d_properties)(self._pos[:, 2], 'z')
|
||||
# Use radius to represent mass
|
||||
self._scatter_plot.set_sizes(self._rad[:, 0] * 10)
|
||||
# Redraw the plot
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue