nbody/data.py

148 lines
4.6 KiB
Python

from pathlib import Path
import matplotlib.animation as animation
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import PathCollection
from numpy.typing import NDArray
from physics import n_body_matrix
def parse_csv(file: Path) -> tuple[NDArray, NDArray, NDArray]:
"""
Reads the starting conditions of a simulation from a CSV
:param file: The CSV file
:return: The position, velocity, and radius matrices, in 2 or 3 dimensions
"""
# Get text from file
lines = file.read_text().strip().splitlines()
field_count = len(lines[0].split(","))
if field_count not in (5, 7):
raise RuntimeError("CSV format not recognized, can only show scenes in 2 or 3 dimensions")
# Allocate matrices
dimensions = (field_count - 1) // 2
pos = np.zeros((len(lines), dimensions), dtype=np.float64)
vel = np.zeros((len(lines), dimensions), dtype=np.float64)
rad = np.zeros((len(lines), 1), dtype=np.float64)
# Parse CSV lines
for i, line in enumerate(lines):
values = tuple(float(field) for field in line.split(","))
if dimensions == 2:
x, y, vx, vy, r = values
pos[i] = (x, y)
vel[i] = (vx, vy)
rad[i] = r
elif dimensions == 3:
x, y, z, vx, vy, vz, r = values
pos[i] = (x, y, z)
vel[i] = (vx, vy, vz)
rad[i] = r
return pos, vel, rad
class Animator:
"""Runs the simulation and displays it in a plot"""
def __init__(self, pos: NDArray, vel: NDArray, rad: NDArray, gravity: float) -> None:
"""
Sets up the simulation using the given data
:param pos: Start positions of the objects
:param vel: Start velocities of the objects
:param rad: Radii of the objects
:param gravity: Strength of gravity in this simulation
"""
# We update our arrays in-place, so we'll make out own copies to be safe
self._pos = pos.copy()
self._vel = vel.copy()
self._rad = rad.copy()
# Calculate volume of a sphere of this radius
self._mass = np.pi * 4 / 3 * rad ** 3
self._gravity = gravity
object_count, dimensions = self._pos.shape
# Objects will be represented using a scatter plot
self._scatter_plot: plt.PathCollection | None = None
# We'll give each objects a random colour to better differentiate them
self._object_colours: NDArray = cm.rainbow(
np.random.random(
(object_count,)
)
)
# Create plot in an appropriate number of dimensions
self._fig = plt.figure()
if dimensions == 2:
self._ax = self._fig.add_subplot()
else:
self._ax = self._fig.add_subplot(projection="3d")
# Set up animation loop
# This attribute never gets used again, but we'll keep a reference so that it doesn't get garbage collected
self._animation = animation.FuncAnimation(
self._fig,
self.update,
interval=1000 / (15 * 2 ** 4),
init_func=self.setup_plot,
blit=True,
cache_frame_data=False,
)
# Display the animation
plt.show()
def setup_plot(self) -> tuple[PathCollection]:
"""
This is a FuncAnimation initialization function in the form matplotlib expects
:return: The single scatter plot we're using
"""
_n, dimensions = self._pos.shape
# Set up the scatter plot in 2 or 3 dimensions
if dimensions == 2:
self._scatter_plot = self._ax.scatter(
self._pos[:, 0],
self._pos[:, 1],
c=self._object_colours,
s=self._rad * 10, # To make the objects more visible
)
# These values work nicely for a landscape window
self._ax.axis([-950, 950, -500, 500])
else:
self._scatter_plot = self._ax.scatter(
self._pos[:, 0],
self._pos[:, 1],
self._pos[:, 2],
c=self._object_colours,
s=self._rad * 10, # To make the objects more visible
depthshade=False, # I find it confusing, YMMV
)
# These values work nicely for a square window
self._ax.axis([-500, 500, -500, 500, -500, 500])
return self._scatter_plot,
def update(self, *_args, **_kwargs) -> tuple[PathCollection]:
"""
This is a FuncAnimation update function in the form matplotlib expects
:arg _args: We don't need any of matplotlib's information for our implementation
"arg _kwargs: Again, not necessary for us
:return: The single scatter plot we're using
"""
_n, dimensions = self._pos.shape
# Update the state of our simulation
n_body_matrix(self._pos, self._vel, self._mass, self._gravity)
# Set the X and Y values of the objects
self._scatter_plot.set_offsets(self._pos[:, :2])
if dimensions == 3:
# Update the Z value if in 3D
self._scatter_plot.set_3d_properties(self._pos[:, 2], 'z')
# Use radius to represent mass
self._scatter_plot.set_sizes(self._rad[:, 0] * 10)
# Redraw the plot
self._fig.canvas.draw()
return self._scatter_plot,