#!/usr/bin/env python
"""
Two lines on the same plot.
All points are on a 3 column file, and the first column
decides which line each (x, y) point belongs to.

- http://stackoverflow.com/questions/31863083/python-split-numpy-array-based-on-values-in-the-array
- http://stackoverflow.com/questions/33622888/how-to-plot-2-lines-based-on-the-value-not-column pandas example
"""
import imp
import os
import sys

from cycler import cycler
import matplotlib
import matplotlib.pyplot as plt
import numpy

out_ext = 'png'
time_col = 0

position_cols = [2, 3, 4]
velocity_cols = [5, 6, 7]
rotation_cols = [8, 9, 10, 11]
angular_velocity_cols = [12, 13, 14]
nplots_per_body = 4

def plot_cols(cols, title, nplots, key, plot_index):
    for c in cols:
        ax = plt.subplot(nplots, 1, plot_index + 1)
        ax.plot(
            v[:, time_col],
            v[:, c],
            'x',
            label=str(c)
        )
        ax.set_title('{} {}'.format(title, key))

if len(sys.argv) > 1:
    data_path = sys.argv[1]
    data_name = os.path.splitext(os.path.split(data_path)[1])[0]
    data_file = open(data_path, 'r')
else:
    data_name = 'stdin'
    data_file = sys.stdin

a = numpy.loadtxt(
    data_file,
    # TODO: convert first two columns to int here.
    # dtype=([('a','i4')] * 2 + [('a','f4')] * 6),
    skiprows=1,
    usecols=range(15)
)
matplotlib.rcParams['axes.prop_cycle'] = cycler('color', ['r', 'g', 'b', 'y'])
body_col = 1
keys = list(set(a[:, body_col]))
nplots = nplots_per_body * len(keys)
fig, axs = plt.subplots()
for key in keys:
    key = int(key)
    v = a[a[:, body_col] == key, :]
    plot_cols(position_cols,         'pos',     nplots, key, nplots_per_body * key    )
    plot_cols(velocity_cols,         'vel',     nplots, key, nplots_per_body * key + 1)
    plot_cols(rotation_cols,          'rot',    nplots, key, nplots_per_body * key + 2)
    plot_cols(angular_velocity_cols, 'ang_vel', nplots, key, nplots_per_body * key + 3)
# TODO custom per figure plots.
# try:
    # custom_plotter = imp.load_source(name, name + '.py')
# except IOError:
    # raise
# else:
    # custom_plotter.plot(plt, data_path)
fig.set_size_inches(8, nplots * 3)
plt.savefig(data_name + '.' + out_ext, dpi=100)
