#!/usr/bin/env python3

# Copyright 2021-2022 Yury Gribov
# 
# Use of this source code is governed by MIT license that can be
# found in the LICENSE.txt file.

"""
A wrapper for linker.
"""

import json
import os
import os.path
import re
import subprocess
import sys

me = os.path.basename(__file__)

def warn(msg):
  """
  Print nicely-formatted warning message.
  """
  sys.stderr.write(f'{me}: warning: {msg}\n')

def error(msg):
  """
  Print nicely-formatted error message and exit.
  """
  sys.stderr.write(f'{me}: error: {msg}\n')
  sys.exit(1)

def warn_if(cond, msg):
  if cond:
    warn(msg)

def error_if(cond, msg):
  if cond:
    error(msg)

def run(cmd, **kwargs):
  """
  Simple wrapper for subprocess.
  """
  if 'fatal' in kwargs:
    fatal = kwargs['fatal']
    del kwargs['fatal']
  else:
    fatal = False
  if 'tee' in kwargs:
    tee = kwargs['tee']
    del kwargs['tee']
  else:
    tee = False
  if isinstance(cmd, str):
    cmd = cmd.split(' ')
#  print(cmd)
  with subprocess.Popen(cmd, stdin=None, stdout=subprocess.PIPE,
                        stderr=subprocess.PIPE, **kwargs) as p:
    out, err = p.communicate()
  out = out.decode()
  err = err.decode()
  if fatal and p.returncode != 0:
    cmds = ' '.join(cmd)
    error(f"'{cmds}' failed:\n{out}{err}")
  if tee:
    sys.stdout.write(out)
    sys.stderr.write(err)
  return p.returncode, out, err

def readelf(filename, extra_args=[]):  # pylint: disable=dangerous-default-value
  """ Returns symbol table of ELF file, """

  if not os.path.exists(filename):
    error(f"unable to run readelf: file does not exist: {filename}")

  if not isinstance(extra_args, list):
    extra_args = [extra_args]
  with subprocess.Popen(["readelf", "-W", filename] + extra_args,
                        stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p:
    out, err = p.communicate()
  out = out.decode()
  err = err.decode()
  if p.returncode != 0:
    error(f"readelf failed with retcode {p.returncode}: {filename}: {err}")
  warn_if(err, f"readelf: {filename}: {err}")


  file_map = {}

  toc = None
  syms = []
  for line in out.splitlines():
    line = line.strip()
    if not line:
      continue
    words = re.split(r' +', line)
    if line.startswith('File: '):  # New file in static library?
      # File: ./lib/libdiffutils.a(vasnprintf.o)
      m = re.search(r'\((.*)\)', line)
      file_map[m.group(1)] = syms = []
      toc = None
    elif line.startswith('Num'):  # Header?
      if toc is not None:
        error(f"multiple headers in output of readelf for {filename}")
      toc = {}
      for i, n in enumerate(words):
        # Colons are different across readelf versions so get rid of them.
        n = n.replace(':', '')
        toc[i] = n
    elif toc is not None:
      sym = {k: (words[i] if i < len(words) else '') for i, k in toc.items()}
      name = sym['Name']
      if '@' in name:
        sym['Default'] = '@@' in name
        name, ver = re.split(r'@+', name)
        sym['Name'] = name
        sym['Version'] = ver
      else:
        sym['Default'] = True
        sym['Version'] = None
      syms.append(sym)

  # crtn.o on Ubuntu does not have a symtab
  if os.path.basename(filename) not in {'crtn.o'}:
    warn_if(toc is None, f"no symbol table in {filename}")

  return file_map if file_map else syms

class LinkerInvocation:
  def __init__(self, args):
    self.args = args

  def replace_exe(self, new_exe):
    self.args[0] = new_exe

  def eval(self):
    rc, out, err = run(self.args, tee=True)
    return rc, out, err

  def get_output_file(self):
    for k, v in zip(self.args, self.args[1:]):
      if k == '-o':
        return k[2:] if k[2:] else v
    return 'a.out'

  def get_object_files(self):
    return [a for a in self.args if a.endswith('.o')]

  def get_static_library_files(self):
    return [a for a in self.args if a.endswith('.a')]

  def has_global_exports(self):
    for a in self.args:
      # TODO: handle --dynamic-listXXX and --version-script
      if a in ('-shared', '-Bshareable', '--export-dynamic'):
        return True
    return False

def main():
  tmp_dir = os.environ.get('LOCALIZER_DIR')
  error_if(tmp_dir is None, "LOCALIZER_DIR variable not set")

  v = int(os.environ.get('LOCALIZER_VERBOSE', 0))

  if v:
    argvs = ' '.join(sys.argv)
    sys.stderr.write(f"ld-wrapper: called with: {argvs}\n")

  inv = LinkerInvocation(sys.argv)

  # Remove ourselves from PATH
  mydir = os.path.dirname(__file__)
  new_paths = []
  for d in os.environ['PATH'].split(os.pathsep):
    if os.path.exists(d) and not os.path.samefile(d, mydir):
      new_paths.append(d)
  os.environ['PATH'] = os.pathsep.join(new_paths)
  if v:
    sys.stderr.write(f"ld-wrapper: updated PATH: {os.environ['PATH']}\n")

  inv.replace_exe(os.path.basename(sys.argv[0]))

  rc, out, err = inv.eval()
  sys.stdout.write(out)
  sys.stderr.write(err)

  report_data = {'cmdline' : sys.argv, 'imports' : [], 'exports' : [], 'global_exports' : []}

  # Collect symtabs of .o files
  file_map = {}
  for o in inv.get_object_files():
    file_map[o] = readelf(o, '-s')

  # Collect symtabs of .o files in static libs
  # TODO: for global exports we need to know what files were linked
  # (via 'ld -verbose')
  for lib in inv.get_static_library_files():
    file_map_1 = readelf(lib, '-s')
    for o, syms in sorted(file_map_1.items()):  # pylint: disable=no-member  # file_map_1 is really a dict
      # TODO: check dups?
      file_map[o] = syms

  # Collect static defs/uses
  for o, syms in sorted(file_map.items()):
    for s in syms:
      if s['Bind'] == 'GLOBAL':
        lst = report_data['imports' if s['Ndx'] == 'UND' else 'exports']
        lst.append({'name' : s['Name'], 'file' : o})

  # Collect dynamic defs (uses are potentially unbounded)
  out_name = inv.get_output_file()
  if inv.has_global_exports():
    for s in readelf(out_name, '--dyn-syms'):
      if s['Bind'] == 'GLOBAL':
        report_data['global_exports'].append({'name' : s['Name'], 'file' : out_name})

  report_name = os.path.join(tmp_dir, str(os.getpid()) + '.json')
  with open(report_name, 'w') as f:
    json.dump(report_data, f, indent=2, sort_keys=True)

  return rc

if __name__ == '__main__':
  sys.exit(main())
