From 524dd4afa8dc2a0d17ab08f20f592af03c165db1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 26 Sep 2024 23:11:28 +0200 Subject: [PATCH] Avoid making a copy of large frames. This isn't very significant compared to the cost of compression. It can make a real difference for decompression. --- docs/project/changelog.rst | 14 +++++++++ .../extensions/permessage_deflate.py | 30 ++++++++++++------- src/websockets/frames.py | 8 ++--- tests/extensions/test_permessage_deflate.py | 11 +++++++ 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index c8d854ba..f5b4812b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -61,6 +61,20 @@ Backwards-incompatible changes Aliases for deprecated API were removed from ``__all__``. As a consequence, they cannot be imported e.g. with ``from websockets import *`` anymore. +.. admonition:: :attr:`Frame.data ` is now a bytes-like object. + :class: note + + In addition to :class:`bytes`, it may be a :class:`bytearray` or a + :class:`memoryview`. + + If you wrote an :class:`extension ` that relies on + methods not provided by these new types, you may need to update your code. + +Improvements +............ + +* Sending or receiving large compressed frames is now faster. + .. _13.1: 13.1 diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 21df804f..ed16937d 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -129,16 +129,22 @@ def decode( # Uncompress data. Protect against zip bombs by preventing zlib from # decompressing more than max_length bytes (except when the limit is # disabled with max_size = None). - data = frame.data - if frame.fin: - data += _EMPTY_UNCOMPRESSED_BLOCK + if frame.fin and len(frame.data) < 2044: + # Profiling shows that appending four bytes, which makes a copy, is + # faster than calling decompress() again when data is less than 2kB. + data = bytes(frame.data) + _EMPTY_UNCOMPRESSED_BLOCK + else: + data = frame.data max_length = 0 if max_size is None else max_size try: data = self.decoder.decompress(data, max_length) + if self.decoder.unconsumed_tail: + raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") + if frame.fin and len(frame.data) >= 2044: + # This cannot generate additional data. + self.decoder.decompress(_EMPTY_UNCOMPRESSED_BLOCK) except zlib.error as exc: raise ProtocolError("decompression failed") from exc - if self.decoder.unconsumed_tail: - raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") # Allow garbage collection of the decoder if it won't be reused. if frame.fin and self.remote_no_context_takeover: @@ -176,11 +182,15 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # Compress data. data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) - if frame.fin and data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK: - # Making a copy is faster than memoryview(a)[:-4] until about 2kB. - # On larger messages, it's slower but profiling shows that it's - # marginal compared to compress() and flush(). Keep it simple. - data = data[:-4] + if frame.fin: + # Sync flush generates between 5 or 6 bytes, ending with the bytes + # 0x00 0x00 0xff 0xff, which must be removed. + assert data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK + # Making a copy is faster than memoryview(a)[:-4] until 2kB. + if len(data) < 2048: + data = data[:-4] + else: + data = memoryview(data)[:-4] # Allow garbage collection of the encoder if it won't be reused. if frame.fin and self.local_no_context_takeover: diff --git a/src/websockets/frames.py b/src/websockets/frames.py index dace2c90..5fadf3c2 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -7,7 +7,7 @@ import secrets import struct from collections.abc import Generator, Sequence -from typing import Callable +from typing import Callable, Union from .exceptions import PayloadTooBig, ProtocolError @@ -139,7 +139,7 @@ class Frame: """ opcode: Opcode - data: bytes + data: Union[bytes, bytearray, memoryview] fin: bool = True rsv1: bool = False rsv2: bool = False @@ -160,7 +160,7 @@ def __str__(self) -> str: if self.opcode is OP_TEXT: # Decoding only the beginning and the end is needlessly hard. # Decode the entire payload then elide later if necessary. - data = repr(self.data.decode()) + data = repr(bytes(self.data).decode()) elif self.opcode is OP_BINARY: # We'll show at most the first 16 bytes and the last 8 bytes. # Encode just what we need, plus two dummy bytes to elide later. @@ -178,7 +178,7 @@ def __str__(self) -> str: # binary. If self.data is a memoryview, it has no decode() method, # which raises AttributeError. try: - data = repr(self.data.decode()) + data = repr(bytes(self.data).decode()) coding = "text" except (UnicodeDecodeError, AttributeError): binary = self.data diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index ee09813c..76cd4862 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -1,4 +1,5 @@ import dataclasses +import os import unittest from websockets.exceptions import ( @@ -167,6 +168,16 @@ def test_encode_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame1, frame1) self.assertEqual(dec_frame2, frame2) + def test_encode_decode_large_frame(self): + # There is a separate code path that avoids copying data + # when frames are larger than 2kB. Test it for coverage. + frame = Frame(OP_BINARY, os.urandom(4096)) + + enc_frame = self.extension.encode(frame) + dec_frame = self.extension.decode(enc_frame) + + self.assertEqual(dec_frame, frame) + def test_no_decode_text_frame(self): frame = Frame(OP_TEXT, "café".encode())