#include "saber/funcs/permute_power.h"

#include "saber/core/context.h"
#include "test/saber/test_saber_func.h"
#include "test/saber/test_saber_base.h"
#include "saber/core/tensor_op.h"
#include "saber/saber_types.h"
#include "saber/core/data_traits.h"

#include <vector>

using namespace anakin::saber;
template <typename dtype, typename TargetType_D, typename TargetType_H>
void permute_power_cpu_func(const std::vector<Tensor<TargetType_H>*>& input, std::vector<Tensor<TargetType_H>*>& output, PermutePowerParam<TargetType_D>& param) {
    const dtype* src_ptr = static_cast<const dtype*>(input[0] -> data());
    dtype* dst_ptr = static_cast<dtype*>(output[0] -> mutable_data());
    //permute
    std::vector<int> orders = param.permute_param.order;
    int out_size = output[0] -> valid_size();
    int num_axes = input[0] -> valid_shape().size();
    std::vector<int> new_steps = output[0] -> get_stride();
    std::vector<int> old_steps = input[0] -> get_stride();
    std::vector<int> new_valid_shape = output[0] -> valid_shape();
    for (int j=0; j<out_size; ++j){
        int in_idx = 0;
        int out_idx  = 0;
        int new_valid_stride = 1;
        for (int i = num_axes - 1; i >= 0; --i) {
            int order = orders[i];
            int new_step = new_steps[i];
            int old_step = old_steps[order];
            int id = (j / new_valid_stride) % new_valid_shape[i];
            in_idx += id * old_step;
            out_idx += id * new_step;
            new_valid_stride *= new_valid_shape[i];
        }
        dst_ptr[out_idx] = src_ptr[in_idx];
    }
    
    //power
    float p = param.power_param.power;
    float scale = param.power_param.scale;
    float shift = param.power_param.shift;
    
    for (int i=0; i < out_size; ++i){
        dst_ptr[i] = pow(dst_ptr[i] * scale + shift, p);
    }
    
}

template <typename TargetType_D, typename TargetType_H, DataType OpDtype>
void test_permute_power(){
    TestSaberBase<TargetType_D, TargetType_H, OpDtype, PermutePower, PermutePowerParam> testbase;
    typedef typename DataTrait<TargetType_H, OpDtype> :: Dtype dtype;
    
    for (int s0 : {0, 1, 2, 3}){
    for (int s1 : {0, 1, 2, 3}){
    for (int s2 : {0, 1, 2, 3}){
    for (int s3 : {0, 1, 2, 3}){
        if (s0 != s1 && s0 != s2 && s0 != s3 && s1 != s2 && s1 != s3 && s2 != s3){
            PermuteParam<TargetType_D> permute_param({s0, s1,s2, s3});

            std::vector<int> v_p = {0, 1, 2};
            std::vector<int> v_scale = {0.5, 1.0, 2.0};
            std::vector<int> v_shift = {0, 1, 2};
            // mlu test is too slow for now
            if (std::is_same<TargetType_D, MLU>::value) {
                v_p = {2};
                v_scale = {0.5};
                v_shift = {2};
            }

            for (float p : v_p) {
            for (float scale : v_scale) {
            for (float shift : v_shift) {
                PowerParam<TargetType_D> power_param(p, scale, shift);
                LOG(INFO)<<"permute:("<<s0<<","<<s1<<","<<s2<<","<<s3<<")";
                LOG(INFO)<<"power_param: p:"<<p<<" scale: "<<scale<<" shift:"<<shift;
                PermutePowerParam<TargetType_D> param(permute_param, power_param);

                std::vector<int> v_n = {1, 2};    std::vector<int> v_c = {1, 3};
                std::vector<int> v_h = {32, 64};  std::vector<int> v_w = {32, 64};
                // mlu test is too slow for now
                if (std::is_same<TargetType_D, MLU>::value) {
                    v_n = {2};   v_c = {3};
                    v_h = {32};  v_w = {64};
                }
                for (int n : v_n){
                for (int c : v_c){
                for (int h : v_h){
                for (int w : v_w){
                    testbase.set_param(param);
                    testbase.set_input_shape(Shape({n, c, h, w}));
                    if (std::is_same<TargetType_D, MLU>::value) {
                        testbase.set_rand_limit(1.0, 2.0);
                        testbase.run_test(permute_power_cpu_func<dtype, TargetType_D, TargetType_H>,
                                          0.02, true);
                    } else {
                        testbase.run_test(permute_power_cpu_func<dtype, TargetType_D, TargetType_H>);
                    }
                }
                }
                }
                }//for after permute_power_param
            }
            }
            } //for after power_param
        }
    }
    }
    }
    }
}

TEST(TestSaberFunc, test_func_normalize) {
#ifdef USE_CUDA
    test_permute_power<NV, NVHX86, AK_FLOAT>();
#endif
#ifdef USE_X86_PLACE
    test_permute_power<X86, X86, AK_FLOAT>();
#endif
#ifdef USE_MLU
    test_permute_power<MLU, MLUHX86, AK_FLOAT>();
#endif
}

int main(int argc, const char** argv) {
    // initial logger
    //logger::init(argv[0]);
    
    InitTest();
    RUN_ALL_TESTS(argv[0]);
    
    return 0;
}
