Support 3d simulations

This commit is contained in:
Michael Bradley 2023-10-08 12:13:53 -04:00
parent 4a31fc1a7d
commit 424721ccfc
4 changed files with 70 additions and 46 deletions

69
data.py
View file

@ -2,20 +2,28 @@ import matplotlib.animation as animation
import matplotlib.cm as cm import matplotlib.cm as cm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import physics import physics
def parse_csv(filename: str): def parse_csv(filename: str, dimensions=2):
if not (1 < dimensions < 4):
raise ValueError(f"Can only show 2or 3 dimensional scenes, not {dimensions}")
with open(filename, 'r') as file: with open(filename, 'r') as file:
lines = file.read().strip().splitlines() lines = file.read().strip().splitlines()
pos = np.zeros((len(lines), 2)) pos = np.zeros((len(lines), dimensions))
vel = np.zeros((len(lines), 2)) vel = np.zeros((len(lines), dimensions))
rad = np.zeros((len(lines), 1)) rad = np.zeros((len(lines), 1))
for i, [x, y, vx, vy, r] in enumerate(map(lambda l: map(float, l.split(',')), lines)): for i, values in enumerate(map(lambda l: map(float, l.split(',')), lines)):
pos[i] = [x, y] if dimensions == 2:
vel[i] = [vx, vy] [x, y, vx, vy, r] = values
rad[i] = r pos[i] = [x, y]
vel[i] = [vx, vy]
rad[i] = r
elif dimensions == 3:
[x, y, z, vx, vy, vz, r] = values
pos[i] = [x, y, z]
vel[i] = [vx, vy, vz]
rad[i] = r
return pos, vel, rad return pos, vel, rad
@ -26,14 +34,20 @@ class Animator:
self.rad = rad self.rad = rad
self.mass = np.pi * 4 / 3 * rad ** 3 self.mass = np.pi * 4 / 3 * rad ** 3
self.scat = None n, d = self.pos.shape
self.scat: plt.PathCollection = None
self.colours = cm.rainbow( self.colours = cm.rainbow(
np.random.random( np.random.random(
(len(self.rad),) (n,)
) )
) )
self.fig, self.ax = plt.subplots() self.fig = plt.figure()
if d == 2:
self.ax = self.fig.add_subplot()
else:
self.ax = self.fig.add_subplot(projection="3d")
self.ani = animation.FuncAnimation( self.ani = animation.FuncAnimation(
self.fig, self.fig,
self.update, self.update,
@ -44,18 +58,33 @@ class Animator:
) )
def setup_plot(self): def setup_plot(self):
self.scat = self.ax.scatter( _n, d = self.pos.shape
self.pos[:, 0], if d == 2:
self.pos[:, 1], self.scat = self.ax.scatter(
c=self.colours, self.pos[:, 0],
s=self.rad * 10 self.pos[:, 1],
) c=self.colours,
self.ax.axis([-950, 950, -500, 500]) s=self.rad * 10
)
self.ax.axis([-950, 950, -500, 500])
else:
self.scat = self.ax.scatter(
self.pos[:, 0],
self.pos[:, 1],
self.pos[:, 2],
c=self.colours,
s=self.rad * 10
)
self.ax.axis([-500, 500, -500, 500, -500, 500])
return self.scat, return self.scat,
def update(self, *_args, **_kwargs): def update(self, *_args, **_kwargs):
physics.n_body_matrix_constrained(self.pos, self.vel, self.mass) _n, d = self.pos.shape
self.scat.set_offsets(self.pos) physics.n_body_matrix(self.pos, self.vel, self.mass, constrain=2.)
self.scat.set_offsets(self.pos[:, :2])
if d == 3:
self.scat.set_3d_properties(self.pos[:, 2], 'z')
self.fig.canvas.draw()
return self.scat, return self.scat,
def show(self): def show(self):

2
data/3d.csv Normal file
View file

@ -0,0 +1,2 @@
0,50,10,5,0,-1,11
0,-50,-10,-5,0,1,11
1 0 50 10 5 0 -1 11
2 0 -50 -10 -5 0 1 11

11
main.py
View file

@ -1,5 +1,6 @@
#!./venv/bin/python #!./venv/bin/python
import argparse import argparse
import typing
import data import data
import physics import physics
@ -8,6 +9,7 @@ import physics
class Args: class Args:
filename: str filename: str
gravity: float gravity: float
dimensions: typing.Literal[2, 3]
if __name__ == "__main__": if __name__ == "__main__":
@ -27,11 +29,18 @@ if __name__ == "__main__":
type=float, type=float,
default=1. default=1.
) )
parser.add_argument(
"-d",
"--dimensions",
type=int,
choices=[2, 3],
default=2
)
args: Args = parser.parse_args() args: Args = parser.parse_args()
physics.G = args.gravity physics.G = args.gravity
objects = data.parse_csv(args.filename) objects = data.parse_csv(args.filename, dimensions=args.dimensions)
a = data.Animator(*objects) a = data.Animator(*objects)
a.show() a.show()

View file

@ -17,38 +17,22 @@ def n_body(pos: np.ndarray, vel: np.ndarray, mass: np.ndarray):
pos += vel pos += vel
def n_body_matrix(pos: np.ndarray, vel: np.ndarray, mass: np.ndarray): def n_body_matrix(pos: np.ndarray, vel: np.ndarray, mass: np.ndarray, constrain=2.):
dist = np.zeros((len(pos) - 1, len(pos), 2)) n, d = pos.shape
rot_mass = np.zeros((len(mass) - 1, len(mass), 1)) dist = np.zeros((n - 1, n, d))
rot_mass = np.zeros((n - 1, n, 1))
pos2 = np.concatenate((pos, pos)) pos2 = np.concatenate((pos, pos))
mass2 = np.concatenate((mass, mass)) mass2 = np.concatenate((mass, mass))
for i in range(1, len(pos)): for i in range(1, len(pos)):
dist[i - 1] = pos2[i: i + len(pos)] - pos dist[i - 1] = pos2[i: i + n] - pos
rot_mass[i - 1] = mass2[i: i + len(mass)] rot_mass[i - 1] = mass2[i: i + n]
vel += G * np.sum(
dist * rot_mass / (np.linalg.norm(dist, axis=2) ** 3)[:, :, np.newaxis],
axis=0
)
pos += vel
def n_body_matrix_constrained(pos: np.ndarray, vel: np.ndarray, mass: np.ndarray, close=2.):
dist = np.zeros((len(pos) - 1, len(pos), 2))
rot_mass = np.zeros((len(mass) - 1, len(mass), 1))
pos2 = np.concatenate((pos, pos))
mass2 = np.concatenate((mass, mass))
for i in range(1, len(pos)):
dist[i - 1] = pos2[i: i + len(pos)] - pos
rot_mass[i - 1] = mass2[i: i + len(mass)]
norms = np.linalg.norm(dist, axis=2) norms = np.linalg.norm(dist, axis=2)
norms[norms < close] = close if constrain:
norms[norms < constrain] = constrain
vel += G * np.sum( vel += G * np.sum(
dist * rot_mass / (norms ** 3)[:, :, np.newaxis], dist * rot_mass / (norms ** 3)[:, :, np.newaxis],
axis=0 axis=0