Compare commits

..

No commits in common. "6822df04327bd6aa40aaae45b8af83012580a391" and "9651d1470a0034305c8dcd29d22351aa137b59b9" have entirely different histories.

6 changed files with 192 additions and 289 deletions

View file

@ -1,13 +1,16 @@
# nbody # nbody
Threw this together to get more comfortable with Numpy. Threw this together in a day or two cause I thought it would be fun to mess around with.
Can simulate a few hundred bodies in 2 or 3 dimensions without much hassle. Can simulate a few hundred bodies in 2 or 3 dimensions without much hassle (hardware dependent of course).
Comments are non-existent, sorry about that.
To set up: To set up:
```shell ```shell
python -m venv venv python -m venv venv
source venv/bin/activate source venv/bin/activate
pip install -r requirements.txt pip -r requirements.txt
``` ```
To run: To run:
@ -16,7 +19,7 @@ To run:
./main.py -f 2d/simple.csv ./main.py -f 2d/simple.csv
# 3d simulation, increased gravity # 3d simulation, increased gravity
./main.py -f 3d/some.csv -g 30 ./main.py -f 3d/some.csv -d 3 -g 30
``` ```
To create a new start state: To create a new start state:

177
data.py
View file

@ -1,148 +1,93 @@
from pathlib import Path
import matplotlib.animation as animation import matplotlib.animation as animation
import matplotlib.cm as cm import matplotlib.cm as cm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from matplotlib.collections import PathCollection import physics
from numpy.typing import NDArray
from physics import n_body_matrix
def parse_csv(file: Path) -> tuple[NDArray, NDArray, NDArray]: def parse_csv(filename: str, dimensions=2):
""" if not (1 < dimensions < 4):
Reads the starting conditions of a simulation from a CSV raise ValueError(f"Can only show 2or 3 dimensional scenes, not {dimensions}")
:param file: The CSV file with open(filename, 'r') as file:
:return: The position, velocity, and radius matrices, in 2 or 3 dimensions lines = file.read().strip().splitlines()
""" pos = np.zeros((len(lines), dimensions))
# Get text from file vel = np.zeros((len(lines), dimensions))
lines = file.read_text().strip().splitlines() rad = np.zeros((len(lines), 1))
field_count = len(lines[0].split(",")) for i, values in enumerate(map(lambda l: map(float, l.split(',')), lines)):
if field_count not in (5, 7):
raise RuntimeError("CSV format not recognized, can only show scenes in 2 or 3 dimensions")
# Allocate matrices
dimensions = (field_count - 1) // 2
pos = np.zeros((len(lines), dimensions), dtype=np.float64)
vel = np.zeros((len(lines), dimensions), dtype=np.float64)
rad = np.zeros((len(lines), 1), dtype=np.float64)
# Parse CSV lines
for i, line in enumerate(lines):
values = tuple(float(field) for field in line.split(","))
if dimensions == 2: if dimensions == 2:
x, y, vx, vy, r = values [x, y, vx, vy, r] = values
pos[i] = (x, y) pos[i] = [x, y]
vel[i] = (vx, vy) vel[i] = [vx, vy]
rad[i] = r rad[i] = r
elif dimensions == 3: elif dimensions == 3:
x, y, z, vx, vy, vz, r = values [x, y, z, vx, vy, vz, r] = values
pos[i] = (x, y, z) pos[i] = [x, y, z]
vel[i] = (vx, vy, vz) vel[i] = [vx, vy, vz]
rad[i] = r rad[i] = r
return pos, vel, rad return pos, vel, rad
class Animator: class Animator:
"""Runs the simulation and displays it in a plot""" def __init__(self, pos: np.ndarray, vel: np.ndarray, rad: np.ndarray):
def __init__(self, pos: NDArray, vel: NDArray, rad: NDArray, gravity: float) -> None: self.pos = pos
""" self.vel = vel
Sets up the simulation using the given data self.rad = rad
:param pos: Start positions of the objects self.mass = np.pi * 4 / 3 * rad ** 3
:param vel: Start velocities of the objects
:param rad: Radii of the objects
:param gravity: Strength of gravity in this simulation
"""
# We update our arrays in-place, so we'll make out own copies to be safe
self._pos = pos.copy()
self._vel = vel.copy()
self._rad = rad.copy()
# Calculate volume of a sphere of this radius
self._mass = np.pi * 4 / 3 * rad ** 3
self._gravity = gravity n, d = self.pos.shape
object_count, dimensions = self._pos.shape self.scat: plt.PathCollection = None
self.colours = cm.rainbow(
# Objects will be represented using a scatter plot
self._scatter_plot: plt.PathCollection | None = None
# We'll give each objects a random colour to better differentiate them
self._object_colours: NDArray = cm.rainbow(
np.random.random( np.random.random(
(object_count,) (n,)
) )
) )
# Create plot in an appropriate number of dimensions self.fig = plt.figure()
self._fig = plt.figure() if d == 2:
if dimensions == 2: self.ax = self.fig.add_subplot()
self._ax = self._fig.add_subplot()
else: else:
self._ax = self._fig.add_subplot(projection="3d") self.ax = self.fig.add_subplot(projection="3d")
self.ani = animation.FuncAnimation(
# Set up animation loop self.fig,
# This attribute never gets used again, but we'll keep a reference so that it doesn't get garbage collected
self._animation = animation.FuncAnimation(
self._fig,
self.update, self.update,
interval=1000 / (15 * 2 ** 4), interval=1000 / (15 * 2 ** 4),
init_func=self.setup_plot, init_func=self.setup_plot,
blit=True, blit=True,
cache_frame_data=False, cache_frame_data=False
) )
# Display the animation def setup_plot(self):
plt.show() _n, d = self.pos.shape
if d == 2:
def setup_plot(self) -> tuple[PathCollection]: self.scat = self.ax.scatter(
""" self.pos[:, 0],
This is a FuncAnimation initialization function in the form matplotlib expects self.pos[:, 1],
:return: The single scatter plot we're using c=self.colours,
""" s=self.rad * 10
_n, dimensions = self._pos.shape
# Set up the scatter plot in 2 or 3 dimensions
if dimensions == 2:
self._scatter_plot = self._ax.scatter(
self._pos[:, 0],
self._pos[:, 1],
c=self._object_colours,
s=self._rad * 10, # To make the objects more visible
) )
# These values work nicely for a landscape window self.ax.axis([-950, 950, -500, 500])
self._ax.axis([-950, 950, -500, 500])
else: else:
self._scatter_plot = self._ax.scatter( self.scat = self.ax.scatter(
self._pos[:, 0], self.pos[:, 0],
self._pos[:, 1], self.pos[:, 1],
self._pos[:, 2], self.pos[:, 2],
c=self._object_colours, c=self.colours,
s=self._rad * 10, # To make the objects more visible s=self.rad * 10,
depthshade=False, # I find it confusing, YMMV depthshade=False
) )
# These values work nicely for a square window self.ax.axis([-500, 500, -500, 500, -500, 500])
self._ax.axis([-500, 500, -500, 500, -500, 500]) return self.scat,
return self._scatter_plot,
def update(self, *_args, **_kwargs) -> tuple[PathCollection]: def update(self, *_args, **_kwargs):
""" _n, d = self.pos.shape
This is a FuncAnimation update function in the form matplotlib expects physics.n_body_matrix(self.pos, self.vel, self.mass, constrain=2.)
:arg _args: We don't need any of matplotlib's information for our implementation self.scat.set_offsets(self.pos[:, :2])
"arg _kwargs: Again, not necessary for us if d == 3:
:return: The single scatter plot we're using self.scat.set_3d_properties(self.pos[:, 2], 'z')
""" self.scat.set_sizes(self.rad[:, 0] * 10)
_n, dimensions = self._pos.shape self.fig.canvas.draw()
# Update the state of our simulation return self.scat,
n_body_matrix(self._pos, self._vel, self._mass, self._gravity)
# Set the X and Y values of the objects def show(self):
self._scatter_plot.set_offsets(self._pos[:, :2]) plt.show()
if dimensions == 3:
# Update the Z value if in 3D
self._scatter_plot.set_3d_properties(self._pos[:, 2], 'z')
# Use radius to represent mass
self._scatter_plot.set_sizes(self._rad[:, 0] * 10)
# Redraw the plot
self._fig.canvas.draw()
return self._scatter_plot,

View file

@ -1,62 +1,38 @@
#!venv/bin/python #!./venv/bin/python
"""Generates random CSV data to be read by the simulator""" import argparse
from random import uniform, randint
from argparse import ArgumentParser
from random import uniform
from typing import Any, cast
class Args: class Args:
"""
The types of the arguments retrieved from the user
"""
width: int width: int
length: int height: int
depth: int depth: int
speed: float velocity: float
radius: float mass: float
count: int count: int
def print_part(data: Any) -> None:
"""
Prints a CSV field
:param data: The data to put in the field
"""
print(str(data), end=",")
def main(args: Args) -> None:
"""
Generates the starting data
:param args: Parameters for the random generation
"""
# Print <count> objects
for _ in range(args.count):
# Object location
print_part(uniform(-args.width / 2, args.width / 2))
print_part(uniform(-args.length / 2, args.length / 2))
if args.depth:
print_part(uniform(-args.depth / 2, args.depth / 2))
# Object velocity
print_part(uniform(-args.speed, args.speed))
print_part(uniform(-args.speed, args.speed))
if args.depth:
print_part(uniform(-args.speed, args.speed))
# Finish line with a positive radius
print(uniform(1e-2, args.radius))
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser(description="Generates data for the n-body simulator", epilog="You should redirect the output to a file") parser = argparse.ArgumentParser(
prog="n-body data generator",
description="Generates data for the n-body simulator.",
add_help=False
)
parser.add_argument("-w", "--width", type=float, default=1900., help="The width of the spawning area") parser.add_argument("-w", "--width", type=int, default=1900)
parser.add_argument("-l", "--length", type=float, default=1000., help="The length of the spawning area") parser.add_argument("-h", "--height", type=int, default=1000)
parser.add_argument("-d", "--depth", type=float, default=0., help="The depth of the spawning area, where 0 implies only 2 dimensions") parser.add_argument("-d", "--depth", type=int, default=0)
parser.add_argument("-s", "--speed", type=float, default=1., help="The maximum initial starting speed of an object in any dimension") parser.add_argument("-v", "--velocity", type=float, default=1.)
parser.add_argument("-r", "--radius", type=float, default=1., help="The maximum radius of an object") parser.add_argument("-m", "--mass", type=float, default=1.)
parser.add_argument("-c", "--count", type=int, default=500, help="How many objects to create") parser.add_argument("-c", "--count", type=int, default=500)
main(cast(Args, parser.parse_args())) args: Args = parser.parse_args()
for _ in range(args.count):
print(f"{randint(-args.width // 2, args.width // 2)},"
f"{randint(-args.height // 2, args.height // 2)},"
f"{f'{randint(-args.depth // 2, args.depth // 2)},' if args.depth else ''}"
f"{uniform(-args.velocity, args.velocity)},"
f"{uniform(-args.velocity, args.velocity)},"
f"{f'{uniform(-args.velocity, args.velocity)},' if args.depth else ''}"
f"{uniform(1e-2, args.mass)}")

52
main.py
View file

@ -1,26 +1,46 @@
#!venv/bin/python #!./venv/bin/python
from argparse import ArgumentParser import argparse
from pathlib import Path import typing
from typing import cast
from data import parse_csv, Animator import data
import physics
class Args: class Args:
file: Path filename: str
gravity: float gravity: float
dimensions: typing.Literal[2, 3]
def main(file: Path, gravity: float) -> None:
pos, vel, rad = parse_csv(file)
Animator(pos, vel, rad, gravity)
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser(description="Gravitation simulation") parser = argparse.ArgumentParser(
prog="n-body simulation",
description="Simulating gravitational effects"
)
parser.add_argument("-f", "--file", type=Path, default="data/2d/simple.csv", help="The starting state of the simulation") parser.add_argument(
parser.add_argument("-g", "--gravity", type=float, default=1., help="The strength of gravity") "-f",
"--filename",
default="data/2d/simple.csv"
)
parser.add_argument(
"-g",
"--gravity",
type=float,
default=1.
)
parser.add_argument(
"-d",
"--dimensions",
type=int,
choices=[2, 3],
default=2
)
args = cast(Args, parser.parse_args()) args: Args = parser.parse_args()
main(args.file, args.gravity)
physics.G = args.gravity
objects = data.parse_csv(args.filename, dimensions=args.dimensions)
a = data.Animator(*objects)
a.show()

View file

@ -1,82 +1,41 @@
"""
Simulation tick algorithms
Both algorithms cancel out some terms. As you know, the force of gravity is $\frac{G * m_1 * m_2}{r^2}$. However, this
force is applied in the direction of the vector between the two masses. Because this direction vector needs to be
normalized, we can combine the normalization with the above equation to get $\frac{G * m_1 * m_2 * (p_2 - p_1)}{r^3}$
and skip out on a costly square root to calculate $r$ again. Finally, because this force is applied to one of the
masses (say $m_1$), the actual change in velocity is the force divided by the mass. This means that we can just drop
the $m_1$ term from the equation, and we have our change in velocity.
"""
from typing import Any, Generator
import numpy as np import numpy as np
from numpy.typing import NDArray
def _rotations(arr: NDArray) -> Generator[NDArray, Any, None]: G = 6.674e-11
"""
"Rotates" through an array, returning the [i: n+i] range on the ith iteration
:param arr: Array to rotate through
"""
a2 = np.concatenate((arr, arr))
for i in range(1, len(arr)):
yield a2[i: i + len(arr)]
def n_body(pos: NDArray, vel: NDArray, mass: NDArray, gravity: float) -> None: def rotations(a: np.ndarray):
""" a2 = np.concatenate((a, a))
Easier-to-understand but slower update algorithm that simulates a tick. for i in range(1, len(a)):
Unused, just for demonstration purposes. yield a2[i: i + len(a)]
:param pos: Previous positions
:param vel: Previous velocities
:param mass: Object masses def n_body(pos: np.ndarray, vel: np.ndarray, mass: np.ndarray):
:param gravity: Simulation gravity for (o_pos, o_mass) in zip(rotations(pos), rotations(mass)):
:return: None - updated in-place dist = o_pos - pos
""" vel += G * dist * o_mass / (np.linalg.norm(dist, axis=1) ** 3)[:, np.newaxis]
for (other_pos, other_mass) in zip(_rotations(pos), _rotations(mass)):
# For each combination of 2 objects
dist = other_pos - pos
# Calculate the force of gravity from the first to the second, and use it to update the velocity
vel += gravity * dist * other_mass / (np.linalg.norm(dist, axis=1) ** 3)[:, np.newaxis]
# Update positions
pos += vel pos += vel
def n_body_matrix(pos: NDArray, vel: NDArray, mass: NDArray, gravity: float, constrain: float = 2.) -> None: def n_body_matrix(pos: np.ndarray, vel: np.ndarray, mass: np.ndarray, constrain=2.):
""" n, d = pos.shape
Harder-to-understand but faster update algorithm that simulates a tick. dist = np.zeros((n - 1, n, d))
Basically does the simpler algorithm all at once using numpy parallelism. rot_mass = np.zeros((n - 1, n, 1))
:param pos: Previous positions
:param vel: Previous velocities
:param mass: Object masses
:param gravity: Simulation gravity
:param constrain: Numerical stability term
:return: None - updated in-place
"""
num_masses, dimensions = pos.shape
dist = np.zeros((num_masses - 1, num_masses, dimensions), dtype=np.float64)
rot_mass = np.zeros((num_masses - 1, num_masses, 1), dtype=np.float64)
pos2 = np.concatenate((pos, pos)) pos2 = np.concatenate((pos, pos))
mass2 = np.concatenate((mass, mass)) mass2 = np.concatenate((mass, mass))
# Generates a matrix using the rotated arrays, like the for loop in the above algorithm in one go
for i in range(1, len(pos)):
# The distance between two objects
dist[i - 1] = pos2[i: i + num_masses] - pos
# The mass of the other object
rot_mass[i - 1] = mass2[i: i + num_masses]
# Normalize direction vectors, and ensure the distances aren't too close for numerical stability for i in range(1, len(pos)):
dist[i - 1] = pos2[i: i + n] - pos
rot_mass[i - 1] = mass2[i: i + n]
norms = np.linalg.norm(dist, axis=2) norms = np.linalg.norm(dist, axis=2)
if constrain: if constrain:
norms[norms < constrain] = constrain norms[norms < constrain] = constrain
# Calculate all changes in velocity at once, using the same method described and implemented above vel += G * np.sum(
vel += gravity * np.sum(
dist * rot_mass / (norms ** 3)[:, :, np.newaxis], dist * rot_mass / (norms ** 3)[:, :, np.newaxis],
axis=0 axis=0
) )
# Update positions
pos += vel pos += vel

View file

@ -1,11 +1,11 @@
contourpy==1.3.1 contourpy==1.1.1
cycler==0.12.1 cycler==0.12.1
fonttools==4.55.3 fonttools==4.43.1
kiwisolver==1.4.8 kiwisolver==1.4.5
matplotlib==3.10.0 matplotlib==3.8.0
numpy==2.2.1 numpy==1.26.0
packaging==24.2 packaging==23.2
pillow==11.0.0 Pillow==10.0.1
pyparsing==3.2.1 pyparsing==3.1.1
python-dateutil==2.9.0.post0 python-dateutil==2.8.2
six==1.17.0 six==1.16.0