#!/Library/AutoPkg/Python3/Python.framework/Versions/Current/bin/python3
#
# Copyright 2014-2015 Greg Neagle
# Based heavily on Per Olofsson's autopkgserver
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A server that installs Apple packages"""

import logging
import logging.handlers
import os
import plistlib
import socket
import socketserver
import stat
import struct
import sys
import time

from installer import Installer, InstallerError
from itemcopier import ItemCopier, ItemCopierError

APPNAME = "autopkginstalld"
VERSION = "0.2"


class RunHandler(socketserver.StreamRequestHandler):
    """Handler for autopkginstalld run requests"""

    def verify_request_syntax(self, plist):
        """Verify the basic syntax of an installation request plist."""

        # Keep a list of error messages.
        errors = []

        # Root should be a dictionary.
        if not isinstance(plist, dict):
            errors.append("Request root is not a dictionary")
            # Bail out early if it's not.
            return False, errors

        syntax_ok = True
        for key in ["package"]:
            if key not in plist:
                syntax_ok = False
                errors.append(f"Request does not contain {key}")

        return syntax_ok, errors

    def getpeerid(self):
        """
        Get peer credentials on a UNIX domain socket.
        Returns uid, gids.
        """

        # /usr/include/sys/ucred.h
        #
        # struct xucred {
        #         u_int   cr_version;           /* structure layout version */
        #         uid_t   cr_uid;               /* effective user id */
        #         short   cr_ngroups;           /* number of advisory groups */
        #         gid_t   cr_groups[NGROUPS];   /* advisory group list */
        # };

        LOCAL_PEERCRED = 0x001
        XUCRED_VERSION = 0
        NGROUPS = 16
        cr_version = 0
        cr_uid = 1
        cr_ngroups = 2
        cr_groups = 3

        xucred_fmt = "IIh%dI" % NGROUPS
        res = struct.unpack(
            xucred_fmt,
            self.request.getsockopt(0, LOCAL_PEERCRED, struct.calcsize(xucred_fmt)),
        )

        if res[cr_version] != XUCRED_VERSION:
            raise OSError("Incompatible struct xucred version")

        return res[cr_uid], res[cr_groups : cr_groups + res[cr_ngroups]]

    def handle(self):
        """Handle an incoming run request."""

        try:
            # Log through server parent.
            self.log = self.server.log
            self.log.debug("Handling request")

            # Get uid and primary gid of connecting peer.
            uid, gids = self.getpeerid()
            gid = gids[0]
            self.log.debug("Got run request from uid %d gid %d" % (uid, gid))

            # Receive a plist.
            plist_string = self.request.recv(8192)

            # Try to parse it.
            try:
                plist = plistlib.loads(plist_string)
            except BaseException:
                self.log.error("Malformed request")
                self.request.send("ERROR:Malformed request\n".encode())
                return
            self.log.debug("Parsed request plist")

            if "package" in plist:
                self.log.info(
                    "Dispatching Installer worker to process request for "
                    "user %d" % uid
                )
                try:
                    installer = Installer(self.log, self.request, plist)
                    installer.install()
                    self.request.send("OK:DONE\n".encode())
                except InstallerError as err:
                    self.log.error(f"Installing failed: {err}")
                    self.request.send(f"ERROR:{err}\n".encode())
            elif "mount_point" in plist:
                self.log.info(
                    "Dispatching ItemCopier worker to process request for "
                    "user %d" % uid
                )
                try:
                    copier = ItemCopier(self.log, self.request, plist)
                    copier.copy()
                    self.request.send("OK:DONE\n".encode())
                except ItemCopierError as err:
                    self.log.error(f"Copying failed: {err}")
                    self.request.send(f"ERROR:{err}\n".encode())
            else:
                self.log.error("Unsupported request format")
                self.request.send("ERROR:Unsupported request format".encode())

        except BaseException as err:
            self.log.error(f"Caught exception: {repr(err)}")
            self.request.send(f"ERROR:Caught exception: {repr(err)}".encode())


class AutoPkgInstallDaemonError(Exception):
    """Exception class for AutoPkgInstallDaemon errors"""

    pass


class AutoPkgInstallDaemon(socketserver.UnixStreamServer):
    """Daemon that runs as root,
    receiving requests to install packages."""

    allow_reuse_address = True
    request_queue_size = 10
    timeout = 10

    def __init__(self, socket_fd, RequestHandlerClass):
        # Avoid initialization of UnixStreamServer as we need to open the
        # socket from a file descriptor instead of creating our own.
        self.socket = socket.fromfd(socket_fd, socket.AF_UNIX, socket.SOCK_STREAM)
        self.socket.listen(self.request_queue_size)
        socketserver.BaseServer.__init__(
            self, self.socket.getsockname(), RequestHandlerClass
        )
        self.timed_out = False

    def setup_logging(self):
        """Sets up logging to a file in /private/var/log/"""
        try:
            self.log = logging.getLogger(APPNAME)
            self.log.setLevel(logging.DEBUG)

            log_console = logging.StreamHandler()
            log_console.setLevel(logging.DEBUG)
            log_file = logging.handlers.RotatingFileHandler(
                f"/private/var/log/{APPNAME}", "a", 100000, 9, "utf-8"
            )
            log_file.setLevel(logging.DEBUG)

            console_formatter = logging.Formatter("%(message)s")
            file_formatter = logging.Formatter(
                "%(asctime)s %(module)s[%(process)d]: " "%(message)s   (%(funcName)s)"
            )

            log_console.setFormatter(console_formatter)
            log_file.setFormatter(file_formatter)

            self.log.addHandler(log_console)
            self.log.addHandler(log_file)
        except OSError as err:
            raise AutoPkgInstallDaemonError(f"Can't open log: {err.strerror}")

    def handle_timeout(self):
        self.timed_out = True


def main(argv):
    # Make sure we're launched as root
    if os.geteuid() != 0:
        print(f"{APPNAME} must be run as root.", file=sys.stderr)
        # Sleep to avoid respawn.
        time.sleep(10)
        return 1

    # Make sure that the executable and all containing directories are owned
    # by root:wheel or root:admin, and not writeable by other users.
    root_uid = 0
    wheel_gid = 0
    admin_gid = 80

    exepath = os.path.realpath(os.path.abspath(sys.argv[0]))
    path_ok = True
    while True:
        info = os.stat(exepath)
        if info.st_uid != root_uid:
            print(f"{exepath} must be owned by root.", file=sys.stderr)
            path_ok = False
        if info.st_gid not in (wheel_gid, admin_gid):
            print(f"{exepath} must have group wheel or admin.", file=sys.stderr)
            path_ok = False
        if info.st_mode & stat.S_IWOTH:
            print(f"{exepath} mustn't be world writeable.", file=sys.stderr)
            path_ok = False
        exepath = os.path.dirname(exepath)
        if exepath == "/":
            break

    if not path_ok:
        # Sleep to avoid respawn.
        time.sleep(10)
        return 1

    # Keep track of time for launchd.
    start_time = time.time()

    # Get socket file descriptors from launchd.
    import launch2

    try:
        sockets = launch2.launch_activate_socket("autopkginstalld")
    except launch2.LaunchDError as err:
        print(f"launchd check-in failed: {err}", file=sys.stderr)
        time.sleep(10)
        return 1

    sock_fd = sockets[0]

    # Create the daemon object.
    daemon = AutoPkgInstallDaemon(sock_fd, RunHandler)
    daemon.setup_logging()
    daemon.log.info(f"{APPNAME} v{VERSION} starting")

    # Serve all pending requests until we time out.
    while True:
        daemon.handle_request()
        if not daemon.timed_out:
            continue

        # Keep running for at least 10 seconds make launchd happy.
        run_time = time.time() - start_time
        daemon.log.info("run time: %fs" % run_time)
        if run_time < 10.0:
            # Only sleep for a short while in case new requests pop up.
            sleep_time = min(1.0, 10.0 - run_time)
            daemon.log.debug(
                "sleeping for %f seconds to make launchd happy" % sleep_time
            )
            time.sleep(sleep_time)
        else:
            break

    return 0


if __name__ == "__main__":
    sys.exit(main(sys.argv))
