All checks were successful
Lint / MyPy (push) Successful in 1m5s
You know, I don't think this makes it more readable. The problem is really that MPL's typed interface isn't very good, especially when dealing with 3D graphs, so I've gotta do a bunch of workaround to get around it.
166 lines
5.4 KiB
Python
166 lines
5.4 KiB
Python
from abc import ABC, abstractmethod
|
|
from pathlib import Path
|
|
from typing import Any, Callable, cast, Literal
|
|
|
|
import matplotlib.animation as animation
|
|
import matplotlib.cm as cm
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.collections import PathCollection
|
|
from matplotlib.colors import LinearSegmentedColormap
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
|
|
from physics import n_body_matrix
|
|
|
|
|
|
def parse_csv(file: Path) -> tuple[NDArray, NDArray, NDArray]:
|
|
"""
|
|
Reads the starting conditions of a simulation from a CSV
|
|
:param file: The CSV file
|
|
:return: The position, velocity, and radius matrices, in 2 or 3 dimensions
|
|
"""
|
|
# Get text from file
|
|
lines = file.read_text().strip().splitlines()
|
|
field_count = len(lines[0].split(","))
|
|
if field_count not in (5, 7):
|
|
raise RuntimeError("CSV format not recognized, can only show scenes in 2 or 3 dimensions")
|
|
|
|
# Allocate matrices
|
|
dimensions = (field_count - 1) // 2
|
|
pos = np.zeros((len(lines), dimensions), dtype=np.float64)
|
|
vel = np.zeros((len(lines), dimensions), dtype=np.float64)
|
|
rad = np.zeros((len(lines), 1), dtype=np.float64)
|
|
|
|
# Parse CSV lines
|
|
for i, line in enumerate(lines):
|
|
values = tuple(float(field) for field in line.split(","))
|
|
if dimensions == 2:
|
|
x, y, vx, vy, r = values
|
|
pos[i] = (x, y)
|
|
vel[i] = (vx, vy)
|
|
rad[i] = r
|
|
elif dimensions == 3:
|
|
x, y, z, vx, vy, vz, r = values
|
|
pos[i] = (x, y, z)
|
|
vel[i] = (vx, vy, vz)
|
|
rad[i] = r
|
|
|
|
return pos, vel, rad
|
|
|
|
|
|
class Axis(ABC):
|
|
"""Improved types for ``matplotlib.axes.Axes``"""
|
|
@abstractmethod
|
|
def axis(self, _arg: tuple[float, float, float, float, float, float] | tuple[float, float, float, float], /) -> Any: ...
|
|
@abstractmethod
|
|
def scatter(self, x: NDArray, y: NDArray, z: NDArray | None = None, *, c: NDArray | None = None, s: NDArray | None = None, depthshade: bool = True) -> PathCollection: ...
|
|
|
|
|
|
type Set3DProperties = Callable[[float | list[float] | NDArray, Literal["x", "y", "z"]], None]
|
|
|
|
|
|
class Animator:
|
|
"""Runs the simulation and displays it in a plot"""
|
|
def __init__(self, pos: NDArray, vel: NDArray, rad: NDArray, gravity: float) -> None:
|
|
"""
|
|
Sets up the simulation using the given data
|
|
:param pos: Start positions of the objects
|
|
:param vel: Start velocities of the objects
|
|
:param rad: Radii of the objects
|
|
:param gravity: Strength of gravity in this simulation
|
|
"""
|
|
# We update our arrays in-place, so we'll make out own copies to be safe
|
|
self._pos = pos.copy()
|
|
self._vel = vel.copy()
|
|
self._rad = rad.copy()
|
|
# Calculate volume of a sphere of this radius
|
|
self._mass = np.pi * 4 / 3 * rad ** 3
|
|
|
|
self._gravity = gravity
|
|
|
|
object_count, dimensions = self._pos.shape
|
|
|
|
# Objects will be represented using a scatter plot
|
|
self._scatter_plot: plt.PathCollection | None = None
|
|
# We'll give each objects a random colour to better differentiate them
|
|
self._object_colours = cast(LinearSegmentedColormap, cast(Any, cm).rainbow)(
|
|
np.random.random(
|
|
(object_count,)
|
|
)
|
|
)
|
|
|
|
# Create plot in an appropriate number of dimensions
|
|
self._fig = plt.figure()
|
|
if dimensions == 2:
|
|
self._ax = cast(Axis, self._fig.add_subplot())
|
|
else:
|
|
self._ax = cast(Axis, self._fig.add_subplot(projection="3d"))
|
|
|
|
# Set up animation loop
|
|
# This attribute never gets used again, but we'll keep a reference so that it doesn't get garbage collected
|
|
self._animation = animation.FuncAnimation(
|
|
self._fig,
|
|
self.update,
|
|
interval=1000 / (15 * 2 ** 4),
|
|
init_func=self.setup_plot,
|
|
blit=True,
|
|
cache_frame_data=False,
|
|
)
|
|
|
|
# Display the animation
|
|
plt.show()
|
|
|
|
def setup_plot(self) -> tuple[PathCollection]:
|
|
"""
|
|
This is a FuncAnimation initialization function in the form matplotlib expects
|
|
:return: The single scatter plot we're using
|
|
"""
|
|
_n, dimensions = self._pos.shape
|
|
# Set up the scatter plot in 2 or 3 dimensions
|
|
if dimensions == 2:
|
|
self._scatter_plot = self._ax.scatter(
|
|
self._pos[:, 0],
|
|
self._pos[:, 1],
|
|
c=self._object_colours,
|
|
s=self._rad * 10, # To make the objects more visible
|
|
)
|
|
# These values work nicely for a landscape window
|
|
self._ax.axis((-950., 950., -500., 500.))
|
|
else:
|
|
self._scatter_plot = self._ax.scatter(
|
|
self._pos[:, 0],
|
|
self._pos[:, 1],
|
|
self._pos[:, 2],
|
|
c=self._object_colours,
|
|
s=self._rad * 10, # To make the objects more visible
|
|
depthshade=False, # I find it confusing, YMMV
|
|
)
|
|
# These values work nicely for a square window
|
|
self._ax.axis((-500., 500., -500., 500., -500., 500.))
|
|
return self._scatter_plot,
|
|
|
|
def update(self, *_args, **_kwargs) -> tuple[PathCollection]:
|
|
"""
|
|
This is a FuncAnimation update function in the form matplotlib expects
|
|
:arg _args: We don't need any of matplotlib's information for our implementation
|
|
"arg _kwargs: Again, not necessary for us
|
|
:return: The single scatter plot we're using
|
|
"""
|
|
# As long as this is called after setup_plot() we're good
|
|
# Hey, dealing with race conditions is my day job, I will 100% ignore them on my on time
|
|
assert self._scatter_plot
|
|
|
|
_n, dimensions = self._pos.shape
|
|
# Update the state of our simulation
|
|
n_body_matrix(self._pos, self._vel, self._mass, self._gravity)
|
|
|
|
# Set the X and Y values of the objects
|
|
self._scatter_plot.set_offsets(self._pos[:, :2])
|
|
if dimensions == 3:
|
|
# Update the Z value if in 3D
|
|
cast(Set3DProperties, cast(Any, self._scatter_plot).set_3d_properties)(self._pos[:, 2], 'z')
|
|
# Use radius to represent mass
|
|
self._scatter_plot.set_sizes(self._rad[:, 0] * 10)
|
|
# Redraw the plot
|
|
self._fig.canvas.draw()
|
|
return self._scatter_plot,
|