#!/usr/local/munki/munki-python
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from decimal import Decimal
import hashlib
import json
import os
import plistlib
import ssl
import subprocess
import sys
import time
import urllib.request
import warnings
import zlib

MANAGED_INSTALLS_DIR = "/Library/Managed Installs"
ARCHIVES_DIR = os.path.join(MANAGED_INSTALLS_DIR, "Archives")
APPLICATION_INVENTORY = os.path.join(MANAGED_INSTALLS_DIR, "ApplicationInventory.plist")

USER_AGENT = "Zentral/munkipostflight 0.17"
ZENTRAL_API_ENDPOINT = "https://%TLS_HOSTNAME%/public/munki/"  # set during the package build
ZENTRAL_API_SERVER_CERTIFICATE = "%TLS_SERVER_CERTS%"  # set during the package build
ZENTRAL_API_AUTH_TOKEN = "%TOKEN%"  # set during the enrollment in the postinstall script of the enrollment package

SYSTEM_PROFILER = "/usr/sbin/system_profiler"

SCRIPT_CHECK_TIMEOUT = 3  # timeout in seconds for each script check
SCRIPT_CHECK_STATUS_OK = 0
SCRIPT_CHECK_STATUS_UNKNOWN = 200
SCRIPT_CHECK_STATUS_FAILED = 300


# OSX apps


def get_bundle_info(path):
    for subdir in ("Contents", "Resources"):
        info_plist_path = os.path.join(path, subdir, "Info.plist")
        try:
            with open(info_plist_path, "rb") as f:
                return plistlib.load(f)
        except Exception:
            pass


def get_bundle_info_version(bundle_info):
    version = bundle_info.get("CFBundleVersion")
    if version and isinstance(version, str):
        version = version.split()[0]
        if version[0] in '0123456789':
            return version.replace(",", ".")


def get_osx_app_instances(full_info=False):
    apps = []
    try:
        with open(APPLICATION_INVENTORY, "rb") as f:
            data = plistlib.load(f)
    except IOError:
        print("Could not read application inventory plist")
    else:
        for app_d in data:
            ztl_app_d = {'bundle_id': app_d['bundleid'],
                         'bundle_name': app_d['CFBundleName'],
                         'bundle_version_str': app_d['version']}
            if full_info and app_d["path"]:
                bundle_info = get_bundle_info(app_d["path"])
                if bundle_info:
                    version = get_bundle_info_version(bundle_info)
                    if version:
                        ztl_app_d["bundle_version"] = version
                    bundle_display_name = bundle_info.get("CFBundleDisplayName")
                    if bundle_display_name:
                        ztl_app_d["bundle_display_name"] = bundle_display_name
            apps.append({'app': ztl_app_d,
                         'bundle_path': app_d['path']})
    return apps


# Profiles


def iter_profiles():
    cp = subprocess.run(["/usr/bin/profiles", "-C", "-o", "stdout-xml"], capture_output=True)
    if cp.returncode != 0:
        print("Could not get the profiles")
        return
    response = plistlib.loads(cp.stdout)
    for profile_item in response.get("_computerlevel", []):
        profile = {
            "uuid": profile_item["ProfileUUID"],
            "removal_disallowed": profile_item.get("ProfileRemovalDisallowed") == "true",
            "verified": profile_item.get("ProfileVerificationState") == "verified",
            "payloads": []
        }
        try:
            # unaware datetime for the inventory
            profile["install_date"] = datetime.strptime(
                profile_item["ProfileInstallDate"].rsplit(" ", 1)[0],
                "%Y-%m-%d %H:%M:%S"
            )
        except (KeyError, ValueError):
            print("Could not parse profile install date")
        for profile_attr, opt_attr in (("identifier", "ProfileIdentifier"),
                                       ("display_name", "ProfileDisplayName"),
                                       ("description", "ProfileDescription"),
                                       ("organization", "ProfileOrganization")):
            if opt_attr in profile_item:
                profile[profile_attr] = profile_item.get(opt_attr)
        for payload_item in profile_item.get("ProfileItems", []):
            payload = {
                "uuid": payload_item["PayloadUUID"]
            }
            for payload_attr, opt_attr in (("identifier", "PayloadIdentifier"),
                                           ("display_name", "PayloadDisplayName"),
                                           ("description", "PayloadDescription"),
                                           ("type", "PayloadType")):
                if opt_attr in payload_item:
                    payload[payload_attr] = payload_item.get(opt_attr)
            profile["payloads"].append(payload)
        yield profile


# Munki run reports


class ManagedInstallReport(object):
    def __init__(self, filename):
        self.basename = os.path.basename(filename)
        self.sha1sum = self._get_sha1_sum(filename)
        with open(filename, "rb") as f:
            self.data = plistlib.load(f)
        self.start_time = self.data['StartTime']
        if isinstance(self.start_time, str):
            self.start_time = datetime.strptime(self.start_time, "%Y-%m-%d %H:%M:%S %z")
        self.end_time = self.data.get('EndTime')
        if isinstance(self.end_time, str):
            self.end_time = datetime.strptime(self.end_time, "%Y-%m-%d %H:%M:%S %z")
        try:
            self.munki_version = self.data['MachineInfo']['munki_version']
        except KeyError:
            self.munki_version = None

    @staticmethod
    def _get_sha1_sum(filename):
        sha1 = hashlib.sha1()
        with open(filename, 'rb') as f:
            # TODO: chunking if the file is big
            sha1.update(f.read())
        return sha1.hexdigest()

    @staticmethod
    def _ensure_aware(dt):
        if dt.utcoffset() is None:
            dt = dt.replace(tzinfo=timezone.utc)
        return dt

    def _events(self):
        events = [(self.start_time, {'type': 'start'})]
        for ir in self.data.get('InstallResults', []):
            events.append((self._ensure_aware(ir.pop('time')), dict(ir, type='install')))
        for rr in self.data.get('RemovalResults', []):
            events.append((self._ensure_aware(rr.pop('time')), dict(rr, type='removal')))
        for err in set(self.data.get('Errors', [])):
            events.append((self.end_time or self.start_time, {'type': 'error', 'message': err}))
        for warn in set(self.data.get('Warnings', [])):
            events.append((self.end_time or self.start_time, {'type': 'warning', 'message': warn}))
        events.sort(key=lambda t: t[0])
        return events

    def managed_installs(self):
        managed_installs = {}
        # get the installed items
        for managed_install in self.data.get('ManagedInstalls', []):
            if managed_install.get("installed"):
                managed_installs[managed_install["name"]] = (
                    managed_install["installed_version"],
                    managed_install["display_name"],
                    None
                )
        # add the items to remove (because they are not removed yet)
        for item_to_remove in self.data.get('ItemsToRemove', []):
            if item_to_remove.get("installed"):
                managed_installs[item_to_remove["name"]] = (
                    item_to_remove["installed_version"],
                    item_to_remove["display_name"],
                    None
                )
        # update with the successful removals
        for removal_result in self.data.get('RemovalResults', []):
            if removal_result["status"] == 0:
                managed_installs.pop(removal_result["name"], None)
        # update with the successful installs
        for install_result in self.data.get('InstallResults', []):
            if install_result["status"] == 0:
                managed_installs[install_result["name"]] = (
                    install_result["version"],
                    install_result["display_name"],
                    install_result["time"]
                )
        return managed_installs

    def updated_managed_installs(self, from_managed_installs):
        managed_installs = self.managed_installs()
        # skip error reports
        if not managed_installs and self.data.get('Errors'):
            return from_managed_installs
        # carry over the install time if possible
        for n, (v, dn, t) in managed_installs.items():
            if t is None:
                try:
                    from_v, _, from_t = from_managed_installs[n]
                except KeyError:
                    pass
                else:
                    if from_v == v and from_t is not None:
                        managed_installs[n] = (v, dn, from_t)
        # carry over the already present pkg info for failed installs
        for install_result in self.data.get('InstallResults', []):
            if install_result["status"] != 0:
                n = install_result["name"]
                try:
                    managed_installs[n] = from_managed_installs[n]
                except KeyError:
                    pass
        return managed_installs

    def get_conditions(self):
        conditions = self.data.get('Conditions')
        if not isinstance(conditions, dict):
            conditions = {}
        return conditions

    def get_extra_facts(self, keys):
        conditions = self.get_conditions()
        extra_facts = {}
        for key in keys:
            val = conditions.get(key)
            if val is not None:
                extra_facts[key] = val
        return extra_facts

    def serialize(self):
        d = {'basename': self.basename,
             'conditions': self.get_conditions(),
             'sha1sum': self.sha1sum,
             'run_type': self.data['RunType'],
             'start_time': self.start_time,
             'end_time': self.end_time,
             'events': self._events()}
        if self.munki_version:
            d['munki_version'] = self.munki_version
        return d


def iter_manage_install_report_paths(reverse=True):
    last_report = os.path.join(MANAGED_INSTALLS_DIR, 'ManagedInstallReport.plist')
    if reverse and os.path.exists(last_report):
        yield last_report
    if os.path.isdir(ARCHIVES_DIR):
        for filename in sorted(os.listdir(ARCHIVES_DIR), reverse=reverse):
            yield os.path.join(ARCHIVES_DIR, filename)
    if not reverse and os.path.exists(last_report):
        yield last_report


def iter_manage_install_reports(reverse=True):
    for report_path in iter_manage_install_report_paths(reverse):
        mir = ManagedInstallReport(report_path)
        if mir.end_time:
            yield mir


def build_reports_payload(last_seen=None):
    """ Unpacks ManagedInstallReport generator object, initializes MIR objects,
    skips if already processed or not finished, otherwise serializes & returns payload"""
    payload = []
    last_seen_report_found = False
    for mir in iter_manage_install_reports():
        if last_seen is not None and mir.sha1sum == last_seen:
            last_seen_report_found = True
            break
        payload.append(mir.serialize())
        # stop after one report the first time
        if last_seen is None:
            break
    return payload, last_seen_report_found


def build_managed_installs_payload():
    managed_installs = {}
    for mir in iter_manage_install_reports(reverse=False):
        managed_installs = mir.updated_managed_installs(managed_installs)
    return sorted((n, v, dn, t) for n, (v, dn, t) in managed_installs.items())


# Script checks


def get_current_user_info():
    current_user = curr_user_id = None
    try:
        info = subprocess.check_output(
          '/usr/sbin/scutil <<< "show State:/Users/ConsoleUser"',
          shell=True, encoding="utf-8"
        )
    except Exception:
        pass
    else:
        for line in info.splitlines():
            line = line.strip()
            if line.startswith("Name : "):
                found_user = line[7:]
                if "loginwindow" in found_user.lower():
                    current_user = current_user = None
                    break
                current_user = found_user
            elif line.startswith("UID : "):
                curr_user_id = int(line[6:])
    return current_user, curr_user_id


def get_script_check_env():
    env = os.environ.copy()
    current_user, curr_user_id = get_current_user_info()
    if current_user is not None:
        env["CURRENT_USER"] = current_user
    if curr_user_id is not None:
        env["CURR_USER_ID"] = str(curr_user_id)
    return env


def convert_script_check_int_result(result):
    return int(result)


def convert_script_check_bool_result(result):
    result = result.lower()
    if result in ("f", "false"):
        result = "0"
    elif result in ("t", "true"):
        result = "1"
    result = int(result)
    assert result in (0, 1)
    return bool(result)


def run_script_check(script_check, env):
    start_time = time.monotonic()
    cp = subprocess.run(
        ["/bin/zsh"],
        input=script_check["source"],
        capture_output=True,
        encoding="utf-8",
        timeout=SCRIPT_CHECK_TIMEOUT,
        env=env,
    )
    time_elapsed = time.monotonic() - start_time
    expected_result = script_check["expected_result"]
    result = cp.stdout.strip()
    if script_check["type"] == "ZSH_INT":
        try:
            result = convert_script_check_int_result(result)
        except Exception:
            return SCRIPT_CHECK_STATUS_UNKNOWN, time_elapsed
    elif script_check["type"] == "ZSH_BOOL":
        try:
            result = convert_script_check_bool_result(result)
        except Exception:
            return SCRIPT_CHECK_STATUS_UNKNOWN, time_elapsed
    if result == expected_result:
        return SCRIPT_CHECK_STATUS_OK, time_elapsed
    else:
        return SCRIPT_CHECK_STATUS_FAILED, time_elapsed


def run_script_checks(script_checks):
    env = get_script_check_env()
    results = []
    for script_check in script_checks:
        status, time_elapsed = run_script_check(script_check, env)
        results.append({
            "pk": script_check["pk"],
            "version": script_check["version"],
            "status": status,
            "time": time_elapsed,
        })
    return results


# Machine infos


def build_os_version_dict(os_version_str):
    return dict(zip(['major', 'minor', 'patch'],
                    (int(s) for s in os_version_str.split('.'))))


def get_sw_ver_os_version():
    os_version = {}
    try:
        result = subprocess.check_output(["/usr/bin/sw_vers"], encoding="utf-8")
    except Exception:
        # should never happen
        print("Could not run sw_vers")
        return os_version
    for line in result.splitlines():
        key, val = (s.strip() for s in line.strip().split(":"))
        if key == "ProductName":
            os_version["name"] = val
        elif key == "ProductVersion":
            os_version.update(build_os_version_dict(val))
        elif key == "BuildVersion":
            os_version["build"] = val
        elif key == "ProductVersionExtra":
            os_version["version"] = val
    return os_version


class SystemProfilerReport(object):
    def __init__(self):
        p = subprocess.Popen([SYSTEM_PROFILER, '-xml',
                              'SPHardwareDataType',
                              'SPSoftwareDataType',
                              'SPStorageDataType'],
                             stdout=subprocess.PIPE)
        stdoutdata, _ = p.communicate()
        self.data = plistlib.loads(stdoutdata)

    def _get_data_type(self, data_type):
        for subdata in self.data:
            if subdata['_dataType'] == data_type:
                return subdata

    def get_machine_snapshot(self):
        """ Parses sysprofiler output, returns a dict w/three sub-dicts for
        serial / model, CPU, RAM / OS major-minor-patch"""
        # Hardware
        data = self._get_data_type('SPHardwareDataType')
        if len(data['_items']) != 1:
            raise ValueError('0 or more than one item in a SPHardwareDataType output!')
        item_d = data['_items'][0]

        serial_number = item_d['serial_number']
        system_info = {'hardware_model': item_d['machine_model'],
                       'cpu_type': item_d.get('cpu_type', None),
                       'cpu_brand': item_d.get('chip_type', None)}
        # RAM
        ram_multiplicator = None
        ram_amount, ram_amount_unit = item_d['physical_memory'].split()
        if ram_amount_unit == 'GB':
            ram_multiplicator = 2**30
        elif ram_amount_unit == 'MB':
            ram_multiplicator = 2**20
        else:
            warnings.warn('Unknown ram amount unit {}'.format(ram_amount_unit))
        if ram_multiplicator:
            system_info['physical_memory'] = int(Decimal(ram_amount.replace(",", ".")) * ram_multiplicator)

        # Software
        data = self._get_data_type('SPSoftwareDataType')
        if len(data['_items']) != 1:
            raise ValueError('0 or more than one item in a SPSoftwareDataType output!')
        item_d = data['_items'][0]

        # preferable name to track since end users screw with other labels
        # oddly, falls back to hostname's bonjour-alike output if scutil would return Not Set
        system_info['computer_name'] = os.uname().nodename

        # uptime
        # up 7:21:19:44
        uptime = item_d['uptime'].rsplit(" ", 1)[-1]
        td_kwargs = dict(zip(("seconds", "minutes", "hours", "days"),
                             (int(n) for n in uptime.split(":")[::-1])))
        uptime = int(timedelta(**td_kwargs).total_seconds())

        # OS version
        os_version = get_sw_ver_os_version()
        if not os_version:
            os_version = item_d['os_version']
            os_name, os_version_str, os_build = os_version.rsplit(' ', 2)
            os_build = os_build.strip('()')
            os_version = {'name': os_name,
                          'build': os_build}
            os_version.update(build_os_version_dict(os_version_str))
        return {'serial_number': serial_number,
                'system_info': system_info,
                'os_version': os_version,
                'system_uptime': uptime}


# Microsoft certificates and UUIDs


def parse_dn(dn):
    # TODO: poor man's DN parser
    d = defaultdict(list)
    current_attr = ""
    current_val = ""

    state = "ATTR"
    string_state = "NOT_ESCAPED"
    for c in dn:
        if c == "\\" and string_state == "NOT_ESCAPED":
            string_state = "ESCAPED"
        else:
            if string_state == "NOT_ESCAPED" and c in "=/":
                if c == "=":
                    state = "VAL"
                elif c == "/":
                    state = "ATTR"
                    if current_attr:
                        d[current_attr].append(current_val)
                    current_attr = current_val = ""
            else:
                if state == "ATTR":
                    current_attr += c
                elif state == "VAL":
                    current_val += c
                if string_state == "ESCAPED":
                    string_state = "NOT_ESCAPED"

    if current_attr:
        d[current_attr].append(current_val)
        current_attr = current_val = ""
    return d


def read_cert_info(cert):
    p = subprocess.Popen(["/usr/bin/openssl", "x509", "-noout", "-issuer", "-subject"],
                         stdin=subprocess.PIPE, stdout=subprocess.PIPE)
    stdout, _ = p.communicate(cert.encode("utf-8"))
    info = {}
    for line in stdout.decode("utf-8").splitlines():
        line = line.strip()
        attr, dn = line.split("= ", 1)
        info[attr] = parse_dn(dn.strip())
    return info


def iter_certs():
    # first run, for SHA1 and Keychain
    found_certs = {}
    p = subprocess.Popen(["/usr/bin/security", "find-certificate", "-a", "-Z"], stdout=subprocess.PIPE)
    stdout, _ = p.communicate()
    current_sha1 = None
    for line in stdout.decode("utf-8").splitlines():
        line = line.strip()
        if not line:
            continue
        if line.startswith("SHA-1 hash:"):
            current_sha1 = line.replace("SHA-1 hash:", "").strip()
        elif line.startswith("keychain:"):
            found_certs[current_sha1] = line.replace("keychain:", "").strip('" ')
    # second run, for the PEM values
    p = subprocess.Popen(["/usr/bin/security", "find-certificate", "-a", "-Z", "-p"], stdout=subprocess.PIPE)
    stdout, _ = p.communicate()
    current_cert = current_keychain = current_sha1 = None
    for line in stdout.decode("utf-8").splitlines():
        line = line.strip()
        if not line:
            continue
        if line.startswith("SHA-1 hash:"):
            current_sha1 = line.replace("SHA-1 hash:", "").strip()
            try:
                current_keychain = found_certs[current_sha1]
            except KeyError:
                # TODO: probably a new certificate between the 2 runs...
                current_keychain = None
            continue
        elif "--BEGIN CERTIFICATE--" in line:
            current_cert = ""
        if current_cert is not None:
            current_cert += "{}\n".format(line)
        if "--END CERTIFICATE--" in line:
            yield current_keychain, current_cert.strip(), current_sha1
            current_cert = None


def iter_filtered_certs():
    seen_certs = set([])
    for keychain, cert, sha1 in iter_certs():
        if sha1 in seen_certs:
            continue
        if keychain != "/Library/Keychains/System.keychain":
            # only system keychain certificates
            # TODO: verify
            continue
        cert_info = read_cert_info(cert)
        issuer_dict = cert_info.get("issuer", {})
        issuer_dc = issuer_dict.get("DC")
        issuer_cn = issuer_dict.get("CN", [])
        if issuer_dc == ["net", "windows"] and issuer_cn == ["MS-Organization-Access"] or \
           issuer_cn == ["Microsoft Intune MDM Device CA"] or \
           any("JSS" in cn for cn in issuer_cn):
            yield cert
            seen_certs.add(sha1)


# Company portal user info


def iter_users():
    p = subprocess.Popen(["/usr/bin/dscl", "-plist", ".", "-readall", "/Users",
                          "NFSHomeDirectory", "RealName", "UniqueID"],
                         stdout=subprocess.PIPE)
    stdout, _ = p.communicate()
    for dscl_d in plistlib.loads(stdout):
        user_d = {}
        for dscl_attr, user_attr in (("NFSHomeDirectory", "directory"),
                                     ("RecordName", "username"),
                                     ("RealName", "description"),
                                     ("UniqueID", "uid")):
            dscl_values = dscl_d.get("dsAttrTypeStandard:{}".format(dscl_attr))
            if dscl_values:
                value = dscl_values[0]
                if user_attr == "uid":
                    try:
                        value = int(value)
                    except (TypeError, ValueError):
                        continue
                user_d[user_attr] = value
        directory = user_d.get("directory")
        if isinstance(directory, str) and directory.startswith("/Users"):
            yield user_d


def get_company_portal_principal_user(domains):
    selected_plist_ctime = principal_user = None
    for user_d in iter_users():
        plist_path = os.path.join(
            user_d["directory"],
            "Library/Application Support/com.microsoft.CompanyPortal.usercontext.info"
        )
        try:
            plist_ctime = os.stat(plist_path).st_ctime
        except OSError:
            # plist doesn't exist
            continue
        if plist_ctime < selected_plist_ctime:
            # we have already found a more recent plist
            # TODO: better way to select the principal user?
            # TODO: do it on the client or on the server?
            continue
        try:
            p = subprocess.Popen(["/usr/bin/plutil", "-convert", "json", "-o", "-", plist_path],
                                 stdout=subprocess.PIPE)
            stdout, _ = p.communicate()
            company_portal_info = json.loads(stdout)
        except Exception:
            pass
        else:
            selected_plist_ctime = plist_ctime
            principal_user = {
                "source": {
                    "type": "COMPANY_PORTAL",
                    "properties": {
                        "azure_ad_authority_url": company_portal_info["aadAuthorityUrl"],
                        "version": company_portal_info["version"],
                    },
                },
                "unique_id": company_portal_info["aadUniqueId"],
                "principal_name": company_portal_info["aadUserId"],
            }
    return principal_user


def get_google_chrome_principal_user(domains):
    principal_user_active_time = principal_user = None
    for user_d in iter_users():
        json_path = os.path.join(
            user_d["directory"],
            "Library/Application Support/Google/Chrome/Local State"
        )
        try:
            with open(json_path, 'r') as statefile:
                chrome_info = json.load(statefile)
        except Exception:
            pass
        else:
            try:
                for key, data in chrome_info['profile']['info_cache'].items():
                    active_time = data["active_time"]
                    if principal_user_active_time and principal_user_active_time > active_time:
                        # we have already found a more recent profile
                        continue
                    hosted_domain = data["hosted_domain"]
                    if hosted_domain not in domains:
                        continue
                    principal_name = data["user_name"]
                    if hosted_domain not in principal_name:
                        continue
                    principal_user = {
                        "source": {
                            "type": "GOOGLE_CHROME",
                            "properties": {
                                "hosted_domain": hosted_domain,
                            },
                        },
                        "unique_id": data["gaia_id"],
                        "principal_name": principal_name,
                        "display_name": data["gaia_name"],
                    }
                    principal_user_active_time = active_time
            except Exception:
                pass
    return principal_user


def get_logged_in_user_principal_user(domains):
    try:
        from SystemConfiguration import SCDynamicStoreCreate, SCDynamicStoreCopyValue
        net_config = SCDynamicStoreCreate(None, "net", None, None)
        session_d = SCDynamicStoreCopyValue(net_config, "State:/Users/ConsoleUser")["SessionInfo"][0]
        return {
            "source": {
                "type": "LOGGED_IN_USER",
                "properties": {
                    "method": "System Configuration",
                },
            },
            "unique_id": session_d["kCGSSessionUserIDKey"],
            "principal_name": session_d["kCGSSessionUserNameKey"],
            "display_name": session_d["kCGSessionLongUserNameKey"],
        }
    except Exception:
        pass


# munkilib version


def get_munkilib_version_info():
    info = {}
    try:
        with open("/usr/local/munki/munkilib/version.plist", "rb") as f:
            raw_info = plistlib.load(f)
    except Exception:
        return info
    if not isinstance(raw_info, dict):
        return info
    for plist_attr, attr in (("BuildNumber", "build"),
                             ("CFBundleShortVersionString", "version"),
                             ("GitRevision", "revision")):
        val = raw_info.get(plist_attr)
        if isinstance(val, str):
            info[f"munkilib_{attr}"] = val
    return info


# Zentral Munki API calls


class DateTimeEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, datetime):
            return obj.isoformat()
        return super().default(obj)


def make_api_request(url, data=None):
    req = urllib.request.Request(url)
    req.add_header('User-Agent', USER_AGENT)
    req.add_header('Authorization', 'MunkiEnrolledMachine {}'.format(ZENTRAL_API_AUTH_TOKEN))
    if data:
        data = DateTimeEncoder().encode(data)
        req.add_header('Content-Type', 'application/json')
        data = zlib.compress(data.encode("ascii"), 9)
        req.add_header('Content-Encoding', 'deflate')
    ctx = ssl.create_default_context(cafile=ZENTRAL_API_SERVER_CERTIFICATE or "/private/etc/ssl/cert.pem")
    response = urllib.request.urlopen(req, data=data, context=ctx)
    return json.load(response)


def fetch_job_details(machine_serial_number):
    url = "{}/job_details/".format(ZENTRAL_API_ENDPOINT.strip('/'))
    return make_api_request(url, {'machine_serial_number': machine_serial_number})


def post_job(data):
    url = "{}/post_job/".format(ZENTRAL_API_ENDPOINT.strip('/'))
    return make_api_request(url, data)


def get_machine_snapshot():
    spr = SystemProfilerReport()
    machine_snapshot = spr.get_machine_snapshot()
    machine_snapshot['pem_certificates'] = list(iter_filtered_certs())
    machine_snapshot['profiles'] = list(iter_profiles())
    return machine_snapshot


def add_machine_snapshot_apps(machine_snapshot, msn, job_details):
    try:
        apps_full_info_shard = int(job_details["apps_full_info_shard"])
    except (KeyError, ValueError, TypeError):
        apps_full_info_shard = 0
    else:
        apps_full_info_shard = max(0, min(100, apps_full_info_shard))
    if apps_full_info_shard == 0:
        full_info = False
    elif apps_full_info_shard == 100:
        full_info = True
    else:
        full_info = int(hashlib.md5(msn.encode("utf-8")).hexdigest(), 16) % 100 < apps_full_info_shard
    machine_snapshot['osx_app_instances'] = get_osx_app_instances(full_info)


def get_principal_user(principal_user_detection):
    if not principal_user_detection:
        return None
    domains = principal_user_detection.get("domains", [])
    for source in principal_user_detection.get("sources", []):
        if source == "company_portal":
            f = get_company_portal_principal_user
        elif source == "google_chrome":
            f = get_google_chrome_principal_user
        elif source == "logged_in_user":
            f = get_logged_in_user_principal_user
        else:
            continue
        principal_user = f(domains)
        if principal_user:
            return principal_user
            break


def get_job_details(msn):
    job_details_path = "/usr/local/zentral/munki/job_details.plist"
    job_details = None
    try:
        with open(job_details_path, "rb") as f:
            job_details = plistlib.load(f)
    except FileNotFoundError:
        print("Could not find", job_details_path)
    except Exception as e:
        print("Error reading", job_details_path, e)
    if job_details is not None:
        # cleanup
        try:
            os.unlink(job_details_path)
        except Exception as e:
            print("Could not delete", job_details_path, e)
        return job_details
    else:
        # Fallback. Useful during the migration
        return fetch_job_details(msn)


# Main


if __name__ == '__main__':
    # basic machine snapshot
    machine_snapshot = get_machine_snapshot()

    # get job details
    msn = machine_snapshot["serial_number"]
    job_details = get_job_details(msn)

    # enrich machine snapshot
    add_machine_snapshot_apps(machine_snapshot, msn, job_details)
    principal_user_detection = job_details.get('principal_user_detection')
    principal_user = get_principal_user(principal_user_detection)
    if principal_user:
        machine_snapshot['principal_user'] = principal_user
    collected_condition_keys = job_details.get('collected_condition_keys')
    if collected_condition_keys and isinstance(collected_condition_keys, list):
        try:
            latest_mir = next(iter_manage_install_reports())
        except StopIteration:
            pass
        else:
            machine_snapshot['extra_facts'] = latest_mir.get_extra_facts(collected_condition_keys)
    munkilib_version_info = get_munkilib_version_info()
    if munkilib_version_info:
        machine_snapshot.setdefault('extra_facts', {}).update(munkilib_version_info)

    data = {'machine_snapshot': machine_snapshot}

    # add the new reports
    last_seen_sha1sum = job_details.get('last_seen_sha1sum')
    data['reports'], data['last_seen_report_found'] = build_reports_payload(last_seen_sha1sum)
    # add the managed installs if requested or last seen report not found
    if job_details.get('managed_installs', True) or not data['last_seen_report_found']:
        data['managed_installs'] = build_managed_installs_payload()

    # run the script checks
    script_checks = job_details.get("script_checks")
    if isinstance(script_checks, list):
        data["script_check_results"] = run_script_checks(script_checks)

    # post the payload to zentral
    post_job(data)

    # run type
    run_type = None
    try:
        run_type = sys.argv[1]
    except IndexError:
        pass

    # log line
    print('Zentral postflight job OK - '
          'run type %s, last sha1sum %s' % (run_type or "-",
                                            (last_seen_sha1sum or "-")[:7]))
