import {
  CheckpointState,
  CheckpointStorageType,
  CheckpointWorkloadExtended,
  ExperimentBase,
  ExperimentSearcherName,
  FlatRun,
  HyperparameterType,
  Project,
  RunState,
  SummaryMetrics,
  TrialDetails,
  Workspace,
  WorkspaceState,
} from 'types';
import { generateExperiment } from 'utils/task';

export const generateTestExperimentData = (): {
  checkpoint: CheckpointWorkloadExtended;
  experiment: ExperimentBase;
  trial: TrialDetails;
} => {
  const exp = generateExperiment();
  const experiment: ExperimentBase = {
    ...exp,
    archived: false,
    config: {
      checkpointPolicy: 'best',
      checkpointStorage: {
        hostPath: '/tmp',
        saveExperimentBest: 0,
        saveTrialBest: 1,
        saveTrialLatest: 1,
        storagePath: 'determined-checkpoint',
        type: CheckpointStorageType.SharedFS,
      },
      hyperparameters: {
        categorical: {
          maxval: 64,
          minval: 8,
          type: HyperparameterType.Categorical,
          vals: [8, 16, 32, 64],
        },
        constant: {
          type: HyperparameterType.Constant,
          val: 64,
        },
        double: {
          maxval: 0.8,
          minval: 0.2,
          type: HyperparameterType.Double,
        },
        log: {
          maxval: 1,
          minval: 0.0001,
          type: HyperparameterType.Log,
        },
      },
      labels: [],
      maxRestarts: 5,
      name: 'mnist_pytorch_adaptive_search',
      profiling: { enabled: false },
      resources: {},
      searcher: {
        metric: 'validation_loss',
        name: ExperimentSearcherName.AdaptiveAsha,
        smallerIsBetter: true,
      },
    },
    configRaw: {
      bind_mounts: [],
      checkpoint_policy: 'best',
      checkpoint_storage: {
        host_path: '/tmp',
        propagation: 'rprivate',
        save_experiment_best: 0,
        save_trial_best: 1,
        save_trial_latest: 1,
        storage_path: 'determined-checkpoint',
        type: 'shared_fs',
      },
      data: {
        url: 'https://s3-us-west-2.amazonaws.com/determined-ai-test-data/pytorch_mnist.tar.gz',
      },
      data_layer: {
        container_storage_path: null,
        host_storage_path: null,
        type: 'shared_fs',
      },
      debug: false,
      description: null,
      entrypoint: 'model_def:MNistTrial',
      environment: {
        add_capabilities: [],
        drop_capabilities: [],
        environment_variables: {
          cpu: [],
          gpu: [],
        },
        force_pull_image: false,
        image: {
          cpu: 'determinedai/environments:py-3.7-pytorch-1.7-tf-1.15-cpu-da845fc',
          gpu: 'determinedai/environments:cuda-10.2-pytorch-1.7-tf-1.15-gpu-da845fc',
        },
        pod_spec: null,
        ports: {},
        registry_auth: null,
      },
      hyperparameters: {
        dropout1: {
          maxval: 0.8,
          minval: 0.2,
          type: 'double',
        },
        dropout2: {
          maxval: 0.8,
          minval: 0.2,
          type: 'double',
        },
        global_batch_size: {
          type: 'const',
          val: 64,
        },
        learning_rate: {
          maxval: 1,
          minval: 0.0001,
          type: 'double',
        },
        n_filters1: {
          maxval: 64,
          minval: 8,
          type: 'int',
        },
        n_filters2: {
          maxval: 72,
          minval: 8,
          type: 'int',
        },
      },
      labels: [],
      max_restarts: 5,
      min_checkpoint_period: { batches: 0 },
      min_validation_period: { batches: 0 },
      name: 'mnist_pytorch_adaptive_search',
      optimizations: {
        aggregation_frequency: 1,
        auto_tune_tensor_fusion: false,
        average_aggregated_gradients: true,
        average_training_metrics: false,
        grad_updates_size_file: null,
        gradient_compression: false,
        mixed_precision: 'O0',
        tensor_fusion_cycle_time: 5,
        tensor_fusion_threshold: 64,
      },
      perform_initial_validation: false,
      profiling: {
        begin_on_batch: 0,
        enabled: false,
        end_after_batch: null,
      },
      records_per_epoch: 0,
      reproducibility: { experiment_seed: 1623252417 },
      resources: {
        agent_label: '',
        devices: [],
        max_slots: null,
        native_parallel: false,
        priority: null,
        resource_pool: 'default',
        shm_size: null,
        slots_per_trial: 1,
        weight: 1,
      },
      scheduling_unit: 100,
      searcher: {
        bracket_rungs: [],
        divisor: 4,
        max_concurrent_trials: 16,
        max_rungs: 5,
        max_trials: 16,
        metric: 'validation_loss',
        mode: 'standard',
        name: 'adaptive_asha',
        smaller_is_better: true,
        source_checkpoint_uuid: null,
        source_trial_id: null,
        stop_once: false,
      },
    },
    hyperparameters: {
      categorical: {
        maxval: 64,
        minval: 8,
        type: HyperparameterType.Categorical,
        vals: [8, 16, 32, 64],
      },
      constant: {
        type: HyperparameterType.Constant,
        val: 64,
      },
      double: {
        maxval: 0.8,
        minval: 0.2,
        type: HyperparameterType.Double,
      },
      log: {
        maxval: 1,
        minval: 0.0001,
        type: HyperparameterType.Log,
      },
    },
    id: 1,
    name: 'Sample Experiment',
    originalConfig: `
      entrypoint: model_def:MNistTrial
      hyperparameters:
        dropout1: {maxval: 0.8, minval: 0.2, type: double}
        dropout2: {maxval: 0.8, minval: 0.2, type: double}
        global_batch_size: 64
        learning_rate: {maxval: 1.0, minval: 0.0001, type: double}
        n_filters1: {maxval: 64, minval: 8, type: int}
        n_filters2: {maxval: 72, minval: 8, type: int}
      name: mnist_pytorch_adaptive_search
      records_per_epoch: 10
      searcher:
        max_trials: 16
        metric: validation_loss
        name: adaptive_asha
        smaller_is_better: true`,
    parentArchived: false,
    projectId: 1,
    projectName: 'project',
    projectOwnerId: 1,
    resourcePool: 'default',
    startTime: '2021-06-09T15:26:57.610700Z',
    state: RunState.Completed,
    userId: 2,
    workspaceId: 1,
    workspaceName: 'workspace',
  };

  const checkpoint: CheckpointWorkloadExtended = {
    experimentId: 2,
    resources: { foo: 12 },
    state: CheckpointState.Completed,
    totalBatches: 50,
    trialId: 3,
    uuid: 'b6aab473-a959-47fa-a962-ba791b0230fb',
  };

  const trial: TrialDetails = {
    autoRestarts: 1,
    experimentId: 1,
    hyperparameters: { 1: 1 },
    id: 1,
    runnerState: 'Active',
    startTime: '1',
    state: RunState.Active,
    totalBatchesProcessed: 1,
    totalCheckpointSize: 0,
  };

  return { checkpoint, experiment, trial };
};

export const generateTestRunData = (id: number = 1, summaryMetrics: boolean = true): FlatRun => {
  return {
    archived: false,
    checkpointCount: 0,
    checkpointSize: 0,
    hyperparameters: { learning_rate: Math.random() },
    id,
    parentArchived: false,
    projectId: 1,
    projectName: 'Uncategorized',
    startTime: new Date(),
    state: RunState.Active,
    summaryMetrics: summaryMetrics ? generateSummaryMetricsData() : undefined,
    workspaceId: 1,
    workspaceName: 'Uncategorized',
  };
};

export const generateSummaryMetricsData = (): SummaryMetrics => {
  return {
    avgMetrics: {
      loss: {
        count: 1,
        last: 0.5823304653167725,
        max: 0.582330465316772,
        min: 0.582330465316772,
        sum: 0.582330465316772,
        type: 'number',
      },
    },
    validationMetrics: {
      accuracy: {
        count: 1,
        last: 0.8522093949044586,
        max: 0.852209394904459,
        min: 0.852209394904459,
        sum: 0.852209394904459,
        type: 'number',
      },
      validation_loss: {
        count: 1,
        last: 0.49773818169050155,
        max: 0.497738181690502,
        min: 0.497738181690502,
        sum: 0.497738181690502,
        type: 'number',
      },
    },
  };
};

export const generateTestProjectData = (overrides: Partial<Project> = {}): Project => {
  return {
    archived: false,
    id: 1,
    immutable: false,
    name: 'Project Name',
    notes: [],
    numActiveExperiments: 1,
    numExperiments: 1,
    state: WorkspaceState.Unspecified,
    userId: 1,
    workspaceId: 1,
    workspaceName: 'Workspace Name',
    ...overrides,
  };
};

export const generateTestWorkspaceData = (overrides: Partial<Workspace> = {}): Workspace => {
  return {
    archived: false,
    id: 1,
    immutable: false,
    name: 'Workspace Name',
    numExperiments: 1,
    numProjects: 1,
    pinned: false,
    state: WorkspaceState.Unspecified,
    userId: 1,
    ...overrides,
  };
};
