import numpy as np
from numpy import linalg as LA

from keras.applications.densenet import DenseNet169
from keras.preprocessing import image
from keras.applications.densenet import preprocess_input

class DenseNet:
    def __init__(self):
        self.input_shape = (224, 224, 3)
        self.weight = 'imagenet'
        self.pooling = 'max'
        self.model = DenseNet169(weights = self.weight, input_shape = (self.input_shape[0], self.input_shape[1], self.input_shape[2]), pooling = self.pooling, include_top = False)
        self.model.predict(np.zeros((1, 224, 224 , 3)))

    
    def extract_feat(self, img_path):
        img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
        img = image.img_to_array(img)
        img = np.expand_dims(img, axis=0)
        img = preprocess_input(img)
        feat = self.model.predict(img)
        norm_feat = feat[0]/LA.norm(feat[0])
        return norm_feat


    def extract_feats(self, img_paths):
        num = len(img_paths)
        imgs = np.zeros((num,) + self.input_shape)
        for i in range(num):
            img = image.load_img(img_paths[i], target_size=(self.input_shape[0], self.input_shape[1]))
            imgs[i] = image.img_to_array(img)
        feats = self.model.predict(preprocess_input(imgs))
        feats = feats/LA.norm(feats, axis=1)[:, None]
        feats = [i for i in feats]
        return feats