//go:build integration
// +build integration

package internal

import (
	"context"
	"crypto/rand"
	"encoding/json"
	"fmt"
	"log"
	"strconv"
	"strings"
	"testing"
	"time"

	"github.com/google/uuid"
	"github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/require"
	"github.com/uptrace/bun"
	"golang.org/x/exp/maps"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/types/known/structpb"
	"google.golang.org/protobuf/types/known/timestamppb"

	apiPkg "github.com/determined-ai/determined/master/internal/api"
	authz2 "github.com/determined-ai/determined/master/internal/authz"
	"github.com/determined-ai/determined/master/internal/db"
	"github.com/determined-ai/determined/master/internal/trials"
	"github.com/determined-ai/determined/master/pkg/model"
	"github.com/determined-ai/determined/master/pkg/protoutils/protoconverter"
	"github.com/determined-ai/determined/master/pkg/ptrs"
	"github.com/determined-ai/determined/proto/pkg/apiv1"
	"github.com/determined-ai/determined/proto/pkg/checkpointv1"
	"github.com/determined-ai/determined/proto/pkg/commonv1"
	"github.com/determined-ai/determined/proto/pkg/trialv1"
)

var inferenceMetricGroup = "inference"

func createTestTrial(
	t *testing.T, api *apiServer, curUser model.User,
) (*model.Trial, *model.Task) {
	exp := createTestExpWithProjectID(t, api, curUser, 1)

	requestID := model.NewRequestID(rand.Reader)
	task := &model.Task{
		TaskType:   model.TaskTypeTrial,
		LogVersion: model.TaskLogVersion1,
		StartTime:  time.Now(),
		TaskID:     trialTaskID(exp.ID, requestID),
	}
	require.NoError(t, db.AddTask(context.TODO(), task))

	trial := &model.Trial{
		StartTime:    time.Now(),
		RequestID:    &requestID,
		State:        model.PausedState,
		ExperimentID: exp.ID,
	}
	require.NoError(t, db.AddTrial(context.TODO(), trial, task.TaskID))

	// Return trial exactly the way the API will generally get it.
	outTrial, err := db.TrialByID(context.TODO(), trial.ID)
	require.NoError(t, err)
	return outTrial, task
}

func createTestTrialWithMetrics(
	ctx context.Context, t *testing.T, api *apiServer, curUser model.User, includeBatchMetrics bool,
) (*model.Trial, map[model.MetricGroup][]*commonv1.Metrics) {
	var trainingMetrics, validationMetrics []*commonv1.Metrics
	trial, _ := createTestTrial(t, api, curUser)
	metrics := make(map[model.MetricGroup][]*commonv1.Metrics)

	for i := 0; i < 10; i++ {
		trainMetrics := &commonv1.Metrics{
			AvgMetrics: &structpb.Struct{
				Fields: map[string]*structpb.Value{
					"epoch": {
						Kind: &structpb.Value_NumberValue{
							NumberValue: float64(i),
						},
					},
					"loss": {
						Kind: &structpb.Value_NumberValue{
							NumberValue: float64(i),
						},
					},

					"zgroup_b/me.t r%i]\\c_1": {
						Kind: &structpb.Value_NumberValue{
							NumberValue: float64(i),
						},
					},
					"textMetric": {
						Kind: &structpb.Value_StringValue{
							StringValue: "random_text",
						},
					},
				},
			},
		}

		step := int32(i)

		group := model.MetricGroup("mygroup")
		_, err := api.ReportTrialMetrics(ctx,
			&apiv1.ReportTrialMetricsRequest{
				Metrics: &trialv1.TrialMetrics{
					TrialId:        int32(trial.ID),
					TrialRunId:     0,
					StepsCompleted: &step,
					Metrics:        trainMetrics,
				},
				Group: group.ToString(),
			})
		require.NoError(t, err)
		metrics[group] = append(metrics[group], trainMetrics)

		if includeBatchMetrics {
			trainMetrics.BatchMetrics = []*structpb.Struct{
				{
					Fields: map[string]*structpb.Value{
						"batch_loss": {
							Kind: &structpb.Value_NumberValue{
								NumberValue: float64(i),
							},
						},
					},
				},
			}
		}

		_, err = api.ReportTrialTrainingMetrics(ctx,
			&apiv1.ReportTrialTrainingMetricsRequest{
				TrainingMetrics: &trialv1.TrialMetrics{
					TrialId:        int32(trial.ID),
					TrialRunId:     0,
					StepsCompleted: &step,
					Metrics:        trainMetrics,
				},
			})
		require.NoError(t, err)
		trainingMetrics = append(trainingMetrics, trainMetrics)

		valMetrics := &commonv1.Metrics{
			AvgMetrics: &structpb.Struct{
				Fields: map[string]*structpb.Value{
					"epoch": {
						Kind: &structpb.Value_NumberValue{
							NumberValue: float64(i),
						},
					},
					"loss": {
						Kind: &structpb.Value_NumberValue{
							NumberValue: float64(i),
						},
					},

					"val_loss2": {
						Kind: &structpb.Value_NumberValue{
							NumberValue: float64(i),
						},
					},
					"textMetric": {
						Kind: &structpb.Value_StringValue{
							StringValue: "random_text",
						},
					},
				},
			},
		}
		_, err = api.ReportTrialValidationMetrics(ctx,
			&apiv1.ReportTrialValidationMetricsRequest{
				ValidationMetrics: &trialv1.TrialMetrics{
					TrialId:        int32(trial.ID),
					TrialRunId:     0,
					StepsCompleted: &step,
					Metrics:        valMetrics,
				},
			})
		require.NoError(t, err)
		validationMetrics = append(validationMetrics, valMetrics)
	}

	metrics[model.TrainingMetricGroup] = trainingMetrics
	metrics[model.ValidationMetricGroup] = validationMetrics

	return trial, metrics
}

func compareMetrics(
	t *testing.T, trialIDs []int,
	resp []*trialv1.MetricsReport, expected []*commonv1.Metrics, isValidation bool,
) {
	require.NotNil(t, resp)

	trialIndex := 0
	totalBatches := 0
	for i, actual := range resp {
		if i != 0 && i%(len(expected)/len(trialIDs)) == 0 {
			trialIndex++
			totalBatches = 0
		}

		metrics := map[string]any{
			"avg_metrics":   expected[i].AvgMetrics.AsMap(),
			"batch_metrics": nil,
		}
		if expected[i].BatchMetrics != nil {
			var batchMetrics []any
			for _, b := range expected[i].BatchMetrics {
				batchMetrics = append(batchMetrics, b.AsMap())
			}
			metrics["batch_metrics"] = batchMetrics
		}
		if isValidation {
			metrics = map[string]any{
				"validation_metrics": expected[i].AvgMetrics.AsMap(),
			}
		}
		protoStruct, err := structpb.NewStruct(metrics)
		require.NoError(t, err)

		expectedRow := &trialv1.MetricsReport{
			TrialId:      int32(trialIDs[trialIndex]),
			EndTime:      actual.EndTime,
			Metrics:      protoStruct,
			TotalBatches: int32(totalBatches),
			TrialRunId:   int32(0),
			Id:           actual.Id,
		}
		proto.Equal(actual, expectedRow)
		require.Equal(t, expectedRow.Metrics.AsMap(), actual.Metrics.AsMap())

		totalBatches++
	}
}

func isMultiTrialSampleCorrect(expectedMetrics []*commonv1.Metrics,
	actualMetrics *apiv1.DownsampledMetrics,
) bool {
	// Checking if metric names and their values are equal.
	for i := 0; i < len(actualMetrics.Data); i++ {
		allActualAvgMetrics := actualMetrics.Data
		epoch := int(*allActualAvgMetrics[i].Epoch)
		// use epoch to match because in downsampling returned values are randomized.
		expectedAvgMetrics := expectedMetrics[epoch].AvgMetrics.AsMap()
		for metricName := range expectedAvgMetrics {
			actualAvgMetrics := allActualAvgMetrics[i].Values.AsMap()
			switch expectedAvgMetrics[metricName].(type) { //nolint:gocritic
			case float64:
				expectedVal := expectedAvgMetrics[metricName].(float64)
				if metricName == "epoch" {
					if expectedVal != *allActualAvgMetrics[i].Epoch {
						return false
					}
					continue
				}
				if actualAvgMetrics[metricName] == nil {
					return false
				}
				actualVal := actualAvgMetrics[metricName].(float64)
				if expectedVal != actualVal {
					return false
				}
			case string:
				if actual, ok := actualAvgMetrics[metricName].(string); !ok {
					return false
				} else if actual != expectedAvgMetrics[metricName].(string) {
					return false
				}
			default:
				panic("unexpected metric type in multi trial sample")
			}
		}
	}
	return true
}

func TestMultiTrialSampleSpecialMetrics(t *testing.T) {
	api, curUser, ctx := setupAPITest(t, nil)

	trial, _ := createTestTrialWithMetrics(
		ctx, t, api, curUser, false)

	maxDataPoints := 7

	actualMetrics, err := api.multiTrialSample(int32(trial.ID), []string{},
		"", maxDataPoints, 0, 10, nil, []string{
			"mygroup.zgroup_b/me.t r%i]\\c_1",
		})
	require.Len(t, actualMetrics, 1)
	require.NoError(t, err)
	mygroup := actualMetrics[0]
	require.Len(t, mygroup.Data, maxDataPoints)
	require.Len(t, mygroup.Data[0].Values.AsMap(), 1)
}

func TestMultiTrialSampleMetrics(t *testing.T) {
	api, curUser, ctx := setupAPITest(t, nil)

	trial, expectedMetrics := createTestTrialWithMetrics(
		ctx, t, api, curUser, false)

	expectedTrainMetrics := expectedMetrics[model.TrainingMetricGroup]
	expectedValMetrics := expectedMetrics[model.ValidationMetricGroup]
	maxDataPoints := 7

	var trainMetricNames []string
	var metricIds []string
	for metricName := range expectedTrainMetrics[0].AvgMetrics.AsMap() {
		trainMetricNames = append(trainMetricNames, metricName)
		metricIds = append(metricIds, "training."+metricName)
	}
	actualTrainingMetrics, err := api.multiTrialSample(int32(trial.ID), trainMetricNames,
		model.TrainingMetricGroup, maxDataPoints, 0, 10, nil, []string{})
	require.NoError(t, err)
	require.Len(t, actualTrainingMetrics, 1)

	var validationMetricNames []string
	for metricName := range expectedValMetrics[0].AvgMetrics.AsMap() {
		validationMetricNames = append(validationMetricNames, metricName)
		metricIds = append(metricIds, "validation."+metricName)
	}
	actualValidationTrainingMetrics, err := api.multiTrialSample(int32(trial.ID),
		validationMetricNames, model.ValidationMetricGroup, maxDataPoints,
		0, 10, nil, []string{})
	require.Len(t, actualValidationTrainingMetrics, 1)
	require.NoError(t, err)

	var genericMetricNames []string
	for metricName := range expectedValMetrics[0].AvgMetrics.AsMap() {
		genericMetricNames = append(genericMetricNames, metricName)
		metricIds = append(metricIds, "mygroup."+metricName)
	}
	actualGenericTrainingMetrics, err := api.multiTrialSample(int32(trial.ID),
		genericMetricNames, model.MetricGroup("mygroup"), maxDataPoints,
		0, 10, nil, []string{})
	require.Len(t, actualGenericTrainingMetrics, 1)
	require.NoError(t, err)

	require.True(t, isMultiTrialSampleCorrect(expectedTrainMetrics, actualTrainingMetrics[0]))
	require.True(t, isMultiTrialSampleCorrect(expectedValMetrics, actualValidationTrainingMetrics[0]))

	actualAllMetrics, err := api.multiTrialSample(int32(trial.ID), []string{},
		"", maxDataPoints, 0, 10, nil, metricIds)
	require.Len(t, actualAllMetrics, 3)
	require.NoError(t, err)
	require.Len(t, actualAllMetrics[1].Data, maxDataPoints) // max datapoints check
	require.Len(t, actualAllMetrics[2].Data, maxDataPoints) // max datapoints check
	require.True(t, isMultiTrialSampleCorrect(expectedTrainMetrics, actualAllMetrics[1]))
	require.True(t, isMultiTrialSampleCorrect(expectedValMetrics, actualAllMetrics[2]))
}

func TestStreamTrainingMetrics(t *testing.T) {
	api, curUser, ctx := setupAPITest(t, nil)

	var trials []*model.Trial
	var trainingMetrics, validationMetrics [][]*commonv1.Metrics
	for _, haveBatchMetrics := range []bool{false, true} {
		trial, metrics := createTestTrialWithMetrics(
			ctx, t, api, curUser, haveBatchMetrics)
		trials = append(trials, trial)
		trainMetrics := metrics[model.TrainingMetricGroup]
		valMetrics := metrics[model.ValidationMetricGroup]
		trainingMetrics = append(trainingMetrics, trainMetrics)
		validationMetrics = append(validationMetrics, valMetrics)
	}

	cases := []struct {
		requestFunc  func(trialIDs []int32) ([]*trialv1.MetricsReport, error)
		metrics      [][]*commonv1.Metrics
		isValidation bool
	}{
		{
			func(trialIDs []int32) ([]*trialv1.MetricsReport, error) {
				res := &mockStream[*apiv1.GetTrainingMetricsResponse]{ctx: ctx}
				err := api.GetTrainingMetrics(&apiv1.GetTrainingMetricsRequest{
					TrialIds: trialIDs,
				}, res)
				if err != nil {
					return nil, err
				}
				var out []*trialv1.MetricsReport
				for _, d := range res.getData() {
					out = append(out, d.Metrics...)
				}
				return out, nil
			}, trainingMetrics, false,
		},
		{
			func(trialIDs []int32) ([]*trialv1.MetricsReport, error) {
				res := &mockStream[*apiv1.GetValidationMetricsResponse]{ctx: ctx}
				err := api.GetValidationMetrics(&apiv1.GetValidationMetricsRequest{
					TrialIds: trialIDs,
				}, res)
				if err != nil {
					return nil, err
				}
				var out []*trialv1.MetricsReport
				for _, d := range res.getData() {
					out = append(out, d.Metrics...)
				}
				return out, nil
			}, validationMetrics, true,
		},
	}
	for _, curCase := range cases {
		// No trial IDs.
		_, err := curCase.requestFunc([]int32{})
		require.Error(t, err)
		require.Equal(t, codes.InvalidArgument, status.Code(err))

		// Trial IDs not found.
		_, err = curCase.requestFunc([]int32{-1})
		require.Equal(t, codes.NotFound, status.Code(err))

		// One trial.
		resp, err := curCase.requestFunc([]int32{int32(trials[0].ID)})
		require.NoError(t, err)
		compareMetrics(t, []int{trials[0].ID}, resp, curCase.metrics[0], curCase.isValidation)

		// Other trial.
		resp, err = curCase.requestFunc([]int32{int32(trials[1].ID)})
		require.NoError(t, err)
		compareMetrics(t, []int{trials[1].ID}, resp, curCase.metrics[1], curCase.isValidation)

		// Both trials.
		resp, err = curCase.requestFunc([]int32{int32(trials[1].ID), int32(trials[0].ID)})
		require.NoError(t, err)
		compareMetrics(t, []int{trials[0].ID, trials[1].ID}, resp,
			append(curCase.metrics[0], curCase.metrics[1]...), curCase.isValidation)
	}
}

func TestNonNumericEpochMetric(t *testing.T) {
	api, curUser, ctx := setupAPITest(t, nil)
	expectedMetricsMap := map[string]any{
		"numeric_met": 1.5,
		"epoch":       "x",
	}
	expectedMetrics, err := structpb.NewStruct(expectedMetricsMap)
	require.NoError(t, err)

	trial, _ := createTestTrial(t, api, curUser)
	step := int32(1)
	_, err = api.ReportTrialValidationMetrics(ctx, &apiv1.ReportTrialValidationMetricsRequest{
		ValidationMetrics: &trialv1.TrialMetrics{
			TrialId:        int32(trial.ID),
			TrialRunId:     0,
			StepsCompleted: &step,
			Metrics: &commonv1.Metrics{
				AvgMetrics: expectedMetrics,
			},
		},
	})
	require.Equal(t, "cannot add metric with non numeric 'epoch' value got x", err.Error())
}

func TestTrialsNonNumericMetrics(t *testing.T) {
	api, curUser, ctx := setupAPITest(t, nil)

	expectedMetricsMap := map[string]any{
		"string_met":  "abc",
		"numeric_met": 1.5,
		"date_met":    "2021-03-15T13:32:18.91626111111Z",
		"bool_met":    false,
		"null_met":    nil,
	}
	expectedMetrics, err := structpb.NewStruct(expectedMetricsMap)
	require.NoError(t, err)

	step := int32(1)

	trial, _ := createTestTrial(t, api, curUser)
	_, err = api.ReportTrialMetrics(ctx, &apiv1.ReportTrialMetricsRequest{
		Metrics: &trialv1.TrialMetrics{
			TrialId:        int32(trial.ID),
			TrialRunId:     0,
			StepsCompleted: &step,
			Metrics: &commonv1.Metrics{
				AvgMetrics: expectedMetrics,
			},
		},
		Group: model.ValidationMetricGroup.ToString(),
	})
	require.NoError(t, err)

	t.Run("CompareTrialsNonNumeric", func(t *testing.T) {
		resp, err := api.CompareTrials(ctx, &apiv1.CompareTrialsRequest{
			TrialIds:    []int32{int32(trial.ID)},
			MetricNames: maps.Keys(expectedMetricsMap),
		})
		require.NoError(t, err)

		require.Len(t, resp.Trials, 1)
		require.Len(t, resp.Trials[0].Metrics, 1)
		require.Len(t, resp.Trials[0].Metrics[0].Data, 1)
		require.Equal(t, expectedMetricsMap, resp.Trials[0].Metrics[0].Data[0].Values.AsMap())
	})

	t.Run("TrialsSample", func(t *testing.T) {
		_, err := db.Bun().NewUpdate().Table("experiments").
			Set("config = jsonb_set(config, '{searcher,name}', ?, true)", `"custom"`).
			Where("id = ?", trial.ExperimentID).
			Exec(ctx)
		require.NoError(t, err)

		for metricName := range expectedMetricsMap {
			childCtx, cancel := context.WithCancel(ctx)
			resp := &mockStream[*apiv1.TrialsSampleResponse]{ctx: childCtx}
			go func() {
				for i := 0; i < 100; i++ {
					if len(resp.getData()) > 0 {
						cancel()
					}
					time.Sleep(50 * time.Millisecond)
				}
				cancel()
			}()

			err = api.TrialsSample(&apiv1.TrialsSampleRequest{
				ExperimentId:  int32(trial.ExperimentID),
				Group:         "validation",
				MetricName:    metricName,
				PeriodSeconds: 1,
			}, resp)
			require.NoError(t, err)

			data := resp.getData()
			require.NotEmpty(t, data)
			require.Len(t, data[0].Trials, 1)
			require.Len(t, data[0].Trials[0].Data, 1)
			require.Equal(t, map[string]any{
				metricName: expectedMetricsMap[metricName],
			}, resp.data[0].Trials[0].Data[0].Values.AsMap())
		}
	})
}

func TestReportCheckpoint(t *testing.T) {
	api, curUser, ctx := setupAPITest(t, nil)

	tr, task := createTestTrial(t, api, curUser)

	checkpointMeta, err := structpb.NewStruct(map[string]any{
		"steps_completed": 1,
	})
	require.NoError(t, err)

	checkpointID := uuid.New().String()
	req := &apiv1.ReportCheckpointRequest{
		Checkpoint: &checkpointv1.Checkpoint{
			TaskId:       string(task.TaskID),
			AllocationId: nil,
			Uuid:         checkpointID,
			ReportTime:   timestamppb.New(time.Now().Truncate(time.Millisecond)),
			Resources:    nil,
			Metadata:     checkpointMeta,
			State:        checkpointv1.State_STATE_COMPLETED,
		},
	}
	_, err = api.ReportCheckpoint(ctx, req)
	require.NoError(t, err)

	c, err := api.GetCheckpoint(ctx, &apiv1.GetCheckpointRequest{
		CheckpointUuid: checkpointID,
	})
	require.NoError(t, err)

	jsonActual, err := json.MarshalIndent(c.Checkpoint, "", "\t")
	require.NoError(t, err)

	getExperimentResp, err := api.GetExperiment(ctx, &apiv1.GetExperimentRequest{
		ExperimentId: int32(tr.ExperimentID),
	})
	require.NoError(t, err)

	int32TrialID := int32(tr.ID)
	int32ExperimentID := int32(tr.ExperimentID)
	req.Checkpoint.Training = &checkpointv1.CheckpointTrainingMetadata{
		TrialId:           &int32TrialID,
		ExperimentId:      &int32ExperimentID,
		ExperimentConfig:  getExperimentResp.Config,
		Hparams:           nil,
		TrainingMetrics:   &commonv1.Metrics{},
		ValidationMetrics: &commonv1.Metrics{},
	}
	jsonExpected, err := json.MarshalIndent(req.Checkpoint, "", "\t")
	require.NoError(t, err)

	require.Equal(t, string(jsonExpected), string(jsonActual))
}

// This may have worked at some point but this definitely doesn't work after
// trial one to many tasks since we switched the fk reference for some reason.
func TestReportCheckpointNonTrialErrors(t *testing.T) {
	api, _, ctx := setupAPITest(t, nil)

	notebookTask := mockNotebookWithWorkspaceID(ctx, t, 1)

	checkpointMeta, err := structpb.NewStruct(map[string]any{
		"steps_completed": 1,
	})
	require.NoError(t, err)

	checkpointID := uuid.New().String()
	req := &apiv1.ReportCheckpointRequest{
		Checkpoint: &checkpointv1.Checkpoint{
			TaskId:       string(notebookTask),
			AllocationId: nil,
			Uuid:         checkpointID,
			ReportTime:   timestamppb.New(time.Now().Truncate(time.Millisecond)),
			Resources:    nil,
			Metadata:     checkpointMeta,
			State:        checkpointv1.State_STATE_COMPLETED,
		},
	}
	_, err = api.ReportCheckpoint(ctx, req)
	require.ErrorContains(t, err, "can only report checkpoints on trial's tasks")
}

func TestUnusualMetricNames(t *testing.T) {
	api, curUser, ctx := setupAPITest(t, nil)
	longMetric := strings.Repeat("a", 100)
	expectedMetricsMap := map[string]any{
		"a.loss":   1.5,
		"b/loss":   2.5,
		longMetric: 4.5,
	}
	asciiSweep := ""
	for i := 1; i <= 255; i++ {
		asciiSweep += fmt.Sprintf("%c", i)
	}
	expectedMetricsMap[asciiSweep] = 3
	expectedMetrics, err := structpb.NewStruct(expectedMetricsMap)
	require.NoError(t, err)

	step := int32(1)

	trial, _ := createTestTrial(t, api, curUser)
	_, err = api.ReportTrialValidationMetrics(ctx, &apiv1.ReportTrialValidationMetricsRequest{
		ValidationMetrics: &trialv1.TrialMetrics{
			TrialId:        int32(trial.ID),
			TrialRunId:     0,
			StepsCompleted: &step,
			Metrics: &commonv1.Metrics{
				AvgMetrics: expectedMetrics,
			},
		},
	})
	require.NoError(t, err)

	req := &apiv1.CompareTrialsRequest{
		TrialIds:      []int32{int32(trial.ID)},
		MaxDatapoints: 3,
		MetricNames:   []string{"a.loss", "b/loss", asciiSweep, longMetric},
		StartBatches:  0,
		EndBatches:    1000,
		MetricType:    apiv1.MetricType_METRIC_TYPE_VALIDATION,
	}
	resp, err := api.CompareTrials(ctx, req)
	require.NoError(t, err)

	valuesMap := resp.Trials[0].Metrics[0].Data[0].Values.AsMap()
	require.Len(t, valuesMap, len(expectedMetricsMap))
	require.Equal(t, valuesMap[longMetric], expectedMetricsMap[longMetric])
}

func TestTrialAuthZ(t *testing.T) {
	api, authZExp, _, curUser, ctx := setupExpAuthTest(t, nil)
	authZNSC := setupNSCAuthZ()
	workspaceAuthZ := setupWorkspaceAuthZ()
	trial, _ := createTestTrial(t, api, curUser)

	mockUserArg := mock.MatchedBy(func(u model.User) bool {
		return u.ID == curUser.ID
	})

	cases := []struct {
		DenyFuncName   string
		IDToReqCall    func(id int) error
		SkipActionFunc bool
	}{
		{"CanGetExperimentArtifacts", func(id int) error {
			return api.TrialLogs(&apiv1.TrialLogsRequest{
				TrialId: int32(id),
			}, &mockStream[*apiv1.TrialLogsResponse]{ctx: ctx})
		}, false},
		{"CanGetExperimentArtifacts", func(id int) error {
			return api.TrialLogsFields(&apiv1.TrialLogsFieldsRequest{
				TrialId: int32(id),
			}, &mockStream[*apiv1.TrialLogsFieldsResponse]{ctx: ctx})
		}, false},
		{"CanGetExperimentArtifacts", func(id int) error {
			_, err := api.GetTrialCheckpoints(ctx, &apiv1.GetTrialCheckpointsRequest{
				Id: int32(id),
			})
			return err
		}, false},
		{"CanEditExperiment", func(id int) error {
			_, err := api.KillTrial(ctx, &apiv1.KillTrialRequest{
				Id: int32(id),
			})
			return err
		}, false},
		{"CanGetExperimentArtifacts", func(id int) error {
			_, err := api.GetTrial(ctx, &apiv1.GetTrialRequest{
				TrialId: int32(id),
			})
			return err
		}, true},
		{"CanGetExperimentArtifacts", func(id int) error {
			_, err := api.CompareTrials(ctx, &apiv1.CompareTrialsRequest{
				TrialIds: []int32{int32(id)},
			})
			return err
		}, false},
		{"CanGetExperimentArtifacts", func(id int) error {
			_, err := api.GetTrialWorkloads(ctx, &apiv1.GetTrialWorkloadsRequest{
				TrialId: int32(id),
			})
			return err
		}, false},
		{"CanGetExperimentArtifacts", func(id int) error {
			return api.GetTrialProfilerMetrics(&apiv1.GetTrialProfilerMetricsRequest{
				Labels: &trialv1.TrialProfilerMetricLabels{TrialId: int32(id)},
			}, &mockStream[*apiv1.GetTrialProfilerMetricsResponse]{ctx: ctx})
		}, false},
		{"CanGetExperimentArtifacts", func(id int) error {
			return api.GetTrialProfilerAvailableSeries(
				&apiv1.GetTrialProfilerAvailableSeriesRequest{
					TrialId: int32(id),
				}, &mockStream[*apiv1.GetTrialProfilerAvailableSeriesResponse]{ctx: ctx})
		}, false},
		{"CanEditExperiment", func(id int) error {
			_, err := api.PostTrialProfilerMetricsBatch(ctx,
				&apiv1.PostTrialProfilerMetricsBatchRequest{
					Batches: []*trialv1.TrialProfilerMetricsBatch{
						{
							Labels: &trialv1.TrialProfilerMetricLabels{TrialId: int32(id)},
						},
					},
				})
			return err
		}, false},
		{"CanEditExperiment", func(id int) error {
			_, err := api.ReportTrialSearcherEarlyExit(ctx,
				&apiv1.ReportTrialSearcherEarlyExitRequest{
					TrialId: int32(id),
				})
			return err
		}, false},
		{"CanEditExperiment", func(id int) error {
			_, err := api.ReportTrialProgress(ctx,
				&apiv1.ReportTrialProgressRequest{
					TrialId: int32(id),
				})
			return err
		}, false},
		{"CanEditExperiment", func(id int) error {
			_, err := api.ReportTrialTrainingMetrics(ctx,
				&apiv1.ReportTrialTrainingMetricsRequest{
					TrainingMetrics: &trialv1.TrialMetrics{TrialId: int32(id)},
				})
			return err
		}, false},
		{"CanEditExperiment", func(id int) error {
			_, err := api.ReportTrialValidationMetrics(ctx,
				&apiv1.ReportTrialValidationMetricsRequest{
					ValidationMetrics: &trialv1.TrialMetrics{TrialId: int32(id)},
				})
			return err
		}, false},
		{"CanEditExperiment", func(id int) error {
			_, err := api.PostTrialRunnerMetadata(ctx, &apiv1.PostTrialRunnerMetadataRequest{
				TrialId: int32(id),
			})
			return err
		}, false},
		{"CanGetExperimentArtifacts", func(id int) error {
			authZNSC.On("CanGetTensorboard", mock.Anything, mockUserArg, mock.Anything, mock.Anything,
				mock.Anything).Return(nil).Once()
			workspaceAuthZ.On("CanGetWorkspace", mock.Anything, mock.Anything, mock.Anything).
				Return(nil).Once()
			mockRM := MockRM()
			mockRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil)
			api.m.rm = mockRM
			_, err := api.LaunchTensorboard(ctx, &apiv1.LaunchTensorboardRequest{
				TrialIds: []int32{int32(id)},
			})
			return err
		}, false},
		{"CanEditExperiment", func(id int) error {
			req := &apiv1.ReportTrialSourceInfoRequest{TrialSourceInfo: &trialv1.TrialSourceInfo{
				TrialId:             int32(id),
				CheckpointUuid:      uuid.NewString(),
				TrialSourceInfoType: trialv1.TrialSourceInfoType_TRIAL_SOURCE_INFO_TYPE_INFERENCE,
			}}
			_, err := api.ReportTrialSourceInfo(ctx, req)
			return err
		}, false},
	}

	for _, curCase := range cases {
		require.ErrorIs(t, curCase.IDToReqCall(-999), apiPkg.NotFoundErrs("trial", "-999", true))
		// Can't view trials experiment gives same error.
		authZExp.On("CanGetExperiment", mock.Anything, mockUserArg, mock.Anything).
			Return(authz2.PermissionDeniedError{}).Once()
		require.ErrorIs(t, curCase.IDToReqCall(trial.ID),
			apiPkg.NotFoundErrs("trial", strconv.Itoa(trial.ID), true))

		// Experiment view error returns error unmodified.
		expectedErr := fmt.Errorf("canGetTrialError")
		authZExp.On("CanGetExperiment", mock.Anything, mockUserArg, mock.Anything).
			Return(expectedErr).Once()
		require.ErrorIs(t, curCase.IDToReqCall(trial.ID), expectedErr)

		// Action func error returns error in forbidden.
		expectedErr = status.Error(codes.PermissionDenied, curCase.DenyFuncName+"Error")
		authZExp.On("CanGetExperiment", mock.Anything, mockUserArg, mock.Anything).
			Return(nil).Once()
		authZExp.On(curCase.DenyFuncName, mock.Anything, mockUserArg, mock.Anything).
			Return(fmt.Errorf(curCase.DenyFuncName + "Error")).Once()
		require.ErrorIs(t, curCase.IDToReqCall(trial.ID), expectedErr)
	}
}

func TestTrialProtoTaskIDs(t *testing.T) {
	api, curUser, ctx := setupAPITest(t, nil)
	trial, task0 := createTestTrial(t, api, curUser)

	_, err := db.Bun().NewUpdate().Table("experiments").
		Set("best_trial_id = ?", trial.ID).
		Where("id = ?", trial.ExperimentID).Exec(ctx)
	require.NoError(t, err)

	task1 := &model.Task{
		TaskType:   model.TaskTypeTrial,
		LogVersion: model.TaskLogVersion1,
		StartTime:  task0.StartTime.Add(time.Second),
		TaskID:     trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
	}
	require.NoError(t, db.AddTask(ctx, task1))

	task2 := &model.Task{
		TaskType:   model.TaskTypeTrial,
		LogVersion: model.TaskLogVersion1,
		StartTime:  task1.StartTime.Add(time.Second),
		TaskID:     trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
	}
	require.NoError(t, db.AddTask(ctx, task2))

	_, err = db.Bun().NewInsert().Model(&[]model.RunTaskID{
		{RunID: trial.ID, TaskID: task1.TaskID},
		{RunID: trial.ID, TaskID: task2.TaskID},
	}).Exec(ctx)
	require.NoError(t, err)

	taskIDs := []string{string(task0.TaskID), string(task1.TaskID), string(task2.TaskID)}

	cases := []struct {
		name string
		f    func(t *testing.T) *trialv1.Trial
	}{
		{"GetTrial", func(t *testing.T) *trialv1.Trial {
			resp, err := api.GetTrial(ctx, &apiv1.GetTrialRequest{
				TrialId: int32(trial.ID),
			})
			require.NoError(t, err)
			return resp.Trial
		}},
		{"GetExperimentTrials", func(t *testing.T) *trialv1.Trial {
			resp, err := api.GetExperimentTrials(ctx, &apiv1.GetExperimentTrialsRequest{
				ExperimentId: int32(trial.ExperimentID),
			})
			require.NoError(t, err)
			require.Len(t, resp.Trials, 1)
			return resp.Trials[0]
		}},
		/* CompareTrials previously and now sends TaskID="". We also will send TaskIDs=[].
		{"CompareTrials", func(t *testing.T) *trialv1.Trial {
			resp, err := api.CompareTrials(ctx, &apiv1.CompareTrialsRequest{
				TrialIds: []int32{int32(trial.ID)},
			})
			require.NoError(t, err)
			require.Len(t, resp.Trials, 1)
			return resp.Trials[0].Trial
		}}, */
		{"SearchExperiments", func(t *testing.T) *trialv1.Trial {
			resp, err := api.SearchExperiments(ctx, &apiv1.SearchExperimentsRequest{
				Filter: ptrs.Ptr(
					fmt.Sprintf(
						`{"filterGroup":
	{"children":[{"columnName":"id","kind":"field","operator":"=","value":%d}],
		"conjunction":"and","kind":"group"},"showArchived":false}`, trial.ExperimentID)),
			})
			require.NoError(t, err)
			require.Len(t, resp.Experiments, 1)
			return resp.Experiments[0].BestTrial
		}},
	}
	for _, c := range cases {
		t.Run(c.name, func(t *testing.T) {
			resp := c.f(t)
			// Still test deprecated field for TaskID.
			require.Equal(t, string(task0.TaskID), resp.TaskId) // nolint: staticcheck
			require.Equal(t, taskIDs, resp.TaskIds)
		})
	}
}

func TestExperimentIDFromTrialTaskID(t *testing.T) {
	api, curUser, _ := setupAPITest(t, nil)

	trial, task := createTestTrial(t, api, curUser)
	actual, err := experimentIDFromTrialTaskID(task.TaskID)
	require.NoError(t, err)
	require.Equal(t, trial.ExperimentID, actual)

	notTrialTask := &model.Task{
		TaskType:   model.TaskTypeTrial,
		LogVersion: model.TaskLogVersion1,
		StartTime:  time.Now(),
		TaskID:     model.TaskID(uuid.New().String()),
	}
	require.NoError(t, db.AddTask(context.TODO(), task))
	_, err = experimentIDFromTrialTaskID(notTrialTask.TaskID)
	require.ErrorIs(t, err, errIsNotTrialTaskID)

	_, err = experimentIDFromTrialTaskID(model.TaskID(uuid.New().String()))
	require.ErrorIs(t, err, errIsNotTrialTaskID)
}

func TestTrialLogsBackported(t *testing.T) {
	api, curUser, ctx := setupAPITest(t, nil)

	exp := createTestExpWithProjectID(t, api, curUser, 1)
	task := &model.Task{
		TaskType:   model.TaskTypeTrial,
		LogVersion: model.TaskLogVersion1,
		StartTime:  time.Now(),
		TaskID:     model.TaskID(fmt.Sprintf("backported.%d", exp.ID)),
	}
	require.NoError(t, db.AddTask(ctx, task))

	trial := &model.Trial{
		StartTime:    time.Now(),
		State:        model.PausedState,
		ExperimentID: exp.ID,
	}
	require.NoError(t, db.AddTrial(context.TODO(), trial, task.TaskID))

	expected := []*model.TaskLog{
		{TaskID: string(task.TaskID), Log: "test"},
	}
	require.NoError(t, api.m.db.AddTaskLogs(expected))

	stream := &mockStream[*apiv1.TrialLogsResponse]{ctx: ctx}
	err := api.TrialLogs(&apiv1.TrialLogsRequest{
		TrialId: int32(trial.ID),
	}, stream)
	require.NoError(t, err)

	actual := stream.getData()
	require.Len(t, actual, len(expected))
	for i, expected := range expected {
		require.Equal(t, expected.Log, *actual[i].Log)
	}
}

func TestTrialLogs(t *testing.T) {
	api, curUser, ctx := setupAPITest(t, nil)
	trial, task0 := createTestTrial(t, api, curUser)

	task1 := &model.Task{
		TaskType:   model.TaskTypeTrial,
		LogVersion: model.TaskLogVersion1,
		StartTime:  task0.StartTime.Add(time.Second),
		TaskID:     trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
	}
	require.NoError(t, db.AddTask(ctx, task1))

	task2 := &model.Task{
		TaskType:   model.TaskTypeTrial,
		LogVersion: model.TaskLogVersion1,
		StartTime:  task1.StartTime.Add(time.Second),
		TaskID:     trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
	}
	require.NoError(t, db.AddTask(ctx, task2))

	_, err := db.Bun().NewInsert().Model(&[]model.RunTaskID{
		{RunID: trial.ID, TaskID: task1.TaskID},
		{RunID: trial.ID, TaskID: task2.TaskID},
	}).Exec(ctx)
	require.NoError(t, err)

	var expected []string
	tasks := []model.TaskID{task0.TaskID, task1.TaskID, task2.TaskID}
	var taskLogs []*model.TaskLog
	for i, taskID := range tasks {
		for j := 0; j < 9; j++ {
			log := fmt.Sprintf("%d-%d\n", i, j)
			taskLogs = append(taskLogs, &model.TaskLog{TaskID: string(taskID), Log: log})
			expected = append(expected, log)
		}
	}
	require.NoError(t, api.m.db.AddTaskLogs(taskLogs))

	stream := &mockStream[*apiv1.TrialLogsResponse]{ctx: ctx}

	err = api.TrialLogs(&apiv1.TrialLogsRequest{
		TrialId: int32(trial.ID),
	}, stream)
	require.NoError(t, err)

	require.Len(t, stream.data, len(expected))
	for i, expected := range expected {
		require.Equal(t, expected, *stream.data[i].Log)
	}

	// Retry with follow.
	newStream := &mockStream[*apiv1.TrialLogsResponse]{ctx: ctx}
	done := make(chan error, 1)
	go func() {
		done <- api.TrialLogs(&apiv1.TrialLogsRequest{
			TrialId: int32(trial.ID),
			Follow:  true,
		}, newStream)
	}()

	// Send a new log.
	log := "new log\n"
	require.NoError(t, api.m.db.AddTaskLogs(
		[]*model.TaskLog{{TaskID: string(task2.TaskID), Log: log}}))
	expected = append(expected, log)

	// Ensure we are still following.
	select {
	case <-done:
		require.NoError(t, err)
		t.Fatal("follow isn't following task logs")
	case <-time.After(10 * time.Second):
	}

	// Note we only update the latest task. We only care about the latest task in following.
	_, err = db.Bun().NewUpdate().Table("tasks").
		Set("end_time = ?", time.Now().Add(-time.Hour)). // An hour ago to avoid termination delay.
		Where("task_id = ?", task2.TaskID).
		Exec(ctx)
	require.NoError(t, err)

	select {
	case <-done:
	case <-time.After(30 * time.Second):
		t.Fatal("follow is following too long task logs")
	}

	actual := newStream.getData()
	require.Len(t, actual, len(expected))
	for i, expected := range expected {
		require.Equal(t, expected, *actual[i].Log)
	}
}

func TestTrialLogFields(t *testing.T) {
	api, curUser, ctx := setupAPITest(t, nil)
	trial, task0 := createTestTrial(t, api, curUser)

	task1 := &model.Task{
		TaskType:   model.TaskTypeTrial,
		LogVersion: model.TaskLogVersion1,
		StartTime:  task0.StartTime.Add(time.Second),
		TaskID:     trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
	}
	require.NoError(t, db.AddTask(ctx, task1))

	task2 := &model.Task{
		TaskType:   model.TaskTypeTrial,
		LogVersion: model.TaskLogVersion1,
		StartTime:  task1.StartTime.Add(time.Second),
		TaskID:     trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
	}
	require.NoError(t, db.AddTask(ctx, task2))

	_, err := db.Bun().NewInsert().Model(&[]model.RunTaskID{
		{RunID: trial.ID, TaskID: task1.TaskID},
		{RunID: trial.ID, TaskID: task2.TaskID},
	}).Exec(ctx)
	require.NoError(t, err)

	expectedContainerIDs := make(map[string]bool)
	tasks := []model.TaskID{task0.TaskID, task1.TaskID, task2.TaskID}
	var taskLogs []*model.TaskLog
	for i, taskID := range tasks {
		containerID := fmt.Sprintf("id-%d", i)
		taskLogs = append(taskLogs, &model.TaskLog{
			TaskID:      string(taskID),
			Log:         "test log",
			ContainerID: ptrs.Ptr(containerID),
		})
		expectedContainerIDs[containerID] = true
	}
	require.NoError(t, api.m.db.AddTaskLogs(taskLogs))

	stream := &mockStream[*apiv1.TrialLogsFieldsResponse]{ctx: ctx}

	err = api.TrialLogsFields(&apiv1.TrialLogsFieldsRequest{
		TrialId: int32(trial.ID),
	}, stream)
	require.NoError(t, err)

	actualContainerIDs := make(map[string]bool)
	for _, s := range stream.getData() {
		for _, containerID := range s.ContainerIds {
			actualContainerIDs[containerID] = true
		}
	}
	require.Equal(t, expectedContainerIDs, actualContainerIDs)

	// Retry with follow.
	newStream := &mockStream[*apiv1.TrialLogsFieldsResponse]{ctx: ctx}
	done := make(chan error, 1)
	go func() {
		done <- api.TrialLogsFields(&apiv1.TrialLogsFieldsRequest{
			TrialId: int32(trial.ID),
			Follow:  true,
		}, newStream)
	}()

	// Send a new log.
	containerID := "newContainerID"
	require.NoError(t, api.m.db.AddTaskLogs(
		[]*model.TaskLog{{
			TaskID:      string(task2.TaskID),
			Log:         "test log",
			ContainerID: ptrs.Ptr(containerID),
		}}))
	expectedContainerIDs[containerID] = true

	// Ensure we are still following.
	select {
	case <-done:
		require.NoError(t, err)
		t.Fatal("follow isn't following task logs")
	case <-time.After(10 * time.Second):
	}

	// Note we only update the latest task. We only care about the latest task in following.
	_, err = db.Bun().NewUpdate().Table("tasks").
		Set("end_time = ?", time.Now().Add(-time.Hour)). // An hour ago to avoid termination delay.
		Where("task_id = ?", task2.TaskID).
		Exec(ctx)
	require.NoError(t, err)

	select {
	case <-done:
	case <-time.After(30 * time.Second):
		t.Fatal("follow is following too long task logs")
	}

	actualContainerIDs = make(map[string]bool)
	for _, s := range newStream.getData() {
		for _, containerID := range s.ContainerIds {
			actualContainerIDs[containerID] = true
		}
	}
	require.Equal(t, expectedContainerIDs, actualContainerIDs)
}

func compareTrialsResponseToBatches(resp *apiv1.CompareTrialsResponse) []int32 {
	compTrial := resp.Trials[0]
	compMetrics := compTrial.Metrics[0]

	sampleBatches := []int32{}

	for _, m := range compMetrics.Data {
		sampleBatches = append(sampleBatches, m.Batches)
	}

	return sampleBatches
}

func TestCompareTrialsSampling(t *testing.T) {
	api, curUser, ctx := setupAPITest(t, nil)

	trial, _ := createTestTrialWithMetrics(
		ctx, t, api, curUser, false)

	const datapoints = 3

	req := &apiv1.CompareTrialsRequest{
		TrialIds:      []int32{int32(trial.ID)},
		MaxDatapoints: datapoints,
		MetricNames:   []string{"loss"},
		StartBatches:  0,
		EndBatches:    1000,
		MetricType:    apiv1.MetricType_METRIC_TYPE_TRAINING,
	}

	resp, err := api.CompareTrials(ctx, req)
	require.NoError(t, err)

	sampleBatches1 := compareTrialsResponseToBatches(resp)
	require.Len(t, sampleBatches1, datapoints)

	resp, err = api.CompareTrials(ctx, req)
	require.NoError(t, err)

	sampleBatches2 := compareTrialsResponseToBatches(resp)

	require.Equal(t, sampleBatches1, sampleBatches2)
}

func createTestTrialInferenceMetrics(ctx context.Context, t *testing.T, api *apiServer, id int32) {
	var trialMetrics map[model.MetricGroup][]map[string]any
	require.NoError(t, json.Unmarshal([]byte(
		`{"inference": [{"a":1}, {"b":2}]}`,
	), &trialMetrics))
	step := int32(0)
	for mType, metricsList := range trialMetrics {
		for _, m := range metricsList {
			metrics, err := structpb.NewStruct(m)
			require.NoError(t, err)
			err = api.m.db.AddTrialMetrics(ctx,
				&trialv1.TrialMetrics{
					TrialId:        id,
					TrialRunId:     int32(0),
					StepsCompleted: &step,
					Metrics: &commonv1.Metrics{
						AvgMetrics: metrics,
					},
				},
				mType,
			)
			require.NoError(t, err)
		}
	}
}

func TestTrialSourceInfoCheckpoint(t *testing.T) {
	api, authZExp, _, curUser, ctx := setupExpAuthTest(t, nil)
	infTrial, _ := createTestTrial(t, api, curUser)
	infTrial2, _ := createTestTrial(t, api, curUser)
	createTestTrialInferenceMetrics(ctx, t, api, int32(infTrial.ID))
	createTestTrialInferenceMetrics(ctx, t, api, int32(infTrial2.ID))

	mockUserArg := mock.MatchedBy(func(u model.User) bool {
		return u.ID == curUser.ID
	})

	// Create a checkpoint to index with
	checkpointUUID := createVersionTwoCheckpoint(ctx, t, api, curUser, map[string]int64{"a": 1})

	// Create a TrialSourceInfo associated with each of the two trials.
	resp, err := trials.CreateTrialSourceInfo(
		ctx, &trialv1.TrialSourceInfo{
			TrialId:             int32(infTrial.ID),
			CheckpointUuid:      checkpointUUID,
			TrialSourceInfoType: trialv1.TrialSourceInfoType_TRIAL_SOURCE_INFO_TYPE_INFERENCE,
		},
	)
	require.NoError(t, err)
	require.Equal(t, resp.TrialId, int32(infTrial.ID))
	require.Equal(t, resp.CheckpointUuid, checkpointUUID)

	resp, err = trials.CreateTrialSourceInfo(
		ctx, &trialv1.TrialSourceInfo{
			TrialId:             int32(infTrial2.ID),
			CheckpointUuid:      checkpointUUID,
			TrialSourceInfoType: trialv1.TrialSourceInfoType_TRIAL_SOURCE_INFO_TYPE_INFERENCE,
		},
	)
	require.NoError(t, err)
	require.Equal(t, resp.TrialId, int32(infTrial2.ID))
	require.Equal(t, resp.CheckpointUuid, checkpointUUID)

	authZExp.On("CanGetExperiment", mock.Anything, mockUserArg, mock.Anything).
		Return(nil).Times(3)
	authZExp.On("CanGetExperimentArtifacts", mock.Anything, mockUserArg, mock.Anything).
		Return(nil).Times(3)

	// If there are no restrictions, we should see all the trials
	getCkptResp, getErr := api.GetTrialMetricsByCheckpoint(
		ctx, &apiv1.GetTrialMetricsByCheckpointRequest{
			CheckpointUuid: checkpointUUID,
			MetricGroup:    &inferenceMetricGroup,
		},
	)
	require.NoError(t, getErr)
	require.Len(t, getCkptResp.Metrics, 2)

	infTrialExp, err := db.ExperimentByID(ctx, infTrial.ExperimentID)
	require.NoError(t, err)
	infTrial2Exp, err := db.ExperimentByID(ctx, infTrial2.ExperimentID)
	require.NoError(t, err)

	// All experiments can be seen
	authZExp.On("CanGetExperiment", mock.Anything, mockUserArg, mock.Anything).
		Return(nil).Times(3)
	// We can see the experiment that generated the checkpoint
	authZExp.On("CanGetExperimentArtifacts", mock.Anything, mockUserArg, mock.Anything).
		Return(nil).Once()
	// We can't see the experiment for infTrial
	authZExp.On("CanGetExperimentArtifacts", mock.Anything, mockUserArg, infTrialExp).
		Return(authz2.PermissionDeniedError{}).Once()
	// We can see the experiment for infTrial2
	authZExp.On("CanGetExperimentArtifacts", mock.Anything, mockUserArg, infTrial2Exp).
		Return(nil).Once()
	getCkptResp, getErr = api.GetTrialMetricsByCheckpoint(
		ctx, &apiv1.GetTrialMetricsByCheckpointRequest{
			CheckpointUuid: checkpointUUID,
			MetricGroup:    &inferenceMetricGroup,
		},
	)
	require.NoError(t, getErr)
	// Only infTrial2 should be visible, but it doesn't have metrics
	require.Len(t, getCkptResp.Metrics, 1)
	require.Equal(t, int32(infTrial2.ID), getCkptResp.Metrics[0].TrialId)
}

func TestTrialSourceInfoModelVersion(t *testing.T) {
	api, curUser, ctx := setupAPITest(t, nil)
	infTrial, _ := createTestTrial(t, api, curUser)
	infTrial2, _ := createTestTrial(t, api, curUser)
	createTestTrialInferenceMetrics(ctx, t, api, int32(infTrial.ID))

	// Create a checkpoint to index with
	checkpointUUID := createVersionTwoCheckpoint(ctx, t, api, curUser, map[string]int64{"a": 1})

	// Create a model_version to index with
	conv := &protoconverter.ProtoConverter{}
	modelVersion := RegisterCheckpointAsModelVersion(ctx, t, api.m.db, conv.ToUUID(checkpointUUID))

	// Create a TrialSourceInfo associated with each of the two trials.
	resp, err := trials.CreateTrialSourceInfo(
		ctx, &trialv1.TrialSourceInfo{
			TrialId:             int32(infTrial.ID),
			CheckpointUuid:      checkpointUUID,
			TrialSourceInfoType: trialv1.TrialSourceInfoType_TRIAL_SOURCE_INFO_TYPE_INFERENCE,
			ModelId:             &modelVersion.Model.Id,
			ModelVersion:        &modelVersion.Version,
		},
	)
	require.NoError(t, err)
	require.Equal(t, resp.TrialId, int32(infTrial.ID))
	require.Equal(t, resp.CheckpointUuid, checkpointUUID)

	resp, err = trials.CreateTrialSourceInfo(
		ctx, &trialv1.TrialSourceInfo{
			TrialId:             int32(infTrial2.ID),
			CheckpointUuid:      checkpointUUID,
			TrialSourceInfoType: trialv1.TrialSourceInfoType_TRIAL_SOURCE_INFO_TYPE_INFERENCE,
		},
	)
	require.NoError(t, err)
	require.Equal(t, resp.TrialId, int32(infTrial2.ID))
	require.Equal(t, resp.CheckpointUuid, checkpointUUID)

	getMVResp, getMVErr := api.GetTrialMetricsByModelVersion(
		ctx, &apiv1.GetTrialMetricsByModelVersionRequest{
			ModelName:       modelVersion.Model.Name,
			ModelVersionNum: modelVersion.Version,
			MetricGroup:     &inferenceMetricGroup,
		},
	)
	require.NoError(t, getMVErr)
	// One trial is valid and it has one aggregated MetricsReport
	require.Len(t, getMVResp.Metrics, 1)
	require.Equal(t, int32(infTrial.ID), getMVResp.Metrics[0].TrialId)
}

func TestGetTrialByExternalID(t *testing.T) {
	api, curUser, ctx := setupAPITest(t, nil)
	trial, _ := createTestTrial(t, api, curUser)
	externalExpID := uuid.New().String()
	externalTrialID := "trial"

	_, err := db.Bun().NewUpdate().Model(&model.Experiment{}).
		Where("id = ?", trial.ExperimentID).
		Set("external_experiment_id = ?", externalExpID).
		Exec(ctx)
	require.NoError(t, err)

	_, err = db.Bun().NewUpdate().Model(&model.Run{}).
		Where("id = ?", trial.ID).
		Set("external_run_id = ?", externalTrialID).
		Exec(ctx)
	require.NoError(t, err)

	resp, err := api.GetTrialByExternalID(ctx, &apiv1.GetTrialByExternalIDRequest{
		ExternalExperimentId: externalExpID,
		ExternalTrialId:      externalTrialID,
	})
	require.NoError(t, err)

	require.Equal(t, int(resp.Trial.Id), trial.ID)
}

func getLogRetentionDays(ctx context.Context, trialIDs []int) ([]int32, error) {
	var trialLogRetentionDays []int32
	err := db.Bun().NewSelect().Table("runs").
		Column("log_retention_days").
		Where("id IN (?)", bun.In(trialIDs)).
		Scan(ctx, &trialLogRetentionDays)

	return trialLogRetentionDays, err
}

func TestPutTrialRetainLogs(t *testing.T) {
	api, _, ctx := setupAPITest(t, nil)
	exp, trialIDs, _ := CreateTestRetentionExperiment(ctx, t, api, logRetentionConfigForever, 5)

	err := CompleteExpAndTrials(ctx, exp.Id, trialIDs)
	require.NoError(t, err)

	orgLogRetentionDays, err := getLogRetentionDays(ctx, trialIDs)
	require.NoError(t, err)
	require.Equal(t, []int32{-1, -1, -1, -1, -1}, orgLogRetentionDays)

	newLogRetentionDays := []int32{10, 10, 10, 10, 10}
	for i, v := range trialIDs {
		res, err := api.PutTrialRetainLogs(ctx, &apiv1.PutTrialRetainLogsRequest{
			TrialId: int32(v), NumDays: newLogRetentionDays[i],
		})
		require.NoError(t, err)
		require.NotNil(t, res)
	}

	updatedLogRetentionDays, err := getLogRetentionDays(ctx, trialIDs)
	require.NoError(t, err)
	require.Equal(t, updatedLogRetentionDays, newLogRetentionDays)
}

func completeTrialsandTasks(ctx context.Context, trialID int, endTimeDays int) error {
	_, err := db.Bun().NewUpdate().
		Table("tasks", "run_id_task_id").
		Set("end_time = (NOW() - make_interval(days => ?))", endTimeDays).
		Where("run_id_task_id.run_id = ? and tasks.task_id = run_id_task_id.task_id", trialID).
		Exec(ctx)
	if err != nil {
		return err
	}
	_, err = db.Bun().NewUpdate().Table("runs", "run_id_task_id", "tasks").
		Set("state = ?", model.CompletedState).
		Set("end_time = (NOW() - make_interval(days => ?))", endTimeDays).
		Where("id = ?", trialID).
		Exec(ctx)
	if err != nil {
		return err
	}
	return nil
}

func completeExp(ctx context.Context, expID int32) error {
	_, err := db.Bun().NewUpdate().Table("experiments").
		Set("state = ?", model.CompletedState).
		Where("id = ?", expID).
		Exec(ctx)
	if err != nil {
		return err
	}
	return nil
}

func TestGetTrialRemainingLogRetentionDaysNullMasterConfig(t *testing.T) {
	api, _, ctx := setupAPITest(t, nil)

	tests := []struct {
		name             string
		logRetentionDays string
		expRemainingDays []int32
	}{
		{"test-null-days", "", []int32{-1, -1, -1, -1}},
		{"test-forever", `
retention_policy:
  log_retention_days: -1
`, []int32{-1, -1, -1, -1}},
		{"test-10days", `
retention_policy:
  log_retention_days: 10
`, []int32{0, 0, 4, 9}},
	}

	for _, tt := range tests {
		log.Printf("Starting %v", tt.name)
		testEndTimes := []int{15, 10, 5, 0}
		exp, trialIDs, _ := CreateTestRetentionExperiment(ctx, t, api, tt.logRetentionDays, len(testEndTimes))

		for i, v := range testEndTimes {
			err := completeTrialsandTasks(ctx, trialIDs[i], v)
			require.NoError(t, err)

			d, err := api.GetTrialRemainingLogRetentionDays(ctx, &apiv1.GetTrialRemainingLogRetentionDaysRequest{
				Id: int32(trialIDs[i]),
			})
			require.NoError(t, err)
			require.Equal(t, tt.expRemainingDays[i], *d.RemainingDays)
		}
		err := completeExp(ctx, exp.Id)
		require.NoError(t, err)
	}
}

func TestGetTrialRemainingLogRetentionDaysNonNullMasterConfig(t *testing.T) {
	api, _, ctx := setupAPITest(t, nil)
	// set Log retention days in master config
	retentionDays := int16(100)
	api.m.config.RetentionPolicy.LogRetentionDays = &retentionDays
	tests := []struct {
		name             string
		logRetentionDays string
		expRemainingDays []int32
	}{
		{"test-null-days", "", []int32{84, 89, 94, 99}},
		{"test-forever", `
retention_policy:
  log_retention_days: -1
`, []int32{-1, -1, -1, -1}},
		{"test-10days", `
retention_policy:
  log_retention_days: 10
`, []int32{0, 0, 4, 9}},
	}

	for _, tt := range tests {
		log.Printf("Starting %v", tt.name)
		testEndTimes := []int{15, 10, 5, 0}
		exp, trialIDs, _ := CreateTestRetentionExperiment(ctx, t, api, tt.logRetentionDays, len(testEndTimes))

		for i, v := range testEndTimes {
			err := completeTrialsandTasks(ctx, trialIDs[i], v)
			require.NoError(t, err)

			d, err := api.GetTrialRemainingLogRetentionDays(ctx, &apiv1.GetTrialRemainingLogRetentionDaysRequest{
				Id: int32(trialIDs[i]),
			})
			require.NoError(t, err)
			require.Equal(t, tt.expRemainingDays[i], *d.RemainingDays)
		}
		err := completeExp(ctx, exp.Id)
		require.NoError(t, err)
	}
}

func TestRunLocalID(t *testing.T) {
	api, curUser, ctx := setupAPITest(t, nil)
	_, projectID := createProjectAndWorkspace(ctx, t, api)

	exp := createTestExpWithProjectID(t, api, curUser, projectID)
	task := &model.Task{
		TaskType:   model.TaskTypeTrial,
		LogVersion: model.TaskLogVersion1,
		StartTime:  time.Now(),
		TaskID:     model.TaskID(fmt.Sprintf("backported.%d", exp.ID)),
	}
	require.NoError(t, db.AddTask(ctx, task))

	for i := 1; i <= 5; i++ {
		trial := &model.Trial{
			StartTime:    time.Now(),
			State:        model.PausedState,
			ExperimentID: exp.ID,
		}
		require.NoError(t, db.AddTrial(ctx, trial, task.TaskID))

		// validate local_run_id and max id is correct
		var localID int
		err := db.Bun().NewSelect().Table("runs").Column("local_id").Where("id = ?", trial.ID).Scan(ctx, &localID)
		require.NoError(t, err)
		require.Equal(t, i, localID)

		var maxLocalID int
		err = db.Bun().NewSelect().Table("projects").Column("max_local_id").Where("id = ?", projectID).Scan(ctx, &maxLocalID)
		require.NoError(t, err)
		require.Equal(t, i, maxLocalID)
	}
}
