import tarfile
from io import BytesIO
from pathlib import Path

import numpy as np
import pytest

from lhotse import CutSet
from lhotse.audio.backend import audio_backend, check_torchaudio_version_gt
from lhotse.lazy import LazyJsonlIterator
from lhotse.shar import AudioTarWriter, SharWriter, TarIterator, TarWriter
from lhotse.testing.dummies import DummyManifest, dummy_cut


def test_tar_writer(tmp_path: Path):
    with TarWriter(str(tmp_path / "test.%06d.tar"), shard_size=10) as writer:
        writer.write("test.txt", BytesIO(b"test"))

    assert writer.output_paths == [str(tmp_path / "test.000000.tar")]

    with tarfile.open(tmp_path / "test.000000.tar") as f:
        f2 = f.extractfile(f.getmember("test.txt"))
        assert f2.read() == b"test"


def test_tar_writer_not_sharded(tmp_path: Path, caplog):
    with TarWriter(str(tmp_path / "test.tar"), shard_size=None) as writer:
        writer.write("test.txt", BytesIO(b"test"))

    assert (
        "Sharding is disabled because `pattern` doesn't contain a formatting marker (e.g., '%06d'), "
        "but shard_size is not None - ignoring shard_size."
    ) not in caplog.text

    assert writer.output_paths == [str(tmp_path / "test.tar")]

    with tarfile.open(tmp_path / "test.tar") as f:
        f2 = f.extractfile(f.getmember("test.txt"))
        assert f2.read() == b"test"


def test_tar_writer_not_sharded_with_shard_size(tmp_path: Path, caplog):
    with TarWriter(str(tmp_path / "test.tar"), shard_size=10) as writer:
        writer.write("test.txt", BytesIO(b"test"))

    assert (
        "Sharding is disabled because `pattern` doesn't contain a formatting marker (e.g., '%06d'), "
        "but shard_size is not None - ignoring shard_size."
    ) in caplog.text

    assert writer.output_paths == [str(tmp_path / "test.tar")]

    with tarfile.open(tmp_path / "test.tar") as f:
        f2 = f.extractfile(f.getmember("test.txt"))
        assert f2.read() == b"test"


def test_tar_writer_pipe(tmp_path: Path):
    with TarWriter(f"pipe:cat > {tmp_path}/test.%06d.tar", shard_size=10) as writer:
        writer.write("test.txt", BytesIO(b"test"))

    assert writer.output_paths == [f"pipe:cat > {tmp_path}/test.000000.tar"]

    with tarfile.open(tmp_path / "test.000000.tar") as f:
        f2 = f.extractfile(f.getmember("test.txt"))
        assert f2.read() == b"test"


@pytest.mark.parametrize(
    ["format", "backend"],
    [
        ("flac", "default"),
        ("flac", "LibsndfileBackend"),
        ("flac", "TorchaudioDefaultBackend"),
        pytest.param(
            "flac",
            "TorchaudioFFMPEGBackend",
            marks=pytest.mark.skipif(
                not check_torchaudio_version_gt("2.1.0"),
                reason="Older torchaudio versions don't support FFMPEG.",
            ),
        ),
        ("opus", "default"),
        ("opus", "LibsndfileBackend"),
        pytest.param(
            "opus",
            "TorchaudioDefaultBackend",
            marks=pytest.mark.skipif(
                not check_torchaudio_version_gt("2.1.0"),
                reason="Older torchaudio versions won't support writing OPUS.",
            ),
        ),
        pytest.param(
            "opus",
            "TorchaudioFFMPEGBackend",
            marks=pytest.mark.skipif(
                not check_torchaudio_version_gt("2.1.0"),
                reason="Older torchaudio versions won't support writing OPUS.",
            ),
        ),
    ],
)
def test_audio_tar_writer(tmp_path: Path, format: str, backend: str):
    from lhotse.testing.dummies import dummy_recording

    recording = dummy_recording(0, with_data=True)
    audio = recording.load_audio()

    with audio_backend(backend):
        with AudioTarWriter(
            str(tmp_path / "test.tar"), shard_size=None, format=format
        ) as writer:
            writer.write(
                key="my-recording",
                value=audio,
                sampling_rate=recording.sampling_rate,
                manifest=recording,
            )
        (path,) = writer.output_paths
        ((deserialized_recording, inner_path),) = list(TarIterator(path))
        deserialized_audio = deserialized_recording.resample(
            recording.sampling_rate
        ).load_audio()

    rmse = np.sqrt(np.mean((audio - deserialized_audio) ** 2))
    assert rmse < 0.5


@pytest.mark.parametrize(
    ["original_format", "rmse_threshold"],
    [("wav", 0.0), ("flac", 0.0), ("mp3", 0.003), ("opus", 0.3)],
)
def test_audio_tar_writer_original_format(
    tmp_path: Path, original_format: str, rmse_threshold: float
):
    """Test using AudioTarWritter to write the audio signal in the exact same format
    as it was loaded from the source.
    """
    from lhotse.testing.dummies import dummy_recording

    backend = "default"  # use the default backend for reading the audio
    writer_format = "original"  # write the audio in the same format as it was loaded

    recording = dummy_recording(0, with_data=True, source_format=original_format)
    audio = recording.load_audio()

    assert (
        recording.source_format == original_format
    ), f"Recording source format ({recording.source_format}) not matching the expected original format ({original_format})"

    with audio_backend(backend):
        with AudioTarWriter(
            str(tmp_path / "test.tar"), shard_size=None, format=writer_format
        ) as writer:
            writer.write(
                key="my-recording",
                value=audio,
                sampling_rate=recording.sampling_rate,
                manifest=recording,
                original_format=recording.source_format,
            )
        (path,) = writer.output_paths
        ((deserialized_recording, inner_path),) = list(TarIterator(path))

        # make sure the deserialized audio is in the same format as the original
        assert (
            deserialized_recording.source_format == original_format
        ), f"Deserialized recording source format ({deserialized_recording.source_format}) not matching the expected original format ({original_format})"

        # load audio
        deserialized_audio = deserialized_recording.resample(
            recording.sampling_rate
        ).load_audio()

    # check difference between original and deserialized audio
    rmse = np.sqrt(np.mean((audio - deserialized_audio) ** 2))
    assert (
        rmse <= rmse_threshold
    ), f"RMSE between original and deserialized audio is {rmse}, which is above the threshold of {rmse_threshold}"


def test_shar_writer(tmp_path: Path):
    # Prepare data
    cuts = DummyManifest(CutSet, begin_id=0, end_id=20, with_data=True)

    # Prepare system under test
    writer = SharWriter(
        tmp_path,
        fields={
            "recording": "wav",
            "features": "lilcom",
            "custom_embedding": "numpy",
            "custom_features": "lilcom",
            "custom_indexes": "numpy",
            "custom_recording": "wav",
        },
        shard_size=10,
    )

    # Actual test
    with writer:
        for c in cuts:
            writer.write(c)

    # Post-conditions

    assert writer.output_paths == {
        "cuts": [
            str(tmp_path / "cuts.000000.jsonl.gz"),
            str(tmp_path / "cuts.000001.jsonl.gz"),
        ],
        "recording": [
            str(tmp_path / "recording.000000.tar"),
            str(tmp_path / "recording.000001.tar"),
        ],
        "features": [
            str(tmp_path / "features.000000.tar"),
            str(tmp_path / "features.000001.tar"),
        ],
        "custom_embedding": [
            str(tmp_path / "custom_embedding.000000.tar"),
            str(tmp_path / "custom_embedding.000001.tar"),
        ],
        "custom_features": [
            str(tmp_path / "custom_features.000000.tar"),
            str(tmp_path / "custom_features.000001.tar"),
        ],
        "custom_indexes": [
            str(tmp_path / "custom_indexes.000000.tar"),
            str(tmp_path / "custom_indexes.000001.tar"),
        ],
        "custom_recording": [
            str(tmp_path / "custom_recording.000000.tar"),
            str(tmp_path / "custom_recording.000001.tar"),
        ],
    }

    # - we created 2 shards with cutsets and a separate file for each data field
    for fname in (
        "cuts.000000.jsonl.gz",
        "cuts.000001.jsonl.gz",
        "recording.000000.tar",
        "recording.000001.tar",
        "features.000000.tar",
        "features.000001.tar",
        "custom_embedding.000000.tar",
        "custom_embedding.000001.tar",
        "custom_features.000000.tar",
        "custom_features.000001.tar",
        "custom_indexes.000000.tar",
        "custom_indexes.000001.tar",
        "custom_recording.000000.tar",
        "custom_recording.000001.tar",
    ):
        assert (tmp_path / fname).is_file()

    # - we didn't create a third shard
    assert not (tmp_path / "cuts.000002.jsonl.gz").exists()

    # - the cuts do not have any data actually attached to them,
    #   so it's impossible to load it if we open it as a normal CutSet
    for cut in CutSet.from_file(tmp_path / "cuts.000000.jsonl.gz"):
        assert cut.recording.sources[0].type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_audio()

        assert cut.features.storage_type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_features()

        assert cut.custom_embedding.storage_type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_custom_embedding()

        assert cut.custom_features.array.storage_type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_custom_features()

        assert cut.custom_indexes.array.storage_type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_custom_indexes()

        assert cut.custom_recording.sources[0].type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_custom_recording()


def test_shar_writer_custom_nondata_attribute(tmp_path: Path):
    # Prepare data
    cuts = DummyManifest(CutSet, begin_id=0, end_id=20)

    # Prepare system under test
    writer = SharWriter(
        tmp_path,
        fields={"custom_attribute": "jsonl"},
        shard_size=10,
    )

    # Actual test
    with writer:
        for c in cuts:
            writer.write(c)

    # Post-conditions

    assert writer.output_paths == {
        "cuts": [
            str(tmp_path / "cuts.000000.jsonl.gz"),
            str(tmp_path / "cuts.000001.jsonl.gz"),
        ],
        "custom_attribute": [
            str(tmp_path / "custom_attribute.000000.jsonl.gz"),
            str(tmp_path / "custom_attribute.000001.jsonl.gz"),
        ],
    }

    # - we created 2 shards with cutsets and a separate file for each data field
    for fname in (
        "cuts.000000.jsonl.gz",
        "cuts.000001.jsonl.gz",
        "custom_attribute.000000.jsonl.gz",
        "custom_attribute.000001.jsonl.gz",
    ):
        assert (tmp_path / fname).is_file()

    # - we didn't create a third shard
    assert not (tmp_path / "custom_attribute.000002.jsonl.gz").exists()

    # The custom_attribute contains valid JSON lines
    for suffix in (
        ".000000.jsonl.gz",
        ".000001.jsonl.gz",
    ):
        for cut, data in zip(
            CutSet.from_file(tmp_path / f"cuts{suffix}"),
            LazyJsonlIterator(tmp_path / f"custom_attribute{suffix}"),
        ):
            assert isinstance(data, dict)
            assert "cut_id" in data
            assert data["cut_id"] == cut.id
            assert "custom_attribute" in data
            assert data["custom_attribute"] == "dummy-value"
            assert len(data) == 2  # nothing else


def test_shar_writer_custom_nondata_attribute_missing(tmp_path: Path):
    # Prepare data
    cuts = DummyManifest(CutSet, begin_id=0, end_id=20)
    for idx, c in enumerate(cuts):
        if idx == 5:
            del c.custom["custom_attribute"]

    # Prepare system under test
    writer = SharWriter(
        tmp_path,
        fields={"custom_attribute": "jsonl"},
        shard_size=10,
    )

    # Actual test
    with writer:
        for c in cuts:
            writer.write(c)

    # Post-conditions

    assert writer.output_paths == {
        "cuts": [
            str(tmp_path / "cuts.000000.jsonl.gz"),
            str(tmp_path / "cuts.000001.jsonl.gz"),
        ],
        "custom_attribute": [
            str(tmp_path / "custom_attribute.000000.jsonl.gz"),
            str(tmp_path / "custom_attribute.000001.jsonl.gz"),
        ],
    }

    # - we created 2 shards with cutsets and a separate file for each data field
    for fname in (
        "cuts.000000.jsonl.gz",
        "cuts.000001.jsonl.gz",
        "custom_attribute.000000.jsonl.gz",
        "custom_attribute.000001.jsonl.gz",
    ):
        assert (tmp_path / fname).is_file()

    # - we didn't create a third shard
    assert not (tmp_path / "custom_attribute.000002.jsonl.gz").exists()

    # The custom_attribute contains valid JSON lines
    attrs = list(LazyJsonlIterator(tmp_path / f"custom_attribute.000000.jsonl.gz"))
    assert "cut_id" in attrs[5]
    assert "custom_attribute" not in attrs[5]
    for attr in attrs[:5] + attrs[6:]:
        assert "cut_id" in attr
        assert "custom_attribute" in attr


def test_cut_set_to_shar(tmp_path: Path):
    # Prepare data
    cuts = DummyManifest(CutSet, begin_id=0, end_id=20, with_data=True)

    # Prepare system under test
    output_paths = cuts.to_shar(
        tmp_path,
        fields={
            "recording": "wav",
            "features": "lilcom",
            "custom_embedding": "numpy",
            "custom_features": "lilcom",
            "custom_indexes": "numpy",
            "custom_recording": "wav",
        },
        shard_size=10,
    )

    # Post-conditions

    assert output_paths == {
        "cuts": [
            str(tmp_path / "cuts.000000.jsonl.gz"),
            str(tmp_path / "cuts.000001.jsonl.gz"),
        ],
        "recording": [
            str(tmp_path / "recording.000000.tar"),
            str(tmp_path / "recording.000001.tar"),
        ],
        "features": [
            str(tmp_path / "features.000000.tar"),
            str(tmp_path / "features.000001.tar"),
        ],
        "custom_embedding": [
            str(tmp_path / "custom_embedding.000000.tar"),
            str(tmp_path / "custom_embedding.000001.tar"),
        ],
        "custom_features": [
            str(tmp_path / "custom_features.000000.tar"),
            str(tmp_path / "custom_features.000001.tar"),
        ],
        "custom_indexes": [
            str(tmp_path / "custom_indexes.000000.tar"),
            str(tmp_path / "custom_indexes.000001.tar"),
        ],
        "custom_recording": [
            str(tmp_path / "custom_recording.000000.tar"),
            str(tmp_path / "custom_recording.000001.tar"),
        ],
    }

    # - we created 2 shards with cutsets and a separate file for each data field
    for fname in (
        "cuts.000000.jsonl.gz",
        "cuts.000001.jsonl.gz",
        "recording.000000.tar",
        "recording.000001.tar",
        "features.000000.tar",
        "features.000001.tar",
        "custom_embedding.000000.tar",
        "custom_embedding.000001.tar",
        "custom_features.000000.tar",
        "custom_features.000001.tar",
        "custom_indexes.000000.tar",
        "custom_indexes.000001.tar",
        "custom_recording.000000.tar",
        "custom_recording.000001.tar",
    ):
        assert (tmp_path / fname).is_file()

    # - we didn't create a third shard
    assert not (tmp_path / "cuts.000002.jsonl.gz").exists()


def test_cut_set_to_shar_not_include_cuts(tmp_path: Path):
    # Prepare data
    cuts = DummyManifest(CutSet, begin_id=0, end_id=20, with_data=True)

    # Prepare system under test
    output_paths = cuts.to_shar(
        tmp_path,
        fields={"recording": "wav"},
        shard_size=10,
        include_cuts=False,
    )

    # Post-conditions

    # - we created 2 shards only for recordings
    assert output_paths == {
        "recording": [
            str(tmp_path / "recording.000000.tar"),
            str(tmp_path / "recording.000001.tar"),
        ],
    }
    for fname in (
        "recording.000000.tar",
        "recording.000001.tar",
    ):
        assert (tmp_path / fname).is_file()

    # - we did not create the shards for any other field, including cuts
    for fname in (
        "cuts.000000.jsonl.gz",
        "cuts.000001.jsonl.gz",
        "features.000000.tar",
        "features.000001.tar",
        "custom_embedding.000000.tar",
        "custom_embedding.000001.tar",
        "custom_features.000000.tar",
        "custom_features.000001.tar",
        "custom_indexes.000000.tar",
        "custom_indexes.000001.tar",
        "custom_recording.000000.tar",
        "custom_recording.000001.tar",
    ):
        assert not (tmp_path / fname).is_file()

    # - we didn't create a third shard
    assert not (tmp_path / "recording.000002.tar").exists()


def test_cut_set_to_shar_recordings_with_transforms(tmp_path: Path):
    cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True).resample(8000)
    cuts[0].features = None
    cuts[0].custom = None
    output_paths = cuts.to_shar(tmp_path, fields={"recording": "wav"})
    restored = CutSet.from_shar(fields=output_paths)
    assert len(restored) == 1
    samples = restored[0].load_audio()
    assert samples.shape == (1, 8000)


def test_shar_writer_not_sharded(tmp_path: Path):
    # Prepare data
    cuts = DummyManifest(CutSet, begin_id=0, end_id=20, with_data=True)

    # Prepare system under test
    writer = SharWriter(
        tmp_path,
        fields={
            "recording": "wav",
            "features": "lilcom",
            "custom_embedding": "numpy",
            "custom_features": "lilcom",
            "custom_indexes": "numpy",
            "custom_recording": "wav",
        },
        shard_size=None,
    )

    # Actual test
    with writer:
        for c in cuts:
            writer.write(c)

    # Post-conditions

    assert writer.output_paths == {
        "cuts": [
            str(tmp_path / "cuts.jsonl.gz"),
        ],
        "recording": [
            str(tmp_path / "recording.tar"),
        ],
        "features": [
            str(tmp_path / "features.tar"),
        ],
        "custom_embedding": [
            str(tmp_path / "custom_embedding.tar"),
        ],
        "custom_features": [
            str(tmp_path / "custom_features.tar"),
        ],
        "custom_indexes": [
            str(tmp_path / "custom_indexes.tar"),
        ],
        "custom_recording": [
            str(tmp_path / "custom_recording.tar"),
        ],
    }

    # - we created 2 shards with cutsets and a separate file for each data field
    for fname in (
        "cuts.jsonl.gz",
        "recording.tar",
        "features.tar",
        "custom_embedding.tar",
        "custom_features.tar",
        "custom_indexes.tar",
        "custom_recording.tar",
    ):
        assert (tmp_path / fname).is_file()

    # - we didn't create a shard
    assert not (tmp_path / "cuts.000000.jsonl.gz").exists()

    # - the cuts do not have any data actually attached to them,
    #   so it's impossible to load it if we open it as a normal CutSet
    for cut in CutSet.from_file(tmp_path / "cuts.jsonl.gz"):
        assert cut.recording.sources[0].type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_audio()

        assert cut.features.storage_type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_features()

        assert cut.custom_embedding.storage_type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_custom_embedding()

        assert cut.custom_features.array.storage_type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_custom_features()

        assert cut.custom_indexes.array.storage_type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_custom_indexes()

        assert cut.custom_recording.sources[0].type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_custom_recording()


def test_shar_writer_pipe(tmp_path: Path):
    # Prepare data
    cuts = DummyManifest(CutSet, begin_id=0, end_id=20, with_data=True)

    # Prepare system under test
    writer = SharWriter(
        f"pipe:cat >{tmp_path}",
        fields={
            "recording": "wav",
            "features": "lilcom",
            "custom_embedding": "numpy",
            "custom_features": "lilcom",
            "custom_indexes": "numpy",
            "custom_recording": "wav",
        },
        shard_size=10,
    )

    # Actual test
    with writer:
        for c in cuts:
            writer.write(c)

    # Post-conditions

    # - we created 2 shards with cutsets and a separate file for each data field
    for fname in (
        "cuts.000000.jsonl.gz",
        "cuts.000001.jsonl.gz",
        "recording.000000.tar",
        "recording.000001.tar",
        "features.000000.tar",
        "features.000001.tar",
        "custom_embedding.000000.tar",
        "custom_embedding.000001.tar",
        "custom_features.000000.tar",
        "custom_features.000001.tar",
        "custom_indexes.000000.tar",
        "custom_indexes.000001.tar",
        "custom_recording.000000.tar",
        "custom_recording.000001.tar",
    ):
        assert (tmp_path / fname).is_file()

    # - the cuts do not have any data actually attached to them,
    #   so it's impossible to load it if we open it as a normal CutSet
    for cut in CutSet.from_file(tmp_path / "cuts.000000.jsonl.gz"):
        assert cut.recording.sources[0].type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_audio()

        assert cut.features.storage_type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_features()

        assert cut.custom_embedding.storage_type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_custom_embedding()

        assert cut.custom_features.array.storage_type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_custom_features()

        assert cut.custom_indexes.array.storage_type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_custom_indexes()

        assert cut.custom_recording.sources[0].type == "shar"
        with pytest.raises(RuntimeError):
            cut.load_custom_recording()


def test_shar_writer_truncates_temporal_array_and_features(tmp_path: Path):
    # Basic data and sanity check of shapes.
    cut = dummy_cut(0, with_data=True)
    for k in "custom_embedding custom_features custom_recording".split():
        cut = cut.drop_custom(k)
    ref_audio = cut.load_audio()
    ref_feats = cut.load_features()
    ref_indxs = cut.load_custom_indexes()
    assert ref_audio.shape == (1, 16000)
    assert ref_feats.shape == (100, 23)
    assert ref_indxs.shape == (100,)

    # Truncated cut before writing to Shar and sanity check of shapes and content.
    cut = cut.truncate(offset=0.2, duration=0.6)
    trunc_audio = cut.load_audio()
    trunc_feats = cut.load_features()
    trunc_indxs = cut.load_custom_indexes()
    assert trunc_audio.shape == (1, 9600)
    np.testing.assert_array_equal(trunc_audio, ref_audio[:, 3200:-3200])
    assert trunc_feats.shape == (60, 23)
    np.testing.assert_array_equal(trunc_feats, ref_feats[20:-20, :])
    assert trunc_indxs.shape == (60,)
    np.testing.assert_array_equal(trunc_indxs, ref_indxs[20:-20])

    # System under test.
    with SharWriter(
        tmp_path,
        fields={"recording": "wav", "features": "numpy", "custom_indexes": "numpy"},
        shard_size=None,
    ) as writer:
        writer.write(cut)

    # Truncated cut restored from Shar and sanity check of shapes and content.
    sharcuts = CutSet.from_shar(in_dir=writer.output_dir)
    cut = sharcuts[0]
    trunc_audio = cut.load_audio()
    trunc_feats = cut.load_features()
    trunc_indxs = cut.load_custom_indexes()
    assert trunc_audio.shape == (1, 9600)
    np.testing.assert_array_equal(trunc_audio, ref_audio[:, 3200:-3200])
    assert trunc_feats.shape == (60, 23)
    np.testing.assert_array_equal(trunc_feats, ref_feats[20:-20, :])
    assert trunc_indxs.shape == (60,)
    np.testing.assert_array_equal(trunc_indxs, ref_indxs[20:-20])
