extras-repoclosure/yum Errors.py, NONE, 1.1 __init__.py, NONE, 1.1 comps.py, NONE, 1.1 config.py, NONE, 1.1 constants.py, NONE, 1.1 depsolve.py, NONE, 1.1 failover.py, NONE, 1.1 logger.py, NONE, 1.1 mdcache.py, NONE, 1.1 mdparser.py, NONE, 1.1 misc.py, NONE, 1.1 packages.py, NONE, 1.1 parser.py, NONE, 1.1 pgpmsg.py, NONE, 1.1 plugins.py, NONE, 1.1 repos.py, NONE, 1.1 sqlitecache.py, NONE, 1.1 sqlitesack.py, NONE, 1.1 transactioninfo.py, NONE, 1.1 update_md.py, NONE, 1.1

Michael Schwendt (mschwendt) fedora-extras-commits at redhat.com
Tue Jul 4 18:15:30 UTC 2006


Author: mschwendt

Update of /cvs/fedora/extras-repoclosure/yum
In directory cvs-int.fedora.redhat.com:/tmp/cvs-serv6236/yum

Added Files:
	Errors.py __init__.py comps.py config.py constants.py 
	depsolve.py failover.py logger.py mdcache.py mdparser.py 
	misc.py packages.py parser.py pgpmsg.py plugins.py repos.py 
	sqlitecache.py sqlitesack.py transactioninfo.py update_md.py 
Log Message:
include a local yum for checkForObsolete support


--- NEW FILE Errors.py ---
#!/usr/bin/python -tt
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# Copyright 2004 Duke University



import exceptions


class YumBaseError(exceptions.Exception):
    def __init__(self, args=None):
        exceptions.Exception.__init__(self)    
        self.args = args

class LockError(YumBaseError):
    def __init__(self, errno, msg):
        YumBaseError.__init__(self)
        self.errno = errno
        self.msg = msg
        
class DepError(YumBaseError):
    def __init__(self, args=None):
        YumBaseError.__init__(self)
        self.args = args

class RepoError(YumBaseError):
    def __init__(self, args=None):
        YumBaseError.__init__(self)
        self.args = args

class ConfigError(YumBaseError):
    def __init__(self, args=None):
        YumBaseError.__init__(self)
        self.args = args
    
class MiscError(YumBaseError):
    def __init__(self, args=None):
        YumBaseError.__init__(self)
        self.args = args

class GroupsError(YumBaseError):
    def __init__(self, args=None):
        YumBaseError.__init__(self)
        self.args = args

class InstallError(YumBaseError):
    def __init__(self, args=None):
        YumBaseError.__init__(self)
        self.args = args

class UpdateError(YumBaseError):
    def __init__(self, args=None):
        YumBaseError.__init__(self)
        self.args = args

class RemoveError(YumBaseError):
    def __init__(self, args=None):
        YumBaseError.__init__(self)
        self.args = args


--- NEW FILE __init__.py ---
#!/usr/bin/python -tt
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# Copyright 2005 Duke University


import os
import sys
import os.path
import rpm
import re
import fnmatch
import types
import errno
import time
import sre_constants
import glob


import Errors
import rpmUtils
import rpmUtils.updates
import rpmUtils.arch
import comps
import config
import parser
import repos
import misc
import transactioninfo
from urlgrabber.grabber import URLGrabError
import depsolve
import plugins


from packages import parsePackages, YumAvailablePackage, YumLocalPackage, YumInstalledPackage
from repomd import mdErrors
from constants import *
from repomd.packageSack import ListPackageSack

__version__ = '2.6.1'

class YumBase(depsolve.Depsolve):
    """This is a primary structure and base class. It houses the objects and
       methods needed to perform most things in yum. It is almost an abstract
       class in that you will need to add your own class above it for most
       real use."""
    
    def __init__(self):
        depsolve.Depsolve.__init__(self)
        self.localdbimported = 0
        self.repos = repos.RepoStorage() # class of repositories
        if (not self.repos.sqlite):
            self.log(1,"Warning, could not load sqlite, falling back to pickle")

        # Start with plugins disabled
        self.disablePlugins()

        self.localPackages = [] # for local package handling 

    def _transactionDataFactory(self):
        """Factory method returning TransactionData object"""
        if self.conf.enable_group_conditionals:
            return transactioninfo.ConditionalTransactionData()
        return transactioninfo.TransactionData()

    def log(self, value, msg):
        """dummy log stub"""
        print msg

    def errorlog(self, value, msg):
        """dummy errorlog stub"""
        print >> sys.stderr, msg

    def filelog(self, value, msg):
        print msg
   
    def doGenericSetup(self, cache=0):
        """do a default setup for all the normal/necessary yum components,
           really just a shorthand for testing"""
        
        self.doConfigSetup()
        self.conf.cache = cache
        self.doTsSetup()
        self.doRpmDBSetup()
        self.doRepoSetup()
        self.doSackSetup()
        
    def doConfigSetup(self, fn='/etc/yum.conf', root='/'):
        """basic stub function for doing configuration setup"""
       
        self.conf = config.readMainConfig(fn, root)
        self.yumvar = self.conf.yumvar
        self.getReposFromConfig()

    def getReposFromConfig(self):
        """read in repositories from config main and .repo files"""

        reposlist = []

        # Check yum.conf for repositories
        for section in self.conf.cfg.sections():
            # All sections except [main] are repositories
            if section == 'main': 
                continue

            try:
                thisrepo = config.readRepoConfig(self.conf.cfg, section, self.conf)
            except (Errors.RepoError, Errors.ConfigError), e:
                self.errorlog(2, e)
            else:
                reposlist.append(thisrepo)

        # Read .repo files from directories specified by the reposdir option
        # (typically /etc/yum.repos.d and /etc/yum/repos.d)
        parser = config.IncludedDirConfigParser(vars=self.yumvar)
        for reposdir in self.conf.reposdir:
            if os.path.exists(self.conf.installroot+'/'+reposdir):
                reposdir = self.conf.installroot + '/' + reposdir

            if os.path.isdir(reposdir):
                #XXX: why can't we just pass the list of files?
                files = ' '.join(glob.glob('%s/*.repo' % reposdir))
                #XXX: error catching here
                parser.read(files)

        # Check sections in the .repo files that were just slurped up
        for section in parser.sections():
            try:
                thisrepo = config.readRepoConfig(parser, section, self.conf)
            except (Errors.RepoError, Errors.ConfigError), e:
                self.errorlog(2, e)
            else:
                reposlist.append(thisrepo)

        # Got our list of repo objects, add them to the repos collection
        for thisrepo in reposlist:
            try:
                self.repos.add(thisrepo)
            except Errors.RepoError, e: 
                self.errorlog(2, e)
                continue

    def disablePlugins(self):
        '''Disable yum plugins
        '''
        self.plugins = plugins.DummyYumPlugins()
    
    def doPluginSetup(self, optparser=None, types=None):
        '''Initialise and enable yum plugins. 

        If plugins are going to be used, this should be called soon after
        doConfigSetup() has been called.

        @param optparser: The OptionParser instance for this run (optional)
        @param types: A sequence specifying the types of plugins to load.
            This should be sequnce containing one or more of the
            yum.plugins.TYPE_...  constants. If None (the default), all plugins
            will be loaded.
        '''
        # Load plugins first as they make affect available config options
        self.plugins = plugins.YumPlugins(self, self.conf.pluginpath,
                optparser, types)

        # Process options registered by plugins
        self.plugins.parseopts(self.conf, self.repos.findRepos('*'))

        # Initialise plugins
        self.plugins.run('init')

    def doTsSetup(self):
        """setup all the transaction set storage items we'll need
           This can't happen in __init__ b/c we don't know our installroot
           yet"""
        
        if hasattr(self, 'read_ts'):
            return
            
        if not self.conf.installroot:
            raise Errors.YumBaseError, 'Setting up TransactionSets before config class is up'
        
        installroot = self.conf.installroot
        self.read_ts = rpmUtils.transaction.initReadOnlyTransaction(root=installroot)
        self.tsInfo = self._transactionDataFactory()
        self.rpmdb = rpmUtils.RpmDBHolder()
        self.initActionTs()
        
    def doRpmDBSetup(self):
        """sets up a holder object for important information from the rpmdb"""
        
        if not self.localdbimported:
            self.log(3, 'Reading Local RPMDB')
            self.rpmdb.addDB(self.read_ts)
            self.localdbimported = 1

    def closeRpmDB(self):
        """closes down the instances of the rpmdb we have wangling around"""
        if hasattr(self, 'rpmdb'):
            del self.rpmdb
            self.localdbimported = 0
        if hasattr(self, 'ts'):
            del self.ts.ts
            del self.ts
        if hasattr(self, 'read_ts'):
            del self.read_ts.ts
            del self.read_ts
        if hasattr(self, 'up'):
            del self.up
        if hasattr(self, 'comps'):
            self.comps.compiled = False
            

    def doRepoSetup(self, thisrepo=None):
        """grabs the repomd.xml for each enabled repository and sets up 
           the basics of the repository"""

        self.plugins.run('prereposetup')
        
        if thisrepo is None:
            repos = self.repos.listEnabled()
        else:
            repos = self.repos.findRepos(thisrepo)

        if len(repos) < 1:
            self.errorlog(0, 'No Repositories Available to Set Up')

        num = 1
        for repo in repos:
            if repo.repoXML is not None and len(repo.urls) > 0:
                num += 1
                continue
            if self.repos.callback:
                self.repos.callback.log(2, '%-68s [%d/%d]' % 
                                        (repo.id, num, len(repos)))
                
            try:
                repo.cache = self.conf.cache
                repo.baseurlSetup()
                repo.dirSetup()
                self.log(3, 'Baseurl(s) for repo: %s' % repo.urls)
            except Errors.RepoError, e:
                self.errorlog(0, '%s' % e)
                raise
                
            try:
                repo.getRepoXML(text=repo)
            except Errors.RepoError, e:
                self.errorlog(0, 'Cannot open/read repomd.xml file for repository: %s' % repo)
                self.errorlog(0, str(e))
                raise
            num += 1
            
            
        if self.repos.callback and len(repos) > 0:
            self.repos.callback.progressbar(num, len(repos), repo.id)
            
        self.plugins.run('postreposetup')

    def doSackSetup(self, archlist=None, thisrepo=None):
        """populates the package sacks for information from our repositories,
           takes optional archlist for archs to include"""
           
        if hasattr(self, 'pkgSack') and thisrepo is None:
            self.log(7, 'skipping reposetup, pkgsack exists')
            return
            
        if thisrepo is None:
            repos = self.repos.listEnabled()
        else:
            repos = self.repos.findRepos(thisrepo)
            
        self.log(3, 'Setting up Package Sacks')
        if not archlist:
            archlist = rpmUtils.arch.getArchList()

        archdict = {}
        for arch in archlist:
            archdict[arch] = 1

        self.repos.pkgSack.compatarchs = archdict
        self.repos.populateSack(which=repos)
        self.pkgSack = self.repos.pkgSack
        self.excludePackages()
        self.pkgSack.excludeArchs(archlist)

        for repo in repos:
            self.excludePackages(repo)
            self.includePackages(repo)
        self.plugins.run('exclude')
        self.pkgSack.buildIndexes()
        
    def doUpdateSetup(self):
        """setups up the update object in the base class and fills out the
           updates, obsoletes and others lists"""
        
        if hasattr(self, 'up'):
            return
            
        self.log(3, 'Building updates object')
        #FIXME - add checks for the other pkglists to see if we should
        # raise an error
        if not hasattr(self, 'pkgSack'):
            self.doRepoSetup()
            self.doSackSetup()
        
        self.up = rpmUtils.updates.Updates(self.rpmdb.getPkgList(),
                                           self.pkgSack.simplePkgList())
        if self.conf.debuglevel >= 6:
            self.up.debug = 1
            
        if self.conf.obsoletes:
            self.up.rawobsoletes = self.pkgSack.returnObsoletes()
            
        self.up.exactarch = self.conf.exactarch
        self.up.exactarchlist = self.conf.exactarchlist
        self.up.doUpdates()

        if self.conf.obsoletes:
            self.up.doObsoletes()

        self.up.condenseUpdates()
        
    
    def doGroupSetup(self):
        """create the groups object that will store the comps metadata
           finds the repos with groups, gets their comps data and merge it
           into the group object"""
        
        self.log(3, 'Getting group metadata')
        reposWithGroups = []
        for repo in self.repos.listGroupsEnabled():
            if repo.groups_added: # already added the groups from this repo
                reposWithGroups.append(repo)
                continue
                
            if repo.repoXML is None:
                raise Errors.RepoError, "Repository '%s' not yet setup" % repo
            try:
                groupremote = repo.repoXML.groupLocation()
            except mdErrors.RepoMDError, e:
                pass
            else:
                reposWithGroups.append(repo)
                
        # now we know which repos actually have groups files.
        overwrite = self.conf.overwrite_groups
        if not hasattr(self, 'comps'):
            self.comps = comps.Comps(overwrite_groups = overwrite)

        for repo in reposWithGroups:
            if repo.groups_added: # already added the groups from this repo
                continue
                
            self.log(4, 'Adding group file from repository: %s' % repo)
            groupfile = repo.getGroups()
            try:
                self.comps.add(groupfile)
            except Errors.GroupsError, e:
                self.errorlog(0, 'Failed to add groups file for repository: %s' % repo)
            else:
                repo.groups_added = True

        if self.comps.compscount == 0:
            raise Errors.GroupsError, 'No Groups Available in any repository'
        
        self.doRpmDBSetup()
        pkglist = self.rpmdb.getPkgList()
        self.comps.compile(pkglist)
        

    def doSackFilelistPopulate(self):
        """convenience function to populate the repos with the filelist metadata
           it also is simply to only emit a log if anything actually gets populated"""
        
        necessary = False
        for repo in self.repos.listEnabled():
            if 'filelists' in self.pkgSack.added[repo.id]:
                continue
            else:
                necessary = True
        
        if necessary:
            msg = 'Importing additional filelist information'
            self.log(2, msg)
            self.repos.populateSack(with='filelists')
           
    def buildTransaction(self):
        """go through the packages in the transaction set, find them in the
           packageSack or rpmdb, and pack up the ts accordingly"""
        self.plugins.run('preresolve')
        (rescode, restring) = self.resolveDeps()
        self.plugins.run('postresolve', rescode=rescode, restring=restring)

        if self.tsInfo.changed:
            (rescode, restring) = self.resolveDeps()
            
        return rescode, restring

    def runTransaction(self, cb):
        """takes an rpm callback object, performs the transaction"""

        self.plugins.run('pretrans')

        errors = self.ts.run(cb.callback, '')
        if errors:
            raise Errors.YumBaseError, errors

        if not self.conf.keepcache:
            self.cleanHeaders()
            self.cleanPackages()

        self.plugins.run('posttrans')
        
    def excludePackages(self, repo=None):
        """removes packages from packageSacks based on global exclude lists,
           command line excludes and per-repository excludes, takes optional 
           repo object to use."""
        
        # if not repo: then assume global excludes, only
        # if repo: then do only that repos' packages and excludes
        
        if not repo: # global only
            excludelist = self.conf.exclude
            repoid = None
        else:
            excludelist = repo.exclude
            repoid = repo.id

        if len(excludelist) == 0:
            return
        
        if not repo:
            self.log(2, 'Excluding Packages in global exclude list')
        else:
            self.log(2, 'Excluding Packages from %s' % repo.name)
            
        exactmatch, matched, unmatched = \
           parsePackages(self.pkgSack.returnPackages(repoid), excludelist, casematch=1)
        
        for po in exactmatch + matched:
            self.log(3, 'Excluding %s' % po)
            self.pkgSack.delPackage(po)
      
        self.log(2, 'Finished')

    def includePackages(self, repo):
        """removes packages from packageSacks based on list of packages, to include.
           takes repoid as a mandatory argument."""
        
        includelist = repo.includepkgs
        
        if len(includelist) == 0:
            return
        
        pkglist = self.pkgSack.returnPackages(repo.id)
        exactmatch, matched, unmatched = \
           parsePackages(pkglist, includelist, casematch=1)
        
        self.log(2, 'Reducing %s to included packages only' % repo.name)
        rmlist = []
        
        for po in pkglist:
            if po in exactmatch + matched:
                self.log(3, 'Keeping included package %s' % po)
                continue
            else:
                rmlist.append(po)
        
        for po in rmlist:
            self.log(3, 'Removing unmatched package %s' % po)
            self.pkgSack.delPackage(po)
            
        self.log(2, 'Finished')
        
    def doLock(self, lockfile):
        """perform the yum locking, raise yum-based exceptions, not OSErrors"""
        
        # if we're not root then we don't lock - just return nicely
        if self.conf.uid != 0:
            return
            
        root = self.conf.installroot
        lockfile = root + '/' + lockfile # lock in the chroot
        lockfile = os.path.normpath(lockfile) # get rid of silly preceding extra /
        
        mypid=str(os.getpid())    
        while not self._lock(lockfile, mypid, 0644):
            fd = open(lockfile, 'r')
            try: oldpid = int(fd.readline())
            except ValueError:
                # bogus data in the pid file. Throw away.
                self._unlock(lockfile)
            else:
                try: os.kill(oldpid, 0)
                except OSError, e:
                    if e[0] == errno.ESRCH:
                        # The pid doesn't exist
                        self._unlock(lockfile)
                    else:
                        # Whoa. What the heck happened?
                        msg = 'Unable to check if PID %s is active' % oldpid
                        raise Errors.LockError(1, msg)
                else:
                    # Another copy seems to be running.
                    msg = 'Existing lock %s: another copy is running. Aborting.' % lockfile
                    raise Errors.LockError(0, msg)
    
    def doUnlock(self, lockfile):
        """do the unlock for yum"""
        
        # if we're not root then we don't lock - just return nicely
        if self.conf.uid != 0:
            return
        
        root = self.conf.installroot
        lockfile = root + '/' + lockfile # lock in the chroot
        
        self._unlock(lockfile)
        
    def _lock(self, filename, contents='', mode=0777):
        lockdir = os.path.dirname(filename)
        try:
            if not os.path.exists(lockdir):
                os.makedirs(lockdir, mode=0755)
            fd = os.open(filename, os.O_EXCL|os.O_CREAT|os.O_WRONLY, mode)    
        except OSError, msg:
            if not msg.errno == errno.EEXIST: raise msg
            return 0
        else:
            os.write(fd, contents)
            os.close(fd)
            return 1
    
    def _unlock(self, filename):
        try:
            os.unlink(filename)
        except OSError, msg:
            pass


    def verifyPkg(self, fo, po, raiseError):
        """verifies the package is what we expect it to be
           raiseError  = defaults to 0 - if 1 then will raise
           a URLGrabError if the file does not check out.
           otherwise it returns false for a failure, true for success"""

        if type(fo) is types.InstanceType:
            fo = fo.filename

        for (csumtype, csum, csumid) in po.checksums:
            if csumid:
                checksum = csum
                checksumType = csumtype
                break
        try:
            self.verifyChecksum(fo, checksumType, checksum)
        except URLGrabError, e:
            if raiseError:
                raise
            else:
                return 0

        ylp = YumLocalPackage(self.read_ts, fo)
        if ylp.pkgtup != po.pkgtup:
            if raiseError:
                raise URLGrabError(-1, 'Package does not match intended download')
            else:
                return 0
        
        return 1
        
        
    def verifyChecksum(self, fo, checksumType, csum):
        """Verify the checksum of the file versus the 
           provided checksum"""

        try:
            filesum = misc.checksum(checksumType, fo)
        except Errors.MiscError, e:
            raise URLGrabError(-3, 'Could not perform checksum')
            
        if filesum != csum:
            raise URLGrabError(-1, 'Package does not match checksum')
        
        return 0
            
           
    def downloadPkgs(self, pkglist, callback=None):
        """download list of package objects handed to you, output based on
           callback, raise yum.Errors.YumBaseError on problems"""

        errors = {}
        self.plugins.run('predownload', pkglist=pkglist)
        repo_cached = False
        remote_pkgs = []
        for po in pkglist:
            if hasattr(po, 'pkgtype') and po.pkgtype == 'local':
                continue
                    
            local = po.localPkg()
            if os.path.exists(local):
                cursize = os.stat(local)[6]
                totsize = int(po.size())
                try:
                    result = self.verifyPkg(local, po, raiseError=1)
                except URLGrabError, e: # fails the check
                    
                    repo = self.repos.getRepo(po.repoid)
                    if repo.cache:
                        repo_cached = True
                        msg = 'package fails checksum but caching is enabled for %s' % repo.id
                        if not errors.has_key(po): errors[po] = []
                        errors[po].append(msg)
                        
                    if cursize >= totsize: # keep it around for regetting
                        os.unlink(local)
                        
                else:
                    if result:
                        continue
                    else:
                        if cursize >= totsize: # keep it around for regetting
                            os.unlink(local)
            remote_pkgs.append(po)
            
            # caching is enabled and the package 
            # just failed to check out there's no 
            # way to save this, report the error and return
            if (self.conf.cache or repo_cached) and errors:
                return errors
                

        i = 0
        for po in remote_pkgs:
            i += 1
            repo = self.repos.getRepo(po.repoid)
            remote = po.returnSimple('relativepath')
            checkfunc = (self.verifyPkg, (po, 1), {})


            # FIXME - add check here to make sure we have the disk space
            # available to download the package. If we don't then politely
            # bail out with an informative message.
            # os.statvfs(repo's local path)[4]*[2] >= po.size
            try:
                text = '(%s/%s): %s' % (i, len(remote_pkgs),
                                        os.path.basename(remote))
                local = po.localPkg()
                mylocal = repo.get(relative=remote,
                                   local=local,
                                   checkfunc=checkfunc,
                                   text=text,
                                   cache=repo.http_caching != 'none',
                                   )
            except Errors.RepoError, e:
                if not errors.has_key(po):
                    errors[po] = []
                errors[po].append(str(e))
            else:
                po.localpath = mylocal
                if errors.has_key(po):
                    del errors[po]

        self.plugins.run('postdownload', pkglist=pkglist, errors=errors)

        return errors

    def verifyHeader(self, fo, po, raiseError):
        """check the header out via it's naevr, internally"""
        if type(fo) is types.InstanceType:
            fo = fo.filename
            
        try:
            hlist = rpm.readHeaderListFromFile(fo)
            hdr = hlist[0]
        except (rpm.error, IndexError):
            if raiseError:
                raise URLGrabError(-1, 'Header is not complete.')
            else:
                return 0
                
        yip = YumInstalledPackage(hdr) # we're using YumInstalledPackage b/c
                                       # it takes headers <shrug>
        if yip.pkgtup != po.pkgtup:
            if raiseError:
                raise URLGrabError(-1, 'Header does not match intended download')
            else:
                return 0
        
        return 1
        
    def downloadHeader(self, po):
        """download a header from a package object.
           output based on callback, raise yum.Errors.YumBaseError on problems"""

        if hasattr(po, 'pkgtype') and po.pkgtype == 'local':
            return
                
        errors = {}
        local =  po.localHdr()
        start = po.returnSimple('hdrstart')
        end = po.returnSimple('hdrend')
        repo = self.repos.getRepo(po.repoid)
        remote = po.returnSimple('relativepath')
        if os.path.exists(local):
            try:
                result = self.verifyHeader(local, po, raiseError=1)
            except URLGrabError, e:
                # might add a check for length of file - if it is < 
                # required doing a reget
                try:
                    os.unlink(local)
                except OSError, e:
                    pass
            else:
                po.hdrpath = local
                return
        else:
            if self.conf.cache:
                raise Errors.RepoError, \
                'Header not in local cache and caching-only mode enabled. Cannot download %s' % remote
        
        if self.dsCallback: self.dsCallback.downloadHeader(po.name)
        
        try:
            checkfunc = (self.verifyHeader, (po, 1), {})
            hdrpath = repo.get(relative=remote, local=local, start=start,
                    reget=None, end=end, checkfunc=checkfunc, copy_local=1,
                    cache=repo.http_caching != 'none',
                    )
        except Errors.RepoError, e:
            saved_repo_error = e
            try:
                os.unlink(local)
            except OSError, e:
                raise Errors.RepoError, saved_repo_error
            else:
                raise
        else:
            po.hdrpath = hdrpath
            return

    def sigCheckPkg(self, po):
        '''Take a package object and attempt to verify GPG signature if required

        Returns (result, error_string) where result is
            0 - GPG signature verifies ok or verification is not required
            1 - GPG verification failed but installation of the right GPG key might help
            2 - Fatal GPG verifcation error, give up
        '''
        if hasattr(po, 'pkgtype') and po.pkgtype == 'local':
            check = self.conf.gpgcheck
            hasgpgkey = 0
        else:
            repo = self.repos.getRepo(po.repoid)
            check = repo.gpgcheck
            hasgpgkey = not not repo.gpgkey 
        
        if check:
            sigresult = rpmUtils.miscutils.checkSig(self.read_ts, po.localPkg())
            localfn = os.path.basename(po.localPkg())
            
            if sigresult == 0:
                result = 0
                msg = ''

            elif sigresult == 1:
                if hasgpgkey:
                    result = 1
                else:
                    result = 2
                msg = 'Public key for %s is not installed' % localfn

            elif sigresult == 2:
                result = 2
                msg = 'Problem opening package %s' % localfn

            elif sigresult == 3:
                if hasgpgkey:
                    result = 1
                else:
                    result = 2
                result = 1
                msg = 'Public key for %s is not trusted' % localfn

            elif sigresult == 4:
                result = 2 
                msg = 'Package %s is not signed' % localfn
            
        else:
            result =0
            msg = ''

        return result, msg
        
    def cleanHeaders(self):
        filelist = []
        ext = 'hdr'
        removed = 0
        for repo in self.repos.listEnabled():
            repo.dirSetup()
            path = repo.hdrdir
            filelist = misc.getFileList(path, ext, filelist)
            
        for hdr in filelist:
            try:
                os.unlink(hdr)
            except OSError, e:
                self.errorlog(0, 'Cannot remove header %s' % hdr)
                continue
            else:
                self.log(7, 'Header %s removed' % hdr)
                removed+=1
        msg = '%d headers removed' % removed
        return 0, [msg]

            
    def cleanPackages(self):
        filelist = []
        ext = 'rpm'
        removed = 0
        for repo in self.repos.listEnabled():
            repo.dirSetup()
            path = repo.pkgdir
            filelist = misc.getFileList(path, ext, filelist)
            
        for pkg in filelist:
            try:
                os.unlink(pkg)
            except OSError, e:
                self.errorlog(0, 'Cannot remove package %s' % pkg)
                continue
            else:
                self.log(7, 'Package %s removed' % pkg)
                removed+=1
        
        msg = '%d packages removed' % removed
        return 0, [msg]

    def cleanPickles(self):
        filelist = []
        ext = 'pickle'
        removed = 0
        for repo in self.repos.listEnabled():
            repo.dirSetup()
            path = repo.cachedir
            filelist = misc.getFileList(path, ext, filelist)
            
        for item in filelist:
            try:
                os.unlink(item)
            except OSError, e:
                self.errorlog(0, 'Cannot remove cache file %s' % item)
                continue
            else:
                self.log(7, 'Cache file %s removed' % item)
                removed+=1
        msg = '%d cache files removed' % removed
        return 0, [msg]
    
    def cleanSqlite(self):
        filelist = []
        ext = 'sqlite'
        removed = 0
        for repo in self.repos.listEnabled():
            repo.dirSetup()
            path = repo.cachedir
            filelist = misc.getFileList(path, ext, filelist)
            
        for item in filelist:
            try:
                os.unlink(item)
            except OSError, e:
                self.errorlog(0, 'Cannot remove sqlite cache file %s' % item)
                continue
            else:
                self.log(7, 'Cache file %s removed' % item)
                removed+=1
        msg = '%d cache files removed' % removed
        return 0, [msg]

    def cleanMetadata(self):
        filelist = []
        exts = ['xml.gz', 'xml', 'cachecookie']
        
        removed = 0
        for ext in exts:
            for repo in self.repos.listEnabled():
                repo.dirSetup()
                path = repo.cachedir
                filelist = misc.getFileList(path, ext, filelist)
            
        for item in filelist:
            try:
                os.unlink(item)
            except OSError, e:
                self.errorlog(0, 'Cannot remove metadata file %s' % item)
                continue
            else:
                self.log(7, 'metadata file %s removed' % item)
                removed+=1
        msg = '%d metadata files removed' % removed
        return 0, [msg]

    def doPackageLists(self, pkgnarrow='all'):
        """generates lists of packages, un-reduced, based on pkgnarrow option"""
        
        ygh = misc.GenericHolder()
        
        installed = []
        available = []
        updates = []
        obsoletes = []
        obsoletesTuples = []
        recent = []
        extras = []

        # list all packages - those installed and available, don't 'think about it'
        if pkgnarrow == 'all': 
            self.doRepoSetup()
            self.doRpmDBSetup()
            inst = self.rpmdb.getPkgList()
            for hdr in self.rpmdb.getHdrList():
                po = YumInstalledPackage(hdr)
                installed.append(po)

            if self.conf.showdupesfromrepos:
                avail = self.pkgSack.returnPackages()
            else:
                avail = self.pkgSack.returnNewestByNameArch()

            for pkg in avail:
                pkgtup = (pkg.name, pkg.arch, pkg.epoch, pkg.version, pkg.release)
                if pkgtup not in inst:
                    available.append(pkg)

        # produce the updates list of tuples
        elif pkgnarrow == 'updates':
            self.doRepoSetup()
            self.doRpmDBSetup()
            self.doUpdateSetup()
            for (n,a,e,v,r) in self.up.getUpdatesList():
                matches = self.pkgSack.searchNevra(name=n, arch=a, epoch=e, 
                                                   ver=v, rel=r)
                if len(matches) > 1:
                    updates.append(matches[0])
                    self.log(4, 'More than one identical match in sack for %s' % matches[0])
                elif len(matches) == 1:
                    updates.append(matches[0])
                else:
                    self.log(4, 'Nothing matches %s.%s %s:%s-%s from update' % (n,a,e,v,r))

        # installed only
        elif pkgnarrow == 'installed':
            self.doRpmDBSetup()
            for hdr in self.rpmdb.getHdrList():
                po = YumInstalledPackage(hdr)
                installed.append(po)
        
        # available in a repository
        elif pkgnarrow == 'available':
            self.doRepoSetup()
            self.doRpmDBSetup()
            inst = self.rpmdb.getPkgList()
            if self.conf.showdupesfromrepos:
                avail = self.pkgSack.returnPackages()
            else:
                avail = self.pkgSack.returnNewestByNameArch()

            for pkg in avail:
                pkgtup = (pkg.name, pkg.arch, pkg.epoch, pkg.version, pkg.release)
                if pkgtup not in inst:
                    available.append(pkg)

        # not in a repo but installed
        elif pkgnarrow == 'extras':
            # we must compare the installed set versus the repo set
            # anything installed but not in a repo is an extra
            self.doRepoSetup()
            self.doRpmDBSetup()
            avail = self.pkgSack.simplePkgList()
            for hdr in self.rpmdb.getHdrList():
                po = YumInstalledPackage(hdr)
                if po.pkgtup not in avail:
                    extras.append(po)

        # obsoleting packages (and what they obsolete)
        elif pkgnarrow == 'obsoletes':
            self.doRepoSetup()
            self.doRpmDBSetup()
            self.conf.obsoletes = 1
            self.doUpdateSetup()

            for (pkgtup, instTup) in self.up.getObsoletesTuples():
                (n,a,e,v,r) = pkgtup
                pkgs = self.pkgSack.searchNevra(name=n, arch=a, ver=v, rel=r, epoch=e)
                hdr = self.rpmdb.returnHeaderByTuple(instTup)[0] # the first one
                instpo = YumInstalledPackage(hdr)
                for po in pkgs:
                    obsoletes.append(po)
                    obsoletesTuples.append((po, instpo))
        
        # packages recently added to the repositories
        elif pkgnarrow == 'recent':
            now = time.time()
            recentlimit = now-(self.conf.recent*86400)
            ftimehash = {}
            self.doRepoSetup()
            if self.conf.showdupesfromrepos:
                avail = self.pkgSack.returnPackages()
            else:
                avail = self.pkgSack.returnNewestByNameArch()
            
            for po in avail:
                ftime = int(po.returnSimple('filetime'))
                if ftime > recentlimit:
                    if not ftimehash.has_key(ftime):
                        ftimehash[ftime] = [po]
                    else:
                        ftimehash[ftime].append(po)

            for sometime in ftimehash.keys():
                for po in ftimehash[sometime]:
                    recent.append(po)
        
        
        ygh.installed = installed
        ygh.available = available
        ygh.updates = updates
        ygh.obsoletes = obsoletes
        ygh.obsoletesTuples = obsoletesTuples
        ygh.recent = recent
        ygh.extras = extras

        
        return ygh

    def _refineSearchPattern(self, arg):
        """Takes a search string from the cli for Search or Provides
           and cleans it up so it doesn't make us vomit"""
        
        if re.match('.*[\*,\[,\],\{,\},\?,\+].*', arg):
            restring = fnmatch.translate(arg)
        else:
            restring = re.escape(arg)
            
        return restring
        
    def findDeps(self, pkgs):
        """Return the dependencies for a given package, as well
           possible solutions for those dependencies.
           
           Returns the deps as a dict  of:
             packageobject = [reqs] = [list of satisfying pkgs]"""
        
        results = {}
        self.doRepoSetup()
        self.doRpmDBSetup()

        avail = self.pkgSack.returnPackages()
        exactmatch, matched, unmatched = parsePackages(avail, pkgs)

        if len(unmatched) > 0:
            self.errorlog(0, 'No Match for arguments: %s' % unmatched)

        pkgs = misc.unique(exactmatch + matched)
        
        for pkg in pkgs:
            results[pkg] = {} 
            reqs = pkg.returnPrco('requires');
            reqs.sort()
            pkgresults = results[pkg] # shorthand so we don't have to do the
                                      # double bracket thing
            
            for req in reqs:
                (r,f,v) = req
                if r.startswith('rpmlib('):
                    continue
                
                satisfiers = []

                for po in self.whatProvides(r, f, v):
                    satisfiers.append(po)

                pkgresults[req] = satisfiers
        
        return results
    
    def searchGenerator(self, fields, criteria):
        """Generator method to lighten memory load for some searches.
           This is the preferred search function to use."""
        self.doRepoSetup()


        for string in criteria:
            restring = self._refineSearchPattern(string)
            try: crit_re = re.compile(restring, flags=re.I)
            except sre_constants.error, e:
                raise Errors.MiscError, \
                 'Search Expression: %s is an invalid Regular Expression.\n' % string
                  
            for po in self.pkgSack:
                tmpvalues = []
                for field in fields:
                    value = po.returnSimple(field)
                    if value and crit_re.search(value):
                        tmpvalues.append(value)
                
                if len(tmpvalues) > 0:
                    yield (po, tmpvalues)
        
        # do the same for installed pkgs
        
        self.doRpmDBSetup()        
        for hdr in self.rpmdb.getHdrList(): # this is more expensive so this is the  top op
            po = YumInstalledPackage(hdr)
            tmpvalues = []
            for string in criteria:
                restring = self._refineSearchPattern(string)
                
                try: crit_re = re.compile(restring, flags=re.I)
                except sre_constants.error, e:
                    raise Errors.MiscError, \
                     'Search Expression: %s is an invalid Regular Expression.\n' % string

                for field in fields:
                    value = po.returnSimple(field)
                    if type(value) is types.ListType: # this is annoying
                        value = str(value)
                    if value and crit_re.search(value):
                        tmpvalues.append(value)
                        
            if len(tmpvalues) > 0:
                yield (po, tmpvalues)        
        
    def searchPackages(self, fields, criteria, callback=None):
        """Search specified fields for matches to criteria
           optional callback specified to print out results
           as you go. Callback is a simple function of:
           callback(po, matched values list). It will 
           just return a dict of dict[po]=matched values list"""
        
        matches = {}
        match_gen = self.searchGenerator(fields, criteria)
        
        for (po, matched_strings) in match_gen:
            if callback:
                callback(po, matched_strings)
            if not matches.has_key(po):
                matches[po] = []
            
            matches[po].extend(matched_strings)
        
        return matches
    
    def searchPackageProvides(self, args, callback=None):
        
        self.doRepoSetup()
        matches = {}
        
        # search deps the simple way first
        for arg in args:
            self.log(4, 'searching the simple way')
            pkgs = self.returnPackagesByDep(arg)
            for po in pkgs:
                # if the match is already in the list, just skip it
                if matches.has_key(po):
                    continue
                
                matches[po] = [arg]
                if callback:
                    callback(po, [arg])
                

        # search pkgSack - fully populate the worthwhile metadata to search
        # if it even vaguely matches
        self.log(4, 'fully populating the necessary data')
        for arg in args:
            matched = 0
            globs = ['.*bin\/.*', '.*\/etc\/.*', '^\/usr\/lib\/sendmail$']
            for glob in globs:
                globc = re.compile(glob)
                if globc.match(arg):
                    matched = 1
            if not matched:
                self.doSackFilelistPopulate()

        for arg in args:
            # assume we have to search every package, unless we can refine the search set
            where = self.pkgSack
            
            # this is annoying. If the user doesn't use any glob or regex-like
            # or regexes then we can use the where 'like' search in sqlite
            # if they do use globs or regexes then we can't b/c the string
            # will no longer have much meaning to use it for matches
            
            if hasattr(self.pkgSack, 'searchAll'):
                if not re.match('.*[\*,\[,\],\{,\},\?,\+,\%].*', arg):
                    self.log(4, 'Using the like search on %s' % arg)
                    where = self.pkgSack.searchAll(arg, query_type='like')
            
            self.log(4, 'Searching %d packages' % len(where))
            self.log(4, 'refining the search expression of %s' % arg) 
            restring = self._refineSearchPattern(arg)
            self.log(4, 'refined search: %s' % restring)
            try: 
                arg_re = re.compile(restring, flags=re.I)
            except sre_constants.error, e:
                raise Errors.MiscError, \
                  'Search Expression: %s is an invalid Regular Expression.\n' % arg

            for po in where:
                self.log(5, 'searching package %s' % po)
                tmpvalues = []
                
                self.log(5, 'searching in file entries')
                for filetype in po.returnFileTypes():
                    for fn in po.returnFileEntries(ftype=filetype):
                        if arg_re.search(fn):
                            tmpvalues.append(fn)
                
                self.log(5, 'searching in provides entries')
                for (p_name, p_flag, (p_e, p_v, p_r)) in po.returnPrco('provides'):
                    if arg_re.search(p_name):
                        prov = po.prcoPrintable((p_name, p_flag, (p_e, p_v, p_r)))
                        tmpvalues.append(prov)

                if len(tmpvalues) > 0:
                    if not matches.has_key(po):
                        matches[po] = tmpvalues
                        if callback:
                            callback(po, tmpvalues)
                
        
        self.doRpmDBSetup()
        # installed rpms, too
        taglist = ['filenames', 'dirnames', 'provides']
        arg_re = []
        for arg in args:
            restring = self._refineSearchPattern(arg)

            try: reg = re.compile(restring, flags=re.I)
            except sre_constants.error, e:
                raise Errors.MiscError, \
                 'Search Expression: %s is an invalid Regular Expression.\n' % arg
            
            arg_re.append(reg)

        for hdr in self.rpmdb.getHdrList():
            po = YumInstalledPackage(hdr)
            tmpvalues = []
            searchlist = []
            for tag in taglist:
                tagdata = po.returnSimple(tag)
                if tagdata is None:
                    continue
                if type(tagdata) is types.ListType:
                    searchlist.extend(tagdata)
                else:
                    searchlist.append(tagdata)
            
            for reg in arg_re:
                for item in searchlist:
                    if reg.search(item):
                        tmpvalues.append(item)

            del searchlist

            if len(tmpvalues) > 0:
                if callback:
                    callback(po, tmpvalues)
                matches[po] = tmpvalues
            
            
        return matches

    def doGroupLists(self, uservisible=0):
        """returns two lists of groups, installed groups and available groups
           optional 'uservisible' bool to tell it whether or not to return
           only groups marked as uservisible"""
        
        
        installed = []
        available = []
        
        for grp in self.comps.groups:
            if grp.installed:
                if uservisible:
                    if grp.user_visible:
                        installed.append(grp)
                else:
                    installed.append(grp)
            else:
                if uservisible:
                    if grp.user_visible:
                        available.append(grp)
                else:
                    available.append(grp)
            
        return installed, available
    
    
    def groupRemove(self, grpid):
        """mark all the packages in this group to be removed"""
        
        txmbrs_used = []
        self.doGroupSetup()
        
        thisgroup = self.comps.return_group(grpid)
        if not thisgroup:
            raise Errors.GroupsError, "No Group named %s exists" % grpid

        thisgroup.toremove = True
        pkgs = thisgroup.packages
        for pkg in thisgroup.packages:
            txmbrs = self.remove(name=pkg)
            txmbrs_used.extend(txmbrs)
            for txmbr in txmbrs:
                txmbr.groups.append(thisgroup.groupid)
        
        return txmbrs_used

    def groupUnremove(self, grpid):
        """unmark any packages in the group from being removed"""
        
        self.doGroupSetup()

        thisgroup = self.comps.return_group(grpid)
        if not thisgroup:
            raise Errors.GroupsError, "No Group named %s exists" % grpid

        thisgroup.toremove = False
        pkgs = thisgroup.packages
        for pkg in thisgroup.packages:
            for txmbr in self.tsInfo:
                if txmbr.po.name == pkg and txmbr.po.state in TS_INSTALL_STATES:
                    try:
                        txmbr.groups.remove(grpid)
                    except ValueError:
                        self.log(4, "package %s was not marked in group %s" % (txmbr.po, grpid))
                        continue
                    
                    # if there aren't any other groups mentioned then remove the pkg
                    if len(txmbr.groups) == 0:
                        self.tsInfo.remove(txmbr.po.pkgtup)
        
        
    def selectGroup(self, grpid):
        """mark all the packages in the group to be installed
           returns a list of transaction members it added to the transaction 
           set"""
        
        txmbrs_used = []
        if not self.comps:
            self.doGroupSetup()
        
        if not self.comps.has_group(grpid):
            raise Errors.GroupsError, "No Group named %s exists" % grpid
            
        thisgroup = self.comps.return_group(grpid)
        
        if not thisgroup:
            raise Errors.GroupsError, "No Group named %s exists" % grpid
        
        if thisgroup.selected:
            return txmbrs_used
        
        thisgroup.selected = True
        
        pkgs = thisgroup.mandatory_packages.keys() + thisgroup.default_packages.keys()
        for pkg in pkgs:
            self.log(5, 'Adding package %s from group %s' % (pkg, thisgroup.groupid))
            try:
                txmbrs = self.install(name = pkg)
            except Errors.InstallError, e:
                self.log(3, 'No package named %s available to be installed' % pkg)
            else:
                txmbrs_used.extend(txmbrs)
                for txmbr in txmbrs:
                    txmbr.groups.append(thisgroup.groupid)
        
        if self.conf.enable_group_conditionals:
            for condreq, cond in thisgroup.conditional_packages.iteritems():
                if self._isPackageInstalled(cond):
                    try:
                        txmbrs = self.install(name = condreq)
                    except Errors.InstallError:
                        # we don't care if the package doesn't exist
                        continue
                    txmbrs_used.extend(txmbrs)
                    for txmbr in txmbrs:
                        txmbr.groups.append(thisgroup.groupid)
                    continue
                # Otherwise we hook into tsInfo.add
                pkgs = self.pkgSack.searchNevra(name=condreq)
                if pkgs:
                    pkgs = self.bestPackagesFromList(pkgs)
                if self.tsInfo.conditionals.has_key(cond):
                    self.tsInfo.conditionals[cond].extend(pkgs)
                else:
                    self.tsInfo.conditionals[cond] = pkgs

        return txmbrs_used

    def deselectGroup(self, grpid):
        """de-mark all the packages in the group for install"""
        if not self.comps:
            self.doGroupSetup()
        
        if not self.comps.has_group(grpid):
            raise Errors.GroupsError, "No Group named %s exists" % grpid
            
        thisgroup = self.comps.return_group(grpid)
        if not thisgroup:
            raise Errors.GroupsError, "No Group named %s exists" % grpid
        
        thisgroup.selected = False
        
        for pkgname in thisgroup.packages:
        
            for txmbr in self.tsInfo:
                if txmbr.po.name == pkgname and txmbr.po.state in TS_INSTALL_STATES:
                    try: 
                        txmbr.groups.remove(grpid)
                    except ValueError:
                        self.log(4, "package %s was not marked in group %s" % (txmbr.po, grpid))
                        continue
                    
                    # if there aren't any other groups mentioned then remove the pkg
                    if len(txmbr.groups) == 0:
                        self.tsInfo.remove(txmbr.po.pkgtup)

                    
        
    def getPackageObject(self, pkgtup):
        """retrieves a packageObject from a pkgtuple - if we need
           to pick and choose which one is best we better call out
           to some method from here to pick the best pkgobj if there are
           more than one response - right now it's more rudimentary."""
           
        
        (n,a,e,v,r) = pkgtup
        
        # look it up in the self.localPackages first:
        for po in self.localPackages:
            if po.pkgtup == pkgtup:
                return po
                
        pkgs = self.pkgSack.packagesByTuple(pkgtup)

        if len(pkgs) == 0:
            raise Errors.DepError, 'Package tuple %s could not be found in packagesack' % str(pkgtup)
            return None
            
        if len(pkgs) > 1: # boy it'd be nice to do something smarter here FIXME
            result = pkgs[0]
        else:
            result = pkgs[0] # which should be the only
        
            # this is where we could do something to figure out which repository
            # is the best one to pull from
        
        return result

    def getInstalledPackageObject(self, pkgtup):
        """returns a YumInstallPackage object for the pkgtup specified"""
        
        hdrs = self.rpmdb.returnHeaderByTuple(pkgtup)
        hdr = hdrs[0]
        po = YumInstalledPackage(hdr)
        return po
        
    def gpgKeyCheck(self):
        """checks for the presence of gpg keys in the rpmdb
           returns 0 if no keys returns 1 if keys"""

        gpgkeyschecked = self.conf.cachedir + '/.gpgkeyschecked.yum'
        if os.path.exists(gpgkeyschecked):
            return 1
            
        myts = rpmUtils.transaction.initReadOnlyTransaction(root=self.conf.installroot)
        myts.pushVSFlags(~(rpm._RPMVSF_NOSIGNATURES|rpm._RPMVSF_NODIGESTS))
        idx = myts.dbMatch('name', 'gpg-pubkey')
        keys = idx.count()
        del idx
        del myts
        
        if keys == 0:
            return 0
        else:
            mydir = os.path.dirname(gpgkeyschecked)
            if not os.path.exists(mydir):
                os.makedirs(mydir)
                
            fo = open(gpgkeyschecked, 'w')
            fo.close()
            del fo
            return 1

    def returnPackagesByDep(self, depstring):
        """Pass in a generic [build]require string and this function will 
           pass back the packages it finds providing that dep."""
        
        results = []
        self.doRepoSetup()
        # parse the string out
        #  either it is 'dep (some operator) e:v-r'
        #  or /file/dep
        #  or packagename
        depname = depstring
        depflags = None
        depver = None
        
        if depstring[0] != '/':
            # not a file dep - look at it for being versioned
            if re.search('[>=<]', depstring):  # versioned
                try:
                    depname, flagsymbol, depver = depstring.split()
                except ValueError, e:
                    raise Errors.YumBaseError, 'Invalid versioned dependency string, try quoting it.'
                if not SYMBOLFLAGS.has_key(flagsymbol):
                    raise Errors.YumBaseError, 'Invalid version flag'
                depflags = SYMBOLFLAGS[flagsymbol]
                
        sack = self.whatProvides(depname, depflags, depver)
        results = sack.returnPackages()
        return results
        

    def returnPackageByDep(self, depstring):
        """Pass in a generic [build]require string and this function will 
           pass back the best(or first) package it finds providing that dep."""
        
        try:
            pkglist = self.returnPackagesByDep(depstring)
        except Errors.YumBaseError, e:
            raise Errors.YumBaseError, 'No Package found for %s' % depstring
        
        result = self._bestPackageFromList(pkglist)
        if result is None:
            raise Errors.YumBaseError, 'No Package found for %s' % depstring
        
        return result

    def returnInstalledPackagesByDep(self, depstring):
        """Pass in a generic [build]require string and this function will 
           pass back the installed packages it finds providing that dep."""
        
        results = []
        self.doRpmDBSetup()
        # parse the string out
        #  either it is 'dep (some operator) e:v-r'
        #  or /file/dep
        #  or packagename
        depname = depstring
        depflags = None
        depver = None
        
        if depstring[0] != '/':
            # not a file dep - look at it for being versioned
            if re.search('[>=<]', depstring):  # versioned
                try:
                    depname, flagsymbol, depver = depstring.split()
                except ValueError, e:
                    raise Errors.YumBaseError, 'Invalid versioned dependency string, try quoting it.'
                if not SYMBOLFLAGS.has_key(flagsymbol):
                    raise Errors.YumBaseError, 'Invalid version flag'
                depflags = SYMBOLFLAGS[flagsymbol]
                
        pkglist = self.rpmdb.whatProvides(depname, depflags, depver)
        
        for pkgtup in pkglist:
            results.append(self.getInstalledPackageObject(pkgtup))
        
        return results


    def _bestPackageFromList(self, pkglist):
        """take list of package objects and return the best package object.
           If the list is empty, return None. 
           
           Note: this is not aware of multilib so make sure you're only
           passing it packages of a single arch group."""
        
        
        if len(pkglist) == 0:
            return None
            
        if len(pkglist) == 1:
            return pkglist[0]
        
        mysack = ListPackageSack()
        mysack.addList(pkglist)
        bestlist = mysack.returnNewestByNameArch() # get rid of all lesser vers
        
        best = bestlist[0]
        for pkg in bestlist[1:]:
            if len(pkg.name) < len(best.name): # shortest name silliness
                best = pkg
                continue
            elif len(pkg.name) > len(best.name):
                continue

            # compare arch
            arch = rpmUtils.arch.getBestArchFromList([pkg.arch, best.arch])
            if arch == pkg.arch:
                best = pkg
                continue

        return best

    def bestPackagesFromList(self, pkglist, arch=None):
        """Takes a list of packages, returns the best packages.
           This function is multilib aware so that it will not compare
           multilib to singlelib packages""" 
    
        returnlist = []
        compatArchList = rpmUtils.arch.getArchList(arch)
        multiLib = []
        singleLib = []
        for po in pkglist:
            if po.arch not in compatArchList:
                continue
            elif rpmUtils.arch.isMultiLibArch(arch=po.arch):
                multiLib.append(po)
            else:
                singleLib.append(po)
                
        # we should have two lists now - one of singleLib packages
        # one of multilib packages
        # go through each one and find the best package(s)
        for pkglist in [multiLib, singleLib]:
            best = self._bestPackageFromList(pkglist)
            if best is not None:
                returnlist.append(best)
        
        return returnlist


    def install(self, po=None, **kwargs):
        """try to mark for install the item specified. Uses provided package 
           object, if available. If not it uses the kwargs and gets the best
           packages from the keyword options provided 
           returns the list of txmbr of the items it installs
           
           """
        
        self.doRepoSetup()
        self.doSackSetup()
        self.doRpmDBSetup()
        self.doUpdateSetup()
        
        pkgs = []
        if po:
            if isinstance(po, YumAvailablePackage) or isinstance(po, YumLocalPackage):
                pkgs.append(po)
            else:
                raise Errors.InstallError, 'Package Object was not a package object instance'
            
        else:
            if not kwargs.keys():
                raise Errors.InstallError, 'Nothing specified to install'

            if kwargs.has_key('pattern'):
                exactmatch, matched, unmatched = \
                    parsePackages(self.pkgSack.returnPackages(),[kwargs['pattern']] , casematch=1)
                pkgs.extend(exactmatch)
                pkgs.extend(matched)

            else:
                nevra_dict = self._nevra_kwarg_parse(kwargs)

                pkgs = self.pkgSack.searchNevra(name=nevra_dict['name'],
                     epoch=nevra_dict['epoch'], arch=nevra_dict['arch'],
                     ver=nevra_dict['version'], rel=nevra_dict['release'])
                
                if pkgs:
                    pkgs = self.bestPackagesFromList(pkgs)

        if len(pkgs) == 0:
            #FIXME - this is where we could check to see if it already installed
            # for returning better errors
            raise Errors.InstallError, 'No package(s) available to install'
        
        # FIXME - lots more checking here
        #  - install instead of erase
        #  - better error handling/reporting
        
        tx_return = []
        for po in pkgs:
            if self.tsInfo.exists(pkgtup=po.pkgtup):
                self.log(4, 'Package: %s  - already in transaction set' % po)
                tx_return.extend(self.tsInfo.getMembers(pkgtup=po.pkgtup))
                continue
            
            # make sure this shouldn't be passed to update:
            if self.up.updating_dict.has_key(po.pkgtup):
                txmbrs = self.update(po=po)
                tx_return.extend(txmbrs)
                continue
            
            # make sure it's not already installed
            if self.rpmdb.installed(name=po.name, arch=po.arch, epoch=po.epoch,
                    rel=po.rel, ver=po.ver):
                self.errorlog(2, 'Package %s already installed and latest version' % po)
                continue
            
            # make sure we're not installing a package which is obsoleted by something
            # else in the repo
            thispkgobsdict = self.up.checkForObsolete([po.pkgtup])
            if thispkgobsdict.has_key(po.pkgtup):
                obsoleting = thispkgobsdict[po.pkgtup][0]
                obsoleting_pkg = self.getPackageObject(obsoleting)
                self.install(po=obsoleting_pkg)
                continue
                
            txmbr = self.tsInfo.addInstall(po)
            tx_return.append(txmbr)
        
        return tx_return

    
    def update(self, po=None, **kwargs):
        """try to mark for update the item(s) specified. 
            po is a package object - if that is there, mark it for update,
            if possible
            else use **kwargs to match the package needing update
            if nothing is specified at all then attempt to update everything
            
            returns the list of txmbr of the items it marked for update"""
        
        # do updates list
        # do obsoletes list
        
        # check for args - if no po nor kwargs, do them all
        # if po, do it, ignore all else
        # if no po do kwargs
        # uninstalled pkgs called for update get returned with errors in a list, maybe?

        self.doRepoSetup()
        self.doSackSetup()
        self.doRpmDBSetup()
        self.doUpdateSetup()
        updates = self.up.getUpdatesTuples()
        if self.conf.obsoletes:
            obsoletes = self.up.getObsoletesTuples(newest=1)
        else:
            obsoletes = []


        tx_return = []
        if not po and not kwargs.keys(): # update everything (the easy case)
            self.log(5, 'Updating Everything')
            for (obsoleting, installed) in obsoletes:
                obsoleting_pkg = self.getPackageObject(obsoleting)
                hdr = self.rpmdb.returnHeaderByTuple(installed)[0]
                installed_pkg =  YumInstalledPackage(hdr)
                txmbr = self.tsInfo.addObsoleting(obsoleting_pkg, installed_pkg)
                self.tsInfo.addObsoleted(installed_pkg, obsoleting_pkg)
                tx_return.append(txmbr)
                
            for (new, old) in updates:
                if self.tsInfo.isObsoleted(pkgtup=old):
                    self.log(5, 'Not Updating Package that is already obsoleted: %s.%s %s:%s-%s' % old)
                else:
                    updating_pkg = self.getPackageObject(new)
                    hdr = self.rpmdb.returnHeaderByTuple(old)[0]
                    updated_pkg = YumInstalledPackage(hdr)
                    txmbr = self.tsInfo.addUpdate(updating_pkg, updated_pkg)
                    tx_return.append(txmbr)
            
            return tx_return

        else:
            instpkgs = []
            availpkgs = []
            if po: # just a po
                if po.repoid == 'installed':
                    instpkgs.append(po)
                else:
                    availpkgs.append(po)
                
            else: # we have kwargs, sort them out.
                nevra_dict = self._nevra_kwarg_parse(kwargs)

                availpkgs = self.pkgSack.searchNevra(name=nevra_dict['name'],
                          epoch=nevra_dict['epoch'], arch=nevra_dict['arch'],
                        ver=nevra_dict['version'], rel=nevra_dict['release'])
                
                installed_tuples = self.rpmdb.returnTupleByKeyword(
                                name=nevra_dict['name'], epoch=nevra_dict['epoch'],
                                arch=nevra_dict['arch'], ver=nevra_dict['version'],
                                rel=nevra_dict['release'])
            
                for tup in installed_tuples:
                    hdr = self.rpmdb.returnHeaderByTuple(tup)[0]
                    installed_pkg =  YumInstalledPackage(hdr)
                    instpkgs.append(installed_pkg)
            
            # for any thing specified
            # get the list of available pkgs matching it (or take the po)
            # get the list of installed pkgs matching it (or take the po)
            # go through each list and look for:
               # things obsoleting it if it is an installed pkg
               # things it updates if it is an available pkg
               # things updating it if it is an installed pkg
               # in that order
               # all along checking to make sure we:
                # don't update something that's already been obsoleted
            
            # TODO: we should search the updates and obsoletes list and
            # mark the package being updated or obsoleted away appropriately
            # and the package relationship in the tsInfo
            
            for installed_pkg in instpkgs:
                if self.up.obsoleted_dict.has_key(installed_pkg.pkgtup) and self.conf.obsoletes:
                    obsoleting = self.up.obsoleted_dict[installed_pkg.pkgtup][0]
                    obsoleting_pkg = self.getPackageObject(obsoleting)
                    # FIXME check for what might be in there here
                    txmbr = self.tsInfo.addObsoleting(obsoleting_pkg, installed_pkg)
                    self.tsInfo.addObsoleted(installed_pkg, obsoleting_pkg)
                    tx_return.append(txmbr)
            
            for available_pkg in availpkgs:
                if self.up.updating_dict.has_key(available_pkg.pkgtup):
                    updated = self.up.updating_dict[available_pkg.pkgtup][0]
                    if self.tsInfo.isObsoleted(updated):
                        self.log(5, 'Not Updating Package that is already obsoleted: %s.%s %s:%s-%s' % updated)
                    else:
                        hdr = self.rpmdb.returnHeaderByTuple(updated)[0]
                        updated_pkg =  YumInstalledPackage(hdr)
                        txmbr = self.tsInfo.addUpdate(available_pkg, updated_pkg)
                        tx_return.append(txmbr)
                    
            for installed_pkg in instpkgs:
                if self.up.updatesdict.has_key(installed_pkg.pkgtup):
                    updating = self.up.updatesdict[installed_pkg.pkgtup][0]
                    updating_pkg = self.getPackageObject(updating)
                    if self.tsInfo.isObsoleted(installed_pkg.pkgtup):
                        self.log(5, 'Not Updating Package that is already obsoleted: %s.%s %s:%s-%s' % installed_pkg.pkgtup)
                    else:
                        txmbr = self.tsInfo.addUpdate(updating_pkg, installed_pkg)
                        tx_return.append(txmbr)

        return tx_return
        
        
    def remove(self, po=None, **kwargs):
        """try to find and mark for remove the specified package(s) -
            if po is specified then that package object (if it is installed) 
            will be marked for removal.
            if no po then look at kwargs, if neither then raise an exception"""

        if not po and not kwargs.keys():
            raise Errors.RemoveError, 'Nothing specified to remove'
        
        self.doRpmDBSetup()
        tx_return = []
        pkgs = []
        
        if po:
            pkgs = [po]
        else:
            nevra_dict = self._nevra_kwarg_parse(kwargs)

            installed_tuples = self.rpmdb.returnTupleByKeyword(
                            name=nevra_dict['name'], epoch=nevra_dict['epoch'],
                            arch=nevra_dict['arch'], ver=nevra_dict['version'],
                            rel=nevra_dict['release'])

            for tup in installed_tuples:
                hdr = self.rpmdb.returnHeaderByTuple(tup)[0]
                installed_pkg =  YumInstalledPackage(hdr)
                pkgs.append(installed_pkg)

        
        if len(pkgs) == 0: # should this even be happening?
            self.errorlog(3, "No package matched to remove")

        for po in pkgs:
            txmbr = self.tsInfo.addErase(po)
            tx_return.append(txmbr)
        
        return tx_return

    def _nevra_kwarg_parse(self, kwargs):
            
        returndict = {}
        
        try: returndict['name'] = kwargs['name']
        except KeyError:  returndict['name'] = None

        try: returndict['epoch'] = kwargs['epoch']
        except KeyError: returndict['epoch'] = None

        try: returndict['arch'] = kwargs['arch']
        except KeyError: returndict['arch'] = None
        
        # get them as ver, version and rel, release - if someone
        # specifies one of each then that's kinda silly.
        try: returndict['version'] = kwargs['version']
        except KeyError: returndict['version'] = None
        if returndict['version'] is None:
            try: returndict['version'] = kwargs['ver']
            except KeyError: returndict['version'] = None

        try: returndict['release'] = kwargs['release']
        except KeyError: returndict['release'] = None
        if returndict['release'] is None:
            try: release = kwargs['rel']
            except KeyError: returndict['release'] = None
        
        return returndict

    def _isPackageInstalled(self, pkgname):
        # FIXME: Taken from anaconda/pirut 
        # clean up and make public
        installed = False
        if self.rpmdb.installed(name = pkgname):
            installed = True

        lst = self.tsInfo.matchNaevr(name = pkgname)
        for txmbr in lst:
            if txmbr.output_state in TS_INSTALL_STATES:
                return True
        if installed and len(lst) > 0:
            # if we get here, then it was installed, but it's in the tsInfo
            # for an erase or obsoleted --> not going to be installed at end
            return False
        return installed


--- NEW FILE comps.py ---
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# Copyright 2005 Duke University

import sys
from constants import *
from cElementTree import iterparse
import exceptions


lang_attr = '{http://www.w3.org/XML/1998/namespace}lang'

def parse_boolean(strng):
    if BOOLEAN_STATES.has_key(strng.lower()):
        return BOOLEAN_STATES[strng.lower()]
    else:
        return False

def parse_number(strng):
    return int(strng)

class CompsException(exceptions.Exception):
    pass
        
class Group(object):
    def __init__(self, elem=None):
        self.user_visible = True
        self.default = False
        self.selected = False
        self.name = ""
        self.description = ""
        self.translated_name = {}
        self.translated_description = {}
        self.mandatory_packages = {}
        self.optional_packages = {}
        self.default_packages = {}
        self.conditional_packages = {}
        self.langonly = None ## what the hell is this?
        self.groupid = None
        self.display_order = 1024
        self.installed = False
        self.toremove = False
        

        if elem:
            self.parse(elem)
        
    def __str__(self):
        return self.name
    
    def _packageiter(self):
        lst = self.mandatory_packages.keys() + \
              self.optional_packages.keys() + \
              self.default_packages.keys() + \
              self.conditional_packages.keys() 
        
        return lst
    
    packages = property(_packageiter)
    
    def nameByLang(self, lang):
        if self.translated_name.has_key[lang]:
            return self.translated_name[lang]
        else:
            return self.name


    def descriptionByLang(self, lang):
        if self.translated_description.has_key[lang]:
            return self.translated_description[lang]
        else:
            return self.description

    def parse(self, elem):
        for child in elem:

            if child.tag == 'id':
                id = child.text
                if self.groupid is not None:
                    raise CompsException
                self.groupid = id
            
            elif child.tag == 'name':
                text = child.text
                if text:
                    text = text.encode('utf8')
                
                lang = child.attrib.get(lang_attr)
                if lang:
                    self.translated_name[lang] = text
                else:
                    self.name = text
    
    
            elif child.tag == 'description':
                text = child.text
                if text:
                    text = text.encode('utf8')
                    
                lang = child.attrib.get(lang_attr)
                if lang:
                    self.translated_description[lang] = text
                else:
                    self.description = text
    
            elif child.tag == 'uservisible':
                self.user_visible = parse_boolean(child.text)
    
            elif child.tag == 'display_order':
                self.display_order = parse_number(child.text)

            elif child.tag == 'default':
                self.default = parse_boolean(child.text)
    
            elif child.tag == 'langonly': ## FIXME - what the hell is langonly?
                text = child.text
                if self.langonly is not None:
                    raise CompsException
                self.langonly = text
    
            elif child.tag == 'packagelist':
                self.parse_package_list(child)
    
    def parse_package_list(self, packagelist_elem):
        for child in packagelist_elem:
            if child.tag == 'packagereq':
                type = child.attrib.get('type')
                if not type:
                    type = u'mandatory'

                if type not in ('mandatory', 'default', 'optional', 'conditional'):
                    raise CompsException

                package = child.text
                if type == 'mandatory':
                    self.mandatory_packages[package] = 1
                elif type == 'default':
                    self.default_packages[package] = 1
                elif type == 'optional':
                    self.optional_packages[package] = 1
                elif type == 'conditional':
                    self.conditional_packages[package] = child.attrib.get('requires')



    def add(self, obj):
        """Add another group object to this object"""
    
        # we only need package lists and any translation that we don't already
        # have
        
        for pkg in obj.mandatory_packages.keys():
            self.mandatory_packages[pkg] = 1
        for pkg in obj.default_packages.keys():
            self.default_packages[pkg] = 1
        for pkg in obj.optional_packages.keys():
            self.optional_packages[pkg] = 1
        for pkg in obj.conditional_packages.keys():
            self.conditional_packages[pkg] = obj.conditional_packages[pkg]
        
        # name and description translations
        for lang in obj.translated_name.keys():
            if not self.translated_name.has_key(lang):
                self.translated_name[lang] = obj.translated_name[lang]
        
        for lang in obj.translated_description.keys():
            if not self.translated_description.has_key(lang):
                self.translated_description[lang] = obj.translated_description[lang]
        
        
        


class Category(object):
    def __init__(self, elem=None):
        self.name = ""
        self.categoryid = None
        self.description = ""
        self.translated_name = {}
        self.translated_description = {}
        self.display_order = 1024
        self._groups = {}        

        if elem:
            self.parse(elem)
            
    def __str__(self):
        return self.name
    
    def _groupiter(self):
        return self._groups.keys()
    
    groups = property(_groupiter)
    
    def parse(self, elem):
        for child in elem:
            if child.tag == 'id':
                id = child.text
                if self.categoryid is not None:
                    raise CompsException
                self.categoryid = id

            elif child.tag == 'name':
                text = child.text
                if text:
                    text = text.encode('utf8')
                    
                lang = child.attrib.get(lang_attr)
                if lang:
                    self.translated_name[lang] = text
                else:
                    self.name = text
    
            elif child.tag == 'description':
                text = child.text
                if text:
                    text = text.encode('utf8')
                    
                lang = child.attrib.get(lang_attr)
                if lang:
                    self.translated_description[lang] = text
                else:
                    self.description = text
            
            elif child.tag == 'grouplist':
                self.parse_group_list(child)

            elif child.tag == 'display_order':
                self.display_order = parse_number(child.text)

    def parse_group_list(self, grouplist_elem):
        for child in grouplist_elem:
            if child.tag == 'groupid':
                groupid = child.text
                self._groups[groupid] = 1

    def add(self, obj):
        """Add another category object to this object"""
    
        for grp in obj.groups:
            self._groups[grp] = 1
        
        # name and description translations
        for lang in obj.translated_name.keys():
            if not self.translated_name.has_key(lang):
                self.translated_name[lang] = obj.translated_name[lang]
        
        for lang in obj.translated_description.keys():
            if not self.translated_description.has_key(lang):
                self.translated_description[lang] = obj.translated_description[lang]

        
class Comps:
    def __init__(self, overwrite_groups=False):
        self._groups = {}
        self._categories = {}
        self.compscount = 0
        self.overwrite_groups = overwrite_groups
        self.compiled = False # have groups been compiled into avail/installed 
                              # lists, yet.


    def __sort_order(self, item1, item2):
        if item1.display_order > item2.display_order:
            return 1
        elif item1.display_order == item2.display_order:
            return 0
        else:
            return -1
    
    def get_groups(self):
        grps = self._groups.values()
        grps.sort(self.__sort_order)
        return grps
        
    def get_categories(self):
        cats = self._categories.values()
        cats.sort(self.__sort_order)
        return cats
        
    
    groups = property(get_groups)
    categories = property(get_categories)
    
    
    
    def has_group(self, grpid):
        exists = self.return_group(grpid)
            
        if exists:
            return True
            
        return False
    
    def return_group(self, grpid):
        if self._groups.has_key(grpid):
            return self._groups[grpid]
        
        # do matches against group names and ids, too
        for group in self.groups:
            names = [ group.name, group.groupid ]
            names.extend(group.translated_name.values())
            if grpid in names:
                return group

        
        return None


    def add(self, srcfile = None):
        if not srcfile:
            raise CompsException
            
        if type(srcfile) == type('str'):
            # srcfile is a filename string
            infile = open(srcfile, 'rt')
        else:
            # srcfile is a file object
            infile = srcfile
        
        self.compscount += 1
        self.compiled = False
        
        parser = iterparse(infile)

        for event, elem in parser:
            if elem.tag == "group":
                group = Group(elem)
                if self._groups.has_key(group.groupid):
                    thatgroup = self._groups[group.groupid]
                    thatgroup.add(group)
                else:
                    self._groups[group.groupid] = group

            if elem.tag == "category":
                category = Category(elem)
                if self._categories.has_key(category.categoryid):
                    thatcat = self._categories[category.categoryid]
                    thatcat.add(category)
                else:
                    self._categories[category.categoryid] = category
        
        del parser
        
    def compile(self, pkgtuplist):
        """ compile the groups into installed/available groups """
        
        # convert the tuple list to a simple dict of pkgnames
        inst_pkg_names = {}
        for (n,a,e,v,r) in pkgtuplist:
            inst_pkg_names[n] = 1
        

        for group in self.groups:
            # if there are mandatory packages in the group, then make sure
            # they're all installed.  if any are missing, then the group
            # isn't installed.
            if len(group.mandatory_packages.keys()) > 0:
                check_pkgs = group.mandatory_packages.keys()
                group.installed = True
                for pkgname in check_pkgs:
                    if not inst_pkg_names.has_key(pkgname):
                        group.installed = False
                        break
            # if it doesn't have any of those then see if it has ANY of the
            # optional/default packages installed.
            # If so - then the group is installed
            else:
                check_pkgs = group.optional_packages.keys() + group.default_packages.keys() + group.conditional_packages.keys()
                group.installed = False
                for pkgname in check_pkgs:
                    if inst_pkg_names.has_key(pkgname):
                        group.installed = True
                        break
        
        self.compiled = True

def main():

    try:
        print sys.argv[1]
        p = Comps()
        for srcfile in sys.argv[1:]:
            p.add(srcfile)

        for group in p.groups:
            print group
            for pkg in group.packages:
                print '  ' + pkg
        
        for category in p.categories:
            print category.name
            for group in category.groups:
                print '  ' + group
                
    except IOError:
        print >> sys.stderr, "newcomps.py: No such file:\'%s\'" % sys.argv[1]
        sys.exit(1)
        
if __name__ == '__main__':
    main()



--- NEW FILE config.py ---
#!/usr/bin/python -t

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# Copyright 2002 Duke University 

import os
import warnings
import rpm
import copy
import urlparse
import sys
from parser import IncludingConfigParser, IncludedDirConfigParser
from ConfigParser import NoSectionError, NoOptionError
import rpmUtils.transaction
import rpmUtils.arch
import Errors
from repos import Repository

class OptionData(object):
    '''
    Simple class to track state for a single option instance.
    '''
    def __init__(self, name, initial_value):
        self.name = name
        self.value = initial_value
        self.parser = None
        self.section = None

class Option(object):
    '''
    This class handles a single Yum configuration file option. Create
    subclasses for each type of supported configuration option.
    
    Python descriptor foo (__get__ and __set__) is used to make option
    definition easy and consise.
    '''

    def __init__(self, default=None):
        self._setattrname()
        self.inherit = False
        self.default = default

    def _setattrname(self):
        '''Calculate the internal attribute name used to store option state in
        configuration instances.
        '''
        self._attrname = '__opt%d' % id(self)

    def __get__(self, obj, objtype):
        '''Called when the option is read (via the descriptor protocol). 

        @param obj: The configuration instance to modify.
        @param objtype: The type of the config instance (not used).
        @return: The parsed option value or the default value if the value
            wasn't set in the configuration file.
        '''
        if obj is None:
            return self
        optdata = getattr(obj, self._attrname, None)
        if optdata == None:
            return None
        else:
            return optdata.value

    def __set__(self, obj, value):
        '''Called when the option is set (via the descriptor protocol). 

        @param obj: The configuration instance to modify.
        @param value: The value to set the option to.
        @return: Nothing.
        '''
        optdata = getattr(obj, self._attrname)
       
        # Only try to parse if its a string
        if isinstance(value, basestring):
            try:
                value = self.parse(value)
            except ValueError, e:
                # Add the field name onto the error
                raise ValueError('Error parsing %r: %s' % (optdata.name,
                    str(e)))

        optdata.value = value

        # Write string value back to parser instance if possible
        if optdata.parser != None:
            strvalue = self.tostring(value)
            optdata.parser.set(optdata.section, optdata.name, strvalue)

    def setup(self, obj, name):
        '''Initialise the option for a config instance. 
        This must be called before the option can be set or retrieved. 

        @param obj: BaseConfig (or subclass) instance.
        @param name: Name of the option.
        '''
        setattr(obj, self._attrname, OptionData(name, self.default))

    def setparser(self, obj, parser, section):
        '''Set the configuration parser for this option. This is required so
        that options can be written back to a configuration file.

        @param obj: BaseConfig (or subclass) instance.
        @param parser: ConfigParser (or subclass) where the option is read from.
        @param section: config file section where the option is from.
        '''
        optdata = getattr(obj, self._attrname)
        optdata.parser = parser
        optdata.section = section

    def clone(self):
        '''Return a safe copy of this Option instance
        '''
        new = copy.copy(self)
        new._setattrname()
        return new

    def parse(self, s):
        '''Parse the string value to the Option's native value.

        @param s: Raw string value to parse.
        @return: Validated native value.
    
        Will raise ValueError if there was a problem parsing the string.
        Subclasses should override this.
        '''
        return s

    def tostring(self, value):
        '''Convert the Option's native value to a string value.

        @param value: Native option value.
        @return: String representation of input.

        This does the opposite of the parse() method above.
        Subclasses should override this.
        '''
        return str(value)

def Inherit(option_obj):
    '''Clone an Option instance for the purposes of inheritance. The returned
    instance has all the same properties as the input Option and shares items
    such as the default value. Use this to avoid redefinition of reused
    options.

    @param option_obj: Option instance to inherit.
    @return: New Option instance inherited from the input.
    '''
    new_option = option_obj.clone()
    new_option.inherit = True
    return new_option

class ListOption(Option):

    def __init__(self, default=None):
        if default is None:
            default = []
        super(ListOption, self).__init__(default)

    def parse(self, s):
        """Converts a string from the config file to a workable list

        Commas and spaces are used as separators for the list
        """
        # we need to allow for the '\n[whitespace]' continuation - easier
        # to sub the \n with a space and then read the lines
        s = s.replace('\n', ' ')
        s = s.replace(',', ' ')
        return s.split()

    def tostring(self, value):
        return '\n '.join(value)

class UrlOption(Option):
    '''
    This option handles lists of URLs with validation of the URL scheme.
    '''

    def __init__(self, default=None, schemes=('http', 'ftp', 'file', 'https'), 
            allow_none=False):
        super(UrlOption, self).__init__(default)
        self.schemes = schemes
        self.allow_none = allow_none

    def parse(self, url):
        url = url.strip()

        # Handle the "_none_" special case
        if url.lower() == '_none_':
            if self.allow_none:
                return None
            else:
                raise ValueError('"_none_" is not a valid value')

        # Check that scheme is valid
        (s,b,p,q,f,o) = urlparse.urlparse(url)
        if s not in self.schemes:
            raise ValueError('URL must be %s not "%s"' % (self._schemelist(), s))

        return url

    def _schemelist(self):
        '''Return a user friendly list of the allowed schemes
        '''
        if len(self.schemes) < 1:
            return 'empty'
        elif len(self.schemes) == 1:
            return self.schemes[0]
        else:
            return '%s or %s' % (', '.join(self.schemes[:-1]), self.schemes[-1])

class UrlListOption(ListOption):
    '''
    Option for handling lists of URLs with validation of the URL scheme.
    '''

    def __init__(self, default=None, schemes=('http', 'ftp', 'file', 'https')):
        super(UrlListOption, self).__init__(default)

        # Hold a UrlOption instance to assist with parsing
        self._urloption = UrlOption(schemes=schemes)
        
    def parse(self, s):
        out = []
        for url in super(UrlListOption, self).parse(s):
            out.append(self._urloption.parse(url))
        return out


class IntOption(Option):
    def parse(self, s):
        try:
            return int(s)
        except (ValueError, TypeError), e:
            raise ValueError('invalid integer value')

class BoolOption(Option):
    def parse(self, s):
        s = s.lower()
        if s in ('0', 'no', 'false'):
            return False
        elif s in ('1', 'yes', 'true'):
            return True
        else:
            raise ValueError('invalid boolean value')

    def tostring(self, value):
        if value:
            return "yes"
        else:
            return "no"

class FloatOption(Option):
    def parse(self, s):
        try:
            return float(s.strip())
        except (ValueError, TypeError):
            raise ValueError('invalid float value')

class SelectionOption(Option):
    '''Handles string values where only specific values are allowed
    '''
    def __init__(self, default=None, allowed=()):
        super(SelectionOption, self).__init__(default)
        self._allowed = allowed
        
    def parse(self, s):
        if s not in self._allowed:
            raise ValueError('"%s" is not an allowed value' % s)
        return s

class BytesOption(Option):

    # Multipliers for unit symbols
    MULTS = {
        'k': 1024,
        'm': 1024*1024,
        'g': 1024*1024*1024,
    }

    def parse(self, s):
        """Parse a friendly bandwidth option to bytes

        The input should be a string containing a (possibly floating point)
        number followed by an optional single character unit. Valid units are
        'k', 'M', 'G'. Case is ignored.
       
        Valid inputs: 100, 123M, 45.6k, 12.4G, 100K, 786.3, 0
        Invalid inputs: -10, -0.1, 45.6L, 123Mb

        Return value will always be an integer

        1k = 1024 bytes.

        ValueError will be raised if the option couldn't be parsed.
        """
        if len(s) < 1:
            raise ValueError("no value specified")

        if s[-1].isalpha():
            n = s[:-1]
            unit = s[-1].lower()
            mult = self.MULTS.get(unit, None)
            if not mult:
                raise ValueError("unknown unit '%s'" % unit)
        else:
            n = s
            mult = 1
             
        try:
            n = float(n)
        except ValueError:
            raise ValueError("couldn't convert '%s' to number" % n)

        if n < 0:
            raise ValueError("bytes value may not be negative")

        return int(n * mult)

class ThrottleOption(BytesOption):

    def parse(self, s):
        """Get a throttle option. 

        Input may either be a percentage or a "friendly bandwidth value" as
        accepted by the BytesOption.

        Valid inputs: 100, 50%, 80.5%, 123M, 45.6k, 12.4G, 100K, 786.0, 0
        Invalid inputs: 100.1%, -4%, -500

        Return value will be a int if a bandwidth value was specified or a
        float if a percentage was given.

        ValueError will be raised if input couldn't be parsed.
        """
        if len(s) < 1:
            raise ValueError("no value specified")

        if s[-1] == '%':
            n = s[:-1]
            try:
                n = float(n)
            except ValueError:
                raise ValueError("couldn't convert '%s' to number" % n)
            if n < 0 or n > 100:
                raise ValueError("percentage is out of range")
            return n / 100.0
        else:
            return BytesOption.parse(self, s)


class BaseConfig(object):
    '''
    Base class for storing configuration definitions. Subclass when creating
    your own definitons.
    '''

    def __init__(self):
        self._section = None

        for name in self.iterkeys():
            option = self.optionobj(name)
            option.setup(self, name)

    def __str__(self):
        out = []
        out.append('[%s]' % self._section)
        for name, value in self.iteritems():
            out.append('%s: %r' % (name, value))
        return '\n'.join(out)

    def populate(self, parser, section, parent=None):
        '''Set option values from a INI file section.

        @param parser: ConfParser instance (or subclass)
        @param section: INI file section to read use.
        @param parent: Optional parent BaseConfig (or subclass) instance to use
            when doing option value inheritance.
        '''
        self._section = section
        self.cfg = parser           # Keep a reference to the parser

        for name in self.iterkeys():
            option = self.optionobj(name)
            value = None
            try:
                value = parser.get(section, name)
            except (NoSectionError, NoOptionError):
                # No matching option in this section, try inheriting
                if parent and option.inherit:
                    value = getattr(parent, name)
               
            option.setparser(self, parser, section)
            if value is not None:
                setattr(self, name, value)

    def optionobj(cls, name):
        '''Return the Option instance for the given name
        '''
        # Look for Option instances in this class and base classes so that
        # option inheritance works
        for klass in (cls,) + cls.__bases__:
            obj = klass.__dict__.get(name, None)
            if obj is not None:
                break

        if obj is not None and isinstance(obj, Option):
            return obj
        else:
            raise KeyError
    optionobj = classmethod(optionobj)

    def isoption(cls, name):
        '''Return True if the given name refers to a defined option 
        '''
        try:
            cls.optionobj(name)
            return True
        except KeyError:
            return False
    isoption = classmethod(isoption)

    def iterkeys(self):
        '''Yield the names of all defined options in the instance.
        '''
        for name, item in self.iteritems():
            yield name

    def iteritems(self):
        '''Yield (name, value) pairs for every option in the instance.

        The value returned is the parsed, validated option value.
        '''
        # Use dir() so that we see inherited options too
        for name in dir(self):
            if self.isoption(name):
                yield (name, getattr(self, name))

    def getConfigOption(self, option, default=None):
        warnings.warn('getConfigOption() will go away in a future version of Yum.\n'
                'Please access option values as attributes or using getattr().',
                DeprecationWarning)
        if hasattr(self, option):
            return getattr(self, option)
        return default

    def setConfigOption(self, option, value):
        warnings.warn('setConfigOption() will go away in a future version of Yum.\n'
                'Please set option values as attributes or using setattr().',
                DeprecationWarning)
        if hasattr(self, option):
            setattr(self, option, value)
        else:
            raise Errors.ConfigError, 'No such option %s' % option

class EarlyConf(BaseConfig):
    '''
    Configuration option definitions for yum.conf's [main] section that are
    required before the other [main] options can be parsed (mainly due to
    variable substitution).
    '''
    distroverpkg = Option('fedora-release')
    installroot = Option('/')

class YumConf(EarlyConf):
    '''
    Configuration option definitions for yum.conf\'s [main] section.

    Note: inherits options from EarlyConf too.
    '''

    debuglevel = IntOption(2)
    errorlevel = IntOption(2)
    retries = IntOption(10)
    recent = IntOption(7)

    cachedir = Option('/var/cache/yum')
    keepcache = BoolOption(True)
    logfile = Option('/var/log/yum.log')
    reposdir = ListOption(['/etc/yum/repos.d', '/etc/yum.repos.d'])
    syslog_ident = Option()
    syslog_facility = Option('LOG_DAEMON')

    commands = ListOption()
    exclude = ListOption()
    failovermethod = Option('roundrobin')
    yumversion = Option('unversioned')
    proxy = UrlOption(schemes=('http', 'ftp', 'https'), allow_none=True)
    proxy_username = Option()
    proxy_password = Option()
    pluginpath = ListOption(['/usr/lib/yum-plugins'])
    installonlypkgs = ListOption(['kernel', 'kernel-bigmem',
            'kernel-enterprise','kernel-smp', 'kernel-modules', 'kernel-debug',
            'kernel-unsupported', 'kernel-source', 'kernel-devel'])
    kernelpkgnames = ListOption(['kernel','kernel-smp', 'kernel-enterprise',
            'kernel-bigmem', 'kernel-BOOT'])
    exactarchlist = ListOption(['kernel', 'kernel-smp', 'glibc',
            'kernel-hugemem', 'kernel-enterprise', 'kernel-bigmem',
            'kernel-devel'])
    tsflags = ListOption()

    assumeyes = BoolOption(False)
    alwaysprompt = BoolOption(True)
    exactarch = BoolOption(True)
    tolerant = BoolOption(True)
    diskspacecheck = BoolOption(True)
    overwrite_groups = BoolOption(False)
    keepalive = BoolOption(True)
    gpgcheck = BoolOption(False)
    obsoletes = BoolOption(False)
    showdupesfromrepos = BoolOption(False)
    enabled = BoolOption(True)
    plugins = BoolOption(False)
    enablegroups = BoolOption(True)
    enable_group_conditionals = BoolOption(True)

    timeout = FloatOption(30.0)

    bandwidth = BytesOption(0)
    throttle = ThrottleOption(0)

    http_caching = SelectionOption('all', ('none', 'packages', 'all'))
    metadata_expire = IntOption(1800)   # time in seconds

class RepoConf(BaseConfig):
    '''
    Option definitions for repository INI file sections.
    '''
    name = Option()
    enabled = Inherit(YumConf.enabled)
    baseurl = UrlListOption()
    mirrorlist = UrlOption()
    gpgkey = UrlListOption()
    exclude = ListOption() 
    includepkgs = ListOption() 

    proxy = Inherit(YumConf.proxy)
    proxy_username = Inherit(YumConf.proxy_username)
    proxy_password = Inherit(YumConf.proxy_password)
    retries = Inherit(YumConf.retries)
    failovermethod = Inherit(YumConf.failovermethod)

    gpgcheck = Inherit(YumConf.gpgcheck)
    keepalive = Inherit(YumConf.keepalive)
    enablegroups = Inherit(YumConf.enablegroups)

    bandwidth = Inherit(YumConf.bandwidth)
    throttle = Inherit(YumConf.throttle)
    timeout = Inherit(YumConf.timeout)
    http_caching = Inherit(YumConf.http_caching)
    metadata_expire = Inherit(YumConf.metadata_expire)

def readMainConfig(configfile, root):
    '''Parse Yum's main configuration file

    @param configfile: Path to the configuration file to parse (typically
        '/etc/yum.conf').
    @param root: The base path to use for installation (typically '/')
    @return: Populated YumConf instance.
    '''

    # Read up config variables that are needed early to calculate substitution
    # variables
    EarlyConf.installroot.default = root
    earlyconf = EarlyConf()
    confparser = IncludingConfigParser()
    if not os.path.exists(configfile):
        raise Errors.ConfigError, 'No such config file %s' % configfile

    confparser.read(configfile)
    earlyconf.populate(confparser, 'main')

    # Set up substitution vars
    vars = _getEnvVar()
    vars['basearch'] = rpmUtils.arch.getBaseArch()          # FIXME make this configurable??
    vars['arch'] = rpmUtils.arch.getCanonArch()             # FIXME make this configurable??
    vars['releasever'] = _getsysver(earlyconf.installroot, earlyconf.distroverpkg)

    # Read [main] section
    yumconf = YumConf()
    confparser = IncludingConfigParser(vars=vars)
    confparser.read(configfile)
    yumconf.populate(confparser, 'main')

    # Apply the installroot to directory options
    for option in ('cachedir', 'logfile'):
        path = getattr(yumconf, option)
        setattr(yumconf, option, yumconf.installroot + path)
    
    # Check that plugin paths are all absolute
    for path in yumconf.pluginpath:
        if not path.startswith('/'):
            raise Errors.ConfigError("All plugin search paths must be absolute")

    # Add in some extra attributes which aren't actually configuration values 
    yumconf.yumvar = vars
    yumconf.uid = 0
    yumconf.cache = 0
    yumconf.progess_obj = None

    return yumconf

def readRepoConfig(parser, section, mainconf):
    '''Parse an INI file section for a repository.

    @param parser: ConfParser or similar to read INI file values from.
    @param section: INI file section to read.
    @param mainconf: ConfParser or similar for yum.conf.
    @return: Repository instance.
    '''

    conf = RepoConf()
    conf.populate(parser, section, mainconf)

    # Ensure that the repo name is set
    if not conf.name:
        conf.name = section
        print >> sys.stderr, \
            'Repository %r is missing name in configuration, using id' % section

    thisrepo = Repository(section)

    # Transfer attributes across
    #TODO: merge RepoConf and Repository 
    for k, v in conf.iteritems():
        if v or not hasattr(thisrepo, k):
            thisrepo.set(k, v)

    # Set attributes not from the config file
    thisrepo.basecachedir = mainconf.cachedir
    thisrepo.yumvar.update(mainconf.yumvar)
    thisrepo.cfg = parser

    return thisrepo

def getOption(conf, section, name, default, option):
    '''Convenience function to retrieve a parsed and converted value from a
    ConfigParser.

    @param conf: ConfigParser instance or similar
    @param section: Section name
    @param name: Option name
    @param default: Value to use if option is missing
    @param option: Option instance to use for conversion.
    @return: The parsed value or default if value was not present.

    Will raise ValueError if the option could not be parsed.
    '''
    try: 
        val = conf.get(section, name)
    except (NoSectionError, NoOptionError):
        return default
    return option.parse(val)

def _getEnvVar():
    '''Return variable replacements from the environment variables YUM0 to YUM9

    The result is intended to be used with parser.varReplace()
    '''
    yumvar = {}
    for num in range(0, 10):
        env = 'YUM%d' % num
        val = os.environ.get(env, '')
        if val:
            yumvar[env.lower()] = val
    return yumvar

def _getsysver(installroot, distroverpkg):
    '''Calculate the release version for the system.

    @param installroot: The value of the installroot option.
    @param distroverpkg: The value of the distroverpkg option.
    @return: The release version as a string (eg. '4' for FC4)
    '''
    ts = rpmUtils.transaction.initReadOnlyTransaction(root=installroot)
    ts.pushVSFlags(~(rpm._RPMVSF_NOSIGNATURES|rpm._RPMVSF_NODIGESTS))
    idx = ts.dbMatch('provides', distroverpkg)
    # we're going to take the first one - if there is more than one of these
    # then the user needs a beating
    if idx.count() == 0:
        releasever = 'Null'
    else:
        hdr = idx.next()
        releasever = hdr['version']
        del hdr
    del idx
    del ts
    return releasever


#def main():
#    mainconf = readMainConfig('/etc/yum.conf', '/') 
#    repoconf = readRepoConfig(mainconf.cfg, 'core', mainconf)
#
#    print `repoconf.name`
#
#if __name__ == '__main__':
#    main()


--- NEW FILE constants.py ---
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.


#Constants
#transaction set states
TS_UPDATE = 10
TS_INSTALL = 20
TS_TRUEINSTALL = 30
TS_ERASE = 40
TS_OBSOLETED = 50
TS_OBSOLETING = 60
TS_AVAILABLE = 70


TS_INSTALL_STATES = [TS_INSTALL, TS_TRUEINSTALL, TS_UPDATE, TS_OBSOLETING]
TS_REMOVE_STATES = [TS_ERASE, TS_OBSOLETED]

# Transaction Relationships
TR_UPDATES = 1
TR_UPDATEDBY = 2
TR_OBSOLETES = 3
TR_OBSOLETEDBY = 4
TR_DEPENDS = 5
TR_DEPENDSON = 6

# Transaction Member Sort Colors
# Each node in a topological sort is colored
# White nodes are unseen, black nodes are seen
# grey nodes are in progress
TX_WHITE = 0
TX_GREY = 1
TX_BLACK = 2

# package object file types
PO_FILE = 1
PO_DIR = 2
PO_GHOST = 3
PO_CONFIG = 4
PO_DOC = 5

# package object package types
PO_REMOTEPKG = 1
PO_LOCALPKG = 2
PO_INSTALLEDPKG = 3

# FLAGS
SYMBOLFLAGS = {'>':'GT', '<':'LT', '=': 'EQ', '==': 'EQ', '>=':'GE', '<=':'LE'}
LETTERFLAGS = {'GT':'>', 'LT':'<', 'EQ':'=', 'GE': '>=', 'LE': '<='}

# Constants for plugin config option registration
PLUG_OPT_STRING = 0
PLUG_OPT_INT = 1
PLUG_OPT_FLOAT = 2
PLUG_OPT_BOOL = 3

PLUG_OPT_WHERE_MAIN = 0
PLUG_OPT_WHERE_REPO = 1
PLUG_OPT_WHERE_ALL = 2

# boolean dict:
BOOLEAN_STATES = {'1': True, 'yes': True, 'true': True, 'on': True,
                  '0': False, 'no': False, 'false': False, 'off': False}


--- NEW FILE depsolve.py ---
#!/usr/bin/python -t
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# Copyright 2005 Duke University 

import os
import os.path
import re
import types

import rpmUtils.transaction
import rpmUtils.miscutils
import rpmUtils.arch
from misc import unique
import rpm

from repomd.packageSack import ListPackageSack
from repomd.mdErrors import PackageSackError
from Errors import DepError, RepoError
from constants import *
import packages

class Depsolve(object):
    def __init__(self):
        packages.base = self
        self.dsCallback = None
    
    def initActionTs(self):
        """sets up the ts we'll use for all the work"""
        
        self.ts = rpmUtils.transaction.TransactionWrapper(self.conf.installroot)
        ts_flags_to_rpm = { 'noscripts': rpm.RPMTRANS_FLAG_NOSCRIPTS,
                            'notriggers': rpm.RPMTRANS_FLAG_NOTRIGGERS,
                            'nodocs': rpm.RPMTRANS_FLAG_NODOCS,
                            'test': rpm.RPMTRANS_FLAG_TEST,
                            'repackage': rpm.RPMTRANS_FLAG_REPACKAGE}
        
        self.ts.setFlags(0) # reset everything.
        
        for flag in self.conf.tsflags:
            if ts_flags_to_rpm.has_key(flag):
                self.ts.addTsFlag(ts_flags_to_rpm[flag])
            else:
                self.errorlog(0, 'Invalid tsflag in config file: %s' % flag)

    def whatProvides(self, name, flags, version):
        """searches the packageSacks for what provides the arguments
           returns a ListPackageSack of providing packages, possibly empty"""

        self.log(4, 'Searching pkgSack for dep: %s' % name)
        # we need to check the name - if it doesn't match:
        # /etc/* bin/* or /usr/lib/sendmail then we should fetch the 
        # filelists.xml for all repos to make the searchProvides more complete.
        if name[0] == '/':
            matched = 0
            globs = ['.*bin\/.*', '^\/etc\/.*', '^\/usr\/lib\/sendmail$']
            for glob in globs:
                globc = re.compile(glob)
                if globc.match(name):
                    matched = 1
            if not matched:
                self.doSackFilelistPopulate()
            
        pkgs = self.pkgSack.searchProvides(name)
        if flags == 0:
            flags = None
        

        if type(version) in (types.StringType, types.NoneType):
            (r_e, r_v, r_r) = rpmUtils.miscutils.stringToVersion(version)
        elif type(version) in (types.TupleType, types.ListType): # would this ever be a ListType?
            (r_e, r_v, r_r) = version
        
        defSack = ListPackageSack() # holder for items definitely providing this dep
        
        for po in pkgs:
            self.log(5, 'Potential match for %s from %s' % (name, po))
            if name[0] == '/' and r_v is None:
                # file dep add all matches to the defSack
                defSack.addPackage(po)
                continue

            if po.checkPrco('provides', (name, flags, (r_e, r_v, r_r))):
                defSack.addPackage(po)
                self.log(3, 'Matched %s to require for %s' % (po, name))
        
        return defSack
        
    def allowedMultipleInstalls(self, po):
        """takes a packageObject, returns 1 or 0 depending on if the package 
           should/can be installed multiple times with different vers
           like kernels and kernel modules, for example"""
           
        if po.name in self.conf.installonlypkgs:
            return 1
        
        provides = po.getProvidesNames()
        if filter (lambda prov: prov in self.conf.installonlypkgs, provides):
            return 1
        
        return 0

    def populateTs(self, test=0, keepold=1):
        """take transactionData class and populate transaction set"""

        if self.dsCallback: self.dsCallback.transactionPopulation()
        ts_elem = {}
        if keepold:
            for te in self.ts:
                epoch = te.E()
                if epoch is None:
                    epoch = '0'
                pkginfo = (te.N(), te.A(), epoch, te.V(), te.R())
                if te.Type() == 1:
                    mode = 'i'
                elif te.Type() == 2:
                    mode = 'e'
                
                ts_elem[(pkginfo, mode)] = 1
                
        for txmbr in self.tsInfo.getMembers():
            self.log(6, 'Member: %s' % txmbr)
            if txmbr.ts_state in ['u', 'i']:
                if ts_elem.has_key((txmbr.pkgtup, 'i')):
                    continue
                self.downloadHeader(txmbr.po)
                hdr = txmbr.po.returnLocalHeader()
                rpmfile = txmbr.po.localPkg()
                
                if txmbr.ts_state == 'u':
                    if self.allowedMultipleInstalls(txmbr.po):
                        self.log(5, '%s converted to install' % (txmbr.po))
                        txmbr.ts_state = 'i'
                        txmbr.output_state = TS_INSTALL

                
                self.ts.addInstall(hdr, (hdr, rpmfile), txmbr.ts_state)
                self.log(4, 'Adding Package %s in mode %s' % (txmbr.po, txmbr.ts_state))
                if self.dsCallback: 
                    self.dsCallback.pkgAdded(txmbr.pkgtup, txmbr.ts_state)
            
            elif txmbr.ts_state in ['e']:
                if ts_elem.has_key((txmbr.pkgtup, txmbr.ts_state)):
                    continue
                indexes = self.rpmdb.returnIndexByTuple(txmbr.pkgtup)
                for idx in indexes:
                    self.ts.addErase(idx)
                    if self.dsCallback: self.dsCallback.pkgAdded(txmbr.pkgtup, 'e')
                    self.log(4, 'Removing Package %s' % txmbr.po)
        

    def resolveDeps(self):

        CheckDeps = 1
        conflicts = 0
        missingdep = 0
        depscopy = []
        unresolveableloop = 0

        errors = []
        if self.dsCallback: self.dsCallback.start()

        while CheckDeps > 0:
            self.cheaterlookup = {} # short cache for some information we'd resolve
                                    # (needname, needversion) = pkgtup
            self.populateTs(test=1)
            if self.dsCallback: self.dsCallback.tscheck()
            deps = self.ts.check()

            
            if not deps:
                self.tsInfo.changed = False
                return (2, ['Success - deps resolved'])

            deps = unique(deps) # get rid of duplicate deps            
            if deps == depscopy:
                unresolveableloop += 1
                self.log(5, 'Identical Loop count = %d' % unresolveableloop)
                if unresolveableloop >= 2:
                    errors.append('Unable to satisfy dependencies')
                    for deptuple in deps:
                        ((name, version, release), (needname, needversion), flags, 
                          suggest, sense) = deptuple
                        if sense == rpm.RPMDEP_SENSE_REQUIRES:
                            msg = 'Package %s needs %s, this is not available.' % \
                                  (name, rpmUtils.miscutils.formatRequire(needname, 
                                                                          needversion, flags))
                        elif sense == rpm.RPMDEP_SENSE_CONFLICTS:
                            msg = 'Package %s conflicts with %s.' % \
                                  (name, rpmUtils.miscutils.formatRequire(needname, 
                                                                          needversion, flags))

                        errors.append(msg)
                    CheckDeps = 0
                    break
            else:
                unresolveableloop = 0

            depscopy = deps
            CheckDeps = 0


            # things to resolve
            self.log (3, '# of Deps = %d' % len(deps))
            depcount = 0
            for dep in deps:
                ((name, version, release), (needname, needversion), flags, suggest, sense) = dep
                depcount += 1
                self.log(5, '\nDep Number: %d/%d\n' % (depcount, len(deps)))
                if sense == rpm.RPMDEP_SENSE_REQUIRES: # requires
                    # if our packageSacks aren't here, then set them up
                    if not hasattr(self, 'pkgSack'):
                        self.doRepoSetup()
                        self.doSackSetup()
                    (checkdep, missing, conflict, errormsgs) = self._processReq(dep)
                    
                elif sense == rpm.RPMDEP_SENSE_CONFLICTS: # conflicts - this is gonna be short :)
                    (checkdep, missing, conflict, errormsgs) = self._processConflict(dep)
                    
                else: # wtf?
                    self.errorlog(0, 'Unknown Sense: %d' (sense))
                    continue

                missingdep += missing
                conflicts += conflict
                CheckDeps += checkdep
                for error in errormsgs:
                    if error not in errors:
                        errors.append(error)

            self.log(4, 'miss = %d' % missingdep)
            self.log(4, 'conf = %d' % conflicts)
            self.log(4, 'CheckDeps = %d' % CheckDeps)

            if CheckDeps > 0:
                if self.dsCallback: self.dsCallback.restartLoop()
                self.log(4, 'Restarting Loop')
            else:
                if self.dsCallback: self.dsCallback.end()
                self.log(4, 'Dependency Process ending')

            del deps

        self.tsInfo.changed = False
        if len(errors) > 0:
            return (1, errors)
        if len(self.tsInfo) > 0:
            return (2, ['Run Callback'])

    def _processReq(self, dep):
        """processes a Requires dep from the resolveDeps functions, returns a tuple
           of (CheckDeps, missingdep, conflicts, errors) the last item is an array
           of error messages"""
        
        CheckDeps = 0
        missingdep = 0
        conflicts = 0
        errormsgs = []
        
        ((name, version, release), (needname, needversion), flags, suggest, sense) = dep
        
        niceformatneed = rpmUtils.miscutils.formatRequire(needname, needversion, flags)
        self.log(4, '%s requires: %s' % (name, niceformatneed))
        
        if self.dsCallback: self.dsCallback.procReq(name, niceformatneed)
        
        # is the requiring tuple (name, version, release) from an installed package?
        pkgs = []
        dumbmatchpkgs = self.rpmdb.returnTupleByKeyword(name=name, ver=version, rel=release)
        for pkgtuple in dumbmatchpkgs:
            self.log(6, 'Calling rpmdb.returnHeaderByTuple on %s.%s %s:%s-%s' % pkgtuple)
            hdrs = self.rpmdb.returnHeaderByTuple(pkgtuple)
            for hdr in hdrs:
                po = packages.YumInstalledPackage(hdr)
                if self.tsInfo.exists(po.pkgtup):
                    self.log(7, 'Skipping package already in Transaction Set: %s' % po)
                    continue
                if niceformatneed in po.requiresList():
                    pkgs.append(po)

        if len(pkgs) < 1: # requiring tuple is not in the rpmdb
            txmbrs = self.tsInfo.matchNaevr(name=name, ver=version, rel=release)
            if len(txmbrs) < 1:
                msg = 'Requiring package %s-%s-%s not in transaction set \
                                  nor in rpmdb' % (name, version, release)
                self.log(4, msg)
                errormsgs.append(msg)
                missingdep = 1
                CheckDeps = 0

            else:
                txmbr = txmbrs[0]
                self.log(4, 'Requiring package is from transaction set')
                if txmbr.ts_state == 'e':
                    msg = 'Requiring package %s is set to be erased,' % txmbr.name +\
                           'probably processing an old dep, restarting loop early.'
                    self.log(5, msg)
                    CheckDeps=1
                    missingdep=0
                    return (CheckDeps, missingdep, conflicts, errormsgs)
                    
                else:
                    self.log(4, 'Resolving for requiring package: %s-%s-%s in state %s' %
                                (name, version, release, txmbr.ts_state))
                    self.log(4, 'Resolving for requirement: %s' % 
                        rpmUtils.miscutils.formatRequire(needname, needversion, flags))
                    requirementTuple = (needname, flags, needversion)
                    # should we figure out which is pkg it is from the tsInfo?
                    requiringPkg = (name, version, release, txmbr.ts_state) 
                    CheckDeps, missingdep = self._requiringFromTransaction(requiringPkg, requirementTuple, errormsgs)
            
        if len(pkgs) > 0:  # requring tuple is in the rpmdb
            if len(pkgs) > 1:
                self.log(5, 'Multiple Packages match. %s-%s-%s' % (name, version, release))
                for po in pkgs:
                    # if one of them is (name, arch) already in the tsInfo somewhere, 
                    # pop it out of the list
                    (n,a,e,v,r) = po.pkgtup
                    thismode = self.tsInfo.getMode(name=n, arch=a)
                    if thismode is not None:
                        self.log(5, '   %s already in ts %s, skipping' % (po, thismode))
                        pkgs.remove(po)
                        continue
                    else:
                        self.log(5, '   %s' % po)
                    
            if len(pkgs) == 1:
                po = pkgs[0]
                self.log(5, 'Requiring package is installed: %s' % po)
            
            if len(pkgs) > 0:
                requiringPkg = pkgs[0] # take the first one, deal with the others (if there is one)
                                   # on another dep.
            else:
                self.errorlog(1, 'All pkgs in depset are also in tsInfo, this is wrong and bad')
                CheckDeps = 1
                return (CheckDeps, missingdep, conflicts, errormsgs)
            
            self.log(4, 'Resolving for installed requiring package: %s' % requiringPkg)
            self.log(4, 'Resolving for requirement: %s' % 
                rpmUtils.miscutils.formatRequire(needname, needversion, flags))
            
            requirementTuple = (needname, flags, needversion)
            
            CheckDeps, missingdep = self._requiringFromInstalled(requiringPkg.pkgtup,
                                                    requirementTuple, errormsgs)


        return (CheckDeps, missingdep, conflicts, errormsgs)


    def _requiringFromInstalled(self, requiringPkg, requirement, errorlist):
        """processes the dependency resolution for a dep where the requiring 
           package is installed"""
        
        # FIXME - should we think about dealing exclusively in package objects?
           
        (name, arch, epoch, ver, rel) = requiringPkg
        requiringPo = self.getInstalledPackageObject(requiringPkg)
        
        (needname, needflags, needversion) = requirement
        niceformatneed = rpmUtils.miscutils.formatRequire(needname, needversion, needflags)
        checkdeps = 0
        missingdep = 0
        
        # we must first find out why the requirement is no longer there
        # we must find out what provides/provided it from the rpmdb (if anything)
        # then check to see if that thing is being acted upon by the transaction set
        # if it is then we need to find out what is being done to it and act accordingly
        rpmdbNames = self.rpmdb.getNamePkgList()
        needmode = None # mode in the transaction of the needed pkg (if any)
        needpo = None
        providers = []
        
        if self.cheaterlookup.has_key((needname, needflags, needversion)):
            self.log(5, 'Needed Require has already been looked up, cheating')
            cheater_tup = self.cheaterlookup[(needname, needflags, needversion)]
            providers = [cheater_tup]
        
        elif needname in rpmdbNames:
            txmbrs = self.tsInfo.matchNaevr(name=needname)
            for txmbr in txmbrs:
                providers.append(txmbr.pkgtup)

        else:
            self.log(5, 'Needed Require is not a package name. Looking up: %s' % niceformatneed)
            providers = self.rpmdb.whatProvides(needname, needflags, needversion)
            
        for insttuple in providers:
            inst_str = '%s.%s %s:%s-%s' % insttuple
            (i_n, i_a, i_e, i_v, i_r) = insttuple
            self.log(5, 'Potential Provider: %s' % inst_str)
            thismode = self.tsInfo.getMode(name=i_n, arch=i_a, 
                            epoch=i_e, ver=i_v, rel=i_r)
                        
            if thismode is None and i_n in self.conf.exactarchlist:
                # check for mode by the same name+arch
                thismode = self.tsInfo.getMode(name=i_n, arch=i_a)
            
            if thismode is None and i_n not in self.conf.exactarchlist:
                # check for mode by just the name
                thismode = self.tsInfo.getMode(name=i_n)
            
            if thismode is not None:
                needmode = thismode
                try:
                    needpo = self.getInstalledPackageObject(insttuple)
                except KeyError:
                    needpo = self.getPackageObject(insttuple)
                self.cheaterlookup[(needname, needflags, needversion)] = insttuple
                self.log(5, 'Mode is %s for provider of %s: %s' % 
                            (needmode, niceformatneed, inst_str))
                break
                    
        self.log(5, 'Mode for pkg providing %s: %s' % (niceformatneed, needmode))
        
        if needmode in ['e']:
            self.log(5, 'TSINFO: %s package requiring %s marked as erase' %
                            (requiringPo, needname))
            txmbr = self.tsInfo.addErase(requiringPo)
            txmbr.setAsDep(po=needpo)
            checkdeps = 1
        
        if needmode in ['i', 'u']:
            self.doUpdateSetup()
            obslist = []
            # check obsoletes first
            if self.conf.obsoletes:
                if self.up.obsoleted_dict.has_key(requiringPo.pkgtup):
                    obslist = self.up.obsoleted_dict[requiringPo.pkgtup]
                    self.log(4, 'Looking for Obsoletes for %s' % requiringPo)
                
            if len(obslist) > 0:
                po = None
                for pkgtup in obslist:
                    po = self.getPackageObject(pkgtup)
                if po:
                    for (new, old) in self.up.getObsoletesTuples(): # FIXME query the obsoleting_list now?
                        if po.pkgtup == new:
                            txmbr = self.tsInfo.addObsoleting(po, requiringPo)
                            self.tsInfo.addObsoleted(requiringPo, po)
                            txmbr.setAsDep(po=needpo)
                            self.log(5, 'TSINFO: Obsoleting %s with %s to resolve dep.' % (requiringPo, po))
                            checkdeps = 1
                            return checkdeps, missingdep 
                
            
            # check updates second
            uplist = []                
            uplist = self.up.getUpdatesList(name=name)
            # if there's an update for the reqpkg, then update it
            
            po = None
            if len(uplist) > 0:
                if name not in self.conf.exactarchlist:
                    pkgs = self.pkgSack.returnNewestByName(name)
                    archs = {}
                    for pkg in pkgs:
                        (n,a,e,v,r) = pkg.pkgtup
                        archs[a] = pkg
                    a = rpmUtils.arch.getBestArchFromList(archs.keys())
                    po = archs[a]
                else:
                    po = self.pkgSack.returnNewestByNameArch((name,arch))[0]
                if po.pkgtup not in uplist:
                    po = None

            if po:
                for (new, old) in self.up.getUpdatesTuples():
                    if po.pkgtup == new:
                        txmbr = self.tsInfo.addUpdate(po, requiringPo)
                        txmbr.setAsDep(po=needpo)
                        self.log(5, 'TSINFO: Updating %s to resolve dep.' % po)
                checkdeps = 1
                
            else: # if there's no update then pass this over to requringFromTransaction()
                self.log(5, 'Cannot find an update path for dep for: %s' % niceformatneed)
                
                reqpkg = (name, ver, rel, None)
                return self._requiringFromTransaction(reqpkg, requirement, errorlist)
            

        if needmode is None:
            reqpkg = (name, ver, rel, None)
            if hasattr(self, 'pkgSack'):
                return self._requiringFromTransaction(reqpkg, requirement, errorlist)
            else:
                self.log(5, 'Unresolveable requirement %s for %s' % (niceformatneed, reqpkg_print))
                checkdeps = 0
                missingdep = 1


        return checkdeps, missingdep
        

    def _requiringFromTransaction(self, requiringPkg, requirement, errorlist):
        """processes the dependency resolution for a dep where requiring 
           package is in the transaction set"""
        
        (name, version, release, tsState) = requiringPkg
        (needname, needflags, needversion) = requirement
        checkdeps = 0
        missingdep = 0
        
        #~ - if it's not available from some repository:
        #~     - mark as unresolveable.
        #
        #~ - if it's available from some repo:
        #~    - if there is an another version of the package currently installed then
        #        - if the other version is marked in the transaction set
        #           - if it's marked as erase
        #              - mark the dep as unresolveable
         
        #           - if it's marked as update or install
        #              - check if the version for this requirement:
        #                  - if it is higher 
        #                       - mark this version to be updated/installed
        #                       - remove the other version from the transaction set
        #                       - tell the transaction set to be rebuilt
        #                  - if it is lower
        #                       - mark the dep as unresolveable
        #                   - if they are the same
        #                       - be confused but continue

        provSack = self.whatProvides(needname, needflags, needversion)

        # get rid of things that are already in the rpmdb - b/c it's pointless to use them here

        for pkg in provSack.returnPackages():
            if pkg.pkgtup in self.rpmdb.getPkgList(): # is it already installed?
                self.log(5, '%s is in providing packages but it is already installed, removing.' % pkg)
                provSack.delPackage(pkg)
                continue

            # we need to check to see, if we have anything similar to it (name-wise)
            # installed or in the ts, and this isn't a package that allows multiple installs
            # then if it's newer, fine - continue on, if not, then we're unresolveable
            # cite it and exit
        
            tspkgs = []
            if not self.allowedMultipleInstalls(pkg):
                (n, a, e, v, r) = pkg.pkgtup
                
                # from ts
                tspkgs = self.tsInfo.matchNaevr(name=pkg.name, arch=pkg.arch)
                for tspkg in tspkgs:
                    (tn, ta, te, tv, tr) = tspkg.pkgtup
                    rc = rpmUtils.miscutils.compareEVR((e, v, r), (te, tv, tr))
                    if rc < 0:
                        msg = 'Potential resolving package %s has newer instance in ts.' % pkg
                        self.log(5, msg)
                        provSack.delPackage(pkg)
                        continue
                
                # from rpmdb
                dbpkgs = self.rpmdb.returnTupleByKeyword(name=pkg.name, arch=pkg.arch)
                for dbpkgtup in dbpkgs:
                    (dn, da, de, dv, dr) = dbpkgtup
                    rc = rpmUtils.miscutils.compareEVR((e, v, r), (de, dv, dr))
                    if rc < 0:
                        msg = 'Potential resolving package %s has newer instance installed.' % pkg
                        self.log(5, msg)
                        provSack.delPackage(pkg)
                        continue

        if len(provSack) == 0: # unresolveable
            missingdep = 1
            msg = 'Missing Dependency: %s is needed by package %s' % \
            (rpmUtils.miscutils.formatRequire(needname, needversion, needflags),
                                                                   name)
            errorlist.append(msg)
            return checkdeps, missingdep
        
        # iterate the provSack briefly, if we find the package is already in the 
        # tsInfo then just skip this run
        for pkg in provSack.returnPackages():
            (n,a,e,v,r) = pkg.pkgtup
            pkgmode = self.tsInfo.getMode(name=n, arch=a, epoch=e, ver=v, rel=r)
            if pkgmode in ['i', 'u']:
                self.doUpdateSetup()
                self.log(5, '%s already in ts, skipping this one' % (n))
                checkdeps = 1
                return checkdeps, missingdep
        

        # find the best one 
        newest = provSack.returnNewestByNameArch()
        if len(newest) > 1: # there's no way this can be zero
            best = newest[0]
            for po in newest[1:]:
                if len(po.name) < len(best.name):
                    best = po
                elif len(po.name) == len(best.name):
                    # compare arch
                    arch = rpmUtils.arch.getBestArchFromList([po.arch, best.arch])
                    if arch == po.arch:
                        best = po
        elif len(newest) == 1:
            best = newest[0]
        
        if best.pkgtup in self.rpmdb.getPkgList(): # is it already installed?
            missingdep = 1
            checkdeps = 0
            msg = 'Missing Dependency: %s is needed by package %s' % (needname, name)
            errorlist.append(msg)
            return checkdeps, missingdep
        
                
            
        # FIXME - why can't we look up in the transaction set for the requiringPkg
        # and know what needs it that way and provide a more sensible dep structure in the txmbr
        if (best.name, best.arch) in self.rpmdb.getNameArchPkgList():
            self.log(3, 'TSINFO: Marking %s as update for %s' % (best, name))
            txmbr = self.tsInfo.addUpdate(best)
            txmbr.setAsDep()
        else:
            self.log(3, 'TSINFO: Marking %s as install for %s' % (best, name))
            txmbr = self.tsInfo.addInstall(best)
            txmbr.setAsDep()

        checkdeps = 1
        
        return checkdeps, missingdep


    def _processConflict(self, dep):
        """processes a Conflict dep from the resolveDeps() method"""
                
        CheckDeps = 0
        missingdep = 0
        conflicts = 0
        errormsgs = []
        

        ((name, version, release), (needname, needversion), flags, suggest, sense) = dep
        
        niceformatneed = rpmUtils.miscutils.formatRequire(needname, needversion, flags)
        if self.dsCallback: self.dsCallback.procConflict(name, niceformatneed)
        
        # we should try to update out of the dep, if possible        
        # see which side of the conflict is installed and which is in the transaction Set
        needmode = self.tsInfo.getMode(name=needname)
        confmode = self.tsInfo.getMode(name=name, ver=version, rel=release)
        if confmode is None:
            confname = name
        elif needmode is None:
            confname = needname
        else:
            confname = name
            
        po = None        
        self.doUpdateSetup()
        uplist = self.up.getUpdatesList(name=confname)
        
        conftuple = self.rpmdb.returnTupleByKeyword(name=confname)
        if conftuple:
            (confname, confarch, confepoch, confver, confrel) = conftuple[0] # take the first one, probably the only one
                

            # if there's an update for the reqpkg, then update it
            if len(uplist) > 0:
                if confname not in self.conf.exactarchlist:
                    pkgs = self.pkgSack.returnNewestByName(confname)
                    archs = {}
                    for pkg in pkgs:
                        (n,a,e,v,r) = pkg.pkgtup
                        archs[a] = pkg
                    a = rpmUtils.arch.getBestArchFromList(archs.keys())
                    po = archs[a]
                else:
                    po = self.pkgSack.returnNewestByNameArch((confname,confarch))[0]
                if po.pkgtup not in uplist:
                    po = None

        if po:
            self.log(5, 'TSINFO: Updating %s to resolve conflict.' % po)
            txmbr = self.tsInfo.addUpdate(po)
            txmbr.setAsDep()
            CheckDeps = 1
            
        else:
            conf = rpmUtils.miscutils.formatRequire(needname, needversion, flags)
            CheckDeps, conflicts = self._unresolveableConflict(conf, name, errormsgs)
            self.log(4, '%s conflicts: %s' % (name, conf))
        
        return (CheckDeps, missingdep, conflicts, errormsgs)

    def _unresolveableReq(self, req, name, namestate, errors):
        CheckDeps = 0
        missingdep = 1
        msg = 'Missing Dependency: %s needed for package %s (%s)' % (req, name, namestate)
        errors.append(msg)
        if self.dsCallback: self.dsCallback.unresolved(msg)
        return CheckDeps, missingdep

    def _unresolveableConflict(self, conf, name, errors):
        CheckDeps = 0
        conflicts = 1
        msg = '%s conflicts with %s' % (name, conf)
        errors.append(msg)
        return CheckDeps, conflicts


--- NEW FILE failover.py ---
#!/usr/bin/python
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# Copyright 2003 Jack Neely, NC State University

# Here we define a base class for failover methods.  The idea here is that each
# failover method uses a class derived from the base class so yum only has to
# worry about calling get_serverurl() and server_failed() and these classes will 
# figure out which URL to cough up based on the failover method.

import random

class baseFailOverMethod:

    def __init__(self, repo):
        self.repo = repo
        self.failures = 0
    
    def get_serverurl(self, i=None):
        """Returns a serverurl based on this failover method or None 
           if complete failure.  If i is given it is a direct index
           to pull a server URL from instead of using the failures 
           counter."""
        return None
        
    def server_failed(self):
        "Tells the failover method that the current server is failed."
        self.failures = self.failures + 1
        
    def reset(self, i=0):
        "Reset the failures counter to a given index."
        self.failures = i

    def get_index(self):
        """Returns the current number of failures which is also the
           index into the list this object represents.  ger_serverurl()
           should always be used to translate an index into a URL
           as this object may change how indexs map.  (See RoundRobin)"""

        return self.failures
   
    def len(self):
        """Returns the how many URLs we've got to cycle through."""

        return len(self.repo.urls)
        
            

class priority(baseFailOverMethod):

    """Chooses server based on the first success in the list."""
    
    def get_serverurl(self, i=None):
        "Returns a serverurl based on this failover method or None if complete failure."
        
        if i == None:
            index = self.failures
        else:
            index = i
        
        if index >= len(self.repo.urls):
            return None
        
        return self.repos.urls[index]
        
        
    
class roundRobin(baseFailOverMethod):

    """Chooses server based on a round robin."""
    
    def __init__(self, repo):
        baseFailOverMethod.__init__(self, repo)
        random.seed()
        self.offset = random.randint(0, 37)
    
    def get_serverurl(self, i=None):
        "Returns a serverurl based on this failover method or None if complete failure."

        if i == None:
            index = self.failures
        else:
            index = i
        
        if index >= len(self.repo.urls):
            return None
        
        rr = (index + self.offset) % len(self.repo.urls)
        return self.repo.urls[rr]   

# SDG


--- NEW FILE logger.py ---
"""
A module for convenient yet powerful file, console, and syslog logging

BASIC USAGE

  from logger import Logger

  log = Logger(threshold=0)    # create the log object and give it
                               # a threshold of 0
  log.log(2, 'all done')       # send a log of priority 2 (not printed)
  log(0, 'error: bandits!')    # send a log of priority 0 (printed)
  log.write(0, stringvar)      # do a raw write on the file object

DESCRIPTION

  Each logging object is given a threshold.  Any messages that are
  then sent to that object are logged only if their priority meets or
  exceeds the threshold.  Lower numerical priority means that a
  message is more important.  For example: if a log object has
  threshold 2, then all messages of priority 2, 1, 0, -1, etc will be
  logged, while those of priority 3, 4, etc. will not.  I suggest the
  following scale:
  
     LOG PRIORITY    MEANING
              -1     failure - cannot be ignored
               0     important message - printed in default mode
               1     informational message - printed with -v
               2     debugging information

        THRESHOLD    MEANING
              -1     quiet mode (-q) only failures are printed
               0     normal operation
               1     verbose mode (-v)
               2     debug mode (-vv or -d)

  It can be extended farther in both directions, but that is rarely
  useful.  It can also be shifted in either direction.  This might be
  useful if you want to supply the threshold directly on the command
  line but have trouble passing in negative numbers.  In that case,
  add 1 to all thresholds and priorities listed above.

CLASSES

  There are three primary classes defined in this module:

    Logger        Class for basic file and console logging
    SysLogger     Class for syslog logging
    LogContainer  Class for wrapping multiple other loggers together
                  for convenient use

  Instances of all of these support the same basic methods:

    obj.log(priority, message)   # log a message with smart formatting
    obj.write(priority, message) # log a string in a ver raw way
    obj(priority, message)       # shorthand for obj.log(...)

  Different classes support other methods as well, but this is what you
  will mostly use.


ADVANCED

  There are a number of options available for these classes.  These are
  documented below in the respective classes and methods.  Here is a
  list of some of the things you can do:

    * make a prefix contain a string which gets repeated for more
      important logs.  (prefix)
    * directly test if a log object WOULD log, so you can do
      complicated stuff, like efficient for loops. (test)
    * make the priority, threshold arbitrary objects, with a
      home-rolled test to see if it should log. (test)
    * give log containers a "master threshold" and define arbitrary
      behavior based on it.  Examples include:
      - only pass on messages of sufficient priority (ragardless of
        the thresholds of the log ojects).
      - only pass on messages to objects whose thresholds are
        (numerically) lower than the master threshold.

SEE ALSO

  Take a look at the examples at the end of this file in the test &
  demo section.

COMMENTS

  I welcome comments, questions, bug reports and requests... I'm very
  lonely. :)
"""

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.

# Copyright 2001-2003 Michael D. Stenner

import sys
import string
# syslog is imported from within SysLogger.__init__

AUTHOR  = "Michael D. Stenner <mstenner at phy.duke.edu>"
VERSION = "0.7"
DATE    = "2003/09/20"
URL     = "http://linux.duke.edu/projects/mini/logger/"

class Logger:
    """A class for file-object logging
    
    USAGE:
      from logger import Logger

      log_obj = Logger(THRESHOLD)     # create the instance

      log_obj.log(3, 'message')       # log a message with priority 3
      log_obj(3, 'message')           # same thing
      log_obj(3, ['message'])         # same thing

      log_obj.test(3)                 # boolean - would a message of
                                      # this priority be printed?

      # a raw write call after the priority test, for writing
      # arbitrary text -- (this will not be followed by \\n)
      log_obj.write(3, 'thing\\nto\\nwrite')  

      # generate the prefix used for priority 3
      pr = log_obj.gen_prefix(3)

      # see the examples in the test section for more

    BASIC OPTIONS:

      There are a couple of basic options that are commonly needed.
      These are attribues of instances of class Logger.

        preprefix

          Text that will be printed at the start of each line of
          output (for log()ged, not write()en messages).  This might
          be your program's name, for example.

            log.preprefix = 'myprog'

          If preprefix is callable, then it will be called for each
          log and the returned value will be used.  This is useful for
          printing the current time.

            import time
            def printtime():
                return time.strftime('%m/%d/%y %H:%M:%S ',
                         time.localtime(time.time()))
            log.preprefix = printtime
      
        file_object

          This is the file object to which output is directed.  If it
          is None, then the logs are quietly dropped.

      There are other options described in the next section, but these
      are the most commonly used.

    ATTRIBUTES:
      (all of these are settable as keyword args to __init__)

      ATTRIBUTES   DEFAULT      DESCRIPTION
      ----------------------------------------------------------
      threshold    = 0          how verbose the program should be
      file_object  = sys.stderr file object to which output goes
      prefix       = ''         prefix string - repeated for more
                                important logs
      prefix_depth = 5          times prefix is repeated for logs
                                of priority 0.  Basically, set this
                                one larger than your highest log
                                priority.
      preprefix    = ''         string printed before the prefix
                                if callable, returned string will
                                be used (useful for printing time)
      postprefix   = ''         string printed after the prefix
      default      = 1          default priority to log at

    """

    def __init__(self,
                 threshold    = 0,
                 file_object  = sys.stderr,
                 prefix       = '',
                 prefix_depth = 5,
                 preprefix    = '',
                 postprefix   = '',
                 default      = 1):
        self.threshold    = threshold
        self.file_object  = file_object
        self.prefix       = prefix
        self.prefix_depth = prefix_depth
        self.preprefix    = preprefix
        self.postprefix   = postprefix
        self.default      = default

    def test(self, priority):

        """
        Return true if a log of the given priority would be printed.
        
        This can be overridden to do any test you like.  Specifically,
        priority and threshold need not be integers.  They can be
        arbitrary objects.  You need only override this method, and
        possibly gen_prefix.
        """

        return int(self.threshold) >= int(priority)
        
    def gen_prefix(self, priority):

        """
        Return the full prefix (including pre and post) for the
        given priority.

        If you use prefix and use a more complicated priority and
        verbosity (non-numerical), then you should either give the
        chosen object a __int__ method, or override this function.
        """
        
        if callable(self.preprefix): prefix = self.preprefix()
        else: prefix = self.preprefix

        if self.prefix:
            depth  = self.prefix_depth - int(priority)
            if depth < 1: depth = 1
            for i in range(0, depth):
                prefix = prefix + self.prefix

        return prefix + self.postprefix

    def log(self, priority, message=None):
        """
        Print a log message.  This prepends the prefix to each line
        and does some basic formatting.
        """
        p, m = self._use_default(priority, message)
        if self.test(p):
            if self.file_object is None: return
            if type(m) == type(''): # message is a string
                mlist = string.split(m, '\n')
                if mlist[-1] == '': del mlist[-1] # string ends in \n
            elif type(m) == type([]): # message is a list
                mlist = map(string.rstrip, m)
            else: mlist = [str(m)] # message is other type

            prefix = self.gen_prefix(p)
            for line in mlist:
                self.file_object.write(prefix + line + '\n')

    # make the objects callable
    __call__ = log

    def write(self, priority, message=None):
        """
        Print a log message.  In this case, 'message' must be a string
        as it will be passed directly to the file object's write method.
        """
        p, m = self._use_default(priority, message)
        if self.test(p):
            if self.file_object is None: return
            self.file_object.write(m)

    def _use_default(self, priority, message):
        """Substitute default priority if none was provided"""
        if message == None: return self.default, priority
        else: return priority, message

class SysLogger:
    """A class for file-object logging
    
    USAGE:
      For the most part, SysLogger instances are used just like Logger
      instances.  Notable exceptions:

        * prefixes aren't used (at least not as they are for Logger)
        * map_priority is pretty important because it controls
          conversion between Logger priorities and syslog priorities
        * although priority/threshold/test works the same, there is
          also maskpri and your syslog config which will limit what
          gets logged.  Keep this in mind if you see strange behavior.

      The most sensible use of this class will be will a LogContainer.
      You can create one Logger instance for writing to (say) stderr,
      a second for writing to a verbose log file, and a SysLogger
      instance for writing important things to syslog so your automated
      log-readers can see them.  Then you put them all in a log container
      for convenient access.  That would go something like this:

        from logger import Logger, SysLogger, LogContainer
        
        fo = open(file_log, 'w')
        file_logger    = Logger(4, file_object=fo) # print 4 and lower
        console_logger = Logger(1)                 # print 1 and lower
        syslog_logger  = SysLogger(0)              # print 0 and lower
        log = LogContainer([file_logger, console_logger, syslog_logger])

        log(3,  'some debugging message') # printed to file only
        log(1,  'some warning message')   # printed to file and console
        log(0,  'some error message')     # printed to all (ERR level)
        log(-1, 'major problem')          # printed to all (CRIT level)
      
    ATTRIBUTES:
      (all of these are settable as keyword args to __init__)

      ARGUMENT     DEFAULT      DESCRIPTION
      ----------------------------------------------------------
      threshold    = 0          identical to Logger threshold
      ident        = None       string prepended to each log, if None,
                                it will be taken from the program name
                                as it appears in sys.argv[0]
      logopt       = 0          syslog log options
      facility     = 'LOG_USER' syslog facility (it can be a string)
      maskpri      = 0          syslog priority bitmask
      default      = 1          default priority to log at

    """

    def __init__(self,
                 threshold    = 0,
                 ident        = None,
                 logopt       = 0,
                 facility     = 'LOG_USER',
                 maskpri      = 0,
                 default      = 1):
        # putting imports here is kinda icky, but I don't want to import
        # it if no SysLogger's are ever used.
        global syslog
        import syslog

        self.threshold    = threshold
        self.default      = default

        if ident is None:
            ind = string.rfind(sys.argv[0], '/')
            if ind == -1: ident = sys.argv[0]
            else: ident = sys.argv[0][ind+1:]

        if type(facility) == type(''):
            facility = getattr(syslog, facility)

        syslog.openlog(ident, logopt, facility)
        if maskpri: syslog.setlogmask(maskpri)
        
    def setlogmask(self, maskpri):
        """a (very) thin wrapper over the syslog setlogmask function"""
        return syslog.setlogmask(maskpri)

    def map_priority(self, priority):
        """Take a logger priority and return a syslog priority

        Here are the syslog priorities (from syslog.h):
          LOG_EMERG       0       /* system is unusable */
          LOG_ALERT       1       /* action must be taken immediately */
          LOG_CRIT        2       /* critical conditions */
          LOG_ERR         3       /* error conditions */
          LOG_WARNING     4       /* warning conditions */
          LOG_NOTICE      5       /* normal but significant condition */
          LOG_INFO        6       /* informational */
          LOG_DEBUG       7       /* debug-level messages */

        The syslog priority is simply equal to the logger priority plus 3.

           0 ->  0 + 3 =  3
          -5 -> -5 + 3 = -2  (which will be treated as 0)

        You can override this very simply.  Just do:

        def log_everything_emerg(priority): return 0
        log_obj.map_priority = log_everything_emerg

        The return value of this function does not need to be an integer or
        within the allowed syslog range (0 to 7).  It will be converted to
        an int and forced into this range if it lies outside it.
        """
        return priority + 3

    def test(self, priority):
        """
        Return true if a log of the given priority would be printed.
        
        This can be overridden to do any test you like.  Specifically,
        threshold and threshold need not be integers.  They can be
        arbitrary objects.  If you override this and use non-integer
        priorities, you will also need to override map_priority.
        """
        return int(self.threshold) >= int(priority)

    def log(self, priority, message=None):
        """
        Print a log message with some simple formatting.
        """
        p, m = self._use_default(priority, message)
        if self.test(p):
            if type(m) == type(''): # message is a string
                mlist = string.split(m, '\n')
                if mlist[-1] == '': del mlist[-1] # string ends in \n
            elif type(m) == type([]): # message is a list
                mlist = map(string.rstrip, m)
            else: mlist = [str(m)] # message is other type

            sp = int(self.map_priority(p))
            if sp < 0: sp = 0
            if sp > 7: sp = 7
            for line in mlist: syslog.syslog(sp, line)

    # make the objects callable
    __call__ = log

    def write(self, priority, message=None):
        """
        Print a log message.
        """
        p, m = self._use_default(priority, message)
        if self.test(p):
            sp = int(self.map_priority(p))
            if sp < 0: sp = 0
            if sp > 7: sp = 7
            
            # we must split based on newlines for syslog because
            # it doesn't deal with them well
            mlist = string.split(m, '\n')
            if mlist[-1] == '': del mlist[-1] # string ends in \n
            for message in mlist: syslog.syslog(sp, message)

    def _use_default(self, priority, message):
        """Substitute default priority if none was provided"""
        if message == None: return self.default, priority
        else: return priority, message

class LogContainer:
    """A class for consolidating calls to multiple sub-objects

    SUMMARY:
      If you want a program to log to multiple destinations, it might
      be convenient to use log containers.  A log container is an
      object which can hold several sub-log-objects (including other
      log containers).  When you log to a log container it passes the
      message on (with optional tests) to each of the log objects it
      contains.

    USAGE:

      The basic usage is very simple.  LogContainer's simply pass on
      logs to the contained Logger, SysLogger, or LogContainer
      objects.

        from logger import Logger, LogContainer

        log_1 = Logger(threshold=1, file_object=sys.stdout)
        log_2 = Logger(threshold=2, preprefix='LOG2')

        log = LogContainer([log_1, log_2])

        log(1, 'message')               # printed by log_1 and log_2
        log(2, 'message')               # only printed by log_2

      A more common example would be something like this:
      
        from logger import Logger, LogContainer

        system = Logger(threshold=1, file_object=logfile)
        debug  = Logger(threshold=5, file_object=sys.stdout)
        log = LogContainer([system, debug])

        log(3, 'sent to system and debug, but only debug will print it')
        log(0, 'very important, both will print it')

      In this mode, log containers are just shorthand for calling all
      contained objects with the same priority and message.

      When a log object is held in a container, it can still be used
      directly.  For example, you can still do

        debug(3, ['this will not be sent to the system log, even if its',
                  ' threshold is set very high'])

      (Yes, you can send lists of strings and they will be formatted
      on different lines.  The log methods are pretty smart.)

      There are more examples in the SysLogger docs.

    ATTRIBUTES:
      (all of these are settable as keyword args to __init__)

      ATTRIBUTES   DEFAULT      DESCRIPTION
      ----------------------------------------------------------
      list         = []         list of contained objects
      threshold    = None       meaning depends on test - by default
                                threshold has no effect
      default      = 1          default priority to log at

    """


    def __init__(self, list=[], threshold=None, default=1):
        self.list = list
        self.threshold = threshold
        self.default = default

    def add(self, log_obj):
        """Add a log object to the container."""
        self.list.append(log_obj)

    def log(self, priority, message=None):
        """Log a message to all contained log objects, depending on
        the results of test()
        """
        p, m = self._use_default(priority, message)
        for log_obj in self.list:
            if self.test(p, m, self.threshold, log_obj):
                log_obj.log(p, m)

    __call__ = log
    
    def write(self, priority, message=None):
        p, m = self._use_default(priority, message)
        for log_obj in self.list:
            if self.test(p, m, self.threshold, log_obj):
                log_obj.write(p, m)

    def _use_default(self, priority, message):
        """Substitute default priority if none was provided"""
        if message == None: return self.default, priority
        else: return priority, message

    def test(self, priority, message, threshold, log_obj):
        """Test which log objects should be passed a given message.

        This method is used to determine if a given message (and
        priority) should be passed on to a given log_obj.  The
        container's threshold is also provided.

        This method always returns 1, and is the default, meaning that
        all messages will get passed to all objects.  It is intended
        to be overridden if you want more complex behavior.  To
        override with your own function, just do something like:

          def hell_no(p, m, t, object): return 0
          container.test = hell_no
        """
        return 1

    def test_limit_priority(self, priority, message, threshold, log_obj):
        """Only pass on messages with sufficient priority compared to
        the master threshold.

          container = LogContainer([system, debug], threshold = 2)
          container.test = container.test_limit_priority
        """
        return priority <= threshold
    
    def test_limit_threshold(self, priority, message, threshold, log_obj):
        """Only pass on messages to log objects whose threshold is
        (numerically) lower than the master threshold.

          container = LogContainer([system, debug], threshold = 2)
          container.test = container.test_limit_threshold
        """
        return log_obj.threshold <= threshold
    
if __name__ == '__main__':
    ###### TESTING AND DEMONSTRATION

    threshold = 3
    print 'THRESHOLD = %s' % (threshold)
    log   = Logger(threshold,   preprefix = 'TEST  ')

    print " Lets log a few things!"
    for i in range(-2, 10): log(i, 'log priority %s' % (i))

    print "\n Now make it print the time for each log..."
    import time
    def printtime():
        return time.strftime('%m/%d/%y %H:%M:%S ',time.localtime(time.time()))
    log.preprefix = printtime

    print " and log a few more things"
    for i in range(-2, 10): log(i, 'log priority %s' % (i))
 
    print "\n now create another with a different prefix and priority..."
    print " and put them in a container..."
    log2 = Logger(threshold-2, preprefix = 'LOG2 ')
    cont = LogContainer([log, log2], threshold=0)
    cont.test = cont.test_limit_priority
    
    print " and log a bit more"
    for i in range(-2, 10): cont(i, 'log priority %s' % (i))

    print "\n OK, enough of the container... lets play with formatting"

    stuff = 'abcd\nefgh\nijkl'

    print "\n no trailing newline"
    log(stuff)

    print "\n with trailing newline"
    log(stuff + '\n') # should be the same because the log method
                      # takes care of the newline for you

    print "\n two trailing newlines"
    log(stuff + '\n\n') # should create a "blank" line.  If you use two
                        # newlines, it knows you really wanted one :)
    
    print "\n log JUST a newline"
    log('\n') # should create only a single "blank" line

    print "\n use the write method, with a trailing newline"
    log.write(stuff + '\n') # should just write with no prefix crap
                            # it will _NOT_ quietly tack on a newline

    print "\n print some complex object"
    log(1, {'key': 'value'}) # non-strings should be no trouble

    print "\n now set the file object to None (nothing should be printed)"
    log.file_object = None
    log("THIS SHOULD NOT BE PRINTED")
    log.write("THIS SHOULD NOT BE PRINTED")

    if not (sys.argv[1:] and sys.argv[1] == 'syslog'):
        print '\n skipping syslog test (because they would annoy your admin)'
        print ' add "syslog" to the command line to enable syslog tests'
        sys.exit(0)

    print '\n performing syslog tests'
    print ' creating logger with threshold: 3'
    slog = SysLogger(3, 'logger-test')

    print ' logging at (logger) priorities from -2 to 9 (but only'
    print ' priorities <= 3 should show up'
    for i in range(-2, 10): slog(i, 'log priority %s' % (i))

    print '\n now test the write() method'
    slog.write(0, 'first line\nsecond line\nthird line')


--- NEW FILE mdcache.py ---
#!/usr/bin/python -tt
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# Copyright 2004 Duke University 

import os
import sys
import libxml2
import cPickle

import Errors


class RepodataParser:
    def __init__(self, storedir, callback=None):
        self.storedir = storedir
        self.callback = callback
        self.repodata = {
            'metadata': {}, 
            'filelists': {},
            'otherdata': {}
        }
        self.debug = 0
    
        
    def debugprint(self, msg):
        if self.debug:
            print msg
        
    def _piklFileName(self, location, checksum):
        filename = os.path.basename(location)
        piklfile = os.path.join(self.storedir, filename)
        piklfile = '%s.%s.pickle' % (piklfile, checksum)
        self.debugprint('piklfile=%s' % piklfile)
        return piklfile

    def _pickle(self, outfile, obj):
        self.debugprint('Trying to pickle into %s' % outfile)
        try: outfh = open(outfile, 'w')
        except IOError, e: 
            raise cPickle.PicklingError(e)
        try: cPickle.dump(obj, outfh, cPickle.HIGHEST_PROTOCOL)
        except AttributeError: cPickle.dump(obj, outfh, 1)
        self.debugprint('Pickle successful!')
        outfh.close()

    def _unpickle(self, infile):
        self.debugprint('Trying to unpickle from %s' % infile)
        try: infh = open(infile)
        except IOError, e: raise cPickle.UnpicklingError(e)
        obj = cPickle.load(infh)
        infh.close()
        self.debugprint('Unpickle successful!')
        return obj

    def _killold(self, location):
        filename = os.path.basename(location)
        dirfiles = os.listdir(self.storedir)
        for dirfile in dirfiles:
            if dirfile[-7:] == '.pickle':
                if dirfile[:len(filename)] == filename:
                    oldpickle = os.path.join(self.storedir, dirfile)
                    self.debugprint('removing old pickle file %s' % oldpickle)
                    try: os.unlink(oldpickle)
                    except OSError:
                        ## Give an error or something
                        pass
        
    def _getGeneric(self, ident, location, checksum):
        databank = self.repodata[ident]
        if databank: return databank
        if checksum is None:
            ##
            # Pass checksum as None to ignore pickling. This will
            # Go straight to xml files.
            return self.parseDataFromXml(location)
        piklfile = self._piklFileName(location, checksum)
        try: 
            databank = self._unpickle(piklfile)
            self.repodata[ident] = databank
            return databank
        except cPickle.UnpicklingError, e: 
            self.debugprint('Could not unpickle: %s!' % e)
            databank = self.parseDataFromXml(location)
            self._killold(location)
            try: self._pickle(piklfile, databank)
            except cPickle.PicklingError:
                self.debugprint('Could not pickle %s data in %s' % (ident, piklfile))
            return databank
        
    def getPrimary(self, location, checksum):
        return self._getGeneric('metadata', location, checksum)

    def getFilelists(self, location, checksum):
        return self._getGeneric('filelists', location, checksum)

    def getOtherdata(self, location, checksum):
        return self._getGeneric('otherdata', location, checksum)

    def parseDataFromXml(self, fileloc):
        ## TODO: Fail sanely.
        self.debugprint('Parsing data from %s' % fileloc)
        reader = libxml2.newTextReaderFilename(fileloc)
        count = 0
        total = 9999
        mode = None
        databank = None
        while reader.Read():
            if reader.NodeType() != 1: continue
            name = reader.LocalName()
            if name in ('metadata', 'filelists', 'otherdata'):
                mode = name
                databank = self.repodata[mode]
                try: total = int(reader.GetAttribute('packages'))
                except ValueError: pass
            elif name == 'package':
                count += 1
                if mode == 'metadata':
                    obj = PrimaryEntry(reader)
                    pkgid = obj.checksum['value']
                    #if pkgid in databank.keys():
                    #    print 'double detected!'
                    #    print databank[pkgid].nevra, 'vs', obj.nevra
                    if pkgid: databank[pkgid] = obj
                elif mode == 'filelists':
                    pkgid = reader.GetAttribute('pkgid')
                    if pkgid:
                        obj = FilelistsEntry(reader)
                        databank[pkgid] = obj
                elif mode == 'otherdata':
                    pkgid = reader.GetAttribute('pkgid')
                    if pkgid:
                        obj = OtherEntry(reader)
                        databank[pkgid] = obj
                if self.callback: 
                    self.callback.progressbar(count, total, 'MD Read')
        self.debugprint('Parsed %s packages' % count)
        reader.Close()
        del reader
        return databank

class BaseEntry:
    def _props(self, reader):
        if not reader.HasAttributes(): return {}
        propdict = {}
        reader.MoveToFirstAttribute()
        while 1:
            propdict[reader.LocalName()] = reader.Value()
            if not reader.MoveToNextAttribute(): break
        reader.MoveToElement()
        return propdict
        
    def _value(self, reader):
        if reader.IsEmptyElement(): return ''
        val = ''
        while reader.Read():
            if reader.NodeType() == 3: val += reader.Value()
            else: break
        return val

    def _getFileEntry(self, reader):
        type = 'file'
        props = self._props(reader)
        if props.has_key('type'): type = props['type']
        value = self._value(reader)
        return (type, value)

class PrimaryEntry(BaseEntry):
    def __init__(self, reader):
        self.nevra = (None, None, None, None, None)
        self.checksum = {'type': None, 'pkgid': None, 'value': None}
        self.info = {
            'summary': None,
            'description': None,
            'packager': None,
            'url': None,
            'license': None,
            'vendor': None,
            'group': None,
            'buildhost': None,
            'sourcerpm': None
        }
        self.time = {'file': None, 'build': None}
        self.size = {'package': None, 'installed': None, 'archive': None}
        self.location = {'href': None, 'value': None, 'base': None}
        self.hdrange = {'start': None, 'end': None}
        self.prco = {}
        self.files = {}

        n = e = v = r = a = None
        while reader.Read():
            if reader.NodeType() == 15 and reader.LocalName() == 'package':
                break
            if reader.NodeType() != 1: continue
            name = reader.LocalName()
            if name == 'name': n = self._value(reader)
            elif name == 'arch': a = self._value(reader)
            elif name == 'version': 
                evr = self._props(reader)
                (e, v, r) = (evr['epoch'], evr['ver'], evr['rel'])
            elif name in ('summary', 'description', 'packager', 'url'): 
                self.info[name] = self._value(reader)
            elif name == 'checksum': 
                self.checksum = self._props(reader)
                self.checksum['value'] = self._value(reader)
            elif name == 'location': 
                self.location = self._props(reader)
                self.location['value'] = self._value(reader)
            elif name == 'time':
                self.time = self._props(reader)
            elif name == 'size':
                self.size = self._props(reader)
            elif name == 'format': self.setFormat(reader)
        self.nevra = (n, e, v, r, a)

    def dump(self):
        print 'nevra=%s,%s,%s,%s,%s' % self.nevra
        print 'checksum=%s' % self.checksum
        print 'info=%s' % self.info
        print 'time=%s' % self.time
        print 'size=%s' % self.size
        print 'location=%s' % self.location
        print 'hdrange=%s' % self.hdrange
        print 'prco=%s' % self.prco
        print 'files=%s' % self.files

    def setFormat(self, reader):
        while reader.Read():
            if reader.NodeType() == 15 and reader.LocalName() == 'format':
                break
            if reader.NodeType() != 1: continue
            name = reader.LocalName()
            if name in ('license', 'vendor', 'group', 'buildhost',
                        'sourcerpm'):
                self.info[name] = self._value(reader)
            elif name in ('provides', 'requires', 'conflicts', 
                          'obsoletes'):
                self.setPrco(reader)
            elif name == 'header-range':
                self.hdrange = self._props(reader)
            elif name == 'file':
                (type, value) = self._getFileEntry(reader)
                self.files[value] = type

    def setPrco(self, reader):
        members = []
        myname = reader.LocalName()
        while reader.Read():
            if reader.NodeType() == 15 and reader.LocalName() == myname:
                break
            if reader.NodeType() != 1: continue
            name = reader.LocalName()
            members.append(self._props(reader))
        self.prco[myname] = members

class FilelistsEntry(BaseEntry):
    def __init__(self, reader):
        self.files = {}
        while reader.Read():
            if reader.NodeType() == 15 and reader.LocalName() == 'package':
                break
            if reader.NodeType() != 1: continue
            name = reader.LocalName()
            if name == 'file':
                (type, value) = self._getFileEntry(reader)
                self.files[value] = type
                
    def dump(self):
        print 'files=%s' % self.files

class OtherEntry(BaseEntry):
    def __init__(self, reader):
        self.changelog = []
        while reader.Read():
            if reader.NodeType() == 15 and reader.LocalName() == 'package':
                break
            if reader.NodeType() != 1: continue
            name = reader.LocalName()
            if name == 'changelog':
                entry = self._props(reader)
                entry['value'] = self._value(reader)
                self.changelog.append(entry)

    def dump(self):
        print 'changelog=%s' % self.changelog

def test(level, repodir, storedir, checksum):
    import time
    primary = os.path.join(repodir, 'primary.xml')
    filelists = os.path.join(repodir, 'filelists.xml')
    otherdata = os.path.join(repodir, 'other.xml')
    tick = time.time()
    bigtick = tick
    rp = RepodataParser(storedir)
    rp.getPrimary(primary, checksum)
    print 'operation took: %d seconds' % (time.time() - tick)
    print 'primary has %s entries' % len(rp.repodata['metadata'].keys())
    tick = time.time()
    if level == 'filelists' or level == 'other':
        rp.getFilelists(filelists, checksum)
        print 'operation took: %d seconds' % (time.time() - tick)
        print 'filelists has %s entries' % len(rp.repodata['filelists'].keys())
        tick = time.time()
    if level == 'other':
        rp.getOtherdata(otherdata, checksum)
        print 'operation took: %d seconds' % (time.time() - tick)
        print 'otherdata has %s entries' % len(rp.repodata['otherdata'].keys())
    print
    print 'total operation time: %d seconds' % (time.time() - bigtick)

def testusage():
    print 'Usage: %s level repodir storedir checksum' % sys.argv[0]
    print 'level can be primary, filelists, other'
    print 'repodir is the location of .xml files'
    print 'storedir is where pickles will be saved'
    print 'checksum can be anything you want it to be'
    sys.exit(1)
    
if __name__ == '__main__':
    try: (level, repodir, storedir, checksum) = sys.argv[1:]
    except ValueError: testusage()
    if level not in ('primary', 'filelists', 'other'): testusage()
    if checksum == 'None': checksum = None
    test(level, repodir, storedir, checksum)


--- NEW FILE mdparser.py ---

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# Copyright 2005 Duke University 

import gzip
from cElementTree import iterparse

try:
    from cStringIO import StringIO
except ImportError:
    from StringIO import StringIO

#TODO: document everything here

class MDParser:

    def __init__(self, filename):

        # Set up mapping of meta types to handler classes 
        handlers = {
            '{http://linux.duke.edu/metadata/common}metadata': PrimaryEntry,
            '{http://linux.duke.edu/metadata/filelists}filelists': FilelistsEntry,
            '{http://linux.duke.edu/metadata/other}otherdata': OtherEntry,
        }
            
        self.total = None
        self.count = 0
        self._handlercls = None

        # Read in type, set package node handler and get total number of
        # packages
        if filename[-3:] == '.gz': fh = gzip.open(filename, 'r')
        else: fh = open(filename, 'r')
        parser = iterparse(fh, events=('start', 'end'))
        self.reader = parser.__iter__()
        event, elem = self.reader.next()
        self._handlercls = handlers.get(elem.tag, None)
        if not self._handlercls:
            raise ValueError('Unknown repodata type "%s" in %s' % (
                elem.tag, filename))
        # Get the total number of packages
        total = elem.get('packages', None)
        self.total = total is None and None or int(total)

    def __iter__(self):
        return self

    def next(self):
        for event, elem in self.reader:
            if event == 'end' and elem.tag[-7:] == 'package':
                self.count += 1
                return self._handlercls(elem)
        raise StopIteration


class BaseEntry:
    def __init__(self, elem):
        self._p = {} 

    def __getitem__(self, k):
        return self._p[k]

    def keys(self):
        return self._p.keys()

    def values(self):
        return self._p.values()

    def has_key(self, k):
        return self._p.has_key(k)

    def __str__(self):
        out = StringIO()
        keys = self.keys()
        keys.sort()
        for k in keys:
            line = u'%s=%s\n' % (k, self[k])
            out.write(line.encode('utf8'))
        return out.getvalue()

    def _bn(self, qn):
        if qn.find('}') == -1: return qn 
        return qn.split('}')[1]
        
    def _prefixprops(self, elem, prefix):
        ret = {}
        for key in elem.attrib.keys():
            ret[prefix + '_' + self._bn(key)] = elem.attrib[key]
        return ret

class PrimaryEntry(BaseEntry):
    def __init__(self, elem):
        BaseEntry.__init__(self, elem)
        # Avoid excess typing :)
        p = self._p

        self.prco = {}
        self.files = {}

        for child in elem:
            name = self._bn(child.tag)
            if name in ('name', 'arch', 'summary', 'description', 'url', 
                    'packager'): 
                p[name] = child.text

            elif name == 'version': 
                p.update(child.attrib)

            elif name in ('time', 'size'):
                p.update(self._prefixprops(child, name))

            elif name in ('checksum', 'location'): 
                p.update(self._prefixprops(child, name))
                p[name + '_value'] = child.text
                if name == 'location' and not p.has_key("location_base"):
                    p["location_base"] = None
            
            elif name == 'format': 
                self.setFormat(child)

        p['pkgId'] = p['checksum_value']
        elem.clear()

    def setFormat(self, elem):

        # Avoid excessive typing :)
        p = self._p

        for child in elem:
            name = self._bn(child.tag)

            if name in ('license', 'vendor', 'group', 'buildhost',
                        'sourcerpm'):
                p[name] = child.text

            elif name in ('provides', 'requires', 'conflicts', 
                          'obsoletes'):
                self.prco[name] = self.getPrco(child)

            elif name == 'header-range':
                p.update(self._prefixprops(child, 'rpm_header'))

            elif name == 'file':
                type = child.get('type', 'file')
                path = child.text
                self.files[path] = type

    def getPrco(self, elem):
        members = []
        for child in elem:
            name = self._bn(child.tag)
            members.append(child.attrib)
        return members
        
        
class FilelistsEntry(BaseEntry):
    def __init__(self, elem):
        BaseEntry.__init__(self, elem)
        self._p['pkgId'] = elem.attrib['pkgid']
        self.files = {}
        for child in elem:
            name = self._bn(child.tag)
            if name == 'file':
                type = child.get('type', 'file')
                path = child.text
                self.files[path] = type
        elem.clear()
                
class OtherEntry(BaseEntry):
    def __init__(self, elem):
        BaseEntry.__init__(self, elem)
        self._p['pkgId'] = elem.attrib['pkgid']
        self._p['changelog'] = []
        for child in elem:
            name = self._bn(child.tag)
            if name == 'changelog':
                entry = child.attrib
                entry['value'] = child.text
                self._p['changelog'].append(entry)
        elem.clear()



def test():
    import sys

    parser = MDParser(sys.argv[1])

    for pkg in parser:
        print '-' * 40
        print pkg
        pass

    print 'read: %s packages (%s suggested)' % (parser.count, parser.total)

if __name__ == '__main__':
    test()


--- NEW FILE misc.py ---
import types
import string
import os
import os.path
from cStringIO import StringIO
import base64
import struct
import re
import pgpmsg
import tempfile
import glob
import rpm
import pwd
from stat import *

import rpmUtils
from Errors import MiscError

###########
# Title: Remove duplicates from a sequence
# Submitter: Tim Peters 
# From: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/52560                      
    
def unique(s):
    """Return a list of the elements in s, but without duplicates.

    For example, unique([1,2,3,1,2,3]) is some permutation of [1,2,3],
    unique("abcabc") some permutation of ["a", "b", "c"], and
    unique(([1, 2], [2, 3], [1, 2])) some permutation of
    [[2, 3], [1, 2]].

    For best speed, all sequence elements should be hashable.  Then
    unique() will usually work in linear time.

    If not possible, the sequence elements should enjoy a total
    ordering, and if list(s).sort() doesn't raise TypeError it's
    assumed that they do enjoy a total ordering.  Then unique() will
    usually work in O(N*log2(N)) time.

    If that's not possible either, the sequence elements must support
    equality-testing.  Then unique() will usually work in quadratic
    time.
    """

    n = len(s)
    if n == 0:
        return []

    # Try using a dict first, as that's the fastest and will usually
    # work.  If it doesn't work, it will usually fail quickly, so it
    # usually doesn't cost much to *try* it.  It requires that all the
    # sequence elements be hashable, and support equality comparison.
    u = {}
    try:
        for x in s:
            u[x] = 1
    except TypeError:
        del u  # move on to the next method
    else:
        return u.keys()

    # We can't hash all the elements.  Second fastest is to sort,
    # which brings the equal elements together; then duplicates are
    # easy to weed out in a single pass.
    # NOTE:  Python's list.sort() was designed to be efficient in the
    # presence of many duplicate elements.  This isn't true of all
    # sort functions in all languages or libraries, so this approach
    # is more effective in Python than it may be elsewhere.
    try:
        t = list(s)
        t.sort()
    except TypeError:
        del t  # move on to the next method
    else:
        assert n > 0
        last = t[0]
        lasti = i = 1
        while i < n:
            if t[i] != last:
                t[lasti] = last = t[i]
                lasti += 1
            i += 1
        return t[:lasti]

    # Brute force is all that's left.
    u = []
    for x in s:
        if x not in u:
            u.append(x)
    return u

def checksum(sumtype, file, CHUNK=2**16):
    """takes filename, hand back Checksum of it
       sumtype = md5 or sha
       filename = /path/to/file
       CHUNK=65536 by default"""
       
    # chunking brazenly lifted from Ryan Tomayko
    try:
        if type(file) is not types.StringType:
            fo = file # assume it's a file-like-object
        else:           
            fo = open(file, 'r', CHUNK)
            
        if sumtype == 'md5':
            import md5
            sum = md5.new()
        elif sumtype == 'sha':
            import sha
            sum = sha.new()
        else:
            raise MiscError, 'Error Checksumming file, bad checksum type %s' % sumtype
        chunk = fo.read
        while chunk: 
            chunk = fo.read(CHUNK)
            sum.update(chunk)

        if type(file) is types.StringType:
            fo.close()
            del fo
            
        return sum.hexdigest()
    except (IOError, OSError), e:
        raise MiscError, 'Error opening file for checksum: %s' % file

def getFileList(path, ext, filelist):
    """Return all files in path matching ext, store them in filelist, 
       recurse dirs return list object"""
    
    extlen = len(ext)
    try:
        dir_list = os.listdir(path)
    except OSError, e:
        raise MiscError, ('Error accessing directory %s, %s') % (path, e)
        
    for d in dir_list:
        if os.path.isdir(path + '/' + d):
            filelist = getFileList(path + '/' + d, ext, filelist)
        else:
            if string.lower(d[-extlen:]) == '%s' % (ext):
               newpath = os.path.normpath(path + '/' + d)
               filelist.append(newpath)
                    
    return filelist

class GenericHolder:
    """Generic Holder class used to hold other objects of known types
       It exists purely to be able to do object.somestuff, object.someotherstuff
       or object[key] and pass object to another function that will 
       understand it"""
       
    def __getitem__(self, item):
        if hasattr(self, item):
            return getattr(self, item)
        else:
            raise KeyError, item

def procgpgkey(rawkey):
    '''Convert ASCII armoured GPG key to binary
    '''
    # TODO: CRC checking? (will RPM do this anyway?)
    
    # Normalise newlines
    rawkey = re.compile('(\n|\r\n|\r)').sub('\n', rawkey)

    # Extract block
    block = StringIO()
    inblock = 0
    pastheaders = 0
    for line in rawkey.split('\n'):
        if line.startswith('-----BEGIN PGP PUBLIC KEY BLOCK-----'):
            inblock = 1
        elif inblock and line.strip() == '':
            pastheaders = 1
        elif inblock and line.startswith('-----END PGP PUBLIC KEY BLOCK-----'):
            # Hit the end of the block, get out
            break
        elif pastheaders and line.startswith('='):
            # Hit the CRC line, don't include this and stop
            break
        elif pastheaders:
            block.write(line+'\n')
  
    # Decode and return
    return base64.decodestring(block.getvalue())

def getgpgkeyinfo(rawkey):
    '''Return a dict of info for the given ASCII armoured key text

    Returned dict will have the following keys: 'userid', 'keyid', 'timestamp'

    Will raise ValueError if there was a problem decoding the key.
    '''
    # Catch all exceptions as there can be quite a variety raised by this call
    try:
        key = pgpmsg.decode_msg(rawkey)
    except Exception, e:
        raise ValueError(str(e))
    if key is None:
        raise ValueError('No key found in given key data')

    keyid_blob = key.public_key.key_id()

    info = {
        'userid': key.user_id,
        'keyid': struct.unpack('>Q', keyid_blob)[0],
        'timestamp': key.public_key.timestamp,
    }

    # Retrieve the timestamp from the matching signature packet 
    # (this is what RPM appears to do) 
    for userid in key.user_ids[0]:
        if not isinstance(userid, pgpmsg.signature):
            continue

        if userid.key_id() == keyid_blob:
            # Get the creation time sub-packet if available
            if hasattr(userid, 'hashed_subpaks'):
                tspkt = \
                    userid.get_hashed_subpak(pgpmsg.SIG_SUB_TYPE_CREATE_TIME)
                if tspkt != None:
                    info['timestamp'] = int(tspkt[1])
                    break
        
    return info

def keyIdToRPMVer(keyid):
    '''Convert an integer representing a GPG key ID to the hex version string
    used by RPM
    '''
    return "%08x" % (keyid & 0xffffffffL)


def keyInstalled(ts, keyid, timestamp):
    '''Return if the GPG key described by the given keyid and timestamp are
    installed in the rpmdb.  

    The keyid and timestamp should both be passed as integers.
    The ts is an rpm transaction set object

    Return values:
        -1      key is not installed
        0       key with matching ID and timestamp is installed
        1       key with matching ID is installed but has a older timestamp
        2       key with matching ID is installed but has a newer timestamp

    No effort is made to handle duplicates. The first matching keyid is used to 
    calculate the return result.
    '''
    # Convert key id to 'RPM' form
    keyid = keyIdToRPMVer(keyid)

    # Search
    for hdr in ts.dbMatch('name', 'gpg-pubkey'):
        if hdr['version'] == keyid:
            installedts = int(hdr['release'], 16)
            if installedts == timestamp:
                return 0
            elif installedts < timestamp:
                return 1    
            else:
                return 2

    return -1

def getCacheDir(tmpdir='/var/tmp'):
    """return a path to a valid and safe cachedir - only used when not running
       as root or when --tempcache is set"""
    
    uid = os.geteuid()
    try:
        usertup = pwd.getpwuid(uid)
        username = usertup[0]
    except KeyError:
        return None # if it returns None then, well, it's bollocksed

    # check for /var/tmp/yum-username-* - 
    prefix = 'yum-%s-' % username    
    dirpath = '%s/%s*' % (tmpdir, prefix)
    cachedirs = glob.glob(dirpath)
    
    for thisdir in cachedirs:
        stats = os.lstat(thisdir)
        if S_ISDIR(stats[0]) and S_IMODE(stats[0]) == 448 and stats[4] == uid:
            return thisdir

    # make the dir (tempfile.mkdtemp())
    cachedir = tempfile.mkdtemp(prefix=prefix, dir=tmpdir)
    return cachedir

def sortPkgObj(pkg1 ,pkg2):
    """sorts a list of yum package objects by name"""
    if pkg1.name > pkg2.name:
        return 1
    elif pkg1.name == pkg2.name:
        return 0
    else:
        return -1
        


--- NEW FILE packages.py ---
#!/usr/bin/python -tt
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# Copyright 2004 Duke University 
# Written by Seth Vidal <skvidal at phy.duke.edu>

import rpm
import os
import os.path
import misc
import re
import types
import fnmatch
import rpmUtils
import rpmUtils.arch
import rpmUtils.miscutils
import Errors

import repomd.packageObject

base=None

def buildPkgRefDict(pkgs):
    """take a list of pkg objects and return a dict the contains all the possible
       naming conventions for them eg: for (name,i386,0,1,1)
       dict[name] = (name, i386, 0, 1, 1)
       dict[name.i386] = (name, i386, 0, 1, 1)
       dict[name-1-1.i386] = (name, i386, 0, 1, 1)       
       dict[name-1] = (name, i386, 0, 1, 1)       
       dict[name-1-1] = (name, i386, 0, 1, 1)
       dict[0:name-1-1.i386] = (name, i386, 0, 1, 1)
       dict[name-0:1-1.i386] = (name, i386, 0, 1, 1)
       """
    pkgdict = {}
    for pkg in pkgs:
        pkgtup = (pkg.name, pkg.arch, pkg.epoch, pkg.version, pkg.release)
        (n, a, e, v, r) = pkgtup
        name = n
        nameArch = '%s.%s' % (n, a)
        nameVerRelArch = '%s-%s-%s.%s' % (n, v, r, a)
        nameVer = '%s-%s' % (n, v)
        nameVerRel = '%s-%s-%s' % (n, v, r)
        envra = '%s:%s-%s-%s.%s' % (e, n, v, r, a)
        nevra = '%s-%s:%s-%s.%s' % (n, e, v, r, a)
        for item in [name, nameArch, nameVerRelArch, nameVer, nameVerRel, envra, nevra]:
            if not pkgdict.has_key(item):
                pkgdict[item] = []
            pkgdict[item].append(pkg)
            
    return pkgdict
       
def parsePackages(pkgs, usercommands, casematch=0):
    """matches up the user request versus a pkg list:
       for installs/updates available pkgs should be the 'others list' 
       for removes it should be the installed list of pkgs
       takes an optional casematch option to determine if case should be matched
       exactly. Defaults to not matching."""

    pkgdict = buildPkgRefDict(pkgs)
    exactmatch = []
    matched = []
    unmatched = []
    for command in usercommands:
        if pkgdict.has_key(command):
            exactmatch.extend(pkgdict[command])
            del pkgdict[command]
        else:
            # anything we couldn't find a match for
            # could mean it's not there, could mean it's a wildcard
            if re.match('.*[\*,\[,\],\{,\},\?].*', command):
                trylist = pkgdict.keys()
                restring = fnmatch.translate(command)
                if casematch:
                    regex = re.compile(restring) # case sensitive
                else:
                    regex = re.compile(restring, flags=re.I) # case insensitive
                foundit = 0
                for item in trylist:
                    if regex.match(item):
                        matched.extend(pkgdict[item])
                        del pkgdict[item]
                        foundit = 1
 
                if not foundit:    
                    unmatched.append(command)
                    
            else:
                # we got nada
                unmatched.append(command)

    matched = misc.unique(matched)
    unmatched = misc.unique(unmatched)
    exactmatch = misc.unique(exactmatch)
    return exactmatch, matched, unmatched


# goal for the below is to have a packageobject that can be used by generic
# functions independent of the type of package - ie: installed or available
class YumAvailablePackage(repomd.packageObject.PackageObject, repomd.packageObject.RpmBase):
    """derived class for the repomd packageobject and RpmBase packageobject yum
       uses this for dealing with packages in a repository"""

    def __init__(self, pkgdict, repoid):
        repomd.packageObject.PackageObject.__init__(self)
        repomd.packageObject.RpmBase.__init__(self)
        
        self.importFromDict(pkgdict, repoid)
        # quick, common definitions
        self.name = self.returnSimple('name')
        self.epoch = self.returnSimple('epoch')
        self.version = self.returnSimple('version')
        self.release = self.returnSimple('release')
        self.ver = self.returnSimple('version')
        self.rel = self.returnSimple('release')
        self.arch = self.returnSimple('arch')
        self.repoid = self.returnSimple('repoid')
        self.pkgtup = self._pkgtup()
        self.state = None
        
    def size(self):
        return self.returnSimple('packagesize')

    def _pkgtup(self):
        return self.returnPackageTuple()

    def printVer(self):
        """returns a printable version string - including epoch, if it's set"""
        if self.epoch != '0':
            ver = '%s:%s-%s' % (self.epoch, self.version, self.release)
        else:
            ver = '%s-%s' % (self.version, self.release)
        
        return ver
    
    def compactPrint(self):
        ver = self.printVer()
        return "%s.%s %s" % (self.name, self.arch, ver)

    def returnLocalHeader(self):
        """returns an rpm header object from the package object's local
           header cache"""
        
        if os.path.exists(self.localHdr()):
            try: 
                hlist = rpm.readHeaderListFromFile(self.localHdr())
                hdr = hlist[0]
            except (rpm.error, IndexError):
                raise Errors.RepoError, 'Cannot open package header'
        else:
            raise Errors.RepoError, 'Package Header Not Available'

        return hdr

    def getProvidesNames(self):
        """returns a list of providesNames"""
        
        provnames = []
        prov = self.returnPrco('provides')
        
        for (name, flag, vertup) in prov:
            provnames.append(name)

        return provnames
       
    def localPkg(self):
        """return path to local package (whether it is present there, or not)"""
        if not hasattr(self, 'localpath'):
            repo = base.repos.getRepo(self.repoid)
            remote = self.returnSimple('relativepath')
            rpmfn = os.path.basename(remote)
            self.localpath = repo.pkgdir + '/' + rpmfn
        return self.localpath

    def localHdr(self):
        """return path to local cached Header file downloaded from package 
           byte ranges"""
           
        if not hasattr(self, 'hdrpath'):
            repo = base.repos.getRepo(self.repoid)
            pkgpath = self.returnSimple('relativepath')
            pkgname = os.path.basename(pkgpath)
            hdrname = pkgname[:-4] + '.hdr'
            self.hdrpath = repo.hdrdir + '/' + hdrname

        return self.hdrpath
    
    def prcoPrintable(self, prcoTuple):
        """convert the prco tuples into a nicer human string"""
        (name, flag, (e, v, r)) = prcoTuple
        flags = {'GT':'>', 'GE':'>=', 'EQ':'=', 'LT':'<', 'LE':'<='}
        if flag is None:
            return name
        
        base = '%s %s ' % (name, flags[flag])
        if e not in [0, '0', None]:
            base += '%s:' % e
        if v is not None:
            base += '%s' % v
        if r is not None:
            base += '-%s' % r
        
        return base
    
    def requiresList(self):
        """return a list of requires in normal rpm format"""
        
        reqlist = []
        
        for prcoTuple in self.returnPrco('requires'):
            prcostr = self.prcoPrintable(prcoTuple)
            reqlist.append(prcostr)
        
        return reqlist
        
    def importFromDict(self, pkgdict, repoid):
        """handles an mdCache package dictionary item to populate out 
           the package information"""
        
        self.simple['repoid'] = repoid
        # translates from the pkgdict, populating out the information for the
        # packageObject
        
        if hasattr(pkgdict, 'nevra'):
            (n, e, v, r, a) = pkgdict.nevra
            self.simple['name'] = n
            self.simple['epoch'] = e
            self.simple['version'] = v
            self.simple['arch'] = a
            self.simple['release'] = r
        
        if hasattr(pkgdict, 'time'):
            self.simple['buildtime'] = pkgdict.time['build']
            self.simple['filetime'] = pkgdict.time['file']
        
        if hasattr(pkgdict, 'size'):
            self.simple['packagesize'] = pkgdict.size['package']
            self.simple['archivesize'] = pkgdict.size['archive']
            self.simple['installedsize'] = pkgdict.size['installed']
        
        if hasattr(pkgdict, 'location'):
            if not pkgdict.location.has_key('base'):
                url = None
            elif pkgdict.location['base'] == '':
                url = None
            else:
                url = pkgdict.location['base']

            self.simple['basepath'] = url
            self.simple['relativepath'] = pkgdict.location['href']
        
        if hasattr(pkgdict, 'hdrange'):
            self.simple['hdrstart'] = pkgdict.hdrange['start']
            self.simple['hdrend'] = pkgdict.hdrange['end']
        
        if hasattr(pkgdict, 'info'):
            infodict = pkgdict.info
            for item in ['summary', 'description', 'packager', 'group',
                         'buildhost', 'sourcerpm', 'url', 'vendor']:
                self.simple[item] = infodict[item]
            
            self.licenses.append(infodict['license'])
        
        if hasattr(pkgdict, 'files'):
            for file in pkgdict.files.keys():
                ftype = pkgdict.files[file]
                if not self.files.has_key(ftype):
                    self.files[ftype] = []
                self.files[ftype].append(file)
        
        if hasattr(pkgdict, 'prco'):
            for rtype in pkgdict.prco.keys():
                for rdict in pkgdict.prco[rtype]:
                    name = rdict['name']
                    f = e = v = r  = None
                    if rdict.has_key('flags'): f = rdict['flags']
                    if rdict.has_key('epoch'): e = rdict['epoch']
                    if rdict.has_key('ver'): v = rdict['ver']
                    if rdict.has_key('rel'): r = rdict['rel']
                    self.prco[rtype].append((name, f, (e,v,r)))

        if hasattr(pkgdict, 'changelog'):
            for cdict in pkgdict.changelog:
                date = text = author = None
                if cdict.has_key('date'): date = cdict['date']
                if cdict.has_key('value'): text = cdict['value']
                if cdict.has_key('author'): author = cdict['author']
                self.changelog.append((date, author, text))
        
        if hasattr(pkgdict, 'checksum'):
            ctype = pkgdict.checksum['type']
            csum = pkgdict.checksum['value']
            csumid = pkgdict.checksum['pkgid']
            if csumid is None or csumid.upper() == 'NO':
                csumid = 0
            elif csumid.upper() == 'YES':
                csumid = 1
            else:
                csumid = 0
            self.checksums.append((ctype, csum, csumid))
            


class YumInstalledPackage(YumAvailablePackage):
    """super class for dealing with packages in the rpmdb"""
    def __init__(self, hdr):
        """hand in an rpm header, we'll assume it's installed and query from there"""
        self.hdr = hdr
        self.name = self.tagByName('name')
        self.arch = self.tagByName('arch')
        self.epoch = self.doepoch()
        self.version = self.tagByName('version')
        self.release = self.tagByName('release')
        self.ver = self.tagByName('version')
        self.rel = self.tagByName('release')
        self.pkgtup = self._pkgtup()
        self.repoid = 'installed'
        self.summary = self.tagByName('summary')
        self.description = self.tagByName('description')
        self.pkgid = self.tagByName(rpm.RPMTAG_SHA1HEADER)
        self.state = None
        
    def __str__(self):
        if self.epoch == '0':
            val = '%s - %s-%s.%s' % (self.name, self.version, self.release, 
                                        self.arch)
        else:
            val = '%s - %s:%s-%s.%s' % (self.name, self.epoch, self.version,
                                           self.release, self.arch)
        return val

    def tagByName(self, tag):
        data = self.hdr[tag]
        return data
    
    def doepoch(self):
        tmpepoch = self.hdr['epoch']
        if tmpepoch is None:
            epoch = '0'
        else:
            epoch = str(tmpepoch)
        
        return epoch
    
    def returnSimple(self, thing):
        if hasattr(self, thing):
            return getattr(self, thing)
        else:
            return self.tagByName(thing)

    def returnLocalHeader(self):
        return self.hdr


    def getProvidesNames(self):
        """returns a list of providesNames"""
        
        provnames = self.tagByName('providename')
        if type(provnames) is not types.ListType():
            if type(provnames) is types.StringType():
                provnames = [provnames]
            else:
                provnames = []

        return provnames

    def requiresList(self):
        """return a list of all of the strings of the package requirements"""
        reqlist = []
        names = self.hdr[rpm.RPMTAG_REQUIRENAME]
        flags = self.hdr[rpm.RPMTAG_REQUIREFLAGS]
        ver = self.hdr[rpm.RPMTAG_REQUIREVERSION]
        if names is not None:
            tmplst = zip(names, flags, ver)
        
        for (n, f, v) in tmplst:
            req = rpmUtils.miscutils.formatRequire(n, v, f)
            reqlist.append(req)
        
        return reqlist

    def _pkgtup(self):
        return (self.name, self.arch, self.epoch, self.version, self.release)
    
    def size(self):
        return self.tagByName('size')

    def printVer(self):
        """returns a printable version string - including epoch, if it's set"""
        if self.epoch != '0':
            ver = '%s:%s-%s' % (self.epoch, self.version, self.release)
        else:
            ver = '%s-%s' % (self.version, self.release)
        
        return ver
    
    def compactPrint(self):
        ver = self.printVer()
        return "%s.%s %s" % (self.name, self.arch, ver)
    

class YumLocalPackage(YumInstalledPackage):
    """Class to handle an arbitrary package from a file path
       this inherits most things from YumInstalledPackage because
       installed packages and an arbitrary package on disk act very
       much alike. init takes a ts instance and a filename/path 
       to the package."""

    def __init__(self, ts=None, filename=None):
        if ts is None:
            raise Errors.MiscError, \
                 'No Transaction Set Instance for YumLocalPackage instance creation'
        if filename is None:
            raise Errors.MiscError, \
                 'No Filename specified for YumLocalPackage instance creation'
                 
        self.pkgtype = 'local'
        self.localpath = filename
        self.repoid = filename
        try:
            self.hdr = rpmUtils.miscutils.hdrFromPackage(ts, self.localpath)
        except rpmUtils.RpmUtilsError, e:
            raise Errors.MiscError, \
                'Could not open local rpm file: %s' % self.localpath
        self.name = self.tagByName('name')
        self.arch = self.tagByName('arch')
        self.epoch = self.doepoch()
        self.version = self.tagByName('version')
        self.release = self.tagByName('release')
        self.ver = self.tagByName('version')
        self.rel = self.tagByName('release')
        self.summary = self.tagByName('summary')
        self.description = self.tagByName('description')
        self.pkgtup = self._pkgtup()
        self.state = None
    
    def _pkgtup(self):
        return (self.name, self.arch, self.epoch, self.version, self.release)
    
    def localPkg(self):
        return self.localpath
    
        


            
        
    
    


--- NEW FILE parser.py ---
import re
import glob
import shlex
import string
import urlparse
import urlgrabber
import os.path
from ConfigParser import ConfigParser, NoSectionError, NoOptionError

#TODO: better handling of recursion
#TODO: ability to handle bare includes (ie. before first [section])
#TODO: avoid include line reordering on write

# The above 3 items are probably handled best by more separation between
# include functionality and ConfigParser. Delegate instead of subclass. See how
# this was done in the previous implementation.

#TODO: problem: interpolation tokens are lost when config files are rewritten
#       - workaround is to not set vars, not sure if this is ok
#       - maybe we should do interpolation at the Option level after all?
#       - preserve original uninterpolated value?
#TODO: separate $var interpolation into YumParser?

class IncludingConfigParser(ConfigParser):

    def __init__(self, vars=None, include="include"):
        """
        @param vars: A dictionary of subsitution variables.
        @param include: Name of option that lists files for inclusion
        """
        self.include = include

        # Dictionary of filenames -> included configparser objects
        self._fns = {}

        # Dictionary of sections -> filenames
        self._included = {}

        self.cwd = None
        ConfigParser.__init__(self, vars)

    def defaults(self):
        """Return a dictionary containing the instance-wide defaults."""
        return self._defaults

    def sections(self):
        """Return a list of the sections available in file and includes."""
        s = self.__sections()
        for included in self._included.keys():
            s.append(included)
        return s

    def has_section(self, section): 
        """Indicates whether the named section is present in 
        the configuration and includes."""
        if section in self.__sections() or section in self._included.keys():
            return True
        else:
            return False

    def has_option(self, section, option):
        if not self.has_section(section):
            raise NoSectionError(section)
        if section in self._included.keys():
            fn = self._included[section]
            return self._fns[fn].has_option(section, option)
        else:
            return ConfigParser.has_option(self,  section, option)
        
    def options(self, section):
        """Return a list of option names for the given section name"""
        if not self.has_section(section):
            raise NoSectionError(section)
        if section in self._included.keys():
            fn = self._included[section]
            return self._fns[fn].options(section)
        else:
            return ConfigParser.options(self,  section)

    def items(self, section):
        if not self.has_section(section):
            raise NoSectionError(section)
        if section in self._included.keys():
            fn = self._included[section]
            return self._fns[fn].items(section)
        else:
            return ConfigParser.items(self, section)

    def remove_section(self, section):
        if not self.has_section(section):
            raise NoSectionError(section)
        if section in self._included.keys():
            fn = self._included[section]
            return self._fns[fn].remove_section(section)
        else:
            return ConfigParser.remove_section(self, section)

    def add_include(self, section, fn):
        """Add a included file to config section"""
        if not self.has_section(section):
            raise NoSectionError(section)
        if not self.has_option(section, self.include):
            raise NoOptionError(self.include, section)
        inc = self.get(section, self.include)
        if fn in shlex.split(inc):
            return
        self._add_include(section, fn)

        
    def remove_include(self, section, fn):
        """Remove an included config parser"""
        if not self.has_section(section):
            raise NoSectionError(section)
        if not self.has_option(section, self.include):
            raise NoOptionError(self.include, section)
        #XXX: raise NoIncludeError???
        if not self._included.has_key(fn):
            return

    def __sections(self):
        return ConfigParser.sections(self) 

    def read(self, filenames):
        for filename in shlex.split(filenames):
            self.cwd = os.path.dirname(os.path.realpath(filename))
            ConfigParser.read(self,filename)
            self._readincludes()

    def readfp(self, fp, filename=None):
        ConfigParser.readfp(self, fp, filename)
        self._readincludes()

    def _add_include(self, section, filename):
        print '_add_include', section, filename
        c = IncludingConfigParser(self._defaults)

        # Be aware of URL style includes
        scheme, loc, path, params, qry, frag = urlparse.urlparse(filename, 'file')

        # Normalise file URLs
        if scheme == 'file':
            filename = path

        # Prepend current directory if absolute path wasn't given
        if scheme == 'file':
            if not filename.startswith(os.path.sep):
                filename = os.path.join(self.cwd, filename)

        c.readfp(urlgrabber.urlopen(filename), filename)

        # Keep track of included sections
        for includesection in c.sections():
            self._included[includesection] = filename
        self._fns[filename] = c

    def _remove_include(self, section, filename):
        inc = self.get(section, self.include)
        filenames = shlex.split(inc)
        if filename in filenames:
            filenames.remove(filename)
        self.set(section, self.include, string.join(filenames, ' '))
        self._included.pop(filename)

    def _readincludes(self):
        for section in ConfigParser.sections(self):
            if self.has_option(section, self.include):
                for filename in shlex.split(self.get(section, self.include)):
                    self._add_include(section, filename)
            
    def get(self, section, option, raw=False, vars=None):
        """Return section from file or included files"""
        if section in self._included:
            fn = self._included[section]
            return self._fns[fn].get(section, option, raw, vars)
        return ConfigParser.get(self, section, option, raw, vars)

    def set(self, section, option, value):
        if section in self._included:
            fn = self._included[section]
            return self._fns[fn].set(section, option, value)
        return ConfigParser.set(self, section, option, value)

    def write(self, fp):
        """Take a file object and write it"""

        # Don't call the parent write() method because it dumps out
        # self._defaults as its own section which isn't desirable here.

        # Write out the items for this file
        for section in self._sections:
            fp.write("[%s]\n" % section)
            for (key, value) in self._sections[section].items():
                if key == '__name__':
                    continue
                fp.write("%s = %s\n" % (key, str(value).replace('\n', '\n\t')))
            fp.write("\n")

        # Write out any included files
        for fn in self._fns.keys():
            # Only bother for files since we can't easily write back to much else.
            scheme = urlparse.urlparse(fn, 'file')[0]
            if scheme == 'file':
                inc = open(fn, 'w')
                self._fns[fn].write(inc)

    def _interpolate(self, section, option, rawval, vars):
        '''Perform $var subsitution (this overides the default %(..)s subsitution)

        Only the rawval and vars arguments are used. The rest are present for
        compatibility with the parent class.
        '''
        return varReplace(rawval, vars)


class IncludedDirConfigParser(IncludingConfigParser):
    """A conf.d recursive parser - supporting one level of included dirs"""

    def __init__(self, vars=None, includedir=None, includeglob="*.conf", include="include"):
        self.includeglob = includeglob
        self.includedir = includedir
        IncludingConfigParser.__init__(self, vars=vars, include=include)

    def read(self, filenames):
        for filename in shlex.split(filenames):
            IncludingConfigParser.read(self,filename)
            self._includedir()

    def _includedir(self):
        for section in ConfigParser.sections(self):
            if self.includedir:
                matches = glob.glob("%s/%s" % (self.includedir, self.includeglob))
                # glob dir, open files, include
                for match in matches:
                    if os.path.exists(match):
                        self._add_include(section, match)

    def add_include(self, section, filename):
        """Add a included file to config section"""
        if not self.has_section(section):
            raise NoSectionError(section)
        self._add_include(section, filename)


_KEYCRE = re.compile(r"\$(\w+)")

def varReplace(raw, vars):
    '''Perform variable replacement

    @param raw: String to perform substitution on.  
    @param vars: Dictionary of variables to replace. Key is variable name
        (without $ prefix). Value is replacement string.
    @return: Input raw string with substituted values.
    '''

    done = []                      # Completed chunks to return

    while raw:
        m = _KEYCRE.search(raw)
        if not m:
            done.append(raw)
            break

        # Determine replacement value (if unknown variable then preserve original)
        varname = m.group(1).lower()
        replacement = vars.get(varname, m.group())

        start, end = m.span()
        done.append(raw[:start])    # Keep stuff leading up to token
        done.append(replacement)    # Append replacement value
        raw = raw[end:]             # Continue with remainder of string

    return ''.join(done)


def _test():
    import sys

    p = IncludingConfigParser()
    p.read(sys.argv[1])

    p.set('one', 'a', '111')
    p.set('three', 'foo', 'bar')

    for section in p.sections():
        print '***', section
        for k, v in p.items(section):
            print '%s = %r' % (k, v)

    p.write(open(sys.argv[1], 'wt'))

if __name__ == '__main__':
    _test()



--- NEW FILE pgpmsg.py ---
##Copyright (C) 2003,2005  Jens B. Jorgensen <jbj1 at ultraemail.net>
##
##This program is free software; you can redistribute it and/or
##modify it under the terms of the GNU General Public License
##as published by the Free Software Foundation; either version 2
##of the License, or (at your option) any later version.
##
##This program is distributed in the hope that it will be useful,
##but WITHOUT ANY WARRANTY; without even the implied warranty of
##MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
##GNU General Public License for more details.
##
##You should have received a copy of the GNU General Public License
##along with this program; if not, write to the Free Software
##Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
import string, struct, time, cStringIO, base64, types, md5, sha

debug = None

# Cypher Type Byte
# bits 7,6 of the CTB say what kind it is
# we only have reserved defined
CTB_76_NORMAL = 0x80
CTB_76_NEW = 0xc0
CTB_76_MASK = 0xc0

# CTB packet type, bits 5,4,3,2
CTB_PKTV2_MASK = 0x3c    # 1111 - mask for this field
CTB_PKT_MASK = 0x3f      # 111111 - all the lower bits

CTB_PKT_PK_ENC = 1       # 0001 - public-key encrypted session packet
CTB_PKT_SIG = 2          # 0010 - signature packet
CTB_PKT_SK_ENC = 3       # 0011 - symmetric-key encrypted session packet
CTB_PKT_OP_SIG = 4       # 0100 - one-pass signature packet
CTB_PKT_SK_CERT = 5      # 0101 - secret-key certificate packet
CTB_PKT_PK_CERT = 6      # 0110 - public-key certificate packet
CTB_PKT_SK_SUB = 7       # 0111 - secret-key subkey packet
CTB_PKT_COMPRESSED = 8   # 1000 - compressed data packet
CTB_PKT_ENC = 9          # 1001 - symmetric-key encrypted data packet
CTB_PKT_MARKER = 10      # 1010 - marker packet
CTB_PKT_LIT = 11         # 1011 - literal data packet
CTB_PKT_TRUST = 12       # 1100 - trust packet
CTB_PKT_USER_ID = 13     # 1101 - user id packet
CTB_PKT_PK_SUB = 14      # 1110 - public subkey packet

ctb_pkt_to_str = {
    CTB_PKT_PK_ENC : 'public-key encrypted session packet',
    CTB_PKT_SIG : 'signature packet',
    CTB_PKT_SK_ENC : 'symmetric-key encrypted session packet',
    CTB_PKT_OP_SIG : 'one-pass signature packet',
    CTB_PKT_SK_CERT : 'secret-key certificate packet',
    CTB_PKT_PK_CERT : 'public-key certificate packet',
    CTB_PKT_SK_SUB : 'secret-key subkey packet',
    CTB_PKT_COMPRESSED : 'compressed data packet',
    CTB_PKT_ENC : 'symmetric-key encrypted data packet',
    CTB_PKT_MARKER : 'marker packet',
    CTB_PKT_LIT : 'literal data packet',
    CTB_PKT_TRUST : 'trust packet',
    CTB_PKT_USER_ID : 'user id packet',
    CTB_PKT_PK_SUB : 'public subkey packet'
}


# CTB packet-length
CTB_PKT_LEN_MASK = 0x3   # 11 - mask

CTB_PKT_LEN_1 = 0        # 00 - 1 byte
CTB_PKT_LEN_2 = 1        # 01 - 2 bytes
CTB_PKT_LEN_4 = 2        # 10 - 4 bytes
CTB_PKT_LEN_UNDEF = 3    # 11 - no packet length supplied

# Algorithms

# Public Key Algorithms
ALGO_PK_RSA_ENC_OR_SIGN = 1        # RSA (Encrypt or Sign)
ALGO_PK_RSA_ENC_ONLY = 2           # RSA Encrypt-Only
ALGO_PK_RSA_SIGN_ONLY = 3          # RSA Sign-Only
ALGO_PK_ELGAMAL_ENC_ONLY = 16      # Elgamal (Encrypt-Only)
ALGO_PK_DSA = 17                   # DSA (Digital Signature Standard)
ALGO_PK_ELLIPTIC_CURVE = 18        # Elliptic Curve
ALGO_PK_ECDSA = 19                 # ECDSA
ALGO_PK_ELGAMAL_ENC_OR_SIGN = 20   # Elgamal (Encrypt or Sign)
ALGO_PK_DH = 21                    # Diffie-Hellman

algo_pk_to_str = {
    ALGO_PK_RSA_ENC_OR_SIGN : 'RSA (Encrypt or Sign)',
    ALGO_PK_RSA_ENC_ONLY : 'RSA Encrypt-Only',
    ALGO_PK_RSA_SIGN_ONLY : 'RSA Sign-Only',
    ALGO_PK_ELGAMAL_ENC_ONLY : 'Elgamal Encrypt-Only',
    ALGO_PK_DSA : 'DSA (Digital Signature Standard)',
    ALGO_PK_ELLIPTIC_CURVE : 'Elliptic Curve',
    ALGO_PK_ECDSA : 'ECDSA',
    ALGO_PK_ELGAMAL_ENC_OR_SIGN : 'Elgamal (Encrypt or Sign)',
    ALGO_PK_DH : 'Diffie-Hellman'
}    

# Symmetric Key Algorithms
ALGO_SK_PLAIN = 0 # Plaintext or unencrypted data
ALGO_SK_IDEA = 1 # IDEA
ALGO_SK_3DES = 2 # Triple-DES
ALGO_SK_CAST5 = 3 # CAST5
ALGO_SK_BLOWFISH = 4 # Blowfish
ALGO_SK_SAFER_SK128 = 5 # SAFER-SK128
ALGO_SK_DES_SK = 6 # DES/SK
ALGO_SK_AES_128 = 7 # AES 128-bit
ALGO_SK_AES_192 = 8 # AES 192-bit
ALGO_SK_AES_256 = 9 # AES 256-bit

algo_sk_to_str = {
    ALGO_SK_PLAIN : 'Plaintext or unencrypted data',
    ALGO_SK_IDEA : 'IDEA',
    ALGO_SK_3DES : 'Triple-DES',
    ALGO_SK_CAST5 : 'CAST5',
    ALGO_SK_BLOWFISH : 'Blowfish',
    ALGO_SK_SAFER_SK128 : 'SAFER-SK128',
    ALGO_SK_DES_SK : 'DES/SK',
    ALGO_SK_AES_128 : 'AES 128-bit',
    ALGO_SK_AES_192 : 'AES 192-bit',
    ALGO_SK_AES_256 : 'AES 256-bit'
}

# Compression Algorithms
ALGO_COMP_UNCOMP = 0 # Uncompressed
ALGO_COMP_ZIP = 1 # ZIP
ALGO_COMP_ZLIB = 2 # ZLIB

algo_comp_to_str = {
    ALGO_COMP_UNCOMP : 'Uncompressed',
    ALGO_COMP_ZIP : 'ZIP',
    ALGO_COMP_ZLIB : 'ZLIB'
}

# Hash Algorithms
ALGO_HASH_MD5 = 1                  # MD5
ALGO_HASH_SHA1 = 2                 # SHA1
ALGO_HASH_RIPEMD160 = 3            # RIPEMD160
ALGO_HASH_SHA_DBL = 4              # double-width SHA
ALGO_HASH_MD2 = 5                  # MD2
ALGO_HASH_TIGER192 = 6             # TIGER192
ALGO_HASH_HAVAL_5_160 = 7          # HAVAL-5-160

algo_hash_to_str = {
    ALGO_HASH_MD5 : 'MD5',
    ALGO_HASH_SHA1 : 'SHA1',
    ALGO_HASH_RIPEMD160 : 'RIPEMD160',
    ALGO_HASH_SHA_DBL : 'double-width SHA',
    ALGO_HASH_MD2 : 'MD2',
    ALGO_HASH_TIGER192 : 'TIGER192',
    ALGO_HASH_HAVAL_5_160 : 'HAVAL-5-160'
}

# Signature types
SIG_TYPE_DOCUMENT = 0x00           # document signature, binary image
SIG_TYPE_DOCUMENT_CANON = 0x01     # document signature, canonical text
SIG_TYPE_STANDALONE = 0x02         # signature over just subpackets
SIG_TYPE_PK_USER_GEN = 0x10        # public key packet and user ID packet, generic certification
SIG_TYPE_PK_USER_PER = 0x11        # public key packet and user ID packet, persona
SIG_TYPE_PK_USER_CAS = 0x12        # public key packet and user ID packet, casual certification
SIG_TYPE_PK_USER_POS = 0x13        # public key packet and user ID packet, positive certification
SIG_TYPE_SUBKEY_BIND = 0x18        # subkey binding
SIG_TYPE_KEY = 0x1F                # key signature
SIG_TYPE_KEY_REVOKE = 0x20      # key revocation
SIG_TYPE_SUBKEY_REVOKE = 0x28   # subkey revocation
SIG_TYPE_CERT_REVOKE = 0x30     # certificate revocation
SIG_TYPE_TIMESTAMP = 0x40       # timestamp

sig_type_to_str = {
    SIG_TYPE_DOCUMENT : 'document signature, binary image',
    SIG_TYPE_DOCUMENT_CANON : 'document signature, canonical text',
    SIG_TYPE_STANDALONE : 'signature over just subpackets',
    SIG_TYPE_PK_USER_GEN : 'public key packet and user ID packet, generic certification',
    SIG_TYPE_PK_USER_PER : 'public key packet and user ID packet, persona',
    SIG_TYPE_PK_USER_CAS : 'public key packet and user ID packet, casual certification',
    SIG_TYPE_PK_USER_POS : 'public key packet and user ID packet, positive certification',
    SIG_TYPE_SUBKEY_BIND : 'subkey binding',
    SIG_TYPE_KEY : 'key signature',
    SIG_TYPE_KEY_REVOKE : 'key revocation',
    SIG_TYPE_SUBKEY_REVOKE : 'subkey revocation',
    SIG_TYPE_CERT_REVOKE : 'certificate revocation',
    SIG_TYPE_TIMESTAMP : 'timestamp'
}

# Signature sub-packet types
SIG_SUB_TYPE_CREATE_TIME = 2        # signature creation time
SIG_SUB_TYPE_EXPIRE_TIME = 3        # signature expiration time
SIG_SUB_TYPE_EXPORT_CERT = 4        # exportable certification
SIG_SUB_TYPE_TRUST_SIG = 5          # trust signature
SIG_SUB_TYPE_REGEXP = 6             # regular expression
SIG_SUB_TYPE_REVOCABLE = 7          # revocable
SIG_SUB_TYPE_KEY_EXPIRE = 9         # key expiration time
SIG_SUB_TYPE_PLACEHOLDER = 10       # placeholder for backward compatibility
SIG_SUB_TYPE_PREF_SYMM_ALGO = 11    # preferred symmetric algorithms
SIG_SUB_TYPE_REVOKE_KEY = 12        # revocation key
SIG_SUB_TYPE_ISSUER_KEY_ID = 16     # issuer key ID
SIG_SUB_TYPE_NOTATION = 20          # notation data
SIG_SUB_TYPE_PREF_HASH_ALGO = 21    # preferred hash algorithms
SIG_SUB_TYPE_PREF_COMP_ALGO = 22    # preferred compression algorithms
SIG_SUB_TYPE_KEY_SRV_PREF = 23      # key server preferences
SIG_SUB_TYPE_PREF_KEY_SRVR = 24     # preferred key server
SIG_SUB_TYPE_PRIM_USER_ID = 25      # primary user id
SIG_SUB_TYPE_POLICY_URL = 26        # policy URL
SIG_SUB_TYPE_KEY_FLAGS = 27         # key flags
SIG_SUB_TYPE_SGNR_USER_ID = 28      # signer's user id
SIG_SUB_TYPE_REVOKE_REASON = 29     # reason for revocation

sig_sub_type_to_str = {
    SIG_SUB_TYPE_CREATE_TIME : 'signature creation time',
    SIG_SUB_TYPE_EXPIRE_TIME : 'signature expiration time',
    SIG_SUB_TYPE_EXPORT_CERT : 'exportable certification',
    SIG_SUB_TYPE_TRUST_SIG : 'trust signature',
    SIG_SUB_TYPE_REGEXP : 'regular expression',
    SIG_SUB_TYPE_REVOCABLE : 'revocable',
    SIG_SUB_TYPE_KEY_EXPIRE : 'key expiration time',
    SIG_SUB_TYPE_PLACEHOLDER : 'placeholder for backward compatibility',
    SIG_SUB_TYPE_PREF_SYMM_ALGO : 'preferred symmetric algorithms',
    SIG_SUB_TYPE_REVOKE_KEY : 'revocation key',
    SIG_SUB_TYPE_ISSUER_KEY_ID : 'issuer key ID',
    SIG_SUB_TYPE_NOTATION : 'notation data',
    SIG_SUB_TYPE_PREF_HASH_ALGO : 'preferred hash algorithms',
    SIG_SUB_TYPE_PREF_COMP_ALGO : 'preferred compression algorithms',
    SIG_SUB_TYPE_KEY_SRV_PREF : 'key server preferences',
    SIG_SUB_TYPE_PREF_KEY_SRVR : 'preferred key server',
    SIG_SUB_TYPE_PRIM_USER_ID : 'primary user id',
    SIG_SUB_TYPE_POLICY_URL : 'policy URL',
    SIG_SUB_TYPE_KEY_FLAGS : 'key flags',
    SIG_SUB_TYPE_SGNR_USER_ID : "signer's user id",
    SIG_SUB_TYPE_REVOKE_REASON : 'reason for revocation'
}

# in a signature subpacket there may be a revocation reason, these codes indicate
# the reason
REVOKE_REASON_NONE = 0              # No reason specified
REVOKE_REASON_SUPER = 0x01          # Key is superceded
REVOKE_REASON_COMPR = 0x02          # Key has been compromised
REVOKE_REASON_NOT_USED = 0x03       # Key is no longer used
REVOKE_REASON_ID_INVALID = 0x20     # user id information is no longer valid

revoke_reason_to_str = {
    REVOKE_REASON_NONE : 'No reason specified',
    REVOKE_REASON_SUPER : 'Key is superceded',
    REVOKE_REASON_COMPR : 'Key has been compromised',
    REVOKE_REASON_NOT_USED : 'Key is no longer used',
    REVOKE_REASON_ID_INVALID : 'user id information is no longer valid'
}

# These flags are used in a 'key flags' signature subpacket
KEY_FLAGS1_MAY_CERTIFY = 0x01 # This key may be used to certify other keys
KEY_FLAGS1_MAY_SIGN = 0x02 # This key may be used to sign data
KEY_FLAGS1_MAY_ENC_COMM = 0x04 # This key may be used to encrypt communications
KEY_FLAGS1_MAY_ENC_STRG = 0x08 # This key may be used to encrypt storage
KEY_FLAGS1_PRIV_MAYBE_SPLIT = 0x10 # Private component have be split through secret-sharing mech.
KEY_FLAGS1_GROUP = 0x80 # Private component may be among group

# A revocation key subpacket has these class values
REVOKE_KEY_CLASS_MAND = 0x80 # this bit must always be set
REVOKE_KEY_CLASS_SENS = 0x40 # sensitive

def get_whole_number(msg, idx, numlen) :
    """get_whole_number(msg, idx, numlen)
extracts a "whole number" field of length numlen from msg at index idx
returns (<whole number>, new_idx) where the whole number is a long integer
and new_idx is the index of the next element in the message"""
    n = 0L 
    while numlen > 0 :
        b = (struct.unpack("B", msg[idx:idx+1]))[0]
        n = n * 256L + long(b)
        idx = idx + 1
        numlen = numlen - 1
    return (n, idx)

def get_whole_int(msg, idx, numlen) :
    """get_whole_int(msg, idx, numlen)
same as get_whole_number but returns the number as an int for convenience"""
    n, idx = get_whole_number(msg, idx, numlen)
    return int(n), idx

def pack_long(l) :
    """pack_long(l)
    returns big-endian representation of unsigned long integer"""
    arr = []
    while l > 0 :
        arr.insert(0, struct.pack("B", l & 0xff))
        l >>= 8
    return ''.join(arr)
    
def pack_mpi(l) :
    """pack_mpi(l)
    returns the PGP Multi-Precision Integer representation of unsigned long integer"""
    s = pack_long(l)
    # the len is the number of bits, counting only from the MSB,
    # so we need to account for that
    bits = (len(s) - 1) * 8
    if len(s) > 0 :
        n = ord(s[0])
        while n != 0 :
            bits += 1
            n >>= 1
    else :
        bits = 0 # otherwise bits == -8
    return struct.pack(">H", bits) + s

def get_sig_subpak_len(msg, idx) :
    """get_sig_subpak_len(msg, idx)
extracts a signature subpacket length field
returns (subpak_len, new_idx)"""
    plen, idx = get_whole_int(msg, idx, 1)
    if plen < 192 :
        return plen, idx
    if plen < 255 :
        plen2, idx = get_whole_int(msg, idx, 1)
        return ((plen - 192) << 8) + plen2 + 192, idx
    return get_whole_int(msg, idx, 4)

def get_n_mpi(msg, idx) :
    """get_mpi(msg, idx)
    extracts a multi-precision integer field from the message msg at index idx
    returns (n, <mpi>, new_idx) where the mpi is a long integer and new_idx is
    the index of the next element in the message and n is the number of bits of
    precision in <mpi>"""
    ln, idx = get_whole_int(msg, idx, 2)
    return (ln,) + get_whole_number(msg, idx, (ln+7)/8)

def get_mpi(msg, idx) :
    """get_mpi(msg, idx)
extracts a multi-precision integer field from the message msg at index idx
returns (<mpi>, new_idx) where the mpi is a long integer and new_idx is
the index of the next element in the message"""
    l = get_n_mpi(msg, idx)
    return (l[1], l[2])

def str_to_hex(s) :
    return string.join(map(lambda x : string.zfill(hex(ord(x))[2:], 2), list(s)), '')

def duration_to_str(s) :
    if s == 0 :
        return 'never'
    secs = s % 60
    s = s / 60
    mins = s % 60
    s = s / 60
    hrs = s % 60
    s = s / 24
    days = s
    return '%d days %02d:%02d:%02d' % (days, hrs, mins, secs)

def map_to_str(m, vals) :
    slist = []
    # change to a list if it's a single value
    if type(vals) != types.ListType and type(vals) != types.TupleType :
        vals = list((vals,))
    for i in vals :
        if m.has_key(i) :
            slist.append(m[i])
        else :
            slist.append('unknown(' + str(i) + ')')
    return string.join(slist, ', ')

class pgp_packet :
    def __init__(self) :
        self.pkt_typ = None

    def __str__(self) :
        return map_to_str(ctb_pkt_to_str, self.pkt_typ)

class public_key(pgp_packet) :
    def __init__(self) :
        # just initialize the fields
        self.version = None
        self.pk_algo = None
        self.key_size = 0
        self.fingerprint_ = None # we cache this upon calculation

    def fingerprint(self) :
        # return cached value if we have it
        if self.fingerprint_ :
            return self.fingerprint_
        
        # otherwise calculate it now and cache it
        # v3 and v4 are calculated differently
        if self.version == 3 :
            h = md5.new()
            h.update(pack_long(self.pk_rsa_mod))
            h.update(pack_long(self.pk_rsa_exp))
            self.fingerprint_ = h.digest()
        elif self.version == 4 :
            # we hash what would be the whole PGP message containing
            # the pgp certificate
            h = sha.new()
            h.update('\x99')
            # we need to has the length of the packet as well
            buf = self.serialize()
            h.update(struct.pack(">H", len(buf)))
            h.update(buf)
            self.fingerprint_ = h.digest()
        else :
            raise RuntimeError("unknown public key version %d" % self.version)
        return self.fingerprint_

    def key_id(self) :
        if self.version == 3 :
            return pack_long(self.pk_rsa_mod & 0xffffffffffffffffL)
        elif self.version == 4 :
            return self.fingerprint()[-8:]

    def serialize(self) :
        chunks = []
        if self.version == 3 :
            chunks.append(struct.pack('>BIHB', self.version, int(self.timestamp), self.validity, self.pk_algo))
            chunks.append(pack_mpi(self.pk_rsa_mod))
            chunks.append(pack_mpi(self.pk_rsa_exp))
        elif self.version == 4 :
            chunks.append(struct.pack('>BIB', self.version, int(self.timestamp), self.pk_algo))
            if self.pk_algo == ALGO_PK_RSA_ENC_OR_SIGN or self.pk_algo == ALGO_PK_RSA_SIGN_ONLY :
                chunks.append(pack_mpi(self.pk_rsa_mod))
                chunks.append(pack_mpi(self.pk_rsa_exp))
            elif self.pk_algo == ALGO_PK_DSA :
                chunks.append(pack_mpi(self.pk_dsa_prime_p))
                chunks.append(pack_mpi(self.pk_dsa_grp_ord_q))
                chunks.append(pack_mpi(self.pk_dsa_grp_gen_g))
                chunks.append(pack_mpi(self.pk_dsa_pub_key))
            elif self.pk_algo == ALGO_PK_ELGAMAL_ENC_OR_SIGN or self.pk_algo == ALGO_PK_ELGAMAL_ENC_ONLY :
                chunks.append(pack_mpi(self.pk_elgamal_prime_p))
                chunks.append(pack_mpi(self.pk_elgamal_grp_gen_g))
                chunks.append(pack_mpi(self.pk_elgamal_pub_key))
            else :
                raise RuntimeError("unknown public key algorithm %d" % (self.pk_algo))
        return ''.join(chunks)
    
    def deserialize(self, msg, idx, pkt_len) :
        idx_save = idx
        self.version, idx = get_whole_int(msg, idx, 1)
        if self.version != 2 and self.version != 3 and self.version != 4 :
            raise 'unknown public key packet version %d at %d' % (self.version, idx_save)
        if self.version == 2 : # map v2 into v3 for coding simplicity since they're structurally the same
            self.version = 3
        self.timestamp, idx = get_whole_number(msg, idx, 4)
        self.timestamp = float(self.timestamp)
        if self.version == 3 :
            self.validity, idx = get_whole_number(msg, idx, 2)
        self.pk_algo, idx = get_whole_int(msg, idx, 1)
        if self.pk_algo == ALGO_PK_RSA_ENC_OR_SIGN or self.pk_algo == ALGO_PK_RSA_SIGN_ONLY :
            self.key_size, self.pk_rsa_mod, idx = get_n_mpi(msg, idx)
            self.pk_rsa_exp, idx = get_mpi(msg, idx)
        elif self.pk_algo == ALGO_PK_DSA :
            l1, self.pk_dsa_prime_p, idx = get_n_mpi(msg, idx)
            self.pk_dsa_grp_ord_q, idx = get_mpi(msg, idx)
            self.pk_dsa_grp_gen_g, idx = get_mpi(msg, idx)
            l2, self.pk_dsa_pub_key, idx = get_n_mpi(msg, idx)
            self.key_size = l1 + l2
        elif self.pk_algo == ALGO_PK_ELGAMAL_ENC_OR_SIGN or self.pk_algo == ALGO_PK_ELGAMAL_ENC_ONLY :
            self.key_size, self.pk_elgamal_prime_p, idx = get_n_mpi(msg, idx)
            self.pk_elgamal_grp_gen_g, idx = get_mpi(msg, idx)
            self.pk_elgamal_pub_key, idx = get_mpi(msg, idx)
        else :
            raise "unknown public key algorithm %d at %d" % (self.pk_algo, idx_save)

    def __str__(self) :
        sio = cStringIO.StringIO()
        sio.write(pgp_packet.__str__(self) + "\n")
        sio.write("version: " + str(self.version) + "\n")
        sio.write("timestamp: " + time.ctime(self.timestamp) + "\n")
        if self.version == 3 :
            sio.write("validity: " + time.ctime(self.timestamp + self.validity * 24 * 60 * 60) + "\n")
        sio.write("pubkey algo: " + algo_pk_to_str[self.pk_algo] + "\n")
        if self.pk_algo == ALGO_PK_RSA_ENC_OR_SIGN or self.pk_algo == ALGO_PK_RSA_SIGN_ONLY :
            sio.write("pk_rsa_mod: " + hex(self.pk_rsa_mod) + "\n")
            sio.write("pk_rsa_exp: " + hex(self.pk_rsa_exp) + "\n")
        elif self.pk_algo == ALGO_PK_DSA :
            sio.write("pk_dsa_prime_p: " + hex(self.pk_dsa_prime_p) + "\n")
            sio.write("pk_dsa_grp_ord_q: " + hex(self.pk_dsa_grp_ord_q) + "\n")
            sio.write("pk_dsa_grp_gen_g: " + hex(self.pk_dsa_grp_gen_g) + "\n")
            sio.write("pk_dsa_pub_key: " + hex(self.pk_dsa_pub_key) + "\n")
        elif self.pk_algo == ALGO_PK_ELGAMAL_ENC_OR_SIGN or self.pk_algo == ALGO_PK_ELGAMAL_ENC_ONLY :
            sio.write("pk_elgamal_prime_p: " + hex(self.pk_elgamal_prime_p) + "\n")
            sio.write("pk_elgamal_grp_gen_g: " + hex(self.pk_elgamal_grp_gen_g) + "\n")
            sio.write("pk_elgamal_pub_key: " + hex(self.pk_elgamal_pub_key) + "\n")
        return sio.getvalue()

class user_id(pgp_packet) :
    def __init__(self) :
        # just initialize the fields
        id = None

    def deserialize(self, msg, idx, pkt_len) :
        self.id = msg[idx:idx + pkt_len]

    def __str__(self) :
        return pgp_packet.__str__(self) + "\n" + "id: " + self.id + "\n"

class signature(pgp_packet) :
    def __init__(self) :
        # just initialize the fields
        self.version = None
        self.sig_type = None
        self.pk_algo = None
        self.hash_algo = None
        self.hash_frag = None

    def key_id(self) :
        if self.version == 3 :
            return self.key_id_
        else :
            i = self.get_hashed_subpak(SIG_SUB_TYPE_ISSUER_KEY_ID)
            if i :
                return i[1]
            i = self.get_unhashed_subpak(SIG_SUB_TYPE_ISSUER_KEY_ID)
            if i :
                return i[1]
            return None

    def expiration(self) :
        if self.version != 4 :
            raise ValueError('v3 signatures don\'t have expirations')
        i = self.get_hashed_subpak(SIG_SUB_TYPE_KEY_EXPIRE)
        if i :
            return i[1]
        return 0 # if not present then it never expires

    def get_hashed_subpak(self, typ) :
        for i in self.hashed_subpaks :
            if i[0] == typ :
                return i
        return None
    
    def get_unhashed_subpak(self, typ) :
        for i in self.unhashed_subpaks :
            if i[0] == typ :
                return i
        return None
    
    def deserialize_subpacket(self, msg, idx) :
        sublen, idx = get_sig_subpak_len(msg, idx)
        subtype, idx = get_whole_int(msg, idx, 1)
        if subtype == SIG_SUB_TYPE_CREATE_TIME : # key creation time
            tm, idx = get_whole_number(msg, idx, 4)
            return (subtype, float(tm)), idx
        if subtype == SIG_SUB_TYPE_EXPIRE_TIME or subtype == SIG_SUB_TYPE_KEY_EXPIRE :
            s, idx = get_whole_int(msg, idx, 4)
            return (subtype, s), idx
        if subtype == SIG_SUB_TYPE_EXPORT_CERT or subtype == SIG_SUB_TYPE_REVOCABLE :
            bool, idx = get_whole_int(msg, idx, 1)
            return (subtype, bool), idx
        if subtype == SIG_SUB_TYPE_TRUST_SIG : # trust signature
            trust_lvl, idx = get_whole_int(msg, idx, 1)
            trust_amt, idx = get_whole_int(msg, idx, 1)
            return (subtype, trust_lvl, trust_amt), idx
        if subtype == SIG_SUB_TYPE_REGEXP : # regular expression
            expr = msg[idx:idx+sublen-1]
            idx = idx + sublen - 1
            return (subtype, expr), idx
        if subtype == SIG_SUB_TYPE_PREF_SYMM_ALGO or subtype == SIG_SUB_TYPE_PREF_HASH_ALGO or subtype == SIG_SUB_TYPE_PREF_COMP_ALGO or subtype == SIG_SUB_TYPE_KEY_FLAGS :
            algo_list = map(lambda x : ord(x), list(msg[idx:idx+sublen-1]))
            idx = idx + sublen - 1
            return (subtype, algo_list), idx
        if subtype == SIG_SUB_TYPE_REVOKE_KEY : # revocation key
            cls, idx = get_whole_int(msg, idx, 1)
            algo, idx = get_whole_int(msg, idx, 1)
            fprint = msg[idx:idx+20]
            idx = idx + 20
            return (subtype, cls, algo, fprint), idx
        if subtype == SIG_SUB_TYPE_ISSUER_KEY_ID : # issuer key ID
            k_id = msg[idx:idx+8]
            idx = idx + 8
            return (subtype, k_id), idx
        if subtype == SIG_SUB_TYPE_NOTATION : # notation data
            flg1, idx = get_whole_int(msg, idx, 1)
            flg2, idx = get_whole_int(msg, idx, 1)
            flg3, idx = get_whole_int(msg, idx, 1)
            flg4, idx = get_whole_int(msg, idx, 1)
            name_len, idx = get_whole_int(msg, idx, 2)
            val_len, idx = get_whole_int(msg, idx, 2)
            nam = msg[idx:idx+name_len]
            idx = idx + name_len
            val = msg[idx:idx+val_len]
            idx = idx + val_len
            return (subtype, flg1, flg2, flg3, flg4, nam, val), idx
        if subtype == SIG_SUB_TYPE_KEY_SRV_PREF : # key server preferences
            prefs = [ ord(x) for x in msg[idx:idx+sublen-1] ]
            idx = idx + sublen - 1
            return (subtype, prefs), idx
        if subtype == SIG_SUB_TYPE_PREF_KEY_SRVR : # preferred key server
            url = msg[idx:idx+sublen-1]
            idx = idx + sublen - 1
            return (subtype, url), idx
        if subtype == SIG_SUB_TYPE_PRIM_USER_ID : # primary user id
            bool, idx = get_whole_int(msg, idx, 1)
            return (subtype, bool), idx
        if subtype == SIG_SUB_TYPE_POLICY_URL : # policy URL
            url = msg[idx:idx+sublen-1]
            idx = idx + sublen - 1
            return (subtype, url), idx
        if subtype == SIG_SUB_TYPE_SGNR_USER_ID : # signer's user id
            signer_id = msg[idx:idx+sublen-1]
            idx = idx + sublen - 1
            return (subtype, signer_id), idx
        if subtype == SIG_SUB_TYPE_REVOKE_REASON : # reason for revocation
            rev_code, idx = get_whole_int(msg, idx, 1)
            reas_len = sublen - 2
            reas = msg[idx:idx+reas_len]
            idx = idx + reas_len
            return (subtype, rev_code, reas), idx
        # otherwise the subpacket is an unknown type, so we just pack the data in it
        dat = msg[idx:idx+sublen-1]
        idx = idx + sublen - 1
        return (subtype, dat), idx

    def is_primary_user_id(self) :
        """is_primary_user_id()
        returns true if this signature contains a primary user id subpacket with value true"""
        for i in self.hashed_subpaks :
            if i[0] == SIG_SUB_TYPE_PRIM_USER_ID :
                return i[1]
        return 0
    
    def subpacket_to_str(self, sp) :
        if sp[0] == SIG_SUB_TYPE_CREATE_TIME : # signature creation time
            return 'creation time: ' + time.ctime(sp[1])
        if sp[0] == SIG_SUB_TYPE_EXPIRE_TIME : # signature expiration time
            return 'signature expires: ' + duration_to_str(sp[1])
        if sp[0] == SIG_SUB_TYPE_EXPORT_CERT : # exportable certification
            if sp[1] :
                return 'signature exportable: TRUE'
            else :
                return 'signature exportable: FALSE'
        if sp[0] == SIG_SUB_TYPE_TRUST_SIG : # trust signature
            if sp[1] == 0 :
                return 'trust: ordinary'
            if sp[1] == 1 :
                return 'trust: introducer (%d)' % sp[2]
            if sp[1] == 2 :
                return 'trust: meta-introducer (%d)' % sp[2]
            return 'trust: %d %d' % (sp[1], sp[2])
        if sp[0] == SIG_SUB_TYPE_REGEXP : # regular expression
            return 'regexp: ' + sp[1]
        if sp[0] == SIG_SUB_TYPE_REVOCABLE : # revocable
            if sp[1] :
                return 'signature revocable: TRUE'
            else :
                return 'signature revocable: FALSE'
        if sp[0] == SIG_SUB_TYPE_KEY_EXPIRE : # key expiration time
            return 'key expires: ' + duration_to_str(sp[1])
        if sp[0] == SIG_SUB_TYPE_PREF_SYMM_ALGO : # preferred symmetric algorithms
            return 'preferred symmetric algorithms: ' + map_to_str(algo_sk_to_str, sp[1])
        if sp[0] == SIG_SUB_TYPE_REVOKE_KEY : # revocation key
            s = 'revocation key: '
            if sp[1] & REVOKE_KEY_CLASS_SEN :
                s = s + '(sensitive) '
            return s + map_to_str(algo_pk_to_str, sp[2]) + ' ' + str_to_hex(sp[3])
        if sp[0] == SIG_SUB_TYPE_ISSUER_KEY_ID : # issuer key ID
            return 'issuer key id: ' + str_to_hex(sp[1])
        if sp[0] == SIG_SUB_TYPE_NOTATION : # notation data
            return 'notation: flags(%d, %d, %d, %d) name(%s) value(%s)' % sp[1:]
        if sp[0] == SIG_SUB_TYPE_PREF_HASH_ALGO : # preferred hash algorithms
            return 'preferred hash algorithms: ' + map_to_str(algo_hash_to_str, sp[1])
        if sp[0] == SIG_SUB_TYPE_PREF_COMP_ALGO : # preferred compression algorithms
            return 'preferred compression algorithms: ' + map_to_str(algo_comp_to_str, sp[1])
        if sp[0] == SIG_SUB_TYPE_KEY_SRV_PREF : # key server preferences
            s = 'key server preferences: '
            prefs = []
            if sp[1][0] & 0x80 :
                prefs.append('No-modify')
            return s + string.join(prefs, ', ')
        if sp[0] == SIG_SUB_TYPE_PREF_KEY_SRVR : # preferred key server
            return 'preferred key server: %s' % sp[1]
        if sp[0] == SIG_SUB_TYPE_PRIM_USER_ID : # primary user id
            if sp[1] :
                return 'is primary user id'
            else :
                return 'is not primary user id'
        if sp[0] == SIG_SUB_TYPE_POLICY_URL : # policy URL
            return 'policy url: %s' % sp[1]
        if sp[0] == SIG_SUB_TYPE_KEY_FLAGS : # key flags
            flags = []
            flgs1 = 0
            if len(sp[1]) >= 1 :
                flgs1 = sp[1][0]
            if flgs1 & KEY_FLAGS1_MAY_CERTIFY :
                flags.append('may certify other keys')
            if flgs1 & KEY_FLAGS1_MAY_SIGN :
                flags.append('may sign data')
            if flgs1 & KEY_FLAGS1_MAY_ENC_COMM :
                flags.append('may encrypt communications')
            if flgs1 & KEY_FLAGS1_MAY_ENC_STRG :
                flags.append('may encrypt storage')
            if flgs1 & KEY_FLAGS1_PRIV_MAYBE_SPLIT :
                flags.append('private component may have been secret-sharing split')
            if flgs1 & KEY_FLAGS1_GROUP :
                flags.append('group key')
            return 'key flags: ' + string.join(flags, ', ')
        if sp[0] == SIG_SUB_TYPE_SGNR_USER_ID : # signer's user id
            return 'signer id: ' + sp[1]
        if sp[0] == SIG_SUB_TYPE_REVOKE_REASON : # reason for revocation
            reas = ''
            if revoke_reason_to_str.has_key(sp[1]) :
                reas = revoke_reason_to_str[sp[1]]
            return 'reason for revocation: %s, %s' % (reas, sp[2])
        # this means we don't know what the thing is so we just have raw data
        return 'unknown(%d): %s' % (sp[0], str_to_hex(sp[1]))

    def deserialize(self, msg, idx, pkt_len) :
        self.version, idx = get_whole_int(msg, idx, 1)
        if self.version == 2 :
            self.version = 3
        if self.version == 3 :
            hash_len, idx = get_whole_number(msg, idx, 1)
            self.sig_type, idx = get_whole_int(msg, idx, 1)
            self.timestamp, idx = get_whole_number(msg, idx, 4)
            self.timestamp = float(self.timestamp)
            self.key_id_ = msg[idx:idx+8]
            idx = idx + 8
            self.pk_algo, idx = get_whole_int(msg, idx, 1)
            self.hash_algo, idx = get_whole_int(msg, idx, 1)
        elif self.version == 4:
            self.sig_type, idx = get_whole_int(msg, idx, 1)
            self.pk_algo, idx = get_whole_int(msg, idx, 1)
            self.hash_algo, idx = get_whole_int(msg, idx, 1)
            sub_paks_len, idx = get_whole_int(msg, idx, 2)
            sub_paks_end = idx + sub_paks_len
            self.hashed_subpaks = []
            while idx < sub_paks_end :
                sp, idx = self.deserialize_subpacket(msg, idx)
                self.hashed_subpaks.append(sp)
            sub_paks_len, idx = get_whole_int(msg, idx, 2)
            sub_paks_end = idx + sub_paks_len
            self.unhashed_subpaks = []
            while idx < sub_paks_end :
                sp, idx = self.deserialize_subpacket(msg, idx)
                self.unhashed_subpaks.append(sp)
        else :
            raise 'unknown signature packet version %d at %d' % (self.version, idx)
        self.hash_frag, idx = get_whole_number(msg, idx, 2)
        if self.pk_algo == ALGO_PK_RSA_ENC_OR_SIGN or self.pk_algo == ALGO_PK_RSA_SIGN_ONLY :
            self.rsa_sig, idx = get_mpi(msg, idx)
        elif self.pk_algo == ALGO_PK_DSA :
            self.dsa_sig_r, idx = get_mpi(msg, idx)
            self.dsa_sig_s, idx = get_mpi(msg, idx)
        else :
            raise 'unknown public-key algorithm (%d) in signature at %d' % (self.pk_algo, idx)
        return idx

    def __str__(self) :
        sio = cStringIO.StringIO()
        sio.write(pgp_packet.__str__(self) + "\n")
        sio.write("version: " + str(self.version) + "\n")
        sio.write("type: " + sig_type_to_str[self.sig_type] + "\n")
        if self.version == 3 :
            sio.write("timestamp: " + time.ctime(self.timestamp) + "\n")
            sio.write("key_id: " + str_to_hex(self.key_id_) + "\n")
        elif self.version == 4 :
            sio.write("hashed subpackets:\n")
            for i in self.hashed_subpaks :
                sio.write("    " + self.subpacket_to_str(i) + "\n")
            sio.write("unhashed subpackets:\n")
            for i in self.unhashed_subpaks :
                sio.write("    " + self.subpacket_to_str(i) + "\n")
        sio.write("hash_algo: " + algo_hash_to_str[self.hash_algo] + "\n")
        sio.write("hash_frag: " + hex(self.hash_frag) + "\n")
        if self.pk_algo == ALGO_PK_RSA_ENC_OR_SIGN or self.pk_algo == ALGO_PK_RSA_SIGN_ONLY :
            sio.write("pk_algo: RSA\n")
            sio.write("rsa_sig: " + hex(self.rsa_sig) + "\n")
        elif self.pk_algo == ALGO_PK_DSA :
            sio.write("pk_algo: DSA\n")
            sio.write("dsa_sig_r: " + hex(self.dsa_sig_r) + "\n")
            sio.write("dsa_sig_s: " + hex(self.dsa_sig_s) + "\n")
        return sio.getvalue()

#
# This class encapsulates an openpgp public "certificate", which is formed in a message as
# a series of PGP packets of certain types in certain orders
#

class pgp_certificate :
    def __init__(self) :
        self.version = None
        self.public_key = None
        self.revocation = None
        #self.user_id = None
        self.user_ids = []
        self.rvkd_user_ids = []

    def __str__(self) :
        sio = cStringIO.StringIO()
        sio.write("PGP Public Key Certificate v%d\n" % self.version)
        sio.write("Primary ID: %s\n" % self.user_id)
        sio.write(str(self.public_key))
        for uid in self.user_ids :
            sio.write(str(uid[0]))
            for sig in uid[1:] :
                sio.write("   " + str(sig))
        return sio.getvalue()
    
    def get_user_id(self):
        # take the LAST one in the list, not first
        # they appear to be ordered FIFO from the key and that means if you
        # added a key later then it won't show the one you expect
        return self.user_ids[-1][0].id
        
    user_id = property(get_user_id)
    
    def expiration(self) :
        if self.version == 3 :
            if self.public_key.validity == 0 :
                return 0
            return self.public_key.timestamp + self.public_key.validity * 24 * 60 * 60
        else : # self.version == 4
            # this is a bit more complex, we need to find the signature on the
            # key and get its expiration
            u_id = self.user_ids[0]
            for i in u_id[1:] :
                if i.sig_type == SIG_TYPE_PK_USER_GEN :
                    exp = i.expiration()
                    if exp == 0 :
                        return 0
                    return self.public_key.timestamp + exp
            return 0

    def key_size(self) :
        return 0
    
    def load(self, pkts) :
        """load(pkts)
Initialize the pgp_certificate with a list of OpenPGP packets. The list of packets will
be scanned to make sure they are valid for a pgp certificate."""
        
        # each certificate should begin with a public key packet
        if pkts[0].pkt_typ != CTB_PKT_PK_CERT :
            raise ValueError('first PGP packet should be a public-key packet, not %s' % map_to_str(ctb_pkt_to_str, pkts[0].pkt_typ))

        # all versions have a public key although in a v4 cert the main key is only
        # used for signing, never encryption
        self.public_key = pkts[0]

        # ok, then what's the version
        self.version = self.public_key.version

        # now the behavior splits a little depending on the version
        if self.version == 3 :
            pkt_idx = 1

            # second packet could be a revocation
            if pkts[pkt_idx].pkt_typ == CTB_PKT_SIG :
                if pkts[pkt_idx].version != 3 :
                    raise ValueError('version 3 cert has version %d signature' % pkts[pkt_idx].version)
                if pkts[pkt_idx].sig_type != SIG_TYPE_KEY_REVOKE :
                    raise ValueError('v3 cert revocation sig has type %s' % map_to_str(sig_type_to_str, pkts[pkt_idx].sig_type))

                # ok, well at least the type is good, we'll assume the cert is
                # revoked
                self.revocation = pkts[pkt_idx]

                # increment the pkt_idx to go to the next one
                pkt_idx = pkt_idx + 1

            # the following packets are User ID, Signature pairs
            while pkt_idx < len(pkts) :
                # this packet is supposed to be a user id
                if pkts[pkt_idx].pkt_typ != CTB_PKT_USER_ID :
                    raise ValueError('pgp packet %d is not user id, is %s' % (pkt_idx, map_to_str(ctb_pkt_to_str, pkts[pkt_idx].pkt_typ)))

                user_id = [pkts[pkt_idx]]
                pkt_idx = pkt_idx + 1
                is_revoked = 0
                is_primary_user_id = 0

                # there may be a sequence of signatures following the user id which
                # bind it to the key
                while pkt_idx < len(pkts) and pkts[pkt_idx].pkt_typ == CTB_PKT_SIG :
                    if pkts[pkt_idx].sig_type not in (SIG_TYPE_PK_USER_GEN, SIG_TYPE_PK_USER_PER, SIG_TYPE_PK_USER_CAS, SIG_TYPE_PK_USER_POS, SIG_TYPE_CERT_REVOKE) :
                        raise ValueError('signature %d doesn\'t bind user_id to key, is %s' % (pkt_idx, map_to_str(sig_type_to_str, pkts[pkt_idx].sig_typ)))

                    if pkts[pkt_idx].version != 3 :
                        raise ValueError('version 3 cert has version %d signature' % pkts[pkt_idx].version)

                    user_id.append(pkts[pkt_idx])

                    # was the a revocation?
                    if pkts[pkt_idx].sig_type == SIG_TYPE_CERT_REVOKE :
                        is_revoked = 1

                    # is this a primary user id?
                    if not self.cert_id and pkts[pkt_idx].sig_type == SIG_TYPE_PK_USER_GEN :
                        is_primary_user_id = pkt_idx
                    pkt_idx = pkt_idx + 1

                # we get the cert id from the first sig that's a candidate for being
                # the primary user id
                if not self.cert_id and is_primary_user_id and not is_revoked :
                    # the cert type must be generic pubkey and user id
                    self.user_id = user_id[0].id

                # append the user ID and signature(s) onto a list
                if is_revoked :
                    self.rvkd_user_ids.append(user_id)
                else :
                    self.user_ids.append(user_id)

        else : # self.version == 4
            pkt_idx = 1
            self.direct_key_sigs = []
            self.subkeys = []
            self.rvkd_subkeys = []

            # second packet could be a revocation (or a direct key self signature)
            if pkts[pkt_idx].pkt_typ == CTB_PKT_SIG :
                if pkts[pkt_idx].version != 4 :
                    raise ValueError('version 4 cert has version %d signature' % pkts[pkt_idx].version)
                if pkts[pkt_idx].sig_type == SIG_TYPE_KEY_REVOKE :
                    # ok, well at least the type is good, we'll assume the cert is
                    # revoked
                    self.revocation = pkts[pkt_idx]

                    # increment the pkt_idx to go to the next one
                    pkt_idx = pkt_idx + 1

            # there can then be a sequence of direct key signatures
            while pkt_idx < len(pkts) and pkts[pkt_idx].pkt_typ == CTB_PKT_SIG :
                if pkts[pkt_idx].sig_type != SIG_TYPE_KEY :
                    raise ValueError('v4 cert signature has type %s, supposed to be direct key' % map_to_str(sig_type_to_str, pkts[pkt_idx].sig_type))
                self.direct_key_sigs.append(pkts[pkt_idx])
                pkt_idx = pkt_idx + 1
                
            # the following packets are User ID, Signature... sets or subkey, signature... sets
            while pkt_idx < len(pkts) :
                
                # this packet is supposed to be a user id
                if pkts[pkt_idx].pkt_typ == CTB_PKT_USER_ID :
                    user_id = [pkts[pkt_idx]]
                    is_revoked = 0
                    is_primary_user_id = 0

                    pkt_idx = pkt_idx + 1

                    # there may be a sequence of signatures following the user id which
                    # bind it to the key
                    while pkt_idx < len(pkts) and pkts[pkt_idx].pkt_typ == CTB_PKT_SIG :
                        if pkts[pkt_idx].sig_type not in (SIG_TYPE_PK_USER_GEN, SIG_TYPE_PK_USER_PER, SIG_TYPE_PK_USER_CAS, SIG_TYPE_PK_USER_POS, SIG_TYPE_CERT_REVOKE) :
                            raise ValueError('signature %d doesn\'t bind user_id to key, is %s' % (pkt_idx, map_to_str(sig_type_to_str, pkts[pkt_idx].sig_typ)))
                        user_id.append(pkts[pkt_idx])

                        # was the a revocation?
                        if pkts[pkt_idx].sig_type == SIG_TYPE_CERT_REVOKE :
                            is_revoked = 1

                        # is this the primary user id?
                        if pkts[pkt_idx].version == 4 and pkts[pkt_idx].is_primary_user_id() :
                            is_primary_user_id = 1
                        pkt_idx = pkt_idx + 1

                    # append the user ID and signature(s) onto the list
                    if is_revoked :
                        self.rvkd_user_ids.append(user_id)
                    else :
                        if is_primary_user_id :
                            self.user_id = user_id[0].id
                        self.user_ids.append(user_id)

                elif pkts[pkt_idx].pkt_typ == CTB_PKT_PK_SUB :
                    # collect this list of subkey + signatures
                    subkey = [pkts[pkt_idx]]
                    pkt_idx = pkt_idx + 1
                    is_revoked = 0

                    # there may be a sequence of signatures following the user id which
                    # bind it to the key
                    while pkt_idx < len(pkts) and pkts[pkt_idx].pkt_typ == CTB_PKT_SIG :
                        if pkts[pkt_idx].sig_type not in (SIG_TYPE_SUBKEY_BIND, SIG_TYPE_SUBKEY_REVOKE) :
                            raise ValueError('signature %d doesn\'t bind subkey to key, is %s' % (pkt_idx, map_to_str(sig_type_to_str, pkts[pkt_idx].sig_typ)))
                        subkey.append(pkts[pkt_idx])

                        # was this a revocation?
                        if pkts[pkt_idx].sig_type == SIG_TYPE_SUBKEY_REVOKE :
                            is_revoked = 1

                        pkt_idx = pkt_idx + 1

                    # append the user ID and signature(s) onto the list
                    if is_revoked :
                        self.rvkd_subkeys.append(subkey)
                    else :
                        self.subkeys.append(subkey)
                else :
                    raise ValueError('pgp packet %d is not user id or subkey, is %s' % (pkt_idx, map_to_str(ctb_pkt_to_str, pkts[pkt_idx].pkt_typ)))

        # did we get all the things we needed?
        #if not self.user_id :
        # just take the first valid user id we encountered then
        if len(self.user_ids) == 0 :
            raise ValueError('no user id packet was present in the cert')



def get_ctb(msg, idx) :
    """get_ctb(msg, idx)
extracts a the "cypher type bit" information from message msg at index idx
returns (type, len, new_idx) where type is the enumerated type of the packet,
len is the length of the packet, and new_idx is the index of the next element
in the message"""
    b, idx = get_whole_int(msg, idx, 1)
    if (b & CTB_76_MASK) == CTB_76_NORMAL :
        n_len = 0 # undefined length
        if (b & CTB_PKT_LEN_MASK) == CTB_PKT_LEN_1 :
            n_len = 1
        if (b & CTB_PKT_LEN_MASK) == CTB_PKT_LEN_2 :
            n_len = 2
        if (b & CTB_PKT_LEN_MASK) == CTB_PKT_LEN_4 :
            n_len = 4
        if (b & CTB_PKT_LEN_MASK) == CTB_PKT_LEN_UNDEF :
            n_len = 0
        pkt_len = 0
        if n_len > 0 :
            pkt_len, idx = get_whole_int(msg, idx, n_len)
        return (b & CTB_PKTV2_MASK) >> 2, pkt_len, idx
    elif (b & CTB_76_MASK) == CTB_76_NEW :
        plen, idx = get_whole_int(msg, idx, 1)
        if plen < 192 :
            return b & CTB_PKT_MASK, plen, idx
        if plen < 224 :
            plen2, idx = get_whole_int(msg, idx, 1)
            return b & CTB_PKT_MASK, ((plen - 192) << 8) + plen2 + 192, idx
        if plen == 255 :
            plen, idx = get_whole_int(msg, idx, 4)
            return b & CTB_PKT_MASK, plen, idx
        else :
            raise 'partial message bodies are not supported by this version (%d)', b
    else :
        raise "unknown (not \"normal\") cypher type bit %d at byte %d" % (b, idx)

def crc24(msg) :
    crc24_init = 0xb704ce
    crc24_poly = 0x1864cfb

    crc = crc24_init
    for i in list(msg) :
        crc = crc ^ (ord(i) << 16)
        for j in range(0, 8) :
            crc = crc << 1
            if crc & 0x1000000 :
                crc = crc ^ crc24_poly
    return crc & 0xffffff

def decode(msg) :
    # each message is a sequence of packets so we go through the message
    # and generate a list of packets and return that
    pkt_list = []
    idx = 0
    msg_len = len(msg)
    while idx < msg_len :
        pkt_typ, pkt_len, idx = get_ctb(msg, idx)
        pkt = None
        if pkt_typ == CTB_PKT_PK_CERT or pkt_typ == CTB_PKT_PK_SUB :
            pkt = public_key()

        elif pkt_typ == CTB_PKT_USER_ID :
            pkt = user_id()

        elif pkt_typ == CTB_PKT_SIG :
            pkt = signature()

        if pkt :
            pkt.pkt_typ = pkt_typ
            pkt.deserialize(msg, idx, pkt_len)
            if debug :
                debug.write(pkt.__str__() + "\n")
        else :
            raise 'unknown pgp packet type %d at %d' % (pkt_typ, idx)

        pkt_list.append(pkt)

        idx = idx + pkt_len
    return pkt_list

def decode_msg(msg) :
    # first we'll break the block up into lines and trim each line of any 
    # carriage return chars
    pgpkey_lines = map(lambda x : string.rstrip(x), string.split(msg, '\n'))

    # check out block
    in_block = 0
    in_data = 0
    
    block_buf = cStringIO.StringIO()
    for l in pgpkey_lines :
        if not in_block :
            if l == '-----BEGIN PGP PUBLIC KEY BLOCK-----' :
                in_block = 1
            continue

        # are we at the actual data yet?
        if not in_data :
            if len(l) == 0 :
                in_data = 1
            continue
        
        # are we at the checksum line?
        if l[0] == '=' :
            # get the checksum number
            csum = base64.decodestring(l[1:5])
            i = 0
            csum, i = get_whole_number(csum, i, 3)

            # convert the base64 cert data to binary data
            cert_msg = base64.decodestring(block_buf.getvalue())
            block_buf.close()

            # check the checksum
            if csum != crc24(cert_msg) :
                raise 'bad checksum on pgp message'

            # ok, the sum looks ok so we'll actually decode the thing
            pkt_list = decode(cert_msg)

            # turn it into a real cert
            cert = pgp_certificate()
            cert.load(pkt_list)
            return cert
        
        # add the data to our buffer then
        block_buf.write(l)

    return None


--- NEW FILE plugins.py ---
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# Copyright 2005 Duke University

import os
import glob
import imp
import atexit
from constants import *
import ConfigParser
import config 
import Errors

# XXX: break the API for how plugins define config file options
#   - they should just be able to manipulate the YumConf and RepoConf classes
#     directly, adding Option instances as required
#   - cleaner, more flexible and 
#   - update PLUGINS document
#   - will take care of the following existing TODOs:
# TODO: detect conflicts between builtin yum options and registered plugin
# options (will require refactoring of config.py)
# TODO: plugins should be able to specify convertor functions for config vars

# TODO: should plugin searchpath be affected by installroot option?

# TODO: cleaner method to query installed packages rather than exposing RpmDB
# (Panu?)

# TODO: consistent case of YumPlugins methods

# TODO: expose progress bar interface

# TODO: allow plugins to define new repository types

# TODO: check for *_hook methods that aren't supported

# TODO "log" slot? To allow plugins to do customised logging/history (say to a
# SQL db)

# TODO: multiversion plugin support

# TODO: config vars marked as PLUG_OPT_WHERE_ALL should inherit defaults from
#   the [main] setting if the user doesn't specify them

# TODO: allow plugins to extend shell commands

# TODO: allow plugins to extend commands (on the command line)

# TODO: More developer docs:  use epydoc as API begins to stablise

# TODO: test the API by implementing some crack from bugzilla
#         - http://devel.linux.duke.edu/bugzilla/show_bug.cgi?id=181 
#         - http://devel.linux.duke.edu/bugzilla/show_bug.cgi?id=270
#         - http://devel.linux.duke.edu/bugzilla/show_bug.cgi?id=310
#         - http://devel.linux.duke.edu/bugzilla/show_bug.cgi?id=431
#         - http://devel.linux.duke.edu/bugzilla/show_bug.cgi?id=88 (?)
#         - http://devel.linux.duke.edu/bugzilla/show_bug.cgi?id=396 (DONE)


# The API_VERSION constant defines the current plugin API version. It is used
# to decided whether or not plugins can be loaded. It is compared against the
# 'requires_api_version' attribute of each plugin. The version number has the
# format: "major_version.minor_version".
# 
# For a plugin to be loaded the major version required by the plugin must match
# the major version in API_VERSION. Additionally, the minor version in
# API_VERSION must be greater than or equal the minor version required by the
# plugin.
# 
# If a change to yum is made that break backwards compatibility wrt the plugin
# API, the major version number must be incremented and the minor version number
# reset to 0. If a change is made that doesn't break backwards compatibility,
# then the minor number must be incremented.
API_VERSION = '2.2'

# Plugin types
TYPE_CORE = 0
TYPE_INTERFACE = 1
ALL_TYPES = (TYPE_CORE, TYPE_INTERFACE)

# Mapping of slots to conduit classes
SLOT_TO_CONDUIT = {
    'config': 'ConfigPluginConduit',
    'init': 'InitPluginConduit',
    'predownload': 'DownloadPluginConduit',
    'postdownload': 'DownloadPluginConduit',
    'prereposetup': 'PreRepoSetupPluginConduit',
    'postreposetup': 'PostRepoSetupPluginConduit',
    'close': 'PluginConduit',
    'pretrans': 'MainPluginConduit',
    'posttrans': 'MainPluginConduit',
    'exclude': 'MainPluginConduit',
    'preresolve': 'DepsolvePluginConduit',
    'postresolve': 'DepsolvePluginConduit',
    }

# Enumerate all slot names
SLOTS = SLOT_TO_CONDUIT.keys()

class PluginYumExit(Exception):
    '''Used by plugins to signal that yum should stop
    '''

class YumPlugins:
    '''
    Manager class for Yum plugins.
    '''

    def __init__(self, base, searchpath, optparser=None, types=None):
        '''Initialise the instance.

        @param base: The
        @param searchpath: A list of paths to look for plugin modules.
        @param optparser: The OptionParser instance for this run (optional).
            Use to allow plugins to extend command line options.
        @param types: A sequence specifying the types of plugins to load.
            This should be sequnce containing one or more of the TYPE_...
            constants. If None (the default), all plugins will be loaded.
        '''

        self.searchpath = searchpath
        self.base = base
        self.optparser = optparser
        self.cmdline = (None, None)
        if not types:
            types = ALL_TYPES

        self._importplugins(types)

        self.opts = {}
        self.cmdlines = {}

        # Call close handlers when yum exit's
        atexit.register(self.run, 'close')

        # Let plugins register custom config file options
        self.run('config')

    def run(self, slotname, **kwargs):
        '''Run all plugin functions for the given slot.
        '''
        # Determine handler class to use
        conduitcls = SLOT_TO_CONDUIT.get(slotname, None)
        if conduitcls is None:
            raise ValueError('unknown slot name "%s"' % slotname)
        conduitcls = eval(conduitcls)       # Convert name to class object

        for modname, func in self._pluginfuncs[slotname]:
            self.base.log(3, 'Running "%s" handler for "%s" plugin' % (
                slotname, modname))
    
            _, conf = self._plugins[modname]
            func(conduitcls(self, self.base, conf, **kwargs))

    def _importplugins(self, types):
        '''Load plugins matching the given types.
        '''

        # Initialise plugin dict
        self._plugins = {}
        self._pluginfuncs = {}
        for slot in SLOTS:
            self._pluginfuncs[slot] = []

        # Import plugins 
        for dir in self.searchpath:
            if not os.path.isdir(dir):
                continue
            for modulefile in glob.glob('%s/*.py' % dir):
                self._loadplugin(modulefile, types)

    def _loadplugin(self, modulefile, types):
        '''Attempt to import a plugin module and register the hook methods it
        uses.
        '''
        dir, modname = os.path.split(modulefile)
        modname = modname.split('.py')[0]

        conf = self._getpluginconf(modname)
        if not conf or not config.getOption(conf, 'main', 'enabled', False,
                config.BoolOption()):
            self.base.log(3, '"%s" plugin is disabled' % modname)
            return

        fp, pathname, description = imp.find_module(modname, [dir])
        module = imp.load_module(modname, fp, pathname, description)

        # Check API version required by the plugin
        if not hasattr(module, 'requires_api_version'):
             raise Errors.ConfigError(
                'Plugin "%s" doesn\'t specify required API version' % modname
                )
        if not apiverok(API_VERSION, module.requires_api_version):
            raise Errors.ConfigError(
                'Plugin "%s" requires API %s. Supported API is %s.' % (
                    modname,
                    module.requires_api_version,
                    API_VERSION,
                    ))

        # Check plugin type against filter
        plugintypes = getattr(module, 'plugin_type', ALL_TYPES)
        if not isinstance(plugintypes, (list, tuple)):
            plugintypes = (plugintypes,)

        if len(plugintypes) < 1:
            return
        for plugintype in plugintypes:
            if plugintype not in types:
                return

        self.base.log(2, 'Loading "%s" plugin' % modname)

        # Store the plugin module and its configuration file
        if not self._plugins.has_key(modname):
            self._plugins[modname] = (module, conf)
        else:
            raise Errors.ConfigError('Two or more plugins with the name "%s" ' \
                    'exist in the plugin search path' % modname)
        
        for slot in SLOTS:
            funcname = slot+'_hook'
            if hasattr(module, funcname):
                self._pluginfuncs[slot].append(
                        (modname, getattr(module, funcname))
                        )

    def _getpluginconf(self, modname):
        '''Parse the plugin specific configuration file and return a
        IncludingConfigParser instance representing it. Returns None if there
        was an error reading or parsing the configuration file.
        '''
        #XXX: should this use installroot?
        conffilename = os.path.join('/etc/yum/pluginconf.d', modname+'.conf')

        try:
            parser = config.IncludingConfigParser()
            parser.read(conffilename)
        except ConfigParser.Error, e:
            raise Errors.ConfigError("Couldn't parse %s: %s" % (conffilename,
                str(e)))
        except IOError, e:
            self.base.log(2, str(e))
            return None

        return parser

    def registeropt(self, name, valuetype, where, default):
        '''Called from plugins to register a new config file option.

        @param name: Name of the new option.
        @param valuetype: Option type (PLUG_OPT_BOOL, PLUG_OPT_STRING ...)
        @param where: Where the option should be available in the config file.
            (PLUG_OPT_WHERE_MAIN, PLUG_OPT_WHERE_REPO, ...)
        @param default: Default value for the option if not set by the user.
        '''
        if self.opts.has_key(name):
            raise Errors.ConfigError('Plugin option conflict: ' \
                    'an option named "%s" has already been registered' % name
                    )
        self.opts[name] = (valuetype, where, default)

    def parseopts(self, conf, repos):
        '''Parse any configuration options registered by plugins

        @param conf: the yumconf instance holding Yum's global options
        @param repos: a list of all repository objects
        '''
        #XXX: with the new config stuff this is an ugly hack!
        # See first TODO at top of this file

        type2opt =  {
            PLUG_OPT_STRING: config.Option(),
            PLUG_OPT_INT: config.IntOption(),
            PLUG_OPT_BOOL: config.BoolOption(),
            PLUG_OPT_FLOAT: config.FloatOption(),
            }

        # Process [main] options first
        for name, (vtype, where, default) in self.opts.iteritems(): 
            if where in (PLUG_OPT_WHERE_MAIN, PLUG_OPT_WHERE_ALL):
                val = config.getOption(conf.cfg, 'main', name, default,
                        type2opt[vtype])
                setattr(conf, name, val)

        # Process repository level options
        for repo in repos:
            for name, (vtype, where, default) in self.opts.iteritems(): 
                if where in (PLUG_OPT_WHERE_REPO, PLUG_OPT_WHERE_ALL):
                    val = config.getOption(conf.cfg, repo.id, name, default,
                            type2opt[vtype])
                    repo.set(name, val)

    def setCmdLine(self, opts, commands):
        '''Set the parsed command line options so that plugins can access them
        '''
        self.cmdline = (opts, commands)


class DummyYumPlugins:
    '''
    This class provides basic emulation of the YumPlugins class. It exists so
    that calls to plugins.run() don't fail if plugins aren't in use.
    '''
    def run(self, *args, **kwargs):
        pass

    def setCmdLine(self, *args, **kwargs):
        pass

class PluginConduit:
    def __init__(self, parent, base, conf):
        self._parent = parent
        self._base = base
        self._conf = conf

    def info(self, level, msg):
        self._base.log(level, msg)

    def error(self, level, msg):
        self._base.errorlog(level, msg)

    def promptYN(self, msg):
        self.info(2, msg)
        if self._base.conf.assumeyes:
            return 1
        else:
            return self._base.userconfirm()

    def getYumVersion(self):
        import yum
        return yum.__version__

    def getOptParser(self):
        '''Return the optparse.OptionParser instance for this execution of Yum

        In the "config" and "init" slots a plugin may add extra options to this
        instance to extend the command line options that Yum exposes.

        In all other slots a plugin may only read the OptionParser instance.
        Any modification of the instance at this point will have no effect. 
        
        See the getCmdLine() method for details on how to retrieve the parsed
        values of command line options.

        @return: the global optparse.OptionParser instance used by Yum. May be
            None if an OptionParser isn't in use.
        '''
        return self._parent.optparser

    def confString(self, section, opt, default=None):
        '''Read a string value from the plugin's own configuration file

        @param section: Configuration file section to read.
        @param opt: Option name to read.
        @param default: Value to read if option is missing.
        @return: String option value read, or default if option was missing.
        '''
        return config.getOption(self._conf, section, opt, default,
                config.Option())

    def confInt(self, section, opt, default=None):
        '''Read an integer value from the plugin's own configuration file

        @param section: Configuration file section to read.
        @param opt: Option name to read.
        @param default: Value to read if option is missing.
        @return: Integer option value read, or default if option was missing or
            could not be parsed.
        '''
        return config.getOption(self._conf, section, opt, default,
                config.IntOption())

    def confFloat(self, section, opt, default=None):
        '''Read a float value from the plugin's own configuration file

        @param section: Configuration file section to read.
        @param opt: Option name to read.
        @param default: Value to read if option is missing.
        @return: Float option value read, or default if option was missing or
            could not be parsed.
        '''
        return config.getOption(self._conf, section, opt, default,
                config.FloatOption())

    def confBool(self, section, opt, default=None):
        '''Read a boolean value from the plugin's own configuration file

        @param section: Configuration file section to read.
        @param opt: Option name to read.
        @param default: Value to read if option is missing.
        @return: Boolean option value read, or default if option was missing or
            could not be parsed.
        '''
        return config.getOption(self._conf, section, opt, default,
                config.BoolOption())

class ConfigPluginConduit(PluginConduit):

    def registerOpt(self, *args, **kwargs):
        '''Register a yum configuration file option.

        Arguments are as for YumPlugins.registeropt().
        '''
        self._parent.registeropt(*args, **kwargs)


class InitPluginConduit(PluginConduit):

    def getConf(self):
        return self._base.conf

    def getRepos(self):
        '''Return Yum's container object for all configured repositories.

        @return: Yum's RepoStorage instance
        '''
        return self._base.repos

class PreRepoSetupPluginConduit(InitPluginConduit):

    def getCmdLine(self):
        '''Return parsed command line options.

        @return: (options, commands) as returned by OptionParser.parse_args()
        '''
        return self._parent.cmdline

    def getRpmDB(self):
        '''Return a representation of local RPM database. This allows querying
        of installed packages.

        @return: rpmUtils.RpmDBHolder instance
        '''
        self._base.doTsSetup()
        self._base.doRpmDBSetup()
        return self._base.rpmdb

class PostRepoSetupPluginConduit(PreRepoSetupPluginConduit):

    def getGroups(self):
        '''Return group information.

        @return: yum.comps.Comps instance
        '''
        self._base.doGroupSetup()
        return self._base.comps

class DownloadPluginConduit(PostRepoSetupPluginConduit):

    def __init__(self, parent, base, conf, pkglist, errors=None):
        PostRepoSetupPluginConduit.__init__(self, parent, base, conf)
        self._pkglist = pkglist
        self._errors = errors

    def getDownloadPackages(self):
        '''Return a list of package objects representing packages to be
        downloaded.
        '''
        return self._pkglist

    def getErrors(self):
        '''Return a dictionary of download errors. 
        
        The returned dictionary is indexed by package object. Each element is a
        list of strings describing the error.
        '''
        if not self._errors:
            return {}
        return self._errors

class MainPluginConduit(PostRepoSetupPluginConduit):

    def getPackages(self, repo=None):
        if repo:
            arg = repo.id
        else:
            arg = None
        return self._base.pkgSack.returnPackages(arg)

    def getPackageByNevra(self, nevra):
        '''Retrieve a package object from the packages loaded by Yum using
        nevra information 
        
        @param nevra: A tuple holding (name, epoch, version, release, arch)
            for a package
        @return: A PackageObject instance (or subclass)
        '''
        return self._base.getPackageObject(nevra)

    def delPackage(self, po):
        self._base.pkgSack.delPackage(po)

    def getTsInfo(self):
        return self._base.tsInfo

class DepsolvePluginConduit(MainPluginConduit):
    def __init__(self, parent, base, conf, rescode=None, restring=[]):
        MainPluginConduit.__init__(self, parent, base, conf)
        self.resultcode = rescode
        self.resultstring = restring

def parsever(apiver):
    maj, min = apiver.split('.')
    return int(maj), int(min)

def apiverok(a, b):
    '''Return true if API version "a" supports API version "b"
    '''
    a = parsever(a)
    b = parsever(b)

    if a[0] != b[0]:
        return 0

    if a[1] >= b[1]:
        return 1

    return 0


--- NEW FILE repos.py ---
#!/usr/bin/python -tt
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# Copyright 2004 Duke University 

import os
import os.path
import re
import fnmatch
import urlparse
import types
import time

import Errors
from urlgrabber.grabber import URLGrabber
import urlgrabber.mirror
from urlgrabber.grabber import URLGrabError
from repomd import repoMDObject
from repomd import mdErrors
from repomd import packageSack
from packages import YumAvailablePackage
import mdcache
import parser

_is_fnmatch_pattern = re.compile(r"[*?[]").search

class YumPackageSack(packageSack.PackageSack):
    """imports/handles package objects from an mdcache dict object"""
    def __init__(self, packageClass):
        packageSack.PackageSack.__init__(self)
        self.pc = packageClass
        self.added = {}
        
    def addDict(self, repoid, datatype, dataobj, callback=None):
        if self.added.has_key(repoid):
            if datatype in self.added[repoid]:
                return

        total = len(dataobj.keys())
        if datatype == 'metadata':
            current = 0        
            for pkgid in dataobj.keys():
                current += 1
                if callback: callback.progressbar(current, total, repoid)
                pkgdict = dataobj[pkgid]
                po = self.pc(pkgdict, repoid)
                po.simple['id'] = pkgid
                self._addToDictAsList(self.pkgsByID, pkgid, po)
                self.addPackage(po)
            
            if not self.added.has_key(repoid):
                self.added[repoid] = []
            self.added[repoid].append('metadata')
            # indexes will need to be rebuilt
            self.indexesBuilt = 0
            
        elif datatype in ['filelists', 'otherdata']:
            if self.added.has_key(repoid):
                if 'metadata' not in self.added[repoid]:
                    raise Errors.RepoError, '%s md for %s imported before primary' \
                           % (datatype, repoid)
            current = 0
            for pkgid in dataobj.keys():
                current += 1
                if callback: callback.progressbar(current, total, repoid)
                pkgdict = dataobj[pkgid]
                if self.pkgsByID.has_key(pkgid):
                    for po in self.pkgsByID[pkgid]:
                        po.importFromDict(pkgdict, repoid)

            self.added[repoid].append(datatype)
            # indexes will need to be rebuilt
            self.indexesBuilt = 0
        else:
            # umm, wtf?
            pass
            
class RepoStorage:
    """This class contains multiple repositories and core configuration data
       about them."""
       
    def __init__(self):
        self.repos = {} # list of repos by repoid pointing a repo object 
                        # of repo options/misc data
        self.callback = None # progress callback used for populateSack() for importing the xml files
        self.cache = 0
        # Check to see if we can import sqlite stuff
        try:
            import sqlitecache
            import sqlitesack
        except ImportError:
            self.sqlite = False
        else:
            self.sqlite = True
            self.sqlitecache = sqlitecache
            
        self._selectSackType()
    
    def _selectSackType(self):

        if (self.sqlite):
            import sqlitecache
            import sqlitesack
            self.pkgSack = sqlitesack.YumSqlitePackageSack(sqlitesack.YumAvailablePackageSqlite)
        else:
            self.pkgSack = YumPackageSack(YumAvailablePackage)
        
    def __str__(self):
        return str(self.repos.keys())
    

    def add(self, repoobj):
        if self.repos.has_key(repoobj.id):
            raise Errors.RepoError, 'Repository %s is listed more than once in the configuration' % (repoobj.id)
        self.repos[repoobj.id] = repoobj
        

    def delete(self, repoid):
        if self.repos.has_key(repoid):
            del self.repos[repoid]
            
    def sort(self):
        repolist = self.repos.values()
        repolist.sort()
        return repolist
        
    def getRepo(self, repoid):
        try:
            return self.repos[repoid]
        except KeyError, e:
            raise Errors.RepoError, \
                'Error getting repository data for %s, repository not found' % (repoid)

    def findRepos(self,pattern):
        """find all repositories matching fnmatch `pattern`"""

        result = []
        
        for item in pattern.split(','):
            item = item.strip()
            match = re.compile(fnmatch.translate(item)).match
            for name,repo in self.repos.items():
                if match(name):
                    result.append(repo)
        return result
        
    def disableRepo(self, repoid):
        """disable a repository from use
        
        fnmatch wildcards may be used to disable a group of repositories.
        returns repoid of disabled repos as list
        """
        repos = []
        if _is_fnmatch_pattern(repoid) or repoid.find(',') != -1:
            for repo in self.findRepos(repoid):
                repos.append(repo.id)
                repo.disable()
        else:
            thisrepo = self.getRepo(repoid)
            repos.append(thisrepo.id)
            thisrepo.disable()
        
        return repos
        
    def enableRepo(self, repoid):
        """enable a repository for use
        
        fnmatch wildcards may be used to enable a group of repositories.
        returns repoid of enables repos as list
        """
        repos = []
        if _is_fnmatch_pattern(repoid) or repoid.find(',') != -1:
            for repo in self.findRepos(repoid):
                repos.append(repo.id)
                repo.enable()
        else:
            thisrepo = self.getRepo(repoid)
            repos.append(thisrepo.id)
            thisrepo.enable()
        
        return repos
        
    def listEnabled(self):
        """return list of enabled repo objects"""
        returnlist = []
        for repo in self.repos.values():
            if repo.enabled:
                returnlist.append(repo)

        return returnlist

    def listGroupsEnabled(self):
        """return a list of repo objects that have groups enabled"""
        returnlist = []
        for repo in self.listEnabled():
            if repo.enablegroups:
                returnlist.append(repo)

        return returnlist

    def setCache(self, cacheval):
        """sets cache value in all repos"""
        self.cache = cacheval
        for repo in self.repos.values():
            repo.cache = cacheval

    def setCacheDir(self, cachedir):
        """sets the cachedir value in all repos"""

        for repo in self.repos.values():
            repo.basecachedir = cachedir

    def setProgressBar(self, obj):
        """set's the progress bar for downloading files from repos"""
        
        for repo in self.repos.values():
            repo.callback = obj
            repo.setupGrab()

    def setFailureCallback(self, obj):
        """set's the failure callback for all repos"""
        
        for repo in self.repos.values():
            repo.failure_obj = obj
            repo.setupGrab()
            
                
    def populateSack(self, which='enabled', with='metadata', callback=None, pickleonly=0):
        """This populates the package sack from the repositories, two optional 
           arguments: which='repoid, enabled, all'
                      with='metadata, filelists, otherdata, all'"""

        if not callback:
            callback = self.callback
        myrepos = []
        if which == 'enabled':
            myrepos = self.listEnabled()
        elif which == 'all':
            myrepos = self.repos.values()
        else:
            if type(which) == types.ListType:
                for repo in which:
                    if isinstance(repo, Repository):
                        myrepos.append(repo)
                    else:
                        repobj = self.getRepo(repo)
                        myrepos.append(repobj)
            elif type(which) == types.StringType:
                repobj = self.getRepo(which)
                myrepos.append(repobj)

        if with == 'all':
            data = ['metadata', 'filelists', 'otherdata']
        else:
            data = [ with ]
         
        for repo in myrepos:
            if not hasattr(repo, 'cacheHandler'):
                if (self.sqlite):
                    repo.cacheHandler = self.sqlitecache.RepodataParserSqlite(
                            storedir=repo.cachedir, 
                            repoid=repo.id,
                            callback=callback,
                            )
                else:
                    repo.cacheHandler = mdcache.RepodataParser(
                            storedir=repo.cachedir, 
                            callback=callback
                            )
            for item in data:
                if self.pkgSack.added.has_key(repo.id):
                    if item in self.pkgSack.added[repo.id]:
                        continue
                        
                if item == 'metadata':
                    xml = repo.getPrimaryXML()
                    (ctype, csum) = repo.repoXML.primaryChecksum()
                    dobj = repo.cacheHandler.getPrimary(xml, csum)
                    if not pickleonly:
                        self.pkgSack.addDict(repo.id, item, dobj, callback) 
                    del dobj
                        
                elif item == 'filelists':
                    xml = repo.getFileListsXML()
                    (ctype, csum) = repo.repoXML.filelistsChecksum()
                    dobj = repo.cacheHandler.getFilelists(xml, csum)
                    if not pickleonly:
                        self.pkgSack.addDict(repo.id, item, dobj, callback) 
                    del dobj
                        
                        
                elif item == 'otherdata':
                    xml = repo.getOtherXML()
                    (ctype, csum) = repo.repoXML.otherChecksum()
                    dobj = repo.cacheHandler.getOtherdata(xml, csum)
                    if not pickleonly:
                        self.pkgSack.addDict(repo.id, item, dobj, callback)
                    del dobj
                    
                        
                else:
                    # how odd, just move along
                    continue
            # get rid of all this stuff we don't need now
            del repo.cacheHandler
        
        
class Repository:
    """this is an actual repository object"""       

    def __init__(self, repoid):
        self.id = repoid
        self.name = repoid # name is repoid until someone sets it to a real name
        # some default (ish) things
        self.urls = []
        self.gpgcheck = 0
        self.enabled = 0
        self.enablegroups = 0
        self.groupsfilename = 'yumgroups.xml' # something some freaks might 
                                              # eventually want
        self.setkeys = []
        self.repoMDFile = 'repodata/repomd.xml'
        self.repoXML = None
        self.cache = 0
        self.callback = None # callback for the grabber
        self.failure_obj = None
        self.mirrorlist = None # filename/url of mirrorlist file
        self.mirrorlistparsed = 0
        self.baseurl = [] # baseurls from the config file
        self.yumvar = {} # empty dict of yumvariables for $string replacement
        self.proxy_password = None
        self.proxy_username = None
        self.proxy = None
        self._proxy_dict = {}
        self.metadata_cookie_fn = 'cachecookie'
        self.groups_added = False
        
        # throw in some stubs for things that will be set by the config class
        self.basecachedir = ""
        self.cachedir = ""
        self.pkgdir = ""
        self.hdrdir = ""

        # holder for stuff we've grabbed
        self.retrieved = { 'primary':0, 'filelists':0, 'other':0, 'groups':0 }

        
    def __cmp__(self, other):
        if self.id > other.id:
            return 1
        elif self.id < other.id:
            return -1
        else:
            return 0

    def __str__(self):
        return self.id

    def _checksum(self, sumtype, file, CHUNK=2**16):
        """takes filename, hand back Checksum of it
           sumtype = md5 or sha
           filename = /path/to/file
           CHUNK=65536 by default"""
           
        # chunking brazenly lifted from Ryan Tomayko
        try:
            if type(file) is not types.StringType:
                fo = file # assume it's a file-like-object
            else:           
                fo = open(file, 'r', CHUNK)
                
            if sumtype == 'md5':
                import md5
                sum = md5.new()
            elif sumtype == 'sha':
                import sha
                sum = sha.new()
            else:
                raise Errors.RepoError, 'Error Checksumming file, wrong \
                                         checksum type %s' % sumtype
            chunk = fo.read
            while chunk: 
                chunk = fo.read(CHUNK)
                sum.update(chunk)
    
            if type(file) is types.StringType:
                fo.close()
                del fo
                
            return sum.hexdigest()
        except (EnvironmentError, IOError, OSError):
            raise Errors.RepoError, 'Error opening file for checksum'
        
    def dump(self):
        output = '[%s]\n' % self.id
        vars = ['name', 'bandwidth', 'enabled', 'enablegroups', 
                 'gpgcheck', 'includepkgs', 'keepalive', 'proxy',
                 'proxy_password', 'proxy_username', 'exclude', 
                 'retries', 'throttle', 'timeout', 'mirrorlist', 
                 'cachedir', 'gpgkey', 'pkgdir', 'hdrdir']
        vars.sort()
        for attr in vars:
            output = output + '%s = %s\n' % (attr, getattr(self, attr))
        output = output + 'baseurl ='
        for url in self.urls:
            output = output + ' %s\n' % url
        
        return output
    
    def enable(self):
        self.baseurlSetup()
        self.set('enabled', 1)
    
    def disable(self):
        self.set('enabled', 0)
    
    def check(self):
        """self-check the repo information  - if we don't have enough to move
           on then raise a repo error"""
        if len(self.urls) < 1:
            raise Errors.RepoError, \
             'Cannot find a valid baseurl for repo: %s' % self.id
           
                
    def set(self, key, value):
        """sets a generic attribute of this repository"""
        self.setkeys.append(key)
        setattr(self, key, value)
        
    def unset(self, key):
        """delete an attribute of this repository"""
        self.setkeys.remove(key)
        delattr(self, key)
   
    def listSetKeys(self):
        return self.setkeys
    
    def doProxyDict(self):
        if self._proxy_dict:
            return
        
        self._proxy_dict = {} # zap it
        proxy_string = None
        if self.proxy not in [None, '_none_']:
            proxy_string = '%s' % self.proxy
            if self.proxy_username is not None:
                proxy_parsed = urlparse.urlsplit(self.proxy, allow_fragments=0)
                proxy_proto = proxy_parsed[0]
                proxy_host = proxy_parsed[1]
                proxy_rest = proxy_parsed[2] + '?' + proxy_parsed[3]
                proxy_string = '%s://%s@%s%s' % (proxy_proto,
                        self.proxy_username, proxy_host, proxy_rest)
                        
                if self.proxy_password is not None:
                    proxy_string = '%s://%s:%s@%s%s' % (proxy_proto,
                              self.proxy_username, self.proxy_password,
                              proxy_host, proxy_rest)
                                                 
        if proxy_string is not None:
            self._proxy_dict['http'] = proxy_string
            self._proxy_dict['https'] = proxy_string
            self._proxy_dict['ftp'] = proxy_string

    def setupGrab(self):
        """sets up the grabber functions with the already stocked in urls for
           the mirror groups"""

        if self.failovermethod == 'roundrobin':
            mgclass = urlgrabber.mirror.MGRandomOrder
        else:
            mgclass = urlgrabber.mirror.MirrorGroup
        
        
        self.grabfunc = URLGrabber(keepalive=self.keepalive, 
                                   bandwidth=self.bandwidth,
                                   retry=self.retries,
                                   throttle=self.throttle,
                                   progress_obj=self.callback,
                                   proxies = self.proxy_dict,
                                   failure_callback=self.failure_obj,
                                   timeout=self.timeout,
                                   reget='simple')
                                   
        self.grab = mgclass(self.grabfunc, self.urls)

    def dirSetup(self):
        """make the necessary dirs, if possible, raise on failure"""

        cachedir = os.path.join(self.basecachedir, self.id)
        pkgdir = os.path.join(cachedir, 'packages')
        hdrdir = os.path.join(cachedir, 'headers')
        self.set('cachedir', cachedir)
        self.set('pkgdir', pkgdir)
        self.set('hdrdir', hdrdir)
        cookie = self.cachedir + '/' + self.metadata_cookie_fn
        self.set('metadata_cookie', cookie)

        for dir in [self.cachedir, self.hdrdir, self.pkgdir]:
            if self.cache == 0:
                if os.path.exists(dir) and os.path.isdir(dir):
                    continue
                else:
                    try:
                        os.makedirs(dir, mode=0755)
                    except OSError, e:
                        raise Errors.RepoError, \
                            "Error making cache directory: %s error was: %s" % (dir, e)
            else:
                if not os.path.exists(dir):
                    raise Errors.RepoError, \
                        "Cannot access repository dir %s" % dir
 
    def baseurlSetup(self):
        """go through the baseurls and mirrorlists and populate self.urls 
           with valid ones, run  self.check() at the end to make sure it worked"""

        goodurls = []
        if self.mirrorlist and not self.mirrorlistparsed:
            mirrorurls = getMirrorList(self.mirrorlist, self.proxy_dict)
            self.mirrorlistparsed = 1
            for url in mirrorurls:
                url = parser.varReplace(url, self.yumvar)
                self.baseurl.append(url)

        for url in self.baseurl:
            url = parser.varReplace(url, self.yumvar)
            (s,b,p,q,f,o) = urlparse.urlparse(url)
            if s not in ['http', 'ftp', 'file', 'https']:
                print 'not using ftp, http[s], or file for repos, skipping - %s' % (url)
                continue
            else:
                goodurls.append(url)
                
        self.set('urls', goodurls)
        self.check()
        self.setupGrab() # update the grabber for the urls

    def get(self, url=None, relative=None, local=None, start=None, end=None,
            copy_local=0, checkfunc=None, text=None, reget='simple', cache=True):
        """retrieve file from the mirrorgroup for the repo
           relative to local, optionally get range from
           start to end, also optionally retrieve from a specific baseurl"""
           
        # if local or relative is None: raise an exception b/c that shouldn't happen
        # if url is not None - then do a grab from the complete url - not through
        # the mirror, raise errors as need be
        # if url is None do a grab via the mirror group/grab for the repo
        # return the path to the local file

        if cache:
            headers = None
        else:
            headers = (('Pragma', 'no-cache'),)

        if local is None or relative is None:
            raise Errors.RepoError, \
                  "get request for Repo %s, gave no source or dest" % self.id
                  
        if self.failure_obj:
            (f_func, f_args, f_kwargs) = self.failure_obj
            self.failure_obj = (f_func, f_args, f_kwargs)
        
        if self.cache == 1:
            if os.path.exists(local): # FIXME - we should figure out a way
                return local          # to run the checkfunc from here
                
            else: # ain't there - raise
                raise Errors.RepoError, \
                    "Caching enabled but no local cache of %s from %s" % (local,
                           self)

        if url is not None:
            ug = URLGrabber(keepalive = self.keepalive, 
                            bandwidth = self.bandwidth,
                            retry = self.retries,
                            throttle = self.throttle,
                            progres_obj = self.callback,
                            copy_local = copy_local,
                            reget = reget,
                            proxies = self.proxy_dict,
                            failure_callback = self.failure_obj,
                            timeout=self.timeout,
                            checkfunc=checkfunc,
                            http_headers=headers,
                            )
            
            remote = url + '/' + relative

            try:
                result = ug.urlgrab(remote, local,
                                    text=text,
                                    range=(start, end), 
                                    )
            except URLGrabError, e:
                raise Errors.RepoError, \
                    "failed to retrieve %s from %s\nerror was %s" % (relative, self.id, e)
              
        else:
            try:
                result = self.grab.urlgrab(relative, local,
                                           text = text,
                                           range = (start, end),
                                           copy_local=copy_local,
                                           reget = reget,
                                           checkfunc=checkfunc,
                                           http_headers=headers,
                                           )
            except URLGrabError, e:
                raise Errors.RepoError, "failure: %s from %s: %s" % (relative, self.id, e)
                
        return result
           
        
    def metadataCurrent(self):
        """Check if there is a metadata_cookie and check its age. If the 
        age of the cookie is less than metadata_expire time then return true
        else return False"""
        
        val = False
        if os.path.exists(self.metadata_cookie):
            cookie_info = os.stat(self.metadata_cookie)
            if cookie_info[8] + self.metadata_expire > time.time():
                val = True
        
        return val

    
    def setMetadataCookie(self):
        """if possible, set touch the metadata_cookie file"""
        
        check = self.metadata_cookie
        if not os.path.exists(self.metadata_cookie):
            check = self.cachedir
        
        if os.access(check, os.W_OK):
            fo = open(self.metadata_cookie, 'w+')
            fo.close()
            del fo
            
    def getRepoXML(self, text=None):
        """retrieve/check/read in repomd.xml from the repository"""

        remote = self.repoMDFile
        local = self.cachedir + '/repomd.xml'
        if self.repoXML is not None:
            return

        if self.cache or self.metadataCurrent():
            if not os.path.exists(local):
                raise Errors.RepoError, 'Cannot find repomd.xml file for %s' % (self)
            else:
                result = local
        else:
            checkfunc = (self._checkRepoXML, (), {})
            try:
                result = self.get(relative=remote,
                                  local=local,
                                  copy_local=1,
                                  text=text,
                                  reget=None,
                                  checkfunc=checkfunc,
                                  cache=self.http_caching == 'all')

            except URLGrabError, e:
                raise Errors.RepoError, 'Error downloading file %s: %s' % (local, e)
            
            # if we have a 'fresh' repomd.xml then update the cookie
            self.setMetadataCookie()

        try:
            self.repoXML = repoMDObject.RepoMD(self.id, result)
        except mdErrors.RepoMDError, e:
            raise Errors.RepoError, 'Error importing repomd.xml from %s: %s' % (self, e)
    
    def _checkRepoXML(self, fo):
        if type(fo) is types.InstanceType:
            filepath = fo.filename
        else:
            filepath = fo
        
        try:
            foo = repoMDObject.RepoMD(self.id, filepath)
        except mdErrors.RepoMDError, e:
            raise URLGrabError(-1, 'Error importing repomd.xml for %s: %s' % (self, e))


    def __getProxyDict(self):
        self.doProxyDict()
        if self._proxy_dict:
            return self._proxy_dict
        return None

    # consistent access to how proxy information should look (and ensuring
    # that it's actually determined for the repo)
    proxy_dict = property(__getProxyDict)

    def _checkMD(self, fn, mdtype):
        """check the metadata type against its checksum"""
        
        csumDict = { 'primary' : self.repoXML.primaryChecksum,
                     'filelists' : self.repoXML.filelistsChecksum,
                     'other' : self.repoXML.otherChecksum,
                     'group' : self.repoXML.groupChecksum }
                     
        csumMethod = csumDict[mdtype]
        
        (r_ctype, r_csum) = csumMethod() # get the remote checksum
        
        if type(fn) == types.InstanceType: # this is an urlgrabber check
            file = fn.filename
        else:
            file = fn
            
        try:
            l_csum = self._checksum(r_ctype, file) # get the local checksum
        except Errors.RepoError, e:
            raise URLGrabError(-3, 'Error performing checksum')
            
        if l_csum == r_csum: 
            return 1
        else:
            raise URLGrabError(-1, 'Metadata file does not match checksum')
        

        
    def _retrieveMD(self, mdtype):
        """base function to retrieve metadata files from the remote url
           returns the path to the local metadata file of a 'mdtype'
           mdtype can be 'primary', 'filelists', 'other' or 'group'."""
        locDict = { 'primary' : self.repoXML.primaryLocation,
                    'filelists' : self.repoXML.filelistsLocation,
                    'other' : self.repoXML.otherLocation,
                    'group' : self.repoXML.groupLocation }
        
        locMethod = locDict[mdtype]
        
        (r_base, remote) = locMethod()
        fname = os.path.basename(remote)
        local = self.cachedir + '/' + fname

        if self.retrieved.has_key(mdtype):
            if self.retrieved[mdtype]: # got it, move along
                return local

        if self.cache == 1:
            if os.path.exists(local):
                try:
                    self._checkMD(local, mdtype)
                except URLGrabError, e:
                    raise Errors.RepoError, \
                        "Caching enabled and local cache: %s does not match checksum" % local
                else:
                    return local
                    
            else: # ain't there - raise
                raise Errors.RepoError, \
                    "Caching enabled but no local cache of %s from %s" % (local,
                           self)
                           
        if os.path.exists(local):
            try:
                self._checkMD(local, mdtype)
            except URLGrabError, e:
                pass
            else:
                self.retrieved[mdtype] = 1
                return local # it's the same return the local one

        try:
            checkfunc = (self._checkMD, (mdtype,), {})
            local = self.get(relative=remote, local=local, copy_local=1,
                             checkfunc=checkfunc, reget=None, 
                             cache=self.http_caching == 'all')
        except URLGrabError, e:
            raise Errors.RepoError, \
                "Could not retrieve %s matching remote checksum from %s" % (local, self)
        else:
            self.retrieved[mdtype] = 1
            return local

    
    def getPrimaryXML(self):
        """this gets you the path to the primary.xml file, retrieving it if we 
           need a new one"""

        return self._retrieveMD('primary')
        
    
    def getFileListsXML(self):
        """this gets you the path to the filelists.xml file, retrieving it if we 
           need a new one"""

        return self._retrieveMD('filelists')

    def getOtherXML(self):
        return self._retrieveMD('other')

    def getGroups(self):
        """gets groups and returns group file path for the repository, if there 
           is none it returns None"""
        try:
            file = self._retrieveMD('group')
        except URLGrabError:
            file = None
        
        return file
        
        

def getMirrorList(mirrorlist, pdict = None):
    """retrieve an up2date-style mirrorlist file from a url, 
       we also s/$ARCH/$BASEARCH/ and move along
       returns a list of the urls from that file"""
       
    returnlist = []
    if hasattr(urlgrabber.grabber, 'urlopen'):
        urlresolver = urlgrabber.grabber
    else: 
        urlresolver = urllib
    
    scheme = urlparse.urlparse(mirrorlist)[0]
    if scheme == '':
        url = 'file://' + mirrorlist
    else:
        url = mirrorlist

    try:
        fo = urlresolver.urlopen(url, proxies=pdict)
    except urlgrabber.grabber.URLGrabError, e:
        fo = None

    if fo is not None: 
        content = fo.readlines()
        for line in content:
            if re.match('^\s*\#.*', line) or re.match('^\s*$', line):
                continue
            mirror = re.sub('\n$', '', line) # no more trailing \n's
            (mirror, count) = re.subn('\$ARCH', '$BASEARCH', mirror)
            returnlist.append(mirror)
    
    return returnlist



--- NEW FILE sqlitecache.py ---
#!/usr/bin/python -tt

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# Copyright 2005 Duke University 

# TODO
# - Add support for multiple checksums per rpm (not required)

import os
import sqlite
import time
import mdparser
from sqlitesack import encodefiletypelist,encodefilenamelist

# This version refers to the internal structure of the sqlite cache files
# increasing this number forces all caches of a lower version number
# to be re-generated
dbversion = '7'

class RepodataParserSqlite:
    def __init__(self, storedir, repoid, callback=None):
        self.storedir = storedir
        self.callback = callback
        self.repodata = {
            'metadata': {},
            'filelists': {},
            'otherdata': {}
        }
        self.repoid = repoid
        self.debug = 0

    def loadCache(self,filename):
        """Load cache from filename, check if it is valid and that dbversion 
        matches the required dbversion"""
        db = sqlite.connect(filename)
        cur = db.cursor()
        cur.execute("select * from db_info")
        info = cur.fetchone()
        # If info is not in there this is an incompelete cache file
        # (this could happen when the user hits ctrl-c or kills yum
        # when the cache is being generated or updated)
        if (not info):
            raise sqlite.DatabaseError, "Incomplete database cache file"

        # Now check the database version
        if (info['dbversion'] != dbversion):
            self.log(2, "Warning: cache file is version %s, we need %s, will regenerate" % (
                info['dbversion'], dbversion))
            raise sqlite.DatabaseError, "Older version of yum sqlite"

        # This appears to be a valid database, return checksum value and 
        # database object
        return (info['checksum'],db)
        
    def getFilename(self,location):
        return location + '.sqlite'
            
    def getDatabase(self, location, cachetype):
        filename = self.getFilename(location)
        dbchecksum = ''
        # First try to open an existing database
        try:
            f = open(filename)
            (dbchecksum,db) = self.loadCache(filename)
        except (IOError,sqlite.DatabaseError,KeyError):
            # If it doesn't exist, create it
            db = self.makeSqliteCacheFile(filename,cachetype)
        return (db,dbchecksum)

    def _getbase(self, location, checksum, metadatatype):
        (db, dbchecksum) = self.getDatabase(location, metadatatype)
        # db should now contain a valid database object, check if it is
        # up to date
        if (checksum != dbchecksum):
            self.log(3, "%s sqlite cache needs updating, reading in metadata" % (metadatatype))
            parser = mdparser.MDParser(location)
            self.updateSqliteCache(db, parser, checksum, metadatatype)
        db.commit()
        return db

    def getPrimary(self, location, checksum):
        """Load primary.xml.gz from an sqlite cache and update it 
           if required"""
        return self._getbase(location, checksum, 'primary')

    def getFilelists(self, location, checksum):
        """Load filelist.xml.gz from an sqlite cache and update it if 
           required"""
        return self._getbase(location, checksum, 'filelists')

    def getOtherdata(self, location, checksum):
        """Load other.xml.gz from an sqlite cache and update it if required"""
        return self._getbase(location, checksum, 'other')
         
    def createTablesFilelists(self,db):
        """Create the required tables for filelists metadata in the sqlite 
           database"""
        cur = db.cursor()
        self.createDbInfo(cur)
        # This table is needed to match pkgKeys to pkgIds
        cur.execute("""CREATE TABLE packages(
            pkgKey INTEGER PRIMARY KEY,
            pkgId TEXT)
        """)
        cur.execute("""CREATE TABLE filelist(
            pkgKey INTEGER,
            dirname TEXT,
            filenames TEXT,
            filetypes TEXT)
        """)
        cur.execute("CREATE INDEX keyfile ON filelist (pkgKey)")
        cur.execute("CREATE INDEX pkgId ON packages (pkgId)")
    
    def createTablesOther(self,db):
        """Create the required tables for other.xml.gz metadata in the sqlite 
           database"""
        cur = db.cursor()
        self.createDbInfo(cur)
        # This table is needed to match pkgKeys to pkgIds
        cur.execute("""CREATE TABLE packages(
            pkgKey INTEGER PRIMARY KEY,
            pkgId TEXT)
        """)
        cur.execute("""CREATE TABLE changelog(
            pkgKey INTEGER,
            author TEXT,
            date TEXT,
            changelog TEXT)
        """)
        cur.execute("CREATE INDEX keychange ON changelog (pkgKey)")
        cur.execute("CREATE INDEX pkgId ON packages (pkgId)")
        
    def createTablesPrimary(self,db):
        """Create the required tables for primary metadata in the sqlite 
           database"""

        cur = db.cursor()
        self.createDbInfo(cur)
        # The packages table contains most of the information in primary.xml.gz

        q = 'CREATE TABLE packages(\n' \
            'pkgKey INTEGER PRIMARY KEY,\n'

        cols = []
        for col in PackageToDBAdapter.COLUMNS:
            cols.append('%s TEXT' % col)
        q += ',\n'.join(cols) + ')'

        cur.execute(q)

        # Create requires, provides, conflicts and obsoletes tables
        # to store prco data
        for t in ('requires','provides','conflicts','obsoletes'):
            cur.execute("""CREATE TABLE %s (
              name TEXT,
              flags TEXT,
              epoch TEXT,
              version TEXT,
              release TEXT,
              pkgKey TEXT)
            """ % (t))
        # Create the files table to hold all the file information
        cur.execute("""CREATE TABLE files (
            name TEXT,
            type TEXT,
            pkgKey TEXT)
        """)
        # Create indexes for faster searching
        cur.execute("CREATE INDEX packagename ON packages (name)")
        cur.execute("CREATE INDEX providesname ON provides (name)")
        cur.execute("CREATE INDEX packageId ON packages (pkgId)")
        db.commit()
    
    def createDbInfo(self,cur):
        # Create the db_info table, this contains sqlite cache metadata
        cur.execute("""CREATE TABLE db_info (
            dbversion TEXT,
            checksum TEXT)
        """)

    def insertHash(self,table,hash,cursor):
        """Insert the key value pairs in hash into a database table"""

        keys = hash.keys()
        values = hash.values()
        query = "INSERT INTO %s (" % (table)
        query += ",".join(keys)
        query += ") VALUES ("
        # Quote all values by replacing None with NULL and ' by ''
        for x in values:
            if (x == None):
              query += "NULL,"
            else:
              try:
                query += "'%s'," % (x.replace("'","''"))
              except AttributeError:
                query += "'%s'," % x
        # Remove the last , from query
        query = query[:-1]
        # And replace it with )
        query += ")"
        cursor.execute(query.encode('utf8'))
        return cursor.lastrowid
             
    def makeSqliteCacheFile(self, filename, cachetype):
        """Create an initial database in the given filename"""

        # If it exists, remove it as we were asked to create a new one
        if (os.path.exists(filename)):
            self.log(3, "Warning: cache already exists, removing old version")
            try:
                os.unlink(filename)
            except OSError:
                pass

        # Try to create the databse in filename, or use in memory when
        # this fails
        try:
            f = open(filename,'w')
            db = sqlite.connect(filename) 
        except IOError:
            self.log(1, "Warning could not create sqlite cache file, using in memory cache instead")
            db = sqlite.connect(":memory:")

        # The file has been created, now create the tables and indexes
        if (cachetype == 'primary'):
            self.createTablesPrimary(db)
        elif (cachetype == 'filelists'):
            self.createTablesFilelists(db)
        elif (cachetype == 'other'):
            self.createTablesOther(db)
        else:
            raise sqlite.DatabaseError, "Sorry don't know how to store %s" % (cachetype)
        return db

    def addPrimary(self, pkgId, package, cur):
        """Add a package to the primary cache"""
        # Store the package info into the packages table
        pkgKey = self.insertHash('packages', PackageToDBAdapter(package), cur)

        # Now store all prco data
        for ptype in package.prco:
            for entry in package.prco[ptype]:
                data = {
                    'pkgKey': pkgKey,
                    'name': entry.get('name'),
                    'flags': entry.get('flags'),
                    'epoch': entry.get('epoch'),
                    'version': entry.get('ver'),
                    'release': entry.get('rel'),
                }
                self.insertHash(ptype,data,cur)
        
        # Now store all file information
        for f in package.files:
            data = {
                'name': f,
                'type': package.files[f],
                'pkgKey': pkgKey,
            }
            self.insertHash('files',data,cur)

    def addFilelists(self, pkgId, package,cur):
        """Add a package to the filelists cache"""
        pkginfo = {'pkgId': pkgId}
        pkgKey = self.insertHash('packages',pkginfo, cur)
        dirs = {}
        for (filename,filetype) in package.files.iteritems():
            (dirname,filename) = (os.path.split(filename))
            if (dirs.has_key(dirname)):
                dirs[dirname]['files'].append(filename)
                dirs[dirname]['types'].append(filetype)
            else:
                dirs[dirname] = {}
                dirs[dirname]['files'] = [filename]
                dirs[dirname]['types'] = [filetype]

        for (dirname,dir) in dirs.items():
            data = {
                'pkgKey': pkgKey,
                'dirname': dirname,
                'filenames': encodefilenamelist(dir['files']),
                'filetypes': encodefiletypelist(dir['types'])
            }
            self.insertHash('filelist',data,cur)

    def addOther(self, pkgId, package,cur):
        pkginfo = {'pkgId': pkgId}
        pkgKey = self.insertHash('packages', pkginfo, cur)
        for entry in package['changelog']:
            data = {
                'pkgKey': pkgKey,
                'author': entry.get('author'),
                'date': entry.get('date'),
                'changelog': entry.get('value'),
            }
            self.insertHash('changelog', data, cur)

    def updateSqliteCache(self, db, parser, checksum, cachetype):
        """Update the sqlite cache by making it fit the packages described
        in dobj (info that has been read from primary.xml metadata) afterwards
        update the checksum of the database to checksum"""
       
        t = time.time()
        delcount = 0
        newcount = 0

        # We start be removing the old db_info, as it is no longer valid
        cur = db.cursor()
        cur.execute("DELETE FROM db_info") 

        # First create a list of all pkgIds that are in the database now
        cur.execute("SELECT pkgId, pkgKey from packages")
        currentpkgs = {}
        for pkg in cur.fetchall():
            currentpkgs[pkg['pkgId']] = pkg['pkgKey']

        if (cachetype == 'primary'):
            deltables = ("packages","files","provides","requires", 
                        "conflicts","obsoletes")
        elif (cachetype == 'filelists'):
            deltables = ("packages","filelist")
        elif (cachetype == 'other'):
            deltables = ("packages","changelog")
        else:
            raise sqlite.DatabaseError,"Unknown type %s" % (cachetype)
        
        # Add packages that are not in the database yet and build up a list of
        # all pkgids in the current metadata

        all_pkgIds = {}
        for package in parser:

            if self.callback is not None:
                self.callback.progressbar(parser.count, parser.total, self.repoid)

            pkgId = package['pkgId']
            all_pkgIds[pkgId] = 1

            # This package is already in the database, skip it now
            if (currentpkgs.has_key(pkgId)):
                continue

            # This is a new package, lets insert it
            newcount += 1
            if cachetype == 'primary':
                self.addPrimary(pkgId, package, cur)
            elif cachetype == 'filelists':
                self.addFilelists(pkgId, package, cur)
            elif cachetype == 'other':
                self.addOther(pkgId, package, cur)

        # Remove those which are not in dobj
        delpkgs = []
        for (pkgId, pkgKey) in currentpkgs.items():
            if not all_pkgIds.has_key(pkgId):
                delcount += 1
                delpkgs.append(str(pkgKey))
        delpkgs = "("+",".join(delpkgs)+")"
        for table in deltables:
            cur.execute("DELETE FROM "+table+ " where pkgKey in %s" % delpkgs)

        cur.execute("INSERT into db_info (dbversion,checksum) VALUES (%s,%s)",
                (dbversion,checksum))
        db.commit()
        self.log(2, "Added %s new packages, deleted %s old in %.2f seconds" % (
                newcount, delcount, time.time()-t))
        return db

    def log(self, level, msg):
        '''Log to callback (if set)
        '''
        if self.callback:
            self.callback.log(level, msg)

class PackageToDBAdapter:

    '''
    Adapt a PrimaryEntry instance to suit the sqlite database. 

    This hides certain attributes and converts some column names in order to
    decouple the parser implementation from the sqlite database schema.
    '''

    NAME_MAPS = {
        'rpm_package': 'package',
        'version': 'ver',
        'release': 'rel',
        'rpm_license': 'license',
        'rpm_vendor': 'vendor',
        'rpm_group': 'group',
        'rpm_buildhost': 'buildhost',
        'rpm_sourcerpm': 'sourcerpm',
        'rpm_packager': 'packager',
        }
    
    COLUMNS = (
            'pkgId',
            'name',
            'arch',
            'version',
            'epoch',
            'release',
            'summary',
            'description',
            'url',
            'time_file',
            'time_build',
            'rpm_license',
            'rpm_vendor',
            'rpm_group',
            'rpm_buildhost',
            'rpm_sourcerpm',
            'rpm_header_start',
            'rpm_header_end',
            'rpm_packager',
            'size_package',
            'size_installed',
            'size_archive',
            'location_href',
            'location_base',
            'checksum_type',
            'checksum_value',
            )

    def __init__(self, package):
        self._pkg = package

    def __getitem__(self, k):
        return self._pkg[self.NAME_MAPS.get(k, k)]

    def keys(self):
        return self.COLUMNS

    def values(self):
        out = []
        for k in self.keys():
            out.append(self[k])
        return out



--- NEW FILE sqlitesack.py ---
#!/usr/bin/python -tt

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# Copyright 2005 Duke University 

#
# Implementation of the YumPackageSack class that uses an sqlite backend
#

import os
import os.path
import types
import repos
from packages import YumAvailablePackage
from repomd import mdUtils, mdErrors

# Simple subclass of YumAvailablePackage that can load 'simple headers' from
# the database when they are requested
class YumAvailablePackageSqlite(YumAvailablePackage):
    def __init__(self, pkgdict, repoid):
        YumAvailablePackage.__init__(self,pkgdict,repoid)
        self.sack = pkgdict.sack
        self.pkgId = pkgdict.pkgId
        self.simple['id'] = self.pkgId
        self.changelog = None
    
    def loadChangelog(self):
        if hasattr(self, 'dbusedother'):
            return
        self.dbusedother = 1
        self.changelog = self.sack.getChangelog(self.pkgId)

    def returnSimple(self, varname):
        if (not self.simple.has_key(varname) and not hasattr(self,'dbusedsimple')):
            # Make sure we only try once to get the stuff from the database
            self.dbusedsimple = 1
            details = self.sack.getPackageDetails(self.pkgId)
            self.importFromDict(details,self.simple['repoid'])

        return YumAvailablePackage.returnSimple(self,varname)

    def loadFiles(self):
        if (hasattr(self,'dbusedfiles')):
            return
        self.dbusedfiles = 1
        self.files = self.sack.getFiles(self.pkgId)

    def returnChangelog(self):
        self.loadChangelog()
        return YumAvailablePackage.returnChangelog(self)
            
    def returnFileEntries(self, ftype='file'):
        self.loadFiles()
        return YumAvailablePackage.returnFileEntries(self,ftype)
    
    def returnFileTypes(self):
        self.loadFiles()
        return YumAvailablePackage.returnFileTypes(self)

    def returnPrco(self, prcotype):
        if not self.prco[prcotype]:
           self.prco = self.sack.getPrco(self.pkgId, prcotype)
        return self.prco[prcotype]

class YumSqlitePackageSack(repos.YumPackageSack):
    """ Implementation of a PackageSack that uses sqlite cache instead of fully
    expanded metadata objects to provide information """

    def __init__(self, packageClass):
        # Just init as usual and create a dict to hold the databases
        repos.YumPackageSack.__init__(self,packageClass)
        self.primarydb = {}
        self.filelistsdb = {}
        self.otherdb = {}
        self.excludes = {}
        
    def buildIndexes(self):
        # We don't need these
        return

    def _checkIndexes(self, failure='error'):
        return

    # Remove a package
    # Because we don't want to remove a package from the database we just
    # add it to the exclude list
    def delPackage(self, obj):
        repoid = obj.repoid
        self.excludes[repoid][obj.pkgId] = 1

    def addDict(self, repoid, datatype, dataobj, callback=None):
        if (not self.excludes.has_key(repoid)): 
            self.excludes[repoid] = {}
        if datatype == 'metadata':
            if (self.primarydb.has_key(repoid)):
              return
            self.added[repoid] = ['primary']
            self.primarydb[repoid] = dataobj
        elif datatype == 'filelists':
            if (self.filelistsdb.has_key(repoid)):
              return
            self.added[repoid] = ['filelists']
            self.filelistsdb[repoid] = dataobj
        elif datatype == 'otherdata':
            if (self.otherdb.has_key(repoid)):
              return
            self.added[repoid] = ['otherdata']
            self.otherdb[repoid] = dataobj
        else:
            # We can not handle this yet...
            raise "Sorry sqlite does not support %s" % (datatype)
    
    def getChangelog(self,pkgId):
        result = []
        for (rep,cache) in self.otherdb.items():
            cur = cache.cursor()
            cur.execute("select changelog.date as date,\
                changelog.author as author,\
                changelog.changelog as changelog from packages,changelog where packages.pkgId = %s and packages.pkgKey = changelog.pkgKey",pkgId)
            for ob in cur.fetchall():
                result.append(( ob['date'],
                                ob['author'],
                                ob['changelog']
                              ))
        return result

    def getPrco(self,pkgId, prcotype=None):
        if prcotype is not None:
            result = {'requires': [], 'provides': [], 'obsoletes': [], 'conflicts': []}
        else:
            result = { prcotype: [] }
        for (rep, cache) in self.primarydb.items():
            cur = cache.cursor()
            for prco in result.keys():
                cur.execute("select %s.name as name, %s.version as version,\
                    %s.release as release, %s.epoch as epoch, %s.flags as flags\
                    from packages,%s\
                    where packages.pkgId = %s and packages.pkgKey = %s.pkgKey", prco, prco, prco, prco, prco, prco, pkgId, prco)
                for ob in cur.fetchall():
                    name = ob['name']
                    version = ob['version']
                    release = ob['release']
                    epoch = ob['epoch']
                    flags = ob['flags']
                    result[prco].append((name, flags, (epoch, version, release)))
        return result

    # Get all files for a certain pkgId from the filelists.xml metadata
    def getFiles(self,pkgId):
        for (rep,cache) in self.filelistsdb.items():
            found = False
            result = {}
            cur = cache.cursor()
            cur.execute("select filelist.dirname as dirname,\
                filelist.filetypes as filetypes,\
                filelist.filenames as filenames from packages,filelist\
                where packages.pkgId = %s and packages.pkgKey = filelist.pkgKey", pkgId)
            for ob in cur.fetchall():
                found = True
                dirname = ob['dirname']
                filetypes = decodefiletypelist(ob['filetypes'])
                filenames = decodefilenamelist(ob['filenames'])
                while(filenames):
                    if dirname:
                        filename = dirname+'/'+filenames.pop()
                    else:
                        filename = filenames.pop()
                    filetype = filetypes.pop()
                    result.setdefault(filetype,[]).append(filename)
            if (found):
                return result    
        return {}
            
    # Search packages that either provide something containing name
    # or provide a file containing name 
    def searchAll(self, name, query_type='like'):
        
        # This should never be called with a name containing a %
        assert(name.find('%') == -1)
        result = {}
        quotename = name.replace("'","''")
        (dirname,filename) = os.path.split(name)
        
        # check provides
        for (rep,cache) in self.primarydb.items():
            cur = cache.cursor()
            cur.execute("select DISTINCT packages.pkgId as pkgId from provides,packages where provides.name LIKE '%%%s%%' AND provides.pkgKey = packages.pkgKey" % quotename)
            for ob in cur.fetchall():
                if (self.excludes[rep].has_key(ob['pkgId'])):
                    continue

                pkgid = ob['pkgId']
                if not result.has_key(pkgid):
                    pkg = self.getPackageDetails(ob['pkgId'])
                    result[pkgid] = (self.pc(pkg,rep))
        
        # check filelists/dirlists
        for (rep,cache) in self.filelistsdb.items():
            querystrings = []
            # dirnames
            # just the dirname
            if dirname != '':
                tmp = "select packages.pkgId as pkgId,\
                filelist.dirname as dirname,\
                filelist.filetypes as filetypes,\
                filelist.filenames as filenames \
                from packages,filelist where \
                filelist.dirname LIKE '%%%s%%' \
                AND (filelist.pkgKey = packages.pkgKey)" % (dirname)
                querystrings.append(tmp)
                
            # look at full quotename
            tmp = "select packages.pkgId as pkgId,\
                filelist.dirname as dirname,\
                filelist.filetypes as filetypes,\
                filelist.filenames as filenames \
                from packages,filelist where \
                filelist.dirname LIKE '%%%s%%' \
                AND (filelist.pkgKey = packages.pkgKey)" % (quotename)
            querystrings.append(tmp)

            # filenames
            tmp = "select packages.pkgId as pkgId,\
                filelist.dirname as dirname,\
                filelist.filetypes as filetypes,\
                filelist.filenames as filenames \
                from packages,filelist where \
                filelist.filenames LIKE '%%%s%%'\
                AND (filelist.pkgKey = packages.pkgKey)" % (filename)
            
            querystrings.append(tmp)
        
            for querystring in querystrings:
                cur = cache.cursor()
                cur.execute("%s" % querystring)

                # cull the results for false positives
                for ob in cur.fetchall():
                    # Check if it is an actual match
                    # The query above can give false positives, when
                    # a package provides /foo/aaabar it will also match /foo/bar
                    if (self.excludes[rep].has_key(ob['pkgId'])):
                        continue
                    real = False
        
                    for filename in decodefilenamelist(ob['filenames']):
                        if (ob['dirname'] + '/' + filename).find(name) != -1:
                            real = True
                    if (not real):
                        continue

                    pkgid = ob['pkgId']
                    if not result.has_key(pkgid):
                        pkg = self.getPackageDetails(ob['pkgId'])
                        result[pkgid] = (self.pc(pkg,rep))

        return result.values()
    
    def returnObsoletes(self):
        obsoletes = {}
        for (rep,cache) in self.primarydb.items():
            cur = cache.cursor()
            cur.execute("select packages.name as name,\
                packages.pkgId as pkgId,\
                packages.arch as arch, packages.epoch as epoch,\
                packages.release as release, packages.version as version,\
                obsoletes.name as oname, obsoletes.epoch as oepoch,\
                obsoletes.release as orelease, obsoletes.version as oversion,\
                obsoletes.flags as oflags\
                from obsoletes,packages where obsoletes.pkgKey = packages.pkgKey")
            for ob in cur.fetchall():
                # If the package that is causing the obsoletes is excluded
                # continue without processing the obsoletes
                if (self.excludes[rep].has_key(ob['pkgId'])):
                    continue
                key = ( ob['name'],ob['arch'],
                        ob['epoch'],ob['version'],
                        ob['release'])
                (n,f,e,v,r) = ( ob['oname'],ob['oflags'],
                                ob['oepoch'],ob['oversion'],
                                ob['orelease'])

                obsoletes.setdefault(key,[]).append((n,f,(e,v,r)))

        return obsoletes

    def getPackageDetails(self,pkgId):
        for (rep,cache) in self.primarydb.items():
            cur = cache.cursor()
            cur.execute("select * from packages where pkgId = %s",pkgId)
            for ob in cur.fetchall():
                pkg = self.db2class(ob)
                return pkg

    def searchPrco(self, name, prcotype):
        """return list of packages having prcotype name (any evr and flag)"""
        results = []
        for (rep,cache) in self.primarydb.items():
            cur = cache.cursor()
            cur.execute("select * from %s where name = %s" , (prcotype, name))
            prcos = cur.fetchall()
            for res in prcos:
                cur.execute("select * from packages where pkgKey = %s" , (res['pkgKey']))
                for x in cur.fetchall():
                    pkg = self.db2class(x)
                    if (self.excludes[rep].has_key(pkg.pkgId)):
                        continue
                                            
                    # Add this provides to prco otherwise yum doesn't understand
                    # that it matches
                    pkg.prco = {prcotype: 
                      [
                      { 'name': res.name,
                        'flags': res.flags,
                        'rel': res.release,
                        'ver': res.version,
                        'epoch': res.epoch
                      }
                      ]
                    }
                    results.append(self.pc(pkg,rep))


        # If it's not a provides or a filename, we are done
        if (prcotype != "provides" or name.find('/') != 0):
            return results

        # If it is a filename, search the primary.xml file info
        for (rep,cache) in self.primarydb.items():
            cur = cache.cursor()
            cur.execute("select * from files where name = %s" , (name))
            files = cur.fetchall()
            for res in files:
                cur.execute("select * from packages where pkgKey = %s" , (res['pkgKey']))
                for x in cur.fetchall():
                    pkg = self.db2class(x)
                    if (self.excludes[rep].has_key(pkg.pkgId)):
                        continue
                                            
                    pkg.files = {name: res['type']}
                    results.append(self.pc(pkg,rep))

        # If it is a filename, search the files.xml file info
        for (rep,cache) in self.filelistsdb.items():
            cur = cache.cursor()
            (dirname,filename) = os.path.split(name)
            cur.execute("select packages.pkgId as pkgId,\
                filelist.dirname as dirname,\
                filelist.filetypes as filetypes,\
                filelist.filenames as filenames \
                from filelist,packages where dirname = %s AND filelist.pkgKey = packages.pkgKey" , (dirname))
            files = cur.fetchall()
            for res in files:
                if (self.excludes[rep].has_key(res['pkgId'])):
                    continue
                                        
                # If it matches the dirname, that doesnt mean it matches
                # the filename, check if it does
                if filename and \
                  not filename in res['filenames'].split('/'):
                    continue
                # If it matches we only know the packageId
                pkg = self.getPackageDetails(res['pkgId'])
                results.append(self.pc(pkg,rep))
        return results

    def searchProvides(self, name):
        """return list of packages providing name (any evr and flag)"""
        return self.searchPrco(name, "provides")
                
    def searchRequires(self, name):
        """return list of packages requiring name (any evr and flag)"""
        return self.searchPrco(name, "requires")

    def searchObsoletes(self, name):
        """return list of packages obsoleting name (any evr and flag)"""
        return self.searchPrco(name, "obsoletes")

    def searchConflicts(self, name):
        """return list of packages conflicting with name (any evr and flag)"""
        return self.searchPrco(name, "conflicts")

    # TODO this seems a bit ugly and hackish
    def db2class(self,db,nevra_only=False):
      class tmpObject:
        pass
      y = tmpObject()
      y.nevra = (db.name,db.epoch,db.version,db.release,db.arch)
      y.sack = self
      y.pkgId = db.pkgId
      if (nevra_only):
        return y
      y.hdrange = {'start': db.rpm_header_start,'end': db.rpm_header_end}
      y.location = {'href': db.location_href,'value': '', 'base': db.location_base}
      y.checksum = {'pkgid': 'YES','type': db.checksum_type, 
                    'value': db.checksum_value }
      y.time = {'build': db.time_build, 'file': db.time_file }
      y.size = {'package': db.size_package, 'archive': db.size_archive, 'installed': db.size_installed }
      y.info = {'summary': db.summary, 'description': db['description'],
                'packager': db.rpm_packager, 'group': db.rpm_group,
                'buildhost': db.rpm_buildhost, 'sourcerpm': db.rpm_sourcerpm,
                'url': db.url, 'vendor': db.rpm_vendor, 'license': db.rpm_license }
      return y

    def simplePkgList(self, repoid=None):
        """returns a list of pkg tuples (n, a, e, v, r) optionally from a single repoid"""
        simplelist = []
        for (rep,cache) in self.primarydb.items():
            if (repoid == None or repoid == rep):
                cur = cache.cursor()
                cur.execute("select pkgId,name,epoch,version,release,arch from packages")
                for pkg in cur.fetchall():
                    if (self.excludes[rep].has_key(pkg.pkgId)):
                        continue                        
                    simplelist.append((pkg.name, pkg.arch, pkg.epoch, pkg.version, pkg.release)) 
                    
        return simplelist

    def returnNewestByNameArch(self, naTup=None):
        # If naTup is set do it from the database otherwise use our parent's
        # returnNewestByNameArch
        if (not naTup):
            # TODO process excludes here
            return repos.YumPackageSack.returnNewestByNameArch(self, naTup)

        # First find all packages that fulfill naTup
        allpkg = []
        for (rep,cache) in self.primarydb.items():
            cur = cache.cursor()
            cur.execute("select pkgId,name,epoch,version,release,arch from packages where name=%s and arch=%s",naTup)
            for x in cur.fetchall():
                if (self.excludes[rep].has_key(x.pkgId)):
                    continue                    
                allpkg.append(self.pc(self.db2class(x,True),rep))
        
        # if we've got zilch then raise
        if not allpkg:
            raise mdErrors.PackageSackError, 'No Package Matching %s.%s' % naTup
        return mdUtils.newestInList(allpkg)

    def returnNewestByName(self, name=None):
        # If name is set do it from the database otherwise use our parent's
        # returnNewestByName
        if (not name):
            return repos.YumPackageSack.returnNewestByName(self, name)

        # First find all packages that fulfill name
        allpkg = []
        for (rep,cache) in self.primarydb.items():
            cur = cache.cursor()
            cur.execute("select pkgId,name,epoch,version,release,arch from packages where name=%s", name)
            for x in cur.fetchall():
                if (self.excludes[rep].has_key(x.pkgId)):
                    continue                    
                allpkg.append(self.pc(self.db2class(x,True),rep))
        
        # if we've got zilch then raise
        if not allpkg:
            raise mdErrors.PackageSackError, 'No Package Matching %s' % name
        return mdUtils.newestInList(allpkg)

    def returnPackages(self, repoid=None):
        """Returns a list of packages, only containing nevra information """
        returnList = []
        for (rep,cache) in self.primarydb.items():
            if (repoid == None or repoid == rep):
                cur = cache.cursor()
                cur.execute("select pkgId,name,epoch,version,release,arch from packages")
                for x in cur.fetchall():
                    if (self.excludes[rep].has_key(x.pkgId)):
                        continue                    
                    returnList.append(self.pc(self.db2class(x,True),rep))
        return returnList

    def searchNevra(self, name=None, epoch=None, ver=None, rel=None, arch=None):        
        """return list of pkgobjects matching the nevra requested"""
        returnList = []
        
        # make sure some dumbass didn't pass us NOTHING to search on
        empty = True
        for arg in (name, epoch, ver, rel, arch):
            if arg:
                empty = False
        if empty:
            return returnList
        
        # make up our execute string
        q = "select * from packages WHERE"
        for (col, var) in [('name', name), ('epoch', epoch), ('version', ver),
                           ('arch', arch), ('release', rel)]:
            if var:
                if q[-5:] != 'WHERE':
                    q = q + ' AND %s = "%s"' % (col, var)
                else:
                    q = q + ' %s = "%s"' % (col, var)
            
        # Search all repositories            
        for (rep,cache) in self.primarydb.items():
            cur = cache.cursor()
            #cur.execute("select * from packages WHERE name = %s AND epoch = %s AND version = %s AND release = %s AND arch = %s" , (name,epoch,ver,rel,arch))
            cur.execute(q)
            for x in cur.fetchall():
                if (self.excludes[rep].has_key(x.pkgId)):
                    continue
                returnList.append(self.pc(self.db2class(x),rep))
        return returnList
    
    def excludeArchs(self, archlist):
        """excludes incompatible arches - archlist is a list of compat arches"""
        tmpstring = "select * from packages WHERE "
        for arch in archlist:
            tmpstring = tmpstring + 'arch != "%s" AND ' % arch
        
        last = tmpstring.rfind('AND') # clip that last AND
        querystring = tmpstring[:last]
        for (rep, cache) in self.primarydb.items():
            cur = cache.cursor()
            cur.execute(querystring)
            for x in cur.fetchall():
                obj = self.pc(self.db2class(x), rep)
                self.delPackage(obj)

# Simple helper functions

# Return a string representing filenamelist (filenames can not contain /)
def encodefilenamelist(filenamelist):
    return '/'.join(filenamelist)

# Return a list representing filestring (filenames can not contain /)
def decodefilenamelist(filenamestring):
    return filenamestring.split('/')

# Return a string representing filetypeslist
# filetypes should be file, dir or ghost
def encodefiletypelist(filetypelist):
    result = ''
    ft2string = {'file': 'f','dir': 'd','ghost': 'g'}
    for x in filetypelist:
        result += ft2string[x]
    return result

# Return a list representing filetypestring
# filetypes should be file, dir or ghost
def decodefiletypelist(filetypestring):
    string2ft = {'f':'file','d': 'dir','g': 'ghost'}
    return [string2ft[x] for x in filetypestring]



--- NEW FILE transactioninfo.py ---
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# Copyright 2005 Duke University
# Written by Seth Vidal

# TODOS: make all the package relationships deal with package objects
# search by package object for TransactionData, etc.
# provide a real TransactionData.remove(txmbr) method, It should 
# remove the given txmbr and iterate to remove all those in depedent relationships
# with the given txmbr. 

from constants import *

class TransactionData:
    """Data Structure designed to hold information on a yum Transaction Set"""
    def __init__(self):
        self.flags = []
        self.vsflags = []
        self.probFilterFlags = []
        self.root = '/'
        self.pkgdict = {} # key = pkgtup, val = list of TransactionMember obj
        self.debug = 0
        self.changed = False

    def __len__(self):
        return len(self.pkgdict.values())
        
    def __iter__(self):
        if hasattr(self.getMembers(), '__iter__'):
            return self.getMembers().__iter__()
        else:
            return iter(self.getMembers())

    def debugprint(self, msg):
        if self.debug:
            print msg


    def getMembers(self, pkgtup=None):
        """takes an optional package tuple and returns all transaction members 
           matching, no pkgtup means it returns all transaction members"""
        
        if pkgtup is None:
            returnlist = []
            for key in self.pkgdict.keys():
                returnlist.extend(self.pkgdict[key])
                
            return returnlist

        if self.pkgdict.has_key(pkgtup):
            return self.pkgdict[pkgtup]
        else:
            return []
            
    def getMode(self, name=None, arch=None, epoch=None, ver=None, rel=None):
        """returns the mode of the first match from the transaction set, 
           otherwise, returns None"""

        txmbrs = self.matchNaevr(name=name, arch=arch, epoch=epoch, ver=ver, rel=rel)
        if len(txmbrs):
            return txmbrs[0].ts_state
        else:
            return None

    
    def matchNaevr(self, name=None, arch=None, epoch=None, ver=None, rel=None):
        """returns the list of packages matching the args above"""
        completelist = self.pkgdict.keys()
        removedict = {}
        returnlist = []
        returnmembers = []
        
        for pkgtup in completelist:
            (n, a, e, v, r) = pkgtup
            if name is not None:
                if name != n:
                    removedict[pkgtup] = 1
                    continue
            if arch is not None:
                if arch != a:
                    removedict[pkgtup] = 1
                    continue
            if epoch is not None:
                if epoch != e:
                    removedict[pkgtup] = 1
                    continue
            if ver is not None:
                if ver != v:
                    removedict[pkgtup] = 1
                    continue
            if rel is not None:
                if rel != r:
                    removedict[pkgtup] = 1
                    continue
        
        for pkgtup in completelist:
            if not removedict.has_key(pkgtup):
                returnlist.append(pkgtup)
        
        for matched in returnlist:
            returnmembers.extend(self.pkgdict[matched])

        return returnmembers

    def add(self, txmember):
        """add a package to the transaction"""
        
        if not self.pkgdict.has_key(txmember.pkgtup):
            self.pkgdict[txmember.pkgtup] = []
        else:
            self.debugprint("Package: %s.%s - %s:%s-%s already in ts" % txmember.pkgtup)
            for member in self.pkgdict[txmember.pkgtup]:
                if member.ts_state == txmember.ts_state:
                    self.debugprint("Package in same mode, skipping.")
                    return
        self.pkgdict[txmember.pkgtup].append(txmember)
        self.changed = True

    def remove(self, pkgtup):
        """remove a package from the transaction"""
        if not self.pkgdict.has_key(pkgtup):
            self.debugprint("Package: %s not in ts" %(pkgtup,))
            return
        for txmbr in self.pkgdict[pkgtup]:
            txmbr.po.state = None
        
        del self.pkgdict[pkgtup]
        self.changed = True        
    
    def exists(self, pkgtup):
        """tells if the pkg is in the class"""
        if self.pkgdict.has_key(pkgtup):
            if len(self.pkgdict[pkgtup]) != 0:
                return 1
        
        return 0

    def isObsoleted(self, pkgtup):
        """true if the pkgtup is marked to be obsoleted"""
        if self.exists(pkgtup):
            for txmbr in self.getMembers(pkgtup=pkgtup):
                if txmbr.output_state == TS_OBSOLETED:
                    return True
        
        return False
                
    def makelists(self):
        """returns lists of transaction Member objects based on mode:
           updated, installed, erased, obsoleted, depupdated, depinstalled
           deperased"""
           
        self.instgroups = []
        self.removedgroups = []
        self.removed = []
        self.installed = []
        self.updated = []
        self.obsoleted = []
        self.depremoved = []
        self.depinstalled = []
        self.depupdated = []
        
        for txmbr in self.getMembers():
            if txmbr.output_state == TS_UPDATE:
                if txmbr.isDep:
                    self.depupdated.append(txmbr)
                else:
                    self.updated.append(txmbr)
                    
            elif txmbr.output_state == TS_INSTALL or txmbr.output_state == TS_TRUEINSTALL:
                if txmbr.groups:
                    for g in txmbr.groups:
                        if g not in self.instgroups:
                            self.instgroups.append(g)
                if txmbr.isDep:
                    self.depinstalled.append(txmbr)
                else:
                    self.installed.append(txmbr)
            
            elif txmbr.output_state == TS_ERASE:
                for g in txmbr.groups:
                    if g not in self.instgroups:
                        self.removedgroups.append(g)
                if txmbr.isDep:
                    self.depremoved.append(txmbr)
                else:
                    self.removed.append(txmbr)
                    
            elif txmbr.output_state == TS_OBSOLETED:
                self.obsoleted.append(txmbr)
                
            elif txmbr.output_state == TS_OBSOLETING:
                self.installed.append(txmbr)
                
            else:
                pass
    
            self.updated.sort()
            self.installed.sort()
            self.removed.sort()
            self.obsoleted.sort()
            self.depupdated.sort()
            self.depinstalled.sort()
            self.depremoved.sort()
            self.instgroups.sort()
            self.removedgroups.sort()

    
    def addInstall(self, po):
        """adds a package as an install but in mode 'u' to the ts
           takes a packages object and returns a TransactionMember Object"""
    
        txmbr = TransactionMember(po)
        txmbr.current_state = TS_AVAILABLE
        txmbr.output_state = TS_INSTALL
        txmbr.po.state = TS_INSTALL        
        txmbr.ts_state = 'u'
        txmbr.reason = 'user'
        self.add(txmbr)
        return txmbr

    def addTrueInstall(self, po):
        """adds a package as an install
           takes a packages object and returns a TransactionMember Object"""
    
        txmbr = TransactionMember(po)
        txmbr.current_state = TS_AVAILABLE
        txmbr.output_state = TS_TRUEINSTALL
        txmbr.po.state = TS_INSTALL        
        txmbr.ts_state = 'i'
        txmbr.reason = 'user'
        self.add(txmbr)
        return txmbr
    

    def addErase(self, po):
        """adds a package as an erasure
           takes a packages object and returns a TransactionMember Object"""
    
        txmbr = TransactionMember(po)
        txmbr.current_state = TS_INSTALL
        txmbr.output_state = TS_ERASE
        txmbr.po.state = TS_INSTALL
        txmbr.ts_state = 'e'
        self.add(txmbr)
        return txmbr

    def addUpdate(self, po, oldpo=None):
        """adds a package as an update
           takes a packages object and returns a TransactionMember Object"""
    
        txmbr = TransactionMember(po)
        txmbr.current_state = TS_AVAILABLE
        txmbr.output_state = TS_UPDATE
        txmbr.po.state = TS_UPDATE        
        txmbr.ts_state = 'u'
        if oldpo:
            txmbr.relatedto.append((oldpo.pkgtup, 'updates'))
            txmbr.updates.append(oldpo)
        self.add(txmbr)
        return txmbr

    def addObsoleting(self, po, oldpo):
        """adds a package as an obsolete over another pkg
           takes a packages object and returns a TransactionMember Object"""
    
        txmbr = TransactionMember(po)
        txmbr.current_state = TS_AVAILABLE
        txmbr.output_state = TS_OBSOLETING
        txmbr.po.state = TS_OBSOLETING
        txmbr.ts_state = 'u'
        txmbr.relatedto.append((oldpo, 'obsoletes'))
        txmbr.obsoletes.append(oldpo)
        self.add(txmbr)
        return txmbr

    def addObsoleted(self, po, obsoleting_po):
        """adds a package as being obsoleted by another pkg
           takes a packages object and returns a TransactionMember Object"""
    
        txmbr = TransactionMember(po)
        txmbr.current_state = TS_INSTALL
        txmbr.output_state =  TS_OBSOLETED
        txmbr.po.state = TS_OBSOLETED
        txmbr.ts_state = None
        txmbr.relatedto.append((obsoleting_po, 'obsoletedby'))
        txmbr.obsoleted_by.append(obsoleting_po)
        self.add(txmbr)
        return txmbr

class ConditionalTransactionData(TransactionData):
    """A transaction data implementing conditional package addition"""
    def __init__(self):
        # Key: package name to trigger condition
        # Value: list of package objects to add
        self.conditionals = {}
        TransactionData.__init__(self)

    def add(self, txmember):
        TransactionData.add(self, txmember)
        if self.conditionals.has_key(txmember.name):
            for po in self.conditionals[txmember.name]:
                condtxmbr = self.addInstall(po)
                condtxmbr.setAsDep(po=txmember.po)

class SortableTransactionData(ConditionalTransactionData):
    """A transaction data implementing topological sort on it's members"""
    def __init__(self):
        # Cache of sort
        self._sorted = []
        # Current dependency path
        self.path = []
        # List of loops
        self.loops = []
        ConditionalTransactionData.__init__(self)

    def _visit(self, txmbr):
        self.path.append(txmbr.name)
        txmbr.sortColour = TX_GREY
        for po in txmbr.depends_on:
            vertex = self.getMembers(pkgtup=po.pkgtup)[0]
            if vertex.sortColour == TX_GREY:
                self._doLoop(vertex.name)
            if vertex.sortColour == TX_WHITE:
                self._visit(vertex)
        txmbr.sortColour = TX_BLACK
        self._sorted.insert(0, txmbr.pkgtup)

    def _doLoop(self, name):
        self.path.append(name)
        loop = self.path[self.path.index(self.path[-1]):]
        if len(loop) > 2:
            self.loops.append(loop)

    def add(self, txmember):
        txmember.sortColour = TX_WHITE
        ConditionalTransactionData.add(self, txmember)
        self._sorted = []

    def remove(self, pkgtup):
        ConditionalTransactionData.remove(self, pkgtup)
        self._sorted = []

    def sort(self):
        if self._sorted:
            return self._sorted
        self._sorted = []
        # loop over all members
        for txmbr in self.getMembers():
            if txmbr.sortColour == TX_WHITE:
                self.path = [ ]
                self._visit(txmbr)
        self._sorted.reverse()
        return self._sorted

class TransactionMember:
    """Class to describe a Transaction Member (a pkg to be installed/
       updated/erased)."""
    
    def __init__(self, po):
        # holders for data
        self.po = po # package object
        self.current_state = None # where the package currently is (repo, installed)
        self.ts_state = None # what state to put it into in the transaction set
        self.output_state = None # what state to list if printing it
        self.isDep = 0
        self.reason = 'user' # reason for it to be in the transaction set
        self.process = None # 
        self.relatedto = [] # ([relatedpkgtup, relationship)]
        self.depends_on = []
        self.obsoletes = []
        self.obsoleted_by = []
        self.updates = []
        self.updated_by = []
        self.groups = [] # groups it's in
        self._poattr = ['pkgtup', 'repoid', 'name', 'arch', 'epoch', 'version',
                        'release']

        for attr in self._poattr:
            val = getattr(self.po, attr)
            setattr(self, attr, val)

    def setAsDep(self, po=None):
        """sets the transaction member as a dependency and maps the dep into the
           relationship list attribute"""
        
        self.isDep = 1
        if po:
            self.relatedto.append((po.pkgtup, 'dependson'))
            self.depends_on.append(po)

    def __cmp__(self, other):
        if self.name > other.name:
            return 1
        if self.name < other.name:
            return -1
        if self.name == other.name:
            return 0

    def __hash__(self):
        return hash(self.po.pkgtup)
            
    def __str__(self):
        return "%s.%s %s-%s-%s - %s" % (self.name, self.arch, self.epoch,
                                        self.version, self.release, self.ts_state)
        
    # This is the tricky part - how do we nicely setup all this data w/o going insane
    # we could make the txmember object be created from a YumPackage base object
    # we still may need to pass in 'groups', 'ts_state', 'output_state', 'reason', 'current_state'
    # and any related packages. A world of fun that will be, you betcha
    
    
    # definitions
    # current and output states are defined in constants
    # relationships are defined in constants
    # ts states are: u, i, e
    


--- NEW FILE update_md.py ---
#!/usr/bin/python -t

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
# Copyright 2005 Duke University 


import sys
from cElementTree import iterparse
import exceptions



class UpdateNoticeException(exceptions.Exception):
    pass
    
    
class UpdateNotice(object):
    def __init__(self, elem=None):
        self.cves = []
        self.urls = []
        self.packages = []
        self.description = ''
        self.update_id = None
        self.distribution = None
        self.release_date = None
        self.status = None
        self.classification = None
        self.title = ''
        
        if elem:
            self.parse(elem)
    
    def __str__(self):
        cveinfo = pkglist = related = ''
        
        head = """
Class: %s
Status: %s
Distribution: %s
ID: %s
Release date: %s
Description: 
%s
        """ % (self.classification, self.status, self.distribution, 
               self.update_id, self.release_date, self.description)

        if self.urls:
            related = '\nRelated URLS:\n'
            for url in self.urls:
                related = related + '  %s\n' % url
        
        if self.cves:
            cveinfo = '\nResolves CVES:\n'
            for cve in self.cves:
                cveinfo = cveinfo + '  %s\n' % cve
        
        if self.packages:
            pkglist = '\nPackages: \n'
            for pkg in self.packages:
                pkgstring = '%s-%s-%s.%s\t\t%s\n' % (pkg['name'], pkg['ver'],
                                                    pkg['rel'], pkg['arch'],
                                                    pkg['pkgid'])
                pkglist = pkglist + pkgstring
        
        msg = head + related + cveinfo + pkglist
        
        return msg

    def parse(self, elem):
        if elem.tag == 'update':
            id = elem.attrib.get('id')
            if not id:
                raise UpdateNoticeException
            self.update_id = id
            
            self.release_date = elem.attrib.get('release_date')
            self.status = elem.attrib.get('status')
            c = elem.attrib.get('class')
            if not c:
                self.classification = 'update'
            else:
                self.classification = c

        for child in elem:

            if child.tag == 'cve':
                self.cves.append(child.text)

            elif child.tag == 'url':
                self.urls.append(child.text)
            
            elif child.tag == 'description':
                self.description = child.text
            
            elif child.tag == 'distribution':
                self.distribution = child.text
            
            elif child.tag == 'title':
                self.title = child.text

            elif child.tag == 'package':
                self.parse_package(child)
        
    def parse_package(self, elem):
        
        pkg = {}
        pkg['pkgid'] = elem.attrib.get('pkgid')
        pkg['name']  = elem.attrib.get('name')
        pkg['arch'] = elem.attrib.get('arch')
        for child in elem:
            if child.tag == 'version':
                pkg['ver'] = child.attrib.get('ver')
                pkg['rel'] = child.attrib.get('rel')
                pkg['epoch'] = child.attrib.get('epoch')
        
        self.packages.append(pkg)




class UpdateMetadata(object):
    def __init__(self):
        self._notices = {}
        
    def get_notices(self):
        return self._notices.values()

    notices = property(get_notices)
    
    def add(self, srcfile):
        if not srcfile:
            raise UpdateNoticeException
            
        if type(srcfile) == type('str'):
            infile = open(srcfile, 'rt')
        else:   # srcfile is a file object
            infile = srcfile
        
        parser = iterparse(infile)

        for event, elem in parser:
            if elem.tag == 'update':
                un = UpdateNotice(elem)
                if not self._notices.has_key(un.update_id):
                    self._notices[un.update_id] = un
                
            
        del parser
    
    def dump(self):
        for notice in self.notices:
            print notice


def main():

    try:
        print sys.argv[1]
        um = UpdateMetadata()
        for srcfile in sys.argv[1:]:
            um.add(srcfile)

        um.dump()
        
    except IOError:
        print >> sys.stderr, "update_md.py: No such file:\'%s\'" % sys.argv[1:]
        sys.exit(1)
        
if __name__ == '__main__':
    main()




More information about the scm-commits mailing list