Add slightly better n_body_matrix

This commit is contained in:
Michael Bradley 2023-10-07 21:10:37 -04:00
parent aca44fc3b0
commit e1b55dbad2
2 changed files with 22 additions and 3 deletions

View file

@ -3,7 +3,7 @@ import matplotlib.cm as cm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from physics import n_body from physics import n_body_matrix
def parse_csv(filename: str): def parse_csv(filename: str):
@ -54,7 +54,7 @@ class Animator:
return self.scat, return self.scat,
def update(self, *_args, **_kwargs): def update(self, *_args, **_kwargs):
n_body(self.pos, self.vel, self.mass) n_body_matrix(self.pos, self.vel, self.mass)
self.scat.set_offsets(self.pos) self.scat.set_offsets(self.pos)
return self.scat, return self.scat,

View file

@ -13,5 +13,24 @@ def rotations(a: np.ndarray):
def n_body(pos: np.ndarray, vel: np.ndarray, mass: np.ndarray): def n_body(pos: np.ndarray, vel: np.ndarray, mass: np.ndarray):
for (o_pos, o_mass) in zip(rotations(pos), rotations(mass)): for (o_pos, o_mass) in zip(rotations(pos), rotations(mass)):
dist = o_pos - pos dist = o_pos - pos
vel += G * (dist / np.linalg.norm(dist, axis=1)[:, np.newaxis]) * o_mass / np.sum(dist ** 2, axis=1)[:, np.newaxis] vel += G * dist * o_mass / (np.linalg.norm(dist, axis=1) ** 3)[:, np.newaxis]
pos += vel
def n_body_matrix(pos: np.ndarray, vel: np.ndarray, mass: np.ndarray):
dist = np.zeros((len(pos) - 1, len(pos), 2))
rot_mass = np.zeros((len(mass) - 1, len(mass), 1))
pos2 = np.concatenate((pos, pos))
mass2 = np.concatenate((mass, mass))
for i in range(1, len(pos)):
dist[i - 1] = pos2[i: i + len(pos)] - pos
rot_mass[i - 1] = mass2[i: i + len(mass)]
vel += G * np.sum(
dist * rot_mass / (np.linalg.norm(dist, axis=2) ** 3)[:, :, np.newaxis],
axis=0
)
pos += vel pos += vel