/* Copyright © 2017-2024 ABBYY

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

	http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--------------------------------------------------------------------------------------------------------------*/

#pragma once

#include <CudaMathEngineDnnConvs.h>
#include <Kernels/CudaGrid.h>

namespace NeoML {

// Time convolution may be done as a matrix multiplication if inputData will be reordered in a temporary matrix
// where i-th row contains data, which will be covered by filter in i-th row of output

// This kernel builds a PART of this matrix, starting with firstLineIndex and of matrixHeight height
// It's done because full temp matrix may require a lot of memory

__global__ void BuildTempMatrixKernel( const CCudaTimeConvolutionDescInternal desc,
	const float* __restrict__ input, int matrixPartHeight, int matrixWidth, float* __restrict__ matrix,
	int firstLineIndex )
{
	const int objectSize = desc.Source.ObjectSize();
	const int batchSize = desc.Source.BatchWidth() * desc.Source.ListSize();
	const int inBatchLen = desc.Source.BatchLength();
	const int outBatchLen = desc.Result.BatchLength();
	const int stride = desc.Stride;
	const int padFront = desc.PaddingFront;
	const int dilation = desc.Dilation;

	int matrixRow = 0;
	int matrixCol = 0;
	if( !GetCudaTaskIndex2D( matrixPartHeight, matrixWidth, matrixRow, matrixCol ) ) {
		return;
	}

	matrix += matrixRow * matrixWidth + matrixCol;

	// Row index in full temporary matrix
	const int fullMatrixRowIndex = matrixRow + firstLineIndex;
	const int batch = fullMatrixRowIndex % batchSize;
	const int seqPos = fullMatrixRowIndex / batchSize;

	const int inputSeqStart = seqPos * stride - padFront;

	const int elemIndex = matrixCol % objectSize;
	const int filterSeq = matrixCol / objectSize;

	const int inputSeq = inputSeqStart + filterSeq * dilation;
	if( inputSeq >= 0 && inputSeq < inBatchLen ) {
		*matrix = input[( inputSeq * batchSize + batch ) * objectSize + elemIndex];
	} else {
		*matrix = 0;
	}
}

const int BlobTimeConvolutionBackwardUnpackCombine = 64;
__global__ void BlobTimeConvolutionBackwardUnpackKernel( const CCudaTimeConvolutionDescInternal desc, const float* filterData,
	float* inputDiffData, int xSizeNorm, int combineCount, const float* data, int firstRowIndex, int currPartHeight )
{
	const CCudaBlobDesc& inputDiff = desc.Source;
	const CCudaBlobDesc& filter = desc.Filter;
	const CCudaBlobDesc& outputDiff = desc.Result;

	const int batch = blockIdx.y * blockDim.y + threadIdx.y;
	if( batch >= inputDiff.ObjectCount() ) {
		return;
	}

	const int objectSize = inputDiff.ObjectSize();
	const int seqNum = batch / ( inputDiff.BatchWidth() * inputDiff.ListSize() );
	const int batchNum = batch % ( inputDiff.BatchWidth() * inputDiff.ListSize() );

	int index = 0;
	int step = 0;
	const int count = GetCudaTaskCountAndIndex(objectSize, combineCount, index, step);

	// Initialize the sums
	float sums[BlobTimeConvolutionBackwardUnpackCombine];
	for( int i = 0; i < count; ++i ) {
		sums[i] = 0;
	}

	for( int filterY = 0; filterY < filter.Height(); filterY++ ) {
		const int inSeqNumFirst = seqNum - filterY * desc.Dilation;
		if( inSeqNumFirst < -desc.PaddingFront ) {
			break; // the next values can only be smaller
		}
		if( ( inSeqNumFirst + desc.PaddingFront ) % desc.Stride != 0 ) {
			continue; // this row is not affected by the current filter row
		}
		const int outSeqNum = ( inSeqNumFirst + desc.PaddingFront ) / desc.Stride;
		if( outSeqNum >= outputDiff.BatchLength() ) {
			continue;
		}
		const int tempMatrixRowIndex = outSeqNum * inputDiff.BatchWidth() * inputDiff.ListSize() + batchNum;
		if( tempMatrixRowIndex < firstRowIndex || tempMatrixRowIndex >= firstRowIndex + currPartHeight ) {
			continue;
		}
		const float* const from = data + ((tempMatrixRowIndex - firstRowIndex) * filter.Height() + filterY) * objectSize;
		int curIndex = index;
		for(int i = 0; i < count; ++i, curIndex += step) {
			sums[i] += __ldg(from + curIndex);
		}
	}

	// Write the results
	float* const curInputDiffData = inputDiffData
		+ (seqNum * inputDiff.BatchWidth() * inputDiff.ListSize() + batchNum) * objectSize;
	for( int i = 0; i < count; ++i, index += step ) {
		curInputDiffData[index] += sums[i];
	}
}

__global__ void BlobTimeConvolutionLearnFilterKernel( CCudaTimeConvolutionDescInternal desc,
	const float* __restrict__ input, const float* __restrict__ outputDiff, float* filterDiff )
{
	const int objectSize = desc.Filter.Channels();
	const int filterHeight = desc.Filter.Height();
	const int filterCount = desc.Filter.ObjectCount();

	const int inputLength = desc.Source.BatchLength();
	const int outputLength = desc.Result.BatchLength();

	const int batchWidth = desc.Source.BatchWidth();

	int index = 0;
	if( GetCudaTaskIndex( desc.Filter.BlobSize(), index ) ) {
		filterDiff += index;

		const int filterChannel = index % objectSize;
		index /= objectSize;
		const int filterRow = index % filterHeight;
		const int filterNum = index / filterHeight;

		float res = 0;
		for( int outL = 0; outL < outputLength; ++outL ) {
			const int inL = outL * desc.Stride - desc.PaddingFront + filterRow * desc.Dilation;
			if( inL < 0 || inL >= inputLength ) {
				continue;
			}

			const float* currOutputDiff = outputDiff + outL * batchWidth * filterCount + filterNum;
			const float* currInput = input + inL * batchWidth * objectSize + filterChannel;
			for( int b = 0; b < batchWidth; ++b ) {
				res += *currOutputDiff * *currInput;
				currInput += objectSize;
				currOutputDiff += filterCount;
			}
		}

		*filterDiff += res;
	}
}

} // namespace NeoML
