from overrides import overrides

from keras.engine import InputSpec
from keras import backend as K
from keras.layers.recurrent import GRU, _time_distributed_dense


class AttentiveGru(GRU):
    """
    GRUs typically operate over sequences of words. The motivation behind this encoding is that
    a weighted average loses ordering information over it's inputs - for instance, this is important
    in the BABI tasks.

    See Dynamic Memory Networks for more information: https://arxiv.org/pdf/1603.01417v1.pdf.
    This class extends the Keras Gated Recurrent Unit by implementing a method which substitutes
    the GRU update gate (normally a vector, z - it is noted below where it is normally computed) for a scalar
    attention weight (one per input, such as from the output of a softmax over the input vectors), which is
    pre-computed. As mentioned above, instead of using word embedding sequences as input to the GRU,
    we are using sentence encoding sequences.

    The implementation of this class is subtle - it is only very slightly different from a standard GRU.
    When it is initialised, the Keras backend will call the build method. It uses this to check that inputs being
    passed to this function are the correct size, so we allow this to be the actual input size as normal.
    However, for the internal implementation, everywhere where this global shape is used, we override it to be one
    less, as we are passing in a tensor of shape (batch, knowledge_length, 1 + encoding_dim) as we are including
    the attention mask. Therefore, we need all of the weights to have shape (*, encoding_dim),
    NOT (*, 1 + encoding_dim). All of the below methods which are overridden use some
    form of this dimension, so we correct them.
    """

    def __init__(self, output_dim, input_length, **kwargs):
        self.name = kwargs.pop('name')
        super(AttentiveGru, self).__init__(output_dim,
                                           input_length=input_length,
                                           input_dim=output_dim + 1,
                                           name=self.name, **kwargs)

    @overrides
    def step(self, inputs, states):
        # pylint: disable=invalid-name
        """
        The input to step is a tensor of shape (batch, 1 + encoding_dim), i.e. a timeslice of
        the input to this AttentiveGRU, where the time axis is the knowledge_length.
        Before we start, we strip off the attention from the beginning. Then we do the equations for a
        normal GRU, except we don't calculate the output gate z, substituting the attention weight for
        it instead.
        Note that there is some redundancy here - for instance, in the GPU mode, we do a
        larger matrix multiplication than required, as we don't use one part of it. However, for
        readability and similarity to the original GRU code in Keras, it has not been changed. In each section,
        there are commented out lines which contain code. If you were to uncomment these, remove the differences
        in the input size and replace the attention with the z gate at the output, you would have a standard
        GRU back again. We literally copied the Keras GRU code here, making some small modifications.
        """
        attention = inputs[:, 0]
        inputs = inputs[:, 1:]
        h_tm1 = states[0]  # previous memory
        B_U = states[1]  # dropout matrices for recurrent units
        B_W = states[2]

        if self.implementation == 2:

            matrix_x = K.dot(inputs * B_W[0], self.kernel)
            if self.use_bias:
                matrix_x = K.bias_add(matrix_x, self.bias)
            matrix_inner = K.dot(h_tm1 * B_U[0], self.recurrent_kernel[:, :2 * self.units])

            x_r = matrix_x[:, self.units: 2 * self.units]
            inner_r = matrix_inner[:, self.units: 2 * self.units]
            # x_z = matrix_x[:, :self.units]
            # inner_z = matrix_inner[:, :self.units]

            # z = self.recurrent_activation(x_z + inner_z)
            r = self.recurrent_activation(x_r + inner_r)

            x_h = matrix_x[:, 2 * self.units:]
            inner_h = K.dot(r * h_tm1 * B_U[0], self.recurrent_kernel[:, 2 * self.units:])
            hh = self.activation(x_h + inner_h)
        else:
            if self.implementation == 0:
                # x_z = inputs[:, :self.units]
                x_r = inputs[:, self.units: 2 * self.units]
                x_h = inputs[:, 2 * self.units:]
            elif self.implementation == 1:
                # x_z = K.dot(inputs * B_W[0], self.W_z) + self.b_z
                x_r = K.dot(inputs * B_W[1], self.kernel_r)
                x_h = K.dot(inputs * B_W[2], self.kernel_h)
                if self.use_bias:
                    x_r = K.bias_add(x_r, self.bias_r)
                    x_h = K.bias_add(x_h, self.bias_h)
            else:
                raise Exception('Unknown implementation')

            # z = self.recurrent_activation(x_z + K.dot(h_tm1 * B_U[0], self.U_z))
            r = self.recurrent_activation(x_r + K.dot(h_tm1 * B_U[1], self.recurrent_kernel_r))

            hh = self.activation(x_h + K.dot(r * h_tm1 * B_U[2], self.recurrent_kernel_h))

        # Here is the KEY difference between a GRU and an AttentiveGRU. Instead of using
        # a learnt output gate (z), we use a scalar attention vector (batch, 1) for this
        # particular background knowledge vector.
        h = K.expand_dims(attention, 1) * hh + (1 - K.expand_dims(attention, 1)) * h_tm1
        return h, [h]

    @overrides
    def build(self, input_shape):
        """
        This is used by Keras to verify things, but also to build the weights.
        The only differences from the Keras GRU (which we copied exactly
        other than the below) are:
        We generate weights with dimension input_dim[2] - 1, rather than
        dimension input_dim[2].
        There are a few variables which are created in non-'gpu' modes which
        are not required. These are commented out but left in for clarity below.
        """
        new_input_shape = list(input_shape)
        new_input_shape[2] -= 1
        super(AttentiveGru, self).build(tuple(new_input_shape))
        self.input_spec = [InputSpec(shape=input_shape)]

    @overrides
    def preprocess_input(self, inputs, training=None):
        """
        We have to override this preprocessing step, because if we are using the cpu,
        we do the weight - input multiplications in the internals of the GRU as separate,
        smaller matrix multiplications and concatenate them after. Therefore, before this
        happens, we split off the attention and then add it back afterwards.

        """
        if self.implementation == 0:

            attention = inputs[:, :, 0]  # Shape:(samples, knowledge_length)
            inputs = inputs[:, :, 1:]  # Shape:(samples, knowledge_length, word_dim)

            input_shape = self.input_spec[0].shape
            input_dim = input_shape[2] - 1
            timesteps = input_shape[1]

            x_z = _time_distributed_dense(inputs, self.kernel_z, self.bias_z,
                                          self.dropout, input_dim, self.units,
                                          timesteps, training=training)
            x_r = _time_distributed_dense(inputs, self.kernel_r, self.bias_r,
                                          self.dropout, input_dim, self.units,
                                          timesteps, training=training)
            x_h = _time_distributed_dense(inputs, self.kernel_h, self.bias_h,
                                          self.dropout, input_dim, self.units,
                                          timesteps, training=training)

            # Add attention back on to it's original place.
            return K.concatenate([K.expand_dims(attention, 2), x_z, x_r, x_h], axis=2)
        else:
            return inputs
