// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-dialects/Dialect/Input/InputOps.h"

#include "iree-dialects/Dialect/Input/InputDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"

using namespace mlir;
using namespace mlir::iree_compiler::IREE::Input;

#include "iree-dialects/Dialect/Input/InputOpInterfaces.cpp.inc"

//===----------------------------------------------------------------------===//
// IREE::Input::TiedOpInterface
//===----------------------------------------------------------------------===//

namespace mlir::iree_compiler::IREE::Input::detail {

std::optional<unsigned> getTiedResultOperandIndex(Operation *op,
                                                  unsigned resultIndex) {
  auto storageAttr =
      op->getAttrOfType<ArrayAttr>(TiedOpInterface::getStorageAttrName());
  if (!storageAttr)
    return std::nullopt;
  auto valueAttrs = storageAttr.getValue();
  if (valueAttrs.empty())
    return std::nullopt;
  if (auto tiedOp = dyn_cast<TiedOpInterface>(op)) {
    auto indexAndLength = tiedOp.getTiedResultsIndexAndLength();
    if (resultIndex < indexAndLength.first)
      return std::nullopt;
    resultIndex -= indexAndLength.first;
    if (resultIndex >= indexAndLength.second)
      return std::nullopt;
  }
  int64_t value = llvm::cast<IntegerAttr>(valueAttrs[resultIndex]).getInt();
  if (value == TiedOpInterface::kUntiedIndex)
    return std::nullopt;
  if (auto tiedOp = dyn_cast<TiedOpInterface>(op)) {
    unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first;
    return tiedOperandsOffset + static_cast<unsigned>(value);
  } else {
    return static_cast<unsigned>(value);
  }
}

SmallVector<int64_t> getTiedResultOperandIndices(Operation *op) {
  SmallVector<int64_t> indices;
  auto storageAttr =
      op->getAttrOfType<ArrayAttr>(TiedOpInterface::getStorageAttrName());
  if (!storageAttr)
    return indices;
  auto valueAttrs = storageAttr.getValue();
  if (valueAttrs.empty())
    return indices;
  auto tiedOp = cast<TiedOpInterface>(op);
  auto resultRange = tiedOp.getTiedResultsIndexAndLength();
  unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first;
  indices.resize(resultRange.second);
  for (unsigned i = 0; i < valueAttrs.size(); ++i) {
    int64_t index = llvm::cast<IntegerAttr>(valueAttrs[i]).getInt();
    indices[i] = index != TiedOpInterface::kUntiedIndex
                     ? tiedOperandsOffset + index
                     : TiedOpInterface::kUntiedIndex;
  }
  return indices;
}

void setTiedResultOperandIndex(Operation *op, unsigned resultIndex,
                               std::optional<unsigned> operandIndex) {
  auto tiedOp = cast<TiedOpInterface>(op);
  auto resultRange = tiedOp.getTiedResultsIndexAndLength();
  resultIndex -= resultRange.first;

  auto indices = getTiedResultOperandIndices(op);
  if (indices.empty()) {
    indices.resize(resultRange.second, TiedOpInterface::kUntiedIndex);
  } else {
    // Well, getTiedResultOperandIndices() returns indices into the full range
    // of the op, but in the attribute, we expect to store ranges into the range
    // returned by `getTiedOperandsIndexAndLength`.
    unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first;
    for (auto &index : indices) {
      if (index != TiedOpInterface::kUntiedIndex)
        index -= tiedOperandsOffset;
    }
  }

  indices[resultIndex] = operandIndex.value_or(TiedOpInterface::kUntiedIndex);
  op->setAttr(TiedOpInterface::getStorageAttrName(),
              Builder(op).getIndexArrayAttr(indices));
}

bool isOperandTied(Operation *op, unsigned operandIndex) {
  auto tiedOp = dyn_cast<TiedOpInterface>(op);
  if (!tiedOp)
    return false;
  auto tiedIndices = tiedOp.getTiedResultOperandIndices();
  for (unsigned i = 0; i < tiedIndices.size(); ++i) {
    if (tiedIndices[i] == operandIndex) {
      return true;
    }
  }
  return false;
}

SmallVector<Value> getOperandTiedResults(Operation *op, unsigned operandIndex) {
  auto tiedOp = dyn_cast<TiedOpInterface>(op);
  if (!tiedOp)
    return {};
  auto resultRange = tiedOp.getTiedResultsIndexAndLength();
  SmallVector<Value> results;
  auto tiedIndices = tiedOp.getTiedResultOperandIndices();
  for (unsigned i = 0; i < tiedIndices.size(); ++i) {
    if (tiedIndices[i] == operandIndex) {
      results.push_back(op->getResult(resultRange.first + i));
    }
  }
  return results;
}

LogicalResult verifyTiedOp(TiedOpInterface tiedOp) {
  auto tiedOperandIndices = tiedOp.getTiedResultOperandIndices();
  if (tiedOperandIndices.empty())
    return success();
  auto resultRange = tiedOp.getTiedResultsIndexAndLength();
  if (tiedOperandIndices.size() != resultRange.second) {
    return tiedOp.emitError("op results/tied operand indices mismatch");
  }
  return success();
}

} // namespace mlir::iree_compiler::IREE::Input::detail

Value TiedOpInterface::findTiedBaseValue(Value derivedValue) {
  Value baseValue = derivedValue;
  while (auto definingOp =
             dyn_cast_or_null<TiedOpInterface>(baseValue.getDefiningOp())) {
    auto tiedValue = definingOp.getTiedResultOperand(baseValue);
    if (!tiedValue)
      break;
    baseValue = tiedValue;
  }
  return baseValue;
}

bool TiedOpInterface::hasAnyTiedUses(Value value) {
  return llvm::any_of(value.getUses(), [](auto &use) {
    if (auto tiedOp = dyn_cast<TiedOpInterface>(use.getOwner())) {
      return tiedOp.isOperandTied(use.getOperandNumber());
    }
    return false;
  });
}

//===----------------------------------------------------------------------===//
// custom<SymbolVisibility>($sym_visibility)
//===----------------------------------------------------------------------===//
// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
// ->
// some.op @foo
// some.op private @foo

static ParseResult parseSymbolVisibility(OpAsmParser &parser,
                                         StringAttr &symVisibilityAttr) {
  StringRef symVisibility;
  if (succeeded(parser.parseOptionalKeyword(&symVisibility,
                                            {"public", "private", "nested"}))) {
    symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
  }
  return success();
}

static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
                                  StringAttr symVisibilityAttr) {
  if (!symVisibilityAttr) {
    p << "public";
  } else {
    p << symVisibilityAttr.getValue();
  }
}

//===----------------------------------------------------------------------===//
// custom<TypeOrAttr>($type, $attr)
//===----------------------------------------------------------------------===//
// some.op custom<TypeOrAttr>($type, $attr)
// ->
// some.op : i32
// some.op = 42 : i32
// some.op : i32 = 42 : index

static ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
                                   TypedAttr &attr) {
  if (succeeded(parser.parseOptionalEqual())) {
    if (failed(parser.parseAttribute(attr))) {
      return parser.emitError(parser.getCurrentLocation())
             << "expected attribute";
    }
    typeAttr = TypeAttr::get(attr.getType());
    return success();
  }

  Type type;
  if (failed(parser.parseColonType(type))) {
    return parser.emitError(parser.getCurrentLocation()) << "expected type";
  }
  typeAttr = TypeAttr::get(type);

  if (succeeded(parser.parseOptionalEqual())) {
    if (failed(parser.parseAttribute(attr))) {
      return parser.emitError(parser.getCurrentLocation())
             << "expected attribute";
    }
  }

  return success();
}

static void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
                            TypedAttr attr) {
  if (!attr || attr.getType() != type.getValue()) {
    p << " : ";
    p.printAttribute(type);
  }
  if (attr) {
    p << " = ";
    p.printAttribute(attr);
  }
}

//===----------------------------------------------------------------------===//
// custom<ShapedTiedResult>
//===----------------------------------------------------------------------===//
// type{%dim0, %dim1}
// %arg0 as type{%dim0}

static ParseResult parseShapedTiedResult(
    OpAsmParser &parser, Type &resultType,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultDims,
    ArrayAttr &tiedOperands) {
  OpAsmParser::UnresolvedOperand tiedResult;
  auto res = parser.parseOptionalOperand(tiedResult);
  int64_t tiedOperandIndex = TiedOpInterface::kUntiedIndex;
  if (res.has_value() && succeeded(res.value())) {
    tiedOperandIndex = 0;
    if (failed(parser.parseKeyword("as")))
      return failure();
  }
  if (failed(parser.parseType(resultType)))
    return failure();
  if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
    if (!shapedType.hasStaticShape()) {
      SmallVector<OpAsmParser::UnresolvedOperand> dynamicDims;
      if (failed(parser.parseLBrace()) ||
          failed(parser.parseOperandList(dynamicDims,
                                         shapedType.getNumDynamicDims(),
                                         OpAsmParser::Delimiter::None)) ||
          failed(parser.parseRBrace())) {
        return failure();
      }
      resultDims.append(dynamicDims);
    }
  }
  tiedOperands = parser.getBuilder().getIndexArrayAttr({tiedOperandIndex});
  return success();
}

static ParseResult parseShapedTiedResult(
    OpAsmParser &parser, Type &resultType,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultDims) {
  ArrayAttr tiedOperands;
  return parseShapedTiedResult(parser, resultType, resultDims, tiedOperands);
}

void printShapedTiedResult(OpAsmPrinter &p, TiedOpInterface op, Type resultType,
                           ValueRange resultDims) {
  auto tiedOperandIndex = op.getTiedResultOperandIndex(0);
  if (tiedOperandIndex.has_value()) {
    auto tiedOperand = op->getOperand(tiedOperandIndex.value());
    p.printOperand(tiedOperand);
    p << " as ";
  }
  p.printType(resultType);
  if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
    if (!shapedType.hasStaticShape()) {
      if (resultDims.empty()) {
        p << "{<<INVALID>>}";
        return;
      }
      p << "{";
      llvm::interleaveComma(
          resultDims.take_front(shapedType.getNumDynamicDims()), p,
          [&](Value value) { p.printOperand(value); });
      p << "}";
      resultDims = resultDims.drop_front(shapedType.getNumDynamicDims());
    }
  }
}

static void printShapedTiedResult(OpAsmPrinter &p, TiedOpInterface op,
                                  Type resultType, ValueRange resultDims,
                                  ArrayAttr tiedOperands) {
  printShapedTiedResult(p, op, resultType, resultDims);
}

//===----------------------------------------------------------------------===//
// custom<ShapedOperandList>($values, type($values), $value_dims)
//===----------------------------------------------------------------------===//
// %value : type{%dynamic_dims}, ...

ParseResult parseShapedOperandList(
    OpAsmParser &parser,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
    SmallVectorImpl<Type> &valueTypes,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &valueDims) {
  do {
    values.emplace_back();
    valueTypes.emplace_back();
    if (failed(parser.parseOperand(values.back())) ||
        failed(parser.parseColon()) ||
        failed(parser.parseType(valueTypes.back())))
      return failure();
    if (int64_t dynamicDimCount =
            cast<ShapedType>(valueTypes.back()).getNumDynamicDims()) {
      if (failed(parser.parseOperandList(valueDims, dynamicDimCount,
                                         AsmParser::Delimiter::Braces)))
        return failure();
    }
  } while (succeeded(parser.parseOptionalComma()));
  return success();
}

void printShapedOperandList(OpAsmPrinter &p, Operation *op, ValueRange values,
                            TypeRange valueTypes, ValueRange valueDims) {
  llvm::interleaveComma(llvm::zip_equal(values, valueTypes), p, [&](auto it) {
    auto [value, valueType] = it;
    p << value;
    p << " : ";
    p << valueType;
    if (int64_t dynamicDimCount =
            cast<ShapedType>(valueType).getNumDynamicDims()) {
      p << "{";
      llvm::interleaveComma(valueDims.take_front(dynamicDimCount), p);
      valueDims = valueDims.drop_front(dynamicDimCount);
      p << "}";
    }
  });
}

//===----------------------------------------------------------------------===//
// custom<ShapedFunctionType>
//===----------------------------------------------------------------------===//
// (type, type{%dim0, %dim1}, type) -> (type{%dim2}, %operand4)

static ParseResult
parseShapedTypeList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
                    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dims) {
  do {
    Type type;
    if (failed(parser.parseType(type)))
      return failure();
    if (auto shapedType = dyn_cast<ShapedType>(type)) {
      if (!shapedType.hasStaticShape()) {
        SmallVector<OpAsmParser::UnresolvedOperand> dynamicDims;
        if (failed(parser.parseLBrace()) ||
            failed(parser.parseOperandList(dynamicDims,
                                           shapedType.getNumDynamicDims(),
                                           OpAsmParser::Delimiter::None)) ||
            failed(parser.parseRBrace())) {
          return failure();
        }
        dims.append(dynamicDims);
      }
    }
    types.push_back(type);
  } while (succeeded(parser.parseOptionalComma()));
  return success();
}

// Finds the operand index in |operands| that |tiedResult| references.
// Returns TiedOpInterface::kUntiedIndex if no operand is found.
static int64_t
findTiedOperand(OpAsmParser::UnresolvedOperand tiedResult,
                ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
  int64_t operandIndex = TiedOpInterface::kUntiedIndex;
  for (int64_t i = 0; i < operands.size(); ++i) {
    if (operands[i].name == tiedResult.name &&
        operands[i].number == tiedResult.number) {
      operandIndex = i;
      break;
    }
  }
  return operandIndex;
}

static ParseResult parseShapedResultList(
    OpAsmParser &parser, ArrayRef<OpAsmParser::UnresolvedOperand> operands,
    TypeRange operandTypes,
    ArrayRef<OpAsmParser::UnresolvedOperand> operandDims,
    SmallVectorImpl<Type> &resultTypes,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultDims,
    ArrayAttr &tiedOperands) {
  SmallVector<int64_t> tiedOperandIndices;
  do {
    OpAsmParser::UnresolvedOperand tiedResult;
    auto res = parser.parseOptionalOperand(tiedResult);
    Type type;
    int64_t tiedOperandIndex = TiedOpInterface::kUntiedIndex;
    if (res.has_value() && succeeded(res.value())) {
      tiedOperandIndex = findTiedOperand(tiedResult, operands);
      if (tiedOperandIndex == TiedOpInterface::kUntiedIndex) {
        return parser.emitError(tiedResult.location,
                                "tied operand not found for result reference ")
               << tiedResult.name;
      }
      if (succeeded(parser.parseOptionalKeyword("as"))) {
        // Type _may_ differ from the operand.
        if (failed(parser.parseType(type)))
          return failure();
      } else {
        // Use the operands type.
        type = operandTypes[tiedOperandIndex];
      }
    } else if (failed(parser.parseType(type))) {
      return failure();
    }
    if (auto shapedType = dyn_cast<ShapedType>(type)) {
      if (!shapedType.hasStaticShape()) {
        SmallVector<OpAsmParser::UnresolvedOperand> dynamicDims;
        if (failed(parser.parseLBrace()) ||
            failed(parser.parseOperandList(dynamicDims,
                                           shapedType.getNumDynamicDims(),
                                           OpAsmParser::Delimiter::None)) ||
            failed(parser.parseRBrace())) {
          return failure();
        }
        resultDims.append(dynamicDims);
      }
    }
    resultTypes.push_back(type);
    tiedOperandIndices.push_back(tiedOperandIndex);
  } while (succeeded(parser.parseOptionalComma()));
  if (!tiedOperandIndices.empty()) {
    tiedOperands = parser.getBuilder().getIndexArrayAttr(tiedOperandIndices);
  }
  return success();
}

static void printShapedResultList(OpAsmPrinter &p, TiedOpInterface tiedOp,
                                  ValueRange operands, TypeRange operandTypes,
                                  ValueRange operandDims, TypeRange resultTypes,
                                  ValueRange resultDims,
                                  ArrayAttr tiedOperands) {
  for (unsigned i = 0; i < resultTypes.size(); ++i) {
    auto resultType = resultTypes[i];
    auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(i);
    bool printType = true;
    if (tiedOperandIndex.has_value()) {
      auto tiedOperand = tiedOp->getOperand(tiedOperandIndex.value());
      p.printOperand(tiedOperand);
      if (tiedOperand.getType() != resultType) {
        p << " as ";
      } else {
        // Type elided as it matches the operand.
        printType = false;
      }
    }
    if (printType) {
      p.printType(resultType);
    }
    if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
      if (!shapedType.hasStaticShape()) {
        if (resultDims.empty()) {
          p << "{<<INVALID>>}";
          return;
        }
        p << "{";
        llvm::interleaveComma(
            resultDims.take_front(shapedType.getNumDynamicDims()), p,
            [&](Value value) { p.printOperand(value); });
        p << "}";
        resultDims = resultDims.drop_front(shapedType.getNumDynamicDims());
      }
    }
    if (i < resultTypes.size() - 1)
      p << ", ";
  }
}

static ParseResult parseShapedFunctionType(
    OpAsmParser &parser, ArrayRef<OpAsmParser::UnresolvedOperand> operands,
    SmallVectorImpl<Type> &operandTypes,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandDims,
    SmallVectorImpl<Type> &resultTypes,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultDims,
    ArrayAttr &tiedOperands) {
  if (failed(parser.parseLParen()))
    return failure();
  if (failed(parser.parseOptionalRParen())) {
    if (failed(parseShapedTypeList(parser, operandTypes, operandDims)) ||
        failed(parser.parseRParen())) {
      return failure();
    }
  }
  if (failed(parser.parseArrow()))
    return failure();
  if (succeeded(parser.parseOptionalLParen())) {
    if (succeeded(parser.parseOptionalRParen())) {
      // Empty list/no results `()`.
    } else {
      // One or more result types.
      if (failed(parseShapedResultList(parser, operands, operandTypes,
                                       operandDims, resultTypes, resultDims,
                                       tiedOperands)) ||
          failed(parser.parseRParen())) {
        return failure();
      }
    }
  } else {
    // Single result with omitted `()`.
    if (failed(parseShapedResultList(parser, operands, operandTypes,
                                     operandDims, resultTypes, resultDims,
                                     tiedOperands))) {
      return failure();
    }
  }
  return success();
}

static void printShapedFunctionType(OpAsmPrinter &p, TiedOpInterface tiedOp,
                                    ValueRange operands, TypeRange operandTypes,
                                    OperandRange operandDims,
                                    TypeRange resultTypes,
                                    OperandRange resultDims,
                                    ArrayAttr tiedOperands) {
  p << "(";
  llvm::interleaveComma(operandTypes, p, [&](Type type) {
    p.printType(type);
    if (auto shapedType = dyn_cast<ShapedType>(type)) {
      if (!shapedType.hasStaticShape()) {
        if (operandDims.empty()) {
          p << "{<<INVALID>>}";
          return;
        }
        p << "{";
        llvm::interleaveComma(
            operandDims.take_front(shapedType.getNumDynamicDims()), p,
            [&](Value value) { p.printOperand(value); });
        p << "}";
        operandDims = operandDims.drop_front(shapedType.getNumDynamicDims());
      }
    }
  });
  p << ") -> ";
  if (resultTypes.size() != 1)
    p << "(";
  printShapedResultList(p, tiedOp, operands, operandTypes, operandDims,
                        resultTypes, resultDims, tiedOperands);
  if (resultTypes.size() != 1)
    p << ")";
}

//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//

void GlobalOp::build(OpBuilder &builder, OperationState &result, StringRef name,
                     bool isMutable, Type type,
                     std::optional<TypedAttr> initialValue,
                     ArrayRef<NamedAttribute> attrs) {
  result.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
  if (isMutable) {
    result.addAttribute("is_mutable", builder.getUnitAttr());
  }
  if (initialValue.has_value()) {
    result.addAttribute("initial_value", initialValue.value());
  }
  result.addAttribute("type", TypeAttr::get(type));
  result.attributes.append(attrs.begin(), attrs.end());
}

void GlobalOp::build(OpBuilder &builder, OperationState &result, StringRef name,
                     bool isMutable, Type type,
                     ArrayRef<NamedAttribute> attrs) {
  build(builder, result, name, isMutable, type, std::nullopt, attrs);
}

// Returns true if the given |accessType| is compatible with the |globalType|.
// For example, this will return true if the global type is a tensor<?xf32>
// and the access is tensor<4xf32>.
static bool isGlobalTypeCompatible(Type globalType, Type accessType) {
  // If one is a shaped type, then they both must be and have compatible
  // shapes.
  if (isa<ShapedType>(globalType) && isa<ShapedType>(accessType)) {
    return succeeded(mlir::verifyCompatibleShape(globalType, accessType)) &&
           cast<ShapedType>(globalType).getElementType() ==
               cast<ShapedType>(accessType).getElementType();
  }

  // Permissively allow any other types to be marked compatible as long as
  // neither are shaped type.
  return !isa<ShapedType>(globalType) && !isa<ShapedType>(accessType);
}

LogicalResult
GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
  auto globalOp =
      symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getGlobalAttr());
  if (!globalOp) {
    return emitOpError() << "undefined global: " << getGlobal();
  }
  auto loadType = getResult().getType();
  if (!isGlobalTypeCompatible(globalOp.getType(), loadType)) {
    return emitOpError() << "global type mismatch; global " << getGlobal()
                         << " is " << globalOp.getType() << " but load is "
                         << loadType;
  }
  return success();
}

LogicalResult GlobalLoadIndirectOp::verify() {
  auto globalType = cast<PtrType>(getGlobal().getType()).getTargetType();
  auto loadType = getResult().getType();
  if (!isGlobalTypeCompatible(globalType, loadType)) {
    return emitOpError() << "global type mismatch; global pointer is "
                         << globalType << " but load is " << loadType;
  }
  return success();
}

LogicalResult
GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
  auto globalOp =
      symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getGlobalAttr());
  if (!globalOp) {
    return emitOpError() << "undefined global: " << getGlobal();
  }
  auto storeType = getValue().getType();
  if (!isGlobalTypeCompatible(globalOp.getType(), storeType)) {
    return emitOpError() << "global type mismatch; global " << getGlobal()
                         << " is " << globalOp.getType() << " but store is "
                         << storeType;
  }
  return success();
}

LogicalResult GlobalStoreIndirectOp::verify() {
  auto globalType = cast<PtrType>(getGlobal().getType()).getTargetType();
  auto storeType = getValue().getType();
  if (!isGlobalTypeCompatible(globalType, storeType)) {
    return emitOpError() << "global type mismatch; global pointer is "
                         << globalType << " but store is " << storeType;
  }
  return success();
}

//===----------------------------------------------------------------------===//
// !iree_input.byte_vuffer
//===----------------------------------------------------------------------===//

void ByteBufferConstantOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), getName().value_or("bytes_cst"));
}

LogicalResult ByteBufferConstantOp::verify() {
  if (auto minAlignmentAttr = getAlignmentAttr()) {
    int64_t minAlignment = minAlignmentAttr.getInt();
    if (minAlignment > 0 && !llvm::isPowerOf2_64(minAlignment)) {
      return emitOpError("invalid alignment; must be a power of two");
    }
  }
  return success();
}

//===----------------------------------------------------------------------===//
// iree_input.buffer.subspan
//===----------------------------------------------------------------------===//

void BufferSubspanOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "buffer");
}

//===----------------------------------------------------------------------===//
// iree_input.buffer_view.create
//===----------------------------------------------------------------------===//

void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state,
                               Value sourceBuffer, Value sourceOffset,
                               Value sourceLength, int32_t elementType,
                               int32_t encodingType, ValueRange shape) {
  build(builder, state, sourceBuffer, sourceOffset, sourceLength,
        builder.createOrFold<arith::ConstantIntOp>(state.location, elementType,
                                                   32),
        builder.createOrFold<arith::ConstantIntOp>(state.location, encodingType,
                                                   32),
        shape);
}

void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state,
                               Value sourceBuffer, Value sourceOffset,
                               Value sourceLength, Value elementType,
                               Value encodingType, ValueRange shape) {
  state.addOperands(
      {sourceBuffer, sourceOffset, sourceLength, elementType, encodingType});
  state.addOperands(shape);
  state.addTypes({BufferViewType::get(builder.getContext())});
}

//===----------------------------------------------------------------------===//
// iree_input.tensor.update
//===----------------------------------------------------------------------===//

Value TensorUpdateOp::getTiedResult(unsigned resultIndex) {
  return TiedOpInterface::findTiedBaseValue(getTarget());
}

std::optional<unsigned>
TensorUpdateOp::getTiedResultOperandIndex(unsigned resultIndex) {
  return {0}; // $target
}

SmallVector<int64_t> TensorUpdateOp::getTiedResultOperandIndices() {
  return {0}; // $target
}

//===----------------------------------------------------------------------===//
// iree_input.tensor.trace
//===----------------------------------------------------------------------===//

void TensorTraceOp::build(OpBuilder &builder, OperationState &state,
                          StringRef key, ValueRange values) {
  SmallVector<Value> dynamicDims;
  for (auto value : values) {
    auto valueType = cast<ShapedType>(value.getType());
    for (unsigned i = 0; i < valueType.getRank(); ++i) {
      if (valueType.isDynamicDim(i)) {
        dynamicDims.push_back(
            builder.createOrFold<tensor::DimOp>(state.location, value, i));
      }
    }
  }
  build(builder, state, key, values, dynamicDims);
}

//===----------------------------------------------------------------------===//
// iree_input.dispatch
//===----------------------------------------------------------------------===//

void DispatchOp::build(OpBuilder &builder, OperationState &state,
                       ExecutableExportOp exportOp, ValueRange workload,
                       TypeRange resultTypes, ValueRange resultDims,
                       ValueRange operands, ValueRange operandDims,
                       ArrayAttr tiedOperands,
                       ArrayRef<NamedAttribute> attributes) {
  StringRef executableOpSymName =
      exportOp->getParentOp()
          ->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
          .getValue();
  auto entryPoint =
      SymbolRefAttr::get(builder.getContext(), executableOpSymName,
                         {SymbolRefAttr::get(exportOp)});
  state.addAttribute("entry_point", entryPoint);
  state.addOperands(workload);
  state.addTypes(resultTypes);
  state.addOperands(operands);
  state.addOperands(operandDims);
  state.addOperands(resultDims);
  state.addAttributes(attributes);
  state.attributes.erase(TiedOpInterface::getStorageAttrName());
  state.addAttribute(TiedOpInterface::getStorageAttrName(), tiedOperands);
  state.attributes.erase(getOperandSegmentSizeAttr());
  state.addAttribute(getOperandSegmentSizeAttr(),
                     builder.getDenseI32ArrayAttr({
                         static_cast<int32_t>(workload.size()),
                         static_cast<int32_t>(operands.size()),
                         static_cast<int32_t>(operandDims.size()),
                         static_cast<int32_t>(resultDims.size()),
                     }));
}

LogicalResult DispatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
  Operation *op = getOperation();
  auto exportOp = symbolTable.lookupNearestSymbolFrom<ExecutableExportOp>(
      op, getEntryPoint());
  if (!exportOp) {
    return op->emitOpError() << "undefined entry point: " << getEntryPoint();
  }

  // TODO(ezhulenev): verify that the exported function has matching operands.
  return success();
}

std::pair<unsigned, unsigned> DispatchOp::getTiedOperandsIndexAndLength() {
  return getODSOperandIndexAndLength(1); // $arguments
}

//===----------------------------------------------------------------------===//
// iree_input.optimization_barrier
//===----------------------------------------------------------------------===//

void OptimizationBarrierOp::build(OpBuilder &builder, OperationState &state,
                                  ValueRange operands,
                                  ArrayRef<NamedAttribute> attributes) {
  state.addOperands(operands);
  state.addTypes(llvm::to_vector<2>(operands.getTypes()));
  state.addAttributes(attributes);
}

LogicalResult OptimizationBarrierOp::verify() {
  Operation *op = getOperation();
  if (op->getNumOperands() != op->getNumResults()) {
    return op->emitOpError()
           << "must have same number of operands and results, but has "
           << op->getNumOperands() << " and " << op->getNumResults()
           << ", respectively";
  }

  for (int i = 0, e = op->getNumOperands(); i < e; ++i) {
    if (op->getOperand(i).getType() != op->getResult(i).getType()) {
      op->emitOpError() << "must have same operand and result types, but they "
                           "differ at index "
                        << i;
    }
  }

  return success();
}

//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "iree-dialects/Dialect/Input/InputOps.cpp.inc"
