import numpy as np
from numpy import linalg as LA

from keras.applications.resnet50 import ResNet50 as KerasResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input

class ResNet50:
    def __init__(self):
        self.input_shape = (224, 224, 3)
        self.weight = 'imagenet'
        self.pooling = 'max'
        self.model = KerasResNet50(weights = self.weight, input_shape = (self.input_shape[0], self.input_shape[1], self.input_shape[2]), pooling = self.pooling, include_top=False)
        self.model.predict(np.zeros((1, 224, 224 , 3)))

    
    def extract_feat(self, img_path):
        img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
        img = image.img_to_array(img)
        img = np.expand_dims(img, axis=0)
        img = preprocess_input(img)
        feat = self.model.predict(img)
        norm_feat = feat[0]/LA.norm(feat[0])
        return norm_feat