Changed code to support older Python versions
This commit is contained in:
parent
eb92d2d36f
commit
582458cdd0
5027 changed files with 794942 additions and 4 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
188
venv/lib/python3.11/site-packages/anyio/streams/buffered.py
Normal file
188
venv/lib/python3.11/site-packages/anyio/streams/buffered.py
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
from __future__ import annotations
|
||||
|
||||
__all__ = (
|
||||
"BufferedByteReceiveStream",
|
||||
"BufferedByteStream",
|
||||
"BufferedConnectable",
|
||||
)
|
||||
|
||||
import sys
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, SupportsIndex
|
||||
|
||||
from .. import ClosedResourceError, DelimiterNotFound, EndOfStream, IncompleteRead
|
||||
from ..abc import (
|
||||
AnyByteReceiveStream,
|
||||
AnyByteStream,
|
||||
AnyByteStreamConnectable,
|
||||
ByteReceiveStream,
|
||||
ByteStream,
|
||||
ByteStreamConnectable,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class BufferedByteReceiveStream(ByteReceiveStream):
|
||||
"""
|
||||
Wraps any bytes-based receive stream and uses a buffer to provide sophisticated
|
||||
receiving capabilities in the form of a byte stream.
|
||||
"""
|
||||
|
||||
receive_stream: AnyByteReceiveStream
|
||||
_buffer: bytearray = field(init=False, default_factory=bytearray)
|
||||
_closed: bool = field(init=False, default=False)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self.receive_stream.aclose()
|
||||
self._closed = True
|
||||
|
||||
@property
|
||||
def buffer(self) -> bytes:
|
||||
"""The bytes currently in the buffer."""
|
||||
return bytes(self._buffer)
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
return self.receive_stream.extra_attributes
|
||||
|
||||
def feed_data(self, data: Iterable[SupportsIndex], /) -> None:
|
||||
"""
|
||||
Append data directly into the buffer.
|
||||
|
||||
Any data in the buffer will be consumed by receive operations before receiving
|
||||
anything from the wrapped stream.
|
||||
|
||||
:param data: the data to append to the buffer (can be bytes or anything else
|
||||
that supports ``__index__()``)
|
||||
|
||||
"""
|
||||
self._buffer.extend(data)
|
||||
|
||||
async def receive(self, max_bytes: int = 65536) -> bytes:
|
||||
if self._closed:
|
||||
raise ClosedResourceError
|
||||
|
||||
if self._buffer:
|
||||
chunk = bytes(self._buffer[:max_bytes])
|
||||
del self._buffer[:max_bytes]
|
||||
return chunk
|
||||
elif isinstance(self.receive_stream, ByteReceiveStream):
|
||||
return await self.receive_stream.receive(max_bytes)
|
||||
else:
|
||||
# With a bytes-oriented object stream, we need to handle any surplus bytes
|
||||
# we get from the receive() call
|
||||
chunk = await self.receive_stream.receive()
|
||||
if len(chunk) > max_bytes:
|
||||
# Save the surplus bytes in the buffer
|
||||
self._buffer.extend(chunk[max_bytes:])
|
||||
return chunk[:max_bytes]
|
||||
else:
|
||||
return chunk
|
||||
|
||||
async def receive_exactly(self, nbytes: int) -> bytes:
|
||||
"""
|
||||
Read exactly the given amount of bytes from the stream.
|
||||
|
||||
:param nbytes: the number of bytes to read
|
||||
:return: the bytes read
|
||||
:raises ~anyio.IncompleteRead: if the stream was closed before the requested
|
||||
amount of bytes could be read from the stream
|
||||
|
||||
"""
|
||||
while True:
|
||||
remaining = nbytes - len(self._buffer)
|
||||
if remaining <= 0:
|
||||
retval = self._buffer[:nbytes]
|
||||
del self._buffer[:nbytes]
|
||||
return bytes(retval)
|
||||
|
||||
try:
|
||||
if isinstance(self.receive_stream, ByteReceiveStream):
|
||||
chunk = await self.receive_stream.receive(remaining)
|
||||
else:
|
||||
chunk = await self.receive_stream.receive()
|
||||
except EndOfStream as exc:
|
||||
raise IncompleteRead from exc
|
||||
|
||||
self._buffer.extend(chunk)
|
||||
|
||||
async def receive_until(self, delimiter: bytes, max_bytes: int) -> bytes:
|
||||
"""
|
||||
Read from the stream until the delimiter is found or max_bytes have been read.
|
||||
|
||||
:param delimiter: the marker to look for in the stream
|
||||
:param max_bytes: maximum number of bytes that will be read before raising
|
||||
:exc:`~anyio.DelimiterNotFound`
|
||||
:return: the bytes read (not including the delimiter)
|
||||
:raises ~anyio.IncompleteRead: if the stream was closed before the delimiter
|
||||
was found
|
||||
:raises ~anyio.DelimiterNotFound: if the delimiter is not found within the
|
||||
bytes read up to the maximum allowed
|
||||
|
||||
"""
|
||||
delimiter_size = len(delimiter)
|
||||
offset = 0
|
||||
while True:
|
||||
# Check if the delimiter can be found in the current buffer
|
||||
index = self._buffer.find(delimiter, offset)
|
||||
if index >= 0:
|
||||
found = self._buffer[:index]
|
||||
del self._buffer[: index + len(delimiter) :]
|
||||
return bytes(found)
|
||||
|
||||
# Check if the buffer is already at or over the limit
|
||||
if len(self._buffer) >= max_bytes:
|
||||
raise DelimiterNotFound(max_bytes)
|
||||
|
||||
# Read more data into the buffer from the socket
|
||||
try:
|
||||
data = await self.receive_stream.receive()
|
||||
except EndOfStream as exc:
|
||||
raise IncompleteRead from exc
|
||||
|
||||
# Move the offset forward and add the new data to the buffer
|
||||
offset = max(len(self._buffer) - delimiter_size + 1, 0)
|
||||
self._buffer.extend(data)
|
||||
|
||||
|
||||
class BufferedByteStream(BufferedByteReceiveStream, ByteStream):
|
||||
"""
|
||||
A full-duplex variant of :class:`BufferedByteReceiveStream`. All writes are passed
|
||||
through to the wrapped stream as-is.
|
||||
"""
|
||||
|
||||
def __init__(self, stream: AnyByteStream):
|
||||
"""
|
||||
:param stream: the stream to be wrapped
|
||||
|
||||
"""
|
||||
super().__init__(stream)
|
||||
self._stream = stream
|
||||
|
||||
@override
|
||||
async def send_eof(self) -> None:
|
||||
await self._stream.send_eof()
|
||||
|
||||
@override
|
||||
async def send(self, item: bytes) -> None:
|
||||
await self._stream.send(item)
|
||||
|
||||
|
||||
class BufferedConnectable(ByteStreamConnectable):
|
||||
def __init__(self, connectable: AnyByteStreamConnectable):
|
||||
"""
|
||||
:param connectable: the connectable to wrap
|
||||
|
||||
"""
|
||||
self.connectable = connectable
|
||||
|
||||
@override
|
||||
async def connect(self) -> BufferedByteStream:
|
||||
stream = await self.connectable.connect()
|
||||
return BufferedByteStream(stream)
|
||||
154
venv/lib/python3.11/site-packages/anyio/streams/file.py
Normal file
154
venv/lib/python3.11/site-packages/anyio/streams/file.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
from __future__ import annotations
|
||||
|
||||
__all__ = (
|
||||
"FileReadStream",
|
||||
"FileStreamAttribute",
|
||||
"FileWriteStream",
|
||||
)
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from io import SEEK_SET, UnsupportedOperation
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Any, BinaryIO, cast
|
||||
|
||||
from .. import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
EndOfStream,
|
||||
TypedAttributeSet,
|
||||
to_thread,
|
||||
typed_attribute,
|
||||
)
|
||||
from ..abc import ByteReceiveStream, ByteSendStream
|
||||
|
||||
|
||||
class FileStreamAttribute(TypedAttributeSet):
|
||||
#: the open file descriptor
|
||||
file: BinaryIO = typed_attribute()
|
||||
#: the path of the file on the file system, if available (file must be a real file)
|
||||
path: Path = typed_attribute()
|
||||
#: the file number, if available (file must be a real file or a TTY)
|
||||
fileno: int = typed_attribute()
|
||||
|
||||
|
||||
class _BaseFileStream:
|
||||
def __init__(self, file: BinaryIO):
|
||||
self._file = file
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await to_thread.run_sync(self._file.close)
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
attributes: dict[Any, Callable[[], Any]] = {
|
||||
FileStreamAttribute.file: lambda: self._file,
|
||||
}
|
||||
|
||||
if hasattr(self._file, "name"):
|
||||
attributes[FileStreamAttribute.path] = lambda: Path(self._file.name)
|
||||
|
||||
try:
|
||||
self._file.fileno()
|
||||
except UnsupportedOperation:
|
||||
pass
|
||||
else:
|
||||
attributes[FileStreamAttribute.fileno] = lambda: self._file.fileno()
|
||||
|
||||
return attributes
|
||||
|
||||
|
||||
class FileReadStream(_BaseFileStream, ByteReceiveStream):
|
||||
"""
|
||||
A byte stream that reads from a file in the file system.
|
||||
|
||||
:param file: a file that has been opened for reading in binary mode
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def from_path(cls, path: str | PathLike[str]) -> FileReadStream:
|
||||
"""
|
||||
Create a file read stream by opening the given file.
|
||||
|
||||
:param path: path of the file to read from
|
||||
|
||||
"""
|
||||
file = await to_thread.run_sync(Path(path).open, "rb")
|
||||
return cls(cast(BinaryIO, file))
|
||||
|
||||
async def receive(self, max_bytes: int = 65536) -> bytes:
|
||||
try:
|
||||
data = await to_thread.run_sync(self._file.read, max_bytes)
|
||||
except ValueError:
|
||||
raise ClosedResourceError from None
|
||||
except OSError as exc:
|
||||
raise BrokenResourceError from exc
|
||||
|
||||
if data:
|
||||
return data
|
||||
else:
|
||||
raise EndOfStream
|
||||
|
||||
async def seek(self, position: int, whence: int = SEEK_SET) -> int:
|
||||
"""
|
||||
Seek the file to the given position.
|
||||
|
||||
.. seealso:: :meth:`io.IOBase.seek`
|
||||
|
||||
.. note:: Not all file descriptors are seekable.
|
||||
|
||||
:param position: position to seek the file to
|
||||
:param whence: controls how ``position`` is interpreted
|
||||
:return: the new absolute position
|
||||
:raises OSError: if the file is not seekable
|
||||
|
||||
"""
|
||||
return await to_thread.run_sync(self._file.seek, position, whence)
|
||||
|
||||
async def tell(self) -> int:
|
||||
"""
|
||||
Return the current stream position.
|
||||
|
||||
.. note:: Not all file descriptors are seekable.
|
||||
|
||||
:return: the current absolute position
|
||||
:raises OSError: if the file is not seekable
|
||||
|
||||
"""
|
||||
return await to_thread.run_sync(self._file.tell)
|
||||
|
||||
|
||||
class FileWriteStream(_BaseFileStream, ByteSendStream):
|
||||
"""
|
||||
A byte stream that writes to a file in the file system.
|
||||
|
||||
:param file: a file that has been opened for writing in binary mode
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def from_path(
|
||||
cls, path: str | PathLike[str], append: bool = False
|
||||
) -> FileWriteStream:
|
||||
"""
|
||||
Create a file write stream by opening the given file for writing.
|
||||
|
||||
:param path: path of the file to write to
|
||||
:param append: if ``True``, open the file for appending; if ``False``, any
|
||||
existing file at the given path will be truncated
|
||||
|
||||
"""
|
||||
mode = "ab" if append else "wb"
|
||||
file = await to_thread.run_sync(Path(path).open, mode)
|
||||
return cls(cast(BinaryIO, file))
|
||||
|
||||
async def send(self, item: bytes) -> None:
|
||||
try:
|
||||
await to_thread.run_sync(self._file.write, item)
|
||||
except ValueError:
|
||||
raise ClosedResourceError from None
|
||||
except OSError as exc:
|
||||
raise BrokenResourceError from exc
|
||||
325
venv/lib/python3.11/site-packages/anyio/streams/memory.py
Normal file
325
venv/lib/python3.11/site-packages/anyio/streams/memory.py
Normal file
|
|
@ -0,0 +1,325 @@
|
|||
from __future__ import annotations
|
||||
|
||||
__all__ = (
|
||||
"MemoryObjectReceiveStream",
|
||||
"MemoryObjectSendStream",
|
||||
"MemoryObjectStreamStatistics",
|
||||
)
|
||||
|
||||
import warnings
|
||||
from collections import OrderedDict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from types import TracebackType
|
||||
from typing import Generic, NamedTuple, TypeVar
|
||||
|
||||
from .. import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
EndOfStream,
|
||||
WouldBlock,
|
||||
)
|
||||
from .._core._testing import TaskInfo, get_current_task
|
||||
from ..abc import Event, ObjectReceiveStream, ObjectSendStream
|
||||
from ..lowlevel import checkpoint
|
||||
|
||||
T_Item = TypeVar("T_Item")
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
T_contra = TypeVar("T_contra", contravariant=True)
|
||||
|
||||
|
||||
class MemoryObjectStreamStatistics(NamedTuple):
|
||||
current_buffer_used: int #: number of items stored in the buffer
|
||||
#: maximum number of items that can be stored on this stream (or :data:`math.inf`)
|
||||
max_buffer_size: float
|
||||
open_send_streams: int #: number of unclosed clones of the send stream
|
||||
open_receive_streams: int #: number of unclosed clones of the receive stream
|
||||
#: number of tasks blocked on :meth:`MemoryObjectSendStream.send`
|
||||
tasks_waiting_send: int
|
||||
#: number of tasks blocked on :meth:`MemoryObjectReceiveStream.receive`
|
||||
tasks_waiting_receive: int
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class _MemoryObjectItemReceiver(Generic[T_Item]):
|
||||
task_info: TaskInfo = field(init=False, default_factory=get_current_task)
|
||||
item: T_Item = field(init=False)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# When item is not defined, we get following error with default __repr__:
|
||||
# AttributeError: 'MemoryObjectItemReceiver' object has no attribute 'item'
|
||||
item = getattr(self, "item", None)
|
||||
return f"{self.__class__.__name__}(task_info={self.task_info}, item={item!r})"
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class _MemoryObjectStreamState(Generic[T_Item]):
|
||||
max_buffer_size: float = field()
|
||||
buffer: deque[T_Item] = field(init=False, default_factory=deque)
|
||||
open_send_channels: int = field(init=False, default=0)
|
||||
open_receive_channels: int = field(init=False, default=0)
|
||||
waiting_receivers: OrderedDict[Event, _MemoryObjectItemReceiver[T_Item]] = field(
|
||||
init=False, default_factory=OrderedDict
|
||||
)
|
||||
waiting_senders: OrderedDict[Event, T_Item] = field(
|
||||
init=False, default_factory=OrderedDict
|
||||
)
|
||||
|
||||
def statistics(self) -> MemoryObjectStreamStatistics:
|
||||
return MemoryObjectStreamStatistics(
|
||||
len(self.buffer),
|
||||
self.max_buffer_size,
|
||||
self.open_send_channels,
|
||||
self.open_receive_channels,
|
||||
len(self.waiting_senders),
|
||||
len(self.waiting_receivers),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]):
|
||||
_state: _MemoryObjectStreamState[T_co]
|
||||
_closed: bool = field(init=False, default=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._state.open_receive_channels += 1
|
||||
|
||||
def receive_nowait(self) -> T_co:
|
||||
"""
|
||||
Receive the next item if it can be done without waiting.
|
||||
|
||||
:return: the received item
|
||||
:raises ~anyio.ClosedResourceError: if this send stream has been closed
|
||||
:raises ~anyio.EndOfStream: if the buffer is empty and this stream has been
|
||||
closed from the sending end
|
||||
:raises ~anyio.WouldBlock: if there are no items in the buffer and no tasks
|
||||
waiting to send
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise ClosedResourceError
|
||||
|
||||
if self._state.waiting_senders:
|
||||
# Get the item from the next sender
|
||||
send_event, item = self._state.waiting_senders.popitem(last=False)
|
||||
self._state.buffer.append(item)
|
||||
send_event.set()
|
||||
|
||||
if self._state.buffer:
|
||||
return self._state.buffer.popleft()
|
||||
elif not self._state.open_send_channels:
|
||||
raise EndOfStream
|
||||
|
||||
raise WouldBlock
|
||||
|
||||
async def receive(self) -> T_co:
|
||||
await checkpoint()
|
||||
try:
|
||||
return self.receive_nowait()
|
||||
except WouldBlock:
|
||||
# Add ourselves in the queue
|
||||
receive_event = Event()
|
||||
receiver = _MemoryObjectItemReceiver[T_co]()
|
||||
self._state.waiting_receivers[receive_event] = receiver
|
||||
|
||||
try:
|
||||
await receive_event.wait()
|
||||
finally:
|
||||
self._state.waiting_receivers.pop(receive_event, None)
|
||||
|
||||
try:
|
||||
return receiver.item
|
||||
except AttributeError:
|
||||
raise EndOfStream from None
|
||||
|
||||
def clone(self) -> MemoryObjectReceiveStream[T_co]:
|
||||
"""
|
||||
Create a clone of this receive stream.
|
||||
|
||||
Each clone can be closed separately. Only when all clones have been closed will
|
||||
the receiving end of the memory stream be considered closed by the sending ends.
|
||||
|
||||
:return: the cloned stream
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise ClosedResourceError
|
||||
|
||||
return MemoryObjectReceiveStream(_state=self._state)
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the stream.
|
||||
|
||||
This works the exact same way as :meth:`aclose`, but is provided as a special
|
||||
case for the benefit of synchronous callbacks.
|
||||
|
||||
"""
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
self._state.open_receive_channels -= 1
|
||||
if self._state.open_receive_channels == 0:
|
||||
send_events = list(self._state.waiting_senders.keys())
|
||||
for event in send_events:
|
||||
event.set()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.close()
|
||||
|
||||
def statistics(self) -> MemoryObjectStreamStatistics:
|
||||
"""
|
||||
Return statistics about the current state of this stream.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
return self._state.statistics()
|
||||
|
||||
def __enter__(self) -> MemoryObjectReceiveStream[T_co]:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if not self._closed:
|
||||
warnings.warn(
|
||||
f"Unclosed <{self.__class__.__name__} at {id(self):x}>",
|
||||
ResourceWarning,
|
||||
stacklevel=1,
|
||||
source=self,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
|
||||
_state: _MemoryObjectStreamState[T_contra]
|
||||
_closed: bool = field(init=False, default=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._state.open_send_channels += 1
|
||||
|
||||
def send_nowait(self, item: T_contra) -> None:
|
||||
"""
|
||||
Send an item immediately if it can be done without waiting.
|
||||
|
||||
:param item: the item to send
|
||||
:raises ~anyio.ClosedResourceError: if this send stream has been closed
|
||||
:raises ~anyio.BrokenResourceError: if the stream has been closed from the
|
||||
receiving end
|
||||
:raises ~anyio.WouldBlock: if the buffer is full and there are no tasks waiting
|
||||
to receive
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise ClosedResourceError
|
||||
if not self._state.open_receive_channels:
|
||||
raise BrokenResourceError
|
||||
|
||||
while self._state.waiting_receivers:
|
||||
receive_event, receiver = self._state.waiting_receivers.popitem(last=False)
|
||||
if not receiver.task_info.has_pending_cancellation():
|
||||
receiver.item = item
|
||||
receive_event.set()
|
||||
return
|
||||
|
||||
if len(self._state.buffer) < self._state.max_buffer_size:
|
||||
self._state.buffer.append(item)
|
||||
else:
|
||||
raise WouldBlock
|
||||
|
||||
async def send(self, item: T_contra) -> None:
|
||||
"""
|
||||
Send an item to the stream.
|
||||
|
||||
If the buffer is full, this method blocks until there is again room in the
|
||||
buffer or the item can be sent directly to a receiver.
|
||||
|
||||
:param item: the item to send
|
||||
:raises ~anyio.ClosedResourceError: if this send stream has been closed
|
||||
:raises ~anyio.BrokenResourceError: if the stream has been closed from the
|
||||
receiving end
|
||||
|
||||
"""
|
||||
await checkpoint()
|
||||
try:
|
||||
self.send_nowait(item)
|
||||
except WouldBlock:
|
||||
# Wait until there's someone on the receiving end
|
||||
send_event = Event()
|
||||
self._state.waiting_senders[send_event] = item
|
||||
try:
|
||||
await send_event.wait()
|
||||
except BaseException:
|
||||
self._state.waiting_senders.pop(send_event, None)
|
||||
raise
|
||||
|
||||
if send_event in self._state.waiting_senders:
|
||||
del self._state.waiting_senders[send_event]
|
||||
raise BrokenResourceError from None
|
||||
|
||||
def clone(self) -> MemoryObjectSendStream[T_contra]:
|
||||
"""
|
||||
Create a clone of this send stream.
|
||||
|
||||
Each clone can be closed separately. Only when all clones have been closed will
|
||||
the sending end of the memory stream be considered closed by the receiving ends.
|
||||
|
||||
:return: the cloned stream
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise ClosedResourceError
|
||||
|
||||
return MemoryObjectSendStream(_state=self._state)
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the stream.
|
||||
|
||||
This works the exact same way as :meth:`aclose`, but is provided as a special
|
||||
case for the benefit of synchronous callbacks.
|
||||
|
||||
"""
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
self._state.open_send_channels -= 1
|
||||
if self._state.open_send_channels == 0:
|
||||
receive_events = list(self._state.waiting_receivers.keys())
|
||||
self._state.waiting_receivers.clear()
|
||||
for event in receive_events:
|
||||
event.set()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.close()
|
||||
|
||||
def statistics(self) -> MemoryObjectStreamStatistics:
|
||||
"""
|
||||
Return statistics about the current state of this stream.
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
return self._state.statistics()
|
||||
|
||||
def __enter__(self) -> MemoryObjectSendStream[T_contra]:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if not self._closed:
|
||||
warnings.warn(
|
||||
f"Unclosed <{self.__class__.__name__} at {id(self):x}>",
|
||||
ResourceWarning,
|
||||
stacklevel=1,
|
||||
source=self,
|
||||
)
|
||||
147
venv/lib/python3.11/site-packages/anyio/streams/stapled.py
Normal file
147
venv/lib/python3.11/site-packages/anyio/streams/stapled.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
from __future__ import annotations
|
||||
|
||||
__all__ = (
|
||||
"MultiListener",
|
||||
"StapledByteStream",
|
||||
"StapledObjectStream",
|
||||
)
|
||||
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from ..abc import (
|
||||
ByteReceiveStream,
|
||||
ByteSendStream,
|
||||
ByteStream,
|
||||
Listener,
|
||||
ObjectReceiveStream,
|
||||
ObjectSendStream,
|
||||
ObjectStream,
|
||||
TaskGroup,
|
||||
)
|
||||
|
||||
T_Item = TypeVar("T_Item")
|
||||
T_Stream = TypeVar("T_Stream")
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class StapledByteStream(ByteStream):
|
||||
"""
|
||||
Combines two byte streams into a single, bidirectional byte stream.
|
||||
|
||||
Extra attributes will be provided from both streams, with the receive stream
|
||||
providing the values in case of a conflict.
|
||||
|
||||
:param ByteSendStream send_stream: the sending byte stream
|
||||
:param ByteReceiveStream receive_stream: the receiving byte stream
|
||||
"""
|
||||
|
||||
send_stream: ByteSendStream
|
||||
receive_stream: ByteReceiveStream
|
||||
|
||||
async def receive(self, max_bytes: int = 65536) -> bytes:
|
||||
return await self.receive_stream.receive(max_bytes)
|
||||
|
||||
async def send(self, item: bytes) -> None:
|
||||
await self.send_stream.send(item)
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
await self.send_stream.aclose()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self.send_stream.aclose()
|
||||
await self.receive_stream.aclose()
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
return {
|
||||
**self.send_stream.extra_attributes,
|
||||
**self.receive_stream.extra_attributes,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class StapledObjectStream(Generic[T_Item], ObjectStream[T_Item]):
|
||||
"""
|
||||
Combines two object streams into a single, bidirectional object stream.
|
||||
|
||||
Extra attributes will be provided from both streams, with the receive stream
|
||||
providing the values in case of a conflict.
|
||||
|
||||
:param ObjectSendStream send_stream: the sending object stream
|
||||
:param ObjectReceiveStream receive_stream: the receiving object stream
|
||||
"""
|
||||
|
||||
send_stream: ObjectSendStream[T_Item]
|
||||
receive_stream: ObjectReceiveStream[T_Item]
|
||||
|
||||
async def receive(self) -> T_Item:
|
||||
return await self.receive_stream.receive()
|
||||
|
||||
async def send(self, item: T_Item) -> None:
|
||||
await self.send_stream.send(item)
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
await self.send_stream.aclose()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self.send_stream.aclose()
|
||||
await self.receive_stream.aclose()
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
return {
|
||||
**self.send_stream.extra_attributes,
|
||||
**self.receive_stream.extra_attributes,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MultiListener(Generic[T_Stream], Listener[T_Stream]):
|
||||
"""
|
||||
Combines multiple listeners into one, serving connections from all of them at once.
|
||||
|
||||
Any MultiListeners in the given collection of listeners will have their listeners
|
||||
moved into this one.
|
||||
|
||||
Extra attributes are provided from each listener, with each successive listener
|
||||
overriding any conflicting attributes from the previous one.
|
||||
|
||||
:param listeners: listeners to serve
|
||||
:type listeners: Sequence[Listener[T_Stream]]
|
||||
"""
|
||||
|
||||
listeners: Sequence[Listener[T_Stream]]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
listeners: list[Listener[T_Stream]] = []
|
||||
for listener in self.listeners:
|
||||
if isinstance(listener, MultiListener):
|
||||
listeners.extend(listener.listeners)
|
||||
del listener.listeners[:] # type: ignore[attr-defined]
|
||||
else:
|
||||
listeners.append(listener)
|
||||
|
||||
self.listeners = listeners
|
||||
|
||||
async def serve(
|
||||
self, handler: Callable[[T_Stream], Any], task_group: TaskGroup | None = None
|
||||
) -> None:
|
||||
from .. import create_task_group
|
||||
|
||||
async with create_task_group() as tg:
|
||||
for listener in self.listeners:
|
||||
tg.start_soon(listener.serve, handler, task_group)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
for listener in self.listeners:
|
||||
await listener.aclose()
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
attributes: dict = {}
|
||||
for listener in self.listeners:
|
||||
attributes.update(listener.extra_attributes)
|
||||
|
||||
return attributes
|
||||
176
venv/lib/python3.11/site-packages/anyio/streams/text.py
Normal file
176
venv/lib/python3.11/site-packages/anyio/streams/text.py
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
from __future__ import annotations
|
||||
|
||||
__all__ = (
|
||||
"TextConnectable",
|
||||
"TextReceiveStream",
|
||||
"TextSendStream",
|
||||
"TextStream",
|
||||
)
|
||||
|
||||
import codecs
|
||||
import sys
|
||||
from collections.abc import Callable, Mapping
|
||||
from dataclasses import InitVar, dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..abc import (
|
||||
AnyByteReceiveStream,
|
||||
AnyByteSendStream,
|
||||
AnyByteStream,
|
||||
AnyByteStreamConnectable,
|
||||
ObjectReceiveStream,
|
||||
ObjectSendStream,
|
||||
ObjectStream,
|
||||
ObjectStreamConnectable,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class TextReceiveStream(ObjectReceiveStream[str]):
|
||||
"""
|
||||
Stream wrapper that decodes bytes to strings using the given encoding.
|
||||
|
||||
Decoding is done using :class:`~codecs.IncrementalDecoder` which returns any
|
||||
completely received unicode characters as soon as they come in.
|
||||
|
||||
:param transport_stream: any bytes-based receive stream
|
||||
:param encoding: character encoding to use for decoding bytes to strings (defaults
|
||||
to ``utf-8``)
|
||||
:param errors: handling scheme for decoding errors (defaults to ``strict``; see the
|
||||
`codecs module documentation`_ for a comprehensive list of options)
|
||||
|
||||
.. _codecs module documentation:
|
||||
https://docs.python.org/3/library/codecs.html#codec-objects
|
||||
"""
|
||||
|
||||
transport_stream: AnyByteReceiveStream
|
||||
encoding: InitVar[str] = "utf-8"
|
||||
errors: InitVar[str] = "strict"
|
||||
_decoder: codecs.IncrementalDecoder = field(init=False)
|
||||
|
||||
def __post_init__(self, encoding: str, errors: str) -> None:
|
||||
decoder_class = codecs.getincrementaldecoder(encoding)
|
||||
self._decoder = decoder_class(errors=errors)
|
||||
|
||||
async def receive(self) -> str:
|
||||
while True:
|
||||
chunk = await self.transport_stream.receive()
|
||||
decoded = self._decoder.decode(chunk)
|
||||
if decoded:
|
||||
return decoded
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self.transport_stream.aclose()
|
||||
self._decoder.reset()
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
return self.transport_stream.extra_attributes
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class TextSendStream(ObjectSendStream[str]):
|
||||
"""
|
||||
Sends strings to the wrapped stream as bytes using the given encoding.
|
||||
|
||||
:param AnyByteSendStream transport_stream: any bytes-based send stream
|
||||
:param str encoding: character encoding to use for encoding strings to bytes
|
||||
(defaults to ``utf-8``)
|
||||
:param str errors: handling scheme for encoding errors (defaults to ``strict``; see
|
||||
the `codecs module documentation`_ for a comprehensive list of options)
|
||||
|
||||
.. _codecs module documentation:
|
||||
https://docs.python.org/3/library/codecs.html#codec-objects
|
||||
"""
|
||||
|
||||
transport_stream: AnyByteSendStream
|
||||
encoding: InitVar[str] = "utf-8"
|
||||
errors: str = "strict"
|
||||
_encoder: Callable[..., tuple[bytes, int]] = field(init=False)
|
||||
|
||||
def __post_init__(self, encoding: str) -> None:
|
||||
self._encoder = codecs.getencoder(encoding)
|
||||
|
||||
async def send(self, item: str) -> None:
|
||||
encoded = self._encoder(item, self.errors)[0]
|
||||
await self.transport_stream.send(encoded)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self.transport_stream.aclose()
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
return self.transport_stream.extra_attributes
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class TextStream(ObjectStream[str]):
|
||||
"""
|
||||
A bidirectional stream that decodes bytes to strings on receive and encodes strings
|
||||
to bytes on send.
|
||||
|
||||
Extra attributes will be provided from both streams, with the receive stream
|
||||
providing the values in case of a conflict.
|
||||
|
||||
:param AnyByteStream transport_stream: any bytes-based stream
|
||||
:param str encoding: character encoding to use for encoding/decoding strings to/from
|
||||
bytes (defaults to ``utf-8``)
|
||||
:param str errors: handling scheme for encoding errors (defaults to ``strict``; see
|
||||
the `codecs module documentation`_ for a comprehensive list of options)
|
||||
|
||||
.. _codecs module documentation:
|
||||
https://docs.python.org/3/library/codecs.html#codec-objects
|
||||
"""
|
||||
|
||||
transport_stream: AnyByteStream
|
||||
encoding: InitVar[str] = "utf-8"
|
||||
errors: InitVar[str] = "strict"
|
||||
_receive_stream: TextReceiveStream = field(init=False)
|
||||
_send_stream: TextSendStream = field(init=False)
|
||||
|
||||
def __post_init__(self, encoding: str, errors: str) -> None:
|
||||
self._receive_stream = TextReceiveStream(
|
||||
self.transport_stream, encoding=encoding, errors=errors
|
||||
)
|
||||
self._send_stream = TextSendStream(
|
||||
self.transport_stream, encoding=encoding, errors=errors
|
||||
)
|
||||
|
||||
async def receive(self) -> str:
|
||||
return await self._receive_stream.receive()
|
||||
|
||||
async def send(self, item: str) -> None:
|
||||
await self._send_stream.send(item)
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
await self.transport_stream.send_eof()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._send_stream.aclose()
|
||||
await self._receive_stream.aclose()
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
return {
|
||||
**self._send_stream.extra_attributes,
|
||||
**self._receive_stream.extra_attributes,
|
||||
}
|
||||
|
||||
|
||||
class TextConnectable(ObjectStreamConnectable[str]):
|
||||
def __init__(self, connectable: AnyByteStreamConnectable):
|
||||
"""
|
||||
:param connectable: the bytestream endpoint to wrap
|
||||
|
||||
"""
|
||||
self.connectable = connectable
|
||||
|
||||
@override
|
||||
async def connect(self) -> TextStream:
|
||||
stream = await self.connectable.connect()
|
||||
return TextStream(stream)
|
||||
424
venv/lib/python3.11/site-packages/anyio/streams/tls.py
Normal file
424
venv/lib/python3.11/site-packages/anyio/streams/tls.py
Normal file
|
|
@ -0,0 +1,424 @@
|
|||
from __future__ import annotations
|
||||
|
||||
__all__ = (
|
||||
"TLSAttribute",
|
||||
"TLSConnectable",
|
||||
"TLSListener",
|
||||
"TLSStream",
|
||||
)
|
||||
|
||||
import logging
|
||||
import re
|
||||
import ssl
|
||||
import sys
|
||||
from collections.abc import Callable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from ssl import SSLContext
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from .. import (
|
||||
BrokenResourceError,
|
||||
EndOfStream,
|
||||
aclose_forcefully,
|
||||
get_cancelled_exc_class,
|
||||
to_thread,
|
||||
)
|
||||
from .._core._typedattr import TypedAttributeSet, typed_attribute
|
||||
from ..abc import (
|
||||
AnyByteStream,
|
||||
AnyByteStreamConnectable,
|
||||
ByteStream,
|
||||
ByteStreamConnectable,
|
||||
Listener,
|
||||
TaskGroup,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypeVarTuple, Unpack
|
||||
else:
|
||||
from typing_extensions import TypeVarTuple, Unpack
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
T_Retval = TypeVar("T_Retval")
|
||||
PosArgsT = TypeVarTuple("PosArgsT")
|
||||
_PCTRTT: TypeAlias = tuple[tuple[str, str], ...]
|
||||
_PCTRTTT: TypeAlias = tuple[_PCTRTT, ...]
|
||||
|
||||
|
||||
class TLSAttribute(TypedAttributeSet):
|
||||
"""Contains Transport Layer Security related attributes."""
|
||||
|
||||
#: the selected ALPN protocol
|
||||
alpn_protocol: str | None = typed_attribute()
|
||||
#: the channel binding for type ``tls-unique``
|
||||
channel_binding_tls_unique: bytes = typed_attribute()
|
||||
#: the selected cipher
|
||||
cipher: tuple[str, str, int] = typed_attribute()
|
||||
#: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert`
|
||||
# for more information)
|
||||
peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute()
|
||||
#: the peer certificate in binary form
|
||||
peer_certificate_binary: bytes | None = typed_attribute()
|
||||
#: ``True`` if this is the server side of the connection
|
||||
server_side: bool = typed_attribute()
|
||||
#: ciphers shared by the client during the TLS handshake (``None`` if this is the
|
||||
#: client side)
|
||||
shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute()
|
||||
#: the :class:`~ssl.SSLObject` used for encryption
|
||||
ssl_object: ssl.SSLObject = typed_attribute()
|
||||
#: ``True`` if this stream does (and expects) a closing TLS handshake when the
|
||||
#: stream is being closed
|
||||
standard_compatible: bool = typed_attribute()
|
||||
#: the TLS protocol version (e.g. ``TLSv1.2``)
|
||||
tls_version: str = typed_attribute()
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class TLSStream(ByteStream):
|
||||
"""
|
||||
A stream wrapper that encrypts all sent data and decrypts received data.
|
||||
|
||||
This class has no public initializer; use :meth:`wrap` instead.
|
||||
All extra attributes from :class:`~TLSAttribute` are supported.
|
||||
|
||||
:var AnyByteStream transport_stream: the wrapped stream
|
||||
|
||||
"""
|
||||
|
||||
transport_stream: AnyByteStream
|
||||
standard_compatible: bool
|
||||
_ssl_object: ssl.SSLObject
|
||||
_read_bio: ssl.MemoryBIO
|
||||
_write_bio: ssl.MemoryBIO
|
||||
|
||||
@classmethod
|
||||
async def wrap(
|
||||
cls,
|
||||
transport_stream: AnyByteStream,
|
||||
*,
|
||||
server_side: bool | None = None,
|
||||
hostname: str | None = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
standard_compatible: bool = True,
|
||||
) -> TLSStream:
|
||||
"""
|
||||
Wrap an existing stream with Transport Layer Security.
|
||||
|
||||
This performs a TLS handshake with the peer.
|
||||
|
||||
:param transport_stream: a bytes-transporting stream to wrap
|
||||
:param server_side: ``True`` if this is the server side of the connection,
|
||||
``False`` if this is the client side (if omitted, will be set to ``False``
|
||||
if ``hostname`` has been provided, ``False`` otherwise). Used only to create
|
||||
a default context when an explicit context has not been provided.
|
||||
:param hostname: host name of the peer (if host name checking is desired)
|
||||
:param ssl_context: the SSLContext object to use (if not provided, a secure
|
||||
default will be created)
|
||||
:param standard_compatible: if ``False``, skip the closing handshake when
|
||||
closing the connection, and don't raise an exception if the peer does the
|
||||
same
|
||||
:raises ~ssl.SSLError: if the TLS handshake fails
|
||||
|
||||
"""
|
||||
if server_side is None:
|
||||
server_side = not hostname
|
||||
|
||||
if not ssl_context:
|
||||
purpose = (
|
||||
ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
|
||||
)
|
||||
ssl_context = ssl.create_default_context(purpose)
|
||||
|
||||
# Re-enable detection of unexpected EOFs if it was disabled by Python
|
||||
if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
|
||||
ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
|
||||
|
||||
bio_in = ssl.MemoryBIO()
|
||||
bio_out = ssl.MemoryBIO()
|
||||
|
||||
# External SSLContext implementations may do blocking I/O in wrap_bio(),
|
||||
# but the standard library implementation won't
|
||||
if type(ssl_context) is ssl.SSLContext:
|
||||
ssl_object = ssl_context.wrap_bio(
|
||||
bio_in, bio_out, server_side=server_side, server_hostname=hostname
|
||||
)
|
||||
else:
|
||||
ssl_object = await to_thread.run_sync(
|
||||
ssl_context.wrap_bio,
|
||||
bio_in,
|
||||
bio_out,
|
||||
server_side,
|
||||
hostname,
|
||||
None,
|
||||
)
|
||||
|
||||
wrapper = cls(
|
||||
transport_stream=transport_stream,
|
||||
standard_compatible=standard_compatible,
|
||||
_ssl_object=ssl_object,
|
||||
_read_bio=bio_in,
|
||||
_write_bio=bio_out,
|
||||
)
|
||||
await wrapper._call_sslobject_method(ssl_object.do_handshake)
|
||||
return wrapper
|
||||
|
||||
async def _call_sslobject_method(
|
||||
self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
|
||||
) -> T_Retval:
|
||||
while True:
|
||||
try:
|
||||
result = func(*args)
|
||||
except ssl.SSLWantReadError:
|
||||
try:
|
||||
# Flush any pending writes first
|
||||
if self._write_bio.pending:
|
||||
await self.transport_stream.send(self._write_bio.read())
|
||||
|
||||
data = await self.transport_stream.receive()
|
||||
except EndOfStream:
|
||||
self._read_bio.write_eof()
|
||||
except OSError as exc:
|
||||
self._read_bio.write_eof()
|
||||
self._write_bio.write_eof()
|
||||
raise BrokenResourceError from exc
|
||||
else:
|
||||
self._read_bio.write(data)
|
||||
except ssl.SSLWantWriteError:
|
||||
await self.transport_stream.send(self._write_bio.read())
|
||||
except ssl.SSLSyscallError as exc:
|
||||
self._read_bio.write_eof()
|
||||
self._write_bio.write_eof()
|
||||
raise BrokenResourceError from exc
|
||||
except ssl.SSLError as exc:
|
||||
self._read_bio.write_eof()
|
||||
self._write_bio.write_eof()
|
||||
if isinstance(exc, ssl.SSLEOFError) or (
|
||||
exc.strerror and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
|
||||
):
|
||||
if self.standard_compatible:
|
||||
raise BrokenResourceError from exc
|
||||
else:
|
||||
raise EndOfStream from None
|
||||
|
||||
raise
|
||||
else:
|
||||
# Flush any pending writes first
|
||||
if self._write_bio.pending:
|
||||
await self.transport_stream.send(self._write_bio.read())
|
||||
|
||||
return result
|
||||
|
||||
async def unwrap(self) -> tuple[AnyByteStream, bytes]:
|
||||
"""
|
||||
Does the TLS closing handshake.
|
||||
|
||||
:return: a tuple of (wrapped byte stream, bytes left in the read buffer)
|
||||
|
||||
"""
|
||||
await self._call_sslobject_method(self._ssl_object.unwrap)
|
||||
self._read_bio.write_eof()
|
||||
self._write_bio.write_eof()
|
||||
return self.transport_stream, self._read_bio.read()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self.standard_compatible:
|
||||
try:
|
||||
await self.unwrap()
|
||||
except BaseException:
|
||||
await aclose_forcefully(self.transport_stream)
|
||||
raise
|
||||
|
||||
await self.transport_stream.aclose()
|
||||
|
||||
async def receive(self, max_bytes: int = 65536) -> bytes:
|
||||
data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
|
||||
if not data:
|
||||
raise EndOfStream
|
||||
|
||||
return data
|
||||
|
||||
async def send(self, item: bytes) -> None:
|
||||
await self._call_sslobject_method(self._ssl_object.write, item)
|
||||
|
||||
async def send_eof(self) -> None:
|
||||
tls_version = self.extra(TLSAttribute.tls_version)
|
||||
match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
|
||||
if match:
|
||||
major, minor = int(match.group(1)), int(match.group(2) or 0)
|
||||
if (major, minor) < (1, 3):
|
||||
raise NotImplementedError(
|
||||
f"send_eof() requires at least TLSv1.3; current "
|
||||
f"session uses {tls_version}"
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
"send_eof() has not yet been implemented for TLS streams"
|
||||
)
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
return {
|
||||
**self.transport_stream.extra_attributes,
|
||||
TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
|
||||
TLSAttribute.channel_binding_tls_unique: (
|
||||
self._ssl_object.get_channel_binding
|
||||
),
|
||||
TLSAttribute.cipher: self._ssl_object.cipher,
|
||||
TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
|
||||
TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
|
||||
True
|
||||
),
|
||||
TLSAttribute.server_side: lambda: self._ssl_object.server_side,
|
||||
TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers()
|
||||
if self._ssl_object.server_side
|
||||
else None,
|
||||
TLSAttribute.standard_compatible: lambda: self.standard_compatible,
|
||||
TLSAttribute.ssl_object: lambda: self._ssl_object,
|
||||
TLSAttribute.tls_version: self._ssl_object.version,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class TLSListener(Listener[TLSStream]):
|
||||
"""
|
||||
A convenience listener that wraps another listener and auto-negotiates a TLS session
|
||||
on every accepted connection.
|
||||
|
||||
If the TLS handshake times out or raises an exception,
|
||||
:meth:`handle_handshake_error` is called to do whatever post-mortem processing is
|
||||
deemed necessary.
|
||||
|
||||
Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
|
||||
|
||||
:param Listener listener: the listener to wrap
|
||||
:param ssl_context: the SSL context object
|
||||
:param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
|
||||
:param handshake_timeout: time limit for the TLS handshake
|
||||
(passed to :func:`~anyio.fail_after`)
|
||||
"""
|
||||
|
||||
listener: Listener[Any]
|
||||
ssl_context: ssl.SSLContext
|
||||
standard_compatible: bool = True
|
||||
handshake_timeout: float = 30
|
||||
|
||||
@staticmethod
|
||||
async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
|
||||
"""
|
||||
Handle an exception raised during the TLS handshake.
|
||||
|
||||
This method does 3 things:
|
||||
|
||||
#. Forcefully closes the original stream
|
||||
#. Logs the exception (unless it was a cancellation exception) using the
|
||||
``anyio.streams.tls`` logger
|
||||
#. Reraises the exception if it was a base exception or a cancellation exception
|
||||
|
||||
:param exc: the exception
|
||||
:param stream: the original stream
|
||||
|
||||
"""
|
||||
await aclose_forcefully(stream)
|
||||
|
||||
# Log all except cancellation exceptions
|
||||
if not isinstance(exc, get_cancelled_exc_class()):
|
||||
# CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using
|
||||
# any asyncio implementation, so we explicitly pass the exception to log
|
||||
# (https://github.com/python/cpython/issues/108668). Trio does not have this
|
||||
# issue because it works around the CPython bug.
|
||||
logging.getLogger(__name__).exception(
|
||||
"Error during TLS handshake", exc_info=exc
|
||||
)
|
||||
|
||||
# Only reraise base exceptions and cancellation exceptions
|
||||
if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
|
||||
raise
|
||||
|
||||
async def serve(
|
||||
self,
|
||||
handler: Callable[[TLSStream], Any],
|
||||
task_group: TaskGroup | None = None,
|
||||
) -> None:
|
||||
@wraps(handler)
|
||||
async def handler_wrapper(stream: AnyByteStream) -> None:
|
||||
from .. import fail_after
|
||||
|
||||
try:
|
||||
with fail_after(self.handshake_timeout):
|
||||
wrapped_stream = await TLSStream.wrap(
|
||||
stream,
|
||||
ssl_context=self.ssl_context,
|
||||
standard_compatible=self.standard_compatible,
|
||||
)
|
||||
except BaseException as exc:
|
||||
await self.handle_handshake_error(exc, stream)
|
||||
else:
|
||||
await handler(wrapped_stream)
|
||||
|
||||
await self.listener.serve(handler_wrapper, task_group)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self.listener.aclose()
|
||||
|
||||
@property
|
||||
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||||
return {
|
||||
TLSAttribute.standard_compatible: lambda: self.standard_compatible,
|
||||
}
|
||||
|
||||
|
||||
class TLSConnectable(ByteStreamConnectable):
|
||||
"""
|
||||
Wraps another connectable and does TLS negotiation after a successful connection.
|
||||
|
||||
:param connectable: the connectable to wrap
|
||||
:param hostname: host name of the server (if host name checking is desired)
|
||||
:param ssl_context: the SSLContext object to use (if not provided, a secure default
|
||||
will be created)
|
||||
:param standard_compatible: if ``False``, skip the closing handshake when closing
|
||||
the connection, and don't raise an exception if the server does the same
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connectable: AnyByteStreamConnectable,
|
||||
*,
|
||||
hostname: str | None = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
standard_compatible: bool = True,
|
||||
) -> None:
|
||||
self.connectable = connectable
|
||||
self.ssl_context: SSLContext = ssl_context or ssl.create_default_context(
|
||||
ssl.Purpose.SERVER_AUTH
|
||||
)
|
||||
if not isinstance(self.ssl_context, ssl.SSLContext):
|
||||
raise TypeError(
|
||||
"ssl_context must be an instance of ssl.SSLContext, not "
|
||||
f"{type(self.ssl_context).__name__}"
|
||||
)
|
||||
self.hostname = hostname
|
||||
self.standard_compatible = standard_compatible
|
||||
|
||||
@override
|
||||
async def connect(self) -> TLSStream:
|
||||
stream = await self.connectable.connect()
|
||||
try:
|
||||
return await TLSStream.wrap(
|
||||
stream,
|
||||
hostname=self.hostname,
|
||||
ssl_context=self.ssl_context,
|
||||
standard_compatible=self.standard_compatible,
|
||||
)
|
||||
except BaseException:
|
||||
await aclose_forcefully(stream)
|
||||
raise
|
||||
Loading…
Add table
Add a link
Reference in a new issue