#include "stored_table_node.hpp"

#include <algorithm>
#include <cstddef>
#include <memory>
#include <ostream>
#include <set>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

#include <boost/container_hash/hash.hpp>

#include "expression/expression_utils.hpp"
#include "expression/lqp_column_expression.hpp"
#include "hyrise.hpp"
#include "logical_query_plan/abstract_lqp_node.hpp"
#include "logical_query_plan/data_dependencies/order_dependency.hpp"
#include "logical_query_plan/data_dependencies/unique_column_combination.hpp"
#include "lqp_utils.hpp"
#include "storage/index/chunk_index_statistics.hpp"
#include "storage/index/table_index_statistics.hpp"
#include "storage/storage_manager.hpp"
#include "types.hpp"
#include "utils/assert.hpp"
#include "utils/pruning_utils.hpp"

namespace {

using namespace hyrise;  // NOLINT(build/namespaces)

template <typename ColumnIDs>
bool contains_any_column_id(const ColumnIDs& search_columns, const std::vector<ColumnID>& columns) {
  return std::any_of(columns.cbegin(), columns.cend(), [&](const auto& column_id) {
    return std::find(search_columns.cbegin(), search_columns.cend(), column_id) != search_columns.cend();
  });
}

}  // namespace

namespace hyrise {

StoredTableNode::StoredTableNode(const std::string& init_table_name)
    : AbstractLQPNode(LQPNodeType::StoredTable), table_name(init_table_name) {}

std::shared_ptr<LQPColumnExpression> StoredTableNode::get_column(const std::string& name) const {
  const auto& table = Hyrise::get().storage_manager.get_table(table_name);
  const auto column_id = table->column_id_by_name(name);
  return std::make_shared<LQPColumnExpression>(shared_from_this(), column_id);
}

void StoredTableNode::set_pruned_chunk_ids(const std::vector<ChunkID>& pruned_chunk_ids) {
  DebugAssert(std::is_sorted(pruned_chunk_ids.begin(), pruned_chunk_ids.end()), "Expected sorted vector of ChunkIDs");
  DebugAssert(std::adjacent_find(pruned_chunk_ids.begin(), pruned_chunk_ids.end()) == pruned_chunk_ids.end(),
              "Expected vector of unique ChunkIDs");

  _pruned_chunk_ids = pruned_chunk_ids;
}

const std::vector<ChunkID>& StoredTableNode::pruned_chunk_ids() const {
  return _pruned_chunk_ids;
}

void StoredTableNode::set_pruned_column_ids(const std::vector<ColumnID>& pruned_column_ids) {
  DebugAssert(std::is_sorted(pruned_column_ids.begin(), pruned_column_ids.end()),
              "Expected sorted vector of ColumnIDs");
  DebugAssert(std::adjacent_find(pruned_column_ids.begin(), pruned_column_ids.end()) == pruned_column_ids.end(),
              "Expected vector of unique ColumnIDs");

  // It is valid for an LQP to not use any of the table's columns (e.g., SELECT 5 FROM t). We still need to include at
  // least one column in the output of this node, which is used by Table::size() to determine the number of 5's.
  const auto stored_column_count = Hyrise::get().storage_manager.get_table(table_name)->column_count();
  Assert(pruned_column_ids.size() < static_cast<size_t>(stored_column_count), "Cannot exclude all columns from Table.");

  _pruned_column_ids = pruned_column_ids;

  _set_output_expressions();
}

const std::vector<ColumnID>& StoredTableNode::pruned_column_ids() const {
  return _pruned_column_ids;
}

void StoredTableNode::set_prunable_subquery_predicates(
    const std::vector<std::weak_ptr<AbstractLQPNode>>& predicate_nodes) {
  DebugAssert(std::all_of(predicate_nodes.cbegin(), predicate_nodes.cend(),
                          [](const auto& node) {
                            return node.lock() && node.lock()->type == LQPNodeType::Predicate;
                          }),
              "No PredicateNode set as prunable predicate.");
  _prunable_subquery_predicates = predicate_nodes;
}

std::vector<std::shared_ptr<AbstractLQPNode>> StoredTableNode::prunable_subquery_predicates() const {
  auto subquery_predicates = std::vector<std::shared_ptr<AbstractLQPNode>>{};
  subquery_predicates.reserve(_prunable_subquery_predicates.size());
  for (const auto& subquery_predicate_ref : _prunable_subquery_predicates) {
    const auto& subquery_predicate = subquery_predicate_ref.lock();
    Assert(subquery_predicate, "Referenced PredicateNode expired. LQP is invalid.");
    subquery_predicates.emplace_back(subquery_predicate);
  }
  return subquery_predicates;
}

std::string StoredTableNode::description(const DescriptionMode /*mode*/) const {
  const auto& stored_table = Hyrise::get().storage_manager.get_table(table_name);

  auto stream = std::ostringstream{};
  stream << "[StoredTable] Name: '" << table_name << "' pruned: ";
  stream << _pruned_chunk_ids.size() << "/" << stored_table->chunk_count() << " chunk(s), ";
  stream << _pruned_column_ids.size() << "/" << stored_table->column_count() << " column(s)";

  return stream.str();
}

std::vector<std::shared_ptr<AbstractExpression>> StoredTableNode::output_expressions() const {
  if (!_output_expressions) {
    _set_output_expressions();
  }

  return *_output_expressions;
}

bool StoredTableNode::is_column_nullable(const ColumnID column_id) const {
  const auto& table = Hyrise::get().storage_manager.get_table(table_name);
  return table->column_is_nullable(column_id);
}

UniqueColumnCombinations StoredTableNode::unique_column_combinations() const {
  auto unique_column_combinations = UniqueColumnCombinations{};

  // We create unique column combinations from selected table key constraints.
  const auto& table = Hyrise::get().storage_manager.get_table(table_name);
  const auto& table_key_constraints = table->soft_key_constraints();

  for (const auto& table_key_constraint : table_key_constraints) {
    // Discard key constraints that involve pruned column id(s).
    if (contains_any_column_id(table_key_constraint.columns(), _pruned_column_ids)) {
      continue;
    }

    // Search for expressions representing the key constraint's ColumnIDs.
    auto column_expressions = get_expressions_for_column_ids(*this, table_key_constraint.columns());
    DebugAssert(column_expressions.size() == table_key_constraint.columns().size(),
                "Unexpected count of column expressions.");

    // Create UniqueColumnCombination.
    unique_column_combinations.emplace(std::move(column_expressions));
  }

  return unique_column_combinations;
}

OrderDependencies StoredTableNode::order_dependencies() const {
  auto order_dependencies = OrderDependencies{};

  // We create order dependencies from table order constraints.
  const auto& table = Hyrise::get().storage_manager.get_table(table_name);
  const auto& table_order_constraints = table->soft_order_constraints();

  for (const auto& table_order_constraint : table_order_constraints) {
    // Discard order constraints that involve pruned column id(s). [a] |-> [b, c] could be transformed to [a] |-> [b] if
    // c is pruned. We ignore this for now.
    if (contains_any_column_id(table_order_constraint.ordering_columns(), _pruned_column_ids) ||
        contains_any_column_id(table_order_constraint.ordered_columns(), _pruned_column_ids)) {
      continue;
    }

    // Search for expressions representing the order constraint's ColumnIDs.
    auto column_expressions = get_expressions_for_column_ids(*this, table_order_constraint.ordering_columns());
    auto ordered_column_expressions = get_expressions_for_column_ids(*this, table_order_constraint.ordered_columns());

    // Create OrderDependency.
    order_dependencies.emplace(std::move(column_expressions), std::move(ordered_column_expressions));
  }

  // Construct transitive ODs. For instance, create [a] |-> [c] from [a] |-> [b] and [b] |-> [c].
  build_transitive_od_closure(order_dependencies);

  return order_dependencies;
}

std::vector<ChunkIndexStatistics> StoredTableNode::chunk_indexes_statistics() const {
  DebugAssert(!left_input() && !right_input(), "StoredTableNode must be a leaf");

  const auto& table = Hyrise::get().storage_manager.get_table(table_name);
  if (_pruned_column_ids.empty()) {
    return table->chunk_indexes_statistics();
  }

  auto pruned_indexes_statistics = table->chunk_indexes_statistics();
  const auto column_id_mapping = pruned_column_id_mapping(table->column_count(), _pruned_column_ids);

  // Update index statistics
  // Note: The lambda also modifies statistics.column_ids. This is done because a regular for loop runs into issues
  // when remove(iterator) invalidates the iterator.
  pruned_indexes_statistics.erase(std::remove_if(pruned_indexes_statistics.begin(), pruned_indexes_statistics.end(),
                                                 [&](auto& statistics) {
                                                   for (auto& original_column_id : statistics.column_ids) {
                                                     const auto updated_column_id =
                                                         column_id_mapping[original_column_id];
                                                     if (updated_column_id == INVALID_COLUMN_ID) {
                                                       // Indexed column was pruned - remove index from statistics
                                                       return true;
                                                     }

                                                     // Update column id
                                                     original_column_id = updated_column_id;
                                                   }
                                                   return false;
                                                 }),
                                  pruned_indexes_statistics.end());

  return pruned_indexes_statistics;
}

std::vector<TableIndexStatistics> StoredTableNode::table_indexes_statistics() const {
  const auto& table = Hyrise::get().storage_manager.get_table(table_name);

  if (_pruned_column_ids.empty()) {
    return table->table_indexes_statistics();
  }

  const auto input_table_column_count = table->column_count();
  const auto& index_statistics = table->table_indexes_statistics();
  const auto column_id_mapping = pruned_column_id_mapping(input_table_column_count, _pruned_column_ids);

  auto pruned_index_statistics = std::vector<TableIndexStatistics>{};
  pruned_index_statistics.reserve(input_table_column_count - _pruned_column_ids.size());

  for (const auto& index_statistic : index_statistics) {
    // TODO(anyone): When chunk indexes are removed, TableIndexStatistics should no longer store a vector of ColumnIDs
    // as multi-column indexes are no longer supported.
    DebugAssert(index_statistic.column_ids.size() == 1, "Unexpected multi-column index");

    const auto& updated_column_id = column_id_mapping[index_statistic.column_ids[0]];
    if (updated_column_id == INVALID_COLUMN_ID) {
      // Indexed column was pruned.
      continue;
    }

    // Append statistic and update its column id.
    pruned_index_statistics.push_back(index_statistic);
    pruned_index_statistics.back().column_ids[0] = updated_column_id;
  }

  return pruned_index_statistics;
}

size_t StoredTableNode::_on_shallow_hash() const {
  auto hash = size_t{0};
  boost::hash_combine(hash, table_name);
  for (const auto& pruned_chunk_id : _pruned_chunk_ids) {
    boost::hash_combine(hash, pruned_chunk_id);
  }
  for (const auto& pruned_column_id : _pruned_column_ids) {
    boost::hash_combine(hash, pruned_column_id);
  }
  // We intentionally force a hash collision for StoredTableNodes with the same number of prunable subquery predicates
  // even though these predicates are different. Since we assume that (i) these predicates are not often set and (ii) we
  // hash LQPs often, this reduces the hash overhead, makes the code simpler, and triggers an in-depth equality check
  // for the rare cases with (the same number of) prunable subquery predicates.
  boost::hash_combine(hash, _prunable_subquery_predicates.size());
  return hash;
}

std::shared_ptr<AbstractLQPNode> StoredTableNode::_on_shallow_copy(LQPNodeMapping& /*node_mapping*/) const {
  // We cannot copy _prunable_subquery_predicated here since deep_copy() recurses into the input nodes and the
  // StoredTableNodes are the first ones to be copied. Instead, AbstractLQPNode::deep_copy() sets the copied
  // PredicateNodes after the entire LQP has been copied.
  const auto copy = make(table_name);
  copy->set_pruned_chunk_ids(_pruned_chunk_ids);
  copy->set_pruned_column_ids(_pruned_column_ids);
  return copy;
}

bool StoredTableNode::_on_shallow_equals(const AbstractLQPNode& rhs, const LQPNodeMapping& node_mapping) const {
  const auto& stored_table_node = static_cast<const StoredTableNode&>(rhs);
  if (table_name != stored_table_node.table_name || _pruned_chunk_ids != stored_table_node._pruned_chunk_ids ||
      _pruned_column_ids != stored_table_node._pruned_column_ids) {
    return false;
  }

  // Check equality of prunable subquery predicates. For now, the order of the predicates matters. Though this is a
  // missed opportunity for LQP deduplication, we do not consider this a problem for now.
  const auto& prunable_subquery_predicates = this->prunable_subquery_predicates();
  const auto& rhs_prunable_subquery_predicates = stored_table_node.prunable_subquery_predicates();
  const auto subquery_predicate_count = prunable_subquery_predicates.size();

  if (subquery_predicate_count != rhs_prunable_subquery_predicates.size()) {
    return false;
  }

  for (auto predicate_idx = size_t{0}; predicate_idx < subquery_predicate_count; ++predicate_idx) {
    // We cannot check that the PredicateNodes are equal since this equality check recurses into the inputs und we do
    // not terminate. We have to compare the predicate expressions.
    if (!expressions_equal_to_expressions_in_different_lqp(
            prunable_subquery_predicates[predicate_idx]->node_expressions,
            rhs_prunable_subquery_predicates[predicate_idx]->node_expressions, node_mapping)) {
      return false;
    }
  }

  return true;
}

void StoredTableNode::_set_output_expressions() const {
  const auto& table = Hyrise::get().storage_manager.get_table(table_name);
  const auto stored_column_count = table->column_count();

  // Create `_output_expressions` sized with respect to the `_pruned_column_ids`
  const auto unpruned_column_count = stored_column_count - _pruned_column_ids.size();
  _output_expressions = std::vector<std::shared_ptr<AbstractExpression>>(unpruned_column_count);

  auto pruned_column_ids_iter = _pruned_column_ids.begin();
  auto output_column_id = ColumnID{0};
  for (auto stored_column_id = ColumnID{0}; stored_column_id < stored_column_count; ++stored_column_id) {
    // Skip `stored_column_id` if it is in the sorted vector `_pruned_column_ids`.
    if (pruned_column_ids_iter != _pruned_column_ids.end() && stored_column_id == *pruned_column_ids_iter) {
      ++pruned_column_ids_iter;
      continue;
    }

    (*_output_expressions)[output_column_id] =
        std::make_shared<LQPColumnExpression>(shared_from_this(), stored_column_id);
    ++output_column_id;
  }
}

}  // namespace hyrise
