/*
 * Copyright The OpenZipkin Authors
 * SPDX-License-Identifier: Apache-2.0
 */
package brave.jms;

import brave.Tags;
import brave.handler.MutableSpan;
import brave.messaging.MessagingRuleSampler;
import brave.messaging.MessagingTracing;
import brave.propagation.SamplingFlags;
import brave.propagation.TraceContext;
import brave.sampler.Sampler;
import javax.jms.JMSConsumer;
import javax.jms.JMSContext;
import javax.jms.JMSProducer;
import javax.jms.Message;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import static brave.Span.Kind.CONSUMER;
import static brave.jms.MessageProperties.getPropertyIfString;
import static brave.messaging.MessagingRequestMatchers.operationEquals;
import static javax.jms.JMSContext.AUTO_ACKNOWLEDGE;
import static org.assertj.core.api.Assertions.assertThat;

/**
 * When adding tests here, add tests not already in {@link ITJms_1_1_TracingMessageConsumer} to
 * {@link brave.jms.ITJms_2_0_TracingMessageConsumer}
 */
class ITTracingJMSConsumer extends ITJms {
  @RegisterExtension ArtemisJmsExtension jms = new ArtemisJmsExtension();

  JMSContext tracedContext;
  JMSProducer producer;
  JMSConsumer consumer;
  JMSContext context;

  @BeforeEach void setup() {
    context = jms.newContext();
    producer = context.createProducer();

    setupTracedConsumer(jmsTracing);
  }

  void setupTracedConsumer(JmsTracing jmsTracing) {
    if (consumer != null) consumer.close();
    if (tracedContext != null) tracedContext.close();
    tracedContext = jmsTracing.connectionFactory(jms.factory).createContext(AUTO_ACKNOWLEDGE);
    consumer = tracedContext.createConsumer(jms.queue);
  }

  @AfterEach void tearDownTraced() {
    tracedContext.close();
  }

  @Test void messageListener_runsAfterConsumer() {
    consumer.setMessageListener(m -> {
    });
    producer.send(jms.queue, "foo");

    MutableSpan consumerSpan = testSpanHandler.takeRemoteSpan(CONSUMER);
    MutableSpan listenerSpan = testSpanHandler.takeLocalSpan();

    assertChildOf(listenerSpan, consumerSpan);
    assertSequential(consumerSpan, listenerSpan);
  }

  @Test void messageListener_startsNewTrace() {
    messageListener_startsNewTrace(() -> producer.send(jms.queue, "foo"));
  }

  @Test void messageListener_startsNewTrace_bytes() {
    messageListener_startsNewTrace(() -> producer.send(jms.queue, new byte[] {1, 2, 3, 4}));
  }

  void messageListener_startsNewTrace(Runnable send) {
    consumer.setMessageListener(m -> {
      tracing.tracer().currentSpanCustomizer().name("message-listener");

      // clearing headers ensures later work doesn't try to use the old parent
      String b3 = getPropertyIfString(m, "b3");
      tracing.tracer().currentSpanCustomizer().tag("b3", String.valueOf(b3 != null));
    });

    send.run();

    MutableSpan consumerSpan = testSpanHandler.takeRemoteSpan(CONSUMER);
    MutableSpan listenerSpan = testSpanHandler.takeLocalSpan();

    assertThat(consumerSpan.name()).isEqualTo("receive");
    assertThat(consumerSpan.tags())
      .hasSize(1)
      .containsEntry("jms.queue", jms.queueName);

    assertChildOf(listenerSpan, consumerSpan);
    assertThat(listenerSpan.name()).isEqualTo("message-listener"); // overridden name
    assertThat(listenerSpan.tags())
      .hasSize(1) // no redundant copy of consumer tags
      .containsEntry("b3", "false"); // b3 header not leaked to listener
  }

  @Test void messageListener_resumesTrace() {
    messageListener_resumesTrace(() -> producer.send(jms.queue, "foo"));
  }

  @Test void messageListener_resumesTrace_bytes() {
    messageListener_resumesTrace(() -> producer.send(jms.queue, new byte[] {1, 2, 3, 4}));
  }

  void messageListener_resumesTrace(Runnable send) {
    consumer.setMessageListener(m -> {
      // clearing headers ensures later work doesn't try to use the old parent
      String b3 = getPropertyIfString(m, "b3");
      tracing.tracer().currentSpanCustomizer().tag("b3", String.valueOf(b3 != null));
    });

    TraceContext parent = newTraceContext(SamplingFlags.SAMPLED);
    producer.setProperty("b3", parent.traceIdString() + "-" + parent.spanIdString() + "-1");
    send.run();

    MutableSpan consumerSpan = testSpanHandler.takeRemoteSpan(CONSUMER);
    MutableSpan listenerSpan = testSpanHandler.takeLocalSpan();

    assertChildOf(consumerSpan, parent);
    assertChildOf(listenerSpan, consumerSpan);

    assertThat(listenerSpan.tags())
      .hasSize(1) // no redundant copy of consumer tags
      .containsEntry("b3", "false"); // b3 header not leaked to listener
  }

  @Test void messageListener_readsBaggage() {
    messageListener_readsBaggage(() -> producer.send(jms.queue, "foo"));
  }

  @Test void messageListener_readsBaggage_bytes() {
    messageListener_readsBaggage(() -> producer.send(jms.queue, new byte[] {1, 2, 3, 4}));
  }

  void messageListener_readsBaggage(Runnable send) {
    consumer.setMessageListener(m ->
      Tags.BAGGAGE_FIELD.tag(BAGGAGE_FIELD, tracing.tracer().currentSpan())
    );

    String baggage = "joey";
    producer.setProperty(BAGGAGE_FIELD_KEY, baggage);
    send.run();

    MutableSpan consumerSpan = testSpanHandler.takeRemoteSpan(CONSUMER);
    MutableSpan listenerSpan = testSpanHandler.takeLocalSpan();

    assertThat(consumerSpan.parentId()).isNull();
    assertChildOf(listenerSpan, consumerSpan);
    assertThat(listenerSpan.tags())
      .containsEntry(BAGGAGE_FIELD.name(), baggage);
  }

  @Test void receive_startsNewTrace() {
    receive_startsNewTrace(() -> producer.send(jms.queue, "foo"));
  }

  @Test void receive_startsNewTrace_bytes() {
    receive_startsNewTrace(() -> producer.send(jms.queue, new byte[] {1, 2, 3, 4}));
  }

  void receive_startsNewTrace(Runnable send) {
    send.run();
    consumer.receive();
    MutableSpan consumerSpan = testSpanHandler.takeRemoteSpan(CONSUMER);
    assertThat(consumerSpan.name()).isEqualTo("receive");
    assertThat(consumerSpan.tags()).containsEntry("jms.queue", jms.queueName);
  }

  @Test void receive_resumesTrace() {
    receiveResumesTrace(() -> producer.send(jms.queue, "foo"));
  }

  @Test void receive_resumesTrace_bytes() {
    receiveResumesTrace(() -> producer.send(jms.queue, new byte[] {1, 2, 3, 4}));
  }

  void receiveResumesTrace(Runnable send) {
    TraceContext parent = newTraceContext(SamplingFlags.SAMPLED);
    producer.setProperty("b3", parent.traceIdString() + "-" + parent.spanIdString() + "-1");
    send.run();

    Message received = consumer.receive();

    MutableSpan consumerSpan = testSpanHandler.takeRemoteSpan(CONSUMER);
    assertChildOf(consumerSpan, parent);

    assertThat(getPropertyIfString(received, "b3"))
      .isEqualTo(parent.traceIdString() + "-" + consumerSpan.id() + "-1");
  }

  @Test void receive_customSampler() {
    setupTracedConsumer(JmsTracing.create(MessagingTracing.newBuilder(tracing)
      .consumerSampler(MessagingRuleSampler.newBuilder()
        .putRule(operationEquals("receive"), Sampler.NEVER_SAMPLE)
        .build()).build()));

    producer.send(jms.queue, "foo");

    // Check that the message headers are not sampled
    assertThat(getPropertyIfString(consumer.receive(), "b3"))
      .endsWith("-0");

    // @After will also check that the consumer was not sampled
  }
}
