diff --git a/data.py b/data.py index 40b7b8e..d9a2417 100644 --- a/data.py +++ b/data.py @@ -11,11 +11,11 @@ def parse_csv(filename: str): lines = file.read().strip().splitlines() pos = np.zeros((len(lines), 2)) vel = np.zeros((len(lines), 2)) - rad = np.zeros((len(lines),)) + rad = np.zeros((len(lines), 1)) for i, [x, y, vx, vy, r] in enumerate(map(lambda l: map(float, l.split(',')), lines)): pos[i] = [x, y] vel[i] = [vx, vy] - rad[i] = r + rad[i] = [r] return pos, vel, rad diff --git a/physics.py b/physics.py index 43060be..0d73b9e 100644 --- a/physics.py +++ b/physics.py @@ -8,10 +8,12 @@ def rotations(a: np.ndarray): a2 = np.concatenate((a, a)) for i in range(1, len(a)): yield np.split(a2, [i, i + len(a)])[1] + # TODO: Compare performance + # yield np.roll(a, i) def n_body(pos: np.ndarray, vel: np.ndarray, mass: np.ndarray): - for (o_pos, o_mass) in zip(rotations(pos), rotations(mass[:, np.newaxis])): + for (o_pos, o_mass) in zip(rotations(pos), rotations(mass)): dist = o_pos - pos vel += G * (dist / np.linalg.norm(dist, axis=1)[:, np.newaxis]) * o_mass / np.sum(dist ** 2, axis=1)[:, np.newaxis] pos += vel