from pathlib import Path import matplotlib.animation as animation import matplotlib.cm as cm import matplotlib.pyplot as plt import numpy as np from matplotlib.collections import PathCollection from numpy.typing import NDArray from physics import n_body_matrix def parse_csv(file: Path) -> tuple[NDArray, NDArray, NDArray]: """ Reads the starting conditions of a simulation from a CSV :param file: The CSV file :return: The position, velocity, and radius matrices, in 2 or 3 dimensions """ # Get text from file lines = file.read_text().strip().splitlines() field_count = len(lines[0].split(",")) if field_count not in (5, 7): raise RuntimeError("CSV format not recognized, can only show scenes in 2 or 3 dimensions") # Allocate matrices dimensions = (field_count - 1) // 2 pos = np.zeros((len(lines), dimensions), dtype=np.float64) vel = np.zeros((len(lines), dimensions), dtype=np.float64) rad = np.zeros((len(lines), 1), dtype=np.float64) # Parse CSV lines for i, line in enumerate(lines): values = tuple(float(field) for field in line.split(",")) if dimensions == 2: x, y, vx, vy, r = values pos[i] = (x, y) vel[i] = (vx, vy) rad[i] = r elif dimensions == 3: x, y, z, vx, vy, vz, r = values pos[i] = (x, y, z) vel[i] = (vx, vy, vz) rad[i] = r return pos, vel, rad class Animator: """Runs the simulation and displays it in a plot""" def __init__(self, pos: NDArray, vel: NDArray, rad: NDArray, gravity: float) -> None: """ Sets up the simulation using the given data :param pos: Start positions of the objects :param vel: Start velocities of the objects :param rad: Radii of the objects :param gravity: Strength of gravity in this simulation """ # We update our arrays in-place, so we'll make out own copies to be safe self._pos = pos.copy() self._vel = vel.copy() self._rad = rad.copy() # Calculate volume of a sphere of this radius self._mass = np.pi * 4 / 3 * rad ** 3 self._gravity = gravity object_count, dimensions = self._pos.shape # Objects will be represented using a scatter plot self._scatter_plot: plt.PathCollection | None = None # We'll give each objects a random colour to better differentiate them self._object_colours: NDArray = cm.rainbow( np.random.random( (object_count,) ) ) # Create plot in an appropriate number of dimensions self._fig = plt.figure() if dimensions == 2: self._ax = self._fig.add_subplot() else: self._ax = self._fig.add_subplot(projection="3d") # Set up animation loop # This attribute never gets used again, but we'll keep a reference so that it doesn't get garbage collected self._animation = animation.FuncAnimation( self._fig, self.update, interval=1000 / (15 * 2 ** 4), init_func=self.setup_plot, blit=True, cache_frame_data=False, ) # Display the animation plt.show() def setup_plot(self) -> tuple[PathCollection]: """ This is a FuncAnimation initialization function in the form matplotlib expects :return: The single scatter plot we're using """ _n, dimensions = self._pos.shape # Set up the scatter plot in 2 or 3 dimensions if dimensions == 2: self._scatter_plot = self._ax.scatter( self._pos[:, 0], self._pos[:, 1], c=self._object_colours, s=self._rad * 10, # To make the objects more visible ) # These values work nicely for a landscape window self._ax.axis([-950, 950, -500, 500]) else: self._scatter_plot = self._ax.scatter( self._pos[:, 0], self._pos[:, 1], self._pos[:, 2], c=self._object_colours, s=self._rad * 10, # To make the objects more visible depthshade=False, # I find it confusing, YMMV ) # These values work nicely for a square window self._ax.axis([-500, 500, -500, 500, -500, 500]) return self._scatter_plot, def update(self, *_args, **_kwargs) -> tuple[PathCollection]: """ This is a FuncAnimation update function in the form matplotlib expects :arg _args: We don't need any of matplotlib's information for our implementation "arg _kwargs: Again, not necessary for us :return: The single scatter plot we're using """ _n, dimensions = self._pos.shape # Update the state of our simulation n_body_matrix(self._pos, self._vel, self._mass, self._gravity) # Set the X and Y values of the objects self._scatter_plot.set_offsets(self._pos[:, :2]) if dimensions == 3: # Update the Z value if in 3D self._scatter_plot.set_3d_properties(self._pos[:, 2], 'z') # Use radius to represent mass self._scatter_plot.set_sizes(self._rad[:, 0] * 10) # Redraw the plot self._fig.canvas.draw() return self._scatter_plot,