package database

import (
	"errors"
	"fmt"
	"reflect"
	"regexp"
	"slices"

	"github.com/google/uuid"
	"github.com/rotisserie/eris"
	log "github.com/sirupsen/logrus"
	"gorm.io/gorm"
	"gorm.io/gorm/clause"

	"github.com/G-Research/fasttrackml/pkg/api/mlflow/common"
)

type experimentInfo struct {
	destID   int32
	sourceID int32
}

var uuidRegexp = regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)

// Importer will handle the transport of data from source to destination db.
type Importer struct {
	sourceDB                 *gorm.DB
	destinationDB            *gorm.DB
	experimentInfos          []experimentInfo
	sourceNamespace          *Namespace
	destinationNamespace     *Namespace
	sourceNamespaceName      *string
	destinationNamespaceName *string
}

// NewImporter initializes an Importer.
func NewImporter(input, output *gorm.DB, options ...func(importer *Importer)) *Importer {
	importer := Importer{
		destinationDB:   output,
		sourceDB:        input,
		experimentInfos: []experimentInfo{},
	}
	for _, o := range options {
		o(&importer)
	}
	return &importer
}

// Import copies the contents of input db to output db.
func (s *Importer) Import() error {
	// if source Namespace has been provided, then apply restriction, otherwise fetch everything.
	if s.sourceNamespaceName != nil {
		var sourceNamespace Namespace
		if err := s.sourceDB.Where(
			"code = ?", *s.sourceNamespaceName,
		).First(&sourceNamespace).Error; err != nil {
			if errors.Is(err, gorm.ErrRecordNotFound) {
				return eris.Wrapf(err, "source namespace %s not found", *s.sourceNamespaceName)
			}
			return eris.Wrapf(err, "error getting namespace %s", *s.sourceNamespaceName)
		}
		s.sourceNamespace = &sourceNamespace
	}

	// if destination Namespace has been provided, then apply restriction, otherwise fetch everything.
	if s.destinationNamespaceName != nil {
		var destinationNamespace Namespace
		if err := s.destinationDB.Where(
			"code = ?", *s.destinationNamespaceName,
		).First(&destinationNamespace).Error; err != nil {
			if errors.Is(err, gorm.ErrRecordNotFound) {
				return eris.Wrapf(err, "destination namespace %s not found", *s.destinationNamespaceName)
			}
			return eris.Wrapf(err, "error getting namespace %s", *s.destinationNamespaceName)
		}
		s.destinationNamespace = &destinationNamespace
	}

	tables := []string{
		"namespaces",
		"apps",
		"dashboards",
		"experiments",
		"experiment_tags",
		"runs",
		"tags",
		"params",
		"contexts",
		"metrics",
		"latest_metrics",
		"shared_tags",
		"run_shared_tags",
	}
	for _, table := range tables {
		if err := s.importTable(table); err != nil {
			return eris.Wrapf(err, "error importing table %s", table)
		}
	}
	if err := s.updateNamespaceDefaultExperiment(); err != nil {
		return eris.Wrap(err, "error updating namespace default experiment")
	}
	return nil
}

// importExperiments copies the contents of the experiment table from sourceDB to destinationDB,
// while recording the new ID.
func (s *Importer) importExperiments() error {
	// Start transaction in the destDB
	err := s.destinationDB.Transaction(func(destTX *gorm.DB) error {
		// Query data from the source database
		rows, err := EntityLimitedByNamespace(
			"experiments",
			s.sourceDB.Model(Experiment{}),
			s.sourceNamespace,
		).Rows()
		if err != nil {
			return eris.Wrap(err, "error creating Rows instance from source")
		}
		if err := rows.Err(); err != nil {
			return eris.Wrap(err, "error getting query result")
		}
		//nolint:errcheck
		defer rows.Close()

		count := 0
		for rows.Next() {
			var scannedItem Experiment
			if err := s.sourceDB.ScanRows(rows, &scannedItem); err != nil {
				return eris.Wrap(err, "error creating Rows instance from source")
			}
			newItem := Experiment{
				Name:             scannedItem.Name,
				NamespaceID:      scannedItem.NamespaceID,
				ArtifactLocation: scannedItem.ArtifactLocation,
				LifecycleStage:   scannedItem.LifecycleStage,
				CreationTime:     scannedItem.CreationTime,
				LastUpdateTime:   scannedItem.LastUpdateTime,
			}
			// override Namespace if it was provided during the import.
			if s.destinationNamespace != nil {
				newItem.NamespaceID = s.destinationNamespace.ID
			}
			// keep default experiment ID, but otherwise draw new one
			if *scannedItem.ID == int32(0) {
				newItem.ID = scannedItem.ID
			}
			if err := destTX.Where(
				Experiment{Name: scannedItem.Name},
			).FirstOrCreate(
				&newItem,
			).Error; err != nil {
				return eris.Wrap(err, "error creating destination row")
			}
			s.saveExperimentInfo(scannedItem, newItem)
			count++
		}
		log.Infof("Importing experiments - found %d records", count)
		return nil
	})
	if err != nil {
		return eris.Wrap(err, "error copying experiments table")
	}
	return nil
}

// importTable copies the contents of one table (model) from sourceDB
// while updating the experiment_id to destinationDB.
func (s *Importer) importTable(table string) error {
	switch table {
	// handle a special case for experiments.
	case "experiments":
		if err := s.importExperiments(); err != nil {
			return eris.Wrap(err, "error importing table experiments")
		}
	default:
		// Start transaction in the destinationDB
		err := s.destinationDB.Transaction(func(destTX *gorm.DB) error {
			// Query data from the source database
			rows, err := EntityLimitedByNamespace(
				table,
				s.sourceDB.Table(table).Select(
					fmt.Sprintf("%s.*", table),
				),
				s.sourceNamespace,
			).Rows()
			if err != nil {
				return eris.Wrap(err, "error creating rows instance from source")
			}
			if err := rows.Err(); err != nil {
				return eris.Wrap(err, "error getting query result")
			}
			//nolint:errcheck
			defer rows.Close()

			count := 0
			for rows.Next() {
				var item map[string]any
				if err = s.sourceDB.Debug().ScanRows(rows, &item); err != nil {
					return eris.Wrap(err, "error scanning source row")
				}
				item, err = s.translateFields(item)
				if err != nil {
					return eris.Wrap(err, "error translating fields")
				}
				item = ApplyNamespaceRestriction(table, item, s.destinationNamespace)
				if err := destTX.Table(table).Clauses(
					clause.OnConflict{DoNothing: true}, clause.Returning{},
				).Create(&item).Error; err != nil {
					return eris.Wrap(err, "error creating destination row")
				}
				count++
			}
			log.Infof("Importing %s - found %d records", table, count)
			return nil
		})
		if err != nil {
			return err
		}
	}

	return nil
}

// saveExperimentInfo maps source and destination experiment for later id mapping.
func (s *Importer) saveExperimentInfo(source, dest Experiment) {
	s.experimentInfos = append(s.experimentInfos, experimentInfo{
		destID:   *dest.ID,
		sourceID: *source.ID,
	})
}

// translateFields alters row before creation as needed (especially, replacing old experiment_id with new).
func (s *Importer) translateFields(item map[string]any) (map[string]any, error) {
	// boolean fields are numeric when coming from sqlite
	booleanFields := []string{"is_nan", "is_archived"}
	for _, field := range booleanFields {
		if fieldVal, ok := item[field]; ok {
			switch v := fieldVal.(type) {
			case bool:
				break
			default:
				item[field] = v != 0.0
			}
		}
	}
	// items with experiment_id need to reference the new ID
	if expID, ok := item["experiment_id"]; ok {
		var id int32
		switch v := expID.(type) {
		case int32:
			id = v
		case int64:
			id = int32(v)
		default:
			return nil, eris.Errorf("unable to assert %s as int32: %d", "experiment_id", expID)
		}
		for _, expInfo := range s.experimentInfos {
			if expInfo.sourceID == id {
				item["experiment_id"] = expInfo.destID
			}
		}
	}
	// items with string uuid need to translate to UUID native type
	uuidFields := []string{"id", "app_id"}
	for _, field := range uuidFields {
		if srcUUID, ok := item[field]; ok {
			// when uuid, this field will be pointer to interface{} and requires some reflection
			stringUUID := fmt.Sprintf("%v", reflect.Indirect(reflect.ValueOf(srcUUID)))
			if uuidRegexp.MatchString(stringUUID) {
				binID, err := uuid.Parse(stringUUID)
				if err != nil {
					return nil, eris.Errorf("unable to create binary UUID field from string: %s", stringUUID)
				}
				item[field] = binID
			}
		}
	}
	return item, nil
}

// updateNamespaceDefaultExperiment updates the default_experiment_id for all namespaces
// when its related experiment received a new id.
func (s *Importer) updateNamespaceDefaultExperiment() error {
	// Start transaction in the destinationDB
	err := s.destinationDB.Transaction(func(destTX *gorm.DB) error {
		// Get namespaces
		var namespaces []Namespace
		if err := destTX.Model(Namespace{}).Find(&namespaces).Error; err != nil {
			return eris.Wrap(err, "error reading namespaces in destination")
		}
		for _, ns := range namespaces {
			updatedExperimentID := ns.DefaultExperimentID
			for _, expInfo := range s.experimentInfos {
				if ns.DefaultExperimentID != nil && expInfo.sourceID == *ns.DefaultExperimentID {
					updatedExperimentID = common.GetPointer[int32](expInfo.destID)
					break
				}
			}
			if err := destTX.Model(
				Namespace{},
			).Where(
				Namespace{ID: ns.ID},
			).Update(
				"default_experiment_id", updatedExperimentID,
			).Error; err != nil {
				return eris.Wrap(err, "error updating destination namespace row")
			}
		}
		log.Infof("Updating namespaces - processed %d records", len(namespaces))
		return nil
	})
	return err
}

// ApplyNamespaceRestriction overwrite Namespace if it is needed.
func ApplyNamespaceRestriction(table string, item map[string]any, namespace *Namespace) map[string]any {
	if namespace != nil {
		if slices.Contains([]string{"apps", "experiments"}, table) {
			item["namespace_id"] = namespace.ID
		}
	}
	return item
}

// EntityLimitedByNamespace represents scope function that limits current query by Namespace.
func EntityLimitedByNamespace(table string, db *gorm.DB, namespace *Namespace) *gorm.DB {
	if namespace != nil {
		switch table {
		case "contexts":
			return db.Joins(
				"LEFT JOIN metrics ON metrics.context_id = contexts.id",
			).Joins(
				"LEFT JOIN runs ON runs.run_uuid = metrics.run_uuid",
			).Joins(
				"LEFT JOIN experiments ON experiments.experiment_id = runs.experiment_id",
			).Where(
				"experiments.namespace_id = ?", namespace.ID,
			)
		case "runs":
			return db.Joins(
				"LEFT JOIN experiments ON experiments.experiment_id = runs.experiment_id",
			).Where(
				"experiments.namespace_id = ?", namespace.ID,
			)
		case "tags", "params", "metrics", "latest_metrics":
			return db.Joins(
				fmt.Sprintf("LEFT JOIN runs ON runs.run_uuid = %s.run_uuid", table),
			).Joins(
				"LEFT JOIN experiments ON experiments.experiment_id = runs.experiment_id",
			).Where(
				"experiments.namespace_id = ?", namespace.ID,
			)
		case "apps", "experiments":
			return db.Where(fmt.Sprintf("%s.namespace_id = ?", table), namespace.ID)
		case "dashboards":
			return db.Joins(
				"LEFT JOIN apps ON apps.id = dashboards.app_id",
			).Where(
				"apps.namespace_id = ?", namespace.ID,
			)
		case "namespaces":
			// if source namespace has been provided, we don't need to import any namespace.
			// just other related data.
			return db.Where("id = ?", -1)
		case "shared_tags":
			return db.Where("shared_tags.namespace_id = ?", namespace.ID)
		case "run_shared_tags":
			// if source namespace has been provided, we don't need to import any namespace.
			// just other related data.
			return db.Joins(
				"LEFT JOIN shared_tags ON shared_tags.id = run_shared_tags.shared_tag_id",
			).Where(
				"shared_tags.namespace_id = ?", namespace.ID,
			)
		}
	}
	return db
}

// WithSourceNamespace sets Importer source Namespace.
func WithSourceNamespace(namespace string) func(importer *Importer) {
	return func(s *Importer) {
		s.sourceNamespaceName = &namespace
	}
}

// WithDestinationNamespace sets Importer destination Namespace.
func WithDestinationNamespace(namespace string) func(importer *Importer) {
	return func(s *Importer) {
		s.destinationNamespaceName = &namespace
	}
}
