#pragma once

#include <functional>
#include <optional>

#include "ctranslate2/decoding_utils.h"
#include "ctranslate2/devices.h"
#include "ctranslate2/layers/decoder.h"
#include "ctranslate2/sampling.h"
#include "ctranslate2/storage_view.h"

namespace ctranslate2 {

  struct DecodingResult {
    std::vector<std::vector<size_t>> hypotheses;
    std::vector<float> scores;
    std::vector<std::vector<std::vector<float>>> attention;
    std::vector<std::vector<StorageView>> logits_vocab;
  };

  struct DecodingStepResult {
    size_t step;
    size_t batch_id;
    size_t token_id;
    size_t hypothesis_id;
    std::optional<float> score;
    std::optional<StorageView> logits;
    bool is_last = false;
  };


  class SearchStrategy {
  public:
    virtual ~SearchStrategy() = default;
    virtual std::vector<DecodingResult>
    search(layers::Decoder& decoder,
           layers::DecoderState& state,
           const Sampler& sampler,
           const std::vector<size_t>& start_ids,
           const std::vector<size_t>& end_ids,
           const dim_t start_step,
           const dim_t max_length,
           const dim_t min_length,
           const bool return_scores = false,
           const bool return_attention = false,
           const bool return_logits_vocab = true,
           const bool return_prefix = true,
           const size_t num_hypotheses = 1,
           const bool include_eos_in_hypotheses = true,
           const std::vector<std::shared_ptr<LogitsProcessor>>& logits_processors = {},
           const std::vector<std::vector<size_t>>* prefix_ids = nullptr) const = 0;
  };

  class BeamSearch : public SearchStrategy {
  public:
    BeamSearch(const dim_t beam_size,
               const float length_penalty = 0,
               const float coverage_penalty = 0,
               const float prefix_bias_beta = 0,
               const float patience = 1);

    std::vector<DecodingResult>
    search(layers::Decoder& decoder,
           layers::DecoderState& state,
           const Sampler& sampler,
           const std::vector<size_t>& start_ids,
           const std::vector<size_t>& end_ids,
           const dim_t start_step,
           const dim_t max_length,
           const dim_t min_length,
           const bool return_scores = false,
           const bool return_attention = false,
           const bool return_logits_vocab = true,
           const bool return_prefix = true,
           const size_t num_hypotheses = 1,
           const bool include_eos_in_hypotheses = true,
           const std::vector<std::shared_ptr<LogitsProcessor>>& logits_processors = {},
           const std::vector<std::vector<size_t>>* prefix_ids = nullptr) const override;

  private:
    const dim_t _beam_size;
    const float _length_penalty;
    const float _coverage_penalty;
    const float _prefix_bias_beta;
    const size_t _max_candidates;
  };

  class BiasedDecoder {
  public:
    BiasedDecoder(const float prefix_bias_beta,
                  const std::vector<std::vector<size_t>>& prefix_ids);

    void
    decode(const dim_t cur_batch_size,
           const size_t step,
           const std::vector<dim_t>& batch_offset,
           const std::vector<std::vector<bool>>& beams_diverged_from_prefix,
           const StorageView& logits,
           StorageView& log_probs);
  private:
    StorageView _spare_beam;
    const float _prefix_bias_beta;
    std::vector<std::vector<size_t>> _prefix_ids;
  };


  class GreedySearch : public SearchStrategy {
  public:
    // Penalties are only applied to return scores consistent with the beam search.
    GreedySearch(const float length_penalty = 0,
                 const float coverage_penalty = 0,
                 std::function<bool(DecodingStepResult)> callback = nullptr);

    std::vector<DecodingResult>
    search(layers::Decoder& decoder,
           layers::DecoderState& state,
           const Sampler& sampler,
           const std::vector<size_t>& start_ids,
           const std::vector<size_t>& end_id,
           const dim_t start_step,
           const dim_t max_length,
           const dim_t min_length,
           const bool return_scores = false,
           const bool return_attention = false,
           const bool return_logits_vocab = true,
           const bool return_prefix = true,
           const size_t num_hypotheses = 1,
           const bool include_eos_in_hypotheses = true,
           const std::vector<std::shared_ptr<LogitsProcessor>>& logits_processors = {},
           const std::vector<std::vector<size_t>>* prefix_ids = nullptr) const override;

  private:
    const float _length_penalty;
    const float _coverage_penalty;
    const std::function<bool(DecodingStepResult)> _callback;
  };


  struct DecodingOptions {
    size_t beam_size = 1;
    float patience = 1;
    float length_penalty = 0;
    float coverage_penalty = 0;
    float repetition_penalty = 1;
    size_t no_repeat_ngram_size = 0;
    float prefix_bias_beta = 0;
    dim_t start_step = 0;
    size_t max_length = 256;
    size_t min_length = 0;
    size_t sampling_topk = 1;
    float sampling_topp = 1;
    float sampling_temperature = 1;
    size_t num_hypotheses = 1;
    bool include_eos_in_hypotheses = true;
    bool return_scores = false;
    bool return_attention = false;
    bool return_logits_vocab = false;
    bool return_alternatives = false;
    bool return_prefix = true;
    float min_alternative_expansion_prob = 0;
    std::vector<size_t> disable_ids;
    std::vector<size_t> disable_ids_begin;
    std::vector<std::vector<size_t>> disable_sequences;
    std::vector<std::shared_ptr<LogitsProcessor>> logits_processors;
    std::function<bool(DecodingStepResult)> callback = nullptr;
  };

  std::vector<DecodingResult>
  decode(layers::Decoder& decoder,
         layers::DecoderState& state,
         std::vector<std::vector<size_t>> start_tokens,
         std::vector<size_t> end_ids,
         DecodingOptions options = DecodingOptions());

}
