//go:build integration
// +build integration

package internal

import (
	"context"
	"fmt"
	"sync"
	"testing"
	"time"

	"github.com/go-co-op/gocron/v2"
	"github.com/jonboulle/clockwork"
	"github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/require"
	"github.com/uptrace/bun"

	"github.com/determined-ai/determined/master/internal/db"
	"github.com/determined-ai/determined/master/internal/logretention"
	"github.com/determined-ai/determined/master/pkg/model"
	"github.com/determined-ai/determined/master/pkg/ptrs"
	"github.com/determined-ai/determined/proto/pkg/apiv1"
	"github.com/determined-ai/determined/proto/pkg/experimentv1"
	"github.com/determined-ai/determined/proto/pkg/utilv1"
)

const (
	pgTimeFormat              = "2006-01-02T15:04:05.888738 -07:00:00"
	logRetentionConfig100days = `
retention_policy:
  log_retention_days: 100
`
	logRetentionConfig1000days = `
retention_policy:
  log_retention_days: 1000
`
	logRetentionConfigForever = `
retention_policy:
  log_retention_days: -1
`
)

func setRetentionTime(timestamp string) error {
	_, err := db.Bun().NewRaw(fmt.Sprintf(`
	CREATE or REPLACE FUNCTION retention_timestamp() RETURNS TIMESTAMPTZ AS $$
    BEGIN
        RETURN %s;
    END
    $$ LANGUAGE PLPGSQL;
	`, timestamp)).Exec(context.Background())
	return err
}

func CompleteExpAndTrials(ctx context.Context, expID int32, trialIDs []int) error {
	_, err := db.Bun().NewUpdate().Table("experiments").
		Set("state = ?", model.CompletedState).
		Where("id = ?", expID).
		Exec(ctx)
	if err != nil {
		return err
	}
	_, err = db.Bun().NewUpdate().Table("runs").
		Set("state = ?", model.CompletedState).
		Where("id IN (?)", bun.In(trialIDs)).
		Exec(ctx)
	if err != nil {
		return err
	}
	return nil
}

func quoteSetRetentionTime(timestamp time.Time) error {
	return setRetentionTime(fmt.Sprintf("'%s'", timestamp.Format(pgTimeFormat)))
}

func resetRetentionTime() error {
	return setRetentionTime("transaction_timestamp()")
}

// nolint: exhaustruct
func CreateTestRetentionExperiment(
	ctx context.Context, t *testing.T, api *apiServer, config string, numTrials int,
) (*experimentv1.Experiment, []int, []model.TaskID) {
	conf := fmt.Sprintf(`
entrypoint: test
checkpoint_storage:
  type: shared_fs
  host_path: /tmp
hyperparameters:
  n_filters1:
    count: 100
    maxval: 100
    minval: 1
    type: int
searcher:
  name: grid
  metric: none
  max_concurrent_trials: %d
%s
`, numTrials, config)
	createReq := &apiv1.CreateExperimentRequest{
		ModelDefinition: []*utilv1.File{{Content: []byte{1}}},
		Config:          conf,
		ParentId:        0,
		Activate:        true,
		ProjectId:       1,
	}

	// No checkpoint specified anywhere.
	mockRM := MockRM()
	api.m.rm = mockRM
	mockRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil)
	resp, err := api.CreateExperiment(ctx, createReq)
	require.NoError(t, err)
	require.Empty(t, resp.Warnings)
	trialIDs, taskIDs, err := db.ExperimentsTrialAndTaskIDs(ctx, db.Bun(), []int{int(resp.Experiment.Id)})
	require.NoError(t, err)
	return resp.Experiment, trialIDs, taskIDs
}

func TestDeleteExpiredTaskLogs(t *testing.T) {
	// Reset retention time to transaction time on exit.
	defer func() {
		require.NoError(t, resetRetentionTime())
	}()

	api, _, ctx := setupAPITest(t, nil)

	// Clear all logs.
	_, err := db.Bun().NewDelete().Model(&model.TaskLog{}).Where("TRUE").Exec(context.Background())
	require.NoError(t, err)

	// Create an experiment1 with 5 trials and no special config.
	experiment1, trialIDs1, taskIDs1 := CreateTestRetentionExperiment(ctx, t, api, "", 5)
	require.Nil(t, experiment1.EndTime)
	require.Len(t, trialIDs1, 5)
	require.Len(t, taskIDs1, 5)

	// Create an experiment1 with 5 trials and a config to expire in 1000 days.
	experiment2, trialIDs2, taskIDs2 := CreateTestRetentionExperiment(ctx, t, api, logRetentionConfig1000days, 5)
	require.Nil(t, experiment2.EndTime)
	require.Len(t, trialIDs2, 5)
	require.Len(t, taskIDs2, 5)

	// Create an experiment1 with 5 trials and config to never expire.
	experiment3, trialIDs3, taskIDs3 := CreateTestRetentionExperiment(ctx, t, api, logRetentionConfigForever, 5)
	require.Nil(t, experiment3.EndTime)
	require.Len(t, trialIDs3, 5)
	require.Len(t, taskIDs3, 5)

	taskIDs := []model.TaskID{}
	taskIDs = append(taskIDs, taskIDs1...)
	taskIDs = append(taskIDs, taskIDs2...)
	taskIDs = append(taskIDs, taskIDs3...)

	// Add logs for each task.
	for _, taskID := range taskIDs {
		task, err := db.TaskByID(ctx, taskID)
		require.NoError(t, err)
		require.Nil(t, task.EndTime)
		require.NoError(t, api.m.db.AddTaskLogs(
			[]*model.TaskLog{{TaskID: string(taskID), Log: "log1\n"}}))
		require.NoError(t, api.m.db.AddTaskLogs(
			[]*model.TaskLog{{TaskID: string(taskID), Log: "log2\n"}}))
	}

	// Check that the logs are there.
	for _, taskID := range taskIDs {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Equal(t, 2, logCount)
	}

	// Move time database time 30 days in the future.
	require.NoError(t, quoteSetRetentionTime(time.Now().AddDate(0, 0, 30)))

	// Verify that the logs are still there if we delete with 0 day expiration.
	count, err := logretention.DeleteExpiredTaskLogs(ctx, ptrs.Ptr(int16(0)))
	require.NoError(t, err)
	require.Equal(t, int64(0), count)

	// Add an end time to the task logs.
	for _, taskID := range taskIDs {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Equal(t, 2, logCount)
		task, err := db.TaskByID(context.Background(), taskID)
		require.NoError(t, err)
		task.EndTime = ptrs.Ptr(time.Now())
		res, err := db.Bun().NewUpdate().Model(task).Where("task_id = ?", taskID).Exec(context.Background())
		require.NoError(t, err)
		rows, err := res.RowsAffected()
		require.NoError(t, err)
		require.Equal(t, int64(1), rows)
	}
	// Verify that the logs are still there if we delete without an expirary.
	count, err = logretention.DeleteExpiredTaskLogs(ctx, nil)
	require.NoError(t, err)
	require.Equal(t, int64(0), count)

	// Move time database time 100 days in the future.
	require.NoError(t, quoteSetRetentionTime(time.Now().AddDate(0, 0, 100).Add(time.Second)))
	// Verify that the logs are deleted with a 100 day expiration.
	count, err = logretention.DeleteExpiredTaskLogs(ctx, ptrs.Ptr(int16(100)))
	require.NoError(t, err)
	require.Equal(t, int64(10), count)

	// Ensure that experiment1 logs are deleted.
	for _, taskID := range taskIDs1 {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Zero(t, logCount)
	}
	// Ensure that experiment2 logs are not deleted.
	for _, taskID := range taskIDs2 {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Equal(t, 2, logCount)
	}
	// Ensure that experiment3 logs are not deleted.
	for _, taskID := range taskIDs3 {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Equal(t, 2, logCount)
	}

	// Move time database time 999 days in the future.
	require.NoError(t, quoteSetRetentionTime(time.Now().AddDate(0, 0, 999).Add(time.Second)))
	// Verify that the logs are not deleted with a 999 day expiration.
	count, err = logretention.DeleteExpiredTaskLogs(ctx, ptrs.Ptr(int16(999)))
	require.NoError(t, err)
	require.Equal(t, int64(0), count)

	// Move time database time 1000 days in the future.
	require.NoError(t, quoteSetRetentionTime(time.Now().AddDate(0, 0, 1000).Add(time.Second)))
	// Verify that the logs are deleted with a 1000 day expiration.
	count, err = logretention.DeleteExpiredTaskLogs(ctx, ptrs.Ptr(int16(1000)))
	require.NoError(t, err)
	require.Equal(t, int64(10), count)

	// Ensure that experiment2 logs are deleted.
	for _, taskID := range taskIDs2 {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Zero(t, logCount)
	}
	// Ensure that experiment3 logs are not deleted.
	for _, taskID := range taskIDs3 {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Equal(t, 2, logCount)
	}

	// Move time database time 100 years in the future.
	require.NoError(t, quoteSetRetentionTime(time.Now().AddDate(100, 0, 0).Add(time.Second)))
	// Verify that the logs are not deleted with a 0 day expiration.
	count, err = logretention.DeleteExpiredTaskLogs(ctx, ptrs.Ptr(int16(0)))
	require.NoError(t, err)
	require.Equal(t, int64(0), count)

	// Ensure that experiment3 logs are not deleted.
	for _, taskID := range taskIDs3 {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Equal(t, 2, logCount)
	}
}

func countTaskLogs(db *db.PgDB, taskIDs []model.TaskID) (int, error) {
	count := 0
	for _, taskID := range taskIDs {
		logCount, err := db.TaskLogsCount(taskID, nil)
		if err != nil {
			return 0, err
		}
		count += logCount
	}
	return count, nil
}

func incrementScheduler(
	t *testing.T,
	lrs *logretention.Scheduler,
	timestamp time.Time,
	fakeClock clockwork.FakeClock,
	days int,
) (time.Time, clockwork.FakeClock) {
	for i := 0; i < days; i++ {
		fakeClock.BlockUntil(1)
		lrs.TestingOnlySynchronizationHelper.Add(1)
		timestamp = timestamp.AddDate(0, 0, 1)
		require.NoError(t, quoteSetRetentionTime(timestamp))
		fakeClock.Advance(timestamp.Sub(fakeClock.Now()))
		lrs.TestingOnlySynchronizationHelper.Wait()
	}
	return timestamp, fakeClock
}

func TestScheduleRetentionNoConfig(t *testing.T) {
	// Reset retention time to transaction time on exit.
	defer func() {
		require.NoError(t, resetRetentionTime())
	}()

	fakeClock := clockwork.NewFakeClock()
	lrs, err := logretention.NewScheduler(gocron.WithClock(fakeClock))
	require.NoError(t, err)
	defer func() {
		if err := lrs.Shutdown(); err != nil {
			t.Logf("failed to shutdown gocron.Scheduler: %v", err)
		}
	}()
	lrs.TestingOnlySynchronizationHelper = &sync.WaitGroup{}

	api, _, ctx := setupAPITest(t, nil)

	err = lrs.Schedule(model.LogRetentionPolicy{
		LogRetentionDays: ptrs.Ptr(int16(10)),
		Schedule:         ptrs.Ptr("0 0 * * *"),
	})
	require.NoError(t, err)

	// Clear all logs.
	_, err = db.Bun().NewDelete().Model(&model.TaskLog{}).Where("TRUE").Exec(context.Background())
	require.NoError(t, err)

	// Create an experiment1 with 5 trials and no special config.
	experiment, trialIDs, taskIDs := CreateTestRetentionExperiment(ctx, t, api, "", 5)
	require.Nil(t, experiment.EndTime)
	require.Len(t, trialIDs, 5)
	require.Len(t, taskIDs, 5)

	// Add logs for each task.
	for _, taskID := range taskIDs {
		task, err := db.TaskByID(ctx, taskID)
		require.NoError(t, err)
		require.Nil(t, task.EndTime)
		require.NoError(t, api.m.db.AddTaskLogs(
			[]*model.TaskLog{{TaskID: string(taskID), Log: "log1\n"}}))
		require.NoError(t, api.m.db.AddTaskLogs(
			[]*model.TaskLog{{TaskID: string(taskID), Log: "log2\n"}}))
	}

	// Check that the logs are there.
	for _, taskID := range taskIDs {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Equal(t, 2, logCount)
	}

	// Advance time to midnight.
	now := time.Now()
	midnight := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 1, 0, 0, now.Location())
	midnight, fakeClock = incrementScheduler(t, lrs, midnight, fakeClock, 1)

	// Verify that the logs are still there.
	count, err := countTaskLogs(api.m.db, taskIDs)
	require.NoError(t, err)
	require.Equal(t, 10, count)

	// Add an end time to the task logs.
	for _, taskID := range taskIDs {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Equal(t, 2, logCount)
		task, err := db.TaskByID(context.Background(), taskID)
		require.NoError(t, err)
		task.EndTime = ptrs.Ptr(time.Now())
		res, err := db.Bun().NewUpdate().Model(task).Where("task_id = ?", taskID).Exec(context.Background())
		require.NoError(t, err)
		rows, err := res.RowsAffected()
		require.NoError(t, err)
		require.Equal(t, int64(1), rows)
	}

	// Mark experiments and trials as completed.
	err = CompleteExpAndTrials(ctx, experiment.Id, trialIDs)
	require.NoError(t, err)

	// Advance time by 1 day.
	midnight, fakeClock = incrementScheduler(t, lrs, midnight, fakeClock, 1)
	// Verify that the logs are still there.
	count, err = countTaskLogs(api.m.db, taskIDs)
	require.NoError(t, err)
	require.Equal(t, 10, count)

	// Advance time by 9 days.
	_, _ = incrementScheduler(t, lrs, midnight, fakeClock, 9)
	// Verify that logs are deleted.
	count, err = countTaskLogs(api.m.db, taskIDs)
	require.NoError(t, err)
	require.Zero(t, count)

	// Ensure that experiment1 logs are deleted.
	for _, taskID := range taskIDs {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Zero(t, logCount)
	}
}

func TestScheduleRetention100days(t *testing.T) {
	// Reset retention time to transaction time on exit.
	defer func() {
		require.NoError(t, resetRetentionTime())
	}()

	fakeClock := clockwork.NewFakeClock()
	lrs, err := logretention.NewScheduler(gocron.WithClock(fakeClock))
	require.NoError(t, err)
	defer func() {
		if err := lrs.Shutdown(); err != nil {
			t.Logf("failed to shutdown gocron.Scheduler: %v", err)
		}
	}()
	lrs.TestingOnlySynchronizationHelper = &sync.WaitGroup{}

	api, _, ctx := setupAPITest(t, nil)

	err = lrs.Schedule(model.LogRetentionPolicy{
		LogRetentionDays: ptrs.Ptr(int16(10)),
		Schedule:         ptrs.Ptr("0 0 * * *"),
	})
	require.NoError(t, err)

	// Clear all logs.
	_, err = db.Bun().NewDelete().Model(&model.TaskLog{}).Where("TRUE").Exec(context.Background())
	require.NoError(t, err)

	// Create an experiment with 5 trials and a config to expire in 1000 days.
	experiment, trialIDs, taskIDs := CreateTestRetentionExperiment(ctx, t, api, logRetentionConfig100days, 5)
	require.Nil(t, experiment.EndTime)
	require.Len(t, trialIDs, 5)
	require.Len(t, taskIDs, 5)

	// Add logs for each task.
	for _, taskID := range taskIDs {
		task, err := db.TaskByID(ctx, taskID)
		require.NoError(t, err)
		require.Nil(t, task.EndTime)
		require.NoError(t, api.m.db.AddTaskLogs(
			[]*model.TaskLog{{TaskID: string(taskID), Log: "log1\n"}}))
		require.NoError(t, api.m.db.AddTaskLogs(
			[]*model.TaskLog{{TaskID: string(taskID), Log: "log2\n"}}))
	}

	// Check that the logs are there.
	for _, taskID := range taskIDs {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Equal(t, 2, logCount)
	}

	// Advance time to midnight.
	now := time.Now()
	midnight := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 1, 0, 0, now.Location())
	midnight, fakeClock = incrementScheduler(t, lrs, midnight, fakeClock, 1)

	// Verify that the logs are still there.
	count, err := countTaskLogs(api.m.db, taskIDs)
	require.NoError(t, err)
	require.Equal(t, 10, count)

	// Add an end time to the task logs.
	for _, taskID := range taskIDs {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Equal(t, 2, logCount)
		task, err := db.TaskByID(context.Background(), taskID)
		require.NoError(t, err)
		task.EndTime = ptrs.Ptr(time.Now())
		res, err := db.Bun().NewUpdate().Model(task).Where("task_id = ?", taskID).Exec(context.Background())
		require.NoError(t, err)
		rows, err := res.RowsAffected()
		require.NoError(t, err)
		require.Equal(t, int64(1), rows)
	}

	// Mark experiments and trials as completed.
	err = CompleteExpAndTrials(ctx, experiment.Id, trialIDs)
	require.NoError(t, err)

	// Advance time by 98 days.
	midnight, fakeClock = incrementScheduler(t, lrs, midnight, fakeClock, 98)
	// Verify that no logs are deleted.
	count, err = countTaskLogs(api.m.db, taskIDs)
	require.NoError(t, err)
	require.Equal(t, 10, count)

	// Ensure that experiment logs are not deleted.
	for _, taskID := range taskIDs {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Equal(t, 2, logCount)
	}

	// Move time 1 day in the future.
	_, _ = incrementScheduler(t, lrs, midnight, fakeClock, 1)
	// Verify that logs are deleted.
	count, err = countTaskLogs(api.m.db, taskIDs)
	require.NoError(t, err)
	require.Zero(t, count)

	// Ensure that experiment logs are deleted.
	for _, taskID := range taskIDs {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Zero(t, logCount)
	}
}

func TestScheduleRetentionNeverExpire(t *testing.T) {
	// Reset retention time to transaction time on exit.
	defer func() {
		require.NoError(t, resetRetentionTime())
	}()

	fakeClock := clockwork.NewFakeClock()
	lrs, err := logretention.NewScheduler(gocron.WithClock(fakeClock))
	require.NoError(t, err)
	defer func() {
		if err := lrs.Shutdown(); err != nil {
			t.Logf("failed to shutdown gocron.Scheduler: %v", err)
		}
	}()
	lrs.TestingOnlySynchronizationHelper = &sync.WaitGroup{}

	api, _, ctx := setupAPITest(t, nil)

	err = lrs.Schedule(model.LogRetentionPolicy{
		LogRetentionDays: ptrs.Ptr(int16(10)),
		Schedule:         ptrs.Ptr("0 0 * * *"),
	})
	require.NoError(t, err)

	// Clear all logs.
	_, err = db.Bun().NewDelete().Model(&model.TaskLog{}).Where("TRUE").Exec(context.Background())
	require.NoError(t, err)

	// Create an experiment with 5 trials and config to never expire.
	experiment, trialIDs, taskIDs := CreateTestRetentionExperiment(ctx, t, api, logRetentionConfigForever, 5)
	require.Nil(t, experiment.EndTime)
	require.Len(t, trialIDs, 5)
	require.Len(t, taskIDs, 5)

	// Add logs for each task.
	for _, taskID := range taskIDs {
		task, err := db.TaskByID(ctx, taskID)
		require.NoError(t, err)
		require.Nil(t, task.EndTime)
		require.NoError(t, api.m.db.AddTaskLogs(
			[]*model.TaskLog{{TaskID: string(taskID), Log: "log1\n"}}))
		require.NoError(t, api.m.db.AddTaskLogs(
			[]*model.TaskLog{{TaskID: string(taskID), Log: "log2\n"}}))
	}

	// Check that the logs are there.
	for _, taskID := range taskIDs {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Equal(t, 2, logCount)
	}

	// Advance time to midnight.
	now := time.Now()
	midnight := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 1, 0, 0, now.Location())
	midnight, fakeClock = incrementScheduler(t, lrs, midnight, fakeClock, 1)

	// Verify that the logs are still there.
	count, err := countTaskLogs(api.m.db, taskIDs)
	require.NoError(t, err)
	require.Equal(t, 10, count)

	// Add an end time to the task logs.
	for _, taskID := range taskIDs {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Equal(t, 2, logCount)
		task, err := db.TaskByID(context.Background(), taskID)
		require.NoError(t, err)
		task.EndTime = ptrs.Ptr(time.Now())
		res, err := db.Bun().NewUpdate().Model(task).Where("task_id = ?", taskID).Exec(context.Background())
		require.NoError(t, err)
		rows, err := res.RowsAffected()
		require.NoError(t, err)
		require.Equal(t, int64(1), rows)
	}

	// Mark experiments and trials as completed.
	err = CompleteExpAndTrials(ctx, experiment.Id, trialIDs)
	require.NoError(t, err)

	// Move time 1 year in the future.
	_, _ = incrementScheduler(t, lrs, midnight, fakeClock, 365)
	// Verify that no logs are deleted.
	count, err = countTaskLogs(api.m.db, taskIDs)
	require.NoError(t, err)
	require.Equal(t, 10, count)

	// Ensure that experiment logs are not deleted.
	for _, taskID := range taskIDs {
		logCount, err := api.m.db.TaskLogsCount(taskID, nil)
		require.NoError(t, err)
		require.Equal(t, 2, logCount)
	}
}
