package internal

import (
	"context"
	"fmt"
	"strings"
	"time"

	"github.com/google/uuid"
	"github.com/pkg/errors"
	"github.com/sirupsen/logrus"
	"golang.org/x/sync/errgroup"

	"github.com/determined-ai/determined/master/internal/config"
	"github.com/determined-ai/determined/master/internal/db"
	"github.com/determined-ai/determined/master/internal/rm"
	"github.com/determined-ai/determined/master/internal/rm/tasklist"
	"github.com/determined-ai/determined/master/internal/sproto"
	"github.com/determined-ai/determined/master/internal/storage"
	"github.com/determined-ai/determined/master/internal/task"
	"github.com/determined-ai/determined/master/internal/user"
	"github.com/determined-ai/determined/master/pkg/logger"
	"github.com/determined-ai/determined/master/pkg/model"
	"github.com/determined-ai/determined/master/pkg/protoutils/protoconverter"
	"github.com/determined-ai/determined/master/pkg/schemas/expconf"
	"github.com/determined-ai/determined/master/pkg/tasks"
)

const fullDeleteGlob = "**/*"

func runCheckpointGCForCheckpoints(
	rm rm.ResourceManager,
	db *db.PgDB,
	jobID model.JobID,
	jobSubmissionTime time.Time,
	taskSpec *tasks.TaskSpec,
	expID int,
	legacyConfig expconf.LegacyConfig,
	toDeleteCheckpoints []uuid.UUID,
	checkpointGlobs []string,
	deleteTensorboards bool,
	agentUserGroup *model.AgentUserGroup,
	owner *model.User,
	logCtx logger.Context,
) error {
	groups, err := storage.GroupCheckpoints(context.TODO(), toDeleteCheckpoints)
	if err != nil {
		return err
	}

	var wg errgroup.Group
	for _, g := range groups {
		wg.Go(func() error {
			taskID := model.TaskID(fmt.Sprintf("%d.%s", expID, uuid.New()))
			if err := runCheckpointGCTask(
				rm, db, taskID, jobID, jobSubmissionTime, *taskSpec,
				expID, legacyConfig, g.StorageID, g.Checkpoints,
				checkpointGlobs, deleteTensorboards,
				agentUserGroup, owner, logCtx,
			); err != nil {
				return err
			}

			return nil
		})
	}

	if err := wg.Wait(); err != nil {
		return fmt.Errorf("one or more checkpoint GC jobs failed: %w", err)
	}

	return nil
}

func runCheckpointGCTask(
	rm rm.ResourceManager,
	pgDB *db.PgDB,
	taskID model.TaskID,
	jobID model.JobID,
	jobSubmissionTime time.Time,
	taskSpec tasks.TaskSpec,
	expID int,
	legacyConfig expconf.LegacyConfig,
	storageID *model.StorageBackendID,
	toDeleteCheckpoints []uuid.UUID,
	checkpointGlobs []string,
	deleteTensorboards bool,
	agentUserGroup *model.AgentUserGroup,
	owner *model.User,
	logCtx logger.Context,
) error {
	conv := &protoconverter.ProtoConverter{}
	checkpointStrIDs := conv.ToStringList(toDeleteCheckpoints)
	deleteCheckpointsStr := strings.Join(checkpointStrIDs, ",")

	if len(deleteCheckpointsStr) == 0 && !deleteTensorboards {
		// Early return as nothing to do
		return nil
	}

	rp, err := rm.ResolveResourcePool("", -1, 0)
	if err != nil {
		return fmt.Errorf("resolving resource pool: %w", err)
	}

	// t.Base is just a shallow copy of the m.taskSpec on the master, so
	// use caution when mutating it.
	tcd, err := rm.TaskContainerDefaults(
		rp,
		config.GetMasterConfig().TaskContainerDefaults)
	if err != nil {
		return fmt.Errorf("creating task container defaults: %v", err)
	}
	taskSpec.TaskContainerDefaults = tcd

	userSessionToken, err := user.StartSession(context.TODO(), owner)
	if err != nil {
		return errors.Wrapf(err, "unable to create user session for checkpoint gc")
	}
	taskSpec.UserSessionToken = userSessionToken
	taskSpec.AgentUserGroup = agentUserGroup
	taskSpec.Owner = owner

	gcSpec := tasks.GCCkptSpec{
		Base:               taskSpec,
		ExperimentID:       expID,
		LegacyConfig:       legacyConfig,
		ToDelete:           deleteCheckpointsStr,
		CheckpointGlobs:    checkpointGlobs,
		DeleteTensorboards: deleteTensorboards,
	}

	// Update checkpoint storage with storageID.
	if storageID != nil {
		checkpointStorage, err := storage.Backend(context.TODO(), *storageID)
		if err != nil {
			return fmt.Errorf("getting storage id %d in create gc task: %w", *storageID, err)
		}
		gcSpec.LegacyConfig.CheckpointStorage = checkpointStorage
	}

	logCtx = logger.MergeContexts(logCtx, logger.Context{
		"task-id":   taskID,
		"task-type": model.TaskTypeCheckpointGC,
	})
	syslog := logrus.WithField("component", "checkpointgc").WithFields(logCtx.Fields())

	if err := db.AddTask(context.TODO(), &model.Task{
		TaskID:     taskID,
		TaskType:   model.TaskTypeCheckpointGC,
		StartTime:  time.Now().UTC(),
		JobID:      &jobID,
		LogVersion: model.CurrentTaskLogVersion,
	}); err != nil {
		return errors.Wrapf(err, "persisting GC task %s", taskID)
	}

	allocationID := model.AllocationID(fmt.Sprintf("%s.%d", taskID, 1))
	gcJobID := model.JobID(fmt.Sprintf("checkpoint_gc-%s", allocationID))

	resultChan := make(chan error, 1)
	onExit := func(ae *task.AllocationExited) {
		if err := db.CompleteTask(context.TODO(), taskID, time.Now().UTC()); err != nil {
			syslog.WithError(err).Error("marking GC task complete")
		}
		if err := tasklist.GroupPriorityChangeRegistry.Delete(gcJobID); err != nil {
			syslog.WithError(err).Error("deleting group priority change registry")
		}
		resultChan <- ae.Err
	}

	if err := tasklist.GroupPriorityChangeRegistry.Add(gcJobID, nil); err != nil {
		return err
	}
	err = task.DefaultService.StartAllocation(logCtx, sproto.AllocateRequest{
		TaskID:            taskID,
		JobID:             gcJobID,
		JobSubmissionTime: jobSubmissionTime,
		AllocationID:      allocationID,
		Name:              fmt.Sprintf("Checkpoint GC (Experiment %d)", expID),
		FittingRequirements: sproto.FittingRequirements{
			SingleAgent: true,
		},
		ResourcePool: rp.String(),
	}, pgDB, rm, gcSpec, onExit)
	if err != nil {
		return err
	}
	return <-resultChan
}
