D3303: cborutil: implement support for streaming encoding, bytestring decoding

indygreg (Gregory Szorc) phabricator at mercurial-scm.org
Sat Apr 14 20:22:46 EDT 2018

indygreg updated this revision to Diff 8301.

  rHG Mercurial





diff --git a/tests/test-cbor.py b/tests/test-cbor.py
new file mode 100644
--- /dev/null
+++ b/tests/test-cbor.py
@@ -0,0 +1,211 @@
+from __future__ import absolute_import
+import io
+import unittest
+from mercurial.thirdparty import (
+    cbor,
+from mercurial.utils import (
+    cborutil,
+def loadit(it):
+    return cbor.loads(b''.join(it))
+class BytestringTests(unittest.TestCase):
+    def testsimple(self):
+        self.assertEqual(
+            list(cborutil.streamencode(b'foobar')),
+            [b'\x46', b'foobar'])
+        self.assertEqual(
+            loadit(cborutil.streamencode(b'foobar')),
+            b'foobar')
+    def testlong(self):
+        source = b'x' * 1048576
+        self.assertEqual(loadit(cborutil.streamencode(source)), source)
+    def testfromiter(self):
+        # This is the example from RFC 7049 Section 2.2.2.
+        source = [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99']
+        self.assertEqual(
+            list(cborutil.streamencodebytestringfromiter(source)),
+            [
+                b'\x5f',
+                b'\x44',
+                b'\xaa\xbb\xcc\xdd',
+                b'\x43',
+                b'\xee\xff\x99',
+                b'\xff',
+            ])
+        self.assertEqual(
+            loadit(cborutil.streamencodebytestringfromiter(source)),
+            b''.join(source))
+    def testfromiterlarge(self):
+        source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576]
+        self.assertEqual(
+            loadit(cborutil.streamencodebytestringfromiter(source)),
+            b''.join(source))
+    def testindefinite(self):
+        source = b'\x00\x01\x02\x03' + b'\xff' * 16384
+        it = cborutil.streamencodeindefinitebytestring(source, chunksize=2)
+        self.assertEqual(next(it), b'\x5f')
+        self.assertEqual(next(it), b'\x42')
+        self.assertEqual(next(it), b'\x00\x01')
+        self.assertEqual(next(it), b'\x42')
+        self.assertEqual(next(it), b'\x02\x03')
+        self.assertEqual(next(it), b'\x42')
+        self.assertEqual(next(it), b'\xff\xff')
+        dest = b''.join(cborutil.streamencodeindefinitebytestring(
+            source, chunksize=42))
+        self.assertEqual(cbor.loads(dest), b''.join(source))
+    def testreadtoiter(self):
+        source = io.BytesIO(b'\x5f\x44\xaa\xbb\xcc\xdd\x43\xee\xff\x99\xff')
+        it = cborutil.readindefinitebytestringtoiter(source)
+        self.assertEqual(next(it), b'\xaa\xbb\xcc\xdd')
+        self.assertEqual(next(it), b'\xee\xff\x99')
+        with self.assertRaises(StopIteration):
+            next(it)
+class IntTests(unittest.TestCase):
+    def testsmall(self):
+        self.assertEqual(list(cborutil.streamencode(0)), [b'\x00'])
+        self.assertEqual(list(cborutil.streamencode(1)), [b'\x01'])
+        self.assertEqual(list(cborutil.streamencode(2)), [b'\x02'])
+        self.assertEqual(list(cborutil.streamencode(3)), [b'\x03'])
+        self.assertEqual(list(cborutil.streamencode(4)), [b'\x04'])
+    def testnegativesmall(self):
+        self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20'])
+        self.assertEqual(list(cborutil.streamencode(-2)), [b'\x21'])
+        self.assertEqual(list(cborutil.streamencode(-3)), [b'\x22'])
+        self.assertEqual(list(cborutil.streamencode(-4)), [b'\x23'])
+        self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24'])
+    def testrange(self):
+        for i in range(-70000, 70000, 10):
+            self.assertEqual(
+                b''.join(cborutil.streamencode(i)),
+                cbor.dumps(i))
+class ArrayTests(unittest.TestCase):
+    def testempty(self):
+        self.assertEqual(list(cborutil.streamencode([])), [b'\x80'])
+        self.assertEqual(loadit(cborutil.streamencode([])), [])
+    def testbasic(self):
+        source = [b'foo', b'bar', 1, -10]
+        self.assertEqual(list(cborutil.streamencode(source)), [
+            b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29'])
+    def testemptyfromiter(self):
+        self.assertEqual(b''.join(cborutil.streamencodearrayfromiter([])),
+                         b'\x9f\xff')
+    def testfromiter1(self):
+        source = [b'foo']
+        self.assertEqual(list(cborutil.streamencodearrayfromiter(source)), [
+            b'\x9f',
+            b'\x43', b'foo',
+            b'\xff',
+        ])
+        dest = b''.join(cborutil.streamencodearrayfromiter(source))
+        self.assertEqual(cbor.loads(dest), source)
+    def testtuple(self):
+        source = (b'foo', None, 42)
+        self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
+                         list(source))
+class SetTests(unittest.TestCase):
+    def testempty(self):
+        self.assertEqual(list(cborutil.streamencode(set())), [
+            b'\xd9\x01\x02',
+            b'\x80',
+        ])
+    def testset(self):
+        source = {b'foo', None, 42}
+        self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
+                         source)
+class BoolTests(unittest.TestCase):
+    def testbasic(self):
+        self.assertEqual(list(cborutil.streamencode(True)),  [b'\xf5'])
+        self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4'])
+        self.assertIs(loadit(cborutil.streamencode(True)), True)
+        self.assertIs(loadit(cborutil.streamencode(False)), False)
+class NoneTests(unittest.TestCase):
+    def testbasic(self):
+        self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6'])
+        self.assertIs(loadit(cborutil.streamencode(None)), None)
+class MapTests(unittest.TestCase):
+    def testempty(self):
+        self.assertEqual(list(cborutil.streamencode({})), [b'\xa0'])
+        self.assertEqual(loadit(cborutil.streamencode({})), {})
+    def testemptyindefinite(self):
+        self.assertEqual(list(cborutil.streamencodemapfromiter([])), [
+            b'\xbf', b'\xff'])
+        self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {})
+    def testone(self):
+        source = {b'foo': b'bar'}
+        self.assertEqual(list(cborutil.streamencode(source)), [
+            b'\xa1', b'\x43', b'foo', b'\x43', b'bar'])
+        self.assertEqual(loadit(cborutil.streamencode(source)), source)
+    def testmultiple(self):
+        source = {
+            b'foo': b'bar',
+            b'baz': b'value1',
+        }
+        self.assertEqual(loadit(cborutil.streamencode(source)), source)
+        self.assertEqual(
+            loadit(cborutil.streamencodemapfromiter(source.items())),
+            source)
+    def testcomplex(self):
+        source = {
+            b'key': 1,
+            2: -10,
+        }
+        self.assertEqual(loadit(cborutil.streamencode(source)),
+                         source)
+        self.assertEqual(
+            loadit(cborutil.streamencodemapfromiter(source.items())),
+            source)
+if __name__ == '__main__':
+    import silenttestrunner
+    silenttestrunner.main(__name__)
diff --git a/mercurial/utils/cborutil.py b/mercurial/utils/cborutil.py
new file mode 100644
--- /dev/null
+++ b/mercurial/utils/cborutil.py
@@ -0,0 +1,259 @@
+# cborutil.py - CBOR extensions
+# Copyright 2018 Gregory Szorc <gregory.szorc at gmail.com>
+# This software may be used and distributed according to the terms of the
+# GNU General Public License version 2 or any later version.
+from __future__ import absolute_import
+import math
+import struct
+from ..thirdparty.cbor.cbor2 import (
+    decoder as decodermod,
+# Very short very of RFC 7049...
+# Each item begins with a byte. The 3 high bits of that byte denote the
+# "major type." The lower 5 bits denote the "subtype." Each major type
+# has its own encoding mechanism.
+# Most types have lengths. However, bytestring, string, array, and map
+# can be indefinite length. These are denotes by a subtype with value 31.
+# Sub-components of those types then come afterwards and are terminated
+# by a "break" byte.
+SUBTYPE_MASK = 0b00011111
+# Indefinite types begin with their major type ORd with information value 31.
+BEGIN_INDEFINITE_MAP = struct.pack(
+ENCODED_LENGTH_1 = struct.Struct(r'>B')
+ENCODED_LENGTH_2 = struct.Struct(r'>BB')
+ENCODED_LENGTH_3 = struct.Struct(r'>BH')
+ENCODED_LENGTH_4 = struct.Struct(r'>BL')
+ENCODED_LENGTH_5 = struct.Struct(r'>BQ')
+# The break ends an indefinite length item.
+BREAK = b'\xff'
+BREAK_INT = 255
+def encodelength(majortype, length):
+    """Obtain a value encoding the major type and its length."""
+    if length < 24:
+        return ENCODED_LENGTH_1.pack(majortype << 5 | length)
+    elif length < 256:
+        return ENCODED_LENGTH_2.pack(majortype << 5 | 24, length)
+    elif length < 65536:
+        return ENCODED_LENGTH_3.pack(majortype << 5 | 25, length)
+    elif length < 4294967296:
+        return ENCODED_LENGTH_4.pack(majortype << 5 | 26, length)
+    else:
+        return ENCODED_LENGTH_5.pack(majortype << 5 | 27, length)
+def streamencodebytestring(v):
+    yield encodelength(MAJOR_TYPE_BYTESTRING, len(v))
+    yield v
+def streamencodebytestringfromiter(it):
+    """Convert an iterator of chunks to an indefinite bytestring.
+    Given an input that is iterable and each element in the iterator is
+    representable as bytes, emit an indefinite length bytestring.
+    """
+    for chunk in it:
+        yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk))
+        yield chunk
+    yield BREAK
+def streamencodeindefinitebytestring(source, chunksize=65536):
+    """Given a large source buffer, emit as an indefinite length bytestring.
+    This is a generator of chunks constituting the encoded CBOR data.
+    """
+    i = 0
+    l = len(source)
+    while True:
+        chunk = source[i:i + chunksize]
+        i += len(chunk)
+        yield encodelength(MAJOR_TYPE_BYTESTRING, len(chunk))
+        yield chunk
+        if i >= l:
+            break
+    yield BREAK
+def streamencodeint(v):
+    if v >= 18446744073709551616 or v < -18446744073709551616:
+        raise ValueError('big integers not supported')
+    if v >= 0:
+        yield encodelength(MAJOR_TYPE_UINT, v)
+    else:
+        yield encodelength(MAJOR_TYPE_NEGINT, abs(v) - 1)
+def streamencodearray(l):
+    """Encode a known size iterable to an array."""
+    yield encodelength(MAJOR_TYPE_ARRAY, len(l))
+    for i in l:
+        for chunk in streamencode(i):
+            yield chunk
+def streamencodearrayfromiter(it):
+    """Encode an iterator of items to an indefinite length array."""
+    for i in it:
+        for chunk in streamencode(i):
+            yield chunk
+    yield BREAK
+def streamencodeset(s):
+    # https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml defines
+    # semantic tag 258 for finite sets.
+    yield encodelength(MAJOR_TYPE_SEMANTIC, 258)
+    for chunk in streamencodearray(sorted(s)):
+        yield chunk
+def streamencodemap(d):
+    """Encode dictionary to a generator.
+    Does not supporting indefinite length dictionaries.
+    """
+    yield encodelength(MAJOR_TYPE_MAP, len(d))
+    for key, value in sorted(d.iteritems()):
+        for chunk in streamencode(key):
+            yield chunk
+        for chunk in streamencode(value):
+            yield chunk
+def streamencodemapfromiter(it):
+    """Given an iterable of (key, value), encode to an indefinite length map."""
+    for key, value in it:
+        for chunk in streamencode(key):
+            yield chunk
+        for chunk in streamencode(value):
+            yield chunk
+    yield BREAK
+def streamencodebool(b):
+    # major type 7, simple value 20 and 21.
+    yield b'\xf5' if b else b'\xf4'
+def streamencodenone(v):
+    # major type 7, simple value 22.
+    yield b'\xf6'
+    bytes: streamencodebytestring,
+    int: streamencodeint,
+    list: streamencodearray,
+    tuple: streamencodearray,
+    dict: streamencodemap,
+    set: streamencodeset,
+    bool: streamencodebool,
+    type(None): streamencodenone,
+def streamencode(v):
+    """Encode a value in a streaming manner.
+    Given an input object, encode it to CBOR recursively.
+    Returns a generator of CBOR encoded bytes. There is no guarantee
+    that each emitted chunk fully decodes to a value or sub-value.
+    Encoding is deterministic - unordered collections are sorted.
+    """
+    fn = STREAM_ENCODERS.get(v.__class__)
+    if not fn:
+        raise ValueError('do not know how to encode %s' % type(v))
+    return fn(v)
+def readindefinitebytestringtoiter(fh, expectheader=True):
+    """Read an indefinite bytestring to a generator.
+    Receives an object with a ``read(X)`` method to read N bytes.
+    If ``expectheader`` is True, it is expected that the first byte read
+    will represent an indefinite length bytestring. Otherwise, we
+    expect the first byte to be part of the first bytestring chunk.
+    """
+    read = fh.read
+    decodeuint = decodermod.decode_uint
+    byteasinteger = decodermod.byte_as_integer
+    if expectheader:
+        initial = decodermod.byte_as_integer(read(1))
+        majortype = initial >> 5
+        subtype = initial & SUBTYPE_MASK
+        if majortype != MAJOR_TYPE_BYTESTRING:
+            raise decodermod.CBORDecodeError(
+                'expected major type %d; got %d' % (MAJOR_TYPE_BYTESTRING,
+                                                    majortype))
+        if subtype != SUBTYPE_INDEFINITE:
+            raise decodermod.CBORDecodeError(
+                'expected indefinite subtype; got %d' % subtype)
+    # The indefinite bytestring is composed of chunks of normal bytestrings.
+    # Read chunks until we hit a BREAK byte.
+    while True:
+        # We need to sniff for the BREAK byte.
+        initial = byteasinteger(read(1))
+        if initial == BREAK_INT:
+            break
+        length = decodeuint(fh, initial & SUBTYPE_MASK)
+        chunk = read(length)
+        if len(chunk) != length:
+            raise decodermod.CBORDecodeError(
+                'failed to read bytestring chunk: got %d bytes; expected %d' % (
+                    len(chunk), length))
+        yield chunk
diff --git a/contrib/import-checker.py b/contrib/import-checker.py
--- a/contrib/import-checker.py
+++ b/contrib/import-checker.py
@@ -36,6 +36,8 @@
     # third-party imports should be directly imported
+    'mercurial.thirdparty.cbor',
+    'mercurial.thirdparty.cbor.cbor2',

To: indygreg, #hg-reviewers
Cc: yuja, mercurial-devel

More information about the Mercurial-devel mailing list