"""
An example that uses TensorRT's Python api to make inferences.
"""
import ctypes
import os
import random
import sys
import threading
import time

import cv2
import numpy as np
import pycuda.autoinit
import pycuda.driver as cuda
import tensorrt as trt
import torch
import torchvision

INPUT_W = 640
INPUT_H = 640
CONF_THRESH = 0.25
IOU_THRESHOLD = 0.45

PROB_THRESH = 0.65

id2label = {
    0:"normal",  #A
    1:"normal",  #B
    2:"normal",  #C
    3:"normal",  #D
    4:"normal", #E
    5:"early_esophageal_cancer",  #F
    6:"early_gastric_cancer", #G
    7:"normal", #N1
    8:"normal", #N2
    9:"normal", #N3
    10:"normal", #N4
    11:"normal", #N5
    12:"normal", #N6
    13:"normal", #N7
    14:"normal", #N8
    15:"normal",  #N9
    16:"normal", #N10
}

# 画框
def plot_one_box(x, img, color=None, label=None, line_thickness=None):
    """
    description: Plots one bounding box on image img,
                 this function comes from YoLov5 project.
    param: 
        x:      a box likes [x1,y1,x2,y2]
        img:    a opencv image object
        color:  color to draw rectangle, such as (0,255,0)
        label:  str
        line_thickness: int
    return:
        no return

    """
    # if not os.path.exists("detect_res"):
    #     os.makdedirs("detect_res")
    tl = (
        line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1
    )  # line/font thickness
    color = color or [random.randint(0, 255) for _ in range(3)]
    c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
    cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
    if label:
        tf = max(tl - 1, 1)  # font thickness
        t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
        cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # filled
        cv2.putText(
            img,
            label,
            (c1[0], c1[1] - 2),
            0,
            tl / 3,
            [225, 255, 255],
            thickness=tf,
            lineType=cv2.LINE_AA,
        )
    # cv2.imwrite(os.path.join(save_path,file_name),img)


class YoLov5TRT(object):
    """
    description: A YOLOv5 class that warps TensorRT ops, preprocess and postprocess ops.
    """

    def __init__(self, engine_file_path):
        # Create a Context on this device,
        self.cfx = cuda.Device(0).make_context()
        stream = cuda.Stream()
        TRT_LOGGER = trt.Logger(trt.Logger.INFO)
        runtime = trt.Runtime(TRT_LOGGER)

        # <--------------------读取序列化引擎
        with open(engine_file_path, "rb") as f:
            engine = runtime.deserialize_cuda_engine(f.read())
        context = engine.create_execution_context()

        host_inputs = []
        cuda_inputs = []
        host_outputs = []
        cuda_outputs = []
        bindings = []

        for binding in engine:
            size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
            dtype = trt.nptype(engine.get_binding_dtype(binding))
            # Allocate host and device buffers
            host_mem = cuda.pagelocked_empty(size, dtype)
            cuda_mem = cuda.mem_alloc(host_mem.nbytes)
            # Append the device buffer to device bindings.
            bindings.append(int(cuda_mem))
            # Append to the appropriate list.
            if engine.binding_is_input(binding):
                host_inputs.append(host_mem)
                cuda_inputs.append(cuda_mem)
            else:
                host_outputs.append(host_mem)
                cuda_outputs.append(cuda_mem)

        # Store
        self.stream = stream
        self.context = context
        self.engine = engine
        self.host_inputs = host_inputs
        self.cuda_inputs = cuda_inputs
        self.host_outputs = host_outputs
        self.cuda_outputs = cuda_outputs
        self.bindings = bindings

    def infer(self, input_image_path):
        threading.Thread.__init__(self)
        # Make self the active context, pushing it on top of the context stack.
        self.cfx.push()
        # Restore
        stream = self.stream
        context = self.context
        engine = self.engine
        host_inputs = self.host_inputs
        cuda_inputs = self.cuda_inputs
        host_outputs = self.host_outputs
        cuda_outputs = self.cuda_outputs
        bindings = self.bindings

        #  # <-----------------模型的前处理，图像处理
        
        input_image, image_raw, origin_h, origin_w = self.preprocess_image_0(
            input_image_path
        )
        # Copy input image to host buffer
        np.copyto(host_inputs[0], input_image.ravel())
        # Transfer input data  to the GPU.
        cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream)
        # #<-----------基于序列化的引擎，开始推断
        start = time.time()
        context.execute_async(bindings=bindings, stream_handle=stream.handle)
        # Transfer predictions back from the GPU.
        cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream)
        # Synchronize the stream
        stream.synchronize()
        # Remove any context from the top of the context stack, deactivating it.
        self.cfx.pop()
        # Here we use the first row of output in that batch_size = 1
        # <---------------得到推断结果
        output = host_outputs[0]
        end = time.time()
        print(output.shape)

        # <--------------后处理
        result_boxes, result_scores, result_classid = self.post_process(
            output, origin_h, origin_w
        )
        
        print("waste_time: {}".format(end-start))
        # Draw rectangles and labels on the original image

        file_name = input_image_path.split("/")[-1]
        for i in range(len(result_boxes)):
            box = result_boxes[i]
            if result_scores[i] <= PROB_THRESH:
                continue;
            if not int(result_classid[i]) in [5,6]:
                continue;
            plot_one_box(
                box,
                image_raw,
                label="{}:{:.2f}".format(
                    id2label[int(result_classid[i])], result_scores[i]
                ),
            )
        parent, filename = os.path.split(input_image_path)

        if not os.path.exists("detect_res"):
            os.makedirs("detect_res")
        save_name = os.path.join("detect_res", filename)
        # 　Save image
        cv2.imwrite(save_name, image_raw)

    def destroy(self):
        # Remove any context from the top of the context stack, deactivating it.
        self.cfx.pop()

    def preprocess_image(self, input_image_path):
        """
        description: Read an image from image path, convert it to RGB,
                     resize and pad it to target size, normalize to [0,1],
                     transform to NCHW format.
        param:
            input_image_path: str, image path
        return:
            image:  the processed image
            image_raw: the original image
            h: original height
            w: original width
        """
        image_raw = cv2.imread(input_image_path)   # 1.opencv读入图片
        h, w, c = image_raw.shape  

        # Calculate widht and height and paddings
        r_w = INPUT_W / w  # INPUT_W=INPUT_H=640  # 4.计算宽高缩放的倍数 r_w,r_h
        r_h = INPUT_H / h
        if r_h > r_w:       # 5.如果原图的高小于宽(长边），则长边缩放到640，短边按长边缩放比例缩放
            tw = INPUT_W
            th = int(r_w * h)

            dw = INPUT_W - tw
            dh = INPUT_H - th

            dw, dh = np.mod(dw,32),np.mod(dh,32)
            dw /= 2  # divide padding into 2 sides
            dh /= 2

        else:
            tw = int(r_h * w)
            th = INPUT_H

            dw = INPUT_W - tw
            dh = INPUT_H - th

            dw, dh = np.mod(dw,32),np.mod(dh,32)
            dw /= 2  # divide padding into 2 sides
            dh /= 2
          

        
        # Resize the image with long side while maintaining ratio
        image = cv2.resize(image_raw, (tw, th),interpolation=cv2.INTER_LINEAR)  # 6.图像resize,按照cv2.INTER_LINEAR方法
        # Pad the short side with (128,128,128)  

        top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
        left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) 
        image = cv2.copyMakeBorder(
            # image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, (128, 128, 128)
            image, top, bottom, left, right, cv2.BORDER_CONSTANT, (114, 114, 114)

        )  # image:图像， ty1, ty2.tx1,tx2: 相应方向上的边框宽度，添加的边界框像素值为常数，value填充的常数值

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 3. BGR2RGB

        image = image.astype(np.float32)   # 7.unit8-->float
        # Normalize to [0,1]
        image /= 255.0    # 8. 逐像素点除255.0
        # HWC to CHW format:
        image = np.transpose(image, [2, 0, 1])   # 9. HWC2CHW
        # CHW to NCHW format
        image = np.expand_dims(image, axis=0)    # 10.CWH2NCHW
        # Convert the image to row-major order, also known as "C order":
        image = np.ascontiguousarray(image)  # 11.ascontiguousarray函数将一个内存不连续存储的数组转换为内存连续存储的数组，使得运行速度更快
        return image, image_raw, h, w  # 处理后的图像，原图， 原图的h,w

    def preprocess_image_0(self, input_image_path):
        """
        description: Read an image from image path, convert it to RGB,
                     resize and pad it to target size, normalize to [0,1],
                     transform to NCHW format.
        param:
            input_image_path: str, image path
        return:
            image:  the processed image
            image_raw: the original image
            h: original height
            w: original width
        """
        image_raw = cv2.imread(input_image_path)   # 1.opencv读入图片
        h, w, c = image_raw.shape                  # 2.记录图片大小
        image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB)  # 3. BGR2RGB
        # Calculate widht and height and paddings
        r_w = INPUT_W / w  # INPUT_W=INPUT_H=640  # 4.计算宽高缩放的倍数 r_w,r_h
        r_h = INPUT_H / h
        if r_h > r_w:       # 5.如果原图的高小于宽(长边），则长边缩放到640，短边按长边缩放比例缩放
            tw = INPUT_W
            th = int(r_w * h)
            tx1 = tx2 = 0
            ty1 = int((INPUT_H - th) / 2)  # ty1=（640-短边缩放的长度）/2 ，这部分是YOLOv5为加速推断而做的一个图像缩放算法
            ty2 = INPUT_H - th - ty1       # ty2=640-短边缩放的长度-ty1
        else:
            tw = int(r_h * w)
            th = INPUT_H
            tx1 = int((INPUT_W - tw) / 2)
            tx2 = INPUT_W - tw - tx1
            ty1 = ty2 = 0
        # Resize the image with long side while maintaining ratio
        image = cv2.resize(image, (tw, th),interpolation=cv2.INTER_LINEAR)  # 6.图像resize,按照cv2.INTER_LINEAR方法
        # Pad the short side with (128,128,128)   
        image = cv2.copyMakeBorder(
            # image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, (128, 128, 128)
            image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, (114, 114, 114)

        )  # image:图像， ty1, ty2.tx1,tx2: 相应方向上的边框宽度，添加的边界框像素值为常数，value填充的常数值
        image = image.astype(np.float32)   # 7.unit8-->float
        # Normalize to [0,1]
        image /= 255.0    # 8. 逐像素点除255.0
        # HWC to CHW format:
        image = np.transpose(image, [2, 0, 1])   # 9. HWC2CHW
        # CHW to NCHW format
        image = np.expand_dims(image, axis=0)    # 10.CWH2NCHW
        # Convert the image to row-major order, also known as "C order":
        image = np.ascontiguousarray(image)  # 11.ascontiguousarray函数将一个内存不连续存储的数组转换为内存连续存储的数组，使得运行速度更快
        return image, image_raw, h, w  # 处理后的图像，原图， 原图的h,w

    def xywh2xyxy(self, origin_h, origin_w, x):
        """
        description:    Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
        param:
            origin_h:   height of original image
            origin_w:   width of original image
            x:          A boxes tensor, each row is a box [center_x, center_y, w, h]
        return:
            y:          A boxes tensor, each row is a box [x1, y1, x2, y2]
        """
        y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
        r_w = INPUT_W / origin_w
        r_h = INPUT_H / origin_h
        if r_h > r_w:
            y[:, 0] = x[:, 0] - x[:, 2] / 2  #x1
            y[:, 2] = x[:, 0] + x[:, 2] / 2  #x2
            y[:, 1] = x[:, 1] - x[:, 3] / 2 - (INPUT_H - r_w * origin_h) / 2  # y1
            y[:, 3] = x[:, 1] + x[:, 3] / 2 - (INPUT_H - r_w * origin_h) / 2  # y2
            y /= r_w
        else:
            y[:, 0] = x[:, 0] - x[:, 2] / 2 - (INPUT_W - r_h * origin_w) / 2
            y[:, 2] = x[:, 0] + x[:, 2] / 2 - (INPUT_W - r_h * origin_w) / 2
            y[:, 1] = x[:, 1] - x[:, 3] / 2
            y[:, 3] = x[:, 1] + x[:, 3] / 2
            y /= r_h

        return y

    def post_process(self, output, origin_h, origin_w):
        """
        description: postprocess the prediction
        param:
            output:     A tensor likes [num_boxes,cx,cy,w,h,conf,cls_id, cx,cy,w,h,conf,cls_id, ...] 
            origin_h:   height of original image
            origin_w:   width of original image
        return:
            result_boxes: finally boxes, a boxes tensor, each row is a box [x1, y1, x2, y2]
            result_scores: finally scores, a tensor, each element is the score correspoing to box
            result_classid: finally classid, a tensor, each element is the classid correspoing to box
        """
        # Get the num of boxes detected
        num = int(output[0])   # detect的box的个数
        # Reshape to a two dimentional ndarray
        pred = np.reshape(output[1:], (-1, 6))[:num, :]   #[[cx,cy,w,h,conf,cls_id],[cx,cy,w,h,conf,cls_id],...]
        # to a torch Tensor
        pred = torch.Tensor(pred).cuda()
        # Get the boxes
        boxes = pred[:, :4]   # [[cx,cy,w,h],[cx,cy,w,h],...]  
        # Get the scores
        scores = pred[:, 4]   #[conf,conf,....]
        # Get the classid
        classid = pred[:, 5]  # [cls_id,cls_id,...]
        # Choose those boxes that score > CONF_THRESH
        si = scores > CONF_THRESH
        boxes = boxes[si, :]
        scores = scores[si]
        classid = classid[si]
        # Trandform bbox from [center_x, center_y, w, h] to [x1, y1, x2, y2]
        boxes = self.xywh2xyxy(origin_h, origin_w, boxes)
        # Do nms
        indices = torchvision.ops.nms(boxes, scores, iou_threshold=IOU_THRESHOLD).cpu()  # NMS
        result_boxes = boxes[indices, :].cpu()
        result_scores = scores[indices].cpu()
        result_classid = classid[indices].cpu()
        return result_boxes, result_scores, result_classid


class myThread(threading.Thread):
    def __init__(self, func, args):
        threading.Thread.__init__(self)
        self.func = func
        self.args = args

    def run(self):
        self.func(*self.args)


if __name__ == "__main__":

    # load custom plugins
    PLUGIN_LIBRARY = "build/libmyplugins.so"
    ctypes.CDLL(PLUGIN_LIBRARY)

    engine_file_path = "build/yolov5x.engine"

    # a  YoLov5TRT instance
    yolov5_wrapper = YoLov5TRT(engine_file_path)

    # from https://github.com/ultralytics/yolov5/tree/master/inference/images

    files = os.listdir('test')
    input_image_paths = [os.path.join('test',file) for file in files]

    for input_image_path in input_image_paths:
        # create a new thread to do inference
        thread1 = myThread(yolov5_wrapper.infer, [input_image_path])
        thread1.start()
        thread1.join()

    # destroy the instance
    yolov5_wrapper.destroy()
