/*
 * SonarQube
 * Copyright (C) 2009-2024 SonarSource SA
 * mailto:info AT sonarsource DOT com
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 3 of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
 */
package org.sonar.server.platform.web;

import java.io.IOException;
import java.util.Optional;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.MDC;
import org.sonar.core.platform.ExtensionContainer;
import org.sonar.server.platform.Platform;
import org.sonar.server.platform.web.requestid.RequestIdGenerator;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class RequestIdFilterTest {
  private final Platform platform = mock(Platform.class);
  private final RequestIdGenerator requestIdGenerator = mock(RequestIdGenerator.class);
  private final ServletRequest servletRequest = mock(ServletRequest.class);
  private final ServletResponse servletResponse = mock(ServletResponse.class);
  private final FilterChain filterChain = mock(FilterChain.class);
  private final RequestIdFilter underTest = new RequestIdFilter(platform);

  @Before
  public void setUp() {
    ExtensionContainer container = mock(ExtensionContainer.class);
    when(container.getOptionalComponentByType(RequestIdGenerator.class)).thenReturn(Optional.of(requestIdGenerator));
    when(platform.getContainer()).thenReturn(container);
  }

  @Test
  public void filter_put_id_in_MDC_and_remove_it_after_chain_has_executed() throws IOException, ServletException {
    String requestId = "request id";
    when(requestIdGenerator.generate()).thenReturn(requestId);
    doAnswer(invocation -> assertThat(MDC.get("HTTP_REQUEST_ID")).isEqualTo(requestId))
      .when(filterChain)
      .doFilter(servletRequest, servletResponse);

    underTest.doFilter(servletRequest, servletResponse, filterChain);

    assertThat(MDC.get("HTTP_REQUEST_ID")).isNull();
  }

  @Test
  public void filter_put_id_in_MDC_and_remove_it_after_chain_throws_exception() throws IOException, ServletException {
    RuntimeException exception = new RuntimeException("Simulating chain failing");
    String requestId = "request id";
    when(requestIdGenerator.generate()).thenReturn(requestId);
    doAnswer(invocation -> {
      assertThat(MDC.get("HTTP_REQUEST_ID")).isEqualTo(requestId);
      throw exception;
    })
      .when(filterChain)
      .doFilter(servletRequest, servletResponse);

    try {
      underTest.doFilter(servletRequest, servletResponse, filterChain);
      fail("A runtime exception should have been raised");
    } catch (RuntimeException e) {
      assertThat(e).isEqualTo(exception);
    } finally {
      assertThat(MDC.get("HTTP_REQUEST_ID")).isNull();
    }
  }

  @Test
  public void filter_adds_requestId_to_request_passed_on_to_chain() throws IOException, ServletException {
    String requestId = "request id";
    when(requestIdGenerator.generate()).thenReturn(requestId);

    underTest.doFilter(servletRequest, servletResponse, filterChain);

    verify(servletRequest).setAttribute("ID", requestId);
  }

  @Test
  public void filter_does_not_fail_when_there_is_no_RequestIdGenerator_in_container() throws IOException, ServletException {
    ExtensionContainer container = mock(ExtensionContainer.class);
    when(container.getOptionalComponentByType(RequestIdGenerator.class)).thenReturn(Optional.empty());
    when(platform.getContainer()).thenReturn(container);
    RequestIdFilter underTest = new RequestIdFilter(platform);

    underTest.doFilter(servletRequest, servletResponse, filterChain);
  }

  @Test
  public void filter_does_not_add_requestId_to_request_passed_on_to_chain_when_there_is_no_RequestIdGenerator_in_container() throws IOException, ServletException {
    ExtensionContainer container = mock(ExtensionContainer.class);
    when(container.getOptionalComponentByType(RequestIdGenerator.class)).thenReturn(Optional.empty());
    when(platform.getContainer()).thenReturn(container);
    RequestIdFilter underTest = new RequestIdFilter(platform);

    underTest.doFilter(servletRequest, servletResponse, filterChain);

    verify(servletRequest, times(0)).setAttribute(anyString(), anyString());
  }
}
