From 424721ccfceeec69647e1e6f65b0288e789595a7 Mon Sep 17 00:00:00 2001 From: Michael Bradley Date: Sun, 8 Oct 2023 12:13:53 -0400 Subject: [PATCH] Support 3d simulations --- data.py | 69 +++++++++++++++++++++++++++++++++++++---------------- data/3d.csv | 2 ++ main.py | 11 ++++++++- physics.py | 34 +++++++------------------- 4 files changed, 70 insertions(+), 46 deletions(-) create mode 100644 data/3d.csv diff --git a/data.py b/data.py index 5098d0f..1222c20 100644 --- a/data.py +++ b/data.py @@ -2,20 +2,28 @@ import matplotlib.animation as animation import matplotlib.cm as cm import matplotlib.pyplot as plt import numpy as np - import physics -def parse_csv(filename: str): +def parse_csv(filename: str, dimensions=2): + if not (1 < dimensions < 4): + raise ValueError(f"Can only show 2or 3 dimensional scenes, not {dimensions}") with open(filename, 'r') as file: lines = file.read().strip().splitlines() - pos = np.zeros((len(lines), 2)) - vel = np.zeros((len(lines), 2)) + pos = np.zeros((len(lines), dimensions)) + vel = np.zeros((len(lines), dimensions)) rad = np.zeros((len(lines), 1)) - for i, [x, y, vx, vy, r] in enumerate(map(lambda l: map(float, l.split(',')), lines)): - pos[i] = [x, y] - vel[i] = [vx, vy] - rad[i] = r + for i, values in enumerate(map(lambda l: map(float, l.split(',')), lines)): + if dimensions == 2: + [x, y, vx, vy, r] = values + pos[i] = [x, y] + vel[i] = [vx, vy] + rad[i] = r + elif dimensions == 3: + [x, y, z, vx, vy, vz, r] = values + pos[i] = [x, y, z] + vel[i] = [vx, vy, vz] + rad[i] = r return pos, vel, rad @@ -26,14 +34,20 @@ class Animator: self.rad = rad self.mass = np.pi * 4 / 3 * rad ** 3 - self.scat = None + n, d = self.pos.shape + + self.scat: plt.PathCollection = None self.colours = cm.rainbow( np.random.random( - (len(self.rad),) + (n,) ) ) - self.fig, self.ax = plt.subplots() + self.fig = plt.figure() + if d == 2: + self.ax = self.fig.add_subplot() + else: + self.ax = self.fig.add_subplot(projection="3d") self.ani = animation.FuncAnimation( self.fig, self.update, @@ -44,18 +58,33 @@ class Animator: ) def setup_plot(self): - self.scat = self.ax.scatter( - self.pos[:, 0], - self.pos[:, 1], - c=self.colours, - s=self.rad * 10 - ) - self.ax.axis([-950, 950, -500, 500]) + _n, d = self.pos.shape + if d == 2: + self.scat = self.ax.scatter( + self.pos[:, 0], + self.pos[:, 1], + c=self.colours, + s=self.rad * 10 + ) + self.ax.axis([-950, 950, -500, 500]) + else: + self.scat = self.ax.scatter( + self.pos[:, 0], + self.pos[:, 1], + self.pos[:, 2], + c=self.colours, + s=self.rad * 10 + ) + self.ax.axis([-500, 500, -500, 500, -500, 500]) return self.scat, def update(self, *_args, **_kwargs): - physics.n_body_matrix_constrained(self.pos, self.vel, self.mass) - self.scat.set_offsets(self.pos) + _n, d = self.pos.shape + physics.n_body_matrix(self.pos, self.vel, self.mass, constrain=2.) + self.scat.set_offsets(self.pos[:, :2]) + if d == 3: + self.scat.set_3d_properties(self.pos[:, 2], 'z') + self.fig.canvas.draw() return self.scat, def show(self): diff --git a/data/3d.csv b/data/3d.csv new file mode 100644 index 0000000..7241158 --- /dev/null +++ b/data/3d.csv @@ -0,0 +1,2 @@ +0,50,10,5,0,-1,11 +0,-50,-10,-5,0,1,11 diff --git a/main.py b/main.py index c80d4cf..0d62070 100755 --- a/main.py +++ b/main.py @@ -1,5 +1,6 @@ #!./venv/bin/python import argparse +import typing import data import physics @@ -8,6 +9,7 @@ import physics class Args: filename: str gravity: float + dimensions: typing.Literal[2, 3] if __name__ == "__main__": @@ -27,11 +29,18 @@ if __name__ == "__main__": type=float, default=1. ) + parser.add_argument( + "-d", + "--dimensions", + type=int, + choices=[2, 3], + default=2 + ) args: Args = parser.parse_args() physics.G = args.gravity - objects = data.parse_csv(args.filename) + objects = data.parse_csv(args.filename, dimensions=args.dimensions) a = data.Animator(*objects) a.show() diff --git a/physics.py b/physics.py index a47e5cd..4965775 100644 --- a/physics.py +++ b/physics.py @@ -17,38 +17,22 @@ def n_body(pos: np.ndarray, vel: np.ndarray, mass: np.ndarray): pos += vel -def n_body_matrix(pos: np.ndarray, vel: np.ndarray, mass: np.ndarray): - dist = np.zeros((len(pos) - 1, len(pos), 2)) - rot_mass = np.zeros((len(mass) - 1, len(mass), 1)) +def n_body_matrix(pos: np.ndarray, vel: np.ndarray, mass: np.ndarray, constrain=2.): + n, d = pos.shape + dist = np.zeros((n - 1, n, d)) + rot_mass = np.zeros((n - 1, n, 1)) pos2 = np.concatenate((pos, pos)) mass2 = np.concatenate((mass, mass)) for i in range(1, len(pos)): - dist[i - 1] = pos2[i: i + len(pos)] - pos - rot_mass[i - 1] = mass2[i: i + len(mass)] - - vel += G * np.sum( - dist * rot_mass / (np.linalg.norm(dist, axis=2) ** 3)[:, :, np.newaxis], - axis=0 - ) - - pos += vel - - -def n_body_matrix_constrained(pos: np.ndarray, vel: np.ndarray, mass: np.ndarray, close=2.): - dist = np.zeros((len(pos) - 1, len(pos), 2)) - rot_mass = np.zeros((len(mass) - 1, len(mass), 1)) - - pos2 = np.concatenate((pos, pos)) - mass2 = np.concatenate((mass, mass)) - - for i in range(1, len(pos)): - dist[i - 1] = pos2[i: i + len(pos)] - pos - rot_mass[i - 1] = mass2[i: i + len(mass)] + dist[i - 1] = pos2[i: i + n] - pos + rot_mass[i - 1] = mass2[i: i + n] norms = np.linalg.norm(dist, axis=2) - norms[norms < close] = close + if constrain: + norms[norms < constrain] = constrain + vel += G * np.sum( dist * rot_mass / (norms ** 3)[:, :, np.newaxis], axis=0