from __future__ import annotations import asyncio import errno import socket import struct from abc import ABC, abstractmethod from collections.abc import Iterable from ctypes import c_int32 as signed_int32 from ctypes import c_int64 as signed_int64 from ctypes import c_uint32 as unsigned_int32 from ctypes import c_uint64 as unsigned_int64 from ipaddress import ip_address from typing import TYPE_CHECKING, cast import asyncio_dgram from mcstatus.address import Address if TYPE_CHECKING: from typing_extensions import Self, SupportsIndex, TypeAlias BytesConvertable: TypeAlias = "SupportsIndex | Iterable[SupportsIndex]" def ip_type(address: int | str) -> int | None: """Determinate what IP version is. :param address: A string or integer, the IP address. Either IPv4 or IPv6 addresses may be supplied. Integers less than 2**32 will be considered to be IPv4 by default. :return: ``4`` or ``6`` if the IP is IPv4 or IPv6, respectively. :obj:`None` if the IP is invalid. """ try: return ip_address(address).version except ValueError: return None class BaseWriteSync(ABC): """Base synchronous write class""" __slots__ = () @abstractmethod def write(self, data: Connection | str | bytearray | bytes) -> None: """Write data to ``self``.""" def __repr__(self) -> str: return f"<{self.__class__.__name__} Object>" @staticmethod def _pack(format_: str, data: int) -> bytes: """Pack data in with format in big-endian mode.""" return struct.pack(">" + format_, data) def write_varint(self, value: int) -> None: """Write varint with value ``value`` to ``self``. :param value: Maximum is ``2 ** 31 - 1``, minimum is ``-(2 ** 31)``. :raises ValueError: If value is out of range. """ remaining = unsigned_int32(value).value for _ in range(5): if not remaining & -0x80: # remaining & ~0x7F == 0: self.write(struct.pack("!B", remaining)) if value > 2**31 - 1 or value < -(2**31): break return self.write(struct.pack("!B", remaining & 0x7F | 0x80)) remaining >>= 7 raise ValueError(f'The value "{value}" is too big to send in a varint') def write_varlong(self, value: int) -> None: """Write varlong with value ``value`` to ``self``. :param value: Maximum is ``2 ** 63 - 1``, minimum is ``-(2 ** 63)``. :raises ValueError: If value is out of range. """ remaining = unsigned_int64(value).value for _ in range(10): if not remaining & -0x80: # remaining & ~0x7F == 0: self.write(struct.pack("!B", remaining)) if value > 2**63 - 1 or value < -(2**31): break return self.write(struct.pack("!B", remaining & 0x7F | 0x80)) remaining >>= 7 raise ValueError(f'The value "{value}" is too big to send in a varlong') def write_utf(self, value: str) -> None: """Write varint of length of ``value`` up to 32767 bytes, then write ``value`` encoded with ``UTF-8``.""" self.write_varint(len(value)) self.write(bytearray(value, "utf8")) def write_ascii(self, value: str) -> None: """Write value encoded with ``ISO-8859-1``, then write an additional ``0x00`` at the end.""" self.write(bytearray(value, "ISO-8859-1")) self.write(bytearray.fromhex("00")) def write_short(self, value: int) -> None: """Write 2 bytes for value ``-32768 - 32767``.""" self.write(self._pack("h", value)) def write_ushort(self, value: int) -> None: """Write 2 bytes for value ``0 - 65535 (2 ** 16 - 1)``.""" self.write(self._pack("H", value)) def write_int(self, value: int) -> None: """Write 4 bytes for value ``-2147483648 - 2147483647``.""" self.write(self._pack("i", value)) def write_uint(self, value: int) -> None: """Write 4 bytes for value ``0 - 4294967295 (2 ** 32 - 1)``.""" self.write(self._pack("I", value)) def write_long(self, value: int) -> None: """Write 8 bytes for value ``-9223372036854775808 - 9223372036854775807``.""" self.write(self._pack("q", value)) def write_ulong(self, value: int) -> None: """Write 8 bytes for value ``0 - 18446744073709551613 (2 ** 64 - 1)``.""" self.write(self._pack("Q", value)) def write_bool(self, value: bool) -> None: """Write 1 byte for boolean `True` or `False`""" self.write(self._pack("?", value)) def write_buffer(self, buffer: "Connection") -> None: """Flush buffer, then write a varint of the length of the buffer's data, then write buffer data.""" data = buffer.flush() self.write_varint(len(data)) self.write(data) class BaseWriteAsync(ABC): """Base synchronous write class""" __slots__ = () @abstractmethod async def write(self, data: Connection | str | bytearray | bytes) -> None: """Write data to ``self``.""" def __repr__(self) -> str: return f"<{self.__class__.__name__} Object>" @staticmethod def _pack(format_: str, data: int) -> bytes: """Pack data in with format in big-endian mode.""" return struct.pack(">" + format_, data) async def write_varint(self, value: int) -> None: """Write varint with value ``value`` to ``self``. :param value: Maximum is ``2 ** 31 - 1``, minimum is ``-(2 ** 31)``. :raises ValueError: If value is out of range. """ remaining = unsigned_int32(value).value for _ in range(5): if not remaining & -0x80: # remaining & ~0x7F == 0: await self.write(struct.pack("!B", remaining)) if value > 2**31 - 1 or value < -(2**31): break return await self.write(struct.pack("!B", remaining & 0x7F | 0x80)) remaining >>= 7 raise ValueError(f'The value "{value}" is too big to send in a varint') async def write_varlong(self, value: int) -> None: """Write varlong with value ``value`` to ``self``. :param value: Maximum is ``2 ** 63 - 1``, minimum is ``-(2 ** 63)``. :raises ValueError: If value is out of range. """ remaining = unsigned_int64(value).value for _ in range(10): if not remaining & -0x80: # remaining & ~0x7F == 0: await self.write(struct.pack("!B", remaining)) if value > 2**63 - 1 or value < -(2**31): break return await self.write(struct.pack("!B", remaining & 0x7F | 0x80)) remaining >>= 7 raise ValueError(f'The value "{value}" is too big to send in a varlong') async def write_utf(self, value: str) -> None: """Write varint of length of ``value`` up to 32767 bytes, then write ``value`` encoded with ``UTF-8``.""" await self.write_varint(len(value)) await self.write(bytearray(value, "utf8")) async def write_ascii(self, value: str) -> None: """Write value encoded with ``ISO-8859-1``, then write an additional ``0x00`` at the end.""" await self.write(bytearray(value, "ISO-8859-1")) await self.write(bytearray.fromhex("00")) async def write_short(self, value: int) -> None: """Write 2 bytes for value ``-32768 - 32767``.""" await self.write(self._pack("h", value)) async def write_ushort(self, value: int) -> None: """Write 2 bytes for value ``0 - 65535 (2 ** 16 - 1)``.""" await self.write(self._pack("H", value)) async def write_int(self, value: int) -> None: """Write 4 bytes for value ``-2147483648 - 2147483647``.""" await self.write(self._pack("i", value)) async def write_uint(self, value: int) -> None: """Write 4 bytes for value ``0 - 4294967295 (2 ** 32 - 1)``.""" await self.write(self._pack("I", value)) async def write_long(self, value: int) -> None: """Write 8 bytes for value ``-9223372036854775808 - 9223372036854775807``.""" await self.write(self._pack("q", value)) async def write_ulong(self, value: int) -> None: """Write 8 bytes for value ``0 - 18446744073709551613 (2 ** 64 - 1)``.""" await self.write(self._pack("Q", value)) async def write_bool(self, value: bool) -> None: """Write 1 byte for boolean `True` or `False`""" await self.write(self._pack("?", value)) async def write_buffer(self, buffer: "Connection") -> None: """Flush buffer, then write a varint of the length of the buffer's data, then write buffer data.""" data = buffer.flush() await self.write_varint(len(data)) await self.write(data) class BaseReadSync(ABC): """Base synchronous read class""" __slots__ = () @abstractmethod def read(self, length: int) -> bytearray: """Read length bytes from ``self``, and return a byte array.""" def __repr__(self) -> str: return f"<{self.__class__.__name__} Object>" @staticmethod def _unpack(format_: str, data: bytes) -> int: """Unpack data as bytes with format in big-endian.""" return struct.unpack(">" + format_, bytes(data))[0] def read_varint(self) -> int: """Read varint from ``self`` and return it. :param value: Maximum is ``2 ** 31 - 1``, minimum is ``-(2 ** 31)``. :raises IOError: If varint received is out of range. """ result = 0 for i in range(5): part = self.read(1)[0] result |= (part & 0x7F) << (7 * i) if not part & 0x80: return signed_int32(result).value raise IOError("Received varint is too big!") def read_varlong(self) -> int: """Read varlong from ``self`` and return it. :param value: Maximum is ``2 ** 63 - 1``, minimum is ``-(2 ** 63)``. :raises IOError: If varint received is out of range. """ result = 0 for i in range(10): part = self.read(1)[0] result |= (part & 0x7F) << (7 * i) if not part & 0x80: return signed_int64(result).value raise IOError("Received varlong is too big!") def read_utf(self) -> str: """Read up to 32767 bytes by reading a varint, then decode bytes as ``UTF-8``.""" length = self.read_varint() return self.read(length).decode("utf8") def read_ascii(self) -> str: """Read ``self`` until last value is not zero, then return that decoded with ``ISO-8859-1``""" result = bytearray() while len(result) == 0 or result[-1] != 0: result.extend(self.read(1)) return result[:-1].decode("ISO-8859-1") def read_short(self) -> int: """Return ``-32768 - 32767``. Read 2 bytes.""" return self._unpack("h", self.read(2)) def read_ushort(self) -> int: """Return ``0 - 65535 (2 ** 16 - 1)``. Read 2 bytes.""" return self._unpack("H", self.read(2)) def read_int(self) -> int: """Return ``-2147483648 - 2147483647``. Read 4 bytes.""" return self._unpack("i", self.read(4)) def read_uint(self) -> int: """Return ``0 - 4294967295 (2 ** 32 - 1)``. 4 bytes read.""" return self._unpack("I", self.read(4)) def read_long(self) -> int: """Return ``-9223372036854775808 - 9223372036854775807``. Read 8 bytes.""" return self._unpack("q", self.read(8)) def read_ulong(self) -> int: """Return ``0 - 18446744073709551613 (2 ** 64 - 1)``. Read 8 bytes.""" return self._unpack("Q", self.read(8)) def read_bool(self) -> bool: """Return `True` or `False`. Read 1 byte.""" return cast(bool, self._unpack("?", self.read(1))) def read_buffer(self) -> "Connection": """Read a varint for length, then return a new connection from length read bytes.""" length = self.read_varint() result = Connection() result.receive(self.read(length)) return result class BaseReadAsync(ABC): """Asynchronous Read connection base class.""" __slots__ = () @abstractmethod async def read(self, length: int) -> bytearray: """Read length bytes from ``self``, return a byte array.""" def __repr__(self) -> str: return f"<{self.__class__.__name__} Object>" @staticmethod def _unpack(format_: str, data: bytes) -> int: """Unpack data as bytes with format in big-endian.""" return struct.unpack(">" + format_, bytes(data))[0] async def read_varint(self) -> int: """Read varint from ``self`` and return it. :param value: Maximum is ``2 ** 31 - 1``, minimum is ``-(2 ** 31)``. :raises IOError: If varint received is out of range. """ result = 0 for i in range(5): part = (await self.read(1))[0] result |= (part & 0x7F) << 7 * i if not part & 0x80: return signed_int32(result).value raise IOError("Received a varint that was too big!") async def read_varlong(self) -> int: """Read varlong from ``self`` and return it. :param value: Maximum is ``2 ** 63 - 1``, minimum is ``-(2 ** 63)``. :raises IOError: If varint received is out of range. """ result = 0 for i in range(10): part = (await self.read(1))[0] result |= (part & 0x7F) << (7 * i) if not part & 0x80: return signed_int64(result).value raise IOError("Received varlong is too big!") async def read_utf(self) -> str: """Read up to 32767 bytes by reading a varint, then decode bytes as ``UTF-8``.""" length = await self.read_varint() return (await self.read(length)).decode("utf8") async def read_ascii(self) -> str: """Read ``self`` until last value is not zero, then return that decoded with ``ISO-8859-1``""" result = bytearray() while len(result) == 0 or result[-1] != 0: result.extend(await self.read(1)) return result[:-1].decode("ISO-8859-1") async def read_short(self) -> int: """Return ``-32768 - 32767``. Read 2 bytes.""" return self._unpack("h", await self.read(2)) async def read_ushort(self) -> int: """Return ``0 - 65535 (2 ** 16 - 1)``. Read 2 bytes.""" return self._unpack("H", await self.read(2)) async def read_int(self) -> int: """Return ``-2147483648 - 2147483647``. Read 4 bytes.""" return self._unpack("i", await self.read(4)) async def read_uint(self) -> int: """Return ``0 - 4294967295 (2 ** 32 - 1)``. 4 bytes read.""" return self._unpack("I", await self.read(4)) async def read_long(self) -> int: """Return ``-9223372036854775808 - 9223372036854775807``. Read 8 bytes.""" return self._unpack("q", await self.read(8)) async def read_ulong(self) -> int: """Return ``0 - 18446744073709551613 (2 ** 64 - 1)``. Read 8 bytes.""" return self._unpack("Q", await self.read(8)) async def read_bool(self) -> bool: """Return `True` or `False`. Read 1 byte.""" return cast(bool, self._unpack("?", await self.read(1))) async def read_buffer(self) -> Connection: """Read a varint for length, then return a new connection from length read bytes.""" length = await self.read_varint() result = Connection() result.receive(await self.read(length)) return result class BaseConnection: """Base Connection class. Implements flush, receive, and remaining.""" __slots__ = () def __repr__(self) -> str: return f"<{self.__class__.__name__} Object>" def flush(self) -> bytearray: """Raise :exc:`TypeError`, unsupported.""" raise TypeError(f"{self.__class__.__name__} does not support flush()") def receive(self, data: BytesConvertable | bytearray) -> None: """Raise :exc:`TypeError`, unsupported.""" raise TypeError(f"{self.__class__.__name__} does not support receive()") def remaining(self) -> int: """Raise :exc:`TypeError`, unsupported.""" raise TypeError(f"{self.__class__.__name__} does not support remaining()") class BaseSyncConnection(BaseConnection, BaseReadSync, BaseWriteSync): """Base synchronous read and write class""" __slots__ = () class BaseAsyncReadSyncWriteConnection(BaseConnection, BaseReadAsync, BaseWriteSync): """Base asynchronous read and synchronous write class""" __slots__ = () class BaseAsyncConnection(BaseConnection, BaseReadAsync, BaseWriteAsync): """Base asynchronous read and write class""" __slots__ = () class Connection(BaseSyncConnection): """Base connection class.""" __slots__ = ("received", "sent") def __init__(self) -> None: self.sent = bytearray() self.received = bytearray() def read(self, length: int) -> bytearray: """Return :attr:`.received` up to length bytes, then cut received up to that point.""" if len(self.received) < length: raise IOError(f"Not enough data to read! {len(self.received)} < {length}") result = self.received[:length] self.received = self.received[length:] return result def write(self, data: Connection | str | bytearray | bytes) -> None: """Extend :attr:`.sent` from ``data``.""" if isinstance(data, Connection): data = data.flush() if isinstance(data, str): data = bytearray(data, "utf-8") self.sent.extend(data) def receive(self, data: BytesConvertable | bytearray) -> None: """Extend :attr:`.received` with ``data``.""" if not isinstance(data, bytearray): data = bytearray(data) self.received.extend(data) def remaining(self) -> int: """Return length of :attr:`.received`.""" return len(self.received) def flush(self) -> bytearray: """Return :attr:`.sent`, also clears :attr:`.sent`.""" result, self.sent = self.sent, bytearray() return result def copy(self) -> "Connection": """Return a copy of ``self``""" new = self.__class__() new.receive(self.received) new.write(self.sent) return new class SocketConnection(BaseSyncConnection): """Socket connection.""" __slots__ = ("socket",) def __init__(self) -> None: # These will only be None until connect is called, ignore the None type assignment self.socket: socket.socket = None # type: ignore[assignment] def close(self) -> None: """Close :attr:`.socket`.""" if self.socket is not None: # If initialized try: self.socket.shutdown(socket.SHUT_RDWR) except OSError as exception: # Socket wasn't connected (nothing to shut down) if exception.errno != errno.ENOTCONN: raise self.socket.close() def __enter__(self) -> Self: return self def __exit__(self, *_) -> None: self.close() class TCPSocketConnection(SocketConnection): """TCP Connection to address. Timeout defaults to 3 seconds.""" __slots__ = () def __init__(self, addr: tuple[str | None, int], timeout: float = 3): super().__init__() self.socket = socket.create_connection(addr, timeout=timeout) self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) def read(self, length: int) -> bytearray: """Return length bytes read from :attr:`.socket`. Raises :exc:`IOError` when server doesn't respond.""" result = bytearray() while len(result) < length: new = self.socket.recv(length - len(result)) if len(new) == 0: raise IOError("Server did not respond with any information!") result.extend(new) return result def write(self, data: Connection | str | bytes | bytearray) -> None: """Send data on :attr:`.socket`.""" if isinstance(data, Connection): data = bytearray(data.flush()) elif isinstance(data, str): data = bytearray(data, "utf-8") self.socket.send(data) class UDPSocketConnection(SocketConnection): """UDP Connection class""" __slots__ = ("addr",) def __init__(self, addr: Address, timeout: float = 3): super().__init__() self.addr = addr self.socket = socket.socket( socket.AF_INET if ip_type(addr[0]) == 4 else socket.AF_INET6, socket.SOCK_DGRAM, ) self.socket.settimeout(timeout) def remaining(self) -> int: """Always return ``65535`` (``2 ** 16 - 1``).""" return 65535 def read(self, length: int) -> bytearray: """Return up to :meth:`.remaining` bytes. Length does nothing here.""" result = bytearray() while len(result) == 0: result.extend(self.socket.recvfrom(self.remaining())[0]) return result def write(self, data: Connection | str | bytes | bytearray) -> None: """Use :attr:`.socket` to send data to :attr:`.addr`.""" if isinstance(data, Connection): data = bytearray(data.flush()) elif isinstance(data, str): data = bytearray(data, "utf-8") self.socket.sendto(data, self.addr) class TCPAsyncSocketConnection(BaseAsyncReadSyncWriteConnection): """Asynchronous TCP Connection class""" __slots__ = ("_addr", "reader", "timeout", "writer") def __init__(self, addr: Address, timeout: float = 3) -> None: # These will only be None until connect is called, ignore the None type assignment self.reader: asyncio.StreamReader = None # type: ignore[assignment] self.writer: asyncio.StreamWriter = None # type: ignore[assignment] self.timeout: float = timeout self._addr = addr async def connect(self) -> None: """Use :mod:`asyncio` to open a connection to address. Timeout is in seconds.""" conn = asyncio.open_connection(*self._addr) self.reader, self.writer = await asyncio.wait_for(conn, timeout=self.timeout) if self.writer is not None: # it might be None in unittest sock: socket.socket = self.writer.transport.get_extra_info("socket") sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) async def read(self, length: int) -> bytearray: """Read up to ``length`` bytes from :attr:`.reader`.""" result = bytearray() while len(result) < length: new = await asyncio.wait_for(self.reader.read(length - len(result)), timeout=self.timeout) if len(new) == 0: raise IOError("Socket did not respond with any information!") result.extend(new) return result def write(self, data: Connection | str | bytes | bytearray) -> None: """Write data to :attr:`.writer`.""" if isinstance(data, Connection): data = bytearray(data.flush()) elif isinstance(data, str): data = bytearray(data, "utf-8") self.writer.write(data) def close(self) -> None: """Close :attr:`.writer`.""" if self.writer is not None: # If initialized self.writer.close() async def __aenter__(self) -> Self: await self.connect() return self async def __aexit__(self, *_) -> None: self.close() class UDPAsyncSocketConnection(BaseAsyncConnection): """Asynchronous UDP Connection class""" __slots__ = ("_addr", "stream", "timeout") def __init__(self, addr: Address, timeout: float = 3) -> None: # This will only be None until connect is called, ignore the None type assignment self.stream: asyncio_dgram.aio.DatagramClient = None # type: ignore[assignment] self.timeout: float = timeout self._addr = addr async def connect(self) -> None: """Connect to address. Timeout is in seconds.""" conn = asyncio_dgram.connect(self._addr) self.stream = await asyncio.wait_for(conn, timeout=self.timeout) def remaining(self) -> int: """Always return ``65535`` (``2 ** 16 - 1``).""" return 65535 async def read(self, length: int) -> bytearray: """Read from :attr:`.stream`. Length does nothing here.""" data, remote_addr = await asyncio.wait_for(self.stream.recv(), timeout=self.timeout) return bytearray(data) async def write(self, data: Connection | str | bytes | bytearray) -> None: """Send data with :attr:`.stream`.""" if isinstance(data, Connection): data = bytearray(data.flush()) elif isinstance(data, str): data = bytearray(data, "utf-8") await self.stream.send(data) def close(self) -> None: """Close :attr:`.stream`.""" if self.stream is not None: # If initialized self.stream.close() async def __aenter__(self) -> Self: await self.connect() return self async def __aexit__(self, *_) -> None: self.close()