Initial commit: Email alerts application
This commit is contained in:
@@ -0,0 +1 @@
|
||||
b01999d409b29bd916e067bc963d5f2d9ee63cfc9ae0bccb769910131417bf93
|
||||
@@ -0,0 +1 @@
|
||||
0478ceb55d0ed30ef1a7da742cd003449bc69a07cf9fdb06789bd2b347cbfffe
|
||||
@@ -0,0 +1 @@
|
||||
9e5fe78ed0ebce5414d2b8e01868d90c1facc20b84d2d5ff6c23e86e44a155ae
|
||||
@@ -0,0 +1 @@
|
||||
"""WebSocket protocol versions 13 and 8."""
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,147 @@
|
||||
"""Helpers for WebSocket protocol versions 13 and 8."""
|
||||
|
||||
import functools
|
||||
import re
|
||||
from struct import Struct
|
||||
from typing import TYPE_CHECKING, Final, List, Optional, Pattern, Tuple
|
||||
|
||||
from ..helpers import NO_EXTENSIONS
|
||||
from .models import WSHandshakeError
|
||||
|
||||
UNPACK_LEN3 = Struct("!Q").unpack_from
|
||||
UNPACK_CLOSE_CODE = Struct("!H").unpack
|
||||
PACK_LEN1 = Struct("!BB").pack
|
||||
PACK_LEN2 = Struct("!BBH").pack
|
||||
PACK_LEN3 = Struct("!BBQ").pack
|
||||
PACK_CLOSE_CODE = Struct("!H").pack
|
||||
PACK_RANDBITS = Struct("!L").pack
|
||||
MSG_SIZE: Final[int] = 2**14
|
||||
MASK_LEN: Final[int] = 4
|
||||
|
||||
WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
|
||||
# Used by _websocket_mask_python
|
||||
@functools.lru_cache
|
||||
def _xor_table() -> List[bytes]:
|
||||
return [bytes(a ^ b for a in range(256)) for b in range(256)]
|
||||
|
||||
|
||||
def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
|
||||
"""Websocket masking function.
|
||||
|
||||
`mask` is a `bytes` object of length 4; `data` is a `bytearray`
|
||||
object of any length. The contents of `data` are masked with `mask`,
|
||||
as specified in section 5.3 of RFC 6455.
|
||||
|
||||
Note that this function mutates the `data` argument.
|
||||
|
||||
This pure-python implementation may be replaced by an optimized
|
||||
version when available.
|
||||
|
||||
"""
|
||||
assert isinstance(data, bytearray), data
|
||||
assert len(mask) == 4, mask
|
||||
|
||||
if data:
|
||||
_XOR_TABLE = _xor_table()
|
||||
a, b, c, d = (_XOR_TABLE[n] for n in mask)
|
||||
data[::4] = data[::4].translate(a)
|
||||
data[1::4] = data[1::4].translate(b)
|
||||
data[2::4] = data[2::4].translate(c)
|
||||
data[3::4] = data[3::4].translate(d)
|
||||
|
||||
|
||||
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
|
||||
websocket_mask = _websocket_mask_python
|
||||
else:
|
||||
try:
|
||||
from .mask import _websocket_mask_cython # type: ignore[import-not-found]
|
||||
|
||||
websocket_mask = _websocket_mask_cython
|
||||
except ImportError: # pragma: no cover
|
||||
websocket_mask = _websocket_mask_python
|
||||
|
||||
|
||||
_WS_EXT_RE: Final[Pattern[str]] = re.compile(
|
||||
r"^(?:;\s*(?:"
|
||||
r"(server_no_context_takeover)|"
|
||||
r"(client_no_context_takeover)|"
|
||||
r"(server_max_window_bits(?:=(\d+))?)|"
|
||||
r"(client_max_window_bits(?:=(\d+))?)))*$"
|
||||
)
|
||||
|
||||
_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
|
||||
|
||||
|
||||
def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]:
|
||||
if not extstr:
|
||||
return 0, False
|
||||
|
||||
compress = 0
|
||||
notakeover = False
|
||||
for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
|
||||
defext = ext.group(1)
|
||||
# Return compress = 15 when get `permessage-deflate`
|
||||
if not defext:
|
||||
compress = 15
|
||||
break
|
||||
match = _WS_EXT_RE.match(defext)
|
||||
if match:
|
||||
compress = 15
|
||||
if isserver:
|
||||
# Server never fail to detect compress handshake.
|
||||
# Server does not need to send max wbit to client
|
||||
if match.group(4):
|
||||
compress = int(match.group(4))
|
||||
# Group3 must match if group4 matches
|
||||
# Compress wbit 8 does not support in zlib
|
||||
# If compress level not support,
|
||||
# CONTINUE to next extension
|
||||
if compress > 15 or compress < 9:
|
||||
compress = 0
|
||||
continue
|
||||
if match.group(1):
|
||||
notakeover = True
|
||||
# Ignore regex group 5 & 6 for client_max_window_bits
|
||||
break
|
||||
else:
|
||||
if match.group(6):
|
||||
compress = int(match.group(6))
|
||||
# Group5 must match if group6 matches
|
||||
# Compress wbit 8 does not support in zlib
|
||||
# If compress level not support,
|
||||
# FAIL the parse progress
|
||||
if compress > 15 or compress < 9:
|
||||
raise WSHandshakeError("Invalid window size")
|
||||
if match.group(2):
|
||||
notakeover = True
|
||||
# Ignore regex group 5 & 6 for client_max_window_bits
|
||||
break
|
||||
# Return Fail if client side and not match
|
||||
elif not isserver:
|
||||
raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
|
||||
|
||||
return compress, notakeover
|
||||
|
||||
|
||||
def ws_ext_gen(
|
||||
compress: int = 15, isserver: bool = False, server_notakeover: bool = False
|
||||
) -> str:
|
||||
# client_notakeover=False not used for server
|
||||
# compress wbit 8 does not support in zlib
|
||||
if compress < 9 or compress > 15:
|
||||
raise ValueError(
|
||||
"Compress wbits must between 9 and 15, zlib does not support wbits=8"
|
||||
)
|
||||
enabledext = ["permessage-deflate"]
|
||||
if not isserver:
|
||||
enabledext.append("client_max_window_bits")
|
||||
|
||||
if compress < 15:
|
||||
enabledext.append("server_max_window_bits=" + str(compress))
|
||||
if server_notakeover:
|
||||
enabledext.append("server_no_context_takeover")
|
||||
# if client_notakeover:
|
||||
# enabledext.append('client_no_context_takeover')
|
||||
return "; ".join(enabledext)
|
||||
BIN
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
"""Cython declarations for websocket masking."""
|
||||
|
||||
cpdef void _websocket_mask_cython(bytes mask, bytearray data)
|
||||
@@ -0,0 +1,48 @@
|
||||
from cpython cimport PyBytes_AsString
|
||||
|
||||
|
||||
#from cpython cimport PyByteArray_AsString # cython still not exports that
|
||||
cdef extern from "Python.h":
|
||||
char* PyByteArray_AsString(bytearray ba) except NULL
|
||||
|
||||
from libc.stdint cimport uint32_t, uint64_t, uintmax_t
|
||||
|
||||
|
||||
cpdef void _websocket_mask_cython(bytes mask, bytearray data):
|
||||
"""Note, this function mutates its `data` argument
|
||||
"""
|
||||
cdef:
|
||||
Py_ssize_t data_len, i
|
||||
# bit operations on signed integers are implementation-specific
|
||||
unsigned char * in_buf
|
||||
const unsigned char * mask_buf
|
||||
uint32_t uint32_msk
|
||||
uint64_t uint64_msk
|
||||
|
||||
assert len(mask) == 4
|
||||
|
||||
data_len = len(data)
|
||||
in_buf = <unsigned char*>PyByteArray_AsString(data)
|
||||
mask_buf = <const unsigned char*>PyBytes_AsString(mask)
|
||||
uint32_msk = (<uint32_t*>mask_buf)[0]
|
||||
|
||||
# TODO: align in_data ptr to achieve even faster speeds
|
||||
# does it need in python ?! malloc() always aligns to sizeof(long) bytes
|
||||
|
||||
if sizeof(size_t) >= 8:
|
||||
uint64_msk = uint32_msk
|
||||
uint64_msk = (uint64_msk << 32) | uint32_msk
|
||||
|
||||
while data_len >= 8:
|
||||
(<uint64_t*>in_buf)[0] ^= uint64_msk
|
||||
in_buf += 8
|
||||
data_len -= 8
|
||||
|
||||
|
||||
while data_len >= 4:
|
||||
(<uint32_t*>in_buf)[0] ^= uint32_msk
|
||||
in_buf += 4
|
||||
data_len -= 4
|
||||
|
||||
for i in range(0, data_len):
|
||||
in_buf[i] ^= mask_buf[i]
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Models for WebSocket protocol versions 13 and 8."""
|
||||
|
||||
import json
|
||||
from enum import IntEnum
|
||||
from typing import Any, Callable, Final, NamedTuple, Optional, cast
|
||||
|
||||
WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])
|
||||
|
||||
|
||||
class WSCloseCode(IntEnum):
|
||||
OK = 1000
|
||||
GOING_AWAY = 1001
|
||||
PROTOCOL_ERROR = 1002
|
||||
UNSUPPORTED_DATA = 1003
|
||||
ABNORMAL_CLOSURE = 1006
|
||||
INVALID_TEXT = 1007
|
||||
POLICY_VIOLATION = 1008
|
||||
MESSAGE_TOO_BIG = 1009
|
||||
MANDATORY_EXTENSION = 1010
|
||||
INTERNAL_ERROR = 1011
|
||||
SERVICE_RESTART = 1012
|
||||
TRY_AGAIN_LATER = 1013
|
||||
BAD_GATEWAY = 1014
|
||||
|
||||
|
||||
class WSMsgType(IntEnum):
|
||||
# websocket spec types
|
||||
CONTINUATION = 0x0
|
||||
TEXT = 0x1
|
||||
BINARY = 0x2
|
||||
PING = 0x9
|
||||
PONG = 0xA
|
||||
CLOSE = 0x8
|
||||
|
||||
# aiohttp specific types
|
||||
CLOSING = 0x100
|
||||
CLOSED = 0x101
|
||||
ERROR = 0x102
|
||||
|
||||
text = TEXT
|
||||
binary = BINARY
|
||||
ping = PING
|
||||
pong = PONG
|
||||
close = CLOSE
|
||||
closing = CLOSING
|
||||
closed = CLOSED
|
||||
error = ERROR
|
||||
|
||||
|
||||
class WSMessage(NamedTuple):
|
||||
type: WSMsgType
|
||||
# To type correctly, this would need some kind of tagged union for each type.
|
||||
data: Any
|
||||
extra: Optional[str]
|
||||
|
||||
def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
|
||||
"""Return parsed JSON data.
|
||||
|
||||
.. versionadded:: 0.22
|
||||
"""
|
||||
return loads(self.data)
|
||||
|
||||
|
||||
# Constructing the tuple directly to avoid the overhead of
|
||||
# the lambda and arg processing since NamedTuples are constructed
|
||||
# with a run time built lambda
|
||||
# https://github.com/python/cpython/blob/d83fcf8371f2f33c7797bc8f5423a8bca8c46e5c/Lib/collections/__init__.py#L441
|
||||
WS_CLOSED_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSED, None, None))
|
||||
WS_CLOSING_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSING, None, None))
|
||||
|
||||
|
||||
class WebSocketError(Exception):
|
||||
"""WebSocket protocol parser error."""
|
||||
|
||||
def __init__(self, code: int, message: str) -> None:
|
||||
self.code = code
|
||||
super().__init__(code, message)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return cast(str, self.args[1])
|
||||
|
||||
|
||||
class WSHandshakeError(Exception):
|
||||
"""WebSocket protocol handshake error."""
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Reader for WebSocket protocol versions 13 and 8."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..helpers import NO_EXTENSIONS
|
||||
|
||||
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
|
||||
from .reader_py import (
|
||||
WebSocketDataQueue as WebSocketDataQueuePython,
|
||||
WebSocketReader as WebSocketReaderPython,
|
||||
)
|
||||
|
||||
WebSocketReader = WebSocketReaderPython
|
||||
WebSocketDataQueue = WebSocketDataQueuePython
|
||||
else:
|
||||
try:
|
||||
from .reader_c import ( # type: ignore[import-not-found]
|
||||
WebSocketDataQueue as WebSocketDataQueueCython,
|
||||
WebSocketReader as WebSocketReaderCython,
|
||||
)
|
||||
|
||||
WebSocketReader = WebSocketReaderCython
|
||||
WebSocketDataQueue = WebSocketDataQueueCython
|
||||
except ImportError: # pragma: no cover
|
||||
from .reader_py import (
|
||||
WebSocketDataQueue as WebSocketDataQueuePython,
|
||||
WebSocketReader as WebSocketReaderPython,
|
||||
)
|
||||
|
||||
WebSocketReader = WebSocketReaderPython
|
||||
WebSocketDataQueue = WebSocketDataQueuePython
|
||||
BIN
Binary file not shown.
@@ -0,0 +1,110 @@
|
||||
import cython
|
||||
|
||||
from .mask cimport _websocket_mask_cython as websocket_mask
|
||||
|
||||
|
||||
cdef unsigned int READ_HEADER
|
||||
cdef unsigned int READ_PAYLOAD_LENGTH
|
||||
cdef unsigned int READ_PAYLOAD_MASK
|
||||
cdef unsigned int READ_PAYLOAD
|
||||
|
||||
cdef int OP_CODE_NOT_SET
|
||||
cdef int OP_CODE_CONTINUATION
|
||||
cdef int OP_CODE_TEXT
|
||||
cdef int OP_CODE_BINARY
|
||||
cdef int OP_CODE_CLOSE
|
||||
cdef int OP_CODE_PING
|
||||
cdef int OP_CODE_PONG
|
||||
|
||||
cdef int COMPRESSED_NOT_SET
|
||||
cdef int COMPRESSED_FALSE
|
||||
cdef int COMPRESSED_TRUE
|
||||
|
||||
cdef object UNPACK_LEN3
|
||||
cdef object UNPACK_CLOSE_CODE
|
||||
cdef object TUPLE_NEW
|
||||
|
||||
cdef object WSMsgType
|
||||
cdef object WSMessage
|
||||
|
||||
cdef object WS_MSG_TYPE_TEXT
|
||||
cdef object WS_MSG_TYPE_BINARY
|
||||
|
||||
cdef set ALLOWED_CLOSE_CODES
|
||||
cdef set MESSAGE_TYPES_WITH_CONTENT
|
||||
|
||||
cdef tuple EMPTY_FRAME
|
||||
cdef tuple EMPTY_FRAME_ERROR
|
||||
|
||||
cdef class WebSocketDataQueue:
|
||||
|
||||
cdef unsigned int _size
|
||||
cdef public object _protocol
|
||||
cdef unsigned int _limit
|
||||
cdef object _loop
|
||||
cdef bint _eof
|
||||
cdef object _waiter
|
||||
cdef object _exception
|
||||
cdef public object _buffer
|
||||
cdef object _get_buffer
|
||||
cdef object _put_buffer
|
||||
|
||||
cdef void _release_waiter(self)
|
||||
|
||||
cpdef void feed_data(self, object data, unsigned int size)
|
||||
|
||||
@cython.locals(size="unsigned int")
|
||||
cdef _read_from_buffer(self)
|
||||
|
||||
cdef class WebSocketReader:
|
||||
|
||||
cdef WebSocketDataQueue queue
|
||||
cdef unsigned int _max_msg_size
|
||||
|
||||
cdef Exception _exc
|
||||
cdef bytearray _partial
|
||||
cdef unsigned int _state
|
||||
|
||||
cdef int _opcode
|
||||
cdef bint _frame_fin
|
||||
cdef int _frame_opcode
|
||||
cdef list _payload_fragments
|
||||
cdef Py_ssize_t _frame_payload_len
|
||||
|
||||
cdef bytes _tail
|
||||
cdef bint _has_mask
|
||||
cdef bytes _frame_mask
|
||||
cdef Py_ssize_t _payload_bytes_to_read
|
||||
cdef unsigned int _payload_len_flag
|
||||
cdef int _compressed
|
||||
cdef object _decompressobj
|
||||
cdef bint _compress
|
||||
|
||||
cpdef tuple feed_data(self, object data)
|
||||
|
||||
@cython.locals(
|
||||
is_continuation=bint,
|
||||
fin=bint,
|
||||
has_partial=bint,
|
||||
payload_merged=bytes,
|
||||
)
|
||||
cpdef void _handle_frame(self, bint fin, int opcode, object payload, int compressed) except *
|
||||
|
||||
@cython.locals(
|
||||
start_pos=Py_ssize_t,
|
||||
data_len=Py_ssize_t,
|
||||
length=Py_ssize_t,
|
||||
chunk_size=Py_ssize_t,
|
||||
chunk_len=Py_ssize_t,
|
||||
data_len=Py_ssize_t,
|
||||
data_cstr="const unsigned char *",
|
||||
first_byte="unsigned char",
|
||||
second_byte="unsigned char",
|
||||
f_start_pos=Py_ssize_t,
|
||||
f_end_pos=Py_ssize_t,
|
||||
has_mask=bint,
|
||||
fin=bint,
|
||||
had_fragments=Py_ssize_t,
|
||||
payload_bytearray=bytearray,
|
||||
)
|
||||
cpdef void _feed_data(self, bytes data) except *
|
||||
@@ -0,0 +1,476 @@
|
||||
"""Reader for WebSocket protocol versions 13 and 8."""
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
from collections import deque
|
||||
from typing import Deque, Final, Optional, Set, Tuple, Union
|
||||
|
||||
from ..base_protocol import BaseProtocol
|
||||
from ..compression_utils import ZLibDecompressor
|
||||
from ..helpers import _EXC_SENTINEL, set_exception
|
||||
from ..streams import EofStream
|
||||
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
|
||||
from .models import (
|
||||
WS_DEFLATE_TRAILING,
|
||||
WebSocketError,
|
||||
WSCloseCode,
|
||||
WSMessage,
|
||||
WSMsgType,
|
||||
)
|
||||
|
||||
ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
|
||||
|
||||
# States for the reader, used to parse the WebSocket frame
|
||||
# integer values are used so they can be cythonized
|
||||
READ_HEADER = 1
|
||||
READ_PAYLOAD_LENGTH = 2
|
||||
READ_PAYLOAD_MASK = 3
|
||||
READ_PAYLOAD = 4
|
||||
|
||||
WS_MSG_TYPE_BINARY = WSMsgType.BINARY
|
||||
WS_MSG_TYPE_TEXT = WSMsgType.TEXT
|
||||
|
||||
# WSMsgType values unpacked so they can by cythonized to ints
|
||||
OP_CODE_NOT_SET = -1
|
||||
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
|
||||
OP_CODE_TEXT = WSMsgType.TEXT.value
|
||||
OP_CODE_BINARY = WSMsgType.BINARY.value
|
||||
OP_CODE_CLOSE = WSMsgType.CLOSE.value
|
||||
OP_CODE_PING = WSMsgType.PING.value
|
||||
OP_CODE_PONG = WSMsgType.PONG.value
|
||||
|
||||
EMPTY_FRAME_ERROR = (True, b"")
|
||||
EMPTY_FRAME = (False, b"")
|
||||
|
||||
COMPRESSED_NOT_SET = -1
|
||||
COMPRESSED_FALSE = 0
|
||||
COMPRESSED_TRUE = 1
|
||||
|
||||
TUPLE_NEW = tuple.__new__
|
||||
|
||||
cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
|
||||
|
||||
|
||||
class WebSocketDataQueue:
|
||||
"""WebSocketDataQueue resumes and pauses an underlying stream.
|
||||
|
||||
It is a destination for WebSocket data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
|
||||
) -> None:
|
||||
self._size = 0
|
||||
self._protocol = protocol
|
||||
self._limit = limit * 2
|
||||
self._loop = loop
|
||||
self._eof = False
|
||||
self._waiter: Optional[asyncio.Future[None]] = None
|
||||
self._exception: Union[BaseException, None] = None
|
||||
self._buffer: Deque[Tuple[WSMessage, int]] = deque()
|
||||
self._get_buffer = self._buffer.popleft
|
||||
self._put_buffer = self._buffer.append
|
||||
|
||||
def is_eof(self) -> bool:
|
||||
return self._eof
|
||||
|
||||
def exception(self) -> Optional[BaseException]:
|
||||
return self._exception
|
||||
|
||||
def set_exception(
|
||||
self,
|
||||
exc: BaseException,
|
||||
exc_cause: builtins.BaseException = _EXC_SENTINEL,
|
||||
) -> None:
|
||||
self._eof = True
|
||||
self._exception = exc
|
||||
if (waiter := self._waiter) is not None:
|
||||
self._waiter = None
|
||||
set_exception(waiter, exc, exc_cause)
|
||||
|
||||
def _release_waiter(self) -> None:
|
||||
if (waiter := self._waiter) is None:
|
||||
return
|
||||
self._waiter = None
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
|
||||
def feed_eof(self) -> None:
|
||||
self._eof = True
|
||||
self._release_waiter()
|
||||
self._exception = None # Break cyclic references
|
||||
|
||||
def feed_data(self, data: "WSMessage", size: "cython_int") -> None:
|
||||
self._size += size
|
||||
self._put_buffer((data, size))
|
||||
self._release_waiter()
|
||||
if self._size > self._limit and not self._protocol._reading_paused:
|
||||
self._protocol.pause_reading()
|
||||
|
||||
async def read(self) -> WSMessage:
|
||||
if not self._buffer and not self._eof:
|
||||
assert not self._waiter
|
||||
self._waiter = self._loop.create_future()
|
||||
try:
|
||||
await self._waiter
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
self._waiter = None
|
||||
raise
|
||||
return self._read_from_buffer()
|
||||
|
||||
def _read_from_buffer(self) -> WSMessage:
|
||||
if self._buffer:
|
||||
data, size = self._get_buffer()
|
||||
self._size -= size
|
||||
if self._size < self._limit and self._protocol._reading_paused:
|
||||
self._protocol.resume_reading()
|
||||
return data
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
raise EofStream
|
||||
|
||||
|
||||
class WebSocketReader:
|
||||
def __init__(
|
||||
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
|
||||
) -> None:
|
||||
self.queue = queue
|
||||
self._max_msg_size = max_msg_size
|
||||
|
||||
self._exc: Optional[Exception] = None
|
||||
self._partial = bytearray()
|
||||
self._state = READ_HEADER
|
||||
|
||||
self._opcode: int = OP_CODE_NOT_SET
|
||||
self._frame_fin = False
|
||||
self._frame_opcode: int = OP_CODE_NOT_SET
|
||||
self._payload_fragments: list[bytes] = []
|
||||
self._frame_payload_len = 0
|
||||
|
||||
self._tail: bytes = b""
|
||||
self._has_mask = False
|
||||
self._frame_mask: Optional[bytes] = None
|
||||
self._payload_bytes_to_read = 0
|
||||
self._payload_len_flag = 0
|
||||
self._compressed: int = COMPRESSED_NOT_SET
|
||||
self._decompressobj: Optional[ZLibDecompressor] = None
|
||||
self._compress = compress
|
||||
|
||||
def feed_eof(self) -> None:
|
||||
self.queue.feed_eof()
|
||||
|
||||
# data can be bytearray on Windows because proactor event loop uses bytearray
|
||||
# and asyncio types this to Union[bytes, bytearray, memoryview] so we need
|
||||
# coerce data to bytes if it is not
|
||||
def feed_data(
|
||||
self, data: Union[bytes, bytearray, memoryview]
|
||||
) -> Tuple[bool, bytes]:
|
||||
if type(data) is not bytes:
|
||||
data = bytes(data)
|
||||
|
||||
if self._exc is not None:
|
||||
return True, data
|
||||
|
||||
try:
|
||||
self._feed_data(data)
|
||||
except Exception as exc:
|
||||
self._exc = exc
|
||||
set_exception(self.queue, exc)
|
||||
return EMPTY_FRAME_ERROR
|
||||
|
||||
return EMPTY_FRAME
|
||||
|
||||
def _handle_frame(
|
||||
self,
|
||||
fin: bool,
|
||||
opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int
|
||||
payload: Union[bytes, bytearray],
|
||||
compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int
|
||||
) -> None:
|
||||
msg: WSMessage
|
||||
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
|
||||
# load text/binary
|
||||
if not fin:
|
||||
# got partial frame payload
|
||||
if opcode != OP_CODE_CONTINUATION:
|
||||
self._opcode = opcode
|
||||
self._partial += payload
|
||||
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.MESSAGE_TOO_BIG,
|
||||
f"Message size {len(self._partial)} "
|
||||
f"exceeds limit {self._max_msg_size}",
|
||||
)
|
||||
return
|
||||
|
||||
has_partial = bool(self._partial)
|
||||
if opcode == OP_CODE_CONTINUATION:
|
||||
if self._opcode == OP_CODE_NOT_SET:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Continuation frame for non started message",
|
||||
)
|
||||
opcode = self._opcode
|
||||
self._opcode = OP_CODE_NOT_SET
|
||||
# previous frame was non finished
|
||||
# we should get continuation opcode
|
||||
elif has_partial:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"The opcode in non-fin frame is expected "
|
||||
f"to be zero, got {opcode!r}",
|
||||
)
|
||||
|
||||
assembled_payload: Union[bytes, bytearray]
|
||||
if has_partial:
|
||||
assembled_payload = self._partial + payload
|
||||
self._partial.clear()
|
||||
else:
|
||||
assembled_payload = payload
|
||||
|
||||
if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.MESSAGE_TOO_BIG,
|
||||
f"Message size {len(assembled_payload)} "
|
||||
f"exceeds limit {self._max_msg_size}",
|
||||
)
|
||||
|
||||
# Decompress process must to be done after all packets
|
||||
# received.
|
||||
if compressed:
|
||||
if not self._decompressobj:
|
||||
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
|
||||
# XXX: It's possible that the zlib backend (isal is known to
|
||||
# do this, maybe others too?) will return max_length bytes,
|
||||
# but internally buffer more data such that the payload is
|
||||
# >max_length, so we return one extra byte and if we're able
|
||||
# to do that, then the message is too big.
|
||||
payload_merged = self._decompressobj.decompress_sync(
|
||||
assembled_payload + WS_DEFLATE_TRAILING,
|
||||
(
|
||||
self._max_msg_size + 1
|
||||
if self._max_msg_size
|
||||
else self._max_msg_size
|
||||
),
|
||||
)
|
||||
if self._max_msg_size and len(payload_merged) > self._max_msg_size:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.MESSAGE_TOO_BIG,
|
||||
f"Decompressed message exceeds size limit {self._max_msg_size}",
|
||||
)
|
||||
elif type(assembled_payload) is bytes:
|
||||
payload_merged = assembled_payload
|
||||
else:
|
||||
payload_merged = bytes(assembled_payload)
|
||||
|
||||
if opcode == OP_CODE_TEXT:
|
||||
try:
|
||||
text = payload_merged.decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
||||
) from exc
|
||||
|
||||
# XXX: The Text and Binary messages here can be a performance
|
||||
# bottleneck, so we use tuple.__new__ to improve performance.
|
||||
# This is not type safe, but many tests should fail in
|
||||
# test_client_ws_functional.py if this is wrong.
|
||||
self.queue.feed_data(
|
||||
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
|
||||
len(payload_merged),
|
||||
)
|
||||
else:
|
||||
self.queue.feed_data(
|
||||
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
|
||||
len(payload_merged),
|
||||
)
|
||||
elif opcode == OP_CODE_CLOSE:
|
||||
if len(payload) >= 2:
|
||||
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
|
||||
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
f"Invalid close code: {close_code}",
|
||||
)
|
||||
try:
|
||||
close_message = payload[2:].decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
||||
) from exc
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
|
||||
elif payload:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
f"Invalid close frame: {fin} {opcode} {payload!r}",
|
||||
)
|
||||
else:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
|
||||
|
||||
self.queue.feed_data(msg, 0)
|
||||
elif opcode == OP_CODE_PING:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
|
||||
self.queue.feed_data(msg, len(payload))
|
||||
elif opcode == OP_CODE_PONG:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
|
||||
self.queue.feed_data(msg, len(payload))
|
||||
else:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
|
||||
)
|
||||
|
||||
def _feed_data(self, data: bytes) -> None:
|
||||
"""Return the next frame from the socket."""
|
||||
if self._tail:
|
||||
data, self._tail = self._tail + data, b""
|
||||
|
||||
start_pos: int = 0
|
||||
data_len = len(data)
|
||||
data_cstr = data
|
||||
|
||||
while True:
|
||||
# read header
|
||||
if self._state == READ_HEADER:
|
||||
if data_len - start_pos < 2:
|
||||
break
|
||||
first_byte = data_cstr[start_pos]
|
||||
second_byte = data_cstr[start_pos + 1]
|
||||
start_pos += 2
|
||||
|
||||
fin = (first_byte >> 7) & 1
|
||||
rsv1 = (first_byte >> 6) & 1
|
||||
rsv2 = (first_byte >> 5) & 1
|
||||
rsv3 = (first_byte >> 4) & 1
|
||||
opcode = first_byte & 0xF
|
||||
|
||||
# frame-fin = %x0 ; more frames of this message follow
|
||||
# / %x1 ; final frame of this message
|
||||
# frame-rsv1 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
# frame-rsv2 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
# frame-rsv3 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
#
|
||||
# Remove rsv1 from this test for deflate development
|
||||
if rsv2 or rsv3 or (rsv1 and not self._compress):
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received frame with non-zero reserved bits",
|
||||
)
|
||||
|
||||
if opcode > 0x7 and fin == 0:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received fragmented control frame",
|
||||
)
|
||||
|
||||
has_mask = (second_byte >> 7) & 1
|
||||
length = second_byte & 0x7F
|
||||
|
||||
# Control frames MUST have a payload
|
||||
# length of 125 bytes or less
|
||||
if opcode > 0x7 and length > 125:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Control frame payload cannot be larger than 125 bytes",
|
||||
)
|
||||
|
||||
# Set compress status if last package is FIN
|
||||
# OR set compress status if this is first fragment
|
||||
# Raise error if not first fragment with rsv1 = 0x1
|
||||
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
|
||||
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
|
||||
elif rsv1:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received frame with non-zero reserved bits",
|
||||
)
|
||||
|
||||
self._frame_fin = bool(fin)
|
||||
self._frame_opcode = opcode
|
||||
self._has_mask = bool(has_mask)
|
||||
self._payload_len_flag = length
|
||||
self._state = READ_PAYLOAD_LENGTH
|
||||
|
||||
# read payload length
|
||||
if self._state == READ_PAYLOAD_LENGTH:
|
||||
len_flag = self._payload_len_flag
|
||||
if len_flag == 126:
|
||||
if data_len - start_pos < 2:
|
||||
break
|
||||
first_byte = data_cstr[start_pos]
|
||||
second_byte = data_cstr[start_pos + 1]
|
||||
start_pos += 2
|
||||
self._payload_bytes_to_read = first_byte << 8 | second_byte
|
||||
elif len_flag > 126:
|
||||
if data_len - start_pos < 8:
|
||||
break
|
||||
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
|
||||
start_pos += 8
|
||||
else:
|
||||
self._payload_bytes_to_read = len_flag
|
||||
|
||||
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
|
||||
|
||||
# read payload mask
|
||||
if self._state == READ_PAYLOAD_MASK:
|
||||
if data_len - start_pos < 4:
|
||||
break
|
||||
self._frame_mask = data_cstr[start_pos : start_pos + 4]
|
||||
start_pos += 4
|
||||
self._state = READ_PAYLOAD
|
||||
|
||||
if self._state == READ_PAYLOAD:
|
||||
chunk_len = data_len - start_pos
|
||||
if self._payload_bytes_to_read >= chunk_len:
|
||||
f_end_pos = data_len
|
||||
self._payload_bytes_to_read -= chunk_len
|
||||
else:
|
||||
f_end_pos = start_pos + self._payload_bytes_to_read
|
||||
self._payload_bytes_to_read = 0
|
||||
|
||||
had_fragments = self._frame_payload_len
|
||||
self._frame_payload_len += f_end_pos - start_pos
|
||||
f_start_pos = start_pos
|
||||
start_pos = f_end_pos
|
||||
|
||||
if self._payload_bytes_to_read != 0:
|
||||
# If we don't have a complete frame, we need to save the
|
||||
# data for the next call to feed_data.
|
||||
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
|
||||
break
|
||||
|
||||
payload: Union[bytes, bytearray]
|
||||
if had_fragments:
|
||||
# We have to join the payload fragments get the payload
|
||||
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
|
||||
if self._has_mask:
|
||||
assert self._frame_mask is not None
|
||||
payload_bytearray = bytearray(b"".join(self._payload_fragments))
|
||||
websocket_mask(self._frame_mask, payload_bytearray)
|
||||
payload = payload_bytearray
|
||||
else:
|
||||
payload = b"".join(self._payload_fragments)
|
||||
self._payload_fragments.clear()
|
||||
elif self._has_mask:
|
||||
assert self._frame_mask is not None
|
||||
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
|
||||
if type(payload_bytearray) is not bytearray: # pragma: no branch
|
||||
# Cython will do the conversion for us
|
||||
# but we need to do it for Python and we
|
||||
# will always get here in Python
|
||||
payload_bytearray = bytearray(payload_bytearray)
|
||||
websocket_mask(self._frame_mask, payload_bytearray)
|
||||
payload = payload_bytearray
|
||||
else:
|
||||
payload = data_cstr[f_start_pos:f_end_pos]
|
||||
|
||||
self._handle_frame(
|
||||
self._frame_fin, self._frame_opcode, payload, self._compressed
|
||||
)
|
||||
self._frame_payload_len = 0
|
||||
self._state = READ_HEADER
|
||||
|
||||
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
|
||||
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""
|
||||
@@ -0,0 +1,476 @@
|
||||
"""Reader for WebSocket protocol versions 13 and 8."""
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
from collections import deque
|
||||
from typing import Deque, Final, Optional, Set, Tuple, Union
|
||||
|
||||
from ..base_protocol import BaseProtocol
|
||||
from ..compression_utils import ZLibDecompressor
|
||||
from ..helpers import _EXC_SENTINEL, set_exception
|
||||
from ..streams import EofStream
|
||||
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
|
||||
from .models import (
|
||||
WS_DEFLATE_TRAILING,
|
||||
WebSocketError,
|
||||
WSCloseCode,
|
||||
WSMessage,
|
||||
WSMsgType,
|
||||
)
|
||||
|
||||
ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
|
||||
|
||||
# States for the reader, used to parse the WebSocket frame
|
||||
# integer values are used so they can be cythonized
|
||||
READ_HEADER = 1
|
||||
READ_PAYLOAD_LENGTH = 2
|
||||
READ_PAYLOAD_MASK = 3
|
||||
READ_PAYLOAD = 4
|
||||
|
||||
WS_MSG_TYPE_BINARY = WSMsgType.BINARY
|
||||
WS_MSG_TYPE_TEXT = WSMsgType.TEXT
|
||||
|
||||
# WSMsgType values unpacked so they can by cythonized to ints
|
||||
OP_CODE_NOT_SET = -1
|
||||
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
|
||||
OP_CODE_TEXT = WSMsgType.TEXT.value
|
||||
OP_CODE_BINARY = WSMsgType.BINARY.value
|
||||
OP_CODE_CLOSE = WSMsgType.CLOSE.value
|
||||
OP_CODE_PING = WSMsgType.PING.value
|
||||
OP_CODE_PONG = WSMsgType.PONG.value
|
||||
|
||||
EMPTY_FRAME_ERROR = (True, b"")
|
||||
EMPTY_FRAME = (False, b"")
|
||||
|
||||
COMPRESSED_NOT_SET = -1
|
||||
COMPRESSED_FALSE = 0
|
||||
COMPRESSED_TRUE = 1
|
||||
|
||||
TUPLE_NEW = tuple.__new__
|
||||
|
||||
cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
|
||||
|
||||
|
||||
class WebSocketDataQueue:
|
||||
"""WebSocketDataQueue resumes and pauses an underlying stream.
|
||||
|
||||
It is a destination for WebSocket data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
|
||||
) -> None:
|
||||
self._size = 0
|
||||
self._protocol = protocol
|
||||
self._limit = limit * 2
|
||||
self._loop = loop
|
||||
self._eof = False
|
||||
self._waiter: Optional[asyncio.Future[None]] = None
|
||||
self._exception: Union[BaseException, None] = None
|
||||
self._buffer: Deque[Tuple[WSMessage, int]] = deque()
|
||||
self._get_buffer = self._buffer.popleft
|
||||
self._put_buffer = self._buffer.append
|
||||
|
||||
def is_eof(self) -> bool:
|
||||
return self._eof
|
||||
|
||||
def exception(self) -> Optional[BaseException]:
|
||||
return self._exception
|
||||
|
||||
def set_exception(
|
||||
self,
|
||||
exc: BaseException,
|
||||
exc_cause: builtins.BaseException = _EXC_SENTINEL,
|
||||
) -> None:
|
||||
self._eof = True
|
||||
self._exception = exc
|
||||
if (waiter := self._waiter) is not None:
|
||||
self._waiter = None
|
||||
set_exception(waiter, exc, exc_cause)
|
||||
|
||||
def _release_waiter(self) -> None:
|
||||
if (waiter := self._waiter) is None:
|
||||
return
|
||||
self._waiter = None
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
|
||||
def feed_eof(self) -> None:
|
||||
self._eof = True
|
||||
self._release_waiter()
|
||||
self._exception = None # Break cyclic references
|
||||
|
||||
def feed_data(self, data: "WSMessage", size: "cython_int") -> None:
|
||||
self._size += size
|
||||
self._put_buffer((data, size))
|
||||
self._release_waiter()
|
||||
if self._size > self._limit and not self._protocol._reading_paused:
|
||||
self._protocol.pause_reading()
|
||||
|
||||
async def read(self) -> WSMessage:
|
||||
if not self._buffer and not self._eof:
|
||||
assert not self._waiter
|
||||
self._waiter = self._loop.create_future()
|
||||
try:
|
||||
await self._waiter
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
self._waiter = None
|
||||
raise
|
||||
return self._read_from_buffer()
|
||||
|
||||
def _read_from_buffer(self) -> WSMessage:
|
||||
if self._buffer:
|
||||
data, size = self._get_buffer()
|
||||
self._size -= size
|
||||
if self._size < self._limit and self._protocol._reading_paused:
|
||||
self._protocol.resume_reading()
|
||||
return data
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
raise EofStream
|
||||
|
||||
|
||||
class WebSocketReader:
|
||||
def __init__(
|
||||
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
|
||||
) -> None:
|
||||
self.queue = queue
|
||||
self._max_msg_size = max_msg_size
|
||||
|
||||
self._exc: Optional[Exception] = None
|
||||
self._partial = bytearray()
|
||||
self._state = READ_HEADER
|
||||
|
||||
self._opcode: int = OP_CODE_NOT_SET
|
||||
self._frame_fin = False
|
||||
self._frame_opcode: int = OP_CODE_NOT_SET
|
||||
self._payload_fragments: list[bytes] = []
|
||||
self._frame_payload_len = 0
|
||||
|
||||
self._tail: bytes = b""
|
||||
self._has_mask = False
|
||||
self._frame_mask: Optional[bytes] = None
|
||||
self._payload_bytes_to_read = 0
|
||||
self._payload_len_flag = 0
|
||||
self._compressed: int = COMPRESSED_NOT_SET
|
||||
self._decompressobj: Optional[ZLibDecompressor] = None
|
||||
self._compress = compress
|
||||
|
||||
def feed_eof(self) -> None:
|
||||
self.queue.feed_eof()
|
||||
|
||||
# data can be bytearray on Windows because proactor event loop uses bytearray
|
||||
# and asyncio types this to Union[bytes, bytearray, memoryview] so we need
|
||||
# coerce data to bytes if it is not
|
||||
def feed_data(
|
||||
self, data: Union[bytes, bytearray, memoryview]
|
||||
) -> Tuple[bool, bytes]:
|
||||
if type(data) is not bytes:
|
||||
data = bytes(data)
|
||||
|
||||
if self._exc is not None:
|
||||
return True, data
|
||||
|
||||
try:
|
||||
self._feed_data(data)
|
||||
except Exception as exc:
|
||||
self._exc = exc
|
||||
set_exception(self.queue, exc)
|
||||
return EMPTY_FRAME_ERROR
|
||||
|
||||
return EMPTY_FRAME
|
||||
|
||||
def _handle_frame(
|
||||
self,
|
||||
fin: bool,
|
||||
opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int
|
||||
payload: Union[bytes, bytearray],
|
||||
compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int
|
||||
) -> None:
|
||||
msg: WSMessage
|
||||
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
|
||||
# load text/binary
|
||||
if not fin:
|
||||
# got partial frame payload
|
||||
if opcode != OP_CODE_CONTINUATION:
|
||||
self._opcode = opcode
|
||||
self._partial += payload
|
||||
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.MESSAGE_TOO_BIG,
|
||||
f"Message size {len(self._partial)} "
|
||||
f"exceeds limit {self._max_msg_size}",
|
||||
)
|
||||
return
|
||||
|
||||
has_partial = bool(self._partial)
|
||||
if opcode == OP_CODE_CONTINUATION:
|
||||
if self._opcode == OP_CODE_NOT_SET:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Continuation frame for non started message",
|
||||
)
|
||||
opcode = self._opcode
|
||||
self._opcode = OP_CODE_NOT_SET
|
||||
# previous frame was non finished
|
||||
# we should get continuation opcode
|
||||
elif has_partial:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"The opcode in non-fin frame is expected "
|
||||
f"to be zero, got {opcode!r}",
|
||||
)
|
||||
|
||||
assembled_payload: Union[bytes, bytearray]
|
||||
if has_partial:
|
||||
assembled_payload = self._partial + payload
|
||||
self._partial.clear()
|
||||
else:
|
||||
assembled_payload = payload
|
||||
|
||||
if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.MESSAGE_TOO_BIG,
|
||||
f"Message size {len(assembled_payload)} "
|
||||
f"exceeds limit {self._max_msg_size}",
|
||||
)
|
||||
|
||||
# Decompress process must to be done after all packets
|
||||
# received.
|
||||
if compressed:
|
||||
if not self._decompressobj:
|
||||
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
|
||||
# XXX: It's possible that the zlib backend (isal is known to
|
||||
# do this, maybe others too?) will return max_length bytes,
|
||||
# but internally buffer more data such that the payload is
|
||||
# >max_length, so we return one extra byte and if we're able
|
||||
# to do that, then the message is too big.
|
||||
payload_merged = self._decompressobj.decompress_sync(
|
||||
assembled_payload + WS_DEFLATE_TRAILING,
|
||||
(
|
||||
self._max_msg_size + 1
|
||||
if self._max_msg_size
|
||||
else self._max_msg_size
|
||||
),
|
||||
)
|
||||
if self._max_msg_size and len(payload_merged) > self._max_msg_size:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.MESSAGE_TOO_BIG,
|
||||
f"Decompressed message exceeds size limit {self._max_msg_size}",
|
||||
)
|
||||
elif type(assembled_payload) is bytes:
|
||||
payload_merged = assembled_payload
|
||||
else:
|
||||
payload_merged = bytes(assembled_payload)
|
||||
|
||||
if opcode == OP_CODE_TEXT:
|
||||
try:
|
||||
text = payload_merged.decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
||||
) from exc
|
||||
|
||||
# XXX: The Text and Binary messages here can be a performance
|
||||
# bottleneck, so we use tuple.__new__ to improve performance.
|
||||
# This is not type safe, but many tests should fail in
|
||||
# test_client_ws_functional.py if this is wrong.
|
||||
self.queue.feed_data(
|
||||
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
|
||||
len(payload_merged),
|
||||
)
|
||||
else:
|
||||
self.queue.feed_data(
|
||||
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
|
||||
len(payload_merged),
|
||||
)
|
||||
elif opcode == OP_CODE_CLOSE:
|
||||
if len(payload) >= 2:
|
||||
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
|
||||
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
f"Invalid close code: {close_code}",
|
||||
)
|
||||
try:
|
||||
close_message = payload[2:].decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
||||
) from exc
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
|
||||
elif payload:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
f"Invalid close frame: {fin} {opcode} {payload!r}",
|
||||
)
|
||||
else:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
|
||||
|
||||
self.queue.feed_data(msg, 0)
|
||||
elif opcode == OP_CODE_PING:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
|
||||
self.queue.feed_data(msg, len(payload))
|
||||
elif opcode == OP_CODE_PONG:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
|
||||
self.queue.feed_data(msg, len(payload))
|
||||
else:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
|
||||
)
|
||||
|
||||
def _feed_data(self, data: bytes) -> None:
|
||||
"""Return the next frame from the socket."""
|
||||
if self._tail:
|
||||
data, self._tail = self._tail + data, b""
|
||||
|
||||
start_pos: int = 0
|
||||
data_len = len(data)
|
||||
data_cstr = data
|
||||
|
||||
while True:
|
||||
# read header
|
||||
if self._state == READ_HEADER:
|
||||
if data_len - start_pos < 2:
|
||||
break
|
||||
first_byte = data_cstr[start_pos]
|
||||
second_byte = data_cstr[start_pos + 1]
|
||||
start_pos += 2
|
||||
|
||||
fin = (first_byte >> 7) & 1
|
||||
rsv1 = (first_byte >> 6) & 1
|
||||
rsv2 = (first_byte >> 5) & 1
|
||||
rsv3 = (first_byte >> 4) & 1
|
||||
opcode = first_byte & 0xF
|
||||
|
||||
# frame-fin = %x0 ; more frames of this message follow
|
||||
# / %x1 ; final frame of this message
|
||||
# frame-rsv1 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
# frame-rsv2 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
# frame-rsv3 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
#
|
||||
# Remove rsv1 from this test for deflate development
|
||||
if rsv2 or rsv3 or (rsv1 and not self._compress):
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received frame with non-zero reserved bits",
|
||||
)
|
||||
|
||||
if opcode > 0x7 and fin == 0:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received fragmented control frame",
|
||||
)
|
||||
|
||||
has_mask = (second_byte >> 7) & 1
|
||||
length = second_byte & 0x7F
|
||||
|
||||
# Control frames MUST have a payload
|
||||
# length of 125 bytes or less
|
||||
if opcode > 0x7 and length > 125:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Control frame payload cannot be larger than 125 bytes",
|
||||
)
|
||||
|
||||
# Set compress status if last package is FIN
|
||||
# OR set compress status if this is first fragment
|
||||
# Raise error if not first fragment with rsv1 = 0x1
|
||||
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
|
||||
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
|
||||
elif rsv1:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received frame with non-zero reserved bits",
|
||||
)
|
||||
|
||||
self._frame_fin = bool(fin)
|
||||
self._frame_opcode = opcode
|
||||
self._has_mask = bool(has_mask)
|
||||
self._payload_len_flag = length
|
||||
self._state = READ_PAYLOAD_LENGTH
|
||||
|
||||
# read payload length
|
||||
if self._state == READ_PAYLOAD_LENGTH:
|
||||
len_flag = self._payload_len_flag
|
||||
if len_flag == 126:
|
||||
if data_len - start_pos < 2:
|
||||
break
|
||||
first_byte = data_cstr[start_pos]
|
||||
second_byte = data_cstr[start_pos + 1]
|
||||
start_pos += 2
|
||||
self._payload_bytes_to_read = first_byte << 8 | second_byte
|
||||
elif len_flag > 126:
|
||||
if data_len - start_pos < 8:
|
||||
break
|
||||
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
|
||||
start_pos += 8
|
||||
else:
|
||||
self._payload_bytes_to_read = len_flag
|
||||
|
||||
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
|
||||
|
||||
# read payload mask
|
||||
if self._state == READ_PAYLOAD_MASK:
|
||||
if data_len - start_pos < 4:
|
||||
break
|
||||
self._frame_mask = data_cstr[start_pos : start_pos + 4]
|
||||
start_pos += 4
|
||||
self._state = READ_PAYLOAD
|
||||
|
||||
if self._state == READ_PAYLOAD:
|
||||
chunk_len = data_len - start_pos
|
||||
if self._payload_bytes_to_read >= chunk_len:
|
||||
f_end_pos = data_len
|
||||
self._payload_bytes_to_read -= chunk_len
|
||||
else:
|
||||
f_end_pos = start_pos + self._payload_bytes_to_read
|
||||
self._payload_bytes_to_read = 0
|
||||
|
||||
had_fragments = self._frame_payload_len
|
||||
self._frame_payload_len += f_end_pos - start_pos
|
||||
f_start_pos = start_pos
|
||||
start_pos = f_end_pos
|
||||
|
||||
if self._payload_bytes_to_read != 0:
|
||||
# If we don't have a complete frame, we need to save the
|
||||
# data for the next call to feed_data.
|
||||
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
|
||||
break
|
||||
|
||||
payload: Union[bytes, bytearray]
|
||||
if had_fragments:
|
||||
# We have to join the payload fragments get the payload
|
||||
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
|
||||
if self._has_mask:
|
||||
assert self._frame_mask is not None
|
||||
payload_bytearray = bytearray(b"".join(self._payload_fragments))
|
||||
websocket_mask(self._frame_mask, payload_bytearray)
|
||||
payload = payload_bytearray
|
||||
else:
|
||||
payload = b"".join(self._payload_fragments)
|
||||
self._payload_fragments.clear()
|
||||
elif self._has_mask:
|
||||
assert self._frame_mask is not None
|
||||
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
|
||||
if type(payload_bytearray) is not bytearray: # pragma: no branch
|
||||
# Cython will do the conversion for us
|
||||
# but we need to do it for Python and we
|
||||
# will always get here in Python
|
||||
payload_bytearray = bytearray(payload_bytearray)
|
||||
websocket_mask(self._frame_mask, payload_bytearray)
|
||||
payload = payload_bytearray
|
||||
else:
|
||||
payload = data_cstr[f_start_pos:f_end_pos]
|
||||
|
||||
self._handle_frame(
|
||||
self._frame_fin, self._frame_opcode, payload, self._compressed
|
||||
)
|
||||
self._frame_payload_len = 0
|
||||
self._state = READ_HEADER
|
||||
|
||||
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
|
||||
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""
|
||||
@@ -0,0 +1,178 @@
|
||||
"""WebSocket protocol versions 13 and 8."""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
from functools import partial
|
||||
from typing import Any, Final, Optional, Union
|
||||
|
||||
from ..base_protocol import BaseProtocol
|
||||
from ..client_exceptions import ClientConnectionResetError
|
||||
from ..compression_utils import ZLibBackend, ZLibCompressor
|
||||
from .helpers import (
|
||||
MASK_LEN,
|
||||
MSG_SIZE,
|
||||
PACK_CLOSE_CODE,
|
||||
PACK_LEN1,
|
||||
PACK_LEN2,
|
||||
PACK_LEN3,
|
||||
PACK_RANDBITS,
|
||||
websocket_mask,
|
||||
)
|
||||
from .models import WS_DEFLATE_TRAILING, WSMsgType
|
||||
|
||||
DEFAULT_LIMIT: Final[int] = 2**16
|
||||
|
||||
# For websockets, keeping latency low is extremely important as implementations
|
||||
# generally expect to be able to send and receive messages quickly. We use a
|
||||
# larger chunk size than the default to reduce the number of executor calls
|
||||
# since the executor is a significant source of latency and overhead when
|
||||
# the chunks are small. A size of 5KiB was chosen because it is also the
|
||||
# same value python-zlib-ng choose to use as the threshold to release the GIL.
|
||||
|
||||
WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 5 * 1024
|
||||
|
||||
|
||||
class WebSocketWriter:
|
||||
"""WebSocket writer.
|
||||
|
||||
The writer is responsible for sending messages to the client. It is
|
||||
created by the protocol when a connection is established. The writer
|
||||
should avoid implementing any application logic and should only be
|
||||
concerned with the low-level details of the WebSocket protocol.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
protocol: BaseProtocol,
|
||||
transport: asyncio.Transport,
|
||||
*,
|
||||
use_mask: bool = False,
|
||||
limit: int = DEFAULT_LIMIT,
|
||||
random: random.Random = random.Random(),
|
||||
compress: int = 0,
|
||||
notakeover: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a WebSocket writer."""
|
||||
self.protocol = protocol
|
||||
self.transport = transport
|
||||
self.use_mask = use_mask
|
||||
self.get_random_bits = partial(random.getrandbits, 32)
|
||||
self.compress = compress
|
||||
self.notakeover = notakeover
|
||||
self._closing = False
|
||||
self._limit = limit
|
||||
self._output_size = 0
|
||||
self._compressobj: Any = None # actually compressobj
|
||||
|
||||
async def send_frame(
|
||||
self, message: bytes, opcode: int, compress: Optional[int] = None
|
||||
) -> None:
|
||||
"""Send a frame over the websocket with message as its payload."""
|
||||
if self._closing and not (opcode & WSMsgType.CLOSE):
|
||||
raise ClientConnectionResetError("Cannot write to closing transport")
|
||||
|
||||
# RSV are the reserved bits in the frame header. They are used to
|
||||
# indicate that the frame is using an extension.
|
||||
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
|
||||
rsv = 0
|
||||
# Only compress larger packets (disabled)
|
||||
# Does small packet needs to be compressed?
|
||||
# if self.compress and opcode < 8 and len(message) > 124:
|
||||
if (compress or self.compress) and opcode < 8:
|
||||
# RSV1 (rsv = 0x40) is set for compressed frames
|
||||
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
|
||||
rsv = 0x40
|
||||
|
||||
if compress:
|
||||
# Do not set self._compress if compressing is for this frame
|
||||
compressobj = self._make_compress_obj(compress)
|
||||
else: # self.compress
|
||||
if not self._compressobj:
|
||||
self._compressobj = self._make_compress_obj(self.compress)
|
||||
compressobj = self._compressobj
|
||||
|
||||
message = (
|
||||
await compressobj.compress(message)
|
||||
+ compressobj.flush(
|
||||
ZLibBackend.Z_FULL_FLUSH
|
||||
if self.notakeover
|
||||
else ZLibBackend.Z_SYNC_FLUSH
|
||||
)
|
||||
).removesuffix(WS_DEFLATE_TRAILING)
|
||||
# Its critical that we do not return control to the event
|
||||
# loop until we have finished sending all the compressed
|
||||
# data. Otherwise we could end up mixing compressed frames
|
||||
# if there are multiple coroutines compressing data.
|
||||
|
||||
msg_length = len(message)
|
||||
|
||||
use_mask = self.use_mask
|
||||
mask_bit = 0x80 if use_mask else 0
|
||||
|
||||
# Depending on the message length, the header is assembled differently.
|
||||
# The first byte is reserved for the opcode and the RSV bits.
|
||||
first_byte = 0x80 | rsv | opcode
|
||||
if msg_length < 126:
|
||||
header = PACK_LEN1(first_byte, msg_length | mask_bit)
|
||||
header_len = 2
|
||||
elif msg_length < 65536:
|
||||
header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length)
|
||||
header_len = 4
|
||||
else:
|
||||
header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length)
|
||||
header_len = 10
|
||||
|
||||
if self.transport.is_closing():
|
||||
raise ClientConnectionResetError("Cannot write to closing transport")
|
||||
|
||||
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
|
||||
# If we are using a mask, we need to generate it randomly
|
||||
# and apply it to the message before sending it. A mask is
|
||||
# a 32-bit value that is applied to the message using a
|
||||
# bitwise XOR operation. It is used to prevent certain types
|
||||
# of attacks on the websocket protocol. The mask is only used
|
||||
# when aiohttp is acting as a client. Servers do not use a mask.
|
||||
if use_mask:
|
||||
mask = PACK_RANDBITS(self.get_random_bits())
|
||||
message = bytearray(message)
|
||||
websocket_mask(mask, message)
|
||||
self.transport.write(header + mask + message)
|
||||
self._output_size += MASK_LEN
|
||||
elif msg_length > MSG_SIZE:
|
||||
self.transport.write(header)
|
||||
self.transport.write(message)
|
||||
else:
|
||||
self.transport.write(header + message)
|
||||
|
||||
self._output_size += header_len + msg_length
|
||||
|
||||
# It is safe to return control to the event loop when using compression
|
||||
# after this point as we have already sent or buffered all the data.
|
||||
|
||||
# Once we have written output_size up to the limit, we call the
|
||||
# drain helper which waits for the transport to be ready to accept
|
||||
# more data. This is a flow control mechanism to prevent the buffer
|
||||
# from growing too large. The drain helper will return right away
|
||||
# if the writer is not paused.
|
||||
if self._output_size > self._limit:
|
||||
self._output_size = 0
|
||||
if self.protocol._paused:
|
||||
await self.protocol._drain_helper()
|
||||
|
||||
def _make_compress_obj(self, compress: int) -> ZLibCompressor:
|
||||
return ZLibCompressor(
|
||||
level=ZLibBackend.Z_BEST_SPEED,
|
||||
wbits=-compress,
|
||||
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
|
||||
)
|
||||
|
||||
async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
|
||||
"""Close the websocket, sending the specified code and message."""
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
try:
|
||||
await self.send_frame(
|
||||
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
|
||||
)
|
||||
finally:
|
||||
self._closing = True
|
||||
Reference in New Issue
Block a user