/* Copyright © 2017-2021 ABBYY Production LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

	http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--------------------------------------------------------------------------------------------------------------*/

#include <common.h>
#pragma hdrstop

#include "PySplitLayer.h"

class CPySplitChannelsLayer : public CPyLayer {
public:
	explicit CPySplitChannelsLayer( CSplitChannelsLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyLayer( layer, mathEngineOwner ) {}

	py::object CreatePythonObject() const
	{
		py::object pyModule = py::module::import( "neoml.Dnn" );
		py::object pyConstructor = pyModule.attr( "SplitChannels" );
		return pyConstructor( py::cast(this), 0 );
	}
};

class CPySplitDepthLayer : public CPyLayer {
public:
	explicit CPySplitDepthLayer( CSplitDepthLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyLayer( layer, mathEngineOwner ) {}

	py::object CreatePythonObject() const
	{
		py::object pyModule = py::module::import( "neoml.Dnn" );
		py::object pyConstructor = pyModule.attr( "SplitDepth" );
		return pyConstructor( py::cast(this), 0 );
	}
};

class CPySplitWidthLayer : public CPyLayer {
public:
	explicit CPySplitWidthLayer( CSplitWidthLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyLayer( layer, mathEngineOwner ) {}

	py::object CreatePythonObject() const
	{
		py::object pyModule = py::module::import( "neoml.Dnn" );
		py::object pyConstructor = pyModule.attr( "SplitWidth" );
		return pyConstructor( py::cast(this), 0 );
	}
};

class CPySplitHeightLayer : public CPyLayer {
public:
	explicit CPySplitHeightLayer( CSplitHeightLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyLayer( layer, mathEngineOwner ) {}

	py::object CreatePythonObject() const
	{
		py::object pyModule = py::module::import( "neoml.Dnn" );
		py::object pyConstructor = pyModule.attr( "SplitHeight" );
		return pyConstructor( py::cast(this), 0 );
	}
};

class CPySplitListSizeLayer : public CPyLayer {
public:
	explicit CPySplitListSizeLayer( CSplitListSizeLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyLayer( layer, mathEngineOwner ) {}

	py::object CreatePythonObject() const
	{
		py::object pyModule = py::module::import( "neoml.Dnn" );
		py::object pyConstructor = pyModule.attr( "SplitListSize" );
		return pyConstructor( py::cast(this), 0 );
	}
};

class CPySplitBatchWidthLayer : public CPyLayer {
public:
	explicit CPySplitBatchWidthLayer( CSplitBatchWidthLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyLayer( layer, mathEngineOwner ) {}

	py::object CreatePythonObject() const
	{
		py::object pyModule = py::module::import( "neoml.Dnn" );
		py::object pyConstructor = pyModule.attr( "SplitBatchWidth" );
		return pyConstructor( py::cast(this), 0 );
	}
};

class CPySplitBatchLengthLayer : public CPyLayer {
public:
	explicit CPySplitBatchLengthLayer( CSplitBatchLengthLayer& layer, CPyMathEngineOwner& mathEngineOwner ) : CPyLayer( layer, mathEngineOwner ) {}

	py::object CreatePythonObject() const
	{
		py::object pyModule = py::module::import( "neoml.Dnn" );
		py::object pyConstructor = pyModule.attr( "SplitBatchLength" );
		return pyConstructor( py::cast(this), 0 );
	}
};

void InitializeSplitLayer( py::module& m )
{
	py::class_<CPySplitChannelsLayer, CPyLayer>(m, "SplitChannels")
		.def( py::init([]( const CPyLayer& layer )
		{
			return new CPySplitChannelsLayer( *layer.Layer<CSplitChannelsLayer>(), layer.MathEngineOwner() );
		}))
		.def( py::init([]( const std::string& name, const CPyLayer& layer1, int outputNumber1, py::array sizes ) {
			py::gil_scoped_release release;
			CDnn& dnn = layer1.Dnn();
			IMathEngine& mathEngine = dnn.GetMathEngine();
			CPtr<CSplitChannelsLayer> split = new CSplitChannelsLayer( mathEngine );
			CArray<int> outputs;
			outputs.SetSize(static_cast<int>(sizes.size()));
			for( int i = 0; i < outputs.Size(); i++ ) {
				outputs[i] = reinterpret_cast<const int*>(sizes.data())[i];
			}
			split->SetOutputCounts(outputs);
			split->SetName( FindFreeLayerName( dnn, "SplitChannels", name ).c_str() );
			dnn.AddLayer( *split );
			split->Connect( 0, layer1.BaseLayer(), outputNumber1 );
			return new CPySplitChannelsLayer( *split, layer1.MathEngineOwner() );
		}) )
	;

	py::class_<CPySplitDepthLayer, CPyLayer>(m, "SplitDepth")
		.def( py::init([]( const CPyLayer& layer )
		{
			return new CPySplitDepthLayer( *layer.Layer<CSplitDepthLayer>(), layer.MathEngineOwner() );
		}))
		.def( py::init([]( const std::string& name, const CPyLayer& layer1, int outputNumber1, py::array sizes ) {
			py::gil_scoped_release release;
			CDnn& dnn = layer1.Dnn();
			IMathEngine& mathEngine = dnn.GetMathEngine();
			CPtr<CSplitDepthLayer> split = new CSplitDepthLayer( mathEngine );
			CArray<int> outputs;
			outputs.SetSize(static_cast<int>(sizes.size()));
			for( int i = 0; i < outputs.Size(); i++ ) {
				outputs[i] = reinterpret_cast<const int*>(sizes.data())[i];
			}
			split->SetOutputCounts(outputs);
			split->SetName( FindFreeLayerName( dnn, "SplitDepth", name ).c_str() );
			dnn.AddLayer( *split );
			split->Connect( 0, layer1.BaseLayer(), outputNumber1 );
			return new CPySplitDepthLayer( *split, layer1.MathEngineOwner() );
		}) )
	;

	py::class_<CPySplitWidthLayer, CPyLayer>(m, "SplitWidth")
		.def( py::init([]( const CPyLayer& layer )
		{
			return new CPySplitWidthLayer( *layer.Layer<CSplitWidthLayer>(), layer.MathEngineOwner() );
		}))
		.def( py::init([]( const std::string& name, const CPyLayer& layer1, int outputNumber1, py::array sizes ) {
			py::gil_scoped_release release;
			CDnn& dnn = layer1.Dnn();
			IMathEngine& mathEngine = dnn.GetMathEngine();
			CPtr<CSplitWidthLayer> split = new CSplitWidthLayer( mathEngine );
			CArray<int> outputs;
			outputs.SetSize(static_cast<int>(sizes.size()));
			for( int i = 0; i < outputs.Size(); i++ ) {
				outputs[i] = reinterpret_cast<const int*>(sizes.data())[i];
			}
			split->SetOutputCounts(outputs);
			split->SetName( FindFreeLayerName( dnn, "SplitWidth", name ).c_str() );
			dnn.AddLayer( *split );
			split->Connect( 0, layer1.BaseLayer(), outputNumber1 );
			return new CPySplitWidthLayer( *split, layer1.MathEngineOwner() );
		}) )
	;

	py::class_<CPySplitHeightLayer, CPyLayer>(m, "SplitHeight")
		.def( py::init([]( const CPyLayer& layer )
		{
			return new CPySplitHeightLayer( *layer.Layer<CSplitHeightLayer>(), layer.MathEngineOwner() );
		}))
		.def( py::init([]( const std::string& name, const CPyLayer& layer1, int outputNumber1, py::array sizes ) {
			py::gil_scoped_release release;
			CDnn& dnn = layer1.Dnn();
			IMathEngine& mathEngine = dnn.GetMathEngine();
			CPtr<CSplitHeightLayer> split = new CSplitHeightLayer( mathEngine );
			CArray<int> outputs;
			outputs.SetSize(static_cast<int>(sizes.size()));
			for( int i = 0; i < outputs.Size(); i++ ) {
				outputs[i] = reinterpret_cast<const int*>(sizes.data())[i];
			}
			split->SetOutputCounts(outputs);
			split->SetName( FindFreeLayerName( dnn, "SplitHeight", name ).c_str() );
			dnn.AddLayer( *split );
			split->Connect( 0, layer1.BaseLayer(), outputNumber1 );
			return new CPySplitHeightLayer( *split, layer1.MathEngineOwner() );
		}) )
	;

	py::class_<CPySplitListSizeLayer, CPyLayer>(m, "SplitListSize")
		.def( py::init([]( const CPyLayer& layer )
		{
			return new CPySplitListSizeLayer( *layer.Layer<CSplitListSizeLayer>(), layer.MathEngineOwner() );
		}))
		.def( py::init([]( const std::string& name, const CPyLayer& layer1, int outputNumber1, py::array sizes ) {
			py::gil_scoped_release release;
			CDnn& dnn = layer1.Dnn();
			IMathEngine& mathEngine = dnn.GetMathEngine();
			CPtr<CSplitListSizeLayer> split = new CSplitListSizeLayer( mathEngine );
			CArray<int> outputs;
			outputs.SetSize(static_cast<int>(sizes.size()));
			for( int i = 0; i < outputs.Size(); i++ ) {
				outputs[i] = reinterpret_cast<const int*>(sizes.data())[i];
			}
			split->SetOutputCounts(outputs);
			split->SetName( FindFreeLayerName( dnn, "SplitListSize", name ).c_str() );
			dnn.AddLayer( *split );
			split->Connect( 0, layer1.BaseLayer(), outputNumber1 );
			return new CPySplitListSizeLayer( *split, layer1.MathEngineOwner() );
		}) )
	;

	py::class_<CPySplitBatchWidthLayer, CPyLayer>(m, "SplitBatchWidth")
		.def( py::init([]( const CPyLayer& layer )
		{
			return new CPySplitBatchWidthLayer( *layer.Layer<CSplitBatchWidthLayer>(), layer.MathEngineOwner() );
		}))
		.def( py::init([]( const std::string& name, const CPyLayer& layer1, int outputNumber1, py::array sizes ) {
			py::gil_scoped_release release;
			CDnn& dnn = layer1.Dnn();
			IMathEngine& mathEngine = dnn.GetMathEngine();
			CPtr<CSplitBatchWidthLayer> split = new CSplitBatchWidthLayer( mathEngine );
			CArray<int> outputs;
			outputs.SetSize(static_cast<int>(sizes.size()));
			for( int i = 0; i < outputs.Size(); i++ ) {
				outputs[i] = reinterpret_cast<const int*>(sizes.data())[i];
			}
			split->SetOutputCounts(outputs);
			split->SetName( FindFreeLayerName( dnn, "SplitBatchWidth", name ).c_str() );
			dnn.AddLayer( *split );
			split->Connect( 0, layer1.BaseLayer(), outputNumber1 );
			return new CPySplitBatchWidthLayer( *split, layer1.MathEngineOwner() );
		}) )
	;

	py::class_<CPySplitBatchLengthLayer, CPyLayer>(m, "SplitBatchLength")
		.def( py::init([]( const CPyLayer& layer )
		{
			return new CPySplitBatchLengthLayer( *layer.Layer<CSplitBatchLengthLayer>(), layer.MathEngineOwner() );
		}))
		.def( py::init([]( const std::string& name, const CPyLayer& layer1, int outputNumber1, py::array sizes ) {
			py::gil_scoped_release release;
			CDnn& dnn = layer1.Dnn();
			IMathEngine& mathEngine = dnn.GetMathEngine();
			CPtr<CSplitBatchLengthLayer> split = new CSplitBatchLengthLayer( mathEngine );
			CArray<int> outputs;
			outputs.SetSize(static_cast<int>(sizes.size()));
			for( int i = 0; i < outputs.Size(); i++ ) {
				outputs[i] = reinterpret_cast<const int*>(sizes.data())[i];
			}
			split->SetOutputCounts(outputs);
			split->SetName( FindFreeLayerName( dnn, "SplitBatchLength", name ).c_str() );
			dnn.AddLayer( *split );
			split->Connect( 0, layer1.BaseLayer(), outputNumber1 );
			return new CPySplitBatchLengthLayer( *split, layer1.MathEngineOwner() );
		}) )
	;

}
