package utils

import (
	"crypto/aes"
	"encoding/binary"
	"io"
	"net"
)

// EncryptedConn is a sort of encrypted connection between socccks-client and socccks-server
type EncryptedConn struct {
	net.Conn
	Encryptor *Encryptor
}

// encrypted data protocol (bytes)
// header:
//   length 2
// body:
//   iv 32
//   encryptedData ...

func NewEncryptedConn(conn net.Conn, password string) *EncryptedConn {
	return &EncryptedConn{
		Conn:      conn,
		Encryptor: NewEncryptor(password),
	}
}

// encrypt plainText, then write them to the socket
func (ec *EncryptedConn) Write(rawData []byte) (nw int, err error) {
	encryptor := ec.Encryptor

	writeBuf := Pool33K.Get()
	defer Pool33K.Put(writeBuf)

	encryptBytesLength := encryptor.CFBEncrypter(rawData, writeBuf[2:])

	binary.BigEndian.PutUint16(writeBuf[:2], uint16(encryptBytesLength))
	encryptBytesLength += 2

	nw, ew := ec.Conn.Write(writeBuf[:encryptBytesLength])

	if ew != nil {
		err = ew
		return
	}

	if encryptBytesLength != nw {
		err = io.ErrShortWrite
		return
	}

	// In other place, I use io.Copy to proxy data between conns,
	// io.Copy requires the length of data received from src conn to be as same as the length of data written to the dst conn.
	// So it's necessary to return "expected data" length instead of the "real data length"
	nw = encryptBytesLength - 2 - aes.BlockSize
	return
}

// read Encrypted data, fill buf with plainText
func (ec *EncryptedConn) Read(buf []byte) (rn int, err error) {
	encryptor := ec.Encryptor

	readBuffer := Pool33K.Get()
	defer Pool33K.Put(readBuffer)

	if _, er := io.ReadFull(ec.Conn, readBuffer[:2]); er != nil {
		err = er
		return
	}

	dataLen := binary.BigEndian.Uint16(readBuffer[:2])

	_, er := io.ReadFull(ec.Conn, readBuffer[:dataLen])
	if er != nil {
		err = er
		return
	}

	rn = encryptor.CFBDecrypter(readBuffer[:dataLen], buf)

	return
}
