#include "PSXADPCMDecoder.h"
#include "logger.hpp"
#include <algorithm>
#include "Types.hpp"

// Pos / neg Tables
const s16 pos_adpcm_table[5] = {0, +60, +115, +98, +122};
const s16 neg_adpcm_table[5] = {0, 0, -52, -55, -60};

static s8 Signed4bit(u8 number)
{
    if ((number & 0x8) == 0x8)
    {
        return (number & 0x7) - 8;
    }
    else
    {
        return number;
    }
}

static s32 MinMax(s32 number, s32 min, s32 max)
{
    if (number < min)
    {
        return min;
    }

    if (number > max)
    {
        return max;
    }

    return number;
}

template <class T>
static void DecodeBlock(
    T& out,
    const u8 (&samples)[112],
    const u8 (&parameters)[16],
    u8 block,
    u8 nibble,
    s16& dst,
    f64& old,
    f64& older)
{
    // 4 blocks for each nibble, so 8 blocks total
    const s32 shift = static_cast<u8>(12 - (parameters[4 + block * 2 + nibble] & 0xF));
    const s32 filter = static_cast<s8>((parameters[4 + block * 2 + nibble] & 0x30) >> 4);

    const s32 f0 = pos_adpcm_table[filter];
    const s32 f1 = neg_adpcm_table[filter];

    for (s32 d = 0; d < 28; d++)
    {
        const u8 value = samples[block + d * 4];
        const s32 t = Signed4bit(static_cast<u8>((value >> (nibble * 4)) & 0xF));
        s32 s = static_cast<s16>((t << shift) + ((old * f0 + older * f1 + 32) / 64));
        s = static_cast<u16>(MinMax(s, -32768, 32767));


        out[dst] = static_cast<s16>(s);
        dst += 2;

        older = old;
        old = static_cast<s16>(s);
    }
}

template <class T>
static void Decode(const PSXADPCMDecoder::SoundFrame& sf, T& out)
{
    s16 dstLeft = 0;
    static f64 oldLeft = 0;
    static f64 olderLeft = 0;

    s16 dstRight = 1;
    static f64 oldRight = 0;
    static f64 olderRight = 0;

    for (s32 i = 0; i < 18; i++)
    {
        const PSXADPCMDecoder::SoundFrame::SoundGroup& sg = sf.sound_groups[i];
        for (u8 b = 0; b < 4; b++)
        {
            DecodeBlock(out, sg.audio_sample_bytes, sg.sound_parameters, b, 1, dstLeft, oldLeft, olderLeft);
            DecodeBlock(out, sg.audio_sample_bytes, sg.sound_parameters, b, 0, dstRight, oldRight, olderRight);
        }
    }
}

void PSXADPCMDecoder::DecodeFrameToPCM(std::vector<s16>& out, uint8_t* arg_adpcm_frame)
{
    const PSXADPCMDecoder::SoundFrame* sf = reinterpret_cast<const PSXADPCMDecoder::SoundFrame*>(arg_adpcm_frame);
    Decode(*sf, out);
}

void PSXADPCMDecoder::DecodeFrameToPCM(std::array<s16, 4032>& out, uint8_t* arg_adpcm_frame)
{
    const PSXADPCMDecoder::SoundFrame* sf = reinterpret_cast<const PSXADPCMDecoder::SoundFrame*>(arg_adpcm_frame);
    Decode(*sf, out);
}
