// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

// Utils for lowering of the StableHLO dialect to the Linalg dialect.

#ifndef IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_LEGALIZE_TO_LINALG_UTILS_H_
#define IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_LEGALIZE_TO_LINALG_UTILS_H_

#include <algorithm>
#include <numeric>
#include <optional>
#include <string>
#include <utility>

#include "compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::iree_compiler::stablehlo {

/// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
/// are "parallel" except the last `nReduction` elements, where are "reduction"
/// attributes.
SmallVector<utils::IteratorType, 3>
getParallelAndReductionIterators(unsigned nLoops, unsigned nReduction);

/// Returns an ArrayAttr that contains `nParallelLoops` "parallel" attributes.
SmallVector<utils::IteratorType, 3>
getNParallelLoopsAttrs(unsigned nParallelLoops);

/// Generates an init sparse tensor.
Value getEmptySparseTensor(OpBuilder &b, Location loc, ShapedType type,
                           ArrayRef<Value> dynSizes);

/// Generates a tensor.empty op.
Value getEmptyTensor(OpBuilder &b, Location loc, ShapedType type,
                     ArrayRef<Value> dynSizes);

/// Generates an empty tensor for the result of the operation, which could be a
/// dense tensor or a sparse tensor.
Value getEmptyTensorFor(OpBuilder &b, Location loc, ShapedType resultType,
                        Operation *op, ValueRange operands);

/// Ensures a tensor has the same shape (not including the element type) as
/// another.
Value coerceTensorShape(OpBuilder &builder, Location loc,
                        TypedValue<ShapedType> value, ShapedType targetType);

/// Verifies |op|'s semantics by checking if all operands and results have
/// ranged tensor types.
LogicalResult verifyHloOpBufferOrTensorSemantics(Operation *op);

/// Fills |tensor| with a zero constant of the matching type. Returns the new
/// value.
Value fillTensorWithZeros(OpBuilder &builder, Location loc, Value tensor);

/// Sparsifies a (block of) operation(s) that cannot be handled directly
/// by the sparse compiler but has well-known semi-ring semantics.
///
/// This yields something of the following form:
///
///   %result = sparse_tensor.unary %values[0]
///     present={
///       ^bb1(%val):
///         ... codegen proceeds here using %val ....
///         sparse_tensor.yield
///     }
///     absent={}
///   linalg.yield %result
Value preSparsify(Operation *op, llvm::SmallVector<Value, 2> &values, Type rtp,
                  OpBuilder *b);

/// Finalizes sparse semi-ring construction.
Value postSparsify(Operation *op, Value semiring, Value result, OpBuilder *b);

/// Returns true if all operands are tensors with rank 0.
bool allOperandsAreScalarTensors(Operation *op);

/// Returns true if parent op is linalg.
bool isInBodyOfLinalgOps(Operation *op);

/// Extracts integer values from the attribute |elements|.
SmallVector<int64_t> extract1DVector(DenseIntElementsAttr elements);

/// Returns true if the given |values| is a splat of the given |queryValue|.
inline bool isSplatValue(const ArrayRef<int64_t> &values, int64_t queryValue) {
  for (auto value : values) {
    if (value != queryValue) {
      return false;
    }
  }
  return true;
}

/// Returns true if the given |attr| is a splat of the given |value|.
inline bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
  return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
}

} // namespace mlir::iree_compiler::stablehlo

#endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_LEGALIZE_TO_LINALG_UTILS_H_
