--Binding to the DynASM encoding engine.
--Written by Cosmin Apreutesei. Public Domain.
local ffi = require("ffi")
local bit = require("bit")
local arch = ffi.arch

if arch == "x64" then arch = "x86" end --same linker for x64
local C = ffi.load("dasm_" .. arch)
local M = {C = C}
M._VERSION = 10400
ffi.cdef[[
enum {
	DASM_S_OK         = 0x00000000,
	DASM_S_NOMEM      = 0x01000000,
	DASM_S_PHASE      = 0x02000000,
	DASM_S_MATCH_SEC  = 0x03000000,
	DASM_S_RANGE_I    = 0x11000000,
	DASM_S_RANGE_SEC  = 0x12000000,
	DASM_S_RANGE_LG   = 0x13000000,
	DASM_S_RANGE_PC   = 0x14000000,
	DASM_S_RANGE_VREG	= 0x15000000,
	DASM_S_UNDEF_L    = 0x21000000,
	DASM_S_UNDEF_PC   = 0x22000000,
};

/* Internal DynASM encoder state. */
typedef struct dasm_State_Ref { struct dasm_State *p; } dasm_State_Ref;
typedef dasm_State_Ref *Dst_DECL;

/* Initialize and free DynASM state. */
void dasm_init(Dst_DECL, int maxsection);
void dasm_free(Dst_DECL);

/* Setup global array. Must be called before dasm_setup(). */
void dasm_setupglobal(Dst_DECL, void **gl, unsigned int maxgl);

/* Grow PC label array. Can be called after dasm_setup(), too. */
void dasm_growpc(Dst_DECL, unsigned int maxpc);

/* Setup encoder. */
void dasm_setup(Dst_DECL, const void *actionlist);

/* Feed encoder with actions. Calls are generated by pre-processor. */
void dasm_put(Dst_DECL, int start, ...);

/* Link sections and return the resulting size. */
int dasm_link(Dst_DECL, size_t *szp);

/* Encode sections into buffer. */
int dasm_encode(Dst_DECL, void *buffer);

/* Get PC label offset. */
int dasm_getpclabel(Dst_DECL, unsigned int pc);

/* Optional sanity checker to call between isolated encoding steps. */
int dasm_checkstep(Dst_DECL, int secmatch);

typedef int (*DASM_EXTERN_TYPE) (void *ctx, unsigned char *addr, int idx, int type);
DASM_EXTERN_TYPE DASM_EXTERN_FUNC;
]]

local function err(...)
	io.stderr:setvbuf("no")
	io.stderr:write("dasm error: ", ...)
	io.stderr:write("\n")
	os.exit(1)
end

--status check helper
local status_map = {
	[C.DASM_S_NOMEM] = "out of memory",
	[C.DASM_S_PHASE] = "phase error",
	[C.DASM_S_MATCH_SEC] = "section not found",
	[C.DASM_S_RANGE_I] = "immediate value out of range",
	[C.DASM_S_RANGE_SEC] = "too many sections",
	[C.DASM_S_RANGE_LG] = "too many global labels",
	[C.DASM_S_RANGE_PC] = "too many pclabels",
	[C.DASM_S_RANGE_VREG] = "variable register out of range",
	[C.DASM_S_UNDEF_L] = "undefined global label",
	[C.DASM_S_UNDEF_PC] = "undefined pclabel",
}

local function checkst(st)
	if st == C.DASM_S_OK then return end

	local status, arg = status_map[bit.band(st, 0xff000000)], bit.band(st, 0x00ffffff)

	if status then
		err(status, ". :", arg)
	else
		err(string.format("0x%08X", st))
	end
end

--low level API
M.init = C.dasm_init
M.free = C.dasm_free
M.setupglobal = C.dasm_setupglobal
M.growpc = C.dasm_growpc
M.setup = C.dasm_setup
local int_ct = ffi.typeof("int")

local function convert_arg(arg) --dasm_put() accepts only int32 varargs.
	if type(arg) == "number" then --but we make it accept uint32 too by normalizing the arg.
		arg = bit.tobit(arg) --non-number args are converted to int32 according to ffi rules.
	end

	return ffi.cast(int_ct, arg)
end

local function convert_args(...) --not a tailcall but at least it doesn't make any garbage
	if select("#", ...) == 0 then return end

	return convert_arg(...), convert_args(select(2, ...))
end

function M.put(state, start, ...)
	C.dasm_put(state, start, convert_args(...))
end

function M.link(state, sz)
	sz = sz or ffi.new("size_t[1]")
	checkst(C.dasm_link(state, sz))
	return tonumber(sz[0])
end

function M.encode(state, buf)
	checkst(C.dasm_encode(state, buf))
end

jit.off(M.encode) --calls the DASM_EXTERN_FUNC callback
local voidptr_ct = ffi.typeof("void*")
local byteptr_ct = ffi.typeof("int8_t*")

function M.getpclabel(state, pc, buf)
	local offset = C.dasm_getpclabel(state, pc)

	if buf then
		return ffi.cast(voidptr_ct, ffi.cast(byteptr_ct, buf) + offset)
	end

	return offset
end

function M.checkstep(state, section)
	checkst(C.dasm_checkstep(state, section or -1))
end

--get the address of a standard symbol.
--TODO: ask Mike to expose clib_getsym() in ffi so we can get the address
--of symbols without having to declare them first.
local function getsym(name)
	return ffi.C[name]
end

local getsym = function(name)
	local ok, sym = pcall(getsym, name)

	if not ok then --not found or not defined: define it and try again
		ffi.cdef(string.format("void %s();", name))
		return getsym(name)
	else
		return sym
	end
end
--DASM_EXTERN callback plumbing
local extern_names --t[idx] -> name
local extern_get --f(name) -> ptr
local byteptr_ct = ffi.typeof("uint8_t*")

local function DASM_EXTERN_FUNC(ctx, addr, idx, type)
	if not extern_names or not extern_get then
		err("extern callback not initialized.")
	end

	local name = extern_names[idx]
	local ptr = extern_get(name)

	if ptr == nil then err("extern not found: ", name, ".") end

	if type ~= 0 then
		return ffi.cast(byteptr_ct, ptr) - addr - 4
	else
		return ptr
	end
end

function M.setupextern(_, names, getter)
	extern_names = names
	extern_get = getter or getsym

	if C.DASM_EXTERN_FUNC == nil then C.DASM_EXTERN_FUNC = DASM_EXTERN_FUNC end
end

--hi-level API
function M.new(actionlist, externnames, sectioncount, globalcount, externget, globals)
	local state = ffi.new("dasm_State_Ref")
	M.init(state, sectioncount or 1)
	globalcount = globalcount or 256
	globals = globals or ffi.new("void*[?]", globalcount)

	ffi.gc(state, function(state)
		local _ = actionlist, externnames, globals, externget --anchor those: don't rely on the user doing so
		M.free(state)
	end)

	M.setupglobal(state, globals, globalcount)
	M.setupextern(state, externnames, externget)
	M.setup(state, actionlist)
	return state, globals
end

function M.build(state)
	state:checkstep(-1)
	local sz = state:link()

	if sz == 0 then err("no code?") end

	local mm = require("dasm_mm") --runtime dependency
	local buf = mm.new(sz)
	state:encode(buf)
	mm.protect(buf, sz)
	return buf, sz
end

function M.dump(addr, size, out)
	local disass = require("jit.dis_" .. jit.arch).disass
	disass(ffi.string(addr, size), tonumber(ffi.cast("uintptr_t", addr)), out)
end

--given the globals array from dasm.new() and the globalnames list
--from the `.globalnames` directive, return a map {global_name -> global_addr}.
function M.globals(globals, globalnames)
	local t = {}

	for i = 0, #globalnames do
		if globals[i] ~= nil then t[globalnames[i]] = globals[i] end
	end

	return t
end

--object interface
ffi.metatype(
	"dasm_State_Ref",
	{
		__index = {
			--low-level API
			init = M.init,
			free = M.free,
			setupglobal = M.setupglobal,
			setupextern = M.setupextern,
			growpc = M.growpc,
			setup = M.setup,
			put = M.put,
			link = M.link,
			encode = M.encode,
			getpclabel = M.getpclabel,
			checkstep = M.checkstep,
			--hi-level API
			build = M.build,
		},
	}
)

if not ... then --demo
	local dasm = M
	local actions = ffi.new(
		"const uint8_t[19]",
		{254, 0, 102, 184, 5, 0, 254, 1, 102, 187, 3, 0, 254, 2, 102, 187, 3, 0, 255}
	)
	local Dst, globals = dasm.new(actions, nil, 3)
	--|.code
	dasm.put(Dst, 0)
	--| mov ax, 5
	--|.sub1
	dasm.put(Dst, 2)
	--| mov bx, 3
	--|.sub2
	dasm.put(Dst, 8)
	--| mov bx, 3
	dasm.put(Dst, 14)
	local addr, size = Dst:build()
	dasm.dump(addr, size)
end

return M