///
/// Copyright (C) 2016, Dependable Systems Laboratory, EPFL
/// Copyright (C) 2016, Cyberhaven
///
/// Permission is hereby granted, free of charge, to any person obtaining a copy
/// of this software and associated documentation files (the "Software"), to deal
/// in the Software without restriction, including without limitation the rights
/// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
/// copies of the Software, and to permit persons to whom the Software is
/// furnished to do so, subject to the following conditions:
///
/// The above copyright notice and this permission notice shall be included in all
/// copies or substantial portions of the Software.
///
/// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
/// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
/// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
/// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
/// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
/// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
/// SOFTWARE.
///

#include <s2e/S2E.h>
#include <s2e/Utils.h>

#include <unordered_map>

#include <boost/algorithm/string/classification.hpp>
#include <boost/algorithm/string/split.hpp>
#include <boost/filesystem.hpp>
#include <boost/regex.hpp>

#include <s2e/Plugins/OSMonitors/Linux/DecreeMonitor.h>

#include "Recipe.h"
#include "RecipeDescriptor.h"

namespace s2e {
namespace plugins {
namespace recipe {

using namespace pov;

// This allows using per-plugin output streams.
// Don't use g_s2e, because it does not obey log verbosity levels.
static Plugin *s_plugin;

// TODO: design a clean solution so that classes that are
// related to a plugin but not derived from it can use
// its output functions (note that it also ties these classes
// to that plugin).
static llvm::raw_ostream &getDebugStream() {
    if (!s_plugin) {
        s_plugin = g_s2e->getPlugin<Recipe>();
    }

    return s_plugin->getDebugStream() << "RecipeDescriptor: ";
}

static llvm::raw_ostream &getWarningsStream() {
    if (!s_plugin) {
        s_plugin = g_s2e->getPlugin<Recipe>();
    }

    return s_plugin->getWarningsStream() << "RecipeDescriptor: ";
}

RecipeDescriptor *RecipeDescriptor::fromFile(const std::string &recipeFile) {
    unsigned ptrSize;
    std::vector<std::string> lines;
    Preconditions eipPreconditions;
    std::unordered_map<uint8_t, uint8_t> byteValues;

    if (!ReadLines(recipeFile, lines, true)) {
        return nullptr;
    }

    RecipeDescriptor *ret = new RecipeDescriptor();

    for (std::string line : lines) {
        /* Skip comments */
        if (line.size() == 0 || line.at(0) == '#') {
            continue;
        }

        if (line[0] == ':') {
            if (!ret->parseSettingsLine(line)) {
                goto err;
            }
        } else {
            if (!ret->parsePreconditionLine(line)) {
                goto err;
            }
        }
    }

    if (!ret->isValid()) {
        goto err;
    }

    /* Check category */
    /* First, extract preconditions on EIP */
    foreach2 (it, ret->preconditions.begin(), ret->preconditions.end()) {
        if (it->left->type() == Left::Type::REGBYTE && it->left->reg()->isPc()) {
            eipPreconditions.push_back(*it);
        }
    }

    ptrSize = ret->settings.arch == RECIPE_AMD64 ? 8 : 4;

    if (eipPreconditions.size() > ptrSize) {
        getWarningsStream() << "Invalid set of preconditions on EIP\n";
        goto err;
    }

    foreach2 (it, eipPreconditions.begin(), eipPreconditions.end()) {
        if (it->right->type() != Right::Type::CONCRETE) {
            break;
        } else {
            if (byteValues.find(it->left->reg()->idx()) == byteValues.end()) {
                assert(it->right->valueWidth() == klee::Expr::Int8 && "Only 8bit values must be used in recipe");
                byteValues[it->left->reg()->idx()] = it->right->value();
            } else {
                getWarningsStream() << "Multiple preconditions for byte " << int(it->left->reg()->idx()) << " of EIP\n";
                goto err;
            }
        }
    }

    if (byteValues.size() == 4) {
        ret->eipType = EIPType::CONCRETE_EIP;
        // We know that we can't have byte indexes outside [0:3], since we
        // checked at parse time. Furthermore, if we get here, we know we have
        // 4 distinct indexes, thus we can assume we have all we need to build
        // the target eip
        ret->concreteTargetEIP = (byteValues[3] << 24 | byteValues[2] << 16 | byteValues[1] << 8 | byteValues[0]);
        return ret;
    }

    /* ATM, we do not have any other case */
    ret->eipType = EIPType::SYMBOLIC_EIP;

    return ret;

err:
    delete ret;
    return nullptr;
}

bool RecipeDescriptor::isValid() const {
    if (settings.type == PovType::POV_GENERAL) {
        getWarningsStream() << "Recipe has invalid or unset type!\n";
        return false;
    }

#if 0
    // Can't happen, regs are always inited
    if (settings.type == PovType::POV_TYPE1 && settings.gp.reg == Register::REG_INV) {
        getWarningsStream() << "Type 1 recipe has invalid or unset GP!\n";
        return false;
    }
#endif

    return true;
}

static const boost::regex REGEX_SETTINGS(":(.+?)\\s*=\\s*(.+)");

bool RecipeDescriptor::parseSettingsLine(const std::string &line) {

    boost::smatch match;
    if (!boost::regex_match(line, match, REGEX_SETTINGS) || (match.size() != 3)) {
        getWarningsStream() << "Invalid settings format:" << line << "\n";
        return false;
    }

    const std::string &name = match[1];
    const std::string &value = match[2];

    if (name == "type") {
        settings.type = PovType(std::stoull(value, nullptr, 0));
    } else if (name == "arch") {
        if (value == "i386") {
            settings.arch = RECIPE_I386;
        } else if (value == "amd64") {
            settings.arch = RECIPE_AMD64;
        } else {
            getWarningsStream() << "Invalid arch " << value << "\n";
            return false;
        }
    } else if (name == "platform") {
        if (value == "generic") {
            settings.platform = RECIPE_GENERIC;
        } else if (value == "decree") {
            settings.platform = RECIPE_DECREE;
        } else {
            getWarningsStream() << "Invalid platform " << value << "\n";
            return false;
        }
    } else if (name == "exec_mem") {
        auto reg = Register::fromName(value, 0);
        if (!reg) {
            getWarningsStream() << "Invalid register " << value << " in exec_mem setting\n";
            return false;
        }

        auto l = Left::createRegPtrExec(reg);

        preconditions.push_back(Precondition(l, Right::createInvalid()));
    } else if (name == "skip") {
        settings.skip = std::stoull(value, nullptr, 0);
    } else if (name == "module_name") {
        settings.moduleName = value;
    } else if (name == "gp") {
        settings.gp = Register::fromName(value, 0);
    } else if (name == "reg_mask") {
        settings.regMask = std::stoull(value, nullptr, 0);
    } else if (name == "pc_mask") {
        settings.ipMask = std::stoull(value, nullptr, 0);
    } else {
        getWarningsStream() << "Invalid settings name: " << name << "\n";
        return false;
    }

    return true;
}

/// \brief Escape special characters for regex
///
/// Use this function to escape special characters in string that
/// will be later used in regex.
///
/// \param s input string
/// \return escaped string
///
static std::string esc(const std::string &s) {
    std::string ret;
    for (auto c : s) {
        if (c == '$') {
            ret += '\\';
        }
        ret += c;
    }
    return ret;
}

static const boost::regex REGEX_PRECONDITION("(.+?)\\s*==\\s*(.+)");

bool RecipeDescriptor::parsePreconditionLine(const std::string &line) {
    const std::string inputStr = trim(line);

    boost::smatch match;
    if (!boost::regex_match(line, match, REGEX_PRECONDITION) || (match.size() != 3)) {
        getWarningsStream() << "Invalid precondition line:" << line << "\n";
        return false;
    }

    const std::string &leftStr = match[1];
    const std::string &rightStr = match[2];

    const std::string varNameRegex =
        esc(VARNAME_PC) + "|" + esc(VARNAME_GP) + "|" + esc(VARNAME_ADDR) + "|" + esc(VARNAME_SIZE);
    const std::string regNameRegex = "EAX|EBX|ECX|EDX|EBP|ESP|ESI|EDI|EIP|RAX|RBX|RCX|RDX|RBP|RSP|RSI|RDI|RIP";

    const std::string numberRegex = "0x[[:xdigit:]]+|[[:digit:]]+";
    const std::string byteOffsetRegex = "[0-7]{1}";

    const boost::regex assignRegex("([^=]+)==([^=]+)");
    const boost::regex addrRegex("\\[(" + numberRegex + ")\\]");

    const boost::regex regOffsetRegex("(" + regNameRegex + ")\\[(" + byteOffsetRegex + ")\\]");
    const boost::regex regPtrRegex("\\[(" + regNameRegex + ")\\+(" + numberRegex + ")\\]");
    const boost::regex regPtrPtrOffsetRegex("\\[(" + regNameRegex + ")\\+(" + numberRegex + ")\\]" + "\\[(" +
                                            numberRegex + ")\\]");
    const boost::regex valRegex("(" + numberRegex + ")");
    const boost::regex charRegex("'([[:print:]]{1})'");
    const boost::regex tagRegex("((" + varNameRegex + ")\\[" + byteOffsetRegex + "\\])");

    // Parse left
    klee::ref<Left> left;
    if (boost::regex_match(leftStr, match, regOffsetRegex) && match.size() == 3) {
        uint8_t byteIdx = static_cast<uint8_t>(std::stoul(match[2], nullptr, 0));
        auto reg = Register::fromName(match[1], byteIdx);
        left = Left::createRegByte(reg);
    } else if (boost::regex_match(leftStr, match, addrRegex) && match.size() == 2) {
        left = Left::createAddr(std::stoull(match[1], nullptr, 0));
    } else if (boost::regex_match(leftStr, match, regPtrRegex) && match.size() == 3) {
        auto reg = Register::fromName(match[1], 0);
        left = Left::createRegPtr(reg, static_cast<off_t>(std::stoull(match[2], nullptr, 0)));
    } else if (boost::regex_match(leftStr, match, regPtrPtrOffsetRegex) && match.size() == 4) {
        auto reg = Register::fromName(match[1], 0);
        auto ptr1 = std::stoull(match[2], nullptr, 0);
        auto ptr2 = std::stoull(match[3], nullptr, 0);
        left = Left::createRegPtrPtr(reg, ptr1, ptr2);
    } else {
        getWarningsStream() << "Invalid left expression: '" << leftStr << "'\n";
        return false;
    }

    // Parse right
    klee::ref<Right> right;
    if (boost::regex_match(rightStr, match, valRegex) && match.size() == 2) {
        unsigned long long v = std::stoull(match[1], nullptr, 0);
        if (v != klee::bits64::truncateToNBits(v, klee::Expr::Int8)) {
            getWarningsStream() << "Value must fit 8 bits: " << hexval(v) << "\n";
            return false;
        }
        right = Right::createConcrete(v, klee::Expr::Int8);
    } else if (boost::regex_match(rightStr, match, charRegex) && match.size() == 2) {
        right = Right::createConcrete(static_cast<uint8_t>(match[1].str()[0]), klee::Expr::Int8);
    } else if (boost::regex_match(rightStr, match, tagRegex) && match.size() == 3) {
        right = Right::createNegotiable(match[1]);
    } else if (boost::regex_match(rightStr, match, regOffsetRegex) && match.size() == 3) {
        unsigned long long idx = std::stoull(match[2], nullptr, 0);
        auto reg = Register::fromName(match[1], static_cast<uint8_t>(idx));
        right = Right::createRegByte(reg);
    } else {
        getWarningsStream() << "Invalid right expression: '" << rightStr << "'\n";
        return false;
    }

    // Done
    preconditions.push_back(Precondition(left, right));
    return true;
}

bool RecipeDescriptor::mustTryRecipe(const RecipeDescriptor &recipe, const std::string &recipeName,
                                     const StateConditions &sc, uint64_t eip) {
    // If moduleName is not set in the recipe, the recipe must be tried for every module name
    if (recipe.settings.moduleName.size() == 0 || recipe.settings.moduleName == sc.module.Name) {

        // If we have a symbolic EIP, we can try recipes with both symbolic and concrete EIP
        if (sc.eipType == EIPType::SYMBOLIC_EIP) {
            return true;
        }

        // If both the current EIP and recipe EIP are concrete, check if they match
        if (sc.eipType == EIPType::CONCRETE_EIP && recipe.eipType == EIPType::CONCRETE_EIP &&
            recipe.concreteTargetEIP == eip) {
            return true;
        }
    } else {
        getDebugStream() << "Recipe.settings.moduleName (" << recipe.settings.moduleName << ") != (" << sc.module.Name
                         << ") moduleName. Skipping.\n";
    }

    return false;
}

bool RecipeDescriptor::isUsable(S2EExecutionState *state) const {
    unsigned ptrSize = state->getPointerSize();

    if (settings.arch == RECIPE_AMD64) {
        if (ptrSize == sizeof(uint32_t)) {
            return false;
        }
    } else if (settings.arch == RECIPE_I386) {
        if (ptrSize == sizeof(uint64_t)) {
            return false;
        }
    }

    if (settings.platform == RECIPE_DECREE) {
        auto plugin = g_s2e->getPlugin("DecreeMonitor");
        if (!plugin) {
            return false;
        }
    }

    return true;
}
} // namespace recipe
} // namespace plugins
} // namespace s2e
