Source code for morango.models.certificates

"""
``Certificate`` objects are the core of the authentication system which allows the synchronization of data with varying permissions.
Each certificate has a ``private_key`` used for signing (child) certificates (thus giving certain permissions)
and a ``public_key`` used for verifying that a certificate(s) was properly signed.
"""

import json
import logging
import string
from contextlib import contextmanager

import mptt.models
from django.core.management import call_command
from django.db import connection, models, transaction
from django.db.utils import OperationalError
from django.utils import timezone

from morango.errors import (
    CertificateIDInvalid,
    CertificateProfileInvalid,
    CertificateRootScopeInvalid,
    CertificateScopeNotSubset,
    CertificateSignatureInvalid,
    NonceDoesNotExist,
    NonceExpired,
)
from morango.sync.backends.utils import load_backend
from morango.utils import _assert

from .fields.crypto import Key, PrivateKeyField, PublicKeyField
from .fields.uuids import UUIDModelMixin


[docs] class Certificate(mptt.models.MPTTModel, UUIDModelMixin): uuid_input_fields = ("public_key", "profile", "salt") parent = models.ForeignKey("Certificate", blank=True, null=True, on_delete=models.CASCADE) # the Morango profile with which this certificate is associated profile = models.CharField(max_length=20) # scope of this certificate, and version of the scope, along with associated params scope_definition = models.ForeignKey("ScopeDefinition", on_delete=models.CASCADE) scope_version = models.IntegerField() scope_params = models.TextField() # JSON dict of values to insert into scope definitions # track the certificate's public key so we can verify any certificates it signs public_key = PublicKeyField() # a salt value to include in the UUID calculation, to prevent CSR requests from forcing ID collisions salt = models.CharField(max_length=32, blank=True) # the JSON-serialized copy of all the fields above serialized = models.TextField() # signature from the private key of the parent certificate, of the "serialized" field text signature = models.TextField() # when we own a certificate, we'll have the private key for it (otherwise not) _private_key = PrivateKeyField(blank=True, null=True, db_column="private_key") @property def private_key(self): return self._private_key @private_key.setter def private_key(self, value): self._private_key = value if value and not self.public_key: self.public_key = Key(public_key_string=self._private_key.get_public_key_string())
[docs] @classmethod def generate_root_certificate(cls, scope_def_id, **extra_scope_params): # attempt to retrieve the requested scope definition object scope_def = ScopeDefinition.retrieve_by_id(scope_def_id) # create a certificate model instance cert = cls() # set the scope definition foreign key, and read some values off of the scope definition model cert.scope_definition = scope_def cert.scope_version = scope_def.version cert.profile = scope_def.profile primary_scope_param_key = scope_def.primary_scope_param_key _assert( primary_scope_param_key, "Root cert can only be created for ScopeDefinition that has primary_scope_param_key defined", ) # generate a key and extract the public key component cert.private_key = Key() cert.public_key = Key(public_key_string=cert.private_key.get_public_key_string()) # calculate the certificate's ID on the basis of the profile and public key cert.id = cert.calculate_uuid() # set the scope params to include the primary partition value and any additional params scope_params = {primary_scope_param_key: cert.id} scope_params.update(extra_scope_params) cert.scope_params = json.dumps(scope_params) # self-sign the certificate cert.sign_certificate(cert) # save and return the certificate cert.save() return cert
[docs] def has_private_key(self): return self._private_key is not None
[docs] def serialize(self): if not self.id: self.id = self.calculate_uuid() data = { "id": self.id, "parent_id": self.parent_id, "profile": self.profile, "salt": self.salt, "scope_definition_id": self.scope_definition_id, "scope_version": self.scope_version, "scope_params": self.scope_params, "public_key_string": self.public_key.get_public_key_string(), } return json.dumps(data)
[docs] @classmethod def deserialize(cls, serialized, signature): data = json.loads(serialized) model = cls( id=data["id"], parent_id=data["parent_id"], profile=data["profile"], salt=data.get("salt") or "", scope_definition_id=data["scope_definition_id"], scope_version=data["scope_version"], scope_params=data["scope_params"], public_key=Key(public_key_string=data["public_key_string"]), serialized=serialized, signature=signature, ) return model
def _serialize_if_needed(self): if not self.serialized: self.serialized = self.serialize()
[docs] def sign_certificate(self, cert_to_sign): cert_to_sign._serialize_if_needed() cert_to_sign.signature = self.sign(cert_to_sign.serialized)
[docs] def check_certificate(self): # check that the certificate's ID is properly calculated if self.id != self.calculate_uuid(): raise CertificateIDInvalid( "Certificate ID is {} but should be {}".format(self.id, self.calculate_uuid()) ) if not self.parent: # self-signed root certificate # check that the certificate is properly self-signed if not self.verify(self.serialized, self.signature): raise CertificateSignatureInvalid() # check that the certificate scopes all start with the primary partition value scope = self.get_scope() for item in scope.read_filter + scope.write_filter: if not item.startswith(self.id): raise CertificateRootScopeInvalid( "Scope entry {} does not start with primary partition {}".format( item, self.id ) ) else: # non-root child certificate # check that the certificate is properly signed by its parent if not self.parent.verify(self.serialized, self.signature): raise CertificateSignatureInvalid() # check that certificate's scope is a subset of parent's scope if not self.get_scope().is_subset_of(self.parent.get_scope()): raise CertificateScopeNotSubset() # check that certificate is for same profile as parent if self.profile != self.parent.profile: raise CertificateProfileInvalid( "Certificate profile is {} but parent's is {}".format( self.profile, self.parent.profile ) )
[docs] @classmethod def save_certificate_chain(cls, cert_chain, expected_last_id=None): # parse the chain from json if needed if isinstance(cert_chain, str): cert_chain = json.loads(cert_chain) # start from the bottom of the chain cert_data = cert_chain[-1] # create an in-memory instance of the cert from the serialized data and signature cert = cls.deserialize(cert_data["serialized"], cert_data["signature"]) # verify the id of the cert matches the id of the outer serialized data _assert(cert_data["id"] == cert.id, "Serialized ID does not match") # check that the expected ID matches, if specified if expected_last_id: _assert(cert.id == expected_last_id, "ID does not match expected value") # if cert already exists locally, it's already been verified, so no need to continue # (this also means we have the full cert chain for it, given the `parent` relations) try: return cls.objects.get(id=cert.id) except cls.DoesNotExist: pass # recurse up the certificate chain, until we hit a cert that exists or is the root if len(cert_chain) > 1: cls.save_certificate_chain(cert_chain[:-1], expected_last_id=cert.parent_id) else: _assert( not cert.parent_id, "First cert in chain must be a root cert (no parent)", ) # ensure the certificate checks out (now that we know its parent, if any, is saved) cert.check_certificate() # save the certificate, as it's now fully verified cert.save() return cert
[docs] def sign(self, value): _assert(self.private_key, "Can only sign using certificates that have private keys") return self.private_key.sign(value)
[docs] def verify(self, value, signature): return self.public_key.verify(value, signature)
[docs] def get_scope(self): return self.scope_definition.get_scope(self.scope_params)
@contextmanager def _attempt_lock_mptt(self): from morango.sync.utils import lock_partitions DBBackend = load_backend(connection) with transaction.atomic(): # Call get_root on the parent as it is already saved in the DB root_id = self.parent.get_root().id if self.parent else self.id # lock the partitions in our scope to prevent MPTT tree corruption during concurrent certificate creation lock_partitions(DBBackend, sync_filter=Filter(root_id) if root_id else None) yield @contextmanager def _lock_mptt(self): try: with self._attempt_lock_mptt(): yield except OperationalError as e: if "deadlock detected" in e.args[0]: logging.error( "Deadlock detected when attempting to lock MPTT partitions, retrying once more" ) with self._attempt_lock_mptt(): yield else: raise
[docs] def save(self, *args, **kwargs): with self._lock_mptt(): super().save(*args, **kwargs)
def __str__(self): if self.scope_definition: return self.scope_definition.get_description(self.scope_params)
[docs] class Nonce(UUIDModelMixin): """ Stores temporary nonce values used for cryptographic handshakes during syncing. These nonces are requested by the client, and then generated and stored by the server. When the client then goes to initiate a sync session, it signs the nonce value using the private key from the certificate it is using for the session, to prove to the server that it owns the certificate. The server checks that the nonce exists and hasn't expired, and then deletes it. """ uuid_input_fields = "RANDOM" timestamp = models.DateTimeField(default=timezone.now) ip = models.CharField(max_length=100, blank=True)
[docs] @classmethod def use_nonce(cls, nonce_value): with transaction.atomic(): # try fetching the nonce try: nonce = cls.objects.get(id=nonce_value) except cls.DoesNotExist: raise NonceDoesNotExist() # check that the nonce hasn't expired if not (0 < (timezone.now() - nonce.timestamp).total_seconds() < 60): nonce.delete() raise NonceExpired() # now that we've used it, delete the nonce nonce.delete()
[docs] class ScopeDefinition(models.Model): # the identifier used to specify this scope within a certificate id = models.CharField(primary_key=True, max_length=20) # the Morango profile with which this scope is associated profile = models.CharField(max_length=20) # version number is incremented whenever scope definition is updated version = models.IntegerField() # the scope_param key that the primary partition value will be inserted into when generating a root cert # (if this is not set, then this scope definition cannot be used to generate a root cert) primary_scope_param_key = models.CharField(max_length=20, blank=True) # human-readable description # (can include string template refs to scope params e.g. "Allows syncing data for user ${username}") description = models.TextField() # filter templates, in the form of a newline-delimited list of colon-delimited partition strings # (can include string template refs to scope params e.g. "122211:singleuser:${user_id}") read_filter_template = models.TextField() write_filter_template = models.TextField() read_write_filter_template = models.TextField()
[docs] @classmethod def retrieve_by_id(cls, scope_def_id): try: return cls.objects.get(id=scope_def_id) except ScopeDefinition.DoesNotExist: call_command("loaddata", "scopedefinitions") return cls.objects.get(id=scope_def_id)
[docs] def get_scope(self, params): return Scope(definition=self, params=params)
[docs] def get_description(self, params): if isinstance(params, str): params = json.loads(params) return string.Template(self.description).safe_substitute(params)
[docs] class Filter(object): def __init__(self, filter_str, params=None): """ :param filter_str: The partition filter string, which may have multiple separated by newlines :type filter_str: str :param params: DEPRECATED: USE Filter.from_template() INSTEAD :type params: dict|str """ if params is not None: logging.warning( "DEPRECATED: Constructing a filter with a template and params is deprecated. Use Filter.from_template() instead" ) filter_str = str(Filter.from_template(filter_str, params=params)) self._filter_tuple = tuple(filter_str.split()) or ("",)
[docs] def is_subset_of(self, other): """ :param other: The other Filter :type other: Filter :return: A boolean on whether this Filter is captured within the other Filter :rtype: bool """ for partition in self: if not other.contains_partition(partition): return False return True
[docs] def contains_partition(self, partition): """Returns True if the partition starts with as least one of the partitions in this Filter""" return partition.startswith(self._filter_tuple)
[docs] def contains_exact_partition(self, partition): """Returns True if the partition exactly matches one of the partitions in this Filter""" return partition in self._filter_tuple
[docs] def copy(self): return Filter(str(self))
def __le__(self, other): """Returns True if this Filter is a subset of the other""" return self.is_subset_of(other) def __eq__(self, other): """Returns True if this Filter has exactly the same partitions as the other""" if other is None: return False for partition in self: if not other.contains_exact_partition(partition): return False for partition in other: if not self.contains_exact_partition(partition): return False return True def __contains__(self, partition): """ Performs a 'startswith' comparison on the partition, determining whether it matches or is a subset of any partition in this Filter :param partition: str :return: A boolean :rtype: bool """ return self.contains_partition(partition) def __add__(self, other): """ The Filter's addition operator overload :param other: Filter or None :type other: Filter|None :return: The combined Filter :rtype: Filter """ if other is None: return self # create a list of partition filters, deduplicating them between the two filter objects partitions = [] partitions.extend(p for p in self if p) partitions.extend(p for p in other if p and p not in partitions) return Filter("\n".join(partitions)) def __iter__(self): """ :rtype: tuple[str] """ return iter(self._filter_tuple) def __str__(self): return "\n".join(self._filter_tuple) def __len__(self): return len(self._filter_tuple)
[docs] @classmethod def add(cls, filter_a, filter_b): """ The Filter's addition operator overload is already defensive against None being the right-hand operand, but this method is defensive against None being the left-hand operand :param filter_a: A Filter or None :type filter_a: Filter|None :param filter_b: A Filter or None :type filter_b: Filter|None :return: The combined Filter or None :rtype: Filter|None """ if filter_a is None: return filter_b return filter_a + filter_b
[docs] @classmethod def from_template(cls, template, params=None): """ Create a filter from a string template, which may have params that will be replaced with values passed to `params` :param template: The partition filter template :type template: str :param params: The param dictionary or JSON object string :type params: dict|str :return: The filter with params replaced :rtype: Filter """ if isinstance(params, str): params = json.loads(params) params = params or {} return Filter(string.Template(template).safe_substitute(params))
[docs] class Scope(object): def __init__(self, definition, params): # turn the scope definition filter templates into Filter objects rw_filter = Filter.from_template(definition.read_write_filter_template, params) self.read_filter = rw_filter + Filter.from_template(definition.read_filter_template, params) self.write_filter = rw_filter + Filter.from_template( definition.write_filter_template, params )
[docs] def is_subset_of(self, other): if not self.read_filter.is_subset_of(other.read_filter): return False if not self.write_filter.is_subset_of(other.write_filter): return False return True
def __le__(self, other): return self.is_subset_of(other) def __eq__(self, other): return self.read_filter == other.read_filter and self.write_filter == other.write_filter