[PATCH 1/4] Rewrite the sources module

Mathieu Bridon bochecha at fedoraproject.org
Tue Feb 17 03:33:43 UTC 2015


From: Mathieu Bridon <bochecha at daitauha.fr>

The API was awkward to use, to test, and hard to extend to support the
new file format we are targetting.

The new code is fully tested, and offers a simpler object API which
should be easier to work with.
---
 src/pyrpkg/__init__.py |  45 ++++-----
 src/pyrpkg/sources.py  |  92 +++++++++++-------
 test/test_sources.py   | 255 +++++++++++++++++++++++++++++++------------------
 3 files changed, 234 insertions(+), 158 deletions(-)

diff --git a/src/pyrpkg/__init__.py b/src/pyrpkg/__init__.py
index 5c4dbc5..7ecc8b4 100644
--- a/src/pyrpkg/__init__.py
+++ b/src/pyrpkg/__init__.py
@@ -38,7 +38,7 @@ try:
 except ImportError:
     pass
 
-import pyrpkg.sources
+from pyrpkg.sources import SourcesFile
 
 
 # Define our own error class
@@ -1568,16 +1568,20 @@ class Commands(object):
         # Default to putting the files where the module is
         if not outdir:
             outdir = self.path
-        for (csum, file) in self._read_sources():
+
+        sourcesf = SourcesFile(self.sources_filename)
+
+        for entry in sourcesf.entries:
             # See if we already have a valid copy downloaded
-            outfile = os.path.join(outdir, file)
+            outfile = os.path.join(outdir, entry.file)
             if os.path.exists(outfile):
-                if self._verify_file(outfile, csum, self.lookasidehash):
+                if self._verify_file(outfile, entry.hash, entry.hashtype):
                     continue
-            self.log.info("Downloading %s" % (file))
+            self.log.info("Downloading %s" % (entry.file))
             url = '%s/%s/%s/%s/%s' % (self.lookaside, self.module_name,
-                                      file.replace(' ', '%20'),
-                                      csum, file.replace(' ', '%20'))
+                                      entry.file.replace(' ', '%20'),
+                                      entry.hash,
+                                      entry.file.replace(' ', '%20'))
             # These options came from Makefile.common.
             # Probably need to support wget as well
             command = ['curl', '-H', 'Pragma:', '-o', outfile, '-R', '-S',
@@ -1586,8 +1590,8 @@ class Commands(object):
                 command.append('-s')
             command.append(url)
             self._run_command(command)
-            if not self._verify_file(outfile, csum, self.lookasidehash):
-                raise rpkgError('%s failed checksum' % file)
+            if not self._verify_file(outfile, entry.hash, entry.hashtype):
+                raise rpkgError('%s failed checksum' % entry.file)
         return
 
     def switch_branch(self, branch, fetch=True):
@@ -2226,18 +2230,6 @@ class Commands(object):
                            config_dir)
             self._cleanup_tmp_dir(config_dir)
 
-    def _read_sources(self):
-        """Parses 'sources' file"""
-        with open(self.sources_filename, 'rb') as sources_fp:
-            reader = pyrpkg.sources.reader(sources_fp)
-            return [a for a in reader]
-
-    def _write_sources(self, rows):
-        """Writes 'sources' file"""
-        with open(self.sources_filename, 'wb') as sources_fp:
-            writer = pyrpkg.sources.writer(sources_fp)
-            writer.writerows(rows)
-
     def upload(self, files, replace=False):
         """Upload source file(s) in the lookaside cache
 
@@ -2247,11 +2239,7 @@ class Commands(object):
         oldpath = os.getcwd()
         os.chdir(self.path)
 
-        # Decide to overwrite or append to sources:
-        if replace or not os.path.exists(self.sources_filename):
-            sources = []
-        else:
-            sources = self._read_sources()
+        sourcesf = SourcesFile(self.sources_filename, replace=replace)
 
         # Will add new sources to .gitignore if they are not already there.
         gitignore = GitIgnore(os.path.join(self.path, '.gitignore'))
@@ -2262,8 +2250,7 @@ class Commands(object):
             file_hash = self._hash_file(f, self.lookasidehash)
             self.log.info("Uploading: %s  %s" % (file_hash, f))
             file_basename = os.path.basename(f)
-            if (file_hash, file_basename) not in sources:
-                sources.append((file_hash, file_basename))
+            sourcesf.add_entry(self.lookasidehash, file_basename, file_hash)
 
             # Add this file to .gitignore if it's not already there:
             if not gitignore.match(file_basename):
@@ -2278,7 +2265,7 @@ class Commands(object):
                 self._do_curl(file_hash, f)
                 uploaded.append(file_basename)
 
-        self._write_sources(sources)
+        sourcesf.write()
 
         # Write .gitignore with the new sources if anything changed:
         gitignore.write()
diff --git a/src/pyrpkg/sources.py b/src/pyrpkg/sources.py
index c4263ad..e987fa6 100644
--- a/src/pyrpkg/sources.py
+++ b/src/pyrpkg/sources.py
@@ -1,55 +1,75 @@
 """
-Our so-called sources file is simple text-based line-oriented file format. Each
-line represents one file and has two fields: file hash and base name of the
-file. Field separator is two spaces and Unix end-of-lines.
+Our so-called sources file is simple text-based line-oriented file format.
 
-This sources module implements API similar to csv module from standard library
-to read and write data in sources file format.
+Each line represents one source file and is in the same format as the output
+of commands like `md5sum filename`:
+
+    hash  filename
+
+This module implements a simple API to read these files, parse lines into
+entries, and write these entries to the file in the proper format.
 """
 
 
-class Reader(object):
-    def __init__(self, sourcesfile):
-        self.sourcesfile = sourcesfile
-        self._sourcesiter = None
+import os
+import re
 
-    def __iter__(self):
-        for entry in self.sourcesfile:
-            yield _parse_line(entry)
 
+class MalformedLineError(Exception):
+    pass
 
-class Writer(object):
-    def __init__(self, sourcesfile):
+
+class SourcesFile(object):
+    def __init__(self, sourcesfile, replace=False):
         self.sourcesfile = sourcesfile
+        self.entries = []
+
+        if not replace:
+            if not os.path.exists(sourcesfile):
+                return
+
+            with open(sourcesfile) as f:
+                for line in f:
+                    entry = self.parse_line(line)
+
+                    if entry and entry not in self.entries:
+                        self.entries.append(entry)
+
+    def parse_line(self, line):
+        stripped = line.strip()
+
+        if not stripped:
+            return
 
-    def writerow(self, row):
-        self.sourcesfile.write("%s\n" % _format_line(row))
+        try:
+            hash, file = stripped.split('  ', 1)
 
-    def writerows(self, rows):
-        for row in rows:
-            self.writerow(row)
+        except ValueError:
+            raise MalformedLineError(line)
 
+        return SourceFileEntry('md5', file, hash)
 
-def reader(sourcesfile):
-    return Reader(sourcesfile)
+    def add_entry(self, hashtype, file, hash):
+        entry = SourceFileEntry(hashtype, file, hash)
 
+        if entry not in self.entries:
+            self.entries.append(entry)
 
-def writer(sourcesfile):
-    return Writer(sourcesfile)
+    def write(self):
+        with open(self.sourcesfile, 'w') as f:
+            for entry in self.entries:
+                f.write(str(entry))
 
 
-def _parse_line(line):
-    stripped_line = line.strip()
-    if not stripped_line:
-        return []
-    entries = stripped_line.split('  ', 1)
-    if len(entries) != 2:
-        raise ValueError("Malformed line: %r." % line)
-    return entries
+class SourceFileEntry(object):
+    def __init__(self, hashtype, file, hash):
+            self.hashtype = hashtype.lower()
+            self.hash = hash
+            self.file = file
 
+    def __str__(self):
+        return '%s  %s\n' % (self.hash, self.file)
 
-def _format_line(entry):
-    if len(entry) != 0 and len(entry) != 2:
-        raise ValueError("Incorrect number of fields for entry: %r."
-                         % (entry,))
-    return "  ".join(entry)
+    def __eq__(self, other):
+        return ((self.hashtype, self.hash, self.file) ==
+                (other.hashtype, other.hash, other.file))
diff --git a/test/test_sources.py b/test/test_sources.py
index 997ab83..2b281e7 100644
--- a/test/test_sources.py
+++ b/test/test_sources.py
@@ -1,6 +1,9 @@
 import os
+import random
+import shutil
+import string
 import sys
-import StringIO
+import tempfile
 import unittest
 
 old_path = list(sys.path)
@@ -10,98 +13,164 @@ from pyrpkg import sources
 sys.path = old_path
 
 
-class formatLineTestCase(unittest.TestCase):
-    def test_wrong_number_of_fields(self):
-        WRONG_ENTRIES = [
-            ('foo'),
-            ('foo', 'bar', 'foo'),
-        ]
-        for entry in WRONG_ENTRIES:
-            self.assertRaises(ValueError, sources._format_line, entry)
-
-    def test_empty_entry(self):
-        self.assertEqual('', sources._format_line(()))
-
-    def test_correct_entry(self):
-        CORRECT_ENTRIES = [
-            (['foo', 'bar'], ('foo  bar')),
-        ]
-        for entry, line in CORRECT_ENTRIES:
-            self.assertEqual(line,
-                             sources._format_line(entry))
-
-
-class parseLineTestCase(unittest.TestCase):
-    def test_wrong_number_of_parts(self):
-        WRONG_LINES = [
-            'foo\n',
-            'foo  \n',
-            'foo bar\n',
-        ]
-        for line in WRONG_LINES:
-            self.assertRaises(ValueError, sources._parse_line, line)
-
-    def test_empty_line(self):
-        EMPTY_LINES = [
-            '',
-            '\n',
-            '  \n',
-        ]
-        for line in EMPTY_LINES:
-            self.assertEqual([], sources._parse_line(line))
-
-    def test_correct_line(self):
-        CORRECT_LINES = [
-            ('foo  bar\n', ['foo', 'bar']),
-            ('foo   bar\n', ['foo', ' bar'])
-        ]
-        for line, entry in CORRECT_LINES:
-            self.assertEqual(entry, sources._parse_line(line))
-
-
-class ReaderTestCase(unittest.TestCase):
-    def test_empty_sources(self):
-        EMPTY_SOURCES = [
-            ('', []),
-            ('\n', [[]]),
-            (' \n', [[]]),
-            ('\n\n', [[], []]),
-            (' \n ', [[], []]),
-        ]
-        for buffer, entries in EMPTY_SOURCES:
-            fp = StringIO.StringIO(buffer)
-            reader = sources.Reader(fp)
-            self.assertEqual(entries, [a for a in reader])
-            fp.close()
-
-    def test_correct_sources(self):
-        CORRECT_SOURCES = [
-            ('foo  bar\n', [['foo', 'bar']]),
-            ('foo  bar\nfooo  baaar\n', [['foo', 'bar'],
-                                         ['fooo', 'baaar'],
-                                         ]),
-        ]
-        for buffer, entries in CORRECT_SOURCES:
-            fp = StringIO.StringIO(buffer)
-            reader = sources.Reader(fp)
-            self.assertEqual(entries, [a for a in reader])
-            fp.close()
-
-
-class WriterTestCase(unittest.TestCase):
-    def test_writerows(self):
-        CORRECT_SOURCES = [
-            ([['foo', 'bar']], 'foo  bar\n'),
-            ([['foo', 'bar'],
-              ['fooo', 'baaar'],
-              ], 'foo  bar\nfooo  baaar\n'),
-        ]
-        for entries, buffer in CORRECT_SOURCES:
-            fp = StringIO.StringIO()
-            writer = sources.Writer(fp)
-            writer.writerows(entries)
-            self.assertEqual(fp.getvalue(), buffer)
-            fp.close()
+class SourceFileEntryTestCase(unittest.TestCase):
+    def test_entry(self):
+        e = sources.SourceFileEntry('md5', 'afile', 'ahash')
+        expected = 'ahash  afile\n'
+        self.assertEqual(str(e), expected)
+
+
+class SourcesFileTestCase(unittest.TestCase):
+    def setUp(self):
+        self.workdir = tempfile.mkdtemp(prefix='rpkg-tests.')
+        self.sourcesfile = os.path.join(self.workdir, self._testMethodName)
+
+    def tearDown(self):
+        shutil.rmtree(self.workdir)
+
+    def test_parse_empty_line(self):
+        s = sources.SourcesFile(self.sourcesfile)
+        entry = s.parse_line('')
+        self.assertIsNone(entry)
+
+    def test_parse_eol_line(self):
+        s = sources.SourcesFile(self.sourcesfile)
+        entry = s.parse_line('\n')
+        self.assertIsNone(entry)
+
+    def test_parse_whitespace_line(self):
+        s = sources.SourcesFile(self.sourcesfile)
+        entry = s.parse_line('    \n')
+        self.assertIsNone(entry)
+
+    def test_parse_entry_line(self):
+        s = sources.SourcesFile(self.sourcesfile)
+
+        line = 'ahash  afile\n'
+        entry = s.parse_line(line)
+
+        self.assertTrue(isinstance(entry, sources.SourceFileEntry))
+        self.assertEqual(entry.hashtype, 'md5')
+        self.assertEqual(entry.hash, 'ahash')
+        self.assertEqual(entry.file, 'afile')
+        self.assertEqual(str(entry), line)
+
+    def test_parse_wrong_lines(self):
+        s = sources.SourcesFile(self.sourcesfile)
+
+        lines = ['ahash',
+                 'ahash  ',
+                 'ahash afile',
+                 ]
+
+        for line in lines:
+            with self.assertRaises(sources.MalformedLineError):
+                s.parse_line(line)
+
+    def test_open_new_file(self):
+        s = sources.SourcesFile(self.sourcesfile)
+        self.assertEqual(len(s.entries), 0)
+
+    def test_open_empty_file(self):
+        open(self.sourcesfile, 'w').write('')
+        s = sources.SourcesFile(self.sourcesfile)
+        self.assertEqual(len(s.entries), 0)
+
+    def test_open_existing_file(self):
+        lines = ['ahash  afile\n', 'anotherhash  anotherfile\n']
+
+        with open(self.sourcesfile, 'w') as f:
+            for line in lines:
+                f.write(line)
+
+        s = sources.SourcesFile(self.sourcesfile)
+
+        for i, entry in enumerate(s.entries):
+            self.assertTrue(isinstance(entry, sources.SourceFileEntry))
+            self.assertEqual(str(entry), lines[i])
+
+    def test_open_existing_file_with_wrong_line(self):
+        line = 'some garbage here\n'
+
+        with open(self.sourcesfile, 'w') as f:
+            f.write(line)
+
+        with self.assertRaises(sources.MalformedLineError):
+            return sources.SourcesFile(self.sourcesfile)
+
+    def test_add_entry(self):
+        s = sources.SourcesFile(self.sourcesfile)
+        self.assertEqual(len(s.entries), 0)
+
+        s.add_entry('md5', 'afile', 'ahash')
+        self.assertEqual(len(s.entries), 1)
+        self.assertEqual(str(s.entries[-1]), 'ahash  afile\n')
+
+        s.add_entry('md5', 'anotherfile', 'anotherhash')
+        self.assertEqual(len(s.entries), 2)
+        self.assertEqual(str(s.entries[-1]), 'anotherhash  anotherfile\n')
+
+    def test_add_entry_twice(self):
+        s = sources.SourcesFile(self.sourcesfile)
+        self.assertEqual(len(s.entries), 0)
+
+        s.add_entry('md5', 'afile', 'ahash')
+        self.assertEqual(len(s.entries), 1)
+        self.assertEqual(str(s.entries[-1]), 'ahash  afile\n')
+
+        s.add_entry('md5', 'afile', 'ahash')
+        self.assertEqual(len(s.entries), 1)
+
+    def test_write_new_file(self):
+        s = sources.SourcesFile(self.sourcesfile)
+        self.assertEqual(len(s.entries), 0)
+
+        s.add_entry('md5', 'afile', 'ahash')
+        s.add_entry('md5', 'anotherfile', 'anotherhash')
+        s.write()
+
+        with open(self.sourcesfile) as f:
+             lines = f.readlines()
+
+        self.assertEqual(len(lines), 2)
+        self.assertEqual(lines[0], 'ahash  afile\n')
+        self.assertEqual(lines[1], 'anotherhash  anotherfile\n')
+
+    def test_write_adding_a_line(self):
+        lines = ['ahash  afile\n', 'anotherhash  anotherfile\n']
+
+        with open(self.sourcesfile, 'w') as f:
+            for line in lines:
+                f.write(line)
+
+        s = sources.SourcesFile(self.sourcesfile)
+        s.add_entry('md5', 'thirdfile', 'thirdhash')
+        s.write()
+
+        with open(self.sourcesfile) as f:
+             lines = f.readlines()
+
+        self.assertEqual(len(lines), 3)
+        self.assertEqual(lines[0], 'ahash  afile\n')
+        self.assertEqual(lines[1], 'anotherhash  anotherfile\n')
+        self.assertEqual(lines[2], 'thirdhash  thirdfile\n')
+
+    def test_write_over(self):
+        lines = ['ahash  afile\n', 'anotherhash  anotherfile\n']
+
+        with open(self.sourcesfile, 'w') as f:
+            for line in lines:
+                f.write(line)
+
+        s = sources.SourcesFile(self.sourcesfile, replace=True)
+        s.add_entry('md5', 'thirdfile', 'thirdhash')
+        s.write()
+
+        with open(self.sourcesfile) as f:
+             lines = f.readlines()
+
+        self.assertEqual(len(lines), 1)
+        self.assertEqual(lines[0], 'thirdhash  thirdfile\n')
 
 
 if __name__ == '__main__':
-- 
2.1.0



More information about the buildsys mailing list