#!/usr/bin/env python2

from __future__ import with_statement
import os
import os.path
from glob import glob
from tfplib.external import elf
from snakeoil import formatters
from snakeoil.osutils import join, listdir_files
from optparse import OptionParser
import re
import struct
from string import Template
import copy
import sys
from contextlib import contextmanager, nested
from functools import partial
import fnmatch
from tfplib import fileutils

SEARCH_DIR_PATTERN = re.compile('''SEARCH_DIRS="(.*?)"''') # Match SEARCH_DIRS="" in a revdep-rebuild file

# cheap optimization, a cache of real paths
REALPCACHE = {}
HANDLERS = set()
DEFAULT_EMERGE_OPTS = ["-1"]

class Searcher(fileutils.Searcher):
	pjoin = staticmethod(join)
	def __init__(self, extra=None):
		fileutils.Searcher.__init__(self, extra, unicode=False)
		self.parseRevdepRebuild()

	def parseRevdepRebuild(self, dirname="/etc/revdep-rebuild"):
		"""add directories configured in files in /etc/revdep-rebuild to the paths to search for bins and libs"""
		if os.path.isdir(dirname):
			files = listdir_files(dirname)
			for filename in files:
				path = join(dirname, filename)
				f = file(path, "r")
				data = f.read()
				for match in SEARCH_DIR_PATTERN.findall(data):
					tmp = globs(match.split())
					self.sections["libdirs"] += tmp
					self.sections["bindirs"] += tmp
		else:
			tmp = ["/bin", "/sbin", "/usr/bin", "/usr/sbin"] + glob("/lib*") + glob("/usr/lib*")
			self.sections["libdirs"] += tmp
			self.sections["bindirs"] += tmp


class FileInfo(object):
	def __init__(self, f, filename=None):
		self.f = f
		self.filename = filename or f.name
		self.check = []
		self.needs = set()
		self.broken = False

class BaseChecker(object):
	"""Base Class For Checkers, examples would be an ELF checker or la
	checker"""
	def __init__(self, searcher):
		self.searcher = searcher

	def handles(self, finfo):
		"""Returns True if the file is handled by this checker"""
		raise NotImplementedError

	def check(self, finfo):
		"""Checks if this file is broken"""
		raise NotImplementedError

	def cleanup(self, finfo):
		"""cleanup any things from finfo that this checker may have created"""
		raise NotImplementedError

class ELFChecker(BaseChecker):
	FriendlyName = "ELF Checker"

	def _setup_elf(self, finfo):
		if not hasattr(finfo, "elf"):
			try:
				finfo.elf = elf.ELF(finfo.f)
			except (elf.Invalid, struct.error):
				finfo.elf = None

	def _get_extraSearch(self, finfo, dynes):
		extraSearch = []
		# check for DT_RPATH and DT_RUNPATH
		for entry in dynes:
			if entry.d_tag == entry.DT_RPATH or entry.d_tag == entry.DT_RUNPATH:
				tmpl = Template(entry.data)
				replacements = {"ORIGIN": os.path.dirname(finfo.filename)}
				data = tmpl.safe_substitute(**replacements)
				extraSearch += data.split(":")
		return extraSearch

	def handles(self, finfo):
		"""Returns True if the file is handled by this checker"""
		if finfo.filename.endswith(".la"):
			return False
		self._setup_elf(finfo)
		if finfo.elf != None:
			return True
		else:
			return False

	def check(self, finfo):
		"""Checks if this file is broken"""
		self._setup_elf(finfo)
		e = finfo.elf
		if e == None:
			return
		dynes = e.getDynamicEntries()
		extraSearch = self._get_extraSearch(finfo, dynes)
		for entry in dynes:
			if entry.d_tag == entry.DT_NEEDED:
				# XXX: checking the ELF class here might be wrong, check what makes libraries compatible for dynamic linking
				paths = self.searcher.AbsAll(entry.data, extraSearch=extraSearch, cache=True, sections=["libdirs"])
				for path in paths:
					oe = elf.ELF(open(path, "rb"))
					if oe.e_ident.klass == finfo.elf.e_ident.klass:
						finfo.check.append(path)
						break
				else:
					finfo.needs.add(entry.data)
					finfo.broken = True
		return

	def cleanup(self, finfo):
		if hasattr(finfo, "elf"):
			del finfo.elf

class ScriptChecker(BaseChecker):
	FriendlyName = "Script Checker"

	def _setup_firstline(self, finfo):
		if not hasattr(finfo, "firstline"):
			finfo.firstline = finfo.f.readline().rstrip()

	def handles(self, finfo):
		"""Returns True if the file is handled by this checker"""
		self._setup_firstline(finfo)
		if finfo.firstline[:2] == "#!":
			return True
		else:
			return False

	def check(self, finfo):
		"""Checks if this file is broken"""
		self._setup_firstline(finfo)
		if finfo.firstline[:2] != "#!":
			return
		command = finfo.firstline[2:].split()
		if os.path.exists(command[0]):
			finfo.check.append(command[0])
		else:
			finfo.broken = True
			finfo.needs.add(command[0])
		if command[0][-4:] == "/env" or command[0] == "env":
			# Special Case, env
			command = [self.searcher.Abs(command[1], cache=True, sections=["bindirs"]), command[2:]]
			if command[0]:
				finfo.check.append(command[0])
			else:
				finfo.broken = True
				finfo.needs.add(command[0])

	def cleanup(self, finfo):
		if hasattr(finfo, "firstline"):
			del finfo.firstline

class ElfMatchChecker(ELFChecker):
	def __init__(self, mfunc, searcher):
		ELFChecker.__init__(self, searcher)
		self.match = mfunc

	def check(self, finfo):
		"""Checks if this file is broken"""
		self._setup_elf(finfo)
		e = finfo.elf
		if e == None:
			return
		dynes = e.getDynamicEntries()
		extraSearch = self._get_extraSearch(finfo, dynes)
		for entry in dynes:
			if entry.d_tag == entry.DT_NEEDED:
				if self.match(entry.data):
					finfo.needs.add(entry.data)
					finfo.broken = True
				path = self.searcher.Abs(entry.data, extraSearch=extraSearch, cache=True, sections=["libdirs"])
				if path:
					finfo.check.append(path)
		return

class LaChecker(BaseChecker):
	"""libtool .la file checker"""
	dlname_re = re.compile("dlname='(.*?)'")
	deplibs_re = re.compile("dependency_libs='(.*?)'")
	libdir_re = re.compile("libdir='(.*?)'")
	def _setup_matches(self, finfo):
		if not hasattr(finfo, "lamatches"):
			finfo.lamatches = {}
			data = finfo.f.read()
			finfo.lamatches["dlname"] = self.dlname_re.search(data)
			finfo.lamatches["deplibs"] = self.deplibs_re.search(data)
			finfo.lamatches["libdir"] = self.libdir_re.search(data)

	def handles(self, finfo):
		# More Sophisticated Checking?
		if finfo.filename.endswith(".la"):
			self._setup_matches(finfo)
			if finfo.lamatches["dlname"] and finfo.lamatches["deplibs"]:
				return True
		return False

	def _check_dlname(self, finfo, dlname, extraSearch):
		path = self.searcher.Abs(dlname, extraSearch=extraSearch, cache=True, sections=["libdirs"])
		if path:
			finfo.check.append(path)
		else:
			finfo.needs.add(dlname)
			finfo.broken = True

	def _check_norm_deplib(self, finfo, lib):
		if os.path.exists(lib):
			finfo.check.append(lib)
		else:
			finfo.needs.add(lib)
			finfo.broken = True

	def _check_dashl_deplib(self, finfo, dynname, staticname, extraSearch):
		path_dynamic = self.searcher.Abs(dynname, extraSearch=extraSearch, cache=False, sections=["libdirs"])
		path_static = self.searcher.Abs(staticname, extraSearch=extraSearch, cache=False, sections=["libdirs"])
		if path_dynamic or path_static:
			# TODO: do checking of the dynamic libs?
			pass
		else:
			finfo.needs.add(dynname)
			finfo.broken = True

	def check(self, finfo):
		self._setup_matches(finfo)
		dlname = finfo.lamatches["dlname"]
		deplibs = finfo.lamatches["deplibs"]
		libdir = finfo.lamatches["libdir"]
		if not dlname or not deplibs:
			return
		dlname = dlname.groups()[0]
		deplibs = deplibs.groups()[0].split()
		libdir = libdir.groups()[0] if libdir else None
		search = [libdir] if libdir else None
		# wtf, you can have args in deplibs :(, ignore those for now (TODO: handle at least -l properly)
		deplibs_norm = [x for x in deplibs if x[0] != "-"]
		deplibs_dashl = [x[2:] for x in deplibs if x.startswith("-l")]
		deplibs_dashL = [x[2:] for x in deplibs if x.startswith("-L")]
		if dlname: # apparently you can have .la's with no dlname
			self._check_dlname(finfo, dlname, search)
		# check the dependency libs
		for lib in deplibs_norm:
			self._check_norm_deplib(finfo, lib)
		#if deplibs_dashL:
		#	print deplibs_dashL
		for libname in deplibs_dashl:
			dynname = "lib" + libname + ".so"
			staticname = "lib" + libname + ".a"
			self._check_dashl_deplib(finfo, dynname, staticname, deplibs_dashL)

	def cleanup(self, finfo):
		if hasattr(finfo, "lamatches"):
			del finfo.lamatches

class LaMatchChecker(LaChecker):
	def __init__(self, mfunc, searcher):
		LaChecker.__init__(self, searcher)
		self.match = mfunc

	def _check_dlname(self, finfo, dlname, extraSearch):
		#print dlname
		if self.match(dlname):
			finfo.needs.add(dlname)
			finfo.broken = True

	def _check_norm_deplib(self, finfo, lib):
		if self.match(lib):
			finfo.needs.add(lib)
			finfo.broken = True

	def _check_dashl_deplib(self, finfo, dynname, staticname, extraSearch):
		#print dynname
		if self.match(dynname) or self.match(staticname):
			finfo.needs.add(dynname)
			finfo.broken = True

CHECKERS = {
	"ELF": ELFChecker,
	"Script": ScriptChecker,
	"la": LaChecker,
}

LIB_CHECKERS = {
	"ELF": ElfMatchChecker,
	"la": LaMatchChecker,
}

# XXX: portage specific
def get_package_files(package):
	import portage
	vtree = portage.db["/"]["vartree"]
	package = vtree.dep_match(package)[0]
	link = portage.dblink(myroot="/", settings=vtree.settings, *portage.catsplit(package))
	contents = link.getcontents()
	return [x for x in contents if contents[x][0] == "obj"]

class Checker(object):
	"""Checker Class which is for checking if elf binaries are broken"""
	def __init__(self, searcher, checkers=None):
		self.searcher = searcher
		self.broken = []
		self.checked = set()
		self.needs = {}
		if checkers == None:
			checkers = []
		self.checkers = [checker(searcher) for checker in checkers]

	def check(self, filename):
		"""check if the file passed (and all libs depended on by that file) is broken.
		broken being something it links to doesn't exist in the configured search paths
		in the end the folowing members of this object can be used to see a few things:
			.broken: set of the broken elf files
			.checked: set of files which have been checked
			.needs: a dictionary of library to list of libs it needs which don't exist
		"""
		if not os.path.exists(filename):
			filename = self.searcher.Abs(filename)
		assert filename
		if os.path.islink(filename):
			filename = realpath(filename)
		if filename in self.checked:
			return
		if not os.access(filename, os.R_OK):
			return
		f = file(filename, "r")
		finfo = FileInfo(f, filename)
		checked = False
		for i in range(len(self.checkers)):
			checker = self.checkers[i]
			if checker.handles(finfo):
				checker.check(finfo)
				checked = True
			if checked or i == len(self.checkers)-1:
				break
			finfo.f.seek(0)
		[checker.cleanup(finfo) for checker in self.checkers]
		finfo.f.close()
		self.checked.add(filename)
		for fname in finfo.check:
			self.check(fname)
		if finfo.broken:
			self.broken.append(filename)
			self.needs[filename] = finfo.needs
			for handler in HANDLERS:
				handler.onBrokenFile(filename, self.needs.get(filename, []))

	def checks(self, files):
		"""Checks a list of files"""
		for filename in files:
			self.check(filename)

	def checkEverything(self, ffunc=None):
		"""checks everything in the paths in the searcher"""
		manager = nested(*[handler.onCheckEverything() for handler in HANDLERS])
		checked_dirs = set()
		with manager:
			for dir in self.searcher.getPaths(["libdirs", "bindirs"]):
				# XXX: not sure if stat on directories is supposed to work everywhere, comment out this optimization if it doesn't
				# (and in the below loop as well)
				try:
					st = os.stat(dir)
				except OSError, e:
					if e.errno != 2:
						raise
				else:
					if (st.st_ino, st.st_dev) in checked_dirs:
						#print "Skipping Directory %s, already checked with a different name" % (dir,)
						#print checked_dirs
						continue
					else:
						checked_dirs.add((st.st_ino, st.st_dev))
				#print dir
				for dir2, dirs, files in fileutils.walk(dir):
					# XXX: the above optimization is also possible here, I just have to rework walk to take an argument back if the directories should be iterated over,
					# or I could make it so you can delete from dirs/clear it possibly
					# (the former would be os.walk-incompatible, the latter would compatible)
					#
					# dirs[:] = [] when the directory has already been checked would probably work alright, just need to fix walk
					files = [join(dir2, x) for x in files] #map(lambda x: os.path.join(dir2, x), files)
					if ffunc:
						files = [x for x in files if ffunc(x)] #filter(ffunc, files)
					for file in files:
						self.check(file)

	def checkPackages(self, packages, ffunc=None, pfunc=get_package_files):
		manager = nested(*[handler.onCheckEverything() for handler in HANDLERS])
		with manager:
			for package in packages:
				files = pfunc(package)
				if ffunc:
					files = [x for x in files if ffunc(x)]
				for fname in files:
					self.check(fname)

class Resolver(object):
	"""Class to resolve files to a package in portage"""
	def __init__(self, checker, exec_=True):
		self.checker = checker
		self.exec_ = exec_
		self.pkgmappings = {}

	def resolve(self, files):
		"""Tries to Resolve all the files in files and adds a file: [packagename, SLOT] mapping
		to pkgmappings, this checks the CONTENTS of every package installed on the system,
		so it will take a while. I reccomend using resolve once on as many binaries as possible"""
		import portage
		sfiles = set(files)
		vtree = portage.db["/"]["vartree"]
		foundfiles = set()
		for package in vtree.getallcpv():
			link = portage.dblink(myroot="/", settings=vtree.settings, *portage.catsplit(package))
			contents = link.getcontents()
			for file in sfiles - foundfiles:
				if file in contents:
					self.pkgmappings[file] = [link.mysplit[0], link.getstring("SLOT").strip()]
					foundfiles.add(file)
			if not sfiles - foundfiles:
				return

	def resolveBroken(self):
		"""wrapper around resolve which passes in all the broken files in self.checker.broken"""
		manager = nested(*[handler.onResolveBroken() for handler in HANDLERS])
		with manager:
			self.resolve(self.checker.broken)

	def getPackages(self, files=None):
		"""Returns a list of package atoms from the package mappings"""
		if not files:
			files = self.checker.broken
		pm = self.pkgmappings
		packages = []
		for filename in files:
			tmp = pm.get(filename)
			if tmp and tmp not in packages:
				packages.append(tmp)
		Atoms = [getPackageAtom(x) for x in packages]
		return Atoms

	def emergeBroken(self, options=None):
		"""emerge broken packages, based on package mappings (run resolveBroken first)"""
		if options == None:
			options = []
		opts = ""
		if options:
			opts = " ".join(options)
		packages = self.getPackages()
		pstring = " ".join(packages)
		default = " ".join(DEFAULT_EMERGE_OPTS)
		command = "emerge %s %s %s" % (default, opts, pstring)
		[handler.onEmergeBroken(command) for handler in HANDLERS]
		if self.exec_:
			os.execvp("emerge", ["emerge"] + DEFAULT_EMERGE_OPTS + options + packages)

class BaseHandler(object):
	def onStart(self):
		"""Called when the application is started"""

	@contextmanager
	def onCheckEverything(self):
		"""Called when a Checker.checkEverything is called"""
		try:
			yield
		finally:
			pass

	def onBrokenFile(self, filename, needs):
		"""called when a broken file is found"""

	@contextmanager
	def onResolveBroken(self):
		"""Called when Resolver.resolveBroken is called"""
		try:
			yield
		finally:
			pass

	def onEmergeBroken(self, emerge_command):
		"""Called when Resolver.emergeBroken is called"""

	def onNoProbs(self):
		"""Called if nothing is broken"""

class OutputHandler(BaseHandler):
	def __init__(self, formatter=formatters.get_formatter):
		BaseHandler.__init__(self)
		self.stream = sys.stdout
		self.f = formatter(self.stream)

	def onStart(self):
		self.f.write("Checking reverse dependencies...")
		self.f.write("")
		self.f.write("Packages containing binaries and libraries broken by a package update will be emerged.")
		self.f.write("")
		self.stream.flush()

	@contextmanager
	def onCheckEverything(self):
		try:
			self.f.write(self.f.bold, self.f.fg("green"), "Checking for broken files...")
			self.stream.flush()
			yield
		except:
			raise
		else:
			self.f.write("")
			self.stream.flush()

	def onBrokenFile(self, filename, needs):
		self.f.write("  broken ", self.f.bold, filename, self.f.reset, " (needs %s)" %  ", ".join(needs)) # broken %s (needs %s)
		self.stream.flush()

	@contextmanager
	def onResolveBroken(self):
		try:
			self.f.autoline = False
			self.f.write(self.f.bold, self.f.fg("green"), "Resolving files to ebuilds...")
			self.stream.flush()
			self.f.autoline = True
			yield
		finally:
			self.f.write(" done")
			self.f.write("")
			self.stream.flush()

	def onEmergeBroken(self, emerge_command):
		self.f.write(self.f.bold, self.f.fg("green"), "Emerging broken packages...")
		self.f.write("  %s" % emerge_command)
		self.stream.flush()

	def onNoProbs(self):
		self.f.write(self.f.bold, self.f.fg("green"), "No broken packages found, done")
		self.f.write("")
		self.stream.flush()

def realpath(filename):
	"""Cheap Optimization around realpath, cache the results"""
	tmp = REALPCACHE.get(filename)
	if tmp:
		return tmp
	else:
		path = os.path.realpath(filename)
		REALPCACHE[filename] = path
		return path

def getPackageAtom(package):
	ret = "%s:%s" % (package[0], package[1])
	return ret

def default_filter(filename):
	"""Returns True if said file is executable or a .so file, and it doesn't match a couple patterns (.la files for example)"""
	# filter out .debug files (as far as I can see .debug files don't have any dynamic sections, correct me if I'm wrong)
	if filename[-6:] == ".debug":
		return False
	elif filename.find(".so") != -1 or filename.endswith(".la") or os.access(filename, os.X_OK):
		return True
	else:
		return False

def globs(l):
	"""glob an array"""
	ret = []
	for name in l:
		ret += glob(name)
	return ret

def buildOptParser():
	usage = "usage: %prog [options] [-- emerge options]"
	parser = OptionParser(usage=usage)
	# TODO: I should probably combine -L and -a, maybe change to -C (--checkers, a comma-separated? list of checker names) instead of doing the default=only ELF magic
	parser.add_option("-N", "--no-colour", help="don't colourize output", action="store_false", dest="colour", default=True)
	parser.add_option("-c", "--colour", help="force colourization", action="store_true", dest="force_colour", default=False)
	parser.add_option("-f", "--full", help="check EVERY file, by default we only check executable or .so files (this will be SLOW)", 
		action="store_const", dest="filter", const=None, default=default_filter)
	parser.add_option("-a", "--add", help="add a checker to run", action="append", dest="checkers", choices=list(CHECKERS))
	parser.add_option("-l", "--library", help="Emerge existing packages that use the library with MATCHNAME, MATCHNAME can be a standard filename pattern, or if -r is used a full regular expression matching the filename only", dest="matchname")
	parser.add_option("-L", "--lib-checker", help="specify which handler to use for the -l option", dest="lib_checker", choices=list(LIB_CHECKERS), default="ELF")
	parser.add_option("-r", "--use-re", help="use regular expressions to match filenames when used with --library", action="store_true", dest="use_re")
	parser.add_option("--no-rice", help="disable using psyco if available, psyco can't be used when profiling in some cases", action="store_false", dest="psyco", default=True)
	parser.add_option("-p", "--package", help="only check the given package(s) and their binary deps for breakage", action="append", dest="packages")
	parser.add_option("-D", "--dont-exec", help="don't exec emerge, just print what would be executed", action="store_false", dest="exec_", default=True)
	return parser

def main():
	parser = buildOptParser()
	(options, args) = parser.parse_args()
	if options.psyco:
		try:
			import psyco
			psyco.full()
		except ImportError:
			pass
	if options.colour:
		if options.force_colour:
			formatter = partial(formatters.TerminfoFormatter, forcetty=True)
		else:
			formatter = formatters.get_formatter
	else:
		formatter = formatters.PlainTextFormatter
	HANDLERS.add(OutputHandler(formatter))
	for handler in HANDLERS:
		handler.onStart()
	s = Searcher()
	if options.matchname:
		if options.use_re:
			p = re.compile(options.matchname)
			mfunc = p.match
		else:
			mfunc = partial(lambda pattern, file: fnmatch.fnmatch(file, pattern), options.matchname)
		mchecker = LIB_CHECKERS[options.lib_checker]
		options.checkers = [partial(mchecker, mfunc)]
	if not options.checkers:
		options.checkers = ["ELF"]
	ec = []
	for checker in options.checkers:
		if checker in CHECKERS:
			ec.append(CHECKERS[checker])
		elif callable(checker) and isinstance(checker(None), BaseChecker):
			ec.append(checker)
	options.checkers = ec
	#options.checkers = [CHECKERS[checker] for checker in options.checkers if checker in CHECKERS]
	c = Checker(s, options.checkers)
	r = Resolver(c, exec_=options.exec_)
	#c.check("/opt/vmware/workstation/lib/bin/vmplayer")
	if options.packages:
		c.checkPackages(options.packages, ffunc=options.filter)
	else:
		c.checkEverything(ffunc=options.filter)
	if not c.broken:
		[handler.onNoProbs() for handler in HANDLERS]
		return
	r.resolveBroken()
	if not r.pkgmappings:
		[handler.onNoProbs() for handler in HANDLERS]
		return
	r.emergeBroken(args)

if __name__ == "__main__":
	try:
		main()
	except KeyboardInterrupt:
		# swallow the output, and exit with 1
		sys.exit(1)

#parseldconfigp()
#s = Searcher()
#c = Checker(s)
#r = Resolver(c)
#c.checkEverything(ffunc=default_filter)
#c.check("/opt/vmware/workstation/lib/bin/vmplayer")
#r.resolveBroken()
#print r.pkgmappings
#for lib in c.broken:
#	print lib