import json

from listenbrainz_spark.stats.listener.entity import get_listener_stats
from listenbrainz_spark.stats.user.tests import StatsTestCase


class EntityListenerTestCase(StatsTestCase):

    def test_get_artists(self):
        with open(self.path_to_data_file("user_top_artist_listeners_output.json")) as f:
            expected = json.load(f)

        messages = list(get_listener_stats("artists", "all_time"))

        self.assertEqual(messages[0]["type"], "couchdb_data_start")
        self.assertEqual(messages[0]["database"], "artists_listeners_all_time")

        self.assertEqual(messages[1]["type"], expected[0]["type"])
        self.assertEqual(messages[1]["entity"], expected[0]["entity"])
        self.assertEqual(messages[1]["stats_range"], expected[0]["stats_range"])
        self.assertEqual(messages[1]["from_ts"], expected[0]["from_ts"])
        self.assertEqual(messages[1]["to_ts"], expected[0]["to_ts"])
        self.assertCountEqual(messages[1]["data"], expected[0]["data"])
        self.assertCountEqual(messages[1]["database"], "artists_listeners_all_time")

        self.assertEqual(messages[2]["type"], "couchdb_data_end")
        self.assertEqual(messages[2]["database"], "artists_listeners_all_time")

    def test_get_release_groups(self):
        with open(self.path_to_data_file("user_top_release_group_listeners_output.json")) as f:
            expected = json.load(f)

        messages = list(get_listener_stats("release_groups", "all_time"))

        self.assertEqual(messages[0]["type"], "couchdb_data_start")
        self.assertEqual(messages[0]["database"], "release_groups_listeners_all_time")

        self.assertEqual(messages[1]["type"], expected[0]["type"])
        self.assertEqual(messages[1]["entity"], expected[0]["entity"])
        self.assertEqual(messages[1]["stats_range"], expected[0]["stats_range"])
        self.assertEqual(messages[1]["from_ts"], expected[0]["from_ts"])
        self.assertEqual(messages[1]["to_ts"], expected[0]["to_ts"])
        self.assertCountEqual(messages[1]["data"], expected[0]["data"])
        self.assertCountEqual(messages[1]["database"], "release_groups_listeners_all_time")

        self.assertEqual(messages[2]["type"], "couchdb_data_end")
        self.assertEqual(messages[2]["database"], "release_groups_listeners_all_time")
