"""Module provides class BaseServer."""
from __future__ import annotations
# std imports
import zlib
import asyncio
import logging
import datetime
from typing import Any, Union, Optional
# local
from ._base import TelnetProtocolBase, _log_exception, _process_data_chunk
from ._types import ShellCallback
from .telopt import theNULL
from .accessories import TRACE, hexdump
from .stream_reader import TelnetReader, TelnetReaderUnicode
from .stream_writer import TelnetWriter, TelnetWriterUnicode
__all__ = ("BaseServer",)
logger = logging.getLogger("telnetlib3.server_base")
[docs]
class BaseServer(TelnetProtocolBase, asyncio.streams.FlowControlMixin, asyncio.Protocol):
"""Base Telnet Server Protocol."""
_advanced = False
_closing = False
_check_later = None
_rx_bytes = 0
_tx_bytes = 0
_mccp3_decompressor: Optional[zlib._Decompress] = None
def __init__(
self,
shell: Optional[ShellCallback] = None,
_waiter_connected: Optional[asyncio.Future[None]] = None,
encoding: Union[str, bool] = "utf8",
encoding_errors: str = "strict",
force_binary: bool = False,
never_send_ga: bool = False,
line_mode: bool = False,
connect_maxwait: float = 4.0,
limit: Optional[int] = None,
reader_factory: type = TelnetReader,
reader_factory_encoding: type = TelnetReaderUnicode,
writer_factory: type = TelnetWriter,
writer_factory_encoding: type = TelnetWriterUnicode,
) -> None:
"""Class initializer."""
super().__init__()
self.default_encoding = encoding
self._encoding_errors = encoding_errors
self.force_binary = force_binary
self.never_send_ga = never_send_ga
self.line_mode = line_mode
self._extra: dict[str, Any] = {}
self._reader_factory = reader_factory
self._reader_factory_encoding = reader_factory_encoding
self._writer_factory = writer_factory
self._writer_factory_encoding = writer_factory_encoding
#: a future used for testing
self._waiter_connected = _waiter_connected or asyncio.Future()
self._tasks: list[Any] = [self._waiter_connected]
self.shell = shell
self.reader: Optional[Union[TelnetReader, TelnetReaderUnicode]] = None
self.writer: Optional[Union[TelnetWriter, TelnetWriterUnicode]] = None
#: maximum duration for :meth:`check_negotiation`.
self.connect_maxwait = connect_maxwait
self._limit = limit
[docs]
def timeout_connection(self) -> None:
"""Close the connection due to timeout."""
self.reader.feed_eof()
self.writer.close()
# Base protocol methods
[docs]
def eof_received(self) -> None:
"""
Called when the other end calls write_eof() or equivalent.
This callback may be exercised by the nc(1) client argument ``-z``.
"""
logger.debug("EOF from client, closing.")
self.connection_lost(None)
[docs]
def connection_lost(self, exc: Optional[Exception]) -> None:
"""
Called when the connection is lost or closed.
:param exc: Exception instance, or ``None`` to indicate close by EOF.
"""
if self._closing:
return
self._closing = True
# inform yielding readers about closed connection
if exc is None:
logger.info("Connection closed for %s", self)
self.reader.feed_eof()
else:
logger.info("Connection lost for %s: %s", self, exc)
self.reader.set_exception(exc)
# cancel protocol tasks, namely on-connect negotiations
for task in self._tasks:
try:
task.cancel()
except Exception:
pass
# drop references to scheduled tasks/callbacks
self._tasks.clear()
try:
self._waiter_connected.remove_done_callback(self.begin_shell)
except Exception:
pass
# close transport (may already be closed), cancel Future _waiter_connected.
if self._transport is not None:
# Detach protocol from transport to drop strong reference immediately.
try:
if hasattr(self._transport, "set_protocol"):
self._transport.set_protocol(asyncio.Protocol())
except Exception:
pass
self._transport.close()
if not self._waiter_connected.cancelled() and not self._waiter_connected.done():
self._waiter_connected.cancel()
# break circular references for transport; keep reader/writer available
# for inspection by tests after close.
self._transport = None
[docs]
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""
Called when a connection is made.
Sets attributes ``_transport``, ``_when_connected``, ``_last_received``,
``reader`` and ``writer``.
Ensure ``super().connection_made(transport)`` is called when derived.
"""
self._transport = transport
self._when_connected = datetime.datetime.now()
self._last_received = datetime.datetime.now()
reader_factory = self._reader_factory
writer_factory = self._writer_factory
reader_kwds: dict[str, Any] = {}
writer_kwds: dict[str, Any] = {}
if self.default_encoding:
reader_kwds["fn_encoding"] = self.encoding
writer_kwds["fn_encoding"] = self.encoding
reader_kwds["encoding_errors"] = self._encoding_errors
writer_kwds["encoding_errors"] = self._encoding_errors
reader_factory = self._reader_factory_encoding
writer_factory = self._writer_factory_encoding
if self._limit:
reader_kwds["limit"] = self._limit
self.reader = reader_factory(**reader_kwds)
self.writer = writer_factory(
transport=transport, protocol=self, reader=self.reader, server=True, **writer_kwds
)
logger.info("Connection from %s", self)
self._log_tls_info(logger)
self._waiter_connected.add_done_callback(self.begin_shell)
asyncio.get_event_loop().call_soon(self.begin_negotiation)
[docs]
def begin_shell(self, future: asyncio.Future[None]) -> None:
"""Start the shell coroutine after negotiation completes."""
# Don't start shell if the connection was cancelled or errored
if future.cancelled() or future.exception() is not None:
return
if self.shell is not None:
assert self.reader is not None and self.writer is not None
coro = self.shell(self.reader, self.writer)
if asyncio.iscoroutine(coro):
loop = asyncio.get_event_loop()
loop.create_task(coro)
[docs]
def data_received(self, data: bytes) -> None:
"""
Process bytes received by transport.
Feeds raw bytes through the writer's IAC interpreter, forwarding in-band data to the reader.
"""
if logger.isEnabledFor(TRACE):
logger.log(TRACE, "recv %d bytes\n%s", len(data), hexdump(data, prefix="<< "))
self._last_received = datetime.datetime.now()
self._rx_bytes += len(data)
# MCCP3: decompress client→server data when active
if self._mccp3_decompressor is not None:
try:
data = self._mccp3_decompressor.decompress(data)
except zlib.error:
logger.warning("MCCP3 decompression error, disabling")
self._mccp3_end()
return
if self._mccp3_decompressor.eof:
unused = self._mccp3_decompressor.unused_data
self._mccp3_end()
if unused:
self.data_received(unused)
return
if self.writer.slc_simulated:
slc_vals = {defn.val[0] for defn in self.writer.slctab.values() if defn.val != theNULL}
slc_special: frozenset[int] | None = frozenset({255} | slc_vals)
else:
slc_special = None
cmd_received = _process_data_chunk(
data, self.writer, self.reader, slc_special, logger.warning
)
# Check if MCCP3 SB was just received (client→server compression start)
if self.writer.mccp3_active and self._mccp3_decompressor is None:
self._mccp3_start()
if not self._waiter_connected.done() and cmd_received:
self._check_negotiation_timer()
# public properties
@property
def rx_bytes(self) -> int:
"""Total bytes received from client."""
return self._rx_bytes
@property
def tx_bytes(self) -> int:
"""Total bytes sent to client."""
return self._tx_bytes
# public protocol methods
[docs]
def begin_negotiation(self) -> None:
"""
Begin on-connect negotiation.
A Telnet server is expected to demand preferred session options
immediately after connection. Deriving implementations should always
call ``super().begin_negotiation()``.
"""
self._check_later = asyncio.get_event_loop().call_soon(self._check_negotiation_timer)
self._tasks.append(self._check_later)
[docs]
def begin_advanced_negotiation(self) -> None:
"""
Begin advanced negotiation.
Callback method further requests advanced telnet options. Called
once on receipt of any ``DO`` or ``WILL`` acknowledgments
received, indicating that the remote end is capable of negotiating
further.
Only called if sub-classing :meth:`begin_negotiation` causes
at least one negotiation option to be affirmatively acknowledged.
"""
[docs]
def encoding(self, outgoing: bool = False, incoming: bool = False) -> Union[str, bool]:
"""
Encoding that should be used for the direction indicated.
The base implementation **always** returns the encoding given to class
initializer, or, when unset (None), ``US-ASCII``.
"""
return self.default_encoding or "US-ASCII"
[docs]
def negotiation_should_advance(self) -> bool:
"""
Whether advanced negotiation should commence.
:returns: ``True`` if advanced negotiation should be permitted.
The base implementation returns True if any negotiation options
were affirmatively acknowledged by client, more than likely
options requested in callback :meth:`begin_negotiation`.
"""
# Generally, this separates a bare TCP connect() from a True
# RFC-compliant telnet client with responding IAC interpreter.
server_do = sum(enabled for _, enabled in self.writer.remote_option.items())
client_will = sum(enabled for _, enabled in self.writer.local_option.items())
return bool(server_do or client_will)
[docs]
def check_negotiation(self, final: bool = False) -> bool:
"""
Callback, return whether negotiation is complete.
:param final: Whether this is the final time this callback
will be requested to answer regarding protocol negotiation.
:returns: Whether negotiation is over (server end is satisfied).
Method is called on each new command byte processed until negotiation is
considered final, or after ``connect_maxwait`` has elapsed, setting
attribute ``_waiter_connected`` to value ``self`` when complete.
Ensure ``super().check_negotiation()`` is called and conditionally
combined when derived.
"""
if not self._advanced and self.negotiation_should_advance():
self._advanced = True
logger.debug("begin advanced negotiation")
asyncio.get_event_loop().call_soon(self.begin_advanced_negotiation)
# negotiation is complete (returns True) when all negotiation options
# that have been requested have been acknowledged.
return not any(self.writer.pending_option.values())
# private methods
def _check_negotiation_timer(self) -> None:
if self._check_later is not None:
self._check_later.cancel()
if self._check_later in self._tasks:
self._tasks.remove(self._check_later)
later = self.connect_maxwait - self.duration
final = bool(later < 0)
if self.check_negotiation(final=final):
logger.debug("negotiation complete after %1.2fs.", self.duration)
self._waiter_connected.set_result(None)
elif final:
logger.debug("negotiation failed after %1.2fs.", self.duration)
self._waiter_connected.set_result(None)
else:
# keep re-queuing until complete
self._check_later = asyncio.get_event_loop().call_later(
later, self._check_negotiation_timer
)
self._tasks.append(self._check_later)
def _mccp3_start(self) -> None:
"""Start MCCP3 decompression of client→server data."""
self._mccp3_decompressor = zlib.decompressobj()
logger.debug("MCCP3 decompression started (client→server)")
def _mccp3_end(self) -> None:
"""Stop MCCP3 decompression."""
self._mccp3_decompressor = None
self.writer.mccp3_active = False
logger.debug("MCCP3 decompression ended (client→server)")
_log_exception = staticmethod(_log_exception)