#include "CodeGenerator.h"
#include "CodeGenerator+CppHelpers.h"
#include "InterfaceDescription.h"
#include "version.h"

#include <algorithm>
#include <exception>
#include <filesystem>
#include <fstream>
#include <iomanip>
#include <iostream>

/**
 * Returns the name of the stub class for the server, given the interface name.
 */
static inline std::string GetClassName(const std::shared_ptr<InterfaceDescription> &id) {
    auto temp = id->getName();
    temp[0] = std::toupper(temp[0]);

    return temp + "Client";
}



/**
 * Generate the C++ client stub for the interface.
 *
 * This is a full implementation that can be used directly by callers. The method calls on the
 * object are automatically converted to RPC messages, sent via the RPC stream, and the replies
 * decoded.
 */
void CodeGenerator::generateClientStub() {
    auto fileNameH = this->outDir / ("Client_" + this->interface->getName() + ".hpp");
    auto fileNameCpp = this->outDir / ("Client_" + this->interface->getName() + ".cpp");
    std::cout << "    * Client stub: " << fileNameCpp.string() << ", " << fileNameH.string()
              << std::endl;

    // generate the header first
    std::string includeGuardName("RPC_CLIENT_GENERATED_");
    includeGuardName.append(std::to_string(this->interface->getIdentifier()));

    std::ofstream header(fileNameH.string(), std::ofstream::trunc);
    this->clientWriteInfoBlock(header);
    header << R"(#ifndef )" << includeGuardName << R"(
#define )" << includeGuardName << R"(

#include <string>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <span>
#include <string_view>
#include <vector>
)" << std::endl;
    this->cppWriteIncludes(header);
    this->clientWriteHeader(header);
    header << "#endif // defined(" << includeGuardName << ")" << std::endl;

    // and then the implementation
    std::ofstream implementation(fileNameCpp.string(), std::ofstream::trunc);
    this->clientWriteInfoBlock(implementation);
    implementation << "#include \"" << fileNameH.filename().generic_string() << "\"" << std::endl
       << "#include \"" << this->serializationFile.filename().generic_string() << '"' << std::endl
       << R"(
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>

#include <rpc/rt/RpcIoStream.hpp>

using namespace rpc;

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunused-variable"
)";
    this->clientWriteImpl(implementation);
    implementation << "#pragma clang diagnostic pop" << std::endl;
}

/**
 * Write the generic info block comment for the client stub.
 */
void CodeGenerator::clientWriteInfoBlock(std::ofstream &os) {
    os << R"(/*
 * This RPC client stub was autogenerated by idlc (version )"
    << std::string(gVERSION_HASH).substr(0, 8) << "). DO NOT EDIT!" << std::endl
    << " * Generated from " << this->interface->getSourceFilename() << " for interface "
       << this->interface->getName() << " at " << this->creationTimestamp << std::endl
       << " *" << R"(
 * You may use these generated stubs directly as the RPC interface, or you can subclass it to
 * override the behavior of the function calls, or to perform some preprocessing to the data as
 * needed before sending it.
 *
 * See the full RPC documentation for more details.
)"
       << " */" << std::endl;
}



/**
 * Writes the header for the client stub.
 *
 * Most of this interface is the same boilerplate as the server code contains, with the change
 * that there's only one set of call methods available.
 */
void CodeGenerator::clientWriteHeader(std::ofstream &os) {
    // set up the starting of the class
    const auto className = GetClassName(this->interface);

    os << "namespace rpc {" << std::endl << "namespace rt { class ClientRpcIoStream; }"
       << std::endl;
    os << "class " << className << " {";

    /*
     * Define the types, including the MessageHeader, which is the structure we expect to receive
     * from whatever input/output stream we're provided. Also define the return type structs for
     * any methods with multiple returns.
     */
    os << R"(
    struct MessageHeader {
        enum Flags: uint32_t {
            Request                     = (1 << 0),
            Response                    = (1 << 1),
        };

        uint64_t type;
        Flags flags;
        uint32_t tag;

        std::byte payload[];
    };
    static_assert(!(offsetof(MessageHeader, payload) % sizeof(uintptr_t)),
        "message header's payload is not word aligned");

    constexpr static const std::string_view kServiceName{")" << this->interface->getName() << R"("};

    protected:
        using IoStream = rt::ClientRpcIoStream;

    public:
)";

    for(const auto &m : this->interface->getMethods()) {
        if(!m.hasMultipleReturns()) continue;
        this->cppWriteReturnStruct(os, m);
    }


    // public methods
    os << R"(
    public:
        )" << className << R"((const std::shared_ptr<IoStream> &stream);
        virtual ~)" << className << R"(();
)";

    // actual RPC methods
    os << R"(
)";
    for(const auto &m : this->interface->getMethods()) {
        os << "        virtual ";
        this->cppWriteMethodDef(os, m);
        os << ';' << std::endl;
    }

    // implementation details
    os << R"(
    // Helpers provided to subclasses for implementation of interface methods
    protected:
        constexpr inline auto &getIo() {
            return this->io;
        }

        /// Handles errors occurring during client operations; typically terminates the task
        virtual void _HandleError(const bool fatal, const std::string_view &what);

    // Implementation details; pretend this does not exist
    private:
        std::shared_ptr<IoStream> io;
        size_t txBufSize{0};
        void *txBuf{nullptr};

        uint32_t nextTag{0};

        void _ensureTxBuf(const size_t);
        uint32_t _sendRequest(const uint64_t type, const size_t payloadBytes);
)";

    // close the class and namespace
    os << "}; // class " << className << std::endl
       << "} // namespace rpc" << std::endl;
}




/**
 * Writes the implementation for the client stub.
 */
void CodeGenerator::clientWriteImpl(std::ofstream &os) {
    // define templated custom serialization methods if needed
    const auto className = GetClassName(this->interface);
    os << "using Client = " << className << ';' << std::endl << std::endl;

    this->cppWriteCustomTypeHelpers(os);

    // define the constructor and destructors, as well as internal helpers
    os << R"(/**
 * Creates a new client instance, with the given IO stream.
 */
Client::)" << className << R"((const std::shared_ptr<IoStream> &stream) : io(stream) {
}

/**
 * Shuts down the RPC client, releasing any allocated resources.
 */
Client::~)" << className << R"(() {
    free(this->txBuf);
}
)";

os << R"(
/// Sends the message that's been built up in the transmit message buffer.
uint32_t Client::_sendRequest(const uint64_t type, const size_t payloadBytes) {
    const size_t len = sizeof(MessageHeader) + payloadBytes;

    const auto tag = __atomic_add_fetch(&this->nextTag, 1, __ATOMIC_RELAXED);
    auto hdr = reinterpret_cast<MessageHeader *>(this->txBuf);
    memset(hdr, 0, sizeof(*hdr));
    hdr->type = type;
    hdr->flags = MessageHeader::Flags::Request;
    hdr->tag = tag;

    const std::span<std::byte> txBufSpan(reinterpret_cast<std::byte *>(this->txBuf), len);
    if(!this->io->sendRequest(txBufSpan)) {
        this->_HandleError(true, "Failed to send RPC request");
        return 0;
    }

    return tag;
}

// Allocates an aligned transmit buffer of the given size
void Client::_ensureTxBuf(const size_t payloadBytes) {
    const size_t len = sizeof(MessageHeader) + payloadBytes + 16;
    if(len > this->txBufSize) {
        free(this->txBuf);
        int err = posix_memalign(&this->txBuf, 16, len);
        if(err) {
            return this->_HandleError(true, "Failed to allocate RPC send buffer");
        }
        this->txBufSize = len;
    }
}

/**
 * Handles an error that occurred on the client connection. Implementations may override this
 * method if they want to use exceptions, for example.
 *
 * @param fatal If set, the error precludes further operation on this RPC connection
 * @param what Descriptive string for the error
 */
void Client::_HandleError(const bool fatal, const std::string_view &what) {
    fprintf(stderr, "[RPCC] %s: Encountered %s RPC error: %s\n", kServiceName.data(),
        fatal ? "fatal" : "recoverable", what.data());
    if(fatal) exit(-1);
}
)";

    // write out the implementations for each of the calls
    for(const auto &m : this->interface->getMethods()) {
        this->clientWriteMarshallMethod(os, m);
    }
}

/**
 * Writes the implementation of the method to marshall the specified method call.
 */
void CodeGenerator::clientWriteMarshallMethod(std::ofstream &os, const Method &m) {
    // start method
    os << R"(/*
 * Autogenerated call method for ')" << m.getName() << R"(' (id $)" << std::hex << m.getIdentifier() << R"()
 * Have )" << m.getParameters().size() << R"( parameter(s), )" << m.getReturns().size() << R"( return(s); method is )" << (m.isAsync() ? "async" : "sync") << R"(
 */
)";
    this->cppWriteMethodDef(os, m, "Client::", "Client::");
    os << " {";

    if(!m.isAsync()) {
        os << std::endl << "    uint32_t sentTag;";
    }

    // build up the request
    os << R"(
    {
        internals::)" << SerGetMessageStructName(m, false) << R"( request;
)";

    for(const auto &a : m.getParameters()) {
        os << "        request." << a.getName() << " = " << a.getName() << ";" << std::endl;
    }

    // serialize and send
    os << R"(
        const auto numBytes = bytesFor(request);
        this->_ensureTxBuf(numBytes);

        auto packet = reinterpret_cast<MessageHeader *>(this->txBuf);
        std::span<std::byte> data(packet->payload, numBytes);
        serialize(data, request);
        )" << (m.isAsync() ? "" : "sentTag = ") << R"(this->_sendRequest(static_cast<uint64_t>()"
          << SerGetMessageIdEnumName(m) << R"(), numBytes);
    }
)";

    // read the response, if applicable
    if(!m.isAsync()) {
        this->clientWriteMarshallMethodReply(os, m);
    }

    // finish method
    os << '}' << std::endl;
}

/**
 * Writes marshalling code for receiving the reply of the invoked method.
 *
 * @note This will always result in a blocking call; there's no support in _not_ blocking on the
 * received message. Likewise, if any message tag other than the one we sent is received, we'll
 * abort.
 */
void CodeGenerator::clientWriteMarshallMethodReply(std::ofstream &os, const Method &m) {
    // first, receive a message and validate the received buffer
    os << R"(    {
        std::span<std::byte> buf;
        if(!this->io->receiveReply(buf)) this->_HandleError(false, "Failed to receive RPC reply");
        if(buf.size() < sizeof(MessageHeader)) this->_HandleError(false, "Received message too small");
        const auto hdr = reinterpret_cast<const MessageHeader *>(buf.data());
        if(hdr->tag != sentTag) this->_HandleError(false, "Invalid tag in reply RPC packet");
        else if(hdr->type != static_cast<uint64_t>()" << SerGetMessageIdEnumName(m) << R"()) this->_HandleError(false, "Invalid type in reply RPC packet");
        const auto payload = buf.subspan(offsetof(MessageHeader, payload));
)";

    // then deserialize it
    os << R"(
        internals::)" << SerGetMessageStructName(m, true) << R"( reply;
        if(!deserialize(payload, reply)) this->_HandleError(false, "Failed to decode message");
)";

    // handle if there's reply arguments
    const auto &returns = m.getReturns();
    if(returns.size() == 1) {
        const auto &a = returns[0];
        os << "        return reply." << a.getName() << ";";
    } else if(!returns.empty()) {
        os << "        " << m.getName() << "Return r;" << std::endl;

        for(const auto &a : returns) {
            os << "        r." << a.getName() << " =  reply." << a.getName() << ";\n";
        }

        os << "        return r;" << std::endl;
    }

    // end the method
    os << R"(
    }
)";
}
