#!/Library/AutoPkg/Python3/Python.framework/Versions/Current/bin/python3
#
# Copyright 2010 Per Olofsson
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""LaunchDaemon that creates installer packages.

A packaging request is made by creating a property list file and sending it to
the daemon's listening socket.

Protocol:

    When a client connects to the server it should send an XML property list
    with a packaging request. The server will process the request, and if
    there are any problems it'll reply with one or more lines of:

        ERROR:<error message 1>
        ERROR:<error message 2>
        ERROR:<error message 3...>

    Otherwise it'll reply with OK and the path to the package:

        OK:/Users/example/Downloads/example.pkg

Request format:

    Requests should be xml property list with a dictionary as the root object
    and the following keys:

    pkgroot         The virtual root of the files to be packaged, must be a
                    directory owned by the calling user.

    pkgdir          The directory where the pkg will be created, must be a
                    directory owned by the calling user.

    pkgname         The file name of the pkg, without pkg extension.

    pkgtype         The package type. "flat" is the only valid value.

    id              The package ID.

    version         The package version.

    infofile        A path to an Info.plist file for the package.

    resources       A path to a directory to be included as Resources in the
                    package.

    chown           An array of dictionaries with paths, relative to the
                    pkgroot, and the desired owner. Keys:

                    path    Path relative to pkgroot. Symlinks and .. are not
                            allowed in the path.
                    user    A string with a user name, or an int with a uid.
                    group   A string with a group name, or an int with a gid.
                    mode    A string with the mode in octal notation, ie. '0755'

    scripts         A string with the path to a scripts directory. This is
                    supported only for flat package types, and passed directly
                    to the '--scripts' option of pkgbuild.

"""

import logging
import logging.handlers
import os
import plistlib
import re
import socket
import socketserver
import stat
import struct
import sys
import time

from packager import Packager, PackagerError

###############
## Constants ##
###############


APPNAME = "autopkgserver"
VERSION = "0.2"

SOCKET = f"/var/run/{APPNAME}"

request_structure = {
    "pkgroot": str,
    "pkgdir": str,
    "pkgname": str,
    "pkgtype": str,
    "id": str,
    "version": str,
    "infofile": str,
    "resources": str,
    "chown": list,
    "scripts": str,
}
chown_structure = {"path": str, "user": (str, int), "group": (str, int), "mode": str}


#################
## Global Init ##
#################


class AutoPkgServerError(Exception):
    """Exception class for AutoPkgServer errors"""

    pass


class PkgHandler(socketserver.StreamRequestHandler):
    """Handler for packaging requests."""

    re_uid_gid = re.compile(r"^AUTH:AUTHENTICATE (?P<uid>\d{1,10}):(?P<gid>\d{1,10})$")

    def verify_request_syntax(self, plist):
        """Verify the basic syntax of request plist."""

        # Keep a list of error messages.
        errors = []

        # Root should be a dictionary.
        if not isinstance(plist, dict):
            errors.append("Request root is not a dictionary")
            # Bail out early if it's not.
            return (False, errors)

        syntax_ok = True

        # Verify existance and type of keys in the root.
        for key, keytype in list(request_structure.items()):
            if key not in plist:
                errors.append(f"Request is missing key '{key}'")
                syntax_ok = False
            else:
                if not isinstance(plist[key], keytype):
                    errors.append(f"Request key {key} is not of type {str(keytype)}")
                    syntax_ok = False

        if syntax_ok:
            # Check package type.
            if plist["pkgtype"] not in "flat":
                errors.append("pkgtype must be flat")
                syntax_ok = False

            # Make sure all chown entries are dictionaries, and children are
            # the correct type.
            for chown_entry in plist["chown"]:
                if not isinstance(chown_entry, dict):
                    errors.append("chown entry isn't dictionary")
                    syntax_ok = False
                for key, keytype in list(chown_structure.items()):
                    if key not in chown_entry:
                        errors.append(f"chown entry is missing {key}")
                    else:
                        if not isinstance(chown_entry[key], keytype):
                            errors.append(
                                f"Request key chown.{key} is not of type {str(keytype)}"
                            )
                            syntax_ok = False

        return (syntax_ok, errors)

    def getpeerid(self):
        """
        Get peer credentials on a UNIX domain socket.
        Returns uid, gids.
        """

        # /usr/include/sys/ucred.h
        #
        # struct xucred {
        #         u_int   cr_version;            /* structure layout version */
        #         uid_t   cr_uid;                /* effective user id */
        #         short   cr_ngroups;            /* number of advisory groups */
        #         gid_t   cr_groups[NGROUPS];    /* advisory group list */
        # };

        LOCAL_PEERCRED = 0x001
        XUCRED_VERSION = 0
        NGROUPS = 16
        cr_version = 0
        cr_uid = 1
        cr_ngroups = 2
        cr_groups = 3

        xucred_fmt = "IIh%dI" % NGROUPS
        res = struct.unpack(
            xucred_fmt,
            self.request.getsockopt(0, LOCAL_PEERCRED, struct.calcsize(xucred_fmt)),
        )

        if res[cr_version] != XUCRED_VERSION:
            raise OSError("Incompatible struct xucred version")

        return res[cr_uid], res[cr_groups : cr_groups + res[cr_ngroups]]

    def handle(self):
        """Handle an incoming packaging request."""

        try:
            # Log through server parent.
            self.log = self.server.log
            self.log.debug("Handling request")

            # Get uid and primary gid of connecting peer.
            uid, gids = self.getpeerid()
            gid = gids[0]
            self.log.debug("Got packaging request from uid %d gid %d" % (uid, gid))

            # Receive a plist.
            plist_string = self.request.recv(8192)

            # Try to parse it.
            try:
                plist = plistlib.loads(plist_string)
            except BaseException:
                self.log.error("Malformed request")
                self.request.send("ERROR:Malformed request\n".encode())
                return
            self.log.debug("Parsed request plist")

            # Verify the plist syntax.
            syntax_ok, errors = self.verify_request_syntax(plist)
            if not syntax_ok:
                self.log.error("Plist syntax error")
                self.request.send("".join([f"ERROR:{e}\n" for e in errors]).encode())
                return

            if os.path.exists(plist["pkgroot"]):
                name = os.path.basename(plist["pkgroot"])
            else:
                self.request.send("ERROR:Can't find pkgroot".encode())
                return

            self.log.info(f"Dispatching worker to process request for user {uid}")
            try:
                pkgr = Packager(self.log, plist, name, uid, gid)
                pkgpath = pkgr.package()
                self.log.info(f"Package built at {pkgpath}")
                print("Sending request")
                self.request.send(f"OK:{pkgpath}\n".encode())
            except PackagerError as err:
                self.log.error(f"Packaging failed: {err}")
                self.request.send(f"{err}\n".encode())

        except BaseException as err:
            self.log.error(f"Caught exception: {err}")
            self.request.send(f"ERROR:Caught exception: {err}".encode())
            return


class AutoPkgServer(socketserver.UnixStreamServer):
    """Daemon that runs as root, receiving requests to create installer
    packages."""

    allow_reuse_address = True
    request_queue_size = 10
    timeout = 10

    def __init__(self, socket_fd, RequestHandlerClass):
        # Avoid initialization of UnixStreamServer as we need to open the
        # socket from a file descriptor instead of creating our own.
        self.socket = socket.fromfd(socket_fd, socket.AF_UNIX, socket.SOCK_STREAM)
        self.socket.listen(self.request_queue_size)
        socketserver.BaseServer.__init__(
            self, self.socket.getsockname(), RequestHandlerClass
        )
        self.timed_out = False

    # def server_bind(self):
    #    """Override binding to inherit socket from launchd."""
    #

    def setup_logging(self):
        try:
            self.log = logging.getLogger(APPNAME)
            self.log.setLevel(logging.DEBUG)

            log_console = logging.StreamHandler()
            log_console.setLevel(logging.DEBUG)
            log_file = logging.handlers.RotatingFileHandler(
                f"/private/var/log/{APPNAME}", "a", 100000, 9, "utf-8"
            )
            log_file.setLevel(logging.DEBUG)

            console_formatter = logging.Formatter("%(message)s")
            file_formatter = logging.Formatter(
                "%(asctime)s %(module)s[%(process)d]: " + "%(message)s   (%(funcName)s)"
            )

            log_console.setFormatter(console_formatter)
            log_file.setFormatter(file_formatter)

            self.log.addHandler(log_console)
            self.log.addHandler(log_file)
        except OSError as err:
            raise AutoPkgServerError(f"Can't open log: {err.strerror}")

    def handle_timeout(self):
        self.timed_out = True


def main(argv):
    # Make sure we're launched as root
    if os.geteuid() != 0:
        print(f"{APPNAME} must be run as root.", file=sys.stderr)
        # Sleep to avoid respawn.
        time.sleep(10)
        return 1

    # Make sure that the executable and all containing directories are owned
    # by root:wheel or root:admin, and not writeable by other users.
    root_uid = 0
    wheel_gid = 0
    admin_gid = 80

    exepath = os.path.realpath(os.path.abspath(sys.argv[0]))
    path_ok = True
    while True:
        info = os.stat(exepath)
        if info.st_uid != root_uid:
            print(f"{exepath} must be owned by root.", file=sys.stderr)
            path_ok = False
        if info.st_gid not in (wheel_gid, admin_gid):
            print(f"{exepath} must have group wheel or admin.", file=sys.stderr)
            path_ok = False
        if info.st_mode & stat.S_IWOTH:
            print(f"{exepath} mustn't be world writeable.", file=sys.stderr)
            path_ok = False
        exepath = os.path.dirname(exepath)
        if exepath == "/":
            break

    if not path_ok:
        # Sleep to avoid respawn.
        time.sleep(10)
        return 1
    # Keep track of time for launchd.
    start_time = time.time()

    # Get socket file descriptors from launchd.
    import launch2

    try:
        sockets = launch2.launch_activate_socket("AutoPkgServer")
    except launch2.LaunchDError as err:
        print(f"launchd check-in failed: {err}", file=sys.stderr)
        time.sleep(10)
        return 1

    sock_fd = sockets[0]

    # Create the server object.
    server = AutoPkgServer(sock_fd, PkgHandler)
    server.setup_logging()

    # Wrap main loop in try/finally to unlink the socket when we exit.
    try:
        server.log.info(f"{APPNAME} v{VERSION} starting")

        # Serve all pending requests until we time out.
        while True:
            server.handle_request()
            if not server.timed_out:
                continue

            # Keep running for at least 10 seconds make launchd happy.
            run_time = time.time() - start_time
            server.log.info("run time: %fs" % run_time)
            if run_time < 10.0:
                # Only sleep for a short while in case new requests pop up.
                sleep_time = min(1.0, 10.0 - run_time)
                server.log.debug(
                    "sleeping for %f seconds to make launchd happy" % sleep_time
                )
                time.sleep(sleep_time)
            else:
                break
    finally:
        # Make sure the socket is removed.
        # os.unlink(SOCKET)
        pass

    return 0


if __name__ == "__main__":
    sys.exit(main(sys.argv))
