Source code for morango.sync.syncsession

"""
The main module to be used for initiating the synchronization of data between morango instances.
"""
import json
import logging
import os
import socket
import uuid
from io import BytesIO
from urllib.parse import urljoin
from urllib.parse import urlparse

from django.utils import timezone
from requests.adapters import HTTPAdapter
from requests.exceptions import HTTPError
from requests.packages.urllib3.util.retry import Retry

from .session import SessionWrapper
from morango.api.serializers import CertificateSerializer
from morango.api.serializers import InstanceIDSerializer
from morango.constants import api_urls
from morango.constants import transfer_stages
from morango.constants import transfer_statuses
from morango.constants.capabilities import ALLOW_CERTIFICATE_PUSHING
from morango.constants.capabilities import GZIP_BUFFER_POST
from morango.errors import CertificateSignatureInvalid
from morango.errors import MorangoError
from morango.errors import MorangoResumeSyncError
from morango.errors import MorangoServerDoesNotAllowNewCertPush
from morango.models.certificates import Certificate
from morango.models.certificates import Key
from morango.models.core import InstanceIDModel
from morango.models.core import SyncSession
from morango.sync.context import CompositeSessionContext
from morango.sync.context import LocalSessionContext
from morango.sync.context import NetworkSessionContext
from morango.sync.controller import SessionController
from morango.sync.utils import SyncSignal
from morango.sync.utils import SyncSignalGroup
from morango.utils import CAPABILITIES
from morango.utils import pid_exists

if GZIP_BUFFER_POST in CAPABILITIES:
    from gzip import GzipFile


logger = logging.getLogger(__name__)


def _join_with_logical_operator(lst, operator):
    op = ") {operator} (".format(operator=operator)
    return "(({items}))".format(items=op.join(lst))


def _get_server_ip(hostname):
    try:
        return socket.gethostbyname(hostname)
    except:  # noqa: E722
        return ""


def _get_client_ip_for_server(server_host, server_port):
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        s.connect((server_host, server_port))
        IP = s.getsockname()[0]
    except:  # noqa: E722
        IP = "127.0.0.1"
    finally:
        s.close()
    return IP


# borrowed from https://github.com/django/django/blob/1.11.20/django/utils/text.py#L295
[docs] def compress_string(s, compresslevel=9): zbuf = BytesIO() with GzipFile( mode="wb", compresslevel=compresslevel, fileobj=zbuf, mtime=0 ) as zfile: zfile.write(s) return zbuf.getvalue()
[docs] class Connection(object): """ Abstraction around a connection with a syncing peer (network or disk), supporting interactions with that peer. This may be used by a SyncClient, but also supports other operations (e.g. querying certificates) outside the context of syncing. This class should be subclassed for particular transport mechanisms, and the necessary methods overridden. """ pass
[docs] class NetworkSyncConnection(Connection): __slots__ = ( "base_url", "compresslevel", "session", "server_info", "capabilities", "chunk_size", ) default_chunk_size = 500 def __init__( self, base_url="", compresslevel=9, retries=7, backoff_factor=0.3, chunk_size=default_chunk_size, ): """ The underlying network connection with a syncing peer. Any network requests (such as certificate querying or syncing related) will be done through this class. """ if base_url == "": raise AssertionError("Network connection `base_url` cannot be empty") self.base_url = base_url self.compresslevel = compresslevel # set up requests session with retry logic self.session = SessionWrapper() # sleep for {backoff factor} * (2 ^ ({number of total retries} - 1)) between requests # with 7 retry attempts, sleep escalation becomes (0.6s, 1.2s, ..., 38.4s) retry = Retry(total=retries, backoff_factor=backoff_factor) adapter = HTTPAdapter(max_retries=retry) self.session.mount("http://", adapter) self.session.mount("https://", adapter) # get morango information about server self.server_info = self.session.get( urljoin(self.base_url, api_urls.INFO) ).json() self.capabilities = self.server_info.get("capabilities", []) self.chunk_size = chunk_size @property def bytes_sent(self): return self.session.bytes_sent @property def bytes_received(self): return self.session.bytes_received
[docs] def urlresolve(self, endpoint, lookup=None): if lookup: lookup = lookup + "/" url = urljoin(urljoin(self.base_url, endpoint), lookup) return url
[docs] def create_sync_session(self, client_cert, server_cert, chunk_size=None): """ Starts a sync session by creating it on the server side and returning a client to use for initiating transfer operations :param client_cert: The local certificate to use, already registered with the server :type client_cert: Certificate :param server_cert: The server's certificate that relates to the same profile as local :type server_cert: Certificate :param chunk_size: An optional parameter specifying the size for each transferred chunk :type chunk_size: int :return: A SyncSessionClient instance :rtype: SyncSessionClient """ if chunk_size is not None: self.chunk_size = chunk_size # if server cert does not exist locally, retrieve it from server if not Certificate.objects.filter(id=server_cert.id).exists(): cert_chain_response = self._get_certificate_chain( params={"ancestors_of": server_cert.id} ) # upon receiving cert chain from server, we attempt to save the chain into our records Certificate.save_certificate_chain( cert_chain_response.json(), expected_last_id=server_cert.id ) # request the server for a one-time-use nonce nonce_resp = self._get_nonce() nonce = nonce_resp.json()["id"] # if no hostname then url is actually an ip url = urlparse(self.base_url) hostname = url.hostname or self.base_url port = url.port or (80 if url.scheme == "http" else 443) # prepare the data to send in the syncsession creation request data = { "id": uuid.uuid4().hex, "server_certificate_id": server_cert.id, "client_certificate_id": client_cert.id, "profile": client_cert.profile, "certificate_chain": json.dumps( CertificateSerializer( client_cert.get_ancestors(include_self=True), many=True ).data ), "connection_path": self.base_url, "instance": json.dumps( InstanceIDSerializer( InstanceIDModel.get_or_create_current_instance()[0] ).data ), "nonce": nonce, "client_ip": _get_client_ip_for_server(hostname, port), "server_ip": _get_server_ip(hostname), } # sign the nonce/ID combo to attach to the request message = "{nonce}:{id}".format(**data) data["signature"] = client_cert.sign(message) # Sync Session creation request session_resp = self._create_sync_session(data) # check that the nonce/id were properly signed by the server cert if not server_cert.verify(message, session_resp.json().get("signature")): raise CertificateSignatureInvalid() client_instance = InstanceIDModel.get_or_create_current_instance()[0] server_instance_json = session_resp.json().get("server_instance") or "{}" server_instance_id = None if server_instance_json: server_instance = json.loads(server_instance_json) server_instance_id = server_instance.get("id") # build the data to be used for creating our own syncsession data = { "id": data["id"], "start_timestamp": timezone.now(), "last_activity_timestamp": timezone.now(), "active": True, "is_server": False, "client_certificate": client_cert, "server_certificate": server_cert, "profile": client_cert.profile, "connection_kind": "network", "connection_path": self.base_url, "client_ip": data["client_ip"], "server_ip": data["server_ip"], "client_instance_id": client_instance.id, "client_instance_json": json.dumps( InstanceIDSerializer( client_instance ).data ), "server_instance_id": server_instance_id, "server_instance_json": session_resp.json().get("server_instance") or "{}", "process_id": os.getpid(), } sync_session = SyncSession.objects.create(**data) return SyncSessionClient(self, sync_session)
[docs] def resume_sync_session(self, sync_session_id, chunk_size=None, ignore_existing_process=False): """ Resumes an existing sync session given an ID :param sync_session_id: The UUID of the `SyncSession` to resume :param chunk_size: An optional parameter specifying the size for each transferred chunk :type chunk_size: int :param ignore_existing_process:An optional parameter specifying whether to ignore an existing active process ID :type ignore_existing_process: bool :return: A SyncSessionClient instance :rtype: SyncSessionClient """ if chunk_size is not None: self.chunk_size = chunk_size try: sync_session = SyncSession.objects.get(pk=sync_session_id, active=True) except SyncSession.DoesNotExist: raise MorangoResumeSyncError( "Session for ID '{}' not found".format(sync_session_id) ) # check that process of existing session isn't still running if ( not ignore_existing_process and sync_session.process_id and sync_session.process_id != os.getpid() and pid_exists(sync_session.process_id) ): raise MorangoResumeSyncError( "Session process '{}' is still running".format(sync_session.process_id) ) # In order to resume, we need sync sessions on both server and client, otherwise resuming # wouldn't have any benefit try: self._get_sync_session(sync_session) except HTTPError as e: raise MorangoResumeSyncError("Failure resuming sync session") from e # update process id sync_session.process_id = os.getpid() sync_session.save() return SyncSessionClient(self, sync_session)
[docs] def close_sync_session(self, sync_session): # "delete" sync session on server side self._close_sync_session(sync_session) sync_session.active = False sync_session.save()
[docs] def close(self): # close adapters on requests session object self.session.close()
[docs] def get_remote_certificates(self, primary_partition, scope_def_id=None): remote_certs = [] # request certs for this primary partition, where the server also has a private key for remote_certs_resp = self._get_certificate_chain( params={"primary_partition": primary_partition} ) # inflate remote certs into a list of unsaved models for cert in remote_certs_resp.json(): remote_certs.append( Certificate.deserialize(cert["serialized"], cert["signature"]) ) # filter certs by scope definition id, if provided if scope_def_id: remote_certs = [ cert for cert in remote_certs if cert.scope_definition_id == scope_def_id ] return remote_certs
[docs] def certificate_signing_request( self, parent_cert, scope_definition_id, scope_params, userargs=None, password=None, ): # if server cert does not exist locally, retrieve it from server if not Certificate.objects.filter(id=parent_cert.id).exists(): cert_chain_response = self._get_certificate_chain( params={"ancestors_of": parent_cert.id} ) # upon receiving cert chain from server, we attempt to save the chain into our records Certificate.save_certificate_chain( cert_chain_response.json(), expected_last_id=parent_cert.id ) csr_key = Key() # build up data for csr data = { "parent": parent_cert.id, "profile": parent_cert.profile, "scope_definition": scope_definition_id, "scope_version": parent_cert.scope_version, "scope_params": json.dumps(scope_params), "public_key": csr_key.get_public_key_string(), } csr_resp = self._certificate_signing(data, userargs, password) csr_data = csr_resp.json() # verify cert returned from server, and proceed to save into our records csr_cert = Certificate.deserialize( csr_data["serialized"], csr_data["signature"] ) csr_cert.private_key = csr_key csr_cert.check_certificate() csr_cert.save() return csr_cert
[docs] def push_signed_client_certificate_chain( self, local_parent_cert, scope_definition_id, scope_params ): if ALLOW_CERTIFICATE_PUSHING not in self.capabilities: raise MorangoServerDoesNotAllowNewCertPush( "Server does not allow certificate pushing" ) # grab shared public key of server publickey_response = self._get_public_key() # request the server for a one-time-use nonce nonce_response = self._get_nonce() # build up data for csr certificate = Certificate( parent_id=local_parent_cert.id, profile=local_parent_cert.profile, scope_definition_id=scope_definition_id, scope_version=local_parent_cert.scope_version, scope_params=json.dumps(scope_params), public_key=Key( public_key_string=publickey_response.json()[0]["public_key"] ), salt=nonce_response.json()[ "id" ], # for pushing signed certs, we use nonce as salt ) # add ID and signature to the certificate certificate.id = certificate.calculate_uuid() certificate.parent.sign_certificate(certificate) # serialize the chain for sending to server certificate_chain = list(local_parent_cert.get_ancestors(include_self=True)) + [ certificate ] data = json.dumps(CertificateSerializer(certificate_chain, many=True).data) # client sends signed certificate chain to server self._push_certificate_chain(data) # if there are no errors, we can save the pushed certificate certificate.save() return certificate
def _get_public_key(self): return self.session.get(self.urlresolve(api_urls.PUBLIC_KEY)) def _get_nonce(self): return self.session.post(self.urlresolve(api_urls.NONCE)) def _get_certificate_chain(self, params): return self.session.get(self.urlresolve(api_urls.CERTIFICATE), params=params) def _certificate_signing(self, data, userargs, password): # convert user arguments into query str for passing to auth layer if isinstance(userargs, dict): userargs = "&".join( ["{}={}".format(key, val) for (key, val) in userargs.items()] ) return self.session.post( self.urlresolve(api_urls.CERTIFICATE), json=data, auth=(userargs, password) ) def _push_certificate_chain(self, data): return self.session.post(self.urlresolve(api_urls.CERTIFICATE_CHAIN), json=data) def _create_sync_session(self, data): return self.session.post(self.urlresolve(api_urls.SYNCSESSION), json=data) def _get_sync_session(self, sync_session): return self.session.get( self.urlresolve(api_urls.SYNCSESSION, lookup=sync_session.id) ) def _create_transfer_session(self, data): return self.session.post(self.urlresolve(api_urls.TRANSFERSESSION), json=data) def _get_transfer_session(self, transfer_session): return self.session.get( self.urlresolve(api_urls.TRANSFERSESSION, lookup=transfer_session.id) ) def _update_transfer_session(self, data, transfer_session): return self.session.patch( self.urlresolve(api_urls.TRANSFERSESSION, lookup=transfer_session.id), json=data, ) def _close_transfer_session(self, transfer_session): return self.session.delete( self.urlresolve(api_urls.TRANSFERSESSION, lookup=transfer_session.id) ) def _close_sync_session(self, sync_session): return self.session.delete( self.urlresolve(api_urls.SYNCSESSION, lookup=sync_session.id) ) def _push_record_chunk(self, data): # gzip the data if both client and server have gzipping capabilities if GZIP_BUFFER_POST in self.capabilities and GZIP_BUFFER_POST in CAPABILITIES: json_data = json.dumps([dict(el) for el in data]) gzipped_data = compress_string( bytes(json_data.encode("utf-8")), compresslevel=self.compresslevel ) return self.session.post( self.urlresolve(api_urls.BUFFER), data=gzipped_data, headers={"content-type": "application/gzip"}, ) else: return self.session.post(self.urlresolve(api_urls.BUFFER), json=data) def _pull_record_chunk(self, transfer_session): # pull records from server for given transfer session params = { "limit": self.chunk_size, "offset": transfer_session.records_transferred, "transfer_session_id": transfer_session.id, } return self.session.get(self.urlresolve(api_urls.BUFFER), params=params)
[docs] class SyncClientSignals(SyncSignal): """ Class for holding all signal types, attached to `SyncClient` as attribute. All groups are sent the `TransferSession` object via the `transfer_session` keyword argument. """ session = SyncSignalGroup(transfer_session=None) """Signal group firing for each push and pull `TransferSession`.""" queuing = SyncSignalGroup(transfer_session=None) """Queuing signal group for locally or remotely queuing data before transfer.""" transferring = SyncSignalGroup(transfer_session=None) """Transferring signal group for tracking progress of push/pull on `TransferSession`.""" dequeuing = SyncSignalGroup(transfer_session=None) """Dequeuing signal group for locally or remotely dequeuing data after transfer."""
[docs] class SyncSessionClient(object): __slots__ = ( "sync_connection", "sync_session", "signals", "controller", ) def __init__(self, sync_connection, sync_session, controller=None): """ :param sync_connection: NetworkSyncConnection :param sync_session: SyncSession :param controller: SessionController """ self.sync_connection = sync_connection self.sync_session = sync_session self.signals = SyncClientSignals() self.controller = controller or SessionController.build()
[docs] def get_pull_client(self): """ returns ``PullClient`` """ return PullClient(self.sync_connection, self.sync_session, self.controller)
[docs] def get_push_client(self): """ returns ``PushClient`` """ return PushClient(self.sync_connection, self.sync_session, self.controller)
[docs] def initiate_pull(self, sync_filter): """ Deprecated - Please use ``get_pull_client`` and use the client :param sync_filter: Filter """ client = self.get_pull_client() client.signals = self.signals client.initialize(sync_filter) client.run() client.finalize()
[docs] def initiate_push(self, sync_filter): """ Deprecated - Please use ``get_push_client`` and use the client """ client = self.get_push_client() client.signals = self.signals client.initialize(sync_filter) client.run() client.finalize()
[docs] def close_sync_session(self): """ Deprecated - Please use ``NetworkSyncConnection.close_sync_session`` and ``NetworkSyncConnection.close`` """ self.sync_connection.close_sync_session(self.sync_session) self.sync_connection.close()
[docs] class TransferClient(object): """ Base class for handling common operations for initiating syncing and other related operations. """ __slots__ = ( "sync_connection", "sync_session", "controller", "signals", "context", ) def __init__(self, sync_connection, sync_session, controller): """ :param sync_connection: NetworkSyncConnection :param sync_session: SyncSession :param controller: SessionController """ self.sync_connection = sync_connection self.sync_session = sync_session self.controller = controller self.signals = SyncClientSignals() capabilities = sync_connection.server_info.get("capabilities", []) self.context = CompositeSessionContext( [LocalSessionContext(), NetworkSessionContext(sync_connection)], sync_session=sync_session, capabilities=capabilities, ) self.controller.context = self.context @property def current_transfer_session(self): return self.context.transfer_session
[docs] def proceed_to_and_wait_for(self, stage, error_msg=None, callback=None): """ Raises an exception if an ERROR result is received from calling `proceed_to_and_wait_for` :param stage: The stage to proceed to :param error_msg: An error message str to use as the exception message if it errors :param callback: A callback to pass along to the controller """ result = self.controller.proceed_to_and_wait_for(stage, callback=callback) if result == transfer_statuses.ERRORED: raise MorangoError( error_msg or "Stage `{}` failed".format(self.context.stage) ) from self.context.error
[docs] def initialize(self, sync_filter): """ :param sync_filter: Filter """ # set filter on controller self.context.update(sync_filter=sync_filter) # initialize the transfer session self.proceed_to_and_wait_for( transfer_stages.INITIALIZING, error_msg="Failed to initialize transfer session", ) self.signals.session.started.fire( transfer_session=self.current_transfer_session ) # backwards compatibility for the queuing signal as it included both serialization # and queuing originally with self.signals.queuing.send(transfer_session=self.current_transfer_session): # proceeding to queuing on remote will trigger initialization and serialization as well self.proceed_to_and_wait_for(transfer_stages.QUEUING)
[docs] def run(self): """ Execute the transferring portion of the sync """ with self.signals.transferring.send( transfer_session=self.current_transfer_session ) as status: self.proceed_to_and_wait_for( transfer_stages.TRANSFERRING, callback=status.in_progress.fire )
[docs] def finalize(self): with self.signals.dequeuing.send( transfer_session=self.current_transfer_session ): self.proceed_to_and_wait_for(transfer_stages.DESERIALIZING) self.proceed_to_and_wait_for(transfer_stages.CLEANUP) self.signals.session.completed.fire( transfer_session=self.current_transfer_session )
[docs] class PushClient(TransferClient): """ Sync client for pushing to a server """ def __init__(self, *args, **kwargs): super(PushClient, self).__init__(*args, **kwargs) self.context.update(is_push=True)
[docs] class PullClient(TransferClient): """ Sync class to pull from server """ def __init__(self, *args, **kwargs): super(PullClient, self).__init__(*args, **kwargs) self.context.update(is_push=False)