#include <fcntl.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/time.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/syscall.h>
#include <sys/uio.h>
#include <time.h>
#include <unistd.h>
#include "lua.h"
#include "lualib.h"
#include "lauxlib.h"
#include "lua-bundle.h"
#include "multiboot2.h"
#include "sqlite3.h"
#include "util.h"
#include "vbe.h"

extern const uintptr_t heap_start;
static u8 *heap_end = 0;

static lua_State *L = NULL;

#if 1
// QEMU
void __attribute__((noinline))
trap()
{
  //~ asm volatile("");
  //~ asm volatile("int 3");
  asm volatile("xchg bx, bx");
}
#else
// bochs
#define trap() asm volatile("xchg bx, bx")
#endif

static void* l_alloc(void *ud, void *ptr, size_t osize, size_t nsize)
{
  (void)ud;
  (void)osize;
  if (nsize == 0)
  {
    sqlite3_free(ptr);
    return NULL;
  }
  else
  {
    return sqlite3_realloc(ptr, nsize);
  }
}

long handle_syscall(long n, long a1, long a2, long a3, long a4, long a5, long a6)
{
  switch (n)
  {
    case SYS_brk:
    {
      void *end = (void*)a1;
      // trap();
      return 0;
      break;
    }
    
    case SYS_open:
    {
      const char *pathname = (const char*)a1;
      int flags = a2;
      int mode = a3;
      
      lua_getglobal(L, "open");
      lua_pushstring(L, pathname);
      lua_pushnumber(L, flags);
      lua_pushnumber(L, mode);
      lua_call(L, 3, 1);
      return luaL_ref(L, LUA_REGISTRYINDEX);
      //int fd = luaL_ref(L, LUA_REGISTRYINDEX);
      //return lua_tonumber(L, 1);
      break;
    }
    
    case SYS_close:
    {
      return 0;
      break;
    }
    
    case SYS_lseek:
    {
      int fd = a1;
      off_t offset = a2;
      int whence = a3;
      
      lua_getglobal(L, "lseek");
      lua_rawgeti(L, LUA_REGISTRYINDEX, fd);
      lua_pushnumber(L, offset);
      lua_pushnumber(L, whence);
      lua_call(L, 3, 1);
      int retn = lua_tonumber(L, -1);
      lua_pop(L, 1);
      return retn;
      break;
    }
    
    case SYS_read:
    {
      int fd = a1;
      char *buf = (char*)a2;
      size_t count = a3;
      
      lua_getglobal(L, "read");
      if (fd != 0)
      {
        lua_rawgeti(L, LUA_REGISTRYINDEX, fd);
      }
      else
      {
        // stdin
        lua_pushnumber(L, fd);
      }
      lua_pushnumber(L, count);
      lua_call(L, 2, 1);
      size_t len;
      const char *resbuf = lua_tolstring(L, -1, &len);
      memcpy(buf, resbuf, len);
      lua_pop(L, 1);
      return len;
      //~ break;
    }
    
    case SYS_write:
    {
      int fd = a1;
      const char *buf = (const char*)a2;
      size_t count = a3;
      
      lua_getglobal(L, "write");
      if (fd != 1)
      {
        lua_rawgeti(L, LUA_REGISTRYINDEX, fd);
      }
      else
      {
        lua_pushnumber(L, fd);
      }
      lua_pushlstring(L, buf, count);
      lua_call(L, 2, 1);
      int len = lua_tonumber(L, -1);
      lua_pop(L, 1);
      return len;
    }
    
    case SYS_writev:
    {
      // p (char*)iov[0].iov_base
      int fd = a1;
      struct iovec *iov = (struct iovec*)a2;
      int count = (int)a3;
      ssize_t num_bytes_written = 0;
      for (int i = 0; i < count; ++i)
      {
        num_bytes_written += write(fd, iov[i].iov_base, iov[i].iov_len);
      }
      return num_bytes_written;
      break;
    }
    
    case SYS_ioctl:
    {
      return 0;
    }
    
    case SYS_getpid:
    {
      return 0;
    }
    
    case SYS_fsync:
    {
      return 0;
    }
    
    case SYS_chown:
    {
      return 0;
    }
    
    case SYS_clock_gettime:
    {
      clockid_t clk_id = (clockid_t)a1;
      struct timespec *tp = (struct timespec*)a2;
      *tp = (struct timespec){0};
      return 0;
    }
    
    case SYS_gettimeofday:
    {
      struct timeval *tv = (struct timeval*)a1;
      *tv = (struct timeval){0};
      return 0;
    }
    
    case SYS_fcntl:
    {
      return 0;
    }
    
    default:
    {
      trap();
      return -1;
    }
  }
}

long __syscall(long n, long a1, long a2, long a3, long a4, long a5, long a6)
{
  handle_syscall(n, a1, a2, a3, a4, a5, a6);
}

extern u8 *volatile multiboot_boot_information;
static struct VBEModeInfoBlock modeinfo;

static u8 *fbmem = NULL;
static u8 *display_buffer = NULL;
static u32 display_buffer_len = 0;

static int putpixel(lua_State *l)
{
  u32 x = lua_tonumber(l, 1);
  u32 y = lua_tonumber(l, 2);
  u32 r = lua_tonumber(l, 3);
  u32 g = lua_tonumber(l, 4);
  u32 b = lua_tonumber(l, 5);
  lua_pop(l, 5);
  const u32 bytes_per_pixel = (modeinfo.BitsPerPixel / 8);
  // http://forum.osdev.org/viewtopic.php?p=77998&sid=d4699cf03655c572906144641a98e4aa#p77998
  u8 *ptr = 
    &display_buffer[(y * modeinfo.BytesPerScanLine) + (x * bytes_per_pixel)];
  const u8 *display_buffer_end = 
    &display_buffer[(modeinfo.YResolution * modeinfo.BytesPerScanLine)];
  if (ptr < display_buffer_end)
  {
    ptr[0] = b;
    ptr[1] = g;
    ptr[2] = r;
    ptr[3] = 0;
  }
  else
  {
    trap();
  }

  return 0;
}

static int clear_screen(lua_State *l)
{
  memset(display_buffer, 0, display_buffer_len);
  return 0;
}

static int swap_buffers(lua_State *l)
{
  memcpy(fbmem, display_buffer, display_buffer_len);
  clear_screen(l);
  return 0;
}

static void get_multiboot_info(void)
{
  if (multiboot_boot_information == NULL)
  {
    trap();
  }
  u32 total_size = *(u32*)multiboot_boot_information;
  struct multiboot_tag *tag = (struct multiboot_tag*)&multiboot_boot_information[8];
  while (tag->type != MULTIBOOT_TAG_TYPE_END)
  {
    switch (tag->type)
    {
      case MULTIBOOT_TAG_TYPE_VBE:
      {
        struct multiboot_tag_vbe *vbetag = (struct multiboot_tag_vbe*)tag;
        modeinfo = *(struct VBEModeInfoBlock*)&vbetag->vbe_mode_info;
        break;
      }
      
      case MULTIBOOT_TAG_TYPE_FRAMEBUFFER:
      {
        struct multiboot_tag_framebuffer *fb = (struct multiboot_tag_framebuffer*)tag;
        if (modeinfo.XResolution > 0)
        {
          fbmem = (u8*)fb->common.framebuffer_addr;
        }
        break;
      }
#if 0
      case MULTIBOOT_TAG_TYPE_BASIC_MEMINFO:
      {
        struct multiboot_tag_basic_meminfo *meminfo = (struct multiboot_tag_basic_meminfo*)tag;
        break;
      }
#endif

      case MULTIBOOT_TAG_TYPE_MMAP:
      {
        struct multiboot_tag_mmap *mmap = (struct multiboot_tag_mmap*)tag;
        for
        (
          struct multiboot_mmap_entry *entry = mmap->entries;
          (u8*)entry < (u8*)mmap + tag->size;
          entry = (struct multiboot_mmap_entry*)((u8*)entry + mmap->entry_size)
        )
        {
          if (entry->type == MULTIBOOT_MEMORY_AVAILABLE)
          {
            // available memory:
            // 0x0 to 0x9f000
            // 0x100000 to 0x3FFF0000
            // 0xdcaa00
            if (entry->addr == 0x100000)
            {
              heap_end = (u8*)entry->addr + entry->len;
              //~ trap();
            }
          }
        }
        break;
      }
    }

    // tags are padded to ensure 8 byte alignment
    if (tag->size % 8 == 0)
    {
      tag = (struct multiboot_tag*)((u8*)tag + tag->size);
    }
    else
    {
      tag = (struct multiboot_tag*)((u8*)tag + tag->size + (8 - (tag->size % 8)));
    }
  }
}

static u64 timer_ticks = 0;

static u32 keyboard_scancode_queue[8] = {0};
static u32 keyboard_scancode_queue_len = 0;

void handle_interrupt(u32 n)
{
  switch (n)
  {
    // general protection fault
    case 13:
    {
      trap();
      break;
    }
    
    // page fault
    case 14:
    {
      trap();
      break;
    }
    
    // timer
    case 32:
    {
      //~ trap(); while (1);
      // 100 Hz
      timer_ticks += 1;
      outb(0x20, 0x20);
      break;
    }
    
    // keyboard
    case 33:
    {
      u32 scancode = inb(0x60);
      if (keyboard_scancode_queue_len < arraylen(keyboard_scancode_queue))
      {
        keyboard_scancode_queue[keyboard_scancode_queue_len] = scancode;
        keyboard_scancode_queue_len += 1;
      }
      outb(0x20, 0x20);
      break;
    }
    
    // mouse
    case 44:
    {
      u32 n = inb(0x60);
      outb(0xa0, 0x20);
      outb(0x20, 0x20);
      break;
    }
    
    default:
    {
      trap(); while (1);
      break;
    }
  }
}

static int lua_outb(lua_State *l)
{
  const u32 addr = lua_tonumber(l, 1);
  const u8 value = lua_tonumber(l, 2);
  outb(addr, value);
  lua_pop(l, 2);
  return 0;
}

static int lua_inb(lua_State *l)
{
  const u32 addr = lua_tonumber(l, 1);
  lua_pop(l, 1);
  const u8 value = inb(addr);
  lua_pushnumber(l, value);
  return 1;
}

// http://lua-users.org/lists/lua-l/2003-12/msg00301.html
// http://lua-users.org/lists/lua-l/2002-12/msg00171.html
// http://lua-users.org/lists/lua-l/2011-06/msg00426.html
// http://lua-users.org/lists/lua-l/2010-03/msg00679.html

static void lua_hook(lua_State *l, lua_Debug *ar)
{
  lua_yield(l, 0);
}

static int lua_setmaskhook(lua_State *l)
{
  lua_State *t = lua_tothread(l, 1);
  int maskcount = lua_tointeger(l, 2);
  lua_pop(l, 2);
  if (t)
  {
    lua_sethook(t, lua_hook, LUA_MASKCOUNT, maskcount);
  }
  return 0;
}

static int lua_get_timer_ticks(lua_State *l)
{
  lua_pushinteger(l, timer_ticks);
  return 1;
}

static int lua_get_keyboard_interrupt(lua_State *l)
{
  // disable interrupts
  asm volatile ("cli");
  
  // process interrupt data
  lua_createtable(l, keyboard_scancode_queue_len, 0);
  for (int i = 0; i < keyboard_scancode_queue_len; ++i)
  {
    lua_pushinteger(l, keyboard_scancode_queue[i]);
    lua_rawseti(l, -2, i + 1);
  }
  keyboard_scancode_queue_len = 0;
  
  // enable interrupts
  asm volatile ("sti");
  return 1;
}

static int lua_hlt(lua_State *l)
{
  asm volatile("hlt");
  return 0;
}

const char *errstr = NULL;

static int lua_loader(lua_State *l)
{
  size_t len;
  const char *modname = lua_tolstring(l, -1, &len);
  struct module *mod = NULL;
  for (int i = 0; i < arraylen(lua_bundle); ++i)
  {
    if (memcmp(modname, lua_bundle[i].name, len) == 0)
    {
      mod = &lua_bundle[i];
    }
  }
  if (!mod)
  {
    lua_pushnil(l);
    return 1;
  }
  if (luaL_loadbuffer(l, mod->buf, mod->len, mod->name) != LUA_OK)
  {
    errstr = lua_tostring(l, 1);
    //~ puts("luaL_loadstring: error");
    trap();
  }
  int err = lua_pcall(l, 0, LUA_MULTRET, 0);
  if (err != LUA_OK)
  {
    errstr = lua_tostring(l, 1);
    //~ puts("lua_pcall: error");
    trap();
  }
  if (!lua_istable(l, -1))
  {
    puts("not a table");
  }
  return 1;
}

void main(void)
{
  get_multiboot_info();
  
  // Use SQLite3 as the only memory allocator because musl's malloc requires mmap.
  sqlite3_config(SQLITE_CONFIG_HEAP, &heap_start, heap_end - (u8*)&heap_start, 64);
  
  L = lua_newstate(l_alloc, NULL);
  if (!L)
  {
    puts("lua_newstate: error");
    return;
  }
  luaL_openlibs(L);
  
  display_buffer_len = (modeinfo.YResolution * modeinfo.BytesPerScanLine);
  display_buffer = lua_newuserdata(L, display_buffer_len);
  clear_screen(L);
  
  lua_pushnumber(L, modeinfo.XResolution);
  lua_setglobal(L, "DISPLAY_WIDTH");
  lua_pushnumber(L, modeinfo.YResolution);
  lua_setglobal(L, "DISPLAY_HEIGHT");  
  lua_register(L, "clear_screen", clear_screen);
  lua_register(L, "putpixel", putpixel);
  lua_register(L, "swap_buffers", swap_buffers);
  lua_register(L, "outb", lua_outb);
  lua_register(L, "inb", lua_inb);
  lua_register(L, "setmaskhook", lua_setmaskhook);
  lua_register(L, "loader", lua_loader);
  lua_register(L, "get_timer_ticks", lua_get_timer_ticks);
  lua_register(L, "get_keyboard_interrupt", lua_get_keyboard_interrupt);
  lua_register(L, "hlt", lua_hlt);
  
  int luaopen_lsqlite3(lua_State *L);
  luaL_requiref(L, "lsqlite3", luaopen_lsqlite3, 0);
  lua_pop(L, 1);
  
  if (luaL_loadbuffer(L, luakernel_lua, luakernel_lua_len, "luakernel") != LUA_OK)
  {
    //~ puts("luaL_loadstring: error");
    errstr = lua_tostring(L, 1);
    trap();
    return;
  }
  int err = lua_pcall(L, 0, LUA_MULTRET, 0);
  if (err != LUA_OK)
  {
    //~ puts("lua_pcall: error");
    trap();
    return;
  }
  trap();
}
