#include "CodeGenerator.h"
#include "CodeGenerator+CppHelpers.h"
#include "InterfaceDescription.h"
#include "version.h"

#include <algorithm>
#include <cctype>
#include <exception>
#include <filesystem>
#include <fstream>
#include <iomanip>
#include <iostream>

/**
 * Returns the name of the stub class for the server, given the interface name.
 */
static inline std::string GetClassName(const std::shared_ptr<InterfaceDescription> &id) {
    auto temp = id->getName();
    temp[0] = std::toupper(temp[0]);

    return temp + "Server";
}



/**
 * Generate the C++ server stub for the interface.
 *
 * This implements the concrete wire message (de)serialization for incoming requests, as well as
 * the replies thereto. The implementer then simply subclasses this base stub, and implements the
 * relevant abstract methods that actually implement the behavior of the interface.
 */
void CodeGenerator::generateServerStub() {
    auto fileNameH = this->outDir / ("Server_" + this->interface->getName() + ".hpp");
    auto fileNameCpp = this->outDir / ("Server_" + this->interface->getName() + ".cpp");
    std::cout << "    * Server stub: " << fileNameCpp.string() << ", " << fileNameH.string()
              << std::endl;

    // generate the header first
    std::string includeGuardName("RPC_SERVER_GENERATED_");
    includeGuardName.append(std::to_string(this->interface->getIdentifier()));

    std::ofstream header(fileNameH.string(), std::ofstream::trunc);
    this->serverWriteInfoBlock(header);
    header << R"(#ifndef )" << includeGuardName << R"(
#define )" << includeGuardName << R"(

#include <string>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <span>
#include <string_view>
#include <vector>
)" << std::endl;
    this->cppWriteIncludes(header);
    this->serverWriteHeader(header);
    header << "#endif // defined(" << includeGuardName << ")" << std::endl;

    // and then the implementation
    std::ofstream implementation(fileNameCpp.string(), std::ofstream::trunc);
    this->serverWriteInfoBlock(implementation);
    implementation << "#include \"" << fileNameH.filename().generic_string() << "\"" << std::endl
       << "#include \"" << this->serializationFile.filename().generic_string() << '"' << std::endl
       << R"(
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <stdexcept>

#include <rpc/rt/RpcIoStream.hpp>

using namespace rpc;

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunused-variable"
)";
    this->serverWriteImpl(implementation);
    implementation << "#pragma clang diagnostic pop" << std::endl;
}

/**
 * Write the generic info block comment for the server stub.
 */
void CodeGenerator::serverWriteInfoBlock(std::ofstream &os) {
    os << R"(/*
 * This RPC server stub was autogenerated by idlc (version )"
    << std::string(gVERSION_HASH).substr(0, 8) << "). DO NOT EDIT!" << std::endl
    << " * Generated from " << this->interface->getSourceFilename() << " for interface "
       << this->interface->getName() << " at " << this->creationTimestamp << std::endl
       << " *" << R"(
 * You should subclass this implementation and define the required abstract methods to complete
 * implementing the interface. Note that there are several helper methods available to simplify
 * this task, or to retrieve more information about the caller.
 *
 * See the full RPC documentation for more details.
)"
       << " */" << std::endl;
}



/**
 * Writes the header for the server stub.
 *
 * This consists of a large fixed portion to define types and internal functions and helpers, as
 * well as two variable portions: one, defining the abstract methods that the implementation shall
 * provide, and two, the implementation of the deserialization of method parameters, and then
 * serialization of its response.
 */
void CodeGenerator::serverWriteHeader(std::ofstream &os) {
    // set up the starting of the class
    const auto className = GetClassName(this->interface);

    os << "namespace rpc {" << std::endl << "namespace rt { class ServerRpcIoStream; }"
       << std::endl;
    os << "class " << className << " {";

    /*
     * Define the types, including the MessageHeader, which is the structure we expect to receive
     * from whatever input/output stream we're provided. Also define the return type structs for
     * any methods with multiple returns.
     */
    os << R"(
    struct MessageHeader {
        enum Flags: uint32_t {
            Request                     = (1 << 0),
            Response                    = (1 << 1),
        };

        uint64_t type;
        Flags flags;
        uint32_t tag;

        std::byte payload[];
    };
    static_assert(!(offsetof(MessageHeader, payload) % sizeof(uintptr_t)),
        "message header's payload is not word aligned");

    constexpr static const std::string_view kServiceName{")" << this->interface->getName() << R"("};

    protected:
        using IoStream = rt::ServerRpcIoStream;
)";

    for(const auto &m : this->interface->getMethods()) {
        if(!m.hasMultipleReturns()) continue;
        this->cppWriteReturnStruct(os, m);
    }

    // public methods
    os << R"(
    public:
        )" << className << R"((const std::shared_ptr<IoStream> &stream);
        virtual ~)" << className << R"(();

        // Server's main loop; continuously read and handle messages.
        bool run(const bool block = true);
        // Process a single message.
        bool runOne(const bool block);
)";

    // abstract methods to implement
    os << R"(
    // These are methods the implementation provides to complete implementation of the interface
    protected:
)";
    for(const auto &m : this->interface->getMethods()) {
        os << "        virtual ";
        this->cppWriteMethodDef(os, m, "impl");
        os << " = 0;" << std::endl;
    }

    // implementation details
    os << R"(
    // Helpers provided to subclasses for implementation of interface methods
    protected:
        constexpr inline auto &getIo() {
            return this->io;
        }

        /// Handles errors occurring during server operations
        virtual void _HandleError(const bool fatal, const std::string_view &what);

    // Implementation details; pretend this does not exist
    private:
        std::shared_ptr<IoStream> io;
        size_t txBufSize{0};
        void *txBuf{nullptr};

        void _ensureTxBuf(const size_t);
        void _sendReply(const MessageHeader &, const size_t);

)";

    // autogenerated marshalling methods
    for(const auto &m : this->interface->getMethods()) {
        os << "        void _marshall" << GetMethodName(m)
           << "(const MessageHeader &, const std::span<std::byte> &payload)"
           << ";" << std::endl;
    }

    // close the class and namespace
    os << "}; // class " << className << std::endl
       << "} // namespace rpc" << std::endl;
}



/**
 * Writes the server stub implementation.
 */
void CodeGenerator::serverWriteImpl(std::ofstream &os) {
    const auto className = GetClassName(this->interface);
    os << "using Server = " << className << ';' << std::endl << std::endl;

    this->cppWriteCustomTypeHelpers(os);

    // constructors and run method
    os << R"(/**
 * Creates a new server instance, with the given IO stream.
 */
Server::)" << className << R"((const std::shared_ptr<IoStream> &stream) : io(stream) {
}

/**
 * Releases any allocated resources.
 */
Server::~)" << className << R"(() {
    free(this->txBuf);
}

/**
 * Continuously processes messages until processing fails to receive another message.
 */
bool Server::run(const bool block) {
    bool cont;
    do {
        cont = this->runOne(block);
    } while(cont);
    return cont;
}

/**
 * Reads a single message from the RPC connection and attempts to process it.
 *
 * @return Whether a message was able to be received and processed.
 */
bool Server::runOne(const bool block) {
    // try to receive message
    std::span<std::byte> buf;
    if(!this->io->receive(buf, block)) return false;

    // get the message header and its payload
    if(buf.size() < sizeof(MessageHeader)) {
        this->_HandleError(false, "Received message too small");
        return false;
    }
    const auto hdr = reinterpret_cast<const MessageHeader *>(buf.data());

    const auto payload = buf.subspan(offsetof(MessageHeader, payload));

    // then invoke the appropriate marshalling function
    switch(hdr->type) {
)";

    for(const auto &m : this->interface->getMethods()) {
        os << "        case static_cast<uint64_t>(" << SerGetMessageIdEnumName(m) << "):" << std::endl
           << "            this->_marshall" << GetMethodName(m) << "(*hdr, payload);"
           << std::endl
           << "            break;" << std::endl;
    }

    os << R"(    }
    return true;
}

)";

    // built in helpers
    os << R"(
// Helper method to build and send a reply message
void Server::_sendReply(const MessageHeader &inHdr, const size_t payloadBytes) {
    const size_t len = sizeof(MessageHeader) + payloadBytes;

    auto hdr = reinterpret_cast<MessageHeader *>(this->txBuf);
    memset(hdr, 0, sizeof(*hdr));
    hdr->type = inHdr.type;
    hdr->flags = MessageHeader::Flags::Response;
    hdr->tag = inHdr.tag;

    const std::span<std::byte> txBufSpan(reinterpret_cast<std::byte *>(this->txBuf), len);
    if(!this->io->reply(txBufSpan)) {
        this->_HandleError(false, "Failed to send RPC reply");
    }
}

// Allocates an aligned transmit buffer of the given size
void Server::_ensureTxBuf(const size_t payloadBytes) {
    const size_t len = sizeof(MessageHeader) + payloadBytes + 16;
    if(len > this->txBufSize) {
        free(this->txBuf);
        int err = posix_memalign(&this->txBuf, 16, len);
        if(err) {
            return this->_HandleError(true, "Failed to allocate RPC send buffer");
        }
        this->txBufSize = len;
    }
}

/**
 * Handles an error that occurred on the server connection. Implementations may override this
 * method if they want to use exceptions, for example.
 *
 * @param fatal If set, the error precludes further operation on this RPC connection
 * @param what Descriptive string for the error
 */
void Server::_HandleError(const bool fatal, const std::string_view &what) {
    fprintf(stderr, "[RPCS] %s: Encountered %s RPC error: %s\n", kServiceName.data(),
        fatal ? "fatal" : "recoverable", what.data());
    if(fatal) exit(-1);
}
)";

    // implementations of marshalling methods
    for(const auto &m : this->interface->getMethods()) {
        this->serverWriteMarshallMethod(os, m);
    }
}

/**
 * Writes the implementation of the method to marshall the specified method call.
 */
void CodeGenerator::serverWriteMarshallMethod(std::ofstream &os, const Method &m) {
    // start method
    os << R"(/*
 * Autogenerated marshalling method for ')" << m.getName() << R"(' (id $)" << std::hex << m.getIdentifier() << R"()
 * Have )" << m.getParameters().size() << R"( parameter(s), )" << m.getReturns().size() << R"( return(s); method is )" << (m.isAsync() ? "async" : "sync") << R"(
 */
)"
       << "void Server::_marshall" << GetMethodName(m) << "(const MessageHeader &hdr, const std::span<std::byte> &payload) {";

    // deserialize the request
    os << R"(
    internals::)" << SerGetMessageStructName(m, false) << R"( request;
    if(!deserialize(payload, request)) return this->_HandleError(false, "Failed to deserialize request");
)";

    // invoke implementation method
    os << std::endl;
    if(!m.isAsync() && !m.getReturns().empty()) {
        os << "    auto retVal = ";
    } else {
        os << "    ";
    }
    os << "this->impl" << GetMethodName(m) << '(';

    for(size_t i = 0; i < m.getParameters().size(); i++) {
        const auto &a = m.getParameters()[i];
        os << "request." << a.getName();

        if(i != m.getParameters().size()-1) {
            os << ", ";
        }
    }

    os << ");" << std::endl;

    // handle the reply
    if(!m.isAsync()) {
        this->serverWriteMarshallMethodReply(os, m);
    }

    // finish method
    os << '}' << std::endl;
}
/**
 * Writes marshalling code for sending the reply of the invoked method.
 *
 * XXX: This currently only handles the case of either zero or a single return value.
 */
void CodeGenerator::serverWriteMarshallMethodReply(std::ofstream &os, const Method &m) {
    // set up builder
    os << R"(
    internals::)" << SerGetMessageStructName(m, true) << R"( reply;
)";

    // set the values
    for(const auto &a : m.getReturns()) {
        // get the name of the variable to read the value out of (for multiple return values)
        std::string varName{"retVal"};
        if(m.hasMultipleReturns()) {
            varName.append(".");
            varName.append(a.getName());
        }

        os << "    reply." << a.getName() << " = " << varName << ";\n";
    }

    // serialize and send the message
    os << R"(
    const auto numBytes = bytesFor(reply);
    this->_ensureTxBuf(numBytes);

    auto packet = reinterpret_cast<MessageHeader *>(this->txBuf);
    std::span<std::byte> data(packet->payload, numBytes);
    if(!serialize(data, reply)) return this->_HandleError(false, "Failed to serialize reply");

    this->_sendReply(hdr, numBytes);
)";
}



