// Copyright 2022 The Cockroach Authors.
//
// Use of this software is governed by the CockroachDB Software License
// included in the /LICENSE file.

package sqlproxyccl

import (
	"encoding/binary"
	"io"
	"net"
	"time"

	"github.com/cockroachdb/cockroach/pkg/util/randutil"
	"github.com/cockroachdb/cockroach/pkg/util/syncutil"
	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
	"github.com/cockroachdb/errors"
	pgproto3 "github.com/jackc/pgproto3/v2"
)

// cancelInfo contains the information that sqlproxy needs in order to cancel
// a query using the pgwire cancellationo protocol.
type cancelInfo struct {
	// proxyBackendKeyData is the cancel key generated by sqlproxy, which is
	// sent to the client.
	proxyBackendKeyData *pgproto3.BackendKeyData
	// clientAddr is the address where proxyBackendKeyData is sent to.
	clientAddr *net.TCPAddr

	// mu protects the fields of cancelInfo that can be modified as a result of
	// a session transfer.
	mu struct {
		syncutil.RWMutex
		// origBackendKeyData is the cancel key originally generated by a SQL node.
		origBackendKeyData *pgproto3.BackendKeyData
		// crdbAddr is the address of the SQL node that generated origBackendKeyData.
		crdbAddr net.Addr
	}
}

// makeCancelInfo creates a new cancelInfo struct based on the provided data.
// The caller of the connector is responsible for setting the backend address
// and cancel key.
func makeCancelInfo(localAddr, clientAddr net.Addr) *cancelInfo {
	proxySecretID := randutil.FastUint32()
	proxyKeyData := &pgproto3.BackendKeyData{
		ProcessID: encodeIP(localAddr.(*net.TCPAddr).IP),
		SecretKey: proxySecretID,
	}
	return &cancelInfo{
		proxyBackendKeyData: proxyKeyData,
		clientAddr:          clientAddr.(*net.TCPAddr),
	}
}

// encodeIP returns a uint32 that contains the given IPv4 address. If the
// address is IPv6, then 0 is returned.
func encodeIP(src net.IP) uint32 {
	i := src.To4()
	if i == nil {
		// i may be nil if the address was an IPv6 address.
		i = make([]byte, 4)
	}
	return binary.BigEndian.Uint32(i)
}

// decodeIP returns the IP address that is encoded in the uint32.
func decodeIP(src uint32) net.IP {
	ip := make(net.IP, 4)
	binary.BigEndian.PutUint32(ip, src)
	return ip
}

// proxySecretID returns the random secret that was generated to make this
// cancelInfo.
func (c *cancelInfo) proxySecretID() uint32 {
	return c.proxyBackendKeyData.SecretKey
}

// setNewBackend atomically sets a new backend cancel key and address.
func (c *cancelInfo) setNewBackend(
	newBackendKeyData *pgproto3.BackendKeyData, newCrdbAddr *net.TCPAddr,
) {
	c.mu.Lock()
	defer c.mu.Unlock()
	c.mu.origBackendKeyData = newBackendKeyData
	c.mu.crdbAddr = newCrdbAddr
}

// sendCancelToBackend sends a cancel request to the backend after checking that
// the given client IP is allowed to send this request.
func (c *cancelInfo) sendCancelToBackend(requestClientIP net.IP) error {
	const timeout = 2 * time.Second
	if !c.clientAddr.IP.Equal(requestClientIP) {
		// If the IP associated with the cancelInfo does not match the IP from
		// which the request came, then ignore it.
		return errors.Errorf("mismatched client IP for cancel request")
	}
	var crdbAddr net.Addr
	var origBackendKeyData *pgproto3.BackendKeyData
	func() {
		c.mu.RLock()
		defer c.mu.RUnlock()
		crdbAddr = c.mu.crdbAddr
		origBackendKeyData = c.mu.origBackendKeyData
	}()
	cancelConn, err := net.DialTimeout("tcp", crdbAddr.String(), timeout)
	if err != nil {
		return err
	}
	defer cancelConn.Close()
	if err := cancelConn.SetDeadline(timeutil.Now().Add(timeout)); err != nil {
		return err
	}
	crdbRequest := &pgproto3.CancelRequest{
		ProcessID: origBackendKeyData.ProcessID,
		SecretKey: origBackendKeyData.SecretKey,
	}
	buf := crdbRequest.Encode(nil /* buf */)
	if _, err := cancelConn.Write(buf); err != nil {
		return err
	}
	if _, err := cancelConn.Read(buf); err != io.EOF {
		return err
	}
	return nil
}

// cancelInfoMap contains all the cancelInfo objects that this proxy instance
// is aware of. It is safe for concurrent use, and is keyed by a secret that
// is shared between the proxy and the client.
type cancelInfoMap struct {
	syncutil.RWMutex
	m map[uint32]*cancelInfo
}

func makeCancelInfoMap() *cancelInfoMap {
	return &cancelInfoMap{
		m: make(map[uint32]*cancelInfo),
	}
}

func (c *cancelInfoMap) addCancelInfo(proxySecretID uint32, info *cancelInfo) {
	c.Lock()
	defer c.Unlock()
	c.m[proxySecretID] = info
}

func (c *cancelInfoMap) deleteCancelInfo(proxySecretID uint32) {
	c.Lock()
	defer c.Unlock()
	delete(c.m, proxySecretID)
}

func (c *cancelInfoMap) getCancelInfo(proxySecretID uint32) (*cancelInfo, bool) {
	c.RLock()
	defer c.RUnlock()
	i, ok := c.m[proxySecretID]
	return i, ok
}

const proxyCancelRequestLen = 12

// proxyCancelRequest is a pgwire cancel request that is forwarded from
// one proxy to another.
type proxyCancelRequest struct {
	ProxyIP   net.IP
	SecretKey uint32
	ClientIP  net.IP
}

// Decode decodes src into r.
func (r *proxyCancelRequest) Decode(src []byte) error {
	if len(src) != proxyCancelRequestLen {
		return errors.New("bad cancel request size")
	}
	r.ProxyIP = decodeIP(binary.BigEndian.Uint32(src))
	r.SecretKey = binary.BigEndian.Uint32(src[4:])
	r.ClientIP = decodeIP(binary.BigEndian.Uint32(src[8:]))
	return nil
}

// Encode encodes r and returns the bytes.
func (r *proxyCancelRequest) Encode() []byte {
	dst := make([]byte, proxyCancelRequestLen)
	binary.BigEndian.PutUint32(dst, encodeIP(r.ProxyIP))
	binary.BigEndian.PutUint32(dst[4:], r.SecretKey)
	binary.BigEndian.PutUint32(dst[8:], encodeIP(r.ClientIP))
	return dst
}
