#!/usr/bin/env python3
# SPDX-License-Identifier: WTFPL

import argparse
import os
import resource
import sys


LIMITS = {
    name.removeprefix("RLIMIT_").lower().replace("_", "-"): getattr(resource, name)
    for name in dir(resource)
    if name.startswith("RLIMIT_")
}

CURRENT = {
    name: resource.getrlimit(LIMITS[name])
    for name in LIMITS
}


def limit_to_str(v):
    if v == resource.RLIM_INFINITY:
        return 'inf'
    return v


class LimitSetAction(argparse.Action):
    def __init__(self, option_strings, dest, nargs=None, *, name, softness, **kwargs):
        if nargs is not None:
            raise ValueError("nargs not allowed")
        super().__init__(option_strings, dest, **kwargs)
        self.softness = softness
        self.name = name

    def __call__(self, parser, namespace, values, option_string=None):
        values = int(values)
        try:
            d = namespace.limits
        except AttributeError:
            d = namespace.limits = {}

        d.setdefault(self.name, [None, None])[self.softness] = values


def main():
    parser = argparse.ArgumentParser()
    for name in LIMITS:
        for n, softness in enumerate(("soft", "hard")):
            parser.add_argument(
                f"--{softness}-{name}",
                action=LimitSetAction, softness=n, name=name,
                help=f"Set {softness} limit of {name!r} (current: "
                + f"{limit_to_str(CURRENT[name][n])})",
            )

    parser.add_argument(
        "command", help="Command to run",
    )
    parser.add_argument(
        "args", nargs="*", default=[],
        help="Args for COMMAND",
    )
    args = parser.parse_args()

    try:
        args.limits
    except AttributeError:
        args.limits = []

    for name in args.limits:
        new = tuple(
            wish if wish is not None else system
            for system, wish in zip(
                CURRENT[name], args.limits[name]
            )
        )
        if new[1] != resource.RLIM_INFINITY:
            if new[0] == resource.RLIM_INFINITY or new[0] >= new[1]:
                new = (new[1], new[1])

        print(f"setting {name} to soft={new[0]} hard={new[1]}", file=sys.stderr)
        resource.setrlimit(LIMITS[name], new)

    os.execlp(args.command, args.command, *args.args)


if __name__ == "__main__":
    main()
