#include "iris/arch/x86/amd64/hw/pit.h"

#include "iris/arch/x86/amd64/io_instructions.h"
#include "iris/core/global_state.h"
#include "iris/core/interrupt_disabler.h"
#include "iris/core/print.h"
#include "iris/hw/irq.h"
#include "iris/hw/irq_controller.h"
#include "iris/hw/timer.h"

namespace iris::x86::amd64 {
constexpr static auto pit_frequency = 1193182;

class PitTimer {
private:
    auto register_irq() -> Expected<void> {
        if (m_irq_id) {
            return {};
        }

        m_irq_id = TRY(register_external_irq_handler(IrqLine(0), [this](IrqContext& context) -> IrqStatus {
            send_eoi(*context.controller->lock(), IrqLine(0));
            m_callback(context);
            return IrqStatus::Handled;
        }));
        return {};
    }

    friend auto tag_invoke(di::Tag<timer_name>, PitTimer const&) -> di::StringView { return "PIT"_sv; }

    friend auto tag_invoke(di::Tag<timer_capabilities>, PitTimer const&) -> TimerCapabilities {
        return TimerCapabilities::SingleShot | TimerCapabilities::Periodic;
    }

    friend auto tag_invoke(di::Tag<timer_resolution>, PitTimer const&) -> TimerResolution {
        return TimerResolution(1_s) / pit_frequency;
    }

    friend auto tag_invoke(di::Tag<timer_set_single_shot>, PitTimer& self, TimerResolution duration,
                           di::Function<void(IrqContext&)> callback) -> Expected<void> {
        auto divisor = duration.count() * pit_frequency / TimerResolution(1_s).count();

        return with_interrupts_disabled([&] -> Expected<void> {
            TRY(self.register_irq());
            self.m_callback = di::move(callback);
            // Set PIT to mode 0: interrupt on terminal count.
            x86::amd64::io_out(0x43, 0b00110000_u8);
            x86::amd64::io_out(0x40, u8(divisor & 0xFF));
            x86::amd64::io_out(0x40, u8(divisor >> 8));
            return {};
        });
    }

    friend auto tag_invoke(di::Tag<timer_set_interval>, PitTimer& self, TimerResolution duration,
                           di::Function<void(IrqContext&)> callback) -> Expected<void> {
        auto divisor = duration.count() * pit_frequency / TimerResolution(1_s).count();

        return with_interrupts_disabled([&] -> Expected<void> {
            TRY(self.register_irq());
            self.m_callback = di::move(callback);

            // Set PIT to mode 3: square wave generator.
            x86::amd64::io_out(0x43, 0b00110110_u8);
            x86::amd64::io_out(0x40, u8(divisor & 0xFF));
            x86::amd64::io_out(0x40, u8(divisor >> 8));
            return {};
        });
    }

    di::Function<void(IrqContext&)> m_callback;
    di::Optional<usize> m_irq_id;
};

static_assert(di::Impl<PitTimer, TimerInterface>);

void init_pit() {
    // FIXME: In the future, we should determine whether the HPET is available, and if so, disable the PIT.

    *global_state_in_boot().timers.emplace_back(*Timer::create(di::in_place_type<PitTimer>));
}
}
