/*=========================================================================
 *
 *  Copyright NumFOCUS
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *         https://www.apache.org/licenses/LICENSE-2.0.txt
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 *=========================================================================*/


#include "itkFastMarchingExtensionImageFilterBase.h"
#include "itkFastMarchingThresholdStoppingCriterion.h"
#include "itkCommand.h"
#include "itkMath.h"


namespace
{
// The following class is used to support callbacks
// on the filter in the pipeline that follows later
class ShowProgressObject
{
public:
  ShowProgressObject(itk::ProcessObject * o) { m_Process = o; }
  void
  ShowProgress()
  {
    std::cout << "Progress " << m_Process->GetProgress() << std::endl;
  }
  itk::ProcessObject::Pointer m_Process;
};
} // namespace

int
itkFastMarchingExtensionImageFilterTest(int, char *[])
{
  // create a fastmarching object
  constexpr unsigned int Dimension = 2;
  using PixelType = float;

  using FloatImageType = itk::Image<PixelType, Dimension>;

  using CriterionType = itk::FastMarchingThresholdStoppingCriterion<FloatImageType, FloatImageType>;
  auto criterion = CriterionType::New();
  criterion->SetThreshold(100.);

  using MarcherType = itk::FastMarchingExtensionImageFilterBase<FloatImageType, FloatImageType, unsigned char, 1>;

  auto marcher = MarcherType::New();
  marcher->SetStoppingCriterion(criterion);

  ShowProgressObject                                    progressWatch(marcher);
  itk::SimpleMemberCommand<ShowProgressObject>::Pointer command;
  command = itk::SimpleMemberCommand<ShowProgressObject>::New();
  command->SetCallbackFunction(&progressWatch, &ShowProgressObject::ShowProgress);
  marcher->AddObserver(itk::ProgressEvent(), command);


  bool passed;

  // setup trial points
  using NodePairType = MarcherType::NodePairType;
  using NodePairContainerType = MarcherType::NodePairContainerType;

  // setup alive points
  auto AlivePoints = NodePairContainerType::New();

  const FloatImageType::OffsetType offset0 = { { 28, 35 } };

  itk::Index<2> index{};

  AlivePoints->push_back(NodePairType(index + offset0, 0.));

  index.Fill(200);
  AlivePoints->push_back(NodePairType(index, 42.));

  marcher->SetAlivePoints(AlivePoints);


  // setup trial points
  auto TrialPoints = NodePairContainerType::New();

  index.Fill(0);
  index += offset0;

  index[0] += 1;
  TrialPoints->push_back(NodePairType(index, 1.));

  index[0] -= 1;
  index[1] += 1;
  TrialPoints->push_back(NodePairType(index, 1.));

  index[0] -= 1;
  index[1] -= 1;
  TrialPoints->push_back(NodePairType(index, 1.));

  index[0] += 1;
  index[1] -= 1;
  TrialPoints->push_back(NodePairType(index, 1.));

  index.Fill(300); // this node is out of range
  TrialPoints->push_back(NodePairType(index, 42.));

  marcher->SetTrialPoints(TrialPoints);

  // specify the size of the output image
  const FloatImageType::SizeType size = { { 64, 64 } };
  marcher->SetOutputSize(size);

  // setup a speed image of ones
  auto                       speedImage = FloatImageType::New();
  FloatImageType::RegionType region;
  region.SetSize(size);
  speedImage->SetLargestPossibleRegion(region);
  speedImage->SetBufferedRegion(region);
  speedImage->Allocate();

  itk::ImageRegionIterator<FloatImageType> speedIter(speedImage, speedImage->GetBufferedRegion());
  while (!speedIter.IsAtEnd())
  {
    speedIter.Set(1.0);
    ++speedIter;
  }

  marcher->SetInput(speedImage);

  // deliberately cause an exception by not setting AuxAliveValues
  passed = false;
  try
  {
    marcher->Update();
  }
  catch (const itk::ExceptionObject & err)
  {
    passed = true;
    marcher->ResetPipeline();
    std::cout << err << std::endl;
  }
  if (!passed)
  {
    return EXIT_FAILURE;
  }

  using VectorType = MarcherType::AuxValueVectorType;
  using AuxValueContainerType = MarcherType::AuxValueContainerType;

  auto auxAliveValues = AuxValueContainerType::New();

  // deliberately cause an exception setting AuxAliveValues of the wrong size
  marcher->SetAuxiliaryAliveValues(auxAliveValues);

  passed = false;
  try
  {
    marcher->Update();
  }
  catch (const itk::ExceptionObject & err)
  {
    passed = true;
    marcher->ResetPipeline();
    std::cout << err << std::endl;
  }
  if (!passed)
  {
    return EXIT_FAILURE;
  }


  VectorType vector;
  vector[0] = 48;

  auxAliveValues->push_back(vector);
  auxAliveValues->push_back(vector);

  marcher->SetAuxiliaryAliveValues(auxAliveValues);

  // deliberately cause an exception by not setting AuxTrialValues
  passed = false;
  try
  {
    marcher->Update();
  }
  catch (const itk::ExceptionObject & err)
  {
    passed = true;
    marcher->ResetPipeline();
    std::cout << err << std::endl;
  }
  if (!passed)
  {
    return EXIT_FAILURE;
  }

  auto auxTrialValues = AuxValueContainerType::New();

  // deliberately cause an exception setting AuxTrialValues of the wrong size
  marcher->SetAuxiliaryTrialValues(auxTrialValues);

  passed = false;
  try
  {
    marcher->Update();
  }
  catch (const itk::ExceptionObject & err)
  {
    passed = true;
    marcher->ResetPipeline();
    std::cout << err << std::endl;
  }
  if (!passed)
  {
    return EXIT_FAILURE;
  }


  auxTrialValues->push_back(vector);
  auxTrialValues->push_back(vector);
  auxTrialValues->push_back(vector);
  auxTrialValues->push_back(vector);
  auxTrialValues->push_back(vector);

  marcher->SetAuxiliaryTrialValues(auxTrialValues);

  // run the algorithm
  passed = true;
  try
  {
    marcher->Update();
  }
  catch (const itk::ExceptionObject & err)
  {
    passed = false;
    marcher->ResetPipeline();
    std::cout << err << std::endl;
  }
  if (!passed)
  {
    return EXIT_FAILURE;
  }


  // check the results
  passed = true;
  const FloatImageType::Pointer            output = marcher->GetOutput();
  itk::ImageRegionIterator<FloatImageType> iterator(output, output->GetBufferedRegion());

  using AuxImageType = MarcherType::AuxImageType;
  const AuxImageType::Pointer            auxImage = marcher->GetAuxiliaryImage(0);
  itk::ImageRegionIterator<AuxImageType> auxIterator(auxImage, auxImage->GetBufferedRegion());

  while (!iterator.IsAtEnd())
  {
    FloatImageType::IndexType tempIndex;
    double                    distance;
    float                     outputValue;

    tempIndex = iterator.GetIndex();
    tempIndex -= offset0;
    distance = 0.0;
    for (int j = 0; j < 2; ++j)
    {
      distance += tempIndex[j] * tempIndex[j];
    }
    distance = std::sqrt(distance);

    outputValue = static_cast<float>(iterator.Get());

    if (itk::Math::NotAlmostEquals(distance, 0.0))
    {
      if (itk::Math::abs(outputValue) / distance > 1.42)
      {
        std::cout << iterator.GetIndex() << ' ';
        std::cout << itk::Math::abs(outputValue) / distance << ' ';
        std::cout << itk::Math::abs(outputValue) << ' ' << distance << std::endl;
        passed = false;
        break;
      }

      if (auxIterator.Get() != vector[0])
      {
        std::cout << auxIterator.GetIndex() << " got aux value of " << static_cast<double>(auxIterator.Get())
                  << " but it should be  " << static_cast<double>(vector[0]) << std::endl;
        passed = false;
        break;
      }
    }
    ++iterator;
    ++auxIterator;
  }

  // Exercise other member functions
  // std::cout << "Auxiliary alive values: " << marcher->GetAuxiliaryAliveValues();
  std::cout << std::endl;

  // std::cout << "Auxiliary trial values: " << marcher->GetAuxiliaryTrialValues();
  std::cout << std::endl;

  marcher->Print(std::cout);

  if (marcher->GetAuxiliaryImage(2))
  {
    std::cout << "GetAuxiliaryImage(2) should have returned nullptr";
    std::cout << std::endl;
    passed = false;
  }

  if (passed)
  {
    std::cout << "Fast Marching test passed" << std::endl;
    return EXIT_SUCCESS;
  }
  std::cout << "Fast Marching test failed" << std::endl;
  return EXIT_FAILURE;
}
