D4414: cborutil: implement sans I/O decoder

indygreg (Gregory Szorc) phabricator at mercurial-scm.org
Wed Aug 29 03:29:57 UTC 2018


indygreg created this revision.
Herald added subscribers: mercurial-devel, mjpieters.
Herald added a reviewer: hg-reviewers.

REVISION SUMMARY
  The vendored CBOR package decodes by calling read(n) on an object.
  There are a number of disadvantages to this:
  
  - Uses blocking I/O. If sufficient data is not available, the decoder will hang until it is.
  - No support for partial reads. If the read(n) returns less data than requested, the decoder raises an error.
  - Requires the use of a file like object. If the original data is in say a buffer, we need to "cast" it to e.g. a BytesIO to appease the decoder.
  
  In addition, the vendored CBOR decoder doesn't provide flexibility
  that we desire. Specifically:
  
  - It buffers indefinite length bytestrings instead of streaming them.
  - It doesn't allow limiting the set of types that can be decoded. This property is useful when implementing a "hardened" decoder that is less susceptible to abusive input.
  - It doesn't provide sufficient "hook points" and introspection to institute checks around behavior. These are useful for implementing a "hardened" decoder.
  
  This all adds up to a reasonable set of justifications for writing our
  own decoder.
  
  So, this commit implements our own CBOR decoder.
  
  At the heart of the decoder is a function that decodes a single "item"
  from a buffer. This item can be a complete simple value or a special
  value, such as "start of array." Using this function, we can build a
  decoder that effectively iterates over the stream of decoded items and
  builds up higher-level values, such as arrays, maps, sets, and indefinite
  length bytestrings. And we can do this without performing I/O in the
  decoder itself.
  
  The core of the sans I/O decoder will probably not be used directly.
  Instead, it is expected that we'll build utility functions for invoking
  the decoder given specific input types. This will allow extreme
  flexibility in how data is delivered to the decoder.
  
  I'm pretty happy with the state of the decoder modulo the TODO items
  to track wanted features to help with a "hardened" decoder. The one
  thing I could be convinced to change is the handling of semantic tags.
  Since we only support a single semantic tag (sets), I thought it would
  be easier to handle them inline in decodeitem(). This is simpler now.
  But if we add support for other semantic tags, it will likely be easier
  to move semantic tag handling outside of decodeitem(). But, properly
  supporting semantic tags opens up a whole can of worms, as many
  semantic tags imply new types. I'm optimistic we won't need these in
  Mercurial. But who knows.
  
  I'm also pretty happy with the test coverage. Writing comprehensive
  tests for partial decoding did flush out a handful of bugs. One
  general improvement to testing would be fuzz testing for partial
  decoding. I may implement that later. I also anticipate switching the
  wire protocol code to this new decoder will flush out any lingering
  bugs.

REPOSITORY
  rHG Mercurial

REVISION DETAIL
  https://phab.mercurial-scm.org/D4414

AFFECTED FILES
  mercurial/utils/cborutil.py
  tests/test-cbor.py

CHANGE DETAILS

diff --git a/tests/test-cbor.py b/tests/test-cbor.py
--- a/tests/test-cbor.py
+++ b/tests/test-cbor.py
@@ -10,10 +10,17 @@
     cborutil,
 )
 
+class TestCase(unittest.TestCase):
+    if not getattr(unittest.TestCase, 'assertRaisesRegex', False):
+        # Python 3.7 deprecates the regex*p* version, but 2.7 lacks
+        # the regex version.
+        assertRaisesRegex = (# camelcase-required
+            unittest.TestCase.assertRaisesRegexp)
+
 def loadit(it):
     return cbor.loads(b''.join(it))
 
-class BytestringTests(unittest.TestCase):
+class BytestringTests(TestCase):
     def testsimple(self):
         self.assertEqual(
             list(cborutil.streamencode(b'foobar')),
@@ -23,11 +30,20 @@
             loadit(cborutil.streamencode(b'foobar')),
             b'foobar')
 
+        self.assertEqual(cborutil.decodeall(b'\x46foobar'),
+                         [b'foobar'])
+
+        self.assertEqual(cborutil.decodeall(b'\x46foobar\x45fizbi'),
+                         [b'foobar', b'fizbi'])
+
     def testlong(self):
         source = b'x' * 1048576
 
         self.assertEqual(loadit(cborutil.streamencode(source)), source)
 
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
     def testfromiter(self):
         # This is the example from RFC 7049 Section 2.2.2.
         source = [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99']
@@ -47,6 +63,25 @@
             loadit(cborutil.streamencodebytestringfromiter(source)),
             b''.join(source))
 
+        self.assertEqual(cborutil.decodeall(b'\x5f\x44\xaa\xbb\xcc\xdd'
+                                            b'\x43\xee\xff\x99\xff'),
+                         [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99', b''])
+
+        for i, chunk in enumerate(
+            cborutil.decodeall(b'\x5f\x44\xaa\xbb\xcc\xdd'
+                               b'\x43\xee\xff\x99\xff')):
+            self.assertIsInstance(chunk, cborutil.bytestringchunk)
+
+            if i == 0:
+                self.assertTrue(chunk.isfirst)
+            else:
+                self.assertFalse(chunk.isfirst)
+
+            if i == 2:
+                self.assertTrue(chunk.islast)
+            else:
+                self.assertFalse(chunk.islast)
+
     def testfromiterlarge(self):
         source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576]
 
@@ -71,6 +106,18 @@
             source, chunksize=42))
         self.assertEqual(cbor.loads(dest), source)
 
+        self.assertEqual(b''.join(cborutil.decodeall(dest)), source)
+
+        for chunk in cborutil.decodeall(dest):
+            self.assertIsInstance(chunk, cborutil.bytestringchunk)
+            self.assertIn(len(chunk), (0, 8, 42))
+
+        encoded = b'\x5f\xff'
+        b = cborutil.decodeall(encoded)
+        self.assertEqual(b, [b''])
+        self.assertTrue(b[0].isfirst)
+        self.assertTrue(b[0].islast)
+
     def testreadtoiter(self):
         source = io.BytesIO(b'\x5f\x44\xaa\xbb\xcc\xdd\x43\xee\xff\x99\xff')
 
@@ -81,42 +128,405 @@
         with self.assertRaises(StopIteration):
             next(it)
 
-class IntTests(unittest.TestCase):
+    def testdecodevariouslengths(self):
+        for i in (0, 1, 22, 23, 24, 25, 254, 255, 256, 65534, 65535, 65536):
+            source = b'x' * i
+            encoded = b''.join(cborutil.streamencode(source))
+
+            if len(source) < 24:
+                hlen = 1
+            elif len(source) < 256:
+                hlen = 2
+            elif len(source) < 65536:
+                hlen = 3
+            elif len(source) < 1048576:
+                hlen = 5
+
+            self.assertEqual(cborutil.decodeitem(encoded),
+                             (True, source, hlen + len(source),
+                              cborutil.SPECIAL_NONE))
+
+    def testpartialdecode(self):
+        encoded = b''.join(cborutil.streamencode(b'foobar'))
+
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -6, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -5, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -4, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (False, None, -3, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:7]),
+                         (True, b'foobar', 7, cborutil.SPECIAL_NONE))
+
+    def testpartialdecodevariouslengths(self):
+        lens = [
+            2,
+            3,
+            10,
+            23,
+            24,
+            25,
+            31,
+            100,
+            254,
+            255,
+            256,
+            257,
+            16384,
+            65534,
+            65535,
+            65536,
+            65537,
+            131071,
+            131072,
+            131073,
+            1048575,
+            1048576,
+            1048577,
+        ]
+
+        for size in lens:
+            if size < 24:
+                hlen = 1
+            elif size < 2**8:
+                hlen = 2
+            elif size < 2**16:
+                hlen = 3
+            elif size < 2**32:
+                hlen = 5
+            else:
+                assert False
+
+            source = b'x' * size
+            encoded = b''.join(cborutil.streamencode(source))
+
+            res = cborutil.decodeitem(encoded[0:1])
+
+            if hlen > 1:
+                self.assertEqual(res, (False, None, -(hlen - 1),
+                                       cborutil.SPECIAL_NONE))
+            else:
+                self.assertEqual(res, (False, None, -(size + hlen - 1),
+                                       cborutil.SPECIAL_NONE))
+
+            # Decoding partial header reports remaining header size.
+            for i in range(hlen - 1):
+                self.assertEqual(cborutil.decodeitem(encoded[0:i + 1]),
+                                 (False, None, -(hlen - i - 1),
+                                  cborutil.SPECIAL_NONE))
+
+            # Decoding complete header reports item size.
+            self.assertEqual(cborutil.decodeitem(encoded[0:hlen]),
+                             (False, None, -size, cborutil.SPECIAL_NONE))
+
+            # Decoding single byte after header reports item size - 1
+            self.assertEqual(cborutil.decodeitem(encoded[0:hlen + 1]),
+                             (False, None, -(size - 1), cborutil.SPECIAL_NONE))
+
+            # Decoding all but the last byte reports -1 needed.
+            self.assertEqual(cborutil.decodeitem(encoded[0:hlen + size - 1]),
+                             (False, None, -1, cborutil.SPECIAL_NONE))
+
+            # Decoding last byte retrieves value.
+            self.assertEqual(cborutil.decodeitem(encoded[0:hlen + size]),
+                             (True, source, hlen + size, cborutil.SPECIAL_NONE))
+
+    def testindefinitepartialdecode(self):
+        encoded = b''.join(cborutil.streamencodebytestringfromiter(
+            [b'foobar', b'biz']))
+
+        # First item should be begin of bytestring special.
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (True, None, 1,
+                          cborutil.SPECIAL_START_INDEFINITE_BYTESTRING))
+
+        # Second item should be the first chunk. But only available when
+        # we give it 7 bytes (1 byte header + 6 byte chunk).
+        self.assertEqual(cborutil.decodeitem(encoded[1:2]),
+                         (False, None, -6, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[1:3]),
+                         (False, None, -5, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[1:4]),
+                         (False, None, -4, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[1:5]),
+                         (False, None, -3, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[1:6]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[1:7]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+
+        self.assertEqual(cborutil.decodeitem(encoded[1:8]),
+                         (True, b'foobar', 7, cborutil.SPECIAL_NONE))
+
+        # Third item should be second chunk. But only available when
+        # we give it 4 bytes (1 byte header + 3 byte chunk).
+        self.assertEqual(cborutil.decodeitem(encoded[8:9]),
+                         (False, None, -3, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[8:10]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[8:11]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+
+        self.assertEqual(cborutil.decodeitem(encoded[8:12]),
+                         (True, b'biz', 4, cborutil.SPECIAL_NONE))
+
+        # Fourth item should be end of indefinite stream marker.
+        self.assertEqual(cborutil.decodeitem(encoded[12:13]),
+                         (True, None, 1, cborutil.SPECIAL_INDEFINITE_BREAK))
+
+        # Now test the behavior when going through the decoder.
+
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:1]),
+                         (False, 1, 0))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:2]),
+                         (False, 1, 6))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:3]),
+                         (False, 1, 5))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:4]),
+                         (False, 1, 4))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:5]),
+                         (False, 1, 3))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:6]),
+                         (False, 1, 2))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:7]),
+                         (False, 1, 1))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:8]),
+                         (True, 8, 0))
+
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:9]),
+                         (True, 8, 3))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:10]),
+                         (True, 8, 2))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:11]),
+                         (True, 8, 1))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:12]),
+                         (True, 12, 0))
+
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:13]),
+                         (True, 13, 0))
+
+        decoder = cborutil.sansiodecoder()
+        decoder.decode(encoded[0:8])
+        values = decoder.getavailable()
+        self.assertEqual(values, [b'foobar'])
+        self.assertTrue(values[0].isfirst)
+        self.assertFalse(values[0].islast)
+
+        self.assertEqual(decoder.decode(encoded[8:12]),
+                         (True, 4, 0))
+        values = decoder.getavailable()
+        self.assertEqual(values, [b'biz'])
+        self.assertFalse(values[0].isfirst)
+        self.assertFalse(values[0].islast)
+
+        self.assertEqual(decoder.decode(encoded[12:]),
+                         (True, 1, 0))
+        values = decoder.getavailable()
+        self.assertEqual(values, [b''])
+        self.assertFalse(values[0].isfirst)
+        self.assertTrue(values[0].islast)
+
+class StringTests(TestCase):
+    def testdecodeforbidden(self):
+        encoded = b'\x63foo'
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'string major type not supported'):
+            cborutil.decodeall(encoded)
+
+class IntTests(TestCase):
     def testsmall(self):
         self.assertEqual(list(cborutil.streamencode(0)), [b'\x00'])
+        self.assertEqual(cborutil.decodeall(b'\x00'), [0])
+
         self.assertEqual(list(cborutil.streamencode(1)), [b'\x01'])
+        self.assertEqual(cborutil.decodeall(b'\x01'), [1])
+
         self.assertEqual(list(cborutil.streamencode(2)), [b'\x02'])
+        self.assertEqual(cborutil.decodeall(b'\x02'), [2])
+
         self.assertEqual(list(cborutil.streamencode(3)), [b'\x03'])
+        self.assertEqual(cborutil.decodeall(b'\x03'), [3])
+
         self.assertEqual(list(cborutil.streamencode(4)), [b'\x04'])
+        self.assertEqual(cborutil.decodeall(b'\x04'), [4])
+
+        # Multiple value decode works.
+        self.assertEqual(cborutil.decodeall(b'\x00\x01\x02\x03\x04'),
+                         [0, 1, 2, 3, 4])
 
     def testnegativesmall(self):
         self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20'])
+        self.assertEqual(cborutil.decodeall(b'\x20'), [-1])
+
         self.assertEqual(list(cborutil.streamencode(-2)), [b'\x21'])
+        self.assertEqual(cborutil.decodeall(b'\x21'), [-2])
+
         self.assertEqual(list(cborutil.streamencode(-3)), [b'\x22'])
+        self.assertEqual(cborutil.decodeall(b'\x22'), [-3])
+
         self.assertEqual(list(cborutil.streamencode(-4)), [b'\x23'])
+        self.assertEqual(cborutil.decodeall(b'\x23'), [-4])
+
         self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24'])
+        self.assertEqual(cborutil.decodeall(b'\x24'), [-5])
+
+        # Multiple value decode works.
+        self.assertEqual(cborutil.decodeall(b'\x20\x21\x22\x23\x24'),
+                         [-1, -2, -3, -4, -5])
 
     def testrange(self):
         for i in range(-70000, 70000, 10):
-            self.assertEqual(
-                b''.join(cborutil.streamencode(i)),
-                cbor.dumps(i))
+            encoded = b''.join(cborutil.streamencode(i))
+
+            self.assertEqual(encoded, cbor.dumps(i))
+            self.assertEqual(cborutil.decodeall(encoded), [i])
+
+    def testdecodepartialubyte(self):
+        encoded = b''.join(cborutil.streamencode(250))
+
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (True, 250, 2, cborutil.SPECIAL_NONE))
+
+    def testdecodepartialbyte(self):
+        encoded = b''.join(cborutil.streamencode(-42))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (True, -42, 2, cborutil.SPECIAL_NONE))
+
+    def testdecodepartialushort(self):
+        encoded = b''.join(cborutil.streamencode(2**15))
+
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (True, 2**15, 3, cborutil.SPECIAL_NONE))
+
+    def testdecodepartialshort(self):
+        encoded = b''.join(cborutil.streamencode(-1024))
+
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (True, -1024, 3, cborutil.SPECIAL_NONE))
+
+    def testdecodepartialulong(self):
+        encoded = b''.join(cborutil.streamencode(2**28))
+
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -4, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -3, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (True, 2**28, 5, cborutil.SPECIAL_NONE))
+
+    def testdecodepartiallong(self):
+        encoded = b''.join(cborutil.streamencode(-1048580))
 
-class ArrayTests(unittest.TestCase):
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -4, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -3, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (True, -1048580, 5, cborutil.SPECIAL_NONE))
+
+    def testdecodepartialulonglong(self):
+        encoded = b''.join(cborutil.streamencode(2**32))
+
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -8, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -7, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -6, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (False, None, -5, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (False, None, -4, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+                         (False, None, -3, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:7]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:8]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:9]),
+                         (True, 2**32, 9, cborutil.SPECIAL_NONE))
+
+        with self.assertRaisesRegex(
+            cborutil.CBORDecodeError, 'input data not fully consumed'):
+            cborutil.decodeall(encoded[0:1])
+
+        with self.assertRaisesRegex(
+            cborutil.CBORDecodeError, 'input data not fully consumed'):
+            cborutil.decodeall(encoded[0:2])
+
+    def testdecodepartiallonglong(self):
+        encoded = b''.join(cborutil.streamencode(-7000000000))
+
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -8, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -7, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -6, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (False, None, -5, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (False, None, -4, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+                         (False, None, -3, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:7]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:8]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:9]),
+                         (True, -7000000000, 9, cborutil.SPECIAL_NONE))
+
+class ArrayTests(TestCase):
     def testempty(self):
         self.assertEqual(list(cborutil.streamencode([])), [b'\x80'])
         self.assertEqual(loadit(cborutil.streamencode([])), [])
 
+        self.assertEqual(cborutil.decodeall(b'\x80'), [[]])
+
     def testbasic(self):
         source = [b'foo', b'bar', 1, -10]
 
-        self.assertEqual(list(cborutil.streamencode(source)), [
-            b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29'])
+        chunks = [
+            b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29']
+
+        self.assertEqual(list(cborutil.streamencode(source)), chunks)
+
+        self.assertEqual(cborutil.decodeall(b''.join(chunks)), [source])
 
     def testemptyfromiter(self):
         self.assertEqual(b''.join(cborutil.streamencodearrayfromiter([])),
                          b'\x9f\xff')
 
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'indefinite length uint not allowed'):
+            cborutil.decodeall(b'\x9f\xff')
+
     def testfromiter1(self):
         source = [b'foo']
 
@@ -129,57 +539,241 @@
         dest = b''.join(cborutil.streamencodearrayfromiter(source))
         self.assertEqual(cbor.loads(dest), source)
 
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'indefinite length uint not allowed'):
+            cborutil.decodeall(dest)
+
     def testtuple(self):
         source = (b'foo', None, 42)
+        encoded = b''.join(cborutil.streamencode(source))
 
-        self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
-                         list(source))
+        self.assertEqual(cbor.loads(encoded), list(source))
+
+        self.assertEqual(cborutil.decodeall(encoded), [list(source)])
+
+    def testpartialdecode(self):
+        source = list(range(4))
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (True, 4, 1, cborutil.SPECIAL_START_ARRAY))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (True, 4, 1, cborutil.SPECIAL_START_ARRAY))
+
+        source = list(range(23))
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (True, 23, 1, cborutil.SPECIAL_START_ARRAY))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (True, 23, 1, cborutil.SPECIAL_START_ARRAY))
+
+        source = list(range(24))
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (True, 24, 2, cborutil.SPECIAL_START_ARRAY))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (True, 24, 2, cborutil.SPECIAL_START_ARRAY))
 
-class SetTests(unittest.TestCase):
+        source = list(range(256))
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (True, 256, 3, cborutil.SPECIAL_START_ARRAY))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (True, 256, 3, cborutil.SPECIAL_START_ARRAY))
+
+    def testnested(self):
+        source = [[], [], [[], [], []]]
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
+        source = [True, None, [True, 0, 2], [None], [], [[[]], -87]]
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
+        # A set within an array.
+        source = [None, {b'foo', b'bar', None, False}, set()]
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
+        # A map within an array.
+        source = [None, {}, {b'foo': b'bar', True: False}, [{}]]
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
+    def testindefinitebytestringvalues(self):
+        # Single value array whose value is an empty indefinite bytestring.
+        encoded = b'\x81\x5f\x40\xff'
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'indefinite length bytestrings not '
+                                    'allowed as array values'):
+            cborutil.decodeall(encoded)
+
+class SetTests(TestCase):
     def testempty(self):
         self.assertEqual(list(cborutil.streamencode(set())), [
             b'\xd9\x01\x02',
             b'\x80',
         ])
 
+        self.assertEqual(cborutil.decodeall(b'\xd9\x01\x02\x80'), [set()])
+
     def testset(self):
         source = {b'foo', None, 42}
+        encoded = b''.join(cborutil.streamencode(source))
 
-        self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
-                         source)
+        self.assertEqual(cbor.loads(encoded), source)
+
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
+    def testinvalidtag(self):
+        # Must use array to encode sets.
+        encoded = b'\xd9\x01\x02\xa0'
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'expected array after finite set '
+                                    'semantic tag'):
+            cborutil.decodeall(encoded)
+
+    def testpartialdecode(self):
+        # Semantic tag item will be 3 bytes. Set header will be variable
+        # depending on length.
+        encoded = b''.join(cborutil.streamencode({i for i in range(23)}))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (True, 23, 4, cborutil.SPECIAL_START_SET))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (True, 23, 4, cborutil.SPECIAL_START_SET))
+
+        encoded = b''.join(cborutil.streamencode({i for i in range(24)}))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (True, 24, 5, cborutil.SPECIAL_START_SET))
+        self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+                         (True, 24, 5, cborutil.SPECIAL_START_SET))
 
-class BoolTests(unittest.TestCase):
+        encoded = b''.join(cborutil.streamencode({i for i in range(256)}))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+                         (True, 256, 6, cborutil.SPECIAL_START_SET))
+
+    def testinvalidvalue(self):
+        encoded = b''.join([
+            b'\xd9\x01\x02', # semantic tag
+            b'\x81', # array of size 1
+            b'\x5f\x43foo\xff', # indefinite length bytestring "foo"
+        ])
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'indefinite length bytestrings not '
+                                    'allowed as set values'):
+            cborutil.decodeall(encoded)
+
+        encoded = b''.join([
+            b'\xd9\x01\x02',
+            b'\x81',
+            b'\x80', # empty array
+        ])
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'collections not allowed as set values'):
+            cborutil.decodeall(encoded)
+
+        encoded = b''.join([
+            b'\xd9\x01\x02',
+            b'\x81',
+            b'\xa0', # empty map
+        ])
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'collections not allowed as set values'):
+            cborutil.decodeall(encoded)
+
+        encoded = b''.join([
+            b'\xd9\x01\x02',
+            b'\x81',
+            b'\xd9\x01\x02\x81\x01', # set with integer 1
+        ])
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'collections not allowed as set values'):
+            cborutil.decodeall(encoded)
+
+class BoolTests(TestCase):
     def testbasic(self):
         self.assertEqual(list(cborutil.streamencode(True)),  [b'\xf5'])
         self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4'])
 
         self.assertIs(loadit(cborutil.streamencode(True)), True)
         self.assertIs(loadit(cborutil.streamencode(False)), False)
 
-class NoneTests(unittest.TestCase):
+        self.assertEqual(cborutil.decodeall(b'\xf4'), [False])
+        self.assertEqual(cborutil.decodeall(b'\xf5'), [True])
+
+        self.assertEqual(cborutil.decodeall(b'\xf4\xf5\xf5\xf4'),
+                         [False, True, True, False])
+
+class NoneTests(TestCase):
     def testbasic(self):
         self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6'])
 
         self.assertIs(loadit(cborutil.streamencode(None)), None)
 
-class MapTests(unittest.TestCase):
+        self.assertEqual(cborutil.decodeall(b'\xf6'), [None])
+        self.assertEqual(cborutil.decodeall(b'\xf6\xf6'), [None, None])
+
+class MapTests(TestCase):
     def testempty(self):
         self.assertEqual(list(cborutil.streamencode({})), [b'\xa0'])
         self.assertEqual(loadit(cborutil.streamencode({})), {})
 
+        self.assertEqual(cborutil.decodeall(b'\xa0'), [{}])
+
     def testemptyindefinite(self):
         self.assertEqual(list(cborutil.streamencodemapfromiter([])), [
             b'\xbf', b'\xff'])
 
         self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {})
 
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'indefinite length uint not allowed'):
+            cborutil.decodeall(b'\xbf\xff')
+
     def testone(self):
         source = {b'foo': b'bar'}
         self.assertEqual(list(cborutil.streamencode(source)), [
             b'\xa1', b'\x43', b'foo', b'\x43', b'bar'])
 
         self.assertEqual(loadit(cborutil.streamencode(source)), source)
 
+        self.assertEqual(cborutil.decodeall(b'\xa1\x43foo\x43bar'), [source])
+
     def testmultiple(self):
         source = {
             b'foo': b'bar',
@@ -192,6 +786,9 @@
             loadit(cborutil.streamencodemapfromiter(source.items())),
             source)
 
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
     def testcomplex(self):
         source = {
             b'key': 1,
@@ -205,6 +802,170 @@
             loadit(cborutil.streamencodemapfromiter(source.items())),
             source)
 
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
+    def testnested(self):
+        source = {b'key1': None, b'key2': {b'sub1': b'sub2'}, b'sub2': {}}
+        encoded = b''.join(cborutil.streamencode(source))
+
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
+        source = {
+            b'key1': [],
+            b'key2': [None, False],
+            b'key3': {b'foo', b'bar'},
+            b'key4': {},
+        }
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
+    def testillegalkey(self):
+        encoded = b''.join([
+            # map header + len 1
+            b'\xa1',
+            # indefinite length bytestring "foo" in key position
+            b'\x5f\x03foo\xff'
+        ])
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'indefinite length bytestrings not '
+                                    'allowed as map keys'):
+            cborutil.decodeall(encoded)
+
+        encoded = b''.join([
+            b'\xa1',
+            b'\x80', # empty array
+            b'\x43foo',
+        ])
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'collections not supported as map keys'):
+            cborutil.decodeall(encoded)
+
+    def testillegalvalue(self):
+        encoded = b''.join([
+            b'\xa1', # map headers
+            b'\x43foo', # key
+            b'\x5f\x03bar\xff', # indefinite length value
+        ])
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'indefinite length bytestrings not '
+                                    'allowed as map values'):
+            cborutil.decodeall(encoded)
+
+    def testpartialdecode(self):
+        source = {b'key1': b'value1'}
+        encoded = b''.join(cborutil.streamencode(source))
+
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (True, 1, 1, cborutil.SPECIAL_START_MAP))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (True, 1, 1, cborutil.SPECIAL_START_MAP))
+
+        source = {b'key%d' % i: None for i in range(23)}
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (True, 23, 1, cborutil.SPECIAL_START_MAP))
+
+        source = {b'key%d' % i: None for i in range(24)}
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (True, 24, 2, cborutil.SPECIAL_START_MAP))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (True, 24, 2, cborutil.SPECIAL_START_MAP))
+
+        source = {b'key%d' % i: None for i in range(256)}
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (True, 256, 3, cborutil.SPECIAL_START_MAP))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (True, 256, 3, cborutil.SPECIAL_START_MAP))
+
+        source = {b'key%d' % i: None for i in range(65536)}
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -4, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -3, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (True, 65536, 5, cborutil.SPECIAL_START_MAP))
+        self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+                         (True, 65536, 5, cborutil.SPECIAL_START_MAP))
+
+class SemanticTagTests(TestCase):
+    def testdecodeforbidden(self):
+        for i in range(500):
+            if i == cborutil.SEMANTIC_TAG_FINITE_SET:
+                continue
+
+            tag = cborutil.encodelength(cborutil.MAJOR_TYPE_SEMANTIC,
+                                        i)
+
+            encoded = tag + cborutil.encodelength(cborutil.MAJOR_TYPE_UINT, 42)
+
+            # Partial decode is incomplete.
+            if i < 24:
+                pass
+            elif i < 256:
+                self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                                 (False, None, -1, cborutil.SPECIAL_NONE))
+            elif i < 65536:
+                self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                                 (False, None, -2, cborutil.SPECIAL_NONE))
+                self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                                 (False, None, -1, cborutil.SPECIAL_NONE))
+
+            with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                        'semantic tag \d+ not allowed'):
+                cborutil.decodeitem(encoded)
+
+class SpecialTypesTests(TestCase):
+    def testforbiddentypes(self):
+        for i in range(256):
+            if i == cborutil.SUBTYPE_FALSE:
+                continue
+            elif i == cborutil.SUBTYPE_TRUE:
+                continue
+            elif i == cborutil.SUBTYPE_NULL:
+                continue
+
+            encoded = cborutil.encodelength(cborutil.MAJOR_TYPE_SPECIAL, i)
+
+            with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                        'special type \d+ not allowed'):
+                cborutil.decodeitem(encoded)
+
+class SansIODecoderTests(TestCase):
+    def testemptyinput(self):
+        decoder = cborutil.sansiodecoder()
+        self.assertEqual(decoder.decode(b''), (False, 0, 0))
+
+class DecodeallTests(TestCase):
+    def testemptyinput(self):
+        self.assertEqual(cborutil.decodeall(b''), [])
+
+    def testpartialinput(self):
+        encoded = b''.join([
+            b'\x82', # array of 2 elements
+            b'\x01', # integer 1
+        ])
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'input data not complete'):
+            cborutil.decodeall(encoded)
+
 if __name__ == '__main__':
     import silenttestrunner
     silenttestrunner.main(__name__)
diff --git a/mercurial/utils/cborutil.py b/mercurial/utils/cborutil.py
--- a/mercurial/utils/cborutil.py
+++ b/mercurial/utils/cborutil.py
@@ -8,6 +8,7 @@
 from __future__ import absolute_import
 
 import struct
+import sys
 
 from ..thirdparty.cbor.cbor2 import (
     decoder as decodermod,
@@ -35,11 +36,16 @@
 
 SUBTYPE_MASK = 0b00011111
 
+SUBTYPE_FALSE = 20
+SUBTYPE_TRUE = 21
+SUBTYPE_NULL = 22
 SUBTYPE_HALF_FLOAT = 25
 SUBTYPE_SINGLE_FLOAT = 26
 SUBTYPE_DOUBLE_FLOAT = 27
 SUBTYPE_INDEFINITE = 31
 
+SEMANTIC_TAG_FINITE_SET = 258
+
 # Indefinite types begin with their major type ORd with information value 31.
 BEGIN_INDEFINITE_BYTESTRING = struct.pack(
     r'>B', MAJOR_TYPE_BYTESTRING << 5 | SUBTYPE_INDEFINITE)
@@ -146,7 +152,7 @@
 def streamencodeset(s):
     # https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml defines
     # semantic tag 258 for finite sets.
-    yield encodelength(MAJOR_TYPE_SEMANTIC, 258)
+    yield encodelength(MAJOR_TYPE_SEMANTIC, SEMANTIC_TAG_FINITE_SET)
 
     for chunk in streamencodearray(sorted(s, key=_mixedtypesortkey)):
         yield chunk
@@ -260,3 +266,710 @@
                     len(chunk), length))
 
         yield chunk
+
+class CBORDecodeError(Exception):
+    """Represents an error decoding CBOR."""
+
+if sys.version_info.major >= 3:
+    def _elementtointeger(b, i):
+        return b[i]
+else:
+    def _elementtointeger(b, i):
+        return ord(b[i])
+
+STRUCT_BIG_UBYTE = struct.Struct(r'>B')
+STRUCT_BIG_USHORT = struct.Struct('>H')
+STRUCT_BIG_ULONG = struct.Struct('>L')
+STRUCT_BIG_ULONGLONG = struct.Struct('>Q')
+
+SPECIAL_NONE = 0
+SPECIAL_START_INDEFINITE_BYTESTRING = 1
+SPECIAL_START_ARRAY = 2
+SPECIAL_START_MAP = 3
+SPECIAL_START_SET = 4
+SPECIAL_INDEFINITE_BREAK = 5
+
+def decodeitem(b, offset=0):
+    """Decode a new CBOR value from a buffer at offset.
+
+    This function attempts to decode up to one complete CBOR value
+    from ``b`` starting at offset ``offset``.
+
+    The beginning of a collection (such as an array, map, set, or
+    indefinite length bytestring) counts as a single value. For these
+    special cases, a state flag will indicate that a special value was seen.
+
+    When called, the function either returns a decoded value or gives
+    a hint as to how many more bytes are needed to do so. By calling
+    the function repeatedly given a stream of bytes, the caller can
+    build up the original values.
+
+    Returns a tuple with the following elements:
+
+    * Bool indicating whether a complete value was decoded.
+    * A decoded value if first value is True otherwise None
+    * Integer number of bytes. If positive, the number of bytes
+      read. If negative, the number of bytes we need to read to
+      decode this value or the next chunk in this value.
+    * One of the ``SPECIAL_*`` constants indicating special treatment
+      for this value. ``SPECIAL_NONE`` means this is a fully decoded
+      simple value (such as an integer or bool).
+    """
+
+    initial = _elementtointeger(b, offset)
+    offset += 1
+
+    majortype = initial >> 5
+    subtype = initial & SUBTYPE_MASK
+
+    if majortype == MAJOR_TYPE_UINT:
+        complete, value, readcount = decodeuint(subtype, b, offset)
+
+        if complete:
+            return True, value, readcount + 1, SPECIAL_NONE
+        else:
+            return False, None, readcount, SPECIAL_NONE
+
+    elif majortype == MAJOR_TYPE_NEGINT:
+        # Negative integers are the same as UINT except inverted minus 1.
+        complete, value, readcount = decodeuint(subtype, b, offset)
+
+        if complete:
+            return True, -value - 1, readcount + 1, SPECIAL_NONE
+        else:
+            return False, None, readcount, SPECIAL_NONE
+
+    elif majortype == MAJOR_TYPE_BYTESTRING:
+        # Beginning of bytestrings are treated as uints in order to
+        # decode their length, which may be indefinite.
+        complete, size, readcount = decodeuint(subtype, b, offset,
+                                               allowindefinite=True)
+
+        # We don't know the size of the bytestring. It must be a definitive
+        # length since the indefinite subtype would be encoded in the initial
+        # byte.
+        if not complete:
+            return False, None, readcount, SPECIAL_NONE
+
+        # We know the length of the bytestring.
+        if size is not None:
+            # And the data is available in the buffer.
+            if offset + readcount + size <= len(b):
+                value = b[offset + readcount:offset + readcount + size]
+                return True, value, readcount + size + 1, SPECIAL_NONE
+
+            # And we need more data in order to return the bytestring.
+            else:
+                wanted = len(b) - offset - readcount - size
+                return False, None, wanted, SPECIAL_NONE
+
+        # It is an indefinite length bytestring.
+        else:
+            return True, None, 1, SPECIAL_START_INDEFINITE_BYTESTRING
+
+    elif majortype == MAJOR_TYPE_STRING:
+        raise CBORDecodeError('string major type not supported')
+
+    elif majortype == MAJOR_TYPE_ARRAY:
+        # Beginning of arrays are treated as uints in order to decode their
+        # length. We don't allow indefinite length arrays.
+        complete, size, readcount = decodeuint(subtype, b, offset)
+
+        if complete:
+            return True, size, readcount + 1, SPECIAL_START_ARRAY
+        else:
+            return False, None, readcount, SPECIAL_NONE
+
+    elif majortype == MAJOR_TYPE_MAP:
+        # Beginning of maps are treated as uints in order to decode their
+        # number of elements. We don't allow indefinite length arrays.
+        complete, size, readcount = decodeuint(subtype, b, offset)
+
+        if complete:
+            return True, size, readcount + 1, SPECIAL_START_MAP
+        else:
+            return False, None, readcount, SPECIAL_NONE
+
+    elif majortype == MAJOR_TYPE_SEMANTIC:
+        # Semantic tag value is read the same as a uint.
+        complete, tagvalue, readcount = decodeuint(subtype, b, offset)
+
+        if not complete:
+            return False, None, readcount, SPECIAL_NONE
+
+        # This behavior here is a little wonky. The main type being "decorated"
+        # by this semantic tag follows. A more robust parser would probably emit
+        # a special flag indicating this as a semantic tag and let the caller
+        # deal with the types that follow. But since we don't support many
+        # semantic tags, it is easier to deal with the special cases here and
+        # hide complexity from the caller. If we add support for more semantic
+        # tags, we should probably move semantic tag handling into the caller.
+        if tagvalue == SEMANTIC_TAG_FINITE_SET:
+            if offset + readcount >= len(b):
+                return False, None, -1, SPECIAL_NONE
+
+            complete, size, readcount2, special = decodeitem(b,
+                                                             offset + readcount)
+
+            if not complete:
+                return False, None, readcount2, SPECIAL_NONE
+
+            if special != SPECIAL_START_ARRAY:
+                raise CBORDecodeError('expected array after finite set '
+                                      'semantic tag')
+
+            return True, size, readcount + readcount2 + 1, SPECIAL_START_SET
+
+        else:
+            raise CBORDecodeError('semantic tag %d not allowed' % tagvalue)
+
+    elif majortype == MAJOR_TYPE_SPECIAL:
+        # Only specific values for the information field are allowed.
+        if subtype == SUBTYPE_FALSE:
+            return True, False, 1, SPECIAL_NONE
+        elif subtype == SUBTYPE_TRUE:
+            return True, True, 1, SPECIAL_NONE
+        elif subtype == SUBTYPE_NULL:
+            return True, None, 1, SPECIAL_NONE
+        elif subtype == SUBTYPE_INDEFINITE:
+            return True, None, 1, SPECIAL_INDEFINITE_BREAK
+        # If value is 24, subtype is in next byte.
+        else:
+            raise CBORDecodeError('special type %d not allowed' % subtype)
+    else:
+        assert False
+
+def decodeuint(subtype, b, offset=0, allowindefinite=False):
+    """Decode an unsigned integer.
+
+    ``subtype`` is the lower 5 bits from the initial byte CBOR item
+    "header." ``b`` is a buffer containing bytes. ``offset`` points to
+    the index of the first byte after the byte that ``subtype`` was
+    derived from.
+
+    ``allowindefinite`` allows the special indefinite length value
+    indicator.
+
+    Returns a 3-tuple of (successful, value, count).
+
+    The first element is a bool indicating if decoding completed. The 2nd
+    is the decoded integer value or None if not fully decoded or the subtype
+    is 31 and ``allowindefinite`` is True. The 3rd value is the count of bytes.
+    If positive, it is the number of additional bytes decoded. If negative,
+    it is the number of additional bytes needed to decode this value.
+    """
+
+    # Small values are inline.
+    if subtype < 24:
+        return True, subtype, 0
+    # Indefinite length specifier.
+    elif subtype == 31:
+        if allowindefinite:
+            return True, None, 0
+        else:
+            raise CBORDecodeError('indefinite length uint not allowed here')
+    elif subtype >= 28:
+        raise CBORDecodeError('unsupported subtype on integer type: %d' %
+                              subtype)
+
+    if subtype == 24:
+        s = STRUCT_BIG_UBYTE
+    elif subtype == 25:
+        s = STRUCT_BIG_USHORT
+    elif subtype == 26:
+        s = STRUCT_BIG_ULONG
+    elif subtype == 27:
+        s = STRUCT_BIG_ULONGLONG
+    else:
+        raise CBORDecodeError('bounds condition checking violation')
+
+    if len(b) - offset >= s.size:
+        return True, s.unpack_from(b, offset)[0], s.size
+    else:
+        return False, None, len(b) - offset - s.size
+
+class bytestringchunk(bytes):
+    """Represents a chunk/segment in an indefinite length bytestring.
+
+    This behaves like a ``bytes`` but in addition has the ``isfirst``
+    and ``islast`` attributes indicating whether this chunk is the first
+    or last in an indefinite length bytestring.
+    """
+
+    def __new__(cls, v, first=False, last=False):
+        self = bytes.__new__(cls, v)
+        self.isfirst = first
+        self.islast = last
+
+        return self
+
+class sansiodecoder(object):
+    """A CBOR decoder that doesn't perform its own I/O.
+
+    To use, construct an instance and feed it segments containing
+    CBOR-encoded bytes via ``decode()``. The return value from ``decode()``
+    indicates whether a fully-decoded value is available, how many bytes
+    were consumed, and offers a hint as to how many bytes should be fed
+    in next time to decode the next value.
+
+    The decoder assumes it will decode N discrete CBOR values, not just
+    a single value. i.e. if the bytestream contains uints packed one after
+    the other, the decoder will decode them all, rather than just the initial
+    one.
+
+    When ``decode()`` indicates a value is available, call ``getavailable()``
+    to return all fully decoded values.
+
+    ``decode()`` can partially decode input. It is up to the caller to keep
+    track of what data was consumed and to pass unconsumed data in on the
+    next invocation.
+
+    The decoder decodes atomically at the *item* level. See ``decodeitem()``.
+    If an *item* cannot be fully decoded, the decoder won't record it as
+    partially consumed. Instead, the caller will be instructed to pass in
+    the initial bytes of this item on the next invocation. This does result
+    in some redundant parsing. But the overhead should be minimal.
+
+    This decoder only supports a subset of CBOR as required by Mercurial.
+    It lacks support for:
+
+    * Indefinite length arrays
+    * Indefinite length maps
+    * Use of indefinite length bytestrings as keys or values within
+      arrays, maps, or sets.
+    * Nested arrays, maps, or sets within sets
+    * Any semantic tag that isn't a mathematical finite set
+    * Floating point numbers
+    * Undefined special value
+
+    CBOR types are decoded to Python types as follows:
+
+    uint -> int
+    negint -> int
+    bytestring -> bytes
+    map -> dict
+    array -> list
+    True -> bool
+    False -> bool
+    null -> None
+    indefinite length bytestring chunk -> [bytestringchunk]
+
+    The only non-obvious mapping here is an indefinite length bytestring
+    to the ``bytestringchunk`` type. This is to facilitate streaming
+    indefinite length bytestrings out of the decoder and to differentiate
+    a regular bytestring from an indefinite length bytestring.
+    """
+
+    _STATE_NONE = 0
+    _STATE_WANT_MAP_KEY = 1
+    _STATE_WANT_MAP_VALUE = 2
+    _STATE_WANT_ARRAY_VALUE = 3
+    _STATE_WANT_SET_VALUE = 4
+    _STATE_WANT_BYTESTRING_CHUNK_FIRST = 5
+    _STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT = 6
+
+    def __init__(self):
+        # TODO add support for limiting size of bytestrings
+        # TODO add support for limiting number of keys / values in collections
+        # TODO add support for limiting size of buffered partial values
+
+        self.decodedbytecount = 0
+
+        self._state = self._STATE_NONE
+
+        # Stack of active nested collections. Each entry is a dict describing
+        # the collection.
+        self._collectionstack = []
+
+        # Fully decoded key to use for the current map.
+        self._currentmapkey = None
+
+        # Fully decoded values available for retrieval.
+        self._decodedvalues = []
+
+    @property
+    def inprogress(self):
+        """Whether the decoder has partially decoded a value."""
+        return self._state != self._STATE_NONE
+
+    def decode(self, b, offset=0):
+        """Attempt to decode bytes from an input buffer.
+
+        ``b`` is a collection of bytes and ``offset`` is the byte
+        offset within that buffer from which to begin reading data.
+
+        ``b`` must support ``len()`` and accessing bytes slices via
+        ``__slice__``. Typically ``bytes`` instances are used.
+
+        Returns a tuple with the following fields:
+
+        * Bool indicating whether values are available for retrieval.
+        * Integer indicating the number of bytes that were fully consumed,
+          starting from ``offset``.
+        * Integer indicating the number of bytes that are desired for the
+          next call in order to decode an item.
+        """
+        if not b:
+            return bool(self._decodedvalues), 0, 0
+
+        initialoffset = offset
+
+        # We could easily split the body of this loop into a function. But
+        # Python performance is sensitive to function calls and collections
+        # are composed of many items. So leaving as a while loop could help
+        # with performance. One thing that may not help is the use of
+        # if..elif versus a lookup/dispatch table. There may be value
+        # in switching that.
+        while offset < len(b):
+            # Attempt to decode an item. This could be a whole value or a
+            # special value indicating an event, such as start or end of a
+            # collection or indefinite length type.
+            complete, value, readcount, special = decodeitem(b, offset)
+
+            if readcount > 0:
+                self.decodedbytecount += readcount
+
+            if not complete:
+                assert readcount < 0
+                return (
+                    bool(self._decodedvalues),
+                    offset - initialoffset,
+                    -readcount,
+                )
+
+            offset += readcount
+
+            # No nested state. We either have a full value or beginning of a
+            # complex value to deal with.
+            if self._state == self._STATE_NONE:
+                # A normal value.
+                if special == SPECIAL_NONE:
+                    self._decodedvalues.append(value)
+
+                elif special == SPECIAL_START_ARRAY:
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': [],
+                    })
+                    self._state = self._STATE_WANT_ARRAY_VALUE
+
+                elif special == SPECIAL_START_MAP:
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': {},
+                    })
+                    self._state = self._STATE_WANT_MAP_KEY
+
+                elif special == SPECIAL_START_SET:
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': set(),
+                    })
+                    self._state = self._STATE_WANT_SET_VALUE
+
+                elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
+                    self._state = self._STATE_WANT_BYTESTRING_CHUNK_FIRST
+
+                else:
+                    raise CBORDecodeError('unhandled special state: %d' %
+                                          special)
+
+            # This value becomes an element of the current array.
+            elif self._state == self._STATE_WANT_ARRAY_VALUE:
+                # Simple values get appended.
+                if special == SPECIAL_NONE:
+                    c = self._collectionstack[-1]
+                    c['v'].append(value)
+                    c['remaining'] -= 1
+
+                    # self._state doesn't need changed.
+
+                # An array nested within an array.
+                elif special == SPECIAL_START_ARRAY:
+                    lastc = self._collectionstack[-1]
+                    newvalue = []
+
+                    lastc['v'].append(newvalue)
+                    lastc['remaining'] -= 1
+
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': newvalue,
+                    })
+
+                    # self._state doesn't need changed.
+
+                # A map nested within an array.
+                elif special == SPECIAL_START_MAP:
+                    lastc = self._collectionstack[-1]
+                    newvalue = {}
+
+                    lastc['v'].append(newvalue)
+                    lastc['remaining'] -= 1
+
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': newvalue
+                    })
+
+                    self._state = self._STATE_WANT_MAP_KEY
+
+                elif special == SPECIAL_START_SET:
+                    lastc = self._collectionstack[-1]
+                    newvalue = set()
+
+                    lastc['v'].append(newvalue)
+                    lastc['remaining'] -= 1
+
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': newvalue,
+                    })
+
+                    self._state = self._STATE_WANT_SET_VALUE
+
+                elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
+                    raise CBORDecodeError('indefinite length bytestrings '
+                                          'not allowed as array values')
+
+                else:
+                    raise CBORDecodeError('unhandled special item when '
+                                          'expecting array value: %d' % special)
+
+            # This value becomes the key of the current map instance.
+            elif self._state == self._STATE_WANT_MAP_KEY:
+                if special == SPECIAL_NONE:
+                    self._currentmapkey = value
+                    self._state = self._STATE_WANT_MAP_VALUE
+
+                elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
+                    raise CBORDecodeError('indefinite length bytestrings '
+                                          'not allowed as map keys')
+
+                elif special in (SPECIAL_START_ARRAY, SPECIAL_START_MAP,
+                                 SPECIAL_START_SET):
+                    raise CBORDecodeError('collections not supported as map '
+                                          'keys')
+
+                # We do not allow special values to be used as map keys.
+                else:
+                    raise CBORDecodeError('unhandled special item when '
+                                          'expecting map key: %d' % special)
+
+            # This value becomes the value of the current map key.
+            elif self._state == self._STATE_WANT_MAP_VALUE:
+                # Simple values simply get inserted into the map.
+                if special == SPECIAL_NONE:
+                    lastc = self._collectionstack[-1]
+                    lastc['v'][self._currentmapkey] = value
+                    lastc['remaining'] -= 1
+
+                    self._state = self._STATE_WANT_MAP_KEY
+
+                # A new array is used as the map value.
+                elif special == SPECIAL_START_ARRAY:
+                    lastc = self._collectionstack[-1]
+                    newvalue = []
+
+                    lastc['v'][self._currentmapkey] = newvalue
+                    lastc['remaining'] -= 1
+
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': newvalue,
+                    })
+
+                    self._state = self._STATE_WANT_ARRAY_VALUE
+
+                # A new map is used as the map value.
+                elif special == SPECIAL_START_MAP:
+                    lastc = self._collectionstack[-1]
+                    newvalue = {}
+
+                    lastc['v'][self._currentmapkey] = newvalue
+                    lastc['remaining'] -= 1
+
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': newvalue,
+                    })
+
+                    self._state = self._STATE_WANT_MAP_KEY
+
+                # A new set is used as the map value.
+                elif special == SPECIAL_START_SET:
+                    lastc = self._collectionstack[-1]
+                    newvalue = set()
+
+                    lastc['v'][self._currentmapkey] = newvalue
+                    lastc['remaining'] -= 1
+
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': newvalue,
+                    })
+
+                    self._state = self._STATE_WANT_SET_VALUE
+
+                elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
+                    raise CBORDecodeError('indefinite length bytestrings not '
+                                          'allowed as map values')
+
+                else:
+                    raise CBORDecodeError('unhandled special item when '
+                                          'expecting map value: %d' % special)
+
+                self._currentmapkey = None
+
+            # This value is added to the current set.
+            elif self._state == self._STATE_WANT_SET_VALUE:
+                if special == SPECIAL_NONE:
+                    lastc = self._collectionstack[-1]
+                    lastc['v'].add(value)
+                    lastc['remaining'] -= 1
+
+                elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
+                    raise CBORDecodeError('indefinite length bytestrings not '
+                                          'allowed as set values')
+
+                elif special in (SPECIAL_START_ARRAY,
+                                 SPECIAL_START_MAP,
+                                 SPECIAL_START_SET):
+                    raise CBORDecodeError('collections not allowed as set '
+                                          'values')
+
+                # We don't allow non-trivial types to exist as set values.
+                else:
+                    raise CBORDecodeError('unhandled special item when '
+                                          'expecting set value: %d' % special)
+
+            # This value represents the first chunk in an indefinite length
+            # bytestring.
+            elif self._state == self._STATE_WANT_BYTESTRING_CHUNK_FIRST:
+                # We received a full chunk.
+                if special == SPECIAL_NONE:
+                    self._decodedvalues.append(bytestringchunk(value,
+                                                               first=True))
+
+                    self._state = self._STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT
+
+                # The end of stream marker. This means it is an empty
+                # indefinite length bytestring.
+                elif special == SPECIAL_INDEFINITE_BREAK:
+                    # We /could/ convert this to a b''. But we want to preserve
+                    # the nature of the underlying data so consumers expecting
+                    # an indefinite length bytestring get one.
+                    self._decodedvalues.append(bytestringchunk(b'',
+                                                               first=True,
+                                                               last=True))
+
+                    # Since indefinite length bytestrings can't be used in
+                    # collections, we must be at the root level.
+                    assert not self._collectionstack
+                    self._state = self._STATE_NONE
+
+                else:
+                    raise CBORDecodeError('unexpected special value when '
+                                          'expecting bytestring chunk: %d' %
+                                          special)
+
+            # This value represents the non-initial chunk in an indefinite
+            # length bytestring.
+            elif self._state == self._STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT:
+                # We received a full chunk.
+                if special == SPECIAL_NONE:
+                    self._decodedvalues.append(bytestringchunk(value))
+
+                # The end of stream marker.
+                elif special == SPECIAL_INDEFINITE_BREAK:
+                    self._decodedvalues.append(bytestringchunk(b'', last=True))
+
+                    # Since indefinite length bytestrings can't be used in
+                    # collections, we must be at the root level.
+                    assert not self._collectionstack
+                    self._state = self._STATE_NONE
+
+                else:
+                    raise CBORDecodeError('unexpected special value when '
+                                          'expecting bytestring chunk: %d' %
+                                          special)
+
+            else:
+                raise CBORDecodeError('unhandled decoder state: %d' %
+                                      self._state)
+
+            # We could have just added the final value in a collection. End
+            # all complete collections at the top of the stack.
+            while True:
+                # Bail if we're not waiting on a new collection item.
+                if self._state not in (self._STATE_WANT_ARRAY_VALUE,
+                                       self._STATE_WANT_MAP_KEY,
+                                       self._STATE_WANT_SET_VALUE):
+                    break
+
+                # Or we are expecting more items for this collection.
+                lastc = self._collectionstack[-1]
+
+                if lastc['remaining']:
+                    break
+
+                # The collection at the top of the stack is complete.
+
+                # Discard it, as it isn't needed for future items.
+                self._collectionstack.pop()
+
+                # If this is a nested collection, we don't emit it, since it
+                # will be emitted by its parent collection. But we do need to
+                # update state to reflect what the new top-most collection
+                # on the stack is.
+                if self._collectionstack:
+                    self._state = {
+                        list: self._STATE_WANT_ARRAY_VALUE,
+                        dict: self._STATE_WANT_MAP_KEY,
+                        set: self._STATE_WANT_SET_VALUE,
+                    }[type(self._collectionstack[-1]['v'])]
+
+                # If this is the root collection, emit it.
+                else:
+                    self._decodedvalues.append(lastc['v'])
+                    self._state = self._STATE_NONE
+
+        return (
+            bool(self._decodedvalues),
+            offset - initialoffset,
+            0,
+        )
+
+    def getavailable(self):
+        """Returns an iterator over fully decoded values.
+
+        Once values are retrieved, they won't be available on the next call.
+        """
+
+        l = list(self._decodedvalues)
+        self._decodedvalues = []
+        return l
+
+def decodeall(b):
+    """Decode all CBOR items present in an iterable of bytes.
+
+    In addition to regular decode errors, raises CBORDecodeError if the
+    entirety of the passed buffer does not fully decode to complete CBOR
+    values. This includes failure to decode any value, incomplete collection
+    types, incomplete indefinite length items, and extra data at the end of
+    the buffer.
+    """
+    if not b:
+        return []
+
+    decoder = sansiodecoder()
+
+    havevalues, readcount, wantbytes = decoder.decode(b)
+
+    if readcount != len(b):
+        raise CBORDecodeError('input data not fully consumed')
+
+    if decoder.inprogress:
+        raise CBORDecodeError('input data not complete')
+
+    return decoder.getavailable()



To: indygreg, #hg-reviewers
Cc: mjpieters, mercurial-devel


More information about the Mercurial-devel mailing list