import { FilterValue, SorterResult, TablePaginationConfig } from 'antd/es/table/interface';
import Button from 'hew/Button';
import Icon from 'hew/Icon';
import Select, { Option, SelectValue } from 'hew/Select';
import { Loadable, Loaded, NotLoaded } from 'hew/utils/loadable';
import _ from 'lodash';
import React, { useCallback, useEffect, useMemo, useState } from 'react';

import HumanReadableNumber from 'components/HumanReadableNumber';
import MetricBadgeTag from 'components/MetricBadgeTag';
import ResponsiveFilters from 'components/ResponsiveFilters';
import Section from 'components/Section';
import ResponsiveTable from 'components/Table/ResponsiveTable';
import { defaultRowClassName, getFullPaginationConfig } from 'components/Table/Table';
import { useCheckpointFlow } from 'hooks/useCheckpointFlow';
import useFeature from 'hooks/useFeature';
import { useFetchModels } from 'hooks/useFetchModels';
import usePolling from 'hooks/usePolling';
import { getTrialWorkloads } from 'services/api';
import {
  CheckpointWorkloadExtended,
  ExperimentBase,
  Metric,
  Step,
  TrialDetails,
  TrialWorkloadFilter,
  WorkloadGroup,
} from 'types';
import handleError, { ErrorType } from 'utils/error';
import {
  extractMetricSortValue,
  extractMetricValue,
  metricKeyToMetric,
  metricToKey,
} from 'utils/metric';
import { numericSorter } from 'utils/sort';
import { hasCheckpoint, hasCheckpointStep, workloadsToSteps } from 'utils/workload';

import { Settings } from './TrialDetailsOverview.settings';
import { columns as defaultColumns } from './TrialDetailsWorkloads.table';

export interface Props {
  defaultMetrics: Metric[];
  experiment: ExperimentBase;
  metricNames: Metric[];
  metrics: Metric[];
  settings: Settings;
  trial?: TrialDetails;
  updateSettings: (newSettings: Partial<Settings>) => void;
}

const TrialDetailsWorkloads: React.FC<Props> = ({
  defaultMetrics,
  experiment,
  metrics,
  settings,
  trial,
  updateSettings,
}: Props) => {
  const models = useFetchModels();
  const [checkpoint, setCheckpoint] = useState<CheckpointWorkloadExtended>();
  const { checkpointModalComponents, openCheckpoint } = useCheckpointFlow({
    checkpoint,
    config: experiment.config,
    models,
    title: `Checkpoint for Batch ${checkpoint?.totalBatches}`,
  });
  const f_flat_runs = useFeature().isOn('flat_runs');

  const hasFiltersApplied = useMemo(() => {
    const metricsApplied = !_.isEqual(metrics, defaultMetrics);
    const checkpointValidationFilterApplied = settings.filter !== TrialWorkloadFilter.All;
    return metricsApplied || checkpointValidationFilterApplied;
  }, [defaultMetrics, metrics, settings.filter]);

  const handleOpenCheckpoint = useCallback(
    (step: Step) => {
      if (trial && step.checkpoint && hasCheckpointStep(step)) {
        setCheckpoint({
          ...step.checkpoint,
          experimentId: trial.experimentId,
          trialId: trial.id,
        });
        openCheckpoint();
      }
    },
    [openCheckpoint, trial],
  );

  const columns = useMemo(() => {
    const checkpointRenderer = (_: string, record: Step): React.ReactNode => {
      if (trial && record.checkpoint && hasCheckpointStep(record)) {
        return (
          <Button
            aria-label="View Checkpoint"
            icon={<Icon name="checkpoint" showTooltip title="View Checkpoint" />}
            onClick={() => handleOpenCheckpoint(record)}
          />
        );
      }
      return null;
    };

    const metricRenderer = (metric: Metric) => (_: string, record: Step) => {
      const value = extractMetricValue(record, metric);
      return <HumanReadableNumber num={value} />;
    };

    const { metric: searcherMetric, smallerIsBetter } = experiment?.config?.searcher || {};
    const newColumns = [...defaultColumns].map((column) => {
      if (column.key === 'checkpoint') column.render = checkpointRenderer;
      return column;
    });

    metrics.forEach((metric) => {
      if (!['validation', 'training'].includes(metric.group)) return;
      const stateIndex = newColumns.findIndex((column) => column.key === 'state');
      newColumns.splice(stateIndex, 0, {
        defaultSortOrder:
          searcherMetric && searcherMetric === metric.name
            ? smallerIsBetter
              ? 'ascend'
              : 'descend'
            : undefined,
        key: metricToKey(metric),
        render: metricRenderer(metric),
        sorter: (a, b) => {
          const aVal = extractMetricSortValue(a, metric),
            bVal = extractMetricSortValue(b, metric);
          if (aVal === undefined && bVal !== undefined) {
            return settings.sortDesc ? -1 : 1;
          } else if (aVal !== undefined && bVal === undefined) {
            return settings.sortDesc ? 1 : -1;
          }
          return numericSorter(aVal, bVal);
        },
        title: <MetricBadgeTag metric={metric} />,
      });
    });

    return newColumns.map((column) => {
      column.sortOrder = null;
      if (column.key === settings.sortKey) {
        column.sortOrder = settings.sortDesc ? 'descend' : 'ascend';
      }
      return column;
    });
  }, [
    experiment?.config?.searcher,
    metrics,
    trial,
    handleOpenCheckpoint,
    settings.sortDesc,
    settings.sortKey,
  ]);

  const [workloads, setWorkloads] = useState<Loadable<WorkloadGroup[]>>(NotLoaded);
  const [workloadCount, setWorkloadCount] = useState<number>(0);

  const fetchWorkloads = useCallback(async () => {
    try {
      if (trial?.id) {
        const wl = await getTrialWorkloads({
          filter: settings.filter,
          id: trial.id,
          limit: settings.tableLimit,
          offset: settings.tableOffset,
          orderBy: settings.sortDesc ? 'ORDER_BY_DESC' : 'ORDER_BY_ASC',
          sortKey: metricKeyToMetric(settings.sortKey)?.name || undefined,
        });
        setWorkloads(Loaded(wl.workloads));
        setWorkloadCount(wl.pagination.total || 0);
      } else {
        setWorkloadCount(0);
      }
    } catch (e) {
      handleError(e, {
        publicMessage: `Failed to load recent ${f_flat_runs ? 'run' : 'trial'} workloads.`,
        publicSubject: `Unable to fetch ${f_flat_runs ? 'run' : 'trial'} Workloads.`,
        silent: false,
        type: ErrorType.Api,
      });
    }
  }, [
    f_flat_runs,
    trial?.id,
    settings.sortDesc,
    settings.sortKey,
    settings.tableLimit,
    settings.tableOffset,
    settings.filter,
  ]);

  const { stopPolling } = usePolling(fetchWorkloads, { rerunOnNewFn: true });

  const workloadSteps = useMemo(() => {
    const data = Loadable.getOrElse([], workloads);
    const workloadSteps = workloadsToSteps(data);
    return settings.filter === TrialWorkloadFilter.All
      ? workloadSteps
      : workloadSteps.filter((wlStep) => {
          if (settings.filter === TrialWorkloadFilter.Checkpoint) {
            return hasCheckpoint(wlStep);
          } else if (settings.filter === TrialWorkloadFilter.Validation) {
            return !!wlStep.metrics.validation;
          } else if (settings.filter === TrialWorkloadFilter.CheckpointOrValidation) {
            return !!wlStep.checkpoint || !!wlStep.metrics.validation;
          }
          return false;
        });
  }, [settings.filter, workloads]);

  const handleHasCheckpointOrValidationSelect = useCallback(
    (value: SelectValue): void => {
      const newFilter = value as TrialWorkloadFilter;
      const isValidFilter = Object.values(TrialWorkloadFilter).includes(newFilter);
      const filter = isValidFilter ? newFilter : undefined;
      updateSettings({ filter, tableOffset: 0 });
    },
    [updateSettings],
  );

  const handleTableChange = useCallback(
    (
      tablePagination: TablePaginationConfig,
      _tableFilters: Record<string, FilterValue | null>,
      tableSorter: SorterResult<Step> | SorterResult<Step>[],
    ) => {
      if (Array.isArray(tableSorter)) return;

      const { columnKey, order } = tableSorter as SorterResult<Step>;
      if (!columnKey || !columns.find((column) => column.key === columnKey)) return;

      updateSettings({
        sortDesc: order === 'descend',
        sortKey: columnKey as string,
        tableLimit: tablePagination.pageSize,
        tableOffset: ((tablePagination.current ?? 1) - 1) * (tablePagination.pageSize ?? 0),
      });
    },
    [columns, updateSettings],
  );

  const options = (
    <ResponsiveFilters hasFiltersApplied={hasFiltersApplied}>
      <Select label="Show" value={settings.filter} onSelect={handleHasCheckpointOrValidationSelect}>
        {Object.values(TrialWorkloadFilter).map((key) => (
          <Option key={key} value={key}>
            {key}
          </Option>
        ))}
      </Select>
    </ResponsiveFilters>
  );

  useEffect(() => {
    return () => {
      stopPolling();
    };
  }, [stopPolling]);

  return (
    <Section options={options} title="Workloads">
      <ResponsiveTable<Step>
        columns={columns}
        dataSource={workloadSteps}
        loading={Loadable.isNotLoaded(workloads)}
        pagination={getFullPaginationConfig(
          {
            limit: settings.tableLimit,
            offset: settings.tableOffset,
          },
          workloadCount,
        )}
        rowClassName={defaultRowClassName({ clickable: false })}
        rowKey="batchNum"
        scroll={{ x: 1000 }}
        showSorterTooltip={false}
        size="small"
        onChange={handleTableChange}
      />
      {checkpointModalComponents}
    </Section>
  );
};

export default TrialDetailsWorkloads;
