
#include "NTStorPort.h"
#include "dev/pci/DKPCIBus.h"
#include "dev/pci_reg.h"
#include "kdk/dev.h"
#include "kdk/kmem.h"
#include "kdk/libkern.h"
#include "kdk/kern.h"
#include "kdk/object.h"
#include "kdk/queue.h"
#include "kdk/vm.h"
#include "ntcompat/ntcompat.h"
#include "ntcompat/storport.h"
#include "ntcompat/storportcompat.h"
#include "ntcompat/win_types.h"
#include "vm/vmp.h"


#define INFO_ARGS(PINFO) (PINFO)->seg, (PINFO)->bus, (PINFO)->slot, (PINFO)->fun

struct storport_driver_queue drivers = TAILQ_HEAD_INITIALIZER(drivers);

static inline void
packPci(uint16_t seg, uint8_t bus, uint8_t slot, uint8_t fun,
    ULONG *SystemIoBusNumber, ULONG *SlotNumber)
{
	*SystemIoBusNumber = ((uint32_t)seg & 0xFFFF) << 16 |
	    ((uint32_t)bus & 0xFF) << 8 | ((uint32_t)slot & 0xFF);
	*SlotNumber = ((uint32_t)fun & 0xFF) << 24;
}

void
unpackPci(uint32_t SystemIoBusNumber, uint32_t SlotNumber, uint16_t *seg,
    uint8_t *bus, uint8_t *slot, uint8_t *fun)
{
	*seg = (SystemIoBusNumber >> 16) & 0xFFFF;
	*bus = (SystemIoBusNumber >> 8) & 0xFF;
	*slot = SystemIoBusNumber & 0xFF;
	*fun = (SlotNumber >> 24) & 0xFF;
}

static PACCESS_RANGE
getAccessRanges(struct pci_dev_info *info)
{
	PACCESS_RANGE accessranges = (PACCESS_RANGE)kmem_alloc(
	    sizeof(ACCESS_RANGE) * 7);

#if defined(__aarch64__) || defined(__amd64__) || defined(__riscv)
	for (size_t i = 0; i < 6; i++) {
		size_t off = kBaseAddress0 + sizeof(uint32_t) * i;
		uint64_t base;
		size_t len;
		uint32_t bar;

		bar = pci_readl(INFO_ARGS(info), off);

		if ((bar & 1) == 1) {
			accessranges[i].RangeStart.QuadPart = bar & 0xFFFFFFFC;
			accessranges[i].RangeLength = 0;
			accessranges[i].RangeInMemory = false;
		} else if (((bar >> 1) & 3) == 0) {
			uint32_t size_mask;

			pci_writel(INFO_ARGS(info), off, 0xffffffff);
			size_mask = pci_readl(INFO_ARGS(info), off);
			pci_writel(INFO_ARGS(info), off, bar);

			base = bar & 0xffffffF0;
			len = (size_t)1
			    << __builtin_ctzl(size_mask & 0xffffffF0);

			accessranges[i].RangeStart.QuadPart = base;
			accessranges[i].RangeLength = len;
			accessranges[i].RangeInMemory = true;
		} else {
			uint64_t size_mask, bar_high, size_mask_high;

			kassert(((bar >> 1) & 3) == 2);

			bar_high = pci_readl(INFO_ARGS(info), off + 4);
			base = (bar & 0xffffffF0) | (bar_high << 32);

			pci_writel(INFO_ARGS(info), off, 0xffffffff);
			pci_writel(INFO_ARGS(info), off + 4,
			    0xffffffff);
			size_mask = pci_readl(INFO_ARGS(info), off);
			size_mask_high = pci_readl(INFO_ARGS(info),
			    off + 4);
			pci_writel(INFO_ARGS(info), off, bar);
			pci_writel(INFO_ARGS(info), off + 4, bar_high);

			size_mask |= size_mask_high << 32;
			len = (size_t)1
			    << __builtin_ctzl(size_mask & 0xffffffffffffffF0);

			accessranges[i].RangeStart.QuadPart = base;
			accessranges[i].RangeLength = len;
			accessranges[i].RangeInMemory = true;
		}
	}
#endif

	return accessranges;
}

@interface
NTStorPort (Implementation)
- (bool)executeIOP:(iop_t *)iop;
- (void)iterate;
@end

@implementation NTStorPort

void
srb_deferred_completion_callback(void *arg)
{
	NTStorPort *self = arg;

	while (true) {
		PSCSI_REQUEST_BLOCK srb;
		ipl_t ipl = ke_spinlock_acquire_at(
		    &self->srb_deferred_completion_lock, kIPLHigh);
		srb = self->srb_deferred_completion_queue;
		if (srb == NULL) {
			ke_spinlock_release(
			    &self->srb_deferred_completion_lock, ipl);
			break;
		}
		self->srb_deferred_completion_queue = srb->NextSrb;
		ke_spinlock_release(&self->srb_deferred_completion_lock, ipl);

		[self completeSrb:srb];
	}
}

+ (BOOL)probeWithPCIBus:(DKPCIBus *)provider info:(struct pci_dev_info *)info
{
	struct storport_driver *driver = NULL;
	PACCESS_RANGE ranges = NULL;
	PPORT_CONFIGURATION_INFORMATION pcfg = NULL;

	TAILQ_FOREACH (driver, &drivers, queue_entry) {
		struct sp_dev_ext *deviceExtension;
		PVOID HwDeviceExtension;
		bool matched = false;
		struct pci_match *match = driver->nt_driver_object->matches;

		while (match->vendor_id != 0) {
			if (info->vendorId == match->vendor_id &&
			    info->deviceId == match->device_id) {
				matched = true;
				break;
			}
			match++;
		}

		if (!matched)
			continue;

#if defined(__aarch64__) || defined(__amd64__) || defined(__riscv)
		pci_writew(INFO_ARGS(info), kCommand,
		    pci_readw(INFO_ARGS(info), kCommand) &
			~(0x1 | 0x2));
#endif

		if (pcfg == NULL)
			pcfg = kmem_alloc(
			    sizeof(PORT_CONFIGURATION_INFORMATION));
		memset(pcfg, 0x0, sizeof(PORT_CONFIGURATION_INFORMATION));

		if (ranges == NULL)
			ranges = getAccessRanges(info);

		pcfg->AccessRanges = (ACCESS_RANGE(*)[])ranges;
		pcfg->NumberOfAccessRanges = 6;
		pcfg->AdapterInterfaceType = kPCIBus;
		packPci(info->seg, info->bus, info->slot, info->fun,
		    &pcfg->SystemIoBusNumber, &pcfg->SlotNumber);

#if defined(__aarch64__) || defined(__amd64__) || defined(__riscv)
		pci_writew(INFO_ARGS(info), kCommand,
		    pci_readw(INFO_ARGS(info), kCommand) | (0x1 | 0x2));
#endif

		deviceExtension = kmem_alloc(
		    driver->hwinit.DeviceExtensionSize +
		    sizeof(struct sp_dev_ext));
		deviceExtension->driver = driver;
		deviceExtension->portConfig = pcfg;
		ke_spinlock_init(&deviceExtension->intxLock);

		HwDeviceExtension = &deviceExtension->hw_dev_ext[0];

		DKDevLog(self,
		    "Matching driver %s found; calling HwFindAdapter\n",
		    driver->nt_driver_object->name);
		BOOLEAN again;
		ULONG ret;

		ret = driver->hwinit.HwFindAdapter(HwDeviceExtension, NULL,
		    NULL, NULL, pcfg, &again);
		kassert(ret == SP_RETURN_FOUND);

		DKDevLog(self,
		    "HwFindAdapter succeeded; instantiating StorPort device\n");

		[[self alloc] initWithPCIBus:provider
					info:info
			      storportDriver:driver
			     deviceExtension:deviceExtension];

		return YES;
	}

	return NO;
}

static bool
intx_handler(md_intr_frame_t *frame, void *arg)
{
	struct sp_dev_ext *devExt = arg;
	// kprintf("Begin handling it.\n");
	devExt->driver->hwinit.HwInterrupt(devExt->hw_dev_ext);
	// kprintf("Done handling it.\n");
	return true;
}

- (instancetype)initWithPCIBus:(DKPCIBus *)provider
			  info:(struct pci_dev_info *)info
		storportDriver:(struct storport_driver *)driver
	       deviceExtension:(struct sp_dev_ext *)devExt;
{
	int r;

	self = [super initWithProvider:provider];

	m_info = *info;
	m_deviceExtension = devExt;
	m_HwDeviceExtension = &devExt->hw_dev_ext[0];

	kmem_asprintf(obj_name_ptr(self), "%s-%lu",
	    driver->nt_driver_object->name, driver->counter);

	ke_spinlock_init(&srb_deferred_completion_lock);
	srb_deferred_completion_dpc.cpu = NULL;
	srb_deferred_completion_dpc.callback = srb_deferred_completion_callback;
	srb_deferred_completion_dpc.arg = self;
	srb_deferred_completion_queue = NULL;

	m_deviceExtension->device = self;

	DKDevLog(self, "Calling HwInitialize\n");
	BOOLEAN suc = driver->hwinit.HwInitialize(m_HwDeviceExtension);

	r = [[platformDevice platformInterruptController]
	    handleSource:&info->intx_source
	     withHandler:intx_handler
		argument:m_deviceExtension
	      atPriority:kIPLHigh
		   entry:&m_intxEntry];
	kassert(r == 0);

	if (devExt->passive_init != NULL) {
		DKDevLog(self, "Calling HwPassiveInitializeRoutine\n");
		devExt->passive_init(m_HwDeviceExtension);
	}
	DKDevLog(self, "Calling HwUnitControl\n");

	DKDevLog(self, "Probing SCSI bus\n");
	[self iterate];

	if (suc)
		DKDevLog(self, "StorPort device successfully initialised\n");

	[self registerDevice];
	DKLogAttach(self);

	return self;
}

static PSCSI_REQUEST_BLOCK
srb_alloc(void)
{
	return kmem_alloc(sizeof(SCSI_REQUEST_BLOCK));
}

static void
srb_init_execute_scsi(PSCSI_REQUEST_BLOCK srb, int direction, uint8_t pathId,
    uint8_t targetId, uint8_t lun, size_t cdbLength, void *buffer,
    size_t bufferLength)
{
	memset(srb, 0x0, sizeof(SCSI_REQUEST_BLOCK));

	srb->PathId = pathId;
	srb->TargetId = targetId;
	srb->Lun = lun;
	srb->SrbFlags = direction;

	srb->Length = sizeof(SCSI_REQUEST_BLOCK);
	srb->Function = SRB_FUNCTION_EXECUTE_SCSI;
	srb->CdbLength = cdbLength;

	srb->DataBuffer = buffer;
	srb->DataTransferLength = bufferLength;
}

static void
trimstr(char *str, size_t length)
{
	while (length > 0 && (str[length - 1] == ' ')) {
		length--;
	}

	str[length] = '\0';
}

- (void)iterate
{
	void *buffer = kmem_alloc(512);
	struct _LUN_LIST *list = buffer;
	struct _REPORT_LUNS *cdb;
	PSCSI_REQUEST_BLOCK srb;
	iop_t *iop;
	iop_return_t ret;

	srb = srb_alloc();
	iop = iop_new(self);

	for (size_t path = 0;
	     path < 1 /* m_deviceExtension->portConfig->NumberOfBuses*/;
	     path++) {
		for (size_t target = 0; target < 1
		     /*m_deviceExtension->portConfig->MaximumNumberOfTargets*/;
		     target++) {
			srb_init_execute_scsi(srb, SRB_FLAGS_DATA_IN, path,
			    target, 0, sizeof(struct _REPORT_LUNS), buffer,
			    512);
			cdb = (void *)srb->Cdb;
			cdb->OperationCode = SCSIOP_REPORT_LUNS;
			*(uint32_t *)cdb->AllocationLength = __builtin_bswap32(
			    512);

			*(uint32_t *)list->LunListLength = 0x0;

			iop_init_scsi(iop, self, srb);
			ret = iop_send_sync(iop);
			kassert(ret == kIOPRetCompleted);

			if (srb->SrbStatus != SRB_STATUS_SUCCESS &&
			    !(srb->SrbStatus == SRB_STATUS_DATA_OVERRUN &&
				srb->DataTransferLength >=
				    sizeof(struct _LUN_LIST))) {
				 kprintf(
				     "Target %zu: SRB status is %x\n", target,
				     srb->SrbStatus
				);
				continue;
			}

			uint32_t length = __builtin_bswap32(*(
					      uint32_t *)list->LunListLength) /
			    8;
			for (size_t lunIdx = 0; lunIdx < length; lunIdx++) {
				uint64_t lun = __builtin_bswap64(
				    *(uint64_t *)list->Lun[lunIdx]);
				struct _CDB6INQUIRY *cdb;
				INQUIRYDATA *data = kmem_alloc(sizeof(*data));

				srb_init_execute_scsi(srb, SRB_FLAGS_DATA_IN,
				    path, target, lun,
				    sizeof(struct _CDB6INQUIRY), data,
				    sizeof(*data));

				cdb = (void *)srb->Cdb;
				cdb->OperationCode = SCSIOP_INQUIRY;
				cdb->LogicalUnitNumber = lunIdx;
				cdb->Control = 0;
				cdb->AllocationLength = sizeof(*data);

				iop_init_scsi(iop, self, srb);
				ret = iop_send_sync(iop);
				kassert(ret == kIOPRetCompleted);

				if (srb->SrbStatus != SRB_STATUS_SUCCESS) {
					kfatal("Srb not successful!\n");
					continue;
				}

				trimstr(data->VendorId, sizeof(data->VendorId));
				trimstr(data->ProductId,
				    sizeof(data->ProductId));
				trimstr((char *)data->ProductRevisionLevel,
				    sizeof(data->ProductRevisionLevel));

				DKDevLog(self,
				    "<%.8s %.16s %.4s> at bus %zu target %zu lun %"PRIu64"\n",
				    data->VendorId, data->ProductId,
				    data->ProductRevisionLevel, path, target,
				    lun);
			}
		}
	}
}

- (iop_return_t)dispatchIOP:(iop_t *)iop
{
	[self executeIOP:iop];
	return kIOPRetPending;
}

- (bool)executeIOP:(iop_t *)iop
{
	iop_frame_t *frame = iop_stack_current(iop);
	PSCSI_REQUEST_BLOCK srb = frame->scsi.srb;
	struct storport_driver *drv = m_deviceExtension->driver;
	BOOLEAN r;

	kassert(frame->function = kIOPTypeSCSI);

	srb->SrbStatus = SRB_STATUS_PENDING;
	srb->SrbExtension = kmem_alloc(drv->hwinit.SrbExtensionSize);
	srb->OriginalRequest = iop;

	r = drv->hwinit.HwBuildIo(m_HwDeviceExtension, srb);
	kassert(r == TRUE);
	r = drv->hwinit.HwStartIo(m_HwDeviceExtension, srb);
	kassert(r == TRUE);
	return true;
}

- (void)completeSrb:(PSCSI_REQUEST_BLOCK)Srb
{
	struct storport_driver *drv = m_deviceExtension->driver;
	ipl_t ipl = splget();
	if (ipl > kIPLDPC) {
		ke_spinlock_acquire_nospl(&srb_deferred_completion_lock);
		Srb->NextSrb = srb_deferred_completion_queue;
		srb_deferred_completion_queue = Srb;
		ke_spinlock_release_nospl(&srb_deferred_completion_lock);
		ke_dpc_enqueue(&srb_deferred_completion_dpc);
	} else {
		kmem_free(Srb->SrbExtension, drv->hwinit.SrbExtensionSize);
		iop_continue(Srb->OriginalRequest, kIOPRetCompleted);
	}
}

@end
