/* Copyright © 2017-2023 ABBYY

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

	http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--------------------------------------------------------------------------------------------------------------*/

#pragma once

#include <NeoML/NeoMLDefs.h>
#include <NeoML/TraditionalML/FloatVector.h>
#include <NeoML/TraditionalML/ClassificationResult.h>
#include <NeoML/TraditionalML/TrainingModel.h>
#include <NeoML/TraditionalML/SvmKernel.h>

namespace NeoML {

DECLARE_NEOML_MODEL_NAME( SvmBinaryModelName, "FmlSvmBinaryModel" )

// Support-vector machine binary classifier
class NEOML_API ISvmBinaryModel : public IModel {
public:
	~ISvmBinaryModel() override;

	// Gets the kernel type
	virtual CSvmKernel::TKernelType GetKernelType() const = 0;

	// Gets the support vectors
	virtual CSparseFloatMatrix GetVectors() const = 0;

	// Gets the support vector coefficients
	virtual const CArray<double>& GetAlphas() const = 0;

	// Gets the free term
	virtual double GetFreeTerm() const = 0;
};

// Forward declaration
class IThreadPool;

// Binary SVM training algorithm
class NEOML_API CSvm : public ITrainingModel {
public:
	// Classification parameters
	struct CParams final {
		CSvmKernel::TKernelType KernelType; // the type of error function used
		double ErrorWeight; // the weight of the error relative to the regularization function
		int MaxIterations; // the maximum number of algorithm iterations
		int Degree; // Gaussian kernel degree
		double Gamma; // the coefficient before the kernel (used for KT_Poly, KT_RBF, KT_Sigmoid).
		double Coeff0; // the free term in the kernel (used for KT_Poly, KT_Sigmoid).
		double Tolerance; // the solution precision and the stop criterion
		bool DoShrinking; // do shrinking or not
		int ThreadCount; // the number of processing threads used
		TMulticlassMode MulticlassMode; // algorithm used for multiclass classification

		CParams( CSvmKernel::TKernelType kerneltype, double errorWeight = 1., int maxIterations = 10000,
				int degree = 1, double gamma = 1., double coeff0 = 1., double tolerance = 0.1,
				bool doShrinking = true, int threadCount = 1, TMulticlassMode multiclassMode = MM_OneVsAll ) :
			KernelType( kerneltype ),
			ErrorWeight( errorWeight ),
			MaxIterations( maxIterations ),
			Degree( degree ),
			Gamma( gamma ),
			Coeff0( coeff0 ),
			Tolerance( tolerance ),
			DoShrinking( doShrinking ),
			ThreadCount( threadCount ),
			MulticlassMode( multiclassMode )
		{}
		CParams( const CParams& params ) = default;
		CParams( const CParams& params, int realThreadCount ) : CParams( params ) { ThreadCount = realThreadCount; }
	};

	explicit CSvm( const CParams& params );
	~CSvm() override;

	// Sets the text stream for logging processing
	void SetLog( CTextStream* newLog ) { log = newLog; }

	// ITrainingModel interface methods:
	// The resulting IModel is either a ILinearBinaryModel (if the KT_Linear kernel was used)
	// or a ISvmBinaryModel (if some other kernel was used) or params.MulticlassMode model (if number of classes > 2)
	CPtr<IModel> Train( const IProblem& trainingClassificationData ) override;

private:
	IThreadPool* const threadPool; // Parallel executors
	const CParams params; // Classification parameters
	CTextStream* log = nullptr; // Logging stream
};

// DEPRECATED: for backward compatibility
typedef CSvm CSvmBinaryClassifierBuilder;

} // namespace NeoML
