D3937: ssh: avoid reading beyond the end of stream when using compression

joerg.sonnenberger (Joerg Sonnenberger) phabricator at mercurial-scm.org
Thu Jul 12 14:10:02 EDT 2018


joerg.sonnenberger updated this revision to Diff 9576.

REPOSITORY
  rHG Mercurial

CHANGES SINCE LAST UPDATE
  https://phab.mercurial-scm.org/D3937?vs=9573&id=9576

REVISION DETAIL
  https://phab.mercurial-scm.org/D3937

AFFECTED FILES
  mercurial/sshpeer.py
  mercurial/util.py

CHANGE DETAILS

diff --git a/mercurial/util.py b/mercurial/util.py
--- a/mercurial/util.py
+++ b/mercurial/util.py
@@ -355,6 +355,11 @@
             self._fillbuffer()
         return self._frombuffer(size)
 
+    def unbufferedread(self, size):
+        if not self._eof and self._lenbuf == 0:
+            self._fillbuffer(max(size, _chunksize))
+        return self._frombuffer(min(self._lenbuf, size))
+
     def readline(self, *args, **kwargs):
         if 1 < len(self._buffer):
             # this should not happen because both read and readline end with a
@@ -396,9 +401,9 @@
             self._lenbuf = 0
         return data
 
-    def _fillbuffer(self):
+    def _fillbuffer(self, size=_chunksize):
         """read data to the buffer"""
-        data = os.read(self._input.fileno(), _chunksize)
+        data = os.read(self._input.fileno(), size)
         if not data:
             self._eof = True
         else:
@@ -3328,6 +3333,104 @@
         """
         raise NotImplementedError()
 
+class _CompressedStreamReader(object):
+    def __init__(self, fh):
+        if safehasattr(fh, 'unbufferedread'):
+            self._reader = fh.unbufferedread
+        else:
+            self._reader = fh.read
+        self._pending = []
+        self._pos = 0
+        self._eof = False
+
+    def _decompress(self, chunk):
+        raise NotImplementedError()
+
+    def read(self, l):
+        buf = []
+        while True:
+            while self._pending:
+                if len(self._pending[0]) > l + self._pos:
+                    newbuf = self._pending[0]
+                    buf.append(newbuf[self._pos:self._pos + l])
+                    self._pos += l
+                    return ''.join(buf)
+
+                newbuf = self._pending.pop(0)
+                if self._pos:
+                    buf.append(newbuf[self._pos:])
+                    l -= len(newbuf) - self._pos
+                else:
+                    buf.append(newbuf)
+                    l -= len(newbuf)
+                self._pos = 0
+
+            if self._eof:
+                return ''.join(buf)
+            chunk = self._reader(65536)
+            self._decompress(chunk)
+
+class _GzipCompressedStreamReader(_CompressedStreamReader):
+    def __init__(self, fh):
+        super(_GzipCompressedStreamReader, self).__init__(fh)
+        self._decompobj = zlib.decompressobj()
+    def _decompress(self, chunk):
+        newbuf = self._decompobj.decompress(chunk)
+        if newbuf:
+            self._pending.append(newbuf)
+        d = self._decompobj.copy()
+        try:
+            d.decompress('x')
+            d.flush()
+            if d.unused_data == 'x':
+                self._eof = True
+        except zlib.error:
+            pass
+
+class _BZ2CompressedStreamReader(_CompressedStreamReader):
+    def __init__(self, fh):
+        super(_BZ2CompressedStreamReader, self).__init__(fh)
+        self._decompobj = bz2.BZ2Decompressor()
+    def _decompress(self, chunk):
+        newbuf = self._decompobj.decompress(chunk)
+        if newbuf:
+            self._pending.append(newbuf)
+        try:
+            while True:
+                newbuf = self._decompobj.decompress('')
+                if newbuf:
+                    self._pending.append(newbuf)
+                else:
+                    break
+        except EOFError:
+            self._eof = True
+
+class _TruncatedBZ2CompressedStreamReader(_BZ2CompressedStreamReader):
+    def __init__(self, fh):
+        super(_TruncatedBZ2CompressedStreamReader, self).__init__(fh)
+        newbuf = self._decompobj.decompress('BZ')
+        if newbuf:
+            self._pending.append(newbuf)
+
+class _ZstdCompressedStreamReader(_CompressedStreamReader):
+    def __init__(self, fh, zstd):
+        super(_ZstdCompressedStreamReader, self).__init__(fh)
+        self._zstd = zstd
+        self._decompobj = zstd.ZstdDecompressor().decompressobj()
+    def _decompress(self, chunk):
+        newbuf = self._decompobj.decompress(chunk)
+        if newbuf:
+            self._pending.append(newbuf)
+        try:
+            while True:
+                newbuf = self._decompobj.decompress('')
+                if newbuf:
+                    self._pending.append(newbuf)
+                else:
+                    break
+        except self._zstd.ZstdError:
+            self._eof = True
+
 class _zlibengine(compressionengine):
     def name(self):
         return 'zlib'
@@ -3361,15 +3464,7 @@
         yield z.flush()
 
     def decompressorreader(self, fh):
-        def gen():
-            d = zlib.decompressobj()
-            for chunk in filechunkiter(fh):
-                while chunk:
-                    # Limit output size to limit memory.
-                    yield d.decompress(chunk, 2 ** 18)
-                    chunk = d.unconsumed_tail
-
-        return chunkbuffer(gen())
+        return _GzipCompressedStreamReader(fh)
 
     class zlibrevlogcompressor(object):
         def compress(self, data):
@@ -3449,12 +3544,7 @@
         yield z.flush()
 
     def decompressorreader(self, fh):
-        def gen():
-            d = bz2.BZ2Decompressor()
-            for chunk in filechunkiter(fh):
-                yield d.decompress(chunk)
-
-        return chunkbuffer(gen())
+        return _BZ2CompressedStreamReader(fh)
 
 compengines.register(_bz2engine())
 
@@ -3468,14 +3558,7 @@
     # We don't implement compressstream because it is hackily handled elsewhere.
 
     def decompressorreader(self, fh):
-        def gen():
-            # The input stream doesn't have the 'BZ' header. So add it back.
-            d = bz2.BZ2Decompressor()
-            d.decompress('BZ')
-            for chunk in filechunkiter(fh):
-                yield d.decompress(chunk)
-
-        return chunkbuffer(gen())
+        return _TruncatedBZ2CompressedStreamReader(fh)
 
 compengines.register(_truncatedbz2engine())
 
@@ -3570,9 +3653,7 @@
         yield z.flush()
 
     def decompressorreader(self, fh):
-        zstd = self._module
-        dctx = zstd.ZstdDecompressor()
-        return chunkbuffer(dctx.read_from(fh))
+        return _ZstdCompressedStreamReader(fh, self._module)
 
     class zstdrevlogcompressor(object):
         def __init__(self, zstd, level=3):
diff --git a/mercurial/sshpeer.py b/mercurial/sshpeer.py
--- a/mercurial/sshpeer.py
+++ b/mercurial/sshpeer.py
@@ -98,6 +98,17 @@
             _forwardoutput(self._ui, self._side)
         return r
 
+    def unbufferedread(self, size):
+        r = self._call('unbufferedread', size)
+        if size != 0 and not r:
+            # We've observed a condition that indicates the
+            # stdout closed unexpectedly. Check stderr one
+            # more time and snag anything that's there before
+            # letting anyone know the main part of the pipe
+            # closed prematurely.
+            _forwardoutput(self._ui, self._side)
+        return r
+
     def readline(self):
         return self._call('readline')
 



To: joerg.sonnenberger, #hg-reviewers
Cc: mercurial-devel


More information about the Mercurial-devel mailing list