#!/System/Library/Frameworks/Ruby.framework/Versions/Current/usr/bin/ruby -w
require 'pathname'
require 'erb'

user_code = ""
setup, tests, benchmarks = [ ], [ ], [ ]

ARGV.each do |file|
  if p = Pathname.new(file)
    ns = p.basename.sub(/\.[^.]+$/, '')
    lineno, includes, body = 1, [ ], [ ]

    p.each_line do |line|
      if line =~ /^\s*#(include|import)\b/
        includes << "#line #{lineno} \"#{p.realpath}\"\n"
        includes << line.sub(/"(.*?)"/) { '"' + File.expand_path($1, p.parent) + '"' }
      else
        body << "#line #{lineno} \"#{p.realpath}\"\n"
        body << line
        case line
          when /^\s*void (setup(?:_\w+)?) \(\)/ then setup      << "#{ns}::#$1"
          when /^\s*void (test_\w+) \(\)/       then tests      << "#{ns}::#$1"
          when /^\s*void (benchmark_\w+) \(\)/  then benchmarks << "#{ns}::#$1"
        end
      end
      lineno += 1
    end

    user_code << "#{includes.join}namespace #{ns} {\n#{body.join}} /* #{ns} */\n"
  end
end

STDOUT << ERB.new(DATA.read, 0, '-<>').result(binding)

__END__
#line 37 "<%= $PROGRAM_NAME %>"
static std::string oak_format (char const* format, ...) __attribute__ ((format (printf, 1, 2)));
static std::string oak_format (char const* format, ...)
{
  char* tmp = NULL;

  va_list ap;
  va_start(ap, format);
  vasprintf(&tmp, format, ap);
  va_end(ap);

  std::string res(tmp);
  free(tmp);
  return res;
}

std::string to_s (bool value)               { return value ? "YES" : "NO"; }
std::string to_s (char value)               { return oak_format("%c", value); }
std::string to_s (size_t value)             { return oak_format("%zu", value); }
std::string to_s (ssize_t value)            { return oak_format("%zd", value); }
std::string to_s (int16_t value)            { return oak_format("%d", value); }
std::string to_s (uint16_t value)           { return oak_format("%u", value); }
std::string to_s (int32_t value)            { return oak_format("%d", value); }
std::string to_s (uint32_t value)           { return oak_format("%u", value); }
std::string to_s (int64_t value)            { return oak_format("%lld", value); }
std::string to_s (uint64_t value)           { return oak_format("%llu", value); }
std::string to_s (double value)             { return oak_format("%g", value); }
std::string to_s (char const* value)        { return value == NULL     ? "«NULL»"     : oak_format("\"%s\"", value); }
std::string to_s (std::string const& value) { return value == NULL_STR ? "«NULL_STR»" : "\"" + value + "\""; }

template <typename _X, typename _Y>
std::string to_s (std::pair<_X, _Y> const& pair)
{
  return "pair<" + to_s(pair.first) + ", " + to_s(pair.second) + ">";
}

template <typename _T>
std::string to_s (_T const& container)
{
  std::string res = "( ";
  for(auto element : container)
    res += to_s(element) + ", ";
  res += ")";
  return res;
}

void oak_warning (std::string const& message, char const* file, int line)
{
  fprintf(stderr, "%s:%d: %s\n", file, line, message.c_str());
}

struct oak_exception : std::exception
{
  oak_exception (std::string const& message) : _message(message) { }
  virtual char const* what () const noexcept { return _message.c_str(); }
private:
  std::string _message;
};

void oak_assertion_error (std::string const& message, char const* file, int line)
{
  throw oak_exception(oak_format("%s:%d: ", file, line) + message);
}

std::string oak_format_bad_relation (char const* lhs, char const* op, char const* rhs, std::string const& realLHS, char const* realOp, std::string const& realRHS)
{
  return oak_format("Expected (%s %s %s), found (%s %s %s)", lhs, op, rhs, realLHS.c_str(), realOp, realRHS.c_str());
}

#define OAK_TYPE_OF(x)          std::remove_reference<decltype(x)>::type

#define OAK_WARN(msg)           oak_warning(msg, __FILE__, __LINE__)
#define OAK_FAIL(msg)           oak_assertion_error(msg, __FILE__, __LINE__)

#define OAK_ASSERT(expr)        if(!(expr))                 oak_assertion_error(oak_format("Assertion failed: %s", #expr), __FILE__, __LINE__)
#define OAK_ASSERT_LT(lhs, rhs) do { OAK_TYPE_OF(lhs) _lhs = (lhs); OAK_TYPE_OF(rhs) _rhs = (rhs); if(!(_lhs <  _rhs)) oak_assertion_error(oak_format_bad_relation(#lhs, "<",  #rhs, to_s(_lhs), ">=", to_s(_rhs)), __FILE__, __LINE__); } while(false)
#define OAK_ASSERT_LE(lhs, rhs) do { OAK_TYPE_OF(lhs) _lhs = (lhs); OAK_TYPE_OF(rhs) _rhs = (rhs); if(!(_lhs <= _rhs)) oak_assertion_error(oak_format_bad_relation(#lhs, "<=", #rhs, to_s(_lhs), ">",  to_s(_rhs)), __FILE__, __LINE__); } while(false)
#define OAK_ASSERT_GT(lhs, rhs) do { OAK_TYPE_OF(lhs) _lhs = (lhs); OAK_TYPE_OF(rhs) _rhs = (rhs); if(!(_lhs >  _rhs)) oak_assertion_error(oak_format_bad_relation(#lhs, ">",  #rhs, to_s(_lhs), "<=", to_s(_rhs)), __FILE__, __LINE__); } while(false)
#define OAK_ASSERT_GE(lhs, rhs) do { OAK_TYPE_OF(lhs) _lhs = (lhs); OAK_TYPE_OF(rhs) _rhs = (rhs); if(!(_lhs >= _rhs)) oak_assertion_error(oak_format_bad_relation(#lhs, ">=", #rhs, to_s(_lhs), "<",  to_s(_rhs)), __FILE__, __LINE__); } while(false)
#define OAK_ASSERT_EQ(lhs, rhs) do { OAK_TYPE_OF(lhs) _lhs = (lhs); OAK_TYPE_OF(rhs) _rhs = (rhs); if(!(_lhs == _rhs)) oak_assertion_error(oak_format_bad_relation(#lhs, "==", #rhs, to_s(_lhs), "!=", to_s(_rhs)), __FILE__, __LINE__); } while(false)
#define OAK_ASSERT_NE(lhs, rhs) do { OAK_TYPE_OF(lhs) _lhs = (lhs); OAK_TYPE_OF(rhs) _rhs = (rhs); if(!(_lhs != _rhs)) oak_assertion_error(oak_format_bad_relation(#lhs, "!=", #rhs, to_s(_lhs), "==", to_s(_rhs)), __FILE__, __LINE__); } while(false)

#define OAK_MASSERT(msg, expr)        if(!(expr))           oak_assertion_error(msg, __FILE__, __LINE__)
#define OAK_MASSERT_LT(msg, lhs, rhs) if(!((lhs) <  (rhs))) oak_assertion_error(msg, __FILE__, __LINE__)
#define OAK_MASSERT_LE(msg, lhs, rhs) if(!((lhs) <= (rhs))) oak_assertion_error(msg, __FILE__, __LINE__)
#define OAK_MASSERT_GT(msg, lhs, rhs) if(!((lhs) >  (rhs))) oak_assertion_error(msg, __FILE__, __LINE__)
#define OAK_MASSERT_GE(msg, lhs, rhs) if(!((lhs) >= (rhs))) oak_assertion_error(msg, __FILE__, __LINE__)
#define OAK_MASSERT_EQ(msg, lhs, rhs) if(!((lhs) == (rhs))) oak_assertion_error(msg, __FILE__, __LINE__)
#define OAK_MASSERT_NE(msg, lhs, rhs) if(!((lhs) != (rhs))) oak_assertion_error(msg, __FILE__, __LINE__)

<%= user_code %>
#line 128 "<%= $PROGRAM_NAME %>"

static double const AppVersion = 1.0;

static void write_version (FILE* io)
{
  fprintf(io, "%s %.1f (<%= Time.now %>)\n", getprogname(), AppVersion);
}

static void write_usage (FILE* io)
{
  write_version(io);
  fprintf(io, "Usage: %s [-b|--benchmark] [-m|--measure] [-r|--repeat] [-v|--verbose] [-h|--help] [-V|--version]\n", getprogname());
  fprintf(io, "Options:\n");
  fprintf(io, " -b/--benchmark      Run benchmarks instead of tests.\n");
  fprintf(io, " -m/--measure        Measure time of each test.\n");
  fprintf(io, " -p/--[no-]parallel  Run tests in parallel (default).\n");
  fprintf(io, " -r/--repeat <n>     Number of times to repeat each test/benchmark.\n");
  fprintf(io, " -v/--verbose        Be verbose.\n");
  fprintf(io, " -h/--help           Show this help.\n");
  fprintf(io, " -V/--version        Show version number.\n");
}

int main (int argc, char const* argv[])
{
  static struct option const options[] =
  {
    { "benchmark",   no_argument,       0, 'b' },
    { "measure",     no_argument,       0, 'm' },
    { "parallel",    no_argument,       0, 'p' },
    { "no-parallel", no_argument,       0, 'P' },
    { "repeat",      required_argument, 0, 'r' },
    { "verbose",     no_argument,       0, 'v' },
    { "help",        no_argument,       0, 'h' },
    { "version",     no_argument,       0, 'V' },
    { 0,             0,                 0,   0 }
  };

  bool measure        = false;
  bool parallel       = true;
  bool run_benchmarks = false;
  bool verbose        = false;
  size_t repeat       = 1;

  unsigned int ch;
  while((ch = getopt_long(argc, (char* const*)argv, "bmpr:vhV", options, NULL)) != -1)
  {
    extern char* optarg;
    switch(ch)
    {
      case 'b': run_benchmarks = true;         break;
      case 'm': measure        = true;         break;
      case 'p': parallel       = true;         break;
      case 'P': parallel       = false;        break;
      case 'r': repeat         = atoi(optarg); break;
      case 'v': verbose        = true;         break;
      case 'h': write_usage(stdout);           return 0;
      case 'V': write_version(stdout);         return 0;
      case '?': /* unknown option */           return 1;
      case ':': /* missing option */           return 1;
      default:  write_usage(stderr);           return 1;
    }
  }

  <%= setup.map { |fun| "#{fun}();" }.join("\n  ") %>

  struct function_t
  {
    function_t (void (*f)(), std::string const& n) : run(f), name(n) { }

    void execute (bool measure, size_t repeat)
    {
      try {
        auto start = std::chrono::high_resolution_clock::now();
        for(size_t i = 0; i < repeat; ++i)
          run();

        if(measure)
        {
          auto end = std::chrono::high_resolution_clock::now();
          double duration = std::chrono::duration_cast<std::chrono::duration<double, std::nano>>(end-start).count() / std::chrono::duration_cast<std::chrono::duration<double, std::nano>>(std::chrono::seconds(1)).count();
          fprintf(stdout, "%.3fs %s\n", duration, name.c_str());
        }
      }
      catch(std::exception const& e) {
        message = e.what();
      }
    }

    void (*run)();
    std::string name;
    std::string message;
  };

  __block std::vector<function_t> benchmarks, tests;
  <%= benchmarks.map { |fun| "benchmarks.emplace_back(&#{fun}, \"#{fun}\");" }.join("\n  ") %>
  <%= tests.map      { |fun| "tests.emplace_back(&#{fun}, \"#{fun}\");" }.join("\n  ") %>

  dispatch_group_t group = dispatch_group_create();

  if(!run_benchmarks)
  {
    dispatch_group_async(group, dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{
      if(parallel)
      {
        dispatch_apply(tests.size(), DISPATCH_APPLY_AUTO, ^(size_t n){
          tests[n].execute(measure, repeat);
        });
      }
      else
      {
        for(auto& test : tests)
          test.execute(measure, repeat);
      }
    });
  }

  __block bool hasMoreTests = true;

  dispatch_group_notify(group, dispatch_get_main_queue(), ^{
    std::vector<function_t> failed;
    std::copy_if(tests.begin(), tests.end(), back_inserter(failed), [](function_t const& t){ return !t.message.empty(); });

    if(!failed.empty())
    {
      fprintf(stderr, "%s: %zu of %zu %s failed:\n", getprogname(), failed.size(), tests.size(), tests.size() == 1 ? "test" : "tests");
      for(auto test : failed)
        fprintf(stderr, "%s\n", test.message.c_str());
      exit(1);
    }

    if(run_benchmarks)
    {
      for(auto& benchmark : benchmarks)
      {
        benchmark.execute(true, repeat);
        if(!benchmark.message.empty())
        {
          fprintf(stderr, "%s\n", benchmark.message.c_str());
          exit(1);
        }
      }

      if(benchmarks.empty())
        fprintf(stderr, "%s: warning: no benchmarks found\n", getprogname());
    }
    else if(verbose)
    {
      fprintf(stdout, "%s: %zu %s passed\n", getprogname(), tests.size(), tests.size() == 1 ? "test" : "tests");
    }
    hasMoreTests = false;
    CFRunLoopStop(CFRunLoopGetMain());
  });

  while(hasMoreTests)
    CFRunLoopRun();

  return 0;
}
