// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//
// Copyright 2024 FeatureForm Inc.
//

package scheduling

import (
	"fmt"
	"sort"
	"time"

	mapset "github.com/deckarep/golang-set/v2"

	"github.com/featureform/helpers/notifications"
	"github.com/featureform/metadata/proto"

	"github.com/featureform/fferr"
	"github.com/featureform/ffsync"
	ptypes "github.com/featureform/provider/types"
	ss "github.com/featureform/storage"
)

const (
	EmptyList int = iota
)

type TaskMetadataList []TaskMetadata

type TaskRunList []TaskRunMetadata

func (trl TaskRunList) FilterByStatus(statuses ...Status) TaskRunList {
	statusSet := mapset.NewSet(statuses...)
	newList := TaskRunList{}
	for _, run := range trl {
		if statusSet.Contains(run.Status) {
			newList = append(newList, run)
		}
	}
	return newList
}

type TaskManagerType string

type TaskMetadataManager struct {
	Storage     ss.MetadataStorage
	idGenerator ffsync.OrderedIdGenerator
	notifier    notifications.Notifier
}

func (m *TaskMetadataManager) SyncIncompleteRuns() error {
	runs, err := m.GetAllTaskRuns()
	if err != nil {
		return err
	}
	runs = runs.FilterByStatus(PENDING, RUNNING)

	createMap := make(map[string]string)
	for _, run := range runs {
		createMap[UnfinishedTaskRunPath(run.ID)] = run.TaskId.String()
	}
	err = m.Storage.MultiCreate(createMap)
	if err != nil {
		return err
	}
	return nil
}

func (m *TaskMetadataManager) CreateTask(name string, tType TaskType, target TaskTarget) (TaskMetadata, error) {
	id, err := m.idGenerator.NextId("task")
	if err != nil {
		return TaskMetadata{}, err
	}
	uintID, ok := id.(ffsync.Uint64OrderedId)
	if !ok {
		return TaskMetadata{}, fferr.NewInternalErrorf("cannot use type %T as task id", id)
	}

	metadata := TaskMetadata{
		ID:          TaskID(uintID),
		Name:        name,
		TaskType:    tType,
		Target:      target,
		TargetType:  target.Type(),
		DateCreated: time.Now().UTC(),
	}

	// I do this serialize and deserialize a lot in this file. Would be nice to have set and get helpers that deal with
	// all the converting instead
	serializedMetadata, err := metadata.Marshal()
	if err != nil {
		return TaskMetadata{}, err
	}

	key := TaskMetadataKey{taskID: metadata.ID}
	err = m.Storage.Create(key.String(), string(serializedMetadata))
	if err != nil {
		return TaskMetadata{}, err
	}

	runs := TaskRuns{
		TaskID: metadata.ID,
		Runs:   []TaskRunSimple{},
	}
	serializedRuns, err := runs.Marshal()
	if err != nil {
		return TaskMetadata{}, err
	}

	taskRunKey := TaskRunKey{taskID: metadata.ID}
	err = m.Storage.Create(taskRunKey.String(), string(serializedRuns))
	if err != nil {
		return TaskMetadata{}, err
	}

	return metadata, nil
}

func (m *TaskMetadataManager) GetTaskByID(id TaskID) (TaskMetadata, error) {
	key := TaskMetadataKey{taskID: id}.String()
	metadata, err := m.Storage.Get(key)
	if err != nil {
		return TaskMetadata{}, err
	}

	if len(metadata) == EmptyList {
		return TaskMetadata{}, fferr.NewInternalError(fmt.Errorf("task not found for id: %s", id.String()))
	}

	taskMetadata := TaskMetadata{}
	err = taskMetadata.Unmarshal([]byte(metadata))
	if err != nil {
		return TaskMetadata{}, err
	}
	return taskMetadata, nil
}

func (m *TaskMetadataManager) GetAllTasks() (TaskMetadataList, error) {
	metadata, err := m.Storage.List(TaskMetadataKey{}.String())
	if err != nil {
		return TaskMetadataList{}, err
	}

	return m.convertToTaskMetadataList(metadata)
}

func (m *TaskMetadataManager) convertToTaskMetadataList(metadata map[string]string) (TaskMetadataList, error) {
	tml := TaskMetadataList{}
	for _, meta := range metadata {
		taskMetadata := TaskMetadata{}
		err := taskMetadata.Unmarshal([]byte(meta))
		if err != nil {
			return TaskMetadataList{}, err
		}
		tml = append(tml, taskMetadata)
	}
	return tml, nil
}

func sortByIdDesc(s []TaskRunSimple) {
	sort.Slice(s, func(i, j int) bool {
		return s[i].RunID.Value().(uint64) > s[j].RunID.Value().(uint64)
	})
}

func (m *TaskMetadataManager) LastSuccessfulRun(taskID TaskID, runs TaskRuns) (TaskRunID, error) {
	sortByIdDesc(runs.Runs)
	r := runs.Runs
	for _, run := range r {
		if last, err := m.GetRunByID(taskID, run.RunID); err != nil {
			return nil, err
		} else if last.Status == READY {
			return last.ID, nil
		}
	}
	return nil, nil
}

func (m *TaskMetadataManager) CreateTaskRun(name string, taskID TaskID, trigger Trigger) (TaskRunMetadata, error) {
	// ids will be generated by TM
	taskRunKey := TaskRunKey{taskID: taskID}
	taskMetadata, err := m.Storage.Get(taskRunKey.String())
	if err != nil {
		return TaskRunMetadata{}, err
	}

	// Not sold on this naming for this struct. Maybe like RunHistory or something?
	runs := TaskRuns{}
	err = runs.Unmarshal([]byte(taskMetadata))
	if err != nil {
		return TaskRunMetadata{}, err
	}

	lastSuccess, err := m.LastSuccessfulRun(taskID, runs)
	if err != nil {
		return TaskRunMetadata{}, err
	}

	parentTask, err := m.GetTaskByID(taskID)
	if err != nil {
		return TaskRunMetadata{}, err
	}

	id, err := m.idGenerator.NextId("task_run")
	if err != nil {
		return TaskRunMetadata{}, err
	}

	uintID, ok := id.(ffsync.Uint64OrderedId)
	if !ok {
		return TaskRunMetadata{}, fferr.NewInternalErrorf("cannot use type %T as task run id", id)
	}

	startTime := time.Now().UTC()

	metadata := TaskRunMetadata{
		ID:             TaskRunID(uintID),
		TaskId:         taskID,
		Name:           name,
		Trigger:        trigger,
		TriggerType:    trigger.Type(),
		Target:         parentTask.Target,
		TargetType:     parentTask.TargetType,
		Status:         PENDING,
		StartTime:      startTime,
		LastSuccessful: lastSuccess,
	}

	runs.Runs = append(runs.Runs, TaskRunSimple{RunID: metadata.ID, DateCreated: startTime})

	serializedRuns, err := runs.Marshal()
	if err != nil {
		return TaskRunMetadata{}, err
	}

	serializedMetadata, err := metadata.Marshal()
	if err != nil {
		return TaskRunMetadata{}, err
	}

	taskRunMetaKey := TaskRunMetadataKey{taskID: taskID, runID: metadata.ID, date: startTime}

	// this is used to store the metadata for the run as well as the list of runs for the task
	taskRunMetadata := map[string]string{
		taskRunKey.String():                string(serializedRuns),
		taskRunMetaKey.String():            string(serializedMetadata),
		UnfinishedTaskRunPath(metadata.ID): metadata.TaskId.String(),
	}

	err = m.Storage.MultiCreate(taskRunMetadata)
	if err != nil {
		return TaskRunMetadata{}, err
	}

	return metadata, nil
}

func (m *TaskMetadataManager) latestTaskRun(runs TaskRuns) (TaskRunID, error) {

	var latestTime time.Time
	var latestRunIdx int
	for i, run := range runs.Runs {
		if i == 0 {
			latestTime = run.DateCreated
			latestRunIdx = i
		} else if run.DateCreated.After(latestTime) {
			latestTime = run.DateCreated
			latestRunIdx = i
		}
	}
	return runs.Runs[latestRunIdx].RunID, nil
}

// GetLatestRun is not guaranteed to be completely accurate. This function should only
// be used for visual purposes on the Dashboard and CLI rather than internal business logic
func (m *TaskMetadataManager) GetLatestRun(taskID TaskID) (TaskRunMetadata, error) {
	runs, err := m.getTaskRunRecords(taskID)
	if err != nil {
		return TaskRunMetadata{}, err
	}

	if len(runs.Runs) == 0 {
		return TaskRunMetadata{}, fferr.NewNoRunsForTaskError(taskID.String())
	}

	latest, err := m.latestTaskRun(runs)
	if err != nil {
		return TaskRunMetadata{}, err
	}

	run, err := m.GetRunByID(taskID, latest)
	if err != nil {
		return TaskRunMetadata{}, err
	}
	return run, nil
}

func (m *TaskMetadataManager) GetTaskRunMetadata(taskID TaskID) (TaskRunList, error) {
	runs, err := m.getTaskRunRecords(taskID)
	if err != nil {
		return TaskRunList{}, err
	}
	runMetadata := TaskRunList{}
	for _, run := range runs.Runs {
		meta, err := m.GetRunByID(taskID, run.RunID)
		if err != nil {
			return TaskRunList{}, err
		}
		runMetadata = append(runMetadata, meta)
	}
	return runMetadata, nil
}

func (m *TaskMetadataManager) getTaskRunRecords(taskID TaskID) (TaskRuns, error) {
	taskRunKey := TaskRunKey{taskID: taskID}
	taskRunMetadata, err := m.Storage.Get(taskRunKey.String())
	if err != nil {
		return TaskRuns{}, err
	}

	runs := TaskRuns{}
	err = runs.Unmarshal([]byte(taskRunMetadata))
	if err != nil {
		return TaskRuns{}, err
	}

	return runs, nil
}

func (m *TaskMetadataManager) GetRunByID(taskID TaskID, runID TaskRunID) (TaskRunMetadata, error) {
	taskRunKey := TaskRunKey{taskID: taskID}
	taskRunMetadata, err := m.Storage.Get(taskRunKey.String())
	if err != nil {
		return TaskRunMetadata{}, err
	}

	runs := TaskRuns{}
	err = runs.Unmarshal([]byte(taskRunMetadata))
	if err != nil {
		return TaskRunMetadata{}, err
	}

	// Want to move this logic out
	found, runRecord := runs.ContainsRun(runID)
	if !found {
		err := fferr.NewKeyNotFoundError(taskRunKey.String(), fmt.Errorf("run not found"))
		return TaskRunMetadata{}, err
	}

	date := runRecord.DateCreated
	taskRunMetadataKey := TaskRunMetadataKey{taskID: taskID, runID: runRecord.RunID, date: date}
	rec, err := m.Storage.Get(taskRunMetadataKey.String())
	if err != nil {
		return TaskRunMetadata{}, err
	}

	taskRun := TaskRunMetadata{}
	err = taskRun.Unmarshal([]byte(rec))
	if err != nil {
		return TaskRunMetadata{}, err
	}
	return taskRun, nil
}

func (m *TaskMetadataManager) GetRunsByDate(start time.Time, end time.Time) (TaskRunList, error) {
	/*
		Given a date range, return all runs that started within that range
		Currently, we are iterating through each day in the range and getting the runs for that day
		But in the feature, we can iterate by hour and minute as well. We just need to modify the for loop
		below to iterate by hour and minute and modify the getRunsForDay function to get runs for that hour or minute.
	*/

	// the date range is inclusive
	var runs []TaskRunMetadata

	// iterate through each day in the date range including the end date
	for date := start; date.Before(end) || date.Equal(end); date = date.AddDate(0, 0, 1) {
		dayRuns, err := m.getRunsForDay(date, start, end)
		if err != nil {
			return []TaskRunMetadata{}, err
		}
		runs = append(runs, dayRuns...)
	}

	return runs, nil
}

func (m *TaskMetadataManager) getRunsForDay(date time.Time, start time.Time, end time.Time) ([]TaskRunMetadata, error) {
	key := TaskRunMetadataKey{date: date}
	recs, err := m.Storage.List(key.TruncateToDay())
	if err != nil {
		return []TaskRunMetadata{}, err
	}

	var runs []TaskRunMetadata
	for _, record := range recs {
		taskRun := TaskRunMetadata{}
		err = taskRun.Unmarshal([]byte(record))
		if err != nil {
			return []TaskRunMetadata{}, err
		}

		// if the task run started before the start time or after the end time, skip it
		if taskRun.StartTime.Before(start) || taskRun.StartTime.After(end) {
			continue
		}
		runs = append(runs, taskRun)
	}
	return runs, nil
}

func (m *TaskMetadataManager) GetAllTaskRuns() (TaskRunList, error) {
	recs, err := m.Storage.List(TaskRunMetadataKey{}.String())
	if err != nil {
		return []TaskRunMetadata{}, err
	}

	var runs []TaskRunMetadata
	for _, record := range recs {
		taskRun := TaskRunMetadata{}
		err = taskRun.Unmarshal([]byte(record))
		if err != nil {
			return []TaskRunMetadata{}, err
		}
		runs = append(runs, taskRun)
	}
	return runs, nil
}

func (m *TaskMetadataManager) GetUnfinishedTaskRuns() (TaskRunList, error) {
	recs, err := m.Storage.List(unfinishedTaskRunPath.Prefix())
	if err != nil {
		return []TaskRunMetadata{}, err
	}

	var runs []TaskRunMetadata
	for key, val := range recs {

		runID, err := unfinishedTaskRunPath.Parse(key)
		if err != nil {
			return nil, err
		}
		taskID, err := ParseTaskID(val)
		if err != nil {
			return nil, err
		}
		run, err := m.GetRunByID(taskID, runID)
		if err != nil {
			return nil, err
		}
		runs = append(runs, run)
	}
	return runs, nil
}

func (m *TaskMetadataManager) SetRunStatus(runID TaskRunID, taskID TaskID, status *proto.ResourceStatus) error {
	fetchedMetadata, getRunErr := m.GetRunByID(taskID, runID)
	if getRunErr != nil {
		return getRunErr
	}
	var prevStatus, newStatus Status
	var updatedMetadata TaskRunMetadata

	updateStatus := func(runMetadata string) (string, error) {
		updatedRunMetadata := TaskRunMetadata{}
		unmarshalErr := updatedRunMetadata.Unmarshal([]byte(runMetadata))
		if unmarshalErr != nil {
			intErr := fferr.NewInternalError(unmarshalErr)
			return "", intErr
		}

		// set the old status first, we'll need it to notify with later
		prevStatus = updatedRunMetadata.Status

		if validateErr := updatedRunMetadata.Status.validateTransition(Status(status.Status)); validateErr != nil {
			return "", validateErr
		}

		if Status(status.Status) == FAILED && status.ErrorStatus == nil && status.ErrorMessage == "" {
			statusErr := fferr.NewInvalidArgumentError(fmt.Errorf("error is required for failed status"))
			return "", statusErr
		}

		if Status(status.Status) == PENDING || Status(status.Status) == RUNNING {
			createErr := m.Storage.Create(UnfinishedTaskRunPath(runID), fetchedMetadata.TaskId.String())
			if createErr != nil {
				return "", createErr
			}
		} else if Status(status.Status) == READY || Status(status.Status) == FAILED || Status(status.Status) == CANCELLED {
			_, deleteErr := m.Storage.Delete(UnfinishedTaskRunPath(runID))
			if deleteErr != nil {
				return "", deleteErr
			}
		}

		// Handles old and the new fields for storing the error message
		updatedRunMetadata.Status = Status(status.Status)
		if status.ErrorStatus == nil && status.ErrorMessage != "" {
			updatedRunMetadata.Error = status.ErrorMessage
		} else {
			// Set the error for use in the CLI and dashboard. Should move this logic out
			updatedRunMetadata.Error = fferr.ToDashboardError(status)
		}

		// grab the new metadata and its status
		updatedMetadata = updatedRunMetadata
		newStatus = updatedMetadata.Status

		serializedMetadata, marshalErr := updatedRunMetadata.Marshal()
		if marshalErr != nil {
			return "", marshalErr
		}

		return string(serializedMetadata), nil
	}

	//update
	taskRunMetadataKey := TaskRunMetadataKey{taskID: taskID, runID: fetchedMetadata.ID, date: fetchedMetadata.StartTime}
	updateErr := m.Storage.Update(taskRunMetadataKey.String(), updateStatus)

	//fire off notification if status changes
	if prevStatus != newStatus {
		m.notifyChange(updatedMetadata, updateErr)
	} else {
		m.Storage.Logger.Debugf("status has not changed, do not notify status: %s", prevStatus)
	}

	return updateErr
}

func (m *TaskMetadataManager) notifyChange(updatedMetadata TaskRunMetadata, updateErr error) {
	if m.notifier == nil {
		m.Storage.Logger.Warn("notifier is not set, skipping notification")
		return
	}

	if updatedMetadata.TargetType != NameVariantTarget {
		m.Storage.Logger.Debugf("target type is not NameVariant, do not notify targetType: %s", updatedMetadata.TargetType)
		return
	}

	go func() {
		errorMsg := ""
		if updateErr != nil {
			errorMsg = updateErr.Error()
		}
		nameVariant, ok := updatedMetadata.Target.(NameVariant)
		if !ok {
			m.Storage.Logger.Error("could not assert metadata target as NameVariant, cannot send slack notification")
			return
		}
		slackError := m.notifier.ChangeNotification(
			nameVariant.ResourceType,
			nameVariant.Name,
			nameVariant.Variant,
			updatedMetadata.Status.String(),
			errorMsg,
		)
		if slackError != nil {
			m.Storage.Logger.Errorf("could not notify slack for resource udpate taskId: %s, runId: %s, error: %s",
				updatedMetadata.TaskId.String(),
				updatedMetadata.ID.String(),
				slackError.Error())
			return
		}
	}()
}

func (m *TaskMetadataManager) SetResumeID(runID TaskRunID, taskID TaskID, id ptypes.ResumeID) error {
	metadata, err := m.GetRunByID(taskID, runID)
	if err != nil {
		return err
	}
	updateResumeID := func(runMetadata string) (string, error) {
		metadata := TaskRunMetadata{}
		err := metadata.Unmarshal([]byte(runMetadata))
		if err != nil {
			return "", err
		}
		metadata.ResumeID = id
		serializedMetadata, err := metadata.Marshal()
		if err != nil {
			return "", err
		}
		return string(serializedMetadata), nil
	}
	taskRunMetadataKey := TaskRunMetadataKey{taskID: taskID, runID: metadata.ID, date: metadata.StartTime}
	err = m.Storage.Update(taskRunMetadataKey.String(), updateResumeID)
	return err
}

func (m *TaskMetadataManager) SetRunEndTime(runID TaskRunID, taskID TaskID, time time.Time) error {
	if time.IsZero() {
		errMessage := fmt.Errorf("end time cannot be zero")
		err := fferr.NewInvalidArgumentError(errMessage)
		return err
	}

	metadata, err := m.GetRunByID(taskID, runID)
	if err != nil {
		return err
	}

	updateEndTime := func(runMetadata string) (string, error) {
		metadata := TaskRunMetadata{}
		err := metadata.Unmarshal([]byte(runMetadata))
		if err != nil {
			return "", err
		}

		if metadata.StartTime.After(time) {
			err := fferr.NewInvalidArgumentError(fmt.Errorf("end time cannot be before start time"))
			return "", err
		}

		metadata.EndTime = time
		serializedMetadata, err := metadata.Marshal()
		if err != nil {
			return "", err
		}

		return string(serializedMetadata), nil
	}

	taskRunMetadataKey := TaskRunMetadataKey{taskID: taskID, runID: metadata.ID, date: metadata.StartTime}
	err = m.Storage.Update(taskRunMetadataKey.String(), updateEndTime)
	return err
}

func (m *TaskMetadataManager) AppendRunLog(runID TaskRunID, taskID TaskID, log string) error {
	if log == "" {
		err := fferr.NewInvalidArgumentError(fmt.Errorf("log cannot be empty"))
		return err
	}

	metadata, err := m.GetRunByID(taskID, runID)
	if err != nil {
		return err
	}

	updateLog := func(runMetadata string) (string, error) {
		metadata := TaskRunMetadata{}
		err := metadata.Unmarshal([]byte(runMetadata))
		if err != nil {
			return "", err
		}

		metadata.Logs = append(metadata.Logs, log)

		serializedMetadata, err := metadata.Marshal()
		if err != nil {
			return "", err
		}

		return string(serializedMetadata), nil
	}

	taskRunMetadataKey := TaskRunMetadataKey{taskID: taskID, runID: metadata.ID, date: metadata.StartTime}

	err = m.Storage.Update(taskRunMetadataKey.String(), updateLog)
	return err
}

// This will just block until logic is implemented
func (m *TaskMetadataManager) WatchForCancel(runID TaskRunID, taskID TaskID) error {
	for {

	}
	return nil
}
