Commit ea1e8b38 authored by Alex Gaynor's avatar Alex Gaynor
Browse files

Ensured that the archive module consistantly explicitly closed all files.

parent ca6015ca
Loading
Loading
Loading
Loading
+17 −1
Original line number Diff line number Diff line
@@ -46,7 +46,8 @@ def extract(path, to_path=''):
    Unpack the tar or zip file at the specified path to the directory
    specified by to_path.
    """
    Archive(path).extract(to_path)
    with Archive(path) as archive:
        archive.extract(to_path)


class Archive(object):
@@ -77,12 +78,21 @@ class Archive(object):
                "Path not a recognized archive format: %s" % filename)
        return cls

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.close()

    def extract(self, to_path=''):
        self._archive.extract(to_path)

    def list(self):
        self._archive.list()

    def close(self):
        self._archive.close()


class BaseArchive(object):
    """
@@ -161,6 +171,9 @@ class TarArchive(BaseArchive):
                    if extracted:
                        extracted.close()

    def close(self):
        self._archive.close()


class ZipArchive(BaseArchive):

@@ -189,6 +202,9 @@ class ZipArchive(BaseArchive):
                with open(filename, 'wb') as outfile:
                    outfile.write(data)

    def close(self):
        self._archive.close()

extension_map = {
    '.tar': TarArchive,
    '.tar.bz2': TarArchive,
+4 −2
Original line number Diff line number Diff line
@@ -27,12 +27,14 @@ class ArchiveTester(object):
        os.chdir(self.old_cwd)

    def test_extract_method(self):
        Archive(self.archive).extract(self.tmpdir)
        with Archive(self.archive) as archive:
            archive.extract(self.tmpdir)
        self.check_files(self.tmpdir)

    def test_extract_method_no_to_path(self):
        os.chdir(self.tmpdir)
        Archive(self.archive_path).extract()
        with Archive(self.archive_path) as archive:
            archive.extract()
        self.check_files(self.tmpdir)

    def test_extract_function(self):