# -*- coding: utf-8 -*-
# @Time    : 2022/11/17 13:34

from typing import Tuple
import torch
from torch import nn, Tensor

__all__ = [
 'CircleLoss'
]




class CircleLoss(nn.Module):
    def __init__(self, m: float, gamma: float) -> None:
        super(CircleLoss, self).__init__()
        self.m = m
        self.gamma = gamma
        self.soft_plus = nn.Softplus()

    def forward(self,normed_feature: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
        similarity_matrix = normed_feature @ normed_feature.transpose(1, 0)
        label_matrix = label.unsqueeze(1) == label.unsqueeze(0)

        positive_matrix = label_matrix.triu(diagonal=1)
        negative_matrix = label_matrix.logical_not().triu(diagonal=1)

        similarity_matrix = similarity_matrix.view(-1)
        positive_matrix = positive_matrix.view(-1)
        negative_matrix = negative_matrix.view(-1)
        sp,sn = similarity_matrix[positive_matrix], similarity_matrix[negative_matrix]


        ap = torch.clamp_min(- sp.detach() + 1 + self.m, min=0.)
        an = torch.clamp_min(sn.detach() + self.m, min=0.)

        delta_p = 1 - self.m
        delta_n = self.m

        logit_p = - ap * (sp - delta_p) * self.gamma
        logit_n = an * (sn - delta_n) * self.gamma

        loss = self.soft_plus(torch.logsumexp(logit_n, dim=0) + torch.logsumexp(logit_p, dim=0))

        return loss


# if __name__ == "__main__":
#     feat = nn.functional.normalize(torch.rand(256, 64, requires_grad=True))
#     lbl = torch.randint(high=10, size=(256,))
#
#     inp_sp, inp_sn = convert_label_to_similarity(feat, lbl)
#
#     criterion = CircleLoss(m=0.25, gamma=256)
#     circle_loss = criterion(inp_sp, inp_sn)
#
#     print(circle_loss)