/*
 * Licensed to Elasticsearch under one or more contributor
 * license agreements. See the NOTICE file distributed with
 * this work for additional information regarding copyright
 * ownership. Elasticsearch licenses this file to you under
 * the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.elasticsearch.indices;

import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.elasticsearch.common.util.concurrent.EsExecutors.PROCESSORS_SETTING;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import org.apache.lucene.search.ReferenceManager;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.codec.CodecService;
import org.elasticsearch.index.engine.EngineConfig;
import org.elasticsearch.index.engine.InternalEngine;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.index.shard.IndexShardTestCase;
import org.elasticsearch.indices.recovery.RecoveryState;
import org.elasticsearch.threadpool.Scheduler.Cancellable;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.threadpool.ThreadPoolStats;

public class IndexingMemoryControllerTests extends IndexShardTestCase {

    @Override
    public Settings threadPoolSettings() {
        // This test was failing on github actions because the processor size was so small that the
        // refresh threadpool ended up with only 1 thread. That was too small for this test to run correctly.
        // Therefore lets set the number of processor to 4 to make sure the size of the refresh threadpool is at least 2.
        return Settings.builder().put(PROCESSORS_SETTING.getKey(), 4).build();
    }

    static class MockController extends IndexingMemoryController {

        // Size of each shard's indexing buffer
        final Map<IndexShard, Long> indexBufferRAMBytesUsed = new HashMap<>();

        // How many bytes this shard is currently moving to disk
        final Map<IndexShard, Long> writingBytes = new HashMap<>();

        // Shards that are currently throttled
        final Set<IndexShard> throttled = new HashSet<>();

        MockController(Settings settings) {
            super(Settings.builder()
                      .put("indices.memory.interval", "200h") // disable it
                      .put(settings)
                      .build(), null, null);
        }

        public void deleteShard(IndexShard shard) {
            indexBufferRAMBytesUsed.remove(shard);
            writingBytes.remove(shard);
        }

        @Override
        public List<IndexShard> availableShards() {
            return new ArrayList<>(indexBufferRAMBytesUsed.keySet());
        }

        @Override
        protected long getIndexBufferRAMBytesUsed(IndexShard shard) {
            return indexBufferRAMBytesUsed.get(shard) + writingBytes.get(shard);
        }

        @Override
        protected long getShardWritingBytes(IndexShard shard) {
            Long bytes = writingBytes.get(shard);
            if (bytes == null) {
                return 0;
            } else {
                return bytes;
            }
        }

        @Override
        protected void checkIdle(IndexShard shard, long inactiveTimeNS) {
        }

        @Override
        public void writeIndexingBufferAsync(IndexShard shard) {
            long bytes = indexBufferRAMBytesUsed.put(shard, 0L);
            writingBytes.put(shard, writingBytes.get(shard) + bytes);
            indexBufferRAMBytesUsed.put(shard, 0L);
        }

        @Override
        public void activateThrottling(IndexShard shard) {
            assertThat(throttled.add(shard)).isTrue();
        }

        @Override
        public void deactivateThrottling(IndexShard shard) {
            assertThat(throttled.remove(shard)).isTrue();
        }

        public void doneWriting(IndexShard shard) {
            writingBytes.put(shard, 0L);
        }

        public void assertBuffer(IndexShard shard, int expectedMB) {
            Long actual = indexBufferRAMBytesUsed.get(shard);
            if (actual == null) {
                actual = 0L;
            }
            assertThat(actual.longValue()).isEqualTo(expectedMB * 1024 * 1024);
        }

        public void assertThrottled(IndexShard shard) {
            assertThat(throttled.contains(shard)).isTrue();
        }

        public void assertNotThrottled(IndexShard shard) {
            assertThat(throttled.contains(shard)).isFalse();
        }

        public void assertWriting(IndexShard shard, int expectedMB) {
            Long actual = writingBytes.get(shard);
            if (actual == null) {
                actual = 0L;
            }
            assertThat(actual.longValue()).isEqualTo(expectedMB * 1024 * 1024);
        }

        public void simulateIndexing(IndexShard shard) {
            Long bytes = indexBufferRAMBytesUsed.get(shard);
            if (bytes == null) {
                bytes = 0L;
                // First time we are seeing this shard:
                writingBytes.put(shard, 0L);
            }
            // Each doc we index takes up a megabyte!
            bytes += 1024*1024;
            indexBufferRAMBytesUsed.put(shard, bytes);
            forceCheck();
        }

        @Override
        protected Cancellable scheduleTask(ThreadPool threadPool) {
            return null;
        }
    }

    public void testShardAdditionAndRemoval() throws IOException {

        MockController controller = new MockController(Settings.builder()
                                                           .put("indices.memory.index_buffer_size", "4mb").build());
        IndexShard shard0 = newStartedShard();
        controller.simulateIndexing(shard0);
        controller.assertBuffer(shard0, 1);

        // add another shard
        IndexShard shard1 = newStartedShard();
        controller.simulateIndexing(shard1);
        controller.assertBuffer(shard0, 1);
        controller.assertBuffer(shard1, 1);

        // remove first shard
        controller.deleteShard(shard0);
        controller.forceCheck();
        controller.assertBuffer(shard1, 1);

        // remove second shard
        controller.deleteShard(shard1);
        controller.forceCheck();

        // add a new one
        IndexShard shard2 = newStartedShard();
        controller.simulateIndexing(shard2);
        controller.assertBuffer(shard2, 1);
        closeShards(shard0, shard1, shard2);
    }

    public void testActiveInactive() throws IOException {

        MockController controller = new MockController(Settings.builder()
                                                           .put("indices.memory.index_buffer_size", "5mb")
                                                           .build());

        IndexShard shard0 = newStartedShard();
        controller.simulateIndexing(shard0);
        IndexShard shard1 = newStartedShard();
        controller.simulateIndexing(shard1);

        controller.assertBuffer(shard0, 1);
        controller.assertBuffer(shard1, 1);

        controller.simulateIndexing(shard0);
        controller.simulateIndexing(shard1);

        controller.assertBuffer(shard0, 2);
        controller.assertBuffer(shard1, 2);

        // index into one shard only, crosses the 5mb limit, so shard1 is refreshed
        controller.simulateIndexing(shard0);
        controller.simulateIndexing(shard0);
        controller.assertBuffer(shard0, 0);
        controller.assertBuffer(shard1, 2);

        controller.simulateIndexing(shard1);
        controller.simulateIndexing(shard1);
        controller.assertBuffer(shard1, 4);
        controller.simulateIndexing(shard1);
        controller.simulateIndexing(shard1);
        // shard1 crossed 5 mb and is now cleared:
        controller.assertBuffer(shard1, 0);
        closeShards(shard0, shard1);
    }

    public void testMinBufferSizes() {
        MockController controller = new MockController(Settings.builder()
                                                           .put("indices.memory.index_buffer_size", "0.001%")
                                                           .put("indices.memory.min_index_buffer_size", "6mb").build());

        assertThat(controller.indexingBufferSize()).isEqualTo(new ByteSizeValue(6, ByteSizeUnit.MB));
    }

    public void testNegativeMinIndexBufferSize() {
        assertThatThrownBy(() -> new MockController(Settings.builder()
            .put("indices.memory.min_index_buffer_size", "-6mb").build())
        ).isExactlyInstanceOf(IllegalArgumentException.class)
            .hasMessage("failed to parse setting [indices.memory.min_index_buffer_size] with value [-6mb] as a size in bytes");

    }

    public void testNegativeInterval() {
        assertThatThrownBy(() -> new MockController(Settings.builder()
                .put("indices.memory.interval", "-42s").build())
        ).isExactlyInstanceOf(IllegalArgumentException.class)
            .hasMessage(
                "failed to parse setting [indices.memory.interval] with value " +
                "[-42s] as a time value: negative durations are not supported");

    }

    public void testNegativeShardInactiveTime() {
        assertThatThrownBy(() -> new MockController(Settings.builder()
            .put("indices.memory.shard_inactive_time", "-42s").build())
        ).isExactlyInstanceOf(IllegalArgumentException.class)
            .hasMessage(
                "failed to parse setting [indices.memory.shard_inactive_time] with value " +
                "[-42s] as a time value: negative durations are not supported");
    }

    public void testNegativeMaxIndexBufferSize() {
        assertThatThrownBy(() -> new MockController(Settings.builder()
                .put("indices.memory.max_index_buffer_size", "-6mb").build()))
            .isExactlyInstanceOf(IllegalArgumentException.class)
            .hasMessage("failed to parse setting [indices.memory.max_index_buffer_size] with value [-6mb] as a size in bytes");
    }

    public void testMaxBufferSizes() {
        MockController controller = new MockController(Settings.builder()
                                                           .put("indices.memory.index_buffer_size", "90%")
                                                           .put("indices.memory.max_index_buffer_size", "6mb").build());

        assertThat(controller.indexingBufferSize()).isEqualTo(new ByteSizeValue(6, ByteSizeUnit.MB));
    }

    public void testThrottling() throws Exception {

        MockController controller = new MockController(Settings.builder()
                                                           .put("indices.memory.index_buffer_size", "4mb").build());
        IndexShard shard0 = newStartedShard();
        IndexShard shard1 = newStartedShard();
        controller.simulateIndexing(shard0);
        controller.simulateIndexing(shard0);
        controller.simulateIndexing(shard0);
        controller.assertBuffer(shard0, 3);
        controller.simulateIndexing(shard1);
        controller.simulateIndexing(shard1);

        // We are now using 5 MB, so we should be writing shard0 since it's using the most heap:
        controller.assertWriting(shard0, 3);
        controller.assertWriting(shard1, 0);
        controller.assertBuffer(shard0, 0);
        controller.assertBuffer(shard1, 2);

        controller.simulateIndexing(shard0);
        controller.simulateIndexing(shard1);
        controller.simulateIndexing(shard1);

        // Now we are still writing 3 MB (shard0), and using 5 MB index buffers, so we should now 1) be writing shard1,
        // and 2) be throttling shard1:
        controller.assertWriting(shard0, 3);
        controller.assertWriting(shard1, 4);
        controller.assertBuffer(shard0, 1);
        controller.assertBuffer(shard1, 0);

        controller.assertNotThrottled(shard0);
        controller.assertThrottled(shard1);

        logger.info("--> Indexing more data");

        // More indexing to shard0
        controller.simulateIndexing(shard0);
        controller.simulateIndexing(shard0);
        controller.simulateIndexing(shard0);
        controller.simulateIndexing(shard0);

        // Now we are using 5 MB again, so shard0 should also be writing and now also be throttled:
        controller.assertWriting(shard0, 8);
        controller.assertWriting(shard1, 4);
        controller.assertBuffer(shard0, 0);
        controller.assertBuffer(shard1, 0);

        controller.assertThrottled(shard0);
        controller.assertThrottled(shard1);

        // Both shards finally finish writing, and throttling should stop:
        controller.doneWriting(shard0);
        controller.doneWriting(shard1);
        controller.forceCheck();
        controller.assertNotThrottled(shard0);
        controller.assertNotThrottled(shard1);
        closeShards(shard0, shard1);
    }

    public void testTranslogRecoveryWorksWithIMC() throws IOException {
        IndexShard shard = newStartedShard(true);
        for (int i = 0; i < 100; i++) {
            indexDoc(shard, Integer.toString(i), "{}", XContentType.JSON);
        }
        shard.close("simon says", false);
        AtomicReference<IndexShard> shardRef = new AtomicReference<>();
        Settings settings = Settings.builder().put("indices.memory.index_buffer_size", "50kb").build();
        Iterable<IndexShard> iterable = () -> (shardRef.get() == null) ? Collections.emptyIterator()
            : Collections.singleton(shardRef.get()).iterator();
        AtomicInteger flushes = new AtomicInteger();
        IndexingMemoryController imc = new IndexingMemoryController(settings, threadPool, iterable) {
            @Override
            protected void writeIndexingBufferAsync(IndexShard shard) {
                assertThat(shardRef.get()).isEqualTo(shard);
                flushes.incrementAndGet();
                shard.writeIndexingBuffer();
            }
        };
        shard = reinitShard(shard, imc);
        shardRef.set(shard);
        assertThat(imc.availableShards().size()).isEqualTo(0);
        DiscoveryNode localNode = new DiscoveryNode("foo", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT);
        shard.markAsRecovering("store", new RecoveryState(shard.routingEntry(), localNode, null));

        assertThat(imc.availableShards().size()).isEqualTo(1);
        assertThat(recoverFromStore(shard)).isTrue();
        assertThat(flushes.get())
            .as("we should have flushed in IMC at least once")
            .isGreaterThanOrEqualTo(1);
        closeShards(shard);
    }

    EngineConfig configWithRefreshListener(EngineConfig config, ReferenceManager.RefreshListener listener) {
        final List<ReferenceManager.RefreshListener> internalRefreshListeners = new ArrayList<>(config.getInternalRefreshListeners());;
        internalRefreshListeners.add(listener);
        return new EngineConfig(config.getShardId(), config.getThreadPool(),
                                config.getIndexSettings(), config.getStore(), config.getMergePolicy(), config.getAnalyzer(),
                                new CodecService(), config.getEventListener(), config.getQueryCache(),
                                config.getQueryCachingPolicy(), config.getTranslogConfig(), config.getFlushMergesAfter(),
                                config.getExternalRefreshListeners(), internalRefreshListeners,
                                config.getCircuitBreakerService(), config.getGlobalCheckpointSupplier(), config.retentionLeasesSupplier(),
                                config.getPrimaryTermSupplier(), config.getTombstoneDocSupplier());
    }

    ThreadPoolStats.Stats getRefreshThreadPoolStats() {
        final ThreadPoolStats stats = threadPool.stats();
        for (ThreadPoolStats.Stats s : stats) {
            if (s.getName().equals(ThreadPool.Names.REFRESH)) {
                return s;
            }
        }
        throw new AssertionError("refresh thread pool stats not found [" + stats + "]");
    }

    public void testSkipRefreshIfShardIsRefreshingAlready() throws Exception {
        SetOnce<CountDownLatch> refreshLatch = new SetOnce<>();
        ReferenceManager.RefreshListener refreshListener = new ReferenceManager.RefreshListener() {
            @Override
            public void beforeRefresh() {
                if (refreshLatch.get() != null) {
                    try {
                        refreshLatch.get().await();
                    } catch (InterruptedException e) {
                        throw new AssertionError(e);
                    }
                }
            }

            @Override
            public void afterRefresh(boolean didRefresh) {

            }
        };
        IndexShard shard = newStartedShard(
            randomBoolean(),
            Settings.EMPTY,
            List.of(idxSettings -> Optional.of(config -> new InternalEngine(configWithRefreshListener(config, refreshListener))))
        );
        refreshLatch.set(new CountDownLatch(1)); // block refresh
        final IndexingMemoryController controller = new IndexingMemoryController(
            Settings.builder().put("indices.memory.interval", "200h") // disable it
                .put("indices.memory.index_buffer_size", "1024b")
                .build(),
            threadPool,
            Collections.singleton(shard)) {
            @Override
            protected long getIndexBufferRAMBytesUsed(IndexShard shard) {
                return randomLongBetween(1025, 10 * 1024 * 1024);
            }

            @Override
            protected long getShardWritingBytes(IndexShard shard) {
                return 0L;
            }
        };
        int iterations = randomIntBetween(10, 100);
        ThreadPoolStats.Stats beforeStats = getRefreshThreadPoolStats();
        for (int i = 0; i < iterations; i++) {
            controller.forceCheck();
        }

        assertBusy(() -> {
            ThreadPoolStats.Stats stats = getRefreshThreadPoolStats();
            assertThat(stats.getCompleted()).isEqualTo(beforeStats.getCompleted() + iterations - 1);
        });

        refreshLatch.get().countDown(); // allow refresh
        assertBusy(() -> {
            ThreadPoolStats.Stats stats = getRefreshThreadPoolStats();
            assertThat(stats.getCompleted()).isEqualTo(beforeStats.getCompleted() + iterations);
        });
        closeShards(shard);
    }
}
