Compare commits

...

3 commits

Author SHA1 Message Date
6822df0432
Clean up and comment 2025-01-02 01:42:26 +13:00
8bb6ba19d5
Switch to tabs 2025-01-02 00:05:16 +13:00
054b54c291
Update packages 2025-01-02 00:05:03 +13:00
6 changed files with 289 additions and 192 deletions

View file

@ -1,16 +1,13 @@
# nbody # nbody
Threw this together in a day or two cause I thought it would be fun to mess around with. Threw this together to get more comfortable with Numpy.
Can simulate a few hundred bodies in 2 or 3 dimensions without much hassle (hardware dependent of course). Can simulate a few hundred bodies in 2 or 3 dimensions without much hassle.
Comments are non-existent, sorry about that.
To set up: To set up:
```shell ```shell
python -m venv venv python -m venv venv
source venv/bin/activate source venv/bin/activate
pip -r requirements.txt pip install -r requirements.txt
``` ```
To run: To run:
@ -19,7 +16,7 @@ To run:
./main.py -f 2d/simple.csv ./main.py -f 2d/simple.csv
# 3d simulation, increased gravity # 3d simulation, increased gravity
./main.py -f 3d/some.csv -d 3 -g 30 ./main.py -f 3d/some.csv -g 30
``` ```
To create a new start state: To create a new start state:

185
data.py
View file

@ -1,93 +1,148 @@
from pathlib import Path
import matplotlib.animation as animation import matplotlib.animation as animation
import matplotlib.cm as cm import matplotlib.cm as cm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import physics from matplotlib.collections import PathCollection
from numpy.typing import NDArray
from physics import n_body_matrix
def parse_csv(filename: str, dimensions=2): def parse_csv(file: Path) -> tuple[NDArray, NDArray, NDArray]:
if not (1 < dimensions < 4): """
raise ValueError(f"Can only show 2or 3 dimensional scenes, not {dimensions}") Reads the starting conditions of a simulation from a CSV
with open(filename, 'r') as file: :param file: The CSV file
lines = file.read().strip().splitlines() :return: The position, velocity, and radius matrices, in 2 or 3 dimensions
pos = np.zeros((len(lines), dimensions)) """
vel = np.zeros((len(lines), dimensions)) # Get text from file
rad = np.zeros((len(lines), 1)) lines = file.read_text().strip().splitlines()
for i, values in enumerate(map(lambda l: map(float, l.split(',')), lines)): field_count = len(lines[0].split(","))
if field_count not in (5, 7):
raise RuntimeError("CSV format not recognized, can only show scenes in 2 or 3 dimensions")
# Allocate matrices
dimensions = (field_count - 1) // 2
pos = np.zeros((len(lines), dimensions), dtype=np.float64)
vel = np.zeros((len(lines), dimensions), dtype=np.float64)
rad = np.zeros((len(lines), 1), dtype=np.float64)
# Parse CSV lines
for i, line in enumerate(lines):
values = tuple(float(field) for field in line.split(","))
if dimensions == 2: if dimensions == 2:
[x, y, vx, vy, r] = values x, y, vx, vy, r = values
pos[i] = [x, y] pos[i] = (x, y)
vel[i] = [vx, vy] vel[i] = (vx, vy)
rad[i] = r rad[i] = r
elif dimensions == 3: elif dimensions == 3:
[x, y, z, vx, vy, vz, r] = values x, y, z, vx, vy, vz, r = values
pos[i] = [x, y, z] pos[i] = (x, y, z)
vel[i] = [vx, vy, vz] vel[i] = (vx, vy, vz)
rad[i] = r rad[i] = r
return pos, vel, rad return pos, vel, rad
class Animator: class Animator:
def __init__(self, pos: np.ndarray, vel: np.ndarray, rad: np.ndarray): """Runs the simulation and displays it in a plot"""
self.pos = pos def __init__(self, pos: NDArray, vel: NDArray, rad: NDArray, gravity: float) -> None:
self.vel = vel """
self.rad = rad Sets up the simulation using the given data
self.mass = np.pi * 4 / 3 * rad ** 3 :param pos: Start positions of the objects
:param vel: Start velocities of the objects
:param rad: Radii of the objects
:param gravity: Strength of gravity in this simulation
"""
# We update our arrays in-place, so we'll make out own copies to be safe
self._pos = pos.copy()
self._vel = vel.copy()
self._rad = rad.copy()
# Calculate volume of a sphere of this radius
self._mass = np.pi * 4 / 3 * rad ** 3
n, d = self.pos.shape self._gravity = gravity
self.scat: plt.PathCollection = None object_count, dimensions = self._pos.shape
self.colours = cm.rainbow(
# Objects will be represented using a scatter plot
self._scatter_plot: plt.PathCollection | None = None
# We'll give each objects a random colour to better differentiate them
self._object_colours: NDArray = cm.rainbow(
np.random.random( np.random.random(
(n,) (object_count,)
) )
) )
self.fig = plt.figure() # Create plot in an appropriate number of dimensions
if d == 2: self._fig = plt.figure()
self.ax = self.fig.add_subplot() if dimensions == 2:
self._ax = self._fig.add_subplot()
else: else:
self.ax = self.fig.add_subplot(projection="3d") self._ax = self._fig.add_subplot(projection="3d")
self.ani = animation.FuncAnimation(
self.fig, # Set up animation loop
# This attribute never gets used again, but we'll keep a reference so that it doesn't get garbage collected
self._animation = animation.FuncAnimation(
self._fig,
self.update, self.update,
interval=1000 / (15 * 2 ** 4), interval=1000 / (15 * 2 ** 4),
init_func=self.setup_plot, init_func=self.setup_plot,
blit=True, blit=True,
cache_frame_data=False cache_frame_data=False,
) )
def setup_plot(self): # Display the animation
_n, d = self.pos.shape
if d == 2:
self.scat = self.ax.scatter(
self.pos[:, 0],
self.pos[:, 1],
c=self.colours,
s=self.rad * 10
)
self.ax.axis([-950, 950, -500, 500])
else:
self.scat = self.ax.scatter(
self.pos[:, 0],
self.pos[:, 1],
self.pos[:, 2],
c=self.colours,
s=self.rad * 10,
depthshade=False
)
self.ax.axis([-500, 500, -500, 500, -500, 500])
return self.scat,
def update(self, *_args, **_kwargs):
_n, d = self.pos.shape
physics.n_body_matrix(self.pos, self.vel, self.mass, constrain=2.)
self.scat.set_offsets(self.pos[:, :2])
if d == 3:
self.scat.set_3d_properties(self.pos[:, 2], 'z')
self.scat.set_sizes(self.rad[:, 0] * 10)
self.fig.canvas.draw()
return self.scat,
def show(self):
plt.show() plt.show()
def setup_plot(self) -> tuple[PathCollection]:
"""
This is a FuncAnimation initialization function in the form matplotlib expects
:return: The single scatter plot we're using
"""
_n, dimensions = self._pos.shape
# Set up the scatter plot in 2 or 3 dimensions
if dimensions == 2:
self._scatter_plot = self._ax.scatter(
self._pos[:, 0],
self._pos[:, 1],
c=self._object_colours,
s=self._rad * 10, # To make the objects more visible
)
# These values work nicely for a landscape window
self._ax.axis([-950, 950, -500, 500])
else:
self._scatter_plot = self._ax.scatter(
self._pos[:, 0],
self._pos[:, 1],
self._pos[:, 2],
c=self._object_colours,
s=self._rad * 10, # To make the objects more visible
depthshade=False, # I find it confusing, YMMV
)
# These values work nicely for a square window
self._ax.axis([-500, 500, -500, 500, -500, 500])
return self._scatter_plot,
def update(self, *_args, **_kwargs) -> tuple[PathCollection]:
"""
This is a FuncAnimation update function in the form matplotlib expects
:arg _args: We don't need any of matplotlib's information for our implementation
"arg _kwargs: Again, not necessary for us
:return: The single scatter plot we're using
"""
_n, dimensions = self._pos.shape
# Update the state of our simulation
n_body_matrix(self._pos, self._vel, self._mass, self._gravity)
# Set the X and Y values of the objects
self._scatter_plot.set_offsets(self._pos[:, :2])
if dimensions == 3:
# Update the Z value if in 3D
self._scatter_plot.set_3d_properties(self._pos[:, 2], 'z')
# Use radius to represent mass
self._scatter_plot.set_sizes(self._rad[:, 0] * 10)
# Redraw the plot
self._fig.canvas.draw()
return self._scatter_plot,

View file

@ -1,38 +1,62 @@
#!./venv/bin/python #!venv/bin/python
import argparse """Generates random CSV data to be read by the simulator"""
from random import uniform, randint
from argparse import ArgumentParser
from random import uniform
from typing import Any, cast
class Args: class Args:
"""
The types of the arguments retrieved from the user
"""
width: int width: int
height: int length: int
depth: int depth: int
velocity: float speed: float
mass: float radius: float
count: int count: int
if __name__ == "__main__": def print_part(data: Any) -> None:
parser = argparse.ArgumentParser( """
prog="n-body data generator", Prints a CSV field
description="Generates data for the n-body simulator.", :param data: The data to put in the field
add_help=False """
) print(str(data), end=",")
parser.add_argument("-w", "--width", type=int, default=1900)
parser.add_argument("-h", "--height", type=int, default=1000)
parser.add_argument("-d", "--depth", type=int, default=0)
parser.add_argument("-v", "--velocity", type=float, default=1.)
parser.add_argument("-m", "--mass", type=float, default=1.)
parser.add_argument("-c", "--count", type=int, default=500)
args: Args = parser.parse_args()
def main(args: Args) -> None:
"""
Generates the starting data
:param args: Parameters for the random generation
"""
# Print <count> objects
for _ in range(args.count): for _ in range(args.count):
print(f"{randint(-args.width // 2, args.width // 2)}," # Object location
f"{randint(-args.height // 2, args.height // 2)}," print_part(uniform(-args.width / 2, args.width / 2))
f"{f'{randint(-args.depth // 2, args.depth // 2)},' if args.depth else ''}" print_part(uniform(-args.length / 2, args.length / 2))
f"{uniform(-args.velocity, args.velocity)}," if args.depth:
f"{uniform(-args.velocity, args.velocity)}," print_part(uniform(-args.depth / 2, args.depth / 2))
f"{f'{uniform(-args.velocity, args.velocity)},' if args.depth else ''}"
f"{uniform(1e-2, args.mass)}") # Object velocity
print_part(uniform(-args.speed, args.speed))
print_part(uniform(-args.speed, args.speed))
if args.depth:
print_part(uniform(-args.speed, args.speed))
# Finish line with a positive radius
print(uniform(1e-2, args.radius))
if __name__ == "__main__":
parser = ArgumentParser(description="Generates data for the n-body simulator", epilog="You should redirect the output to a file")
parser.add_argument("-w", "--width", type=float, default=1900., help="The width of the spawning area")
parser.add_argument("-l", "--length", type=float, default=1000., help="The length of the spawning area")
parser.add_argument("-d", "--depth", type=float, default=0., help="The depth of the spawning area, where 0 implies only 2 dimensions")
parser.add_argument("-s", "--speed", type=float, default=1., help="The maximum initial starting speed of an object in any dimension")
parser.add_argument("-r", "--radius", type=float, default=1., help="The maximum radius of an object")
parser.add_argument("-c", "--count", type=int, default=500, help="How many objects to create")
main(cast(Args, parser.parse_args()))

52
main.py
View file

@ -1,46 +1,26 @@
#!./venv/bin/python #!venv/bin/python
import argparse from argparse import ArgumentParser
import typing from pathlib import Path
from typing import cast
import data from data import parse_csv, Animator
import physics
class Args: class Args:
filename: str file: Path
gravity: float gravity: float
dimensions: typing.Literal[2, 3]
def main(file: Path, gravity: float) -> None:
pos, vel, rad = parse_csv(file)
Animator(pos, vel, rad, gravity)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = ArgumentParser(description="Gravitation simulation")
prog="n-body simulation",
description="Simulating gravitational effects"
)
parser.add_argument( parser.add_argument("-f", "--file", type=Path, default="data/2d/simple.csv", help="The starting state of the simulation")
"-f", parser.add_argument("-g", "--gravity", type=float, default=1., help="The strength of gravity")
"--filename",
default="data/2d/simple.csv"
)
parser.add_argument(
"-g",
"--gravity",
type=float,
default=1.
)
parser.add_argument(
"-d",
"--dimensions",
type=int,
choices=[2, 3],
default=2
)
args: Args = parser.parse_args() args = cast(Args, parser.parse_args())
main(args.file, args.gravity)
physics.G = args.gravity
objects = data.parse_csv(args.filename, dimensions=args.dimensions)
a = data.Animator(*objects)
a.show()

View file

@ -1,41 +1,82 @@
"""
Simulation tick algorithms
Both algorithms cancel out some terms. As you know, the force of gravity is $\frac{G * m_1 * m_2}{r^2}$. However, this
force is applied in the direction of the vector between the two masses. Because this direction vector needs to be
normalized, we can combine the normalization with the above equation to get $\frac{G * m_1 * m_2 * (p_2 - p_1)}{r^3}$
and skip out on a costly square root to calculate $r$ again. Finally, because this force is applied to one of the
masses (say $m_1$), the actual change in velocity is the force divided by the mass. This means that we can just drop
the $m_1$ term from the equation, and we have our change in velocity.
"""
from typing import Any, Generator
import numpy as np import numpy as np
from numpy.typing import NDArray
G = 6.674e-11 def _rotations(arr: NDArray) -> Generator[NDArray, Any, None]:
"""
"Rotates" through an array, returning the [i: n+i] range on the ith iteration
:param arr: Array to rotate through
"""
a2 = np.concatenate((arr, arr))
for i in range(1, len(arr)):
yield a2[i: i + len(arr)]
def rotations(a: np.ndarray): def n_body(pos: NDArray, vel: NDArray, mass: NDArray, gravity: float) -> None:
a2 = np.concatenate((a, a)) """
for i in range(1, len(a)): Easier-to-understand but slower update algorithm that simulates a tick.
yield a2[i: i + len(a)] Unused, just for demonstration purposes.
:param pos: Previous positions
:param vel: Previous velocities
def n_body(pos: np.ndarray, vel: np.ndarray, mass: np.ndarray): :param mass: Object masses
for (o_pos, o_mass) in zip(rotations(pos), rotations(mass)): :param gravity: Simulation gravity
dist = o_pos - pos :return: None - updated in-place
vel += G * dist * o_mass / (np.linalg.norm(dist, axis=1) ** 3)[:, np.newaxis] """
for (other_pos, other_mass) in zip(_rotations(pos), _rotations(mass)):
# For each combination of 2 objects
dist = other_pos - pos
# Calculate the force of gravity from the first to the second, and use it to update the velocity
vel += gravity * dist * other_mass / (np.linalg.norm(dist, axis=1) ** 3)[:, np.newaxis]
# Update positions
pos += vel pos += vel
def n_body_matrix(pos: np.ndarray, vel: np.ndarray, mass: np.ndarray, constrain=2.): def n_body_matrix(pos: NDArray, vel: NDArray, mass: NDArray, gravity: float, constrain: float = 2.) -> None:
n, d = pos.shape """
dist = np.zeros((n - 1, n, d)) Harder-to-understand but faster update algorithm that simulates a tick.
rot_mass = np.zeros((n - 1, n, 1)) Basically does the simpler algorithm all at once using numpy parallelism.
:param pos: Previous positions
:param vel: Previous velocities
:param mass: Object masses
:param gravity: Simulation gravity
:param constrain: Numerical stability term
:return: None - updated in-place
"""
num_masses, dimensions = pos.shape
dist = np.zeros((num_masses - 1, num_masses, dimensions), dtype=np.float64)
rot_mass = np.zeros((num_masses - 1, num_masses, 1), dtype=np.float64)
pos2 = np.concatenate((pos, pos)) pos2 = np.concatenate((pos, pos))
mass2 = np.concatenate((mass, mass)) mass2 = np.concatenate((mass, mass))
# Generates a matrix using the rotated arrays, like the for loop in the above algorithm in one go
for i in range(1, len(pos)): for i in range(1, len(pos)):
dist[i - 1] = pos2[i: i + n] - pos # The distance between two objects
rot_mass[i - 1] = mass2[i: i + n] dist[i - 1] = pos2[i: i + num_masses] - pos
# The mass of the other object
rot_mass[i - 1] = mass2[i: i + num_masses]
# Normalize direction vectors, and ensure the distances aren't too close for numerical stability
norms = np.linalg.norm(dist, axis=2) norms = np.linalg.norm(dist, axis=2)
if constrain: if constrain:
norms[norms < constrain] = constrain norms[norms < constrain] = constrain
vel += G * np.sum( # Calculate all changes in velocity at once, using the same method described and implemented above
vel += gravity * np.sum(
dist * rot_mass / (norms ** 3)[:, :, np.newaxis], dist * rot_mass / (norms ** 3)[:, :, np.newaxis],
axis=0 axis=0
) )
# Update positions
pos += vel pos += vel

View file

@ -1,11 +1,11 @@
contourpy==1.1.1 contourpy==1.3.1
cycler==0.12.1 cycler==0.12.1
fonttools==4.43.1 fonttools==4.55.3
kiwisolver==1.4.5 kiwisolver==1.4.8
matplotlib==3.8.0 matplotlib==3.10.0
numpy==1.26.0 numpy==2.2.1
packaging==23.2 packaging==24.2
Pillow==10.0.1 pillow==11.0.0
pyparsing==3.1.1 pyparsing==3.2.1
python-dateutil==2.8.2 python-dateutil==2.9.0.post0
six==1.16.0 six==1.17.0