D4414: cborutil: implement sans I/O decoder

indygreg (Gregory Szorc) phabricator at mercurial-scm.org
Mon Sep 3 13:38:18 UTC 2018


This revision was automatically updated to reflect the committed changes.
Closed by commit rHGaeb551a3bb8a: cborutil: implement sans I/O decoder (authored by indygreg, committed by ).

REPOSITORY
  rHG Mercurial

CHANGES SINCE LAST UPDATE
  https://phab.mercurial-scm.org/D4414?vs=10626&id=10728

REVISION DETAIL
  https://phab.mercurial-scm.org/D4414

AFFECTED FILES
  mercurial/utils/cborutil.py
  tests/test-cbor.py

CHANGE DETAILS

diff --git a/tests/test-cbor.py b/tests/test-cbor.py
--- a/tests/test-cbor.py
+++ b/tests/test-cbor.py
@@ -10,10 +10,17 @@
     cborutil,
 )
 
+class TestCase(unittest.TestCase):
+    if not getattr(unittest.TestCase, 'assertRaisesRegex', False):
+        # Python 3.7 deprecates the regex*p* version, but 2.7 lacks
+        # the regex version.
+        assertRaisesRegex = (# camelcase-required
+            unittest.TestCase.assertRaisesRegexp)
+
 def loadit(it):
     return cbor.loads(b''.join(it))
 
-class BytestringTests(unittest.TestCase):
+class BytestringTests(TestCase):
     def testsimple(self):
         self.assertEqual(
             list(cborutil.streamencode(b'foobar')),
@@ -23,11 +30,20 @@
             loadit(cborutil.streamencode(b'foobar')),
             b'foobar')
 
+        self.assertEqual(cborutil.decodeall(b'\x46foobar'),
+                         [b'foobar'])
+
+        self.assertEqual(cborutil.decodeall(b'\x46foobar\x45fizbi'),
+                         [b'foobar', b'fizbi'])
+
     def testlong(self):
         source = b'x' * 1048576
 
         self.assertEqual(loadit(cborutil.streamencode(source)), source)
 
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
     def testfromiter(self):
         # This is the example from RFC 7049 Section 2.2.2.
         source = [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99']
@@ -47,6 +63,25 @@
             loadit(cborutil.streamencodebytestringfromiter(source)),
             b''.join(source))
 
+        self.assertEqual(cborutil.decodeall(b'\x5f\x44\xaa\xbb\xcc\xdd'
+                                            b'\x43\xee\xff\x99\xff'),
+                         [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99', b''])
+
+        for i, chunk in enumerate(
+            cborutil.decodeall(b'\x5f\x44\xaa\xbb\xcc\xdd'
+                               b'\x43\xee\xff\x99\xff')):
+            self.assertIsInstance(chunk, cborutil.bytestringchunk)
+
+            if i == 0:
+                self.assertTrue(chunk.isfirst)
+            else:
+                self.assertFalse(chunk.isfirst)
+
+            if i == 2:
+                self.assertTrue(chunk.islast)
+            else:
+                self.assertFalse(chunk.islast)
+
     def testfromiterlarge(self):
         source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576]
 
@@ -71,6 +106,18 @@
             source, chunksize=42))
         self.assertEqual(cbor.loads(dest), source)
 
+        self.assertEqual(b''.join(cborutil.decodeall(dest)), source)
+
+        for chunk in cborutil.decodeall(dest):
+            self.assertIsInstance(chunk, cborutil.bytestringchunk)
+            self.assertIn(len(chunk), (0, 8, 42))
+
+        encoded = b'\x5f\xff'
+        b = cborutil.decodeall(encoded)
+        self.assertEqual(b, [b''])
+        self.assertTrue(b[0].isfirst)
+        self.assertTrue(b[0].islast)
+
     def testreadtoiter(self):
         source = io.BytesIO(b'\x5f\x44\xaa\xbb\xcc\xdd\x43\xee\xff\x99\xff')
 
@@ -81,42 +128,405 @@
         with self.assertRaises(StopIteration):
             next(it)
 
-class IntTests(unittest.TestCase):
+    def testdecodevariouslengths(self):
+        for i in (0, 1, 22, 23, 24, 25, 254, 255, 256, 65534, 65535, 65536):
+            source = b'x' * i
+            encoded = b''.join(cborutil.streamencode(source))
+
+            if len(source) < 24:
+                hlen = 1
+            elif len(source) < 256:
+                hlen = 2
+            elif len(source) < 65536:
+                hlen = 3
+            elif len(source) < 1048576:
+                hlen = 5
+
+            self.assertEqual(cborutil.decodeitem(encoded),
+                             (True, source, hlen + len(source),
+                              cborutil.SPECIAL_NONE))
+
+    def testpartialdecode(self):
+        encoded = b''.join(cborutil.streamencode(b'foobar'))
+
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -6, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -5, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -4, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (False, None, -3, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:7]),
+                         (True, b'foobar', 7, cborutil.SPECIAL_NONE))
+
+    def testpartialdecodevariouslengths(self):
+        lens = [
+            2,
+            3,
+            10,
+            23,
+            24,
+            25,
+            31,
+            100,
+            254,
+            255,
+            256,
+            257,
+            16384,
+            65534,
+            65535,
+            65536,
+            65537,
+            131071,
+            131072,
+            131073,
+            1048575,
+            1048576,
+            1048577,
+        ]
+
+        for size in lens:
+            if size < 24:
+                hlen = 1
+            elif size < 2**8:
+                hlen = 2
+            elif size < 2**16:
+                hlen = 3
+            elif size < 2**32:
+                hlen = 5
+            else:
+                assert False
+
+            source = b'x' * size
+            encoded = b''.join(cborutil.streamencode(source))
+
+            res = cborutil.decodeitem(encoded[0:1])
+
+            if hlen > 1:
+                self.assertEqual(res, (False, None, -(hlen - 1),
+                                       cborutil.SPECIAL_NONE))
+            else:
+                self.assertEqual(res, (False, None, -(size + hlen - 1),
+                                       cborutil.SPECIAL_NONE))
+
+            # Decoding partial header reports remaining header size.
+            for i in range(hlen - 1):
+                self.assertEqual(cborutil.decodeitem(encoded[0:i + 1]),
+                                 (False, None, -(hlen - i - 1),
+                                  cborutil.SPECIAL_NONE))
+
+            # Decoding complete header reports item size.
+            self.assertEqual(cborutil.decodeitem(encoded[0:hlen]),
+                             (False, None, -size, cborutil.SPECIAL_NONE))
+
+            # Decoding single byte after header reports item size - 1
+            self.assertEqual(cborutil.decodeitem(encoded[0:hlen + 1]),
+                             (False, None, -(size - 1), cborutil.SPECIAL_NONE))
+
+            # Decoding all but the last byte reports -1 needed.
+            self.assertEqual(cborutil.decodeitem(encoded[0:hlen + size - 1]),
+                             (False, None, -1, cborutil.SPECIAL_NONE))
+
+            # Decoding last byte retrieves value.
+            self.assertEqual(cborutil.decodeitem(encoded[0:hlen + size]),
+                             (True, source, hlen + size, cborutil.SPECIAL_NONE))
+
+    def testindefinitepartialdecode(self):
+        encoded = b''.join(cborutil.streamencodebytestringfromiter(
+            [b'foobar', b'biz']))
+
+        # First item should be begin of bytestring special.
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (True, None, 1,
+                          cborutil.SPECIAL_START_INDEFINITE_BYTESTRING))
+
+        # Second item should be the first chunk. But only available when
+        # we give it 7 bytes (1 byte header + 6 byte chunk).
+        self.assertEqual(cborutil.decodeitem(encoded[1:2]),
+                         (False, None, -6, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[1:3]),
+                         (False, None, -5, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[1:4]),
+                         (False, None, -4, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[1:5]),
+                         (False, None, -3, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[1:6]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[1:7]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+
+        self.assertEqual(cborutil.decodeitem(encoded[1:8]),
+                         (True, b'foobar', 7, cborutil.SPECIAL_NONE))
+
+        # Third item should be second chunk. But only available when
+        # we give it 4 bytes (1 byte header + 3 byte chunk).
+        self.assertEqual(cborutil.decodeitem(encoded[8:9]),
+                         (False, None, -3, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[8:10]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[8:11]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+
+        self.assertEqual(cborutil.decodeitem(encoded[8:12]),
+                         (True, b'biz', 4, cborutil.SPECIAL_NONE))
+
+        # Fourth item should be end of indefinite stream marker.
+        self.assertEqual(cborutil.decodeitem(encoded[12:13]),
+                         (True, None, 1, cborutil.SPECIAL_INDEFINITE_BREAK))
+
+        # Now test the behavior when going through the decoder.
+
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:1]),
+                         (False, 1, 0))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:2]),
+                         (False, 1, 6))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:3]),
+                         (False, 1, 5))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:4]),
+                         (False, 1, 4))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:5]),
+                         (False, 1, 3))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:6]),
+                         (False, 1, 2))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:7]),
+                         (False, 1, 1))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:8]),
+                         (True, 8, 0))
+
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:9]),
+                         (True, 8, 3))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:10]),
+                         (True, 8, 2))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:11]),
+                         (True, 8, 1))
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:12]),
+                         (True, 12, 0))
+
+        self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:13]),
+                         (True, 13, 0))
+
+        decoder = cborutil.sansiodecoder()
+        decoder.decode(encoded[0:8])
+        values = decoder.getavailable()
+        self.assertEqual(values, [b'foobar'])
+        self.assertTrue(values[0].isfirst)
+        self.assertFalse(values[0].islast)
+
+        self.assertEqual(decoder.decode(encoded[8:12]),
+                         (True, 4, 0))
+        values = decoder.getavailable()
+        self.assertEqual(values, [b'biz'])
+        self.assertFalse(values[0].isfirst)
+        self.assertFalse(values[0].islast)
+
+        self.assertEqual(decoder.decode(encoded[12:]),
+                         (True, 1, 0))
+        values = decoder.getavailable()
+        self.assertEqual(values, [b''])
+        self.assertFalse(values[0].isfirst)
+        self.assertTrue(values[0].islast)
+
+class StringTests(TestCase):
+    def testdecodeforbidden(self):
+        encoded = b'\x63foo'
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'string major type not supported'):
+            cborutil.decodeall(encoded)
+
+class IntTests(TestCase):
     def testsmall(self):
         self.assertEqual(list(cborutil.streamencode(0)), [b'\x00'])
+        self.assertEqual(cborutil.decodeall(b'\x00'), [0])
+
         self.assertEqual(list(cborutil.streamencode(1)), [b'\x01'])
+        self.assertEqual(cborutil.decodeall(b'\x01'), [1])
+
         self.assertEqual(list(cborutil.streamencode(2)), [b'\x02'])
+        self.assertEqual(cborutil.decodeall(b'\x02'), [2])
+
         self.assertEqual(list(cborutil.streamencode(3)), [b'\x03'])
+        self.assertEqual(cborutil.decodeall(b'\x03'), [3])
+
         self.assertEqual(list(cborutil.streamencode(4)), [b'\x04'])
+        self.assertEqual(cborutil.decodeall(b'\x04'), [4])
+
+        # Multiple value decode works.
+        self.assertEqual(cborutil.decodeall(b'\x00\x01\x02\x03\x04'),
+                         [0, 1, 2, 3, 4])
 
     def testnegativesmall(self):
         self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20'])
+        self.assertEqual(cborutil.decodeall(b'\x20'), [-1])
+
         self.assertEqual(list(cborutil.streamencode(-2)), [b'\x21'])
+        self.assertEqual(cborutil.decodeall(b'\x21'), [-2])
+
         self.assertEqual(list(cborutil.streamencode(-3)), [b'\x22'])
+        self.assertEqual(cborutil.decodeall(b'\x22'), [-3])
+
         self.assertEqual(list(cborutil.streamencode(-4)), [b'\x23'])
+        self.assertEqual(cborutil.decodeall(b'\x23'), [-4])
+
         self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24'])
+        self.assertEqual(cborutil.decodeall(b'\x24'), [-5])
+
+        # Multiple value decode works.
+        self.assertEqual(cborutil.decodeall(b'\x20\x21\x22\x23\x24'),
+                         [-1, -2, -3, -4, -5])
 
     def testrange(self):
         for i in range(-70000, 70000, 10):
-            self.assertEqual(
-                b''.join(cborutil.streamencode(i)),
-                cbor.dumps(i))
+            encoded = b''.join(cborutil.streamencode(i))
+
+            self.assertEqual(encoded, cbor.dumps(i))
+            self.assertEqual(cborutil.decodeall(encoded), [i])
+
+    def testdecodepartialubyte(self):
+        encoded = b''.join(cborutil.streamencode(250))
+
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (True, 250, 2, cborutil.SPECIAL_NONE))
+
+    def testdecodepartialbyte(self):
+        encoded = b''.join(cborutil.streamencode(-42))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (True, -42, 2, cborutil.SPECIAL_NONE))
+
+    def testdecodepartialushort(self):
+        encoded = b''.join(cborutil.streamencode(2**15))
+
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (True, 2**15, 3, cborutil.SPECIAL_NONE))
+
+    def testdecodepartialshort(self):
+        encoded = b''.join(cborutil.streamencode(-1024))
+
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (True, -1024, 3, cborutil.SPECIAL_NONE))
+
+    def testdecodepartialulong(self):
+        encoded = b''.join(cborutil.streamencode(2**28))
+
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -4, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -3, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (True, 2**28, 5, cborutil.SPECIAL_NONE))
+
+    def testdecodepartiallong(self):
+        encoded = b''.join(cborutil.streamencode(-1048580))
 
-class ArrayTests(unittest.TestCase):
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -4, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -3, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (True, -1048580, 5, cborutil.SPECIAL_NONE))
+
+    def testdecodepartialulonglong(self):
+        encoded = b''.join(cborutil.streamencode(2**32))
+
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -8, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -7, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -6, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (False, None, -5, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (False, None, -4, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+                         (False, None, -3, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:7]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:8]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:9]),
+                         (True, 2**32, 9, cborutil.SPECIAL_NONE))
+
+        with self.assertRaisesRegex(
+            cborutil.CBORDecodeError, 'input data not fully consumed'):
+            cborutil.decodeall(encoded[0:1])
+
+        with self.assertRaisesRegex(
+            cborutil.CBORDecodeError, 'input data not fully consumed'):
+            cborutil.decodeall(encoded[0:2])
+
+    def testdecodepartiallonglong(self):
+        encoded = b''.join(cborutil.streamencode(-7000000000))
+
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -8, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -7, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -6, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (False, None, -5, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (False, None, -4, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+                         (False, None, -3, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:7]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:8]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:9]),
+                         (True, -7000000000, 9, cborutil.SPECIAL_NONE))
+
+class ArrayTests(TestCase):
     def testempty(self):
         self.assertEqual(list(cborutil.streamencode([])), [b'\x80'])
         self.assertEqual(loadit(cborutil.streamencode([])), [])
 
+        self.assertEqual(cborutil.decodeall(b'\x80'), [[]])
+
     def testbasic(self):
         source = [b'foo', b'bar', 1, -10]
 
-        self.assertEqual(list(cborutil.streamencode(source)), [
-            b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29'])
+        chunks = [
+            b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29']
+
+        self.assertEqual(list(cborutil.streamencode(source)), chunks)
+
+        self.assertEqual(cborutil.decodeall(b''.join(chunks)), [source])
 
     def testemptyfromiter(self):
         self.assertEqual(b''.join(cborutil.streamencodearrayfromiter([])),
                          b'\x9f\xff')
 
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'indefinite length uint not allowed'):
+            cborutil.decodeall(b'\x9f\xff')
+
     def testfromiter1(self):
         source = [b'foo']
 
@@ -129,57 +539,241 @@
         dest = b''.join(cborutil.streamencodearrayfromiter(source))
         self.assertEqual(cbor.loads(dest), source)
 
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'indefinite length uint not allowed'):
+            cborutil.decodeall(dest)
+
     def testtuple(self):
         source = (b'foo', None, 42)
+        encoded = b''.join(cborutil.streamencode(source))
 
-        self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
-                         list(source))
+        self.assertEqual(cbor.loads(encoded), list(source))
+
+        self.assertEqual(cborutil.decodeall(encoded), [list(source)])
+
+    def testpartialdecode(self):
+        source = list(range(4))
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (True, 4, 1, cborutil.SPECIAL_START_ARRAY))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (True, 4, 1, cborutil.SPECIAL_START_ARRAY))
+
+        source = list(range(23))
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (True, 23, 1, cborutil.SPECIAL_START_ARRAY))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (True, 23, 1, cborutil.SPECIAL_START_ARRAY))
+
+        source = list(range(24))
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (True, 24, 2, cborutil.SPECIAL_START_ARRAY))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (True, 24, 2, cborutil.SPECIAL_START_ARRAY))
 
-class SetTests(unittest.TestCase):
+        source = list(range(256))
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (True, 256, 3, cborutil.SPECIAL_START_ARRAY))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (True, 256, 3, cborutil.SPECIAL_START_ARRAY))
+
+    def testnested(self):
+        source = [[], [], [[], [], []]]
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
+        source = [True, None, [True, 0, 2], [None], [], [[[]], -87]]
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
+        # A set within an array.
+        source = [None, {b'foo', b'bar', None, False}, set()]
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
+        # A map within an array.
+        source = [None, {}, {b'foo': b'bar', True: False}, [{}]]
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
+    def testindefinitebytestringvalues(self):
+        # Single value array whose value is an empty indefinite bytestring.
+        encoded = b'\x81\x5f\x40\xff'
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'indefinite length bytestrings not '
+                                    'allowed as array values'):
+            cborutil.decodeall(encoded)
+
+class SetTests(TestCase):
     def testempty(self):
         self.assertEqual(list(cborutil.streamencode(set())), [
             b'\xd9\x01\x02',
             b'\x80',
         ])
 
+        self.assertEqual(cborutil.decodeall(b'\xd9\x01\x02\x80'), [set()])
+
     def testset(self):
         source = {b'foo', None, 42}
+        encoded = b''.join(cborutil.streamencode(source))
 
-        self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
-                         source)
+        self.assertEqual(cbor.loads(encoded), source)
+
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
+    def testinvalidtag(self):
+        # Must use array to encode sets.
+        encoded = b'\xd9\x01\x02\xa0'
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'expected array after finite set '
+                                    'semantic tag'):
+            cborutil.decodeall(encoded)
+
+    def testpartialdecode(self):
+        # Semantic tag item will be 3 bytes. Set header will be variable
+        # depending on length.
+        encoded = b''.join(cborutil.streamencode({i for i in range(23)}))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (True, 23, 4, cborutil.SPECIAL_START_SET))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (True, 23, 4, cborutil.SPECIAL_START_SET))
+
+        encoded = b''.join(cborutil.streamencode({i for i in range(24)}))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (True, 24, 5, cborutil.SPECIAL_START_SET))
+        self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+                         (True, 24, 5, cborutil.SPECIAL_START_SET))
 
-class BoolTests(unittest.TestCase):
+        encoded = b''.join(cborutil.streamencode({i for i in range(256)}))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+                         (True, 256, 6, cborutil.SPECIAL_START_SET))
+
+    def testinvalidvalue(self):
+        encoded = b''.join([
+            b'\xd9\x01\x02', # semantic tag
+            b'\x81', # array of size 1
+            b'\x5f\x43foo\xff', # indefinite length bytestring "foo"
+        ])
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'indefinite length bytestrings not '
+                                    'allowed as set values'):
+            cborutil.decodeall(encoded)
+
+        encoded = b''.join([
+            b'\xd9\x01\x02',
+            b'\x81',
+            b'\x80', # empty array
+        ])
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'collections not allowed as set values'):
+            cborutil.decodeall(encoded)
+
+        encoded = b''.join([
+            b'\xd9\x01\x02',
+            b'\x81',
+            b'\xa0', # empty map
+        ])
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'collections not allowed as set values'):
+            cborutil.decodeall(encoded)
+
+        encoded = b''.join([
+            b'\xd9\x01\x02',
+            b'\x81',
+            b'\xd9\x01\x02\x81\x01', # set with integer 1
+        ])
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'collections not allowed as set values'):
+            cborutil.decodeall(encoded)
+
+class BoolTests(TestCase):
     def testbasic(self):
         self.assertEqual(list(cborutil.streamencode(True)),  [b'\xf5'])
         self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4'])
 
         self.assertIs(loadit(cborutil.streamencode(True)), True)
         self.assertIs(loadit(cborutil.streamencode(False)), False)
 
-class NoneTests(unittest.TestCase):
+        self.assertEqual(cborutil.decodeall(b'\xf4'), [False])
+        self.assertEqual(cborutil.decodeall(b'\xf5'), [True])
+
+        self.assertEqual(cborutil.decodeall(b'\xf4\xf5\xf5\xf4'),
+                         [False, True, True, False])
+
+class NoneTests(TestCase):
     def testbasic(self):
         self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6'])
 
         self.assertIs(loadit(cborutil.streamencode(None)), None)
 
-class MapTests(unittest.TestCase):
+        self.assertEqual(cborutil.decodeall(b'\xf6'), [None])
+        self.assertEqual(cborutil.decodeall(b'\xf6\xf6'), [None, None])
+
+class MapTests(TestCase):
     def testempty(self):
         self.assertEqual(list(cborutil.streamencode({})), [b'\xa0'])
         self.assertEqual(loadit(cborutil.streamencode({})), {})
 
+        self.assertEqual(cborutil.decodeall(b'\xa0'), [{}])
+
     def testemptyindefinite(self):
         self.assertEqual(list(cborutil.streamencodemapfromiter([])), [
             b'\xbf', b'\xff'])
 
         self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {})
 
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'indefinite length uint not allowed'):
+            cborutil.decodeall(b'\xbf\xff')
+
     def testone(self):
         source = {b'foo': b'bar'}
         self.assertEqual(list(cborutil.streamencode(source)), [
             b'\xa1', b'\x43', b'foo', b'\x43', b'bar'])
 
         self.assertEqual(loadit(cborutil.streamencode(source)), source)
 
+        self.assertEqual(cborutil.decodeall(b'\xa1\x43foo\x43bar'), [source])
+
     def testmultiple(self):
         source = {
             b'foo': b'bar',
@@ -192,6 +786,9 @@
             loadit(cborutil.streamencodemapfromiter(source.items())),
             source)
 
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
     def testcomplex(self):
         source = {
             b'key': 1,
@@ -205,6 +802,170 @@
             loadit(cborutil.streamencodemapfromiter(source.items())),
             source)
 
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
+    def testnested(self):
+        source = {b'key1': None, b'key2': {b'sub1': b'sub2'}, b'sub2': {}}
+        encoded = b''.join(cborutil.streamencode(source))
+
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
+        source = {
+            b'key1': [],
+            b'key2': [None, False],
+            b'key3': {b'foo', b'bar'},
+            b'key4': {},
+        }
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeall(encoded), [source])
+
+    def testillegalkey(self):
+        encoded = b''.join([
+            # map header + len 1
+            b'\xa1',
+            # indefinite length bytestring "foo" in key position
+            b'\x5f\x03foo\xff'
+        ])
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'indefinite length bytestrings not '
+                                    'allowed as map keys'):
+            cborutil.decodeall(encoded)
+
+        encoded = b''.join([
+            b'\xa1',
+            b'\x80', # empty array
+            b'\x43foo',
+        ])
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'collections not supported as map keys'):
+            cborutil.decodeall(encoded)
+
+    def testillegalvalue(self):
+        encoded = b''.join([
+            b'\xa1', # map headers
+            b'\x43foo', # key
+            b'\x5f\x03bar\xff', # indefinite length value
+        ])
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'indefinite length bytestrings not '
+                                    'allowed as map values'):
+            cborutil.decodeall(encoded)
+
+    def testpartialdecode(self):
+        source = {b'key1': b'value1'}
+        encoded = b''.join(cborutil.streamencode(source))
+
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (True, 1, 1, cborutil.SPECIAL_START_MAP))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (True, 1, 1, cborutil.SPECIAL_START_MAP))
+
+        source = {b'key%d' % i: None for i in range(23)}
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (True, 23, 1, cborutil.SPECIAL_START_MAP))
+
+        source = {b'key%d' % i: None for i in range(24)}
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (True, 24, 2, cborutil.SPECIAL_START_MAP))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (True, 24, 2, cborutil.SPECIAL_START_MAP))
+
+        source = {b'key%d' % i: None for i in range(256)}
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (True, 256, 3, cborutil.SPECIAL_START_MAP))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (True, 256, 3, cborutil.SPECIAL_START_MAP))
+
+        source = {b'key%d' % i: None for i in range(65536)}
+        encoded = b''.join(cborutil.streamencode(source))
+        self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                         (False, None, -4, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                         (False, None, -3, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+                         (False, None, -2, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+                         (False, None, -1, cborutil.SPECIAL_NONE))
+        self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+                         (True, 65536, 5, cborutil.SPECIAL_START_MAP))
+        self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+                         (True, 65536, 5, cborutil.SPECIAL_START_MAP))
+
+class SemanticTagTests(TestCase):
+    def testdecodeforbidden(self):
+        for i in range(500):
+            if i == cborutil.SEMANTIC_TAG_FINITE_SET:
+                continue
+
+            tag = cborutil.encodelength(cborutil.MAJOR_TYPE_SEMANTIC,
+                                        i)
+
+            encoded = tag + cborutil.encodelength(cborutil.MAJOR_TYPE_UINT, 42)
+
+            # Partial decode is incomplete.
+            if i < 24:
+                pass
+            elif i < 256:
+                self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                                 (False, None, -1, cborutil.SPECIAL_NONE))
+            elif i < 65536:
+                self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+                                 (False, None, -2, cborutil.SPECIAL_NONE))
+                self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+                                 (False, None, -1, cborutil.SPECIAL_NONE))
+
+            with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                        'semantic tag \d+ not allowed'):
+                cborutil.decodeitem(encoded)
+
+class SpecialTypesTests(TestCase):
+    def testforbiddentypes(self):
+        for i in range(256):
+            if i == cborutil.SUBTYPE_FALSE:
+                continue
+            elif i == cborutil.SUBTYPE_TRUE:
+                continue
+            elif i == cborutil.SUBTYPE_NULL:
+                continue
+
+            encoded = cborutil.encodelength(cborutil.MAJOR_TYPE_SPECIAL, i)
+
+            with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                        'special type \d+ not allowed'):
+                cborutil.decodeitem(encoded)
+
+class SansIODecoderTests(TestCase):
+    def testemptyinput(self):
+        decoder = cborutil.sansiodecoder()
+        self.assertEqual(decoder.decode(b''), (False, 0, 0))
+
+class DecodeallTests(TestCase):
+    def testemptyinput(self):
+        self.assertEqual(cborutil.decodeall(b''), [])
+
+    def testpartialinput(self):
+        encoded = b''.join([
+            b'\x82', # array of 2 elements
+            b'\x01', # integer 1
+        ])
+
+        with self.assertRaisesRegex(cborutil.CBORDecodeError,
+                                    'input data not complete'):
+            cborutil.decodeall(encoded)
+
 if __name__ == '__main__':
     import silenttestrunner
     silenttestrunner.main(__name__)
diff --git a/mercurial/utils/cborutil.py b/mercurial/utils/cborutil.py
--- a/mercurial/utils/cborutil.py
+++ b/mercurial/utils/cborutil.py
@@ -8,6 +8,7 @@
 from __future__ import absolute_import
 
 import struct
+import sys
 
 from ..thirdparty.cbor.cbor2 import (
     decoder as decodermod,
@@ -35,11 +36,16 @@
 
 SUBTYPE_MASK = 0b00011111
 
+SUBTYPE_FALSE = 20
+SUBTYPE_TRUE = 21
+SUBTYPE_NULL = 22
 SUBTYPE_HALF_FLOAT = 25
 SUBTYPE_SINGLE_FLOAT = 26
 SUBTYPE_DOUBLE_FLOAT = 27
 SUBTYPE_INDEFINITE = 31
 
+SEMANTIC_TAG_FINITE_SET = 258
+
 # Indefinite types begin with their major type ORd with information value 31.
 BEGIN_INDEFINITE_BYTESTRING = struct.pack(
     r'>B', MAJOR_TYPE_BYTESTRING << 5 | SUBTYPE_INDEFINITE)
@@ -146,7 +152,7 @@
 def streamencodeset(s):
     # https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml defines
     # semantic tag 258 for finite sets.
-    yield encodelength(MAJOR_TYPE_SEMANTIC, 258)
+    yield encodelength(MAJOR_TYPE_SEMANTIC, SEMANTIC_TAG_FINITE_SET)
 
     for chunk in streamencodearray(sorted(s, key=_mixedtypesortkey)):
         yield chunk
@@ -260,3 +266,710 @@
                     len(chunk), length))
 
         yield chunk
+
+class CBORDecodeError(Exception):
+    """Represents an error decoding CBOR."""
+
+if sys.version_info.major >= 3:
+    def _elementtointeger(b, i):
+        return b[i]
+else:
+    def _elementtointeger(b, i):
+        return ord(b[i])
+
+STRUCT_BIG_UBYTE = struct.Struct(r'>B')
+STRUCT_BIG_USHORT = struct.Struct('>H')
+STRUCT_BIG_ULONG = struct.Struct('>L')
+STRUCT_BIG_ULONGLONG = struct.Struct('>Q')
+
+SPECIAL_NONE = 0
+SPECIAL_START_INDEFINITE_BYTESTRING = 1
+SPECIAL_START_ARRAY = 2
+SPECIAL_START_MAP = 3
+SPECIAL_START_SET = 4
+SPECIAL_INDEFINITE_BREAK = 5
+
+def decodeitem(b, offset=0):
+    """Decode a new CBOR value from a buffer at offset.
+
+    This function attempts to decode up to one complete CBOR value
+    from ``b`` starting at offset ``offset``.
+
+    The beginning of a collection (such as an array, map, set, or
+    indefinite length bytestring) counts as a single value. For these
+    special cases, a state flag will indicate that a special value was seen.
+
+    When called, the function either returns a decoded value or gives
+    a hint as to how many more bytes are needed to do so. By calling
+    the function repeatedly given a stream of bytes, the caller can
+    build up the original values.
+
+    Returns a tuple with the following elements:
+
+    * Bool indicating whether a complete value was decoded.
+    * A decoded value if first value is True otherwise None
+    * Integer number of bytes. If positive, the number of bytes
+      read. If negative, the number of bytes we need to read to
+      decode this value or the next chunk in this value.
+    * One of the ``SPECIAL_*`` constants indicating special treatment
+      for this value. ``SPECIAL_NONE`` means this is a fully decoded
+      simple value (such as an integer or bool).
+    """
+
+    initial = _elementtointeger(b, offset)
+    offset += 1
+
+    majortype = initial >> 5
+    subtype = initial & SUBTYPE_MASK
+
+    if majortype == MAJOR_TYPE_UINT:
+        complete, value, readcount = decodeuint(subtype, b, offset)
+
+        if complete:
+            return True, value, readcount + 1, SPECIAL_NONE
+        else:
+            return False, None, readcount, SPECIAL_NONE
+
+    elif majortype == MAJOR_TYPE_NEGINT:
+        # Negative integers are the same as UINT except inverted minus 1.
+        complete, value, readcount = decodeuint(subtype, b, offset)
+
+        if complete:
+            return True, -value - 1, readcount + 1, SPECIAL_NONE
+        else:
+            return False, None, readcount, SPECIAL_NONE
+
+    elif majortype == MAJOR_TYPE_BYTESTRING:
+        # Beginning of bytestrings are treated as uints in order to
+        # decode their length, which may be indefinite.
+        complete, size, readcount = decodeuint(subtype, b, offset,
+                                               allowindefinite=True)
+
+        # We don't know the size of the bytestring. It must be a definitive
+        # length since the indefinite subtype would be encoded in the initial
+        # byte.
+        if not complete:
+            return False, None, readcount, SPECIAL_NONE
+
+        # We know the length of the bytestring.
+        if size is not None:
+            # And the data is available in the buffer.
+            if offset + readcount + size <= len(b):
+                value = b[offset + readcount:offset + readcount + size]
+                return True, value, readcount + size + 1, SPECIAL_NONE
+
+            # And we need more data in order to return the bytestring.
+            else:
+                wanted = len(b) - offset - readcount - size
+                return False, None, wanted, SPECIAL_NONE
+
+        # It is an indefinite length bytestring.
+        else:
+            return True, None, 1, SPECIAL_START_INDEFINITE_BYTESTRING
+
+    elif majortype == MAJOR_TYPE_STRING:
+        raise CBORDecodeError('string major type not supported')
+
+    elif majortype == MAJOR_TYPE_ARRAY:
+        # Beginning of arrays are treated as uints in order to decode their
+        # length. We don't allow indefinite length arrays.
+        complete, size, readcount = decodeuint(subtype, b, offset)
+
+        if complete:
+            return True, size, readcount + 1, SPECIAL_START_ARRAY
+        else:
+            return False, None, readcount, SPECIAL_NONE
+
+    elif majortype == MAJOR_TYPE_MAP:
+        # Beginning of maps are treated as uints in order to decode their
+        # number of elements. We don't allow indefinite length arrays.
+        complete, size, readcount = decodeuint(subtype, b, offset)
+
+        if complete:
+            return True, size, readcount + 1, SPECIAL_START_MAP
+        else:
+            return False, None, readcount, SPECIAL_NONE
+
+    elif majortype == MAJOR_TYPE_SEMANTIC:
+        # Semantic tag value is read the same as a uint.
+        complete, tagvalue, readcount = decodeuint(subtype, b, offset)
+
+        if not complete:
+            return False, None, readcount, SPECIAL_NONE
+
+        # This behavior here is a little wonky. The main type being "decorated"
+        # by this semantic tag follows. A more robust parser would probably emit
+        # a special flag indicating this as a semantic tag and let the caller
+        # deal with the types that follow. But since we don't support many
+        # semantic tags, it is easier to deal with the special cases here and
+        # hide complexity from the caller. If we add support for more semantic
+        # tags, we should probably move semantic tag handling into the caller.
+        if tagvalue == SEMANTIC_TAG_FINITE_SET:
+            if offset + readcount >= len(b):
+                return False, None, -1, SPECIAL_NONE
+
+            complete, size, readcount2, special = decodeitem(b,
+                                                             offset + readcount)
+
+            if not complete:
+                return False, None, readcount2, SPECIAL_NONE
+
+            if special != SPECIAL_START_ARRAY:
+                raise CBORDecodeError('expected array after finite set '
+                                      'semantic tag')
+
+            return True, size, readcount + readcount2 + 1, SPECIAL_START_SET
+
+        else:
+            raise CBORDecodeError('semantic tag %d not allowed' % tagvalue)
+
+    elif majortype == MAJOR_TYPE_SPECIAL:
+        # Only specific values for the information field are allowed.
+        if subtype == SUBTYPE_FALSE:
+            return True, False, 1, SPECIAL_NONE
+        elif subtype == SUBTYPE_TRUE:
+            return True, True, 1, SPECIAL_NONE
+        elif subtype == SUBTYPE_NULL:
+            return True, None, 1, SPECIAL_NONE
+        elif subtype == SUBTYPE_INDEFINITE:
+            return True, None, 1, SPECIAL_INDEFINITE_BREAK
+        # If value is 24, subtype is in next byte.
+        else:
+            raise CBORDecodeError('special type %d not allowed' % subtype)
+    else:
+        assert False
+
+def decodeuint(subtype, b, offset=0, allowindefinite=False):
+    """Decode an unsigned integer.
+
+    ``subtype`` is the lower 5 bits from the initial byte CBOR item
+    "header." ``b`` is a buffer containing bytes. ``offset`` points to
+    the index of the first byte after the byte that ``subtype`` was
+    derived from.
+
+    ``allowindefinite`` allows the special indefinite length value
+    indicator.
+
+    Returns a 3-tuple of (successful, value, count).
+
+    The first element is a bool indicating if decoding completed. The 2nd
+    is the decoded integer value or None if not fully decoded or the subtype
+    is 31 and ``allowindefinite`` is True. The 3rd value is the count of bytes.
+    If positive, it is the number of additional bytes decoded. If negative,
+    it is the number of additional bytes needed to decode this value.
+    """
+
+    # Small values are inline.
+    if subtype < 24:
+        return True, subtype, 0
+    # Indefinite length specifier.
+    elif subtype == 31:
+        if allowindefinite:
+            return True, None, 0
+        else:
+            raise CBORDecodeError('indefinite length uint not allowed here')
+    elif subtype >= 28:
+        raise CBORDecodeError('unsupported subtype on integer type: %d' %
+                              subtype)
+
+    if subtype == 24:
+        s = STRUCT_BIG_UBYTE
+    elif subtype == 25:
+        s = STRUCT_BIG_USHORT
+    elif subtype == 26:
+        s = STRUCT_BIG_ULONG
+    elif subtype == 27:
+        s = STRUCT_BIG_ULONGLONG
+    else:
+        raise CBORDecodeError('bounds condition checking violation')
+
+    if len(b) - offset >= s.size:
+        return True, s.unpack_from(b, offset)[0], s.size
+    else:
+        return False, None, len(b) - offset - s.size
+
+class bytestringchunk(bytes):
+    """Represents a chunk/segment in an indefinite length bytestring.
+
+    This behaves like a ``bytes`` but in addition has the ``isfirst``
+    and ``islast`` attributes indicating whether this chunk is the first
+    or last in an indefinite length bytestring.
+    """
+
+    def __new__(cls, v, first=False, last=False):
+        self = bytes.__new__(cls, v)
+        self.isfirst = first
+        self.islast = last
+
+        return self
+
+class sansiodecoder(object):
+    """A CBOR decoder that doesn't perform its own I/O.
+
+    To use, construct an instance and feed it segments containing
+    CBOR-encoded bytes via ``decode()``. The return value from ``decode()``
+    indicates whether a fully-decoded value is available, how many bytes
+    were consumed, and offers a hint as to how many bytes should be fed
+    in next time to decode the next value.
+
+    The decoder assumes it will decode N discrete CBOR values, not just
+    a single value. i.e. if the bytestream contains uints packed one after
+    the other, the decoder will decode them all, rather than just the initial
+    one.
+
+    When ``decode()`` indicates a value is available, call ``getavailable()``
+    to return all fully decoded values.
+
+    ``decode()`` can partially decode input. It is up to the caller to keep
+    track of what data was consumed and to pass unconsumed data in on the
+    next invocation.
+
+    The decoder decodes atomically at the *item* level. See ``decodeitem()``.
+    If an *item* cannot be fully decoded, the decoder won't record it as
+    partially consumed. Instead, the caller will be instructed to pass in
+    the initial bytes of this item on the next invocation. This does result
+    in some redundant parsing. But the overhead should be minimal.
+
+    This decoder only supports a subset of CBOR as required by Mercurial.
+    It lacks support for:
+
+    * Indefinite length arrays
+    * Indefinite length maps
+    * Use of indefinite length bytestrings as keys or values within
+      arrays, maps, or sets.
+    * Nested arrays, maps, or sets within sets
+    * Any semantic tag that isn't a mathematical finite set
+    * Floating point numbers
+    * Undefined special value
+
+    CBOR types are decoded to Python types as follows:
+
+    uint -> int
+    negint -> int
+    bytestring -> bytes
+    map -> dict
+    array -> list
+    True -> bool
+    False -> bool
+    null -> None
+    indefinite length bytestring chunk -> [bytestringchunk]
+
+    The only non-obvious mapping here is an indefinite length bytestring
+    to the ``bytestringchunk`` type. This is to facilitate streaming
+    indefinite length bytestrings out of the decoder and to differentiate
+    a regular bytestring from an indefinite length bytestring.
+    """
+
+    _STATE_NONE = 0
+    _STATE_WANT_MAP_KEY = 1
+    _STATE_WANT_MAP_VALUE = 2
+    _STATE_WANT_ARRAY_VALUE = 3
+    _STATE_WANT_SET_VALUE = 4
+    _STATE_WANT_BYTESTRING_CHUNK_FIRST = 5
+    _STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT = 6
+
+    def __init__(self):
+        # TODO add support for limiting size of bytestrings
+        # TODO add support for limiting number of keys / values in collections
+        # TODO add support for limiting size of buffered partial values
+
+        self.decodedbytecount = 0
+
+        self._state = self._STATE_NONE
+
+        # Stack of active nested collections. Each entry is a dict describing
+        # the collection.
+        self._collectionstack = []
+
+        # Fully decoded key to use for the current map.
+        self._currentmapkey = None
+
+        # Fully decoded values available for retrieval.
+        self._decodedvalues = []
+
+    @property
+    def inprogress(self):
+        """Whether the decoder has partially decoded a value."""
+        return self._state != self._STATE_NONE
+
+    def decode(self, b, offset=0):
+        """Attempt to decode bytes from an input buffer.
+
+        ``b`` is a collection of bytes and ``offset`` is the byte
+        offset within that buffer from which to begin reading data.
+
+        ``b`` must support ``len()`` and accessing bytes slices via
+        ``__slice__``. Typically ``bytes`` instances are used.
+
+        Returns a tuple with the following fields:
+
+        * Bool indicating whether values are available for retrieval.
+        * Integer indicating the number of bytes that were fully consumed,
+          starting from ``offset``.
+        * Integer indicating the number of bytes that are desired for the
+          next call in order to decode an item.
+        """
+        if not b:
+            return bool(self._decodedvalues), 0, 0
+
+        initialoffset = offset
+
+        # We could easily split the body of this loop into a function. But
+        # Python performance is sensitive to function calls and collections
+        # are composed of many items. So leaving as a while loop could help
+        # with performance. One thing that may not help is the use of
+        # if..elif versus a lookup/dispatch table. There may be value
+        # in switching that.
+        while offset < len(b):
+            # Attempt to decode an item. This could be a whole value or a
+            # special value indicating an event, such as start or end of a
+            # collection or indefinite length type.
+            complete, value, readcount, special = decodeitem(b, offset)
+
+            if readcount > 0:
+                self.decodedbytecount += readcount
+
+            if not complete:
+                assert readcount < 0
+                return (
+                    bool(self._decodedvalues),
+                    offset - initialoffset,
+                    -readcount,
+                )
+
+            offset += readcount
+
+            # No nested state. We either have a full value or beginning of a
+            # complex value to deal with.
+            if self._state == self._STATE_NONE:
+                # A normal value.
+                if special == SPECIAL_NONE:
+                    self._decodedvalues.append(value)
+
+                elif special == SPECIAL_START_ARRAY:
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': [],
+                    })
+                    self._state = self._STATE_WANT_ARRAY_VALUE
+
+                elif special == SPECIAL_START_MAP:
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': {},
+                    })
+                    self._state = self._STATE_WANT_MAP_KEY
+
+                elif special == SPECIAL_START_SET:
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': set(),
+                    })
+                    self._state = self._STATE_WANT_SET_VALUE
+
+                elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
+                    self._state = self._STATE_WANT_BYTESTRING_CHUNK_FIRST
+
+                else:
+                    raise CBORDecodeError('unhandled special state: %d' %
+                                          special)
+
+            # This value becomes an element of the current array.
+            elif self._state == self._STATE_WANT_ARRAY_VALUE:
+                # Simple values get appended.
+                if special == SPECIAL_NONE:
+                    c = self._collectionstack[-1]
+                    c['v'].append(value)
+                    c['remaining'] -= 1
+
+                    # self._state doesn't need changed.
+
+                # An array nested within an array.
+                elif special == SPECIAL_START_ARRAY:
+                    lastc = self._collectionstack[-1]
+                    newvalue = []
+
+                    lastc['v'].append(newvalue)
+                    lastc['remaining'] -= 1
+
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': newvalue,
+                    })
+
+                    # self._state doesn't need changed.
+
+                # A map nested within an array.
+                elif special == SPECIAL_START_MAP:
+                    lastc = self._collectionstack[-1]
+                    newvalue = {}
+
+                    lastc['v'].append(newvalue)
+                    lastc['remaining'] -= 1
+
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': newvalue
+                    })
+
+                    self._state = self._STATE_WANT_MAP_KEY
+
+                elif special == SPECIAL_START_SET:
+                    lastc = self._collectionstack[-1]
+                    newvalue = set()
+
+                    lastc['v'].append(newvalue)
+                    lastc['remaining'] -= 1
+
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': newvalue,
+                    })
+
+                    self._state = self._STATE_WANT_SET_VALUE
+
+                elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
+                    raise CBORDecodeError('indefinite length bytestrings '
+                                          'not allowed as array values')
+
+                else:
+                    raise CBORDecodeError('unhandled special item when '
+                                          'expecting array value: %d' % special)
+
+            # This value becomes the key of the current map instance.
+            elif self._state == self._STATE_WANT_MAP_KEY:
+                if special == SPECIAL_NONE:
+                    self._currentmapkey = value
+                    self._state = self._STATE_WANT_MAP_VALUE
+
+                elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
+                    raise CBORDecodeError('indefinite length bytestrings '
+                                          'not allowed as map keys')
+
+                elif special in (SPECIAL_START_ARRAY, SPECIAL_START_MAP,
+                                 SPECIAL_START_SET):
+                    raise CBORDecodeError('collections not supported as map '
+                                          'keys')
+
+                # We do not allow special values to be used as map keys.
+                else:
+                    raise CBORDecodeError('unhandled special item when '
+                                          'expecting map key: %d' % special)
+
+            # This value becomes the value of the current map key.
+            elif self._state == self._STATE_WANT_MAP_VALUE:
+                # Simple values simply get inserted into the map.
+                if special == SPECIAL_NONE:
+                    lastc = self._collectionstack[-1]
+                    lastc['v'][self._currentmapkey] = value
+                    lastc['remaining'] -= 1
+
+                    self._state = self._STATE_WANT_MAP_KEY
+
+                # A new array is used as the map value.
+                elif special == SPECIAL_START_ARRAY:
+                    lastc = self._collectionstack[-1]
+                    newvalue = []
+
+                    lastc['v'][self._currentmapkey] = newvalue
+                    lastc['remaining'] -= 1
+
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': newvalue,
+                    })
+
+                    self._state = self._STATE_WANT_ARRAY_VALUE
+
+                # A new map is used as the map value.
+                elif special == SPECIAL_START_MAP:
+                    lastc = self._collectionstack[-1]
+                    newvalue = {}
+
+                    lastc['v'][self._currentmapkey] = newvalue
+                    lastc['remaining'] -= 1
+
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': newvalue,
+                    })
+
+                    self._state = self._STATE_WANT_MAP_KEY
+
+                # A new set is used as the map value.
+                elif special == SPECIAL_START_SET:
+                    lastc = self._collectionstack[-1]
+                    newvalue = set()
+
+                    lastc['v'][self._currentmapkey] = newvalue
+                    lastc['remaining'] -= 1
+
+                    self._collectionstack.append({
+                        'remaining': value,
+                        'v': newvalue,
+                    })
+
+                    self._state = self._STATE_WANT_SET_VALUE
+
+                elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
+                    raise CBORDecodeError('indefinite length bytestrings not '
+                                          'allowed as map values')
+
+                else:
+                    raise CBORDecodeError('unhandled special item when '
+                                          'expecting map value: %d' % special)
+
+                self._currentmapkey = None
+
+            # This value is added to the current set.
+            elif self._state == self._STATE_WANT_SET_VALUE:
+                if special == SPECIAL_NONE:
+                    lastc = self._collectionstack[-1]
+                    lastc['v'].add(value)
+                    lastc['remaining'] -= 1
+
+                elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
+                    raise CBORDecodeError('indefinite length bytestrings not '
+                                          'allowed as set values')
+
+                elif special in (SPECIAL_START_ARRAY,
+                                 SPECIAL_START_MAP,
+                                 SPECIAL_START_SET):
+                    raise CBORDecodeError('collections not allowed as set '
+                                          'values')
+
+                # We don't allow non-trivial types to exist as set values.
+                else:
+                    raise CBORDecodeError('unhandled special item when '
+                                          'expecting set value: %d' % special)
+
+            # This value represents the first chunk in an indefinite length
+            # bytestring.
+            elif self._state == self._STATE_WANT_BYTESTRING_CHUNK_FIRST:
+                # We received a full chunk.
+                if special == SPECIAL_NONE:
+                    self._decodedvalues.append(bytestringchunk(value,
+                                                               first=True))
+
+                    self._state = self._STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT
+
+                # The end of stream marker. This means it is an empty
+                # indefinite length bytestring.
+                elif special == SPECIAL_INDEFINITE_BREAK:
+                    # We /could/ convert this to a b''. But we want to preserve
+                    # the nature of the underlying data so consumers expecting
+                    # an indefinite length bytestring get one.
+                    self._decodedvalues.append(bytestringchunk(b'',
+                                                               first=True,
+                                                               last=True))
+
+                    # Since indefinite length bytestrings can't be used in
+                    # collections, we must be at the root level.
+                    assert not self._collectionstack
+                    self._state = self._STATE_NONE
+
+                else:
+                    raise CBORDecodeError('unexpected special value when '
+                                          'expecting bytestring chunk: %d' %
+                                          special)
+
+            # This value represents the non-initial chunk in an indefinite
+            # length bytestring.
+            elif self._state == self._STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT:
+                # We received a full chunk.
+                if special == SPECIAL_NONE:
+                    self._decodedvalues.append(bytestringchunk(value))
+
+                # The end of stream marker.
+                elif special == SPECIAL_INDEFINITE_BREAK:
+                    self._decodedvalues.append(bytestringchunk(b'', last=True))
+
+                    # Since indefinite length bytestrings can't be used in
+                    # collections, we must be at the root level.
+                    assert not self._collectionstack
+                    self._state = self._STATE_NONE
+
+                else:
+                    raise CBORDecodeError('unexpected special value when '
+                                          'expecting bytestring chunk: %d' %
+                                          special)
+
+            else:
+                raise CBORDecodeError('unhandled decoder state: %d' %
+                                      self._state)
+
+            # We could have just added the final value in a collection. End
+            # all complete collections at the top of the stack.
+            while True:
+                # Bail if we're not waiting on a new collection item.
+                if self._state not in (self._STATE_WANT_ARRAY_VALUE,
+                                       self._STATE_WANT_MAP_KEY,
+                                       self._STATE_WANT_SET_VALUE):
+                    break
+
+                # Or we are expecting more items for this collection.
+                lastc = self._collectionstack[-1]
+
+                if lastc['remaining']:
+                    break
+
+                # The collection at the top of the stack is complete.
+
+                # Discard it, as it isn't needed for future items.
+                self._collectionstack.pop()
+
+                # If this is a nested collection, we don't emit it, since it
+                # will be emitted by its parent collection. But we do need to
+                # update state to reflect what the new top-most collection
+                # on the stack is.
+                if self._collectionstack:
+                    self._state = {
+                        list: self._STATE_WANT_ARRAY_VALUE,
+                        dict: self._STATE_WANT_MAP_KEY,
+                        set: self._STATE_WANT_SET_VALUE,
+                    }[type(self._collectionstack[-1]['v'])]
+
+                # If this is the root collection, emit it.
+                else:
+                    self._decodedvalues.append(lastc['v'])
+                    self._state = self._STATE_NONE
+
+        return (
+            bool(self._decodedvalues),
+            offset - initialoffset,
+            0,
+        )
+
+    def getavailable(self):
+        """Returns an iterator over fully decoded values.
+
+        Once values are retrieved, they won't be available on the next call.
+        """
+
+        l = list(self._decodedvalues)
+        self._decodedvalues = []
+        return l
+
+def decodeall(b):
+    """Decode all CBOR items present in an iterable of bytes.
+
+    In addition to regular decode errors, raises CBORDecodeError if the
+    entirety of the passed buffer does not fully decode to complete CBOR
+    values. This includes failure to decode any value, incomplete collection
+    types, incomplete indefinite length items, and extra data at the end of
+    the buffer.
+    """
+    if not b:
+        return []
+
+    decoder = sansiodecoder()
+
+    havevalues, readcount, wantbytes = decoder.decode(b)
+
+    if readcount != len(b):
+        raise CBORDecodeError('input data not fully consumed')
+
+    if decoder.inprogress:
+        raise CBORDecodeError('input data not complete')
+
+    return decoder.getavailable()



To: indygreg, #hg-reviewers
Cc: mjpieters, mercurial-devel


More information about the Mercurial-devel mailing list