﻿using System;
using System.Buffers.Binary;
using System.Net;
using System.Net.Sockets;
using System.Text;

using Renci.SshNet.Abstractions;
using Renci.SshNet.Common;

namespace Renci.SshNet.Connection
{
    /// <summary>
    /// Establishes a tunnel via a SOCKS4 proxy server.
    /// </summary>
    /// <remarks>
    /// https://www.openssh.com/txt/socks4.protocol.
    /// </remarks>
    internal sealed class Socks4Connector : ProxyConnector
    {
        public Socks4Connector(ISocketFactory socketFactory)
            : base(socketFactory)
        {
        }

        /// <summary>
        /// Establishes a connection to the server via a SOCKS5 proxy.
        /// </summary>
        /// <param name="connectionInfo">The connection information.</param>
        /// <param name="socket">The <see cref="Socket"/>.</param>
        protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socket socket)
        {
            var connectionRequest = CreateSocks4ConnectionRequest(connectionInfo.Host, (ushort)connectionInfo.Port, connectionInfo.ProxyUsername);
            SocketAbstraction.Send(socket, connectionRequest);

            // Read reply version
            if (SocketReadByte(socket, connectionInfo.Timeout) != 0x00)
            {
                throw new ProxyException("SOCKS4: Null is expected.");
            }

            // Read response code
            var code = SocketReadByte(socket, connectionInfo.Timeout);

            switch (code)
            {
                case 0x5a:
                    break;
                case 0x5b:
                    throw new ProxyException("SOCKS4: Connection rejected.");
                case 0x5c:
                    throw new ProxyException("SOCKS4: Client is not running identd or not reachable from the server.");
                case 0x5d:
                    throw new ProxyException("SOCKS4: Client's identd could not confirm the user ID string in the request.");
                default:
                    throw new ProxyException("SOCKS4: Not valid response.");
            }

            var destBuffer = new byte[6]; // destination port and IP address should be ignored
            _ = SocketRead(socket, destBuffer, 0, destBuffer.Length, connectionInfo.Timeout);
        }

        private static byte[] CreateSocks4ConnectionRequest(string hostname, ushort port, string username)
        {
            var addressBytes = GetSocks4DestinationAddress(hostname);
            var proxyUserBytes = GetProxyUserBytes(username);

            var connectionRequest = new byte[// SOCKS version number
                                             1 +

                                             // Command code
                                             1 +

                                             // Port number
                                             2 +

                                             // IP address
                                             addressBytes.Length +

                                             // Username
                                             proxyUserBytes.Length +

                                             // Null terminator
                                             1];

            var index = 0;

            // SOCKS version number
            connectionRequest[index++] = 0x04;

            // Command code
            connectionRequest[index++] = 0x01; // establish a TCP/IP stream connection

            // Port number
            BinaryPrimitives.WriteUInt16BigEndian(connectionRequest.AsSpan(index), port);
            index += 2;

            // Address
            Buffer.BlockCopy(addressBytes, 0, connectionRequest, index, addressBytes.Length);
            index += addressBytes.Length;

            // User name
            Buffer.BlockCopy(proxyUserBytes, 0, connectionRequest, index, proxyUserBytes.Length);
            index += proxyUserBytes.Length;

            // Null terminator
            connectionRequest[index] = 0x00;

            return connectionRequest;
        }

        private static byte[] GetSocks4DestinationAddress(string hostname)
        {
            var addresses = Dns.GetHostAddresses(hostname);

            for (var i = 0; i < addresses.Length; i++)
            {
                var address = addresses[i];
                if (address.AddressFamily == AddressFamily.InterNetwork)
                {
                    return address.GetAddressBytes();
                }
            }

            throw new ProxyException(string.Format("SOCKS4 only supports IPv4. No such address found for '{0}'.", hostname));
        }

        private static byte[] GetProxyUserBytes(string proxyUser)
        {
            if (proxyUser == null)
            {
                return Array.Empty<byte>();
            }

            return Encoding.ASCII.GetBytes(proxyUser);
        }
    }
}
