#!/usr/bin/env python3
# qunpak: extract Quake .pak files
# SPDX-License-Identifier: WTFPL
# date: 2014-06-18

import sys
import os
import struct
import errno
import fnmatch


class PakReader(object):
    def __init__(self, fd):
        self.fd = fd

    def _read(self, fmt):
        return struct.unpack(fmt, self.fd.read(struct.calcsize(fmt)))

    def _read_uint32(self):
        return self._read('<L')[0]

    def _read_str(self, size):
        buf = b''.join(self._read('%dc' % size))
        if b'\0' in buf:
            return buf[:buf.index(b'\0')]
        else:
            return buf

    def get_toc(self):
        self.fd.seek(0, os.SEEK_END)
        archive_size = self.fd.tell()

        self.fd.seek(0, os.SEEK_SET)
        assert self._read_str(4) == b'PACK', 'Header is not recognized'

        toc_offset = self._read_uint32()
        toc_size = self._read_uint32()
        assert toc_size % 64 == 0, 'The table of contents size is unexpected'

        self.fd.seek(toc_offset, os.SEEK_SET)
        left = toc_size
        while left:
            filename = self._read_str(64 - 8).decode('latin-1')
            offset = self._read_uint32()
            size = self._read_uint32()
            left -= 64

            if not filename:
                yield 'File at offset %d has no name, it will not be extracted' % offset
            elif offset + size > archive_size:
                yield 'The archive may be truncated, file %r will not be extracted' % filename
            else:
                yield (filename, offset, size)


class PakExtractor(object):
    def __init__(self, filename):
        self.fd = open(filename, 'rb')
        self.toc = None
        self.output_dir = '.'
        self.progress = True
        self.pattern = '*'

    def parse(self):
        self.toc = []
        for info in PakReader(self.fd).get_toc():
            if isinstance(info, tuple):
                self.toc.append(info)
            else:
                self.print_error(info)

    def _makedirs(self, path):
        if os.path.isdir(path) or not path:
            return
        os.makedirs(path)

    def _clean_path(self, path):
        path = os.path.normpath(path)
        path = path.replace('../', '').lstrip('/')
        return path

    def set_output_dir(self, output_dir):
        self.output_dir = output_dir

    def set_pattern(self, pattern):
        self.pattern = pattern

    def _output_path(self, path):
        path = self._clean_path(path)
        return os.path.join(self.output_dir, path)

    def extract_all(self):
        success = True
        for (filename, offset, size) in self.toc:
            if fnmatch.fnmatch(self._clean_path(filename), self.pattern):
                success &= self.extract(filename, offset, size)
        return success

    def list_all(self, verbose=False):
        for (filename, offset, size) in self.toc:
            if fnmatch.fnmatch(self._clean_path(filename), self.pattern):
                self._write_ok(sys.stdout, '%s (%d bytes)\n' % (filename, size))

    def _write_ok(self, fd, blk):
        try:
            fd.write(blk)
        except IOError as exc:
            if exc.errno == errno.EPIPE:
                return
            raise

    def print_error(self, error):
        self._write_ok(sys.stderr, error + '\n')

    def print_status(self, filename, done, total):
        if self.progress:
            maxdots = 20
            dots = int(done * 20 / total)
            dotstring = '.' * dots + ' ' * (maxdots - dots)
            self._write_ok(sys.stdout, '\r%s (%d/%d) |%s|' % (filename, done, total, dotstring))
        else:
            if done == 0:
                self._write_ok(sys.stdout, filename)
            elif done == total:
                self._write_ok(sys.stdout, '.')

    def _unlink_temp(self, out_filename):
        try:
            os.unlink(out_filename)
        except OSError as exc:
            self.print_error('Cannot remove %r: %s' % (out_filename, exc))

    def extract(self, filename, offset, size):
        filename = filename.lstrip('/')

        try:
            outdir = self._output_path(os.path.dirname(filename))
            self._makedirs(outdir)
        except OSError as exc:
            self.print_error('Cannot create %r: %s' % (outdir, exc))

        out_filename = self._output_path(filename)
        try:
            self._extract_into(filename, offset, size, out_filename)
        except KeyboardInterrupt:
            self._unlink_temp(out_filename)
            raise
        except (IOError, OSError) as exc:
            self.print_error('Cannot extract %r: %s' % (filename, exc))
            return False
        return True

    def _extract_into(self, filename, offset, size, out_filename):
        try:
            out = open(out_filename, 'wb')
        except OSError as exc:
            raise IOError(exc.errno, 'Cannot write %r: %s' % (out_filename, exc.message))

        try:
            try:
                self.fd.seek(offset, os.SEEK_SET)
            except IOError as exc:
                raise IOError(exc.errno, 'Cannot read %r: %s' % (filename, exc.message))

            done = 0
            self.print_status(filename, done, size)
            while done < size:
                readsize = min(4096, size - done)

                try:
                    buf = self.fd.read(readsize)
                    assert len(buf) == readsize, 'End of file was reached'
                except IOError as exc:
                    raise IOError(exc.errno, 'Cannot read %r: %s' % (filename, exc.message))

                if len(buf) != readsize:
                    raise IOError('Cannot read %r: end of archive was reached' % filename)

                try:
                    out.write(buf)
                except IOError as exc:
                    raise IOError(exc.errno, 'Cannot write %r: %s' % (out_filename, exc.message))

                self.print_status(filename, done, size)
                done += readsize
            self.print_status(filename, done, size)
            self._write_ok(sys.stdout, '\n')

        finally:
            out.close()


def main():
    import argparse
    parser = argparse.ArgumentParser(description='Extract Quake .pak files')
    parser.add_argument('-l', dest='action', action='store_const', const='l', help='List archive only')
    parser.add_argument('-x', dest='action', action='store_const', const='x', default='x', help='Extract archive (default)')
    parser.add_argument('-O', dest='output_dir', metavar='DIR', default='.', help='Extract files into DIR')
    parser.add_argument('archive', metavar='FILE', help='Extract FILE')
    parser.add_argument('pattern', metavar='PATTERN', nargs='?', default='*', help='Handle only files matching PATTERN')
    args = parser.parse_args()

    success = True
    extractor = PakExtractor(args.archive)
    extractor.set_output_dir(args.output_dir)
    extractor.set_pattern(args.pattern)
    try:
        extractor.parse()
    except AssertionError as exc:
        extractor.print_error('File %r is corrupted or is not a PAK file: %s' % (args.archive, exc))
        return 1

    if args.action == 'l':
        extractor.list_all()
    else:
        success = extractor.extract_all()

    # in case of a broken pipe, python would flush at end and finally barf
    try:
        sys.stdout.close()
    except IOError as e:
        pass

    if not success:
        return 1

    return 0

if __name__ == '__main__':
    sys.exit(main())
