// Copyright 2017 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package chunk

import (
	"bytes"
	"cmp"
	"sort"

	"github.com/pingcap/tidb/pkg/parser/mysql"
	"github.com/pingcap/tidb/pkg/types"
)

// CompareFunc is a function to compare the two values in Row, the two columns must have the same type.
type CompareFunc = func(l Row, lCol int, r Row, rCol int) int

// GetCompareFunc gets a compare function for the field type.
func GetCompareFunc(tp *types.FieldType) CompareFunc {
	switch tp.GetType() {
	case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear:
		if mysql.HasUnsignedFlag(tp.GetFlag()) {
			return cmpUint64
		}
		return cmpInt64
	case mysql.TypeFloat:
		return cmpFloat32
	case mysql.TypeDouble:
		return cmpFloat64
	case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar,
		mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
		return genCmpStringFunc(tp.GetCollate())
	case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
		return cmpTime
	case mysql.TypeDuration:
		return cmpDuration
	case mysql.TypeNewDecimal:
		return cmpMyDecimal
	case mysql.TypeSet, mysql.TypeEnum:
		return cmpNameValue
	case mysql.TypeBit:
		return cmpBit
	case mysql.TypeJSON:
		return cmpJSON
	case mysql.TypeTiDBVectorFloat32:
		return cmpVectorFloat32
	case mysql.TypeNull:
		return cmpNullConst
	}
	return nil
}

func cmpNull(lNull, rNull bool) int {
	if lNull && rNull {
		return 0
	}
	if lNull {
		return -1
	}
	return 1
}

func cmpInt64(l Row, lCol int, r Row, rCol int) int {
	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
	if lNull || rNull {
		return cmpNull(lNull, rNull)
	}
	return cmp.Compare(l.GetInt64(lCol), r.GetInt64(rCol))
}

func cmpUint64(l Row, lCol int, r Row, rCol int) int {
	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
	if lNull || rNull {
		return cmpNull(lNull, rNull)
	}
	return cmp.Compare(l.GetUint64(lCol), r.GetUint64(rCol))
}

func genCmpStringFunc(collation string) func(l Row, lCol int, r Row, rCol int) int {
	return func(l Row, lCol int, r Row, rCol int) int {
		return cmpStringWithCollationInfo(l, lCol, r, rCol, collation)
	}
}

func cmpStringWithCollationInfo(l Row, lCol int, r Row, rCol int, collation string) int {
	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
	if lNull || rNull {
		return cmpNull(lNull, rNull)
	}
	return types.CompareString(l.GetString(lCol), r.GetString(rCol), collation)
}

func cmpFloat32(l Row, lCol int, r Row, rCol int) int {
	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
	if lNull || rNull {
		return cmpNull(lNull, rNull)
	}
	return cmp.Compare(float64(l.GetFloat32(lCol)), float64(r.GetFloat32(rCol)))
}

func cmpFloat64(l Row, lCol int, r Row, rCol int) int {
	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
	if lNull || rNull {
		return cmpNull(lNull, rNull)
	}
	return cmp.Compare(l.GetFloat64(lCol), r.GetFloat64(rCol))
}

func cmpMyDecimal(l Row, lCol int, r Row, rCol int) int {
	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
	if lNull || rNull {
		return cmpNull(lNull, rNull)
	}
	lDec, rDec := l.GetMyDecimal(lCol), r.GetMyDecimal(rCol)
	return lDec.Compare(rDec)
}

func cmpTime(l Row, lCol int, r Row, rCol int) int {
	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
	if lNull || rNull {
		return cmpNull(lNull, rNull)
	}
	lTime, rTime := l.GetTime(lCol), r.GetTime(rCol)
	return lTime.Compare(rTime)
}

func cmpDuration(l Row, lCol int, r Row, rCol int) int {
	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
	if lNull || rNull {
		return cmpNull(lNull, rNull)
	}
	lDur, rDur := l.GetDuration(lCol, 0).Duration, r.GetDuration(rCol, 0).Duration
	return cmp.Compare(int64(lDur), int64(rDur))
}

func cmpNameValue(l Row, lCol int, r Row, rCol int) int {
	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
	if lNull || rNull {
		return cmpNull(lNull, rNull)
	}
	_, lVal := l.getNameValue(lCol)
	_, rVal := r.getNameValue(rCol)
	return cmp.Compare(lVal, rVal)
}

func cmpBit(l Row, lCol int, r Row, rCol int) int {
	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
	if lNull || rNull {
		return cmpNull(lNull, rNull)
	}
	lBit := types.BinaryLiteral(l.GetBytes(lCol))
	rBit := types.BinaryLiteral(r.GetBytes(rCol))
	return lBit.Compare(rBit)
}

func cmpJSON(l Row, lCol int, r Row, rCol int) int {
	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
	if lNull || rNull {
		return cmpNull(lNull, rNull)
	}
	lJ, rJ := l.GetJSON(lCol), r.GetJSON(rCol)
	return types.CompareBinaryJSON(lJ, rJ)
}

func cmpNullConst(_ Row, _ int, _ Row, _ int) int {
	return 0
}

func cmpVectorFloat32(l Row, lCol int, r Row, rCol int) int {
	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
	if lNull || rNull {
		return cmpNull(lNull, rNull)
	}
	lv, rv := l.GetVectorFloat32(lCol), r.GetVectorFloat32(rCol)
	return lv.Compare(rv)
}

// Compare compares the value with ad.
// We assume that the collation information of the column is the same with the datum.
func Compare(row Row, colIdx int, ad *types.Datum) int {
	switch ad.Kind() {
	case types.KindNull:
		if row.IsNull(colIdx) {
			return 0
		}
		return 1
	case types.KindMinNotNull:
		if row.IsNull(colIdx) {
			return -1
		}
		return 1
	case types.KindMaxValue:
		return -1
	case types.KindInt64:
		return cmp.Compare(row.GetInt64(colIdx), ad.GetInt64())
	case types.KindUint64:
		return cmp.Compare(row.GetUint64(colIdx), ad.GetUint64())
	case types.KindFloat32:
		return cmp.Compare(float64(row.GetFloat32(colIdx)), float64(ad.GetFloat32()))
	case types.KindFloat64:
		return cmp.Compare(row.GetFloat64(colIdx), ad.GetFloat64())
	case types.KindString:
		return types.CompareString(row.GetString(colIdx), ad.GetString(), ad.Collation())
	case types.KindBytes, types.KindBinaryLiteral, types.KindMysqlBit:
		return bytes.Compare(row.GetBytes(colIdx), ad.GetBytes())
	case types.KindMysqlDecimal:
		l, r := row.GetMyDecimal(colIdx), ad.GetMysqlDecimal()
		return l.Compare(r)
	case types.KindMysqlDuration:
		l, r := row.GetDuration(colIdx, 0).Duration, ad.GetMysqlDuration().Duration
		return cmp.Compare(int64(l), int64(r))
	case types.KindMysqlEnum:
		l, r := row.GetEnum(colIdx).Value, ad.GetMysqlEnum().Value
		return cmp.Compare(l, r)
	case types.KindMysqlSet:
		l, r := row.GetSet(colIdx).Value, ad.GetMysqlSet().Value
		return cmp.Compare(l, r)
	case types.KindMysqlJSON:
		l, r := row.GetJSON(colIdx), ad.GetMysqlJSON()
		return types.CompareBinaryJSON(l, r)
	case types.KindVectorFloat32:
		l, r := row.GetVectorFloat32(colIdx), ad.GetVectorFloat32()
		return l.Compare(r)
	case types.KindMysqlTime:
		l, r := row.GetTime(colIdx), ad.GetMysqlTime()
		return l.Compare(r)
	default:
		return 0
	}
}

// LowerBound searches on the non-decreasing Column colIdx,
// returns the smallest index i such that the value at row i is not less than `d`.
func (c *Chunk) LowerBound(colIdx int, d *types.Datum) (index int, match bool) {
	if Compare(c.GetRow(c.NumRows()-1), colIdx, d) < 0 {
		return c.NumRows(), false
	}
	index = sort.Search(c.NumRows(), func(i int) bool {
		cmp := Compare(c.GetRow(i), colIdx, d)
		if cmp == 0 {
			match = true
		}
		return cmp >= 0
	})
	return
}

// UpperBound searches on the non-decreasing Column colIdx,
// returns the smallest index i such that the value at row i is larger than `d`.
func (c *Chunk) UpperBound(colIdx int, d *types.Datum) int {
	return sort.Search(c.NumRows(), func(i int) bool {
		return Compare(c.GetRow(i), colIdx, d) > 0
	})
}
