/*
 * Copyright The OpenTelemetry Authors
 * SPDX-License-Identifier: Apache-2.0
 */

package io.opentelemetry.instrumentation.runtimemetrics.java8;

import static java.util.Objects.requireNonNull;

import io.opentelemetry.api.OpenTelemetry;
import io.opentelemetry.api.common.Attributes;
import io.opentelemetry.api.metrics.Meter;
import io.opentelemetry.api.metrics.ObservableLongMeasurement;
import io.opentelemetry.instrumentation.runtimemetrics.java8.internal.JmxRuntimeMetricsUtil;
import io.opentelemetry.semconv.JvmAttributes;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadInfo;
import java.lang.management.ThreadMXBean;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.annotation.Nullable;

/**
 * Registers measurements that generate metrics about JVM threads. The metrics generated by this
 * class follow <a
 * href="https://github.com/open-telemetry/semantic-conventions/blob/main/docs/runtime/jvm-metrics.md">the
 * stable JVM metrics semantic conventions</a>.
 *
 * <p>Example usage:
 *
 * <pre>{@code
 * Threads.registerObservers(GlobalOpenTelemetry.get());
 * }</pre>
 *
 * <p>Example metrics being exported:
 *
 * <pre>
 *   jvm.thread.count{jvm.thread.daemon=true,jvm.thread.state="waiting"} 1
 *   jvm.thread.count{jvm.thread.daemon=true,jvm.thread.state="runnable"} 2
 *   jvm.thread.count{jvm.thread.daemon=false,jvm.thread.state="waiting"} 2
 *   jvm.thread.count{jvm.thread.daemon=false,jvm.thread.state="runnable"} 3
 * </pre>
 */
public final class Threads {

  // Visible for testing
  static final Threads INSTANCE = new Threads();

  /** Register observers for java runtime class metrics. */
  public static List<AutoCloseable> registerObservers(OpenTelemetry openTelemetry) {
    return INSTANCE.registerObservers(openTelemetry, !isJava9OrNewer());
  }

  private List<AutoCloseable> registerObservers(OpenTelemetry openTelemetry, boolean useThread) {
    if (useThread) {
      return registerObservers(openTelemetry, Threads::getThreads);
    }
    return registerObservers(openTelemetry, ManagementFactory.getThreadMXBean());
  }

  // Visible for testing
  List<AutoCloseable> registerObservers(OpenTelemetry openTelemetry, ThreadMXBean threadBean) {
    return registerObservers(
        openTelemetry,
        isJava9OrNewer() ? Threads::java9AndNewerCallback : Threads::java8Callback,
        threadBean);
  }

  // Visible for testing
  List<AutoCloseable> registerObservers(
      OpenTelemetry openTelemetry, Supplier<Thread[]> threadSupplier) {
    return registerObservers(openTelemetry, Threads::java8ThreadCallback, threadSupplier);
  }

  private static <T> List<AutoCloseable> registerObservers(
      OpenTelemetry openTelemetry,
      Function<T, Consumer<ObservableLongMeasurement>> callbackProvider,
      T threadInfo) {
    Meter meter = JmxRuntimeMetricsUtil.getMeter(openTelemetry);
    List<AutoCloseable> observables = new ArrayList<>();

    observables.add(
        meter
            .upDownCounterBuilder("jvm.thread.count")
            .setDescription("Number of executing platform threads.")
            .setUnit("{thread}")
            .buildWithCallback(callbackProvider.apply(threadInfo)));

    return observables;
  }

  @Nullable private static final MethodHandle THREAD_INFO_IS_DAEMON;

  static {
    MethodHandle isDaemon;
    try {
      isDaemon =
          MethodHandles.publicLookup()
              .findVirtual(ThreadInfo.class, "isDaemon", MethodType.methodType(boolean.class));
    } catch (NoSuchMethodException | IllegalAccessException e) {
      isDaemon = null;
    }
    THREAD_INFO_IS_DAEMON = isDaemon;
  }

  private static boolean isJava9OrNewer() {
    return THREAD_INFO_IS_DAEMON != null;
  }

  private static Consumer<ObservableLongMeasurement> java8Callback(ThreadMXBean threadBean) {
    return measurement -> {
      int daemonThreadCount = threadBean.getDaemonThreadCount();
      measurement.record(
          daemonThreadCount,
          Attributes.builder().put(JvmAttributes.JVM_THREAD_DAEMON, true).build());
      measurement.record(
          threadBean.getThreadCount() - daemonThreadCount,
          Attributes.builder().put(JvmAttributes.JVM_THREAD_DAEMON, false).build());
    };
  }

  private static Consumer<ObservableLongMeasurement> java8ThreadCallback(
      Supplier<Thread[]> supplier) {
    return measurement -> {
      Map<Attributes, Long> counts = new HashMap<>();
      for (Thread thread : supplier.get()) {
        Attributes threadAttributes = threadAttributes(thread);
        counts.compute(threadAttributes, (k, value) -> value == null ? 1 : value + 1);
      }
      counts.forEach((threadAttributes, count) -> measurement.record(count, threadAttributes));
    };
  }

  // Visible for testing
  static Thread[] getThreads() {
    ThreadGroup threadGroup = Thread.currentThread().getThreadGroup();
    while (threadGroup.getParent() != null) {
      threadGroup = threadGroup.getParent();
    }
    // use a slightly larger array in case new threads are created
    int count = threadGroup.activeCount() + 10;
    Thread[] threads = new Thread[count];
    int resultSize = threadGroup.enumerate(threads);
    if (resultSize == threads.length) {
      return threads;
    }
    Thread[] result = new Thread[resultSize];
    System.arraycopy(threads, 0, result, 0, resultSize);
    return result;
  }

  private static Consumer<ObservableLongMeasurement> java9AndNewerCallback(
      ThreadMXBean threadBean) {
    return measurement -> {
      Map<Attributes, Long> counts = new HashMap<>();
      long[] threadIds = threadBean.getAllThreadIds();
      for (ThreadInfo threadInfo : threadBean.getThreadInfo(threadIds)) {
        if (threadInfo == null) {
          continue;
        }
        Attributes threadAttributes = threadAttributes(threadInfo);
        counts.compute(threadAttributes, (k, value) -> value == null ? 1 : value + 1);
      }
      counts.forEach((threadAttributes, count) -> measurement.record(count, threadAttributes));
    };
  }

  private static Attributes threadAttributes(ThreadInfo threadInfo) {
    boolean isDaemon;
    try {
      isDaemon = (boolean) requireNonNull(THREAD_INFO_IS_DAEMON).invoke(threadInfo);
    } catch (Throwable e) {
      throw new IllegalStateException("Unexpected error happened during ThreadInfo#isDaemon()", e);
    }
    String threadState = threadInfo.getThreadState().name().toLowerCase(Locale.ROOT);
    return Attributes.of(
        JvmAttributes.JVM_THREAD_DAEMON, isDaemon, JvmAttributes.JVM_THREAD_STATE, threadState);
  }

  private static Attributes threadAttributes(Thread thread) {
    boolean isDaemon = thread.isDaemon();
    String threadState = thread.getState().name().toLowerCase(Locale.ROOT);
    return Attributes.of(
        JvmAttributes.JVM_THREAD_DAEMON, isDaemon, JvmAttributes.JVM_THREAD_STATE, threadState);
  }

  private Threads() {}
}
