import json
import logging
import platform
import uuid
from django.core.exceptions import ValidationError
from django.utils import timezone
from ipware import get_client_ip
from rest_framework import mixins
from rest_framework import pagination
from rest_framework import response
from rest_framework import status
from rest_framework import viewsets
from rest_framework.parsers import JSONParser
import morango
from morango import errors
from morango.api import permissions
from morango.api import serializers
from morango.constants import transfer_stages
from morango.constants import transfer_statuses
from morango.constants.capabilities import ASYNC_OPERATIONS
from morango.constants.capabilities import GZIP_BUFFER_POST
from morango.models import certificates
from morango.models.core import Buffer
from morango.models.core import Certificate
from morango.models.core import InstanceIDModel
from morango.models.core import SyncSession
from morango.models.core import TransferSession
from morango.models.fields.crypto import SharedKey
from morango.sync.context import LocalSessionContext
from morango.sync.controller import SessionController
from morango.utils import _assert
from morango.utils import CAPABILITIES
from morango.utils import parse_capabilities_from_server_request
if GZIP_BUFFER_POST in CAPABILITIES:
from .parsers import GzipParser
parsers = (GzipParser, JSONParser)
else:
parsers = (JSONParser,)
[docs]
def controller_signal_logger(context=None):
_assert(context is not None, "Missing context")
if context.stage_status == transfer_statuses.PENDING:
logging.info("Starting stage '{}'".format(context.stage))
elif context.stage_status == transfer_statuses.STARTED:
logging.info("Stage '{}' is in progress".format(context.stage))
elif context.stage_status == transfer_statuses.COMPLETED:
logging.info("Completed stage '{}'".format(context.stage))
elif context.stage_status == transfer_statuses.ERRORED:
logging.info("Encountered error during stage '{}'".format(context.stage))
session_controller = SessionController.build()
session_controller.signals.connect(controller_signal_logger)
[docs]
def get_ip(request):
client_ip, _ = get_client_ip(request)
return client_ip
[docs]
class CertificateChainViewSet(viewsets.ViewSet):
permissions = (permissions.CertificatePushPermissions,)
[docs]
def create(self, request):
# pop last certificate in chain
cert_chain = json.loads(request.data)
client_cert = cert_chain.pop()
# verify the rest of the cert chain
try:
Certificate.save_certificate_chain(cert_chain)
except (AssertionError, errors.MorangoCertificateError) as e:
return response.Response(
"Saving certificate chain has failed: {}".format(str(e)),
status=status.HTTP_403_FORBIDDEN,
)
# create an in-memory instance of the cert from the serialized data and signature
certificate = Certificate.deserialize(
client_cert["serialized"], client_cert["signature"]
)
# check if certificate's public key is in our list of shared keys
try:
sharedkey = SharedKey.objects.get(public_key=certificate.public_key)
except SharedKey.DoesNotExist:
return response.Response(
"Shared public key was not used", status=status.HTTP_400_BAD_REQUEST
)
# set private key
certificate.private_key = sharedkey.private_key
# check that the nonce is valid, and consume it so it can't be used again
try:
certificates.Nonce.use_nonce(certificate.salt)
except errors.MorangoNonceError:
return response.Response(
"Nonce (certificate's salt) is not valid",
status=status.HTTP_403_FORBIDDEN,
)
# verify the certificate (scope is a subset, profiles match, etc)
try:
certificate.check_certificate()
except errors.MorangoCertificateError as e:
return response.Response(
{
"error_class": e.__class__.__name__,
"error_message": getattr(
e, "message", (getattr(e, "args") or ("",))[0]
),
},
status=status.HTTP_400_BAD_REQUEST,
)
# we got this far, and everything looks good, so we can save the certificate
certificate.save()
return response.Response(
"Certificate chain has been saved", status=status.HTTP_201_CREATED
)
[docs]
class CertificateViewSet(
viewsets.mixins.CreateModelMixin,
viewsets.mixins.RetrieveModelMixin,
viewsets.mixins.ListModelMixin,
viewsets.GenericViewSet,
):
permission_classes = (permissions.CertificatePermissions,)
serializer_class = serializers.CertificateSerializer
authentication_classes = (permissions.BasicMultiArgumentAuthentication,)
[docs]
def create(self, request):
serialized_cert = serializers.CertificateSerializer(data=request.data)
if serialized_cert.is_valid():
# inflate the provided data into an actual in-memory certificate
certificate = Certificate(**serialized_cert.validated_data)
# add a salt, ID and signature to the certificate
certificate.salt = uuid.uuid4().hex
certificate.id = certificate.calculate_uuid()
certificate.parent.sign_certificate(certificate)
# ensure that the certificate model fields validate
try:
certificate.full_clean()
except ValidationError as e:
return response.Response(e, status=status.HTTP_400_BAD_REQUEST)
# verify the certificate (scope is a subset, profiles match, etc)
try:
certificate.check_certificate()
except errors.MorangoCertificateError as e:
return response.Response(
{
"error_class": e.__class__.__name__,
"error_message": getattr(
e, "message", (getattr(e, "args") or ("",))[0]
),
},
status=status.HTTP_400_BAD_REQUEST,
)
# we got this far, and everything looks good, so we can save the certificate
certificate.save()
# return a serialized copy of the signed certificate to the client
return response.Response(
serializers.CertificateSerializer(certificate).data,
status=status.HTTP_201_CREATED,
)
else:
return response.Response(
serialized_cert.errors, status=status.HTTP_400_BAD_REQUEST
)
[docs]
def get_queryset(self):
params = self.request.query_params
base_queryset = Certificate.objects
# filter by profile, if requested
if "profile" in params:
base_queryset = base_queryset.filter(profile=params["profile"])
try:
# if specified, filter by primary partition, and only include certs the server owns
if "primary_partition" in params:
target_cert = base_queryset.get(id=params["primary_partition"])
return target_cert.get_descendants(include_self=True).exclude(
_private_key=None
)
# if specified, return the certificate chain for a certificate owned by the server
if "ancestors_of" in params:
target_cert = base_queryset.exclude(_private_key=None).get(
id=params["ancestors_of"]
)
return target_cert.get_ancestors(include_self=True)
except Certificate.DoesNotExist:
# if the target_cert can't be found, just return an empty queryset
return base_queryset.none()
# if no filters were specified, just return all certificates owned by the server
return base_queryset.exclude(_private_key=None)
[docs]
class NonceViewSet(viewsets.mixins.CreateModelMixin, viewsets.GenericViewSet):
serializer_class = serializers.NonceSerializer
[docs]
def create(self, request):
nonce = certificates.Nonce.objects.create(ip=get_ip(request))
return response.Response(
serializers.NonceSerializer(nonce).data, status=status.HTTP_201_CREATED
)
[docs]
class SyncSessionViewSet(
viewsets.mixins.DestroyModelMixin,
viewsets.mixins.RetrieveModelMixin,
viewsets.GenericViewSet,
):
serializer_class = serializers.SyncSessionSerializer
[docs]
def create(self, request):
server_instance, _ = InstanceIDModel.get_or_create_current_instance()
# verify and save the certificate chain to our cert store
try:
Certificate.save_certificate_chain(
request.data.get("certificate_chain"),
expected_last_id=request.data.get("client_certificate_id"),
)
except (AssertionError, errors.MorangoCertificateError):
return response.Response(
"Saving certificate chain has failed", status=status.HTTP_403_FORBIDDEN
)
# attempt to load the requested certificates
try:
server_cert = Certificate.objects.get(
id=request.data.get("server_certificate_id")
)
client_cert = Certificate.objects.get(
id=request.data.get("client_certificate_id")
)
except Certificate.DoesNotExist:
return response.Response(
"Requested certificate does not exist!",
status=status.HTTP_400_BAD_REQUEST,
)
if server_cert.profile != client_cert.profile:
return response.Response(
"Certificates must both be associated with the same profile",
status=status.HTTP_400_BAD_REQUEST,
)
# check that the nonce/id were properly signed
message = "{nonce}:{id}".format(
nonce=request.data.get("nonce"), id=request.data.get("id")
)
if not client_cert.verify(message, request.data["signature"]):
return response.Response(
"Client certificate failed to verify signature",
status=status.HTTP_403_FORBIDDEN,
)
# check that the nonce is valid, and consume it so it can't be used again
try:
certificates.Nonce.use_nonce(request.data["nonce"])
except errors.MorangoNonceError:
return response.Response(
"Nonce is not valid", status=status.HTTP_403_FORBIDDEN
)
client_instance_json = request.data.get("instance")
client_instance_id = None
if client_instance_json:
client_instance = json.loads(client_instance_json)
client_instance_id = client_instance.get("id")
# build the data to be used for creation the syncsession
data = {
"id": request.data.get("id"),
"start_timestamp": timezone.now(),
"last_activity_timestamp": timezone.now(),
"active": True,
"is_server": True,
"client_certificate": client_cert,
"server_certificate": server_cert,
"profile": server_cert.profile,
"connection_kind": "network",
"connection_path": request.data.get("connection_path"),
"client_ip": get_ip(request) or "",
"server_ip": request.data.get("server_ip") or "",
"client_instance_id": client_instance_id,
"client_instance_json": client_instance_json,
"server_instance_id": server_instance.id,
"server_instance_json": json.dumps(
serializers.InstanceIDSerializer(server_instance).data
),
}
syncsession = SyncSession(**data)
syncsession.full_clean()
syncsession.save()
resp_data = {
"signature": server_cert.sign(message),
"server_instance": data["server_instance_json"],
}
return response.Response(resp_data, status=status.HTTP_201_CREATED)
[docs]
def get_queryset(self):
return SyncSession.objects.filter(active=True)
[docs]
class TransferSessionViewSet(
viewsets.mixins.RetrieveModelMixin,
viewsets.mixins.UpdateModelMixin,
viewsets.mixins.DestroyModelMixin,
viewsets.GenericViewSet,
):
serializer_class = serializers.TransferSessionSerializer
[docs]
def create(self, request): # noqa: C901
# attempt to load the requested syncsession
try:
syncsession = SyncSession.objects.filter(active=True).get(
id=request.data.get("sync_session_id")
)
except SyncSession.DoesNotExist:
return response.Response(
"Requested syncsession does not exist or is no longer active!",
status=status.HTTP_400_BAD_REQUEST,
)
# a push is to transfer data from client to server; a pull is the inverse
is_a_push = request.data.get("push")
# check that the requested filter is within the appropriate certificate scopes
scope_error_msg = None
requested_filter = certificates.Filter(request.data.get("filter"))
server_scope = syncsession.server_certificate.get_scope()
client_scope = syncsession.client_certificate.get_scope()
if is_a_push:
if not requested_filter.is_subset_of(client_scope.write_filter):
scope_error_msg = "Client certificate scope does not permit pushing for the requested filter."
if not requested_filter.is_subset_of(server_scope.read_filter):
scope_error_msg = "Server certificate scope does not permit receiving pushes for the requested filter."
else:
if not requested_filter.is_subset_of(client_scope.read_filter):
scope_error_msg = "Client certificate scope does not permit pulling for the requested filter."
if not requested_filter.is_subset_of(server_scope.write_filter):
scope_error_msg = "Server certificate scope does not permit responding to pulls for the requested filter."
if scope_error_msg:
return response.Response(scope_error_msg, status=status.HTTP_403_FORBIDDEN)
context = LocalSessionContext.from_request(
request,
sync_session=syncsession,
sync_filter=requested_filter,
is_push=is_a_push,
)
# If both client and ourselves allow async, we just return accepted status, and the client
# should PATCH the transfer_session to the appropriate stage. If not async, we wait until
# queuing is complete
to_stage = (
transfer_stages.INITIALIZING
if self.async_allowed()
else transfer_stages.QUEUING
)
result = session_controller.proceed_to_and_wait_for(
to_stage, context=context, max_interval=2
)
if result == transfer_statuses.ERRORED:
if context.error:
raise context.error
return response.Response(
"Failed to initialize session",
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
if result == transfer_statuses.COMPLETED:
response_status = status.HTTP_201_CREATED
else:
response_status = status.HTTP_202_ACCEPTED
return response.Response(
self.get_serializer(context.transfer_session).data,
status=response_status,
)
[docs]
def update(self, request, *args, **kwargs):
if not kwargs.get("partial", False):
return response.Response(
"Only PATCH updates allowed", status=status.HTTP_405_METHOD_NOT_ALLOWED
)
update_stage = request.data.pop("transfer_stage", None)
if update_stage is not None:
# if client is trying to update `transfer_stage`, then we use the controller to proceed
# to the stage, but wait for completion if both do not support async
context = LocalSessionContext.from_request(
request,
transfer_session=self.get_object(),
)
# special case for transferring, not to wait since it's a chunked process
if self.async_allowed() or update_stage == transfer_stages.TRANSFERRING:
session_controller.proceed_to(update_stage, context=context)
else:
session_controller.proceed_to_and_wait_for(
update_stage, context=context, max_interval=2
)
return super(TransferSessionViewSet, self).update(request, *args, **kwargs)
[docs]
def get_queryset(self):
return TransferSession.objects.filter(active=True)
[docs]
def async_allowed(self):
"""
:return: A boolean if async ops are allowed by client and self
"""
client_capabilities = parse_capabilities_from_server_request(self.request)
return (
ASYNC_OPERATIONS in client_capabilities and ASYNC_OPERATIONS in CAPABILITIES
)
[docs]
class BufferViewSet(mixins.ListModelMixin, viewsets.GenericViewSet):
permission_classes = (permissions.BufferPermissions,)
serializer_class = serializers.BufferSerializer
pagination_class = pagination.LimitOffsetPagination
parser_classes = parsers
[docs]
def create(self, request):
data = request.data if isinstance(request.data, list) else [request.data]
# ensure the transfer session allows pushes, and is same across records
transfer_session = TransferSession.objects.get(id=data[0]["transfer_session"])
if not transfer_session.push:
return response.Response(
"Specified TransferSession does not allow pushes.",
status=status.HTTP_403_FORBIDDEN,
)
if len(set(rec["transfer_session"] for rec in data)) > 1:
return response.Response(
"All pushed records must be associated with the same TransferSession.",
status=status.HTTP_403_FORBIDDEN,
)
context = LocalSessionContext.from_request(
request, transfer_session=transfer_session
)
result = session_controller.proceed_to(
transfer_stages.TRANSFERRING, context=context
)
if result == transfer_statuses.ERRORED:
if context.error:
raise context.error
else:
response_status = status.HTTP_500_INTERNAL_SERVER_ERROR
else:
response_status = status.HTTP_201_CREATED
return response.Response(status=response_status)
[docs]
def get_queryset(self):
session_id = self.request.query_params["transfer_session_id"]
return Buffer.objects.filter(transfer_session_id=session_id).order_by("pk")
[docs]
class MorangoInfoViewSet(viewsets.ViewSet):
[docs]
def retrieve(self, request, pk=None):
(id_model, _) = InstanceIDModel.get_or_create_current_instance()
# include custom instance info as well
m_info = id_model.instance_info.copy()
m_info.update(
{
"instance_hash": id_model.get_proquint(),
"instance_id": id_model.id,
"system_os": platform.system(),
"version": morango.__version__,
"capabilities": CAPABILITIES,
}
)
return response.Response(m_info)
[docs]
class PublicKeyViewSet(viewsets.ReadOnlyModelViewSet):
permission_classes = (permissions.CertificatePushPermissions,)
serializer_class = serializers.SharedKeySerializer
[docs]
def get_queryset(self):
return SharedKey.objects.filter(current=True)