"""Transports for message bus connections."""
import array
import logging
import socket
import threading
import time
from contextlib import suppress
from .auth import authenticate
from .const import MAX_MESSAGE_LEN, MIN_HEADER_SIZE
from .errors import TransportError, TooLongError
from .message import get_sizes, get_unix_fds_cnt, Message
from .raw import RawData
__all__ = [
'Transport',
'UnixTransport',
'TcpTransport',
'NonceTcpTransport',
]
_logger = logging.getLogger(__name__)
def check_for_known_transport(addr):
"""Check whether there is a known transport for any address.
:param ~dcar.address.Address addr: addresses
:raises TransportError: if no transport was found
"""
for name, _ in addr:
if name in _transports:
return
raise TransportError('no transport found')
def connect(addr, router):
"""Connect to message bus.
Tries every address until connection is successful
or there are no more addresses.
:param ~dcar.address.Address addr: addresses
:raises TransportError: if connection failed
"""
for name, params in addr:
transport_class = _transports.get(name)
if transport_class:
transport = transport_class(params, router)
try:
transport.connect()
return (transport,
'%s:%s' % (name,
','.join('%s=%s' % (k, v)
for k, v in params.items())))
except Exception:
_logger.debug('connect failed: %s, %s',
name, params, exc_info=True)
else:
raise TransportError('connection failed')
[docs]class Transport:
"""Base class.
:param dict params: address parameters
:param ~dcar.router.Router router: router object
"""
def __init__(self, params, router):
self.guid = params.get('guid')
self.unix_fds_enabled = False
self.connected = False
self._router = router
self._error = None
self._lock = threading.Lock()
@property
def error(self):
"""Return the error which caused disconnection or ``None``."""
if self._error:
if isinstance(self._error, TransportError):
ex = TransportError('connection lost')
ex.__traceback__ = self._error.__traceback__
else:
ex = TransportError('connection lost: %s' % self._error)
ex.__cause__ = self._error
return ex
def _set_error(self, exc):
if not self._error and self.connected:
self._error = exc
[docs] def connect(self):
"""Connect to message bus."""
with self._lock:
if self.connected:
return
self._sock = socket.socket(self._addr_family)
self._sock.connect(self._address)
self.connected = True
[docs] def disconnect(self):
"""Disconnect from message bus."""
with self._lock:
if not self.connected:
return
self.connected = False
with suppress(OSError):
self._sock.shutdown(socket.SHUT_RDWR)
self._sock.close()
self._router.incoming(None)
[docs] def authenticate(self):
"""Authenticate to message bus."""
self.guid, self.unix_fds_enabled = authenticate(self._sock,
self.unix_fds_enabled)
[docs] def start_loops(self):
"""Start threads with ``recv-loop`` and ``send-loop``."""
self._recv_loop = threading.Thread(target=self._recv_loop,
name='recv-loop', daemon=True)
self._recv_loop.start()
self._send_loop = threading.Thread(target=self._send_loop,
name='send-loop', daemon=True)
self._send_loop.start()
[docs] def block(self, timeout=None):
"""Block until loop threads are finished.
:param float timeout: timeout value in seconds
(``None`` means no timeout)
.. versionchanged:: 0.2.0 Add parameter ``timeout``
"""
if timeout is None:
self._recv_loop.join()
self._send_loop.join()
else:
t = time.time()
self._recv_loop.join(timeout)
r = timeout - (time.time() - t)
if r > 0:
self._send_loop.join(r)
def _send_loop(self):
while self.connected:
b, fds = self._router.out_queue.get()
if not b:
break
try:
if self.unix_fds_enabled and fds:
self._sock.sendmsg([b], [(socket.SOL_SOCKET,
socket.SCM_RIGHTS,
array.array('i', fds))])
else:
self._sock.sendall(b)
except Exception as ex:
_logger.debug('send loop', exc_info=True)
self._set_error(ex)
self.disconnect()
break
_logger.debug('EXIT send loop')
def _recv_loop(self):
try:
while self.connected:
b = self._sock.recv(MIN_HEADER_SIZE, socket.MSG_PEEK)
if not b:
raise TransportError()
total_size, fields_size = get_sizes(RawData(b))
if total_size > MAX_MESSAGE_LEN:
raise TooLongError('message too long: %d bytes' %
total_size)
raw = RawData(bytearray(total_size))
view = raw.getbuffer()
if self.unix_fds_enabled:
b = self._sock.recv(MIN_HEADER_SIZE + fields_size,
socket.MSG_PEEK)
unix_fds_cnt = get_unix_fds_cnt(RawData(b))
else:
unix_fds_cnt = 0
if unix_fds_cnt:
fds = array.array('i')
cnt, anc, _, _ = self._sock.recvmsg_into(
[view], socket.CMSG_SPACE(unix_fds_cnt * fds.itemsize))
for cmsg_level, cmsg_type, cmsg_data in anc:
if (cmsg_level == socket.SOL_SOCKET and
cmsg_type == socket.SCM_RIGHTS):
fds.frombytes(cmsg_data[:len(cmsg_data) -
(len(cmsg_data) % fds.itemsize)])
raw.unix_fds = fds.tolist()
else:
cnt = self._sock.recv_into(view)
view.release()
if not cnt:
raise TransportError()
self._router.incoming(Message.from_bytes(raw))
except Exception as ex:
if self.connected:
_logger.debug('recv loop', exc_info=True)
self._set_error(ex)
self.disconnect()
_logger.debug('EXIT recv loop')
[docs]class UnixTransport(Transport):
"""Transport that uses a unix domain socket.
It supports the passing of file descriptors.
"""
def __init__(self, params, router):
super().__init__(params, router)
self.unix_fds_enabled = True
self._addr_family = socket.AF_UNIX
if 'path' in params:
self._address = params['path']
else:
self._address = b'\0' + params['abstract'].encode()
[docs]class TcpTransport(Transport):
"""Transport that uses a TCP socket."""
def __init__(self, params, router):
super().__init__(params, router)
self._addr_family = socket.AF_INET
self._address = (params['host'], int(params['port']))
[docs]class NonceTcpTransport(TcpTransport):
"""Transport that uses a nonce-authenticated TCP socket."""
def __init__(self, params, router):
super().__init__(params, router)
self._noncefile = params['noncefile']
[docs] def connect(self):
"""Connect to message bus."""
super().connect()
with open(self._noncefile, 'br') as fh:
self._sock.sendall(fh.read())
_transports = {
'unix': UnixTransport,
'tcp': TcpTransport,
'nonce-tcp': NonceTcpTransport,
}