nbody/data.py
Michael Bradley a72b709214
All checks were successful
Lint / MyPy (push) Successful in 1m5s
Make MyPy happy
You know, I don't think this makes it more readable.
The problem is really that MPL's typed interface isn't very good, especially when dealing with 3D graphs, so I've gotta do a bunch of workaround to get around it.
2025-02-23 11:38:32 -05:00

166 lines
5.4 KiB
Python

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, cast, Literal
import matplotlib.animation as animation
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from matplotlib.collections import PathCollection
from matplotlib.colors import LinearSegmentedColormap
import numpy as np
from numpy.typing import NDArray
from physics import n_body_matrix
def parse_csv(file: Path) -> tuple[NDArray, NDArray, NDArray]:
"""
Reads the starting conditions of a simulation from a CSV
:param file: The CSV file
:return: The position, velocity, and radius matrices, in 2 or 3 dimensions
"""
# Get text from file
lines = file.read_text().strip().splitlines()
field_count = len(lines[0].split(","))
if field_count not in (5, 7):
raise RuntimeError("CSV format not recognized, can only show scenes in 2 or 3 dimensions")
# Allocate matrices
dimensions = (field_count - 1) // 2
pos = np.zeros((len(lines), dimensions), dtype=np.float64)
vel = np.zeros((len(lines), dimensions), dtype=np.float64)
rad = np.zeros((len(lines), 1), dtype=np.float64)
# Parse CSV lines
for i, line in enumerate(lines):
values = tuple(float(field) for field in line.split(","))
if dimensions == 2:
x, y, vx, vy, r = values
pos[i] = (x, y)
vel[i] = (vx, vy)
rad[i] = r
elif dimensions == 3:
x, y, z, vx, vy, vz, r = values
pos[i] = (x, y, z)
vel[i] = (vx, vy, vz)
rad[i] = r
return pos, vel, rad
class Axis(ABC):
"""Improved types for ``matplotlib.axes.Axes``"""
@abstractmethod
def axis(self, _arg: tuple[float, float, float, float, float, float] | tuple[float, float, float, float], /) -> Any: ...
@abstractmethod
def scatter(self, x: NDArray, y: NDArray, z: NDArray | None = None, *, c: NDArray | None = None, s: NDArray | None = None, depthshade: bool = True) -> PathCollection: ...
type Set3DProperties = Callable[[float | list[float] | NDArray, Literal["x", "y", "z"]], None]
class Animator:
"""Runs the simulation and displays it in a plot"""
def __init__(self, pos: NDArray, vel: NDArray, rad: NDArray, gravity: float) -> None:
"""
Sets up the simulation using the given data
:param pos: Start positions of the objects
:param vel: Start velocities of the objects
:param rad: Radii of the objects
:param gravity: Strength of gravity in this simulation
"""
# We update our arrays in-place, so we'll make out own copies to be safe
self._pos = pos.copy()
self._vel = vel.copy()
self._rad = rad.copy()
# Calculate volume of a sphere of this radius
self._mass = np.pi * 4 / 3 * rad ** 3
self._gravity = gravity
object_count, dimensions = self._pos.shape
# Objects will be represented using a scatter plot
self._scatter_plot: plt.PathCollection | None = None
# We'll give each objects a random colour to better differentiate them
self._object_colours = cast(LinearSegmentedColormap, cast(Any, cm).rainbow)(
np.random.random(
(object_count,)
)
)
# Create plot in an appropriate number of dimensions
self._fig = plt.figure()
if dimensions == 2:
self._ax = cast(Axis, self._fig.add_subplot())
else:
self._ax = cast(Axis, self._fig.add_subplot(projection="3d"))
# Set up animation loop
# This attribute never gets used again, but we'll keep a reference so that it doesn't get garbage collected
self._animation = animation.FuncAnimation(
self._fig,
self.update,
interval=1000 / (15 * 2 ** 4),
init_func=self.setup_plot,
blit=True,
cache_frame_data=False,
)
# Display the animation
plt.show()
def setup_plot(self) -> tuple[PathCollection]:
"""
This is a FuncAnimation initialization function in the form matplotlib expects
:return: The single scatter plot we're using
"""
_n, dimensions = self._pos.shape
# Set up the scatter plot in 2 or 3 dimensions
if dimensions == 2:
self._scatter_plot = self._ax.scatter(
self._pos[:, 0],
self._pos[:, 1],
c=self._object_colours,
s=self._rad * 10, # To make the objects more visible
)
# These values work nicely for a landscape window
self._ax.axis((-950., 950., -500., 500.))
else:
self._scatter_plot = self._ax.scatter(
self._pos[:, 0],
self._pos[:, 1],
self._pos[:, 2],
c=self._object_colours,
s=self._rad * 10, # To make the objects more visible
depthshade=False, # I find it confusing, YMMV
)
# These values work nicely for a square window
self._ax.axis((-500., 500., -500., 500., -500., 500.))
return self._scatter_plot,
def update(self, *_args, **_kwargs) -> tuple[PathCollection]:
"""
This is a FuncAnimation update function in the form matplotlib expects
:arg _args: We don't need any of matplotlib's information for our implementation
"arg _kwargs: Again, not necessary for us
:return: The single scatter plot we're using
"""
# As long as this is called after setup_plot() we're good
# Hey, dealing with race conditions is my day job, I will 100% ignore them on my on time
assert self._scatter_plot
_n, dimensions = self._pos.shape
# Update the state of our simulation
n_body_matrix(self._pos, self._vel, self._mass, self._gravity)
# Set the X and Y values of the objects
self._scatter_plot.set_offsets(self._pos[:, :2])
if dimensions == 3:
# Update the Z value if in 3D
cast(Set3DProperties, cast(Any, self._scatter_plot).set_3d_properties)(self._pos[:, 2], 'z')
# Use radius to represent mass
self._scatter_plot.set_sizes(self._rad[:, 0] * 10)
# Redraw the plot
self._fig.canvas.draw()
return self._scatter_plot,