Source code for morango.models.certificates

"""
``Certificate`` objects are the core of the authentication system which allows the synchronization of data with varying permissions.
Each certificate has a ``private_key`` used for signing (child) certificates (thus giving certain permissions)
and a ``public_key`` used for verifying that a certificate(s) was properly signed.
"""
import json
import string

import mptt.models
from django.core.management import call_command
from django.db import models
from django.db import transaction
from django.utils import timezone

from .fields.crypto import Key
from .fields.crypto import PrivateKeyField
from .fields.crypto import PublicKeyField
from .fields.uuids import UUIDModelMixin
from morango.errors import CertificateIDInvalid
from morango.errors import CertificateProfileInvalid
from morango.errors import CertificateRootScopeInvalid
from morango.errors import CertificateScopeNotSubset
from morango.errors import CertificateSignatureInvalid
from morango.errors import NonceDoesNotExist
from morango.errors import NonceExpired
from morango.utils import _assert


[docs] class Certificate(mptt.models.MPTTModel, UUIDModelMixin): uuid_input_fields = ("public_key", "profile", "salt") parent = models.ForeignKey("Certificate", blank=True, null=True, on_delete=models.CASCADE) # the Morango profile with which this certificate is associated profile = models.CharField(max_length=20) # scope of this certificate, and version of the scope, along with associated params scope_definition = models.ForeignKey("ScopeDefinition", on_delete=models.CASCADE) scope_version = models.IntegerField() scope_params = ( models.TextField() ) # JSON dict of values to insert into scope definitions # track the certificate's public key so we can verify any certificates it signs public_key = PublicKeyField() # a salt value to include in the UUID calculation, to prevent CSR requests from forcing ID collisions salt = models.CharField(max_length=32, blank=True) # the JSON-serialized copy of all the fields above serialized = models.TextField() # signature from the private key of the parent certificate, of the "serialized" field text signature = models.TextField() # when we own a certificate, we'll have the private key for it (otherwise not) _private_key = PrivateKeyField(blank=True, null=True, db_column="private_key") @property def private_key(self): return self._private_key @private_key.setter def private_key(self, value): self._private_key = value if value and not self.public_key: self.public_key = Key( public_key_string=self._private_key.get_public_key_string() )
[docs] @classmethod def generate_root_certificate(cls, scope_def_id, **extra_scope_params): # attempt to retrieve the requested scope definition object scope_def = ScopeDefinition.retrieve_by_id(scope_def_id) # create a certificate model instance cert = cls() # set the scope definition foreign key, and read some values off of the scope definition model cert.scope_definition = scope_def cert.scope_version = scope_def.version cert.profile = scope_def.profile primary_scope_param_key = scope_def.primary_scope_param_key _assert( primary_scope_param_key, "Root cert can only be created for ScopeDefinition that has primary_scope_param_key defined", ) # generate a key and extract the public key component cert.private_key = Key() cert.public_key = Key( public_key_string=cert.private_key.get_public_key_string() ) # calculate the certificate's ID on the basis of the profile and public key cert.id = cert.calculate_uuid() # set the scope params to include the primary partition value and any additional params scope_params = {primary_scope_param_key: cert.id} scope_params.update(extra_scope_params) cert.scope_params = json.dumps(scope_params) # self-sign the certificate cert.sign_certificate(cert) # save and return the certificate cert.save() return cert
[docs] def has_private_key(self): return self._private_key is not None
[docs] def serialize(self): if not self.id: self.id = self.calculate_uuid() data = { "id": self.id, "parent_id": self.parent_id, "profile": self.profile, "salt": self.salt, "scope_definition_id": self.scope_definition_id, "scope_version": self.scope_version, "scope_params": self.scope_params, "public_key_string": self.public_key.get_public_key_string(), } return json.dumps(data)
[docs] @classmethod def deserialize(cls, serialized, signature): data = json.loads(serialized) model = cls( id=data["id"], parent_id=data["parent_id"], profile=data["profile"], salt=data.get("salt") or "", scope_definition_id=data["scope_definition_id"], scope_version=data["scope_version"], scope_params=data["scope_params"], public_key=Key(public_key_string=data["public_key_string"]), serialized=serialized, signature=signature, ) return model
def _serialize_if_needed(self): if not self.serialized: self.serialized = self.serialize()
[docs] def sign_certificate(self, cert_to_sign): cert_to_sign._serialize_if_needed() cert_to_sign.signature = self.sign(cert_to_sign.serialized)
[docs] def check_certificate(self): # check that the certificate's ID is properly calculated if self.id != self.calculate_uuid(): raise CertificateIDInvalid( "Certificate ID is {} but should be {}".format( self.id, self.calculate_uuid() ) ) if not self.parent: # self-signed root certificate # check that the certificate is properly self-signed if not self.verify(self.serialized, self.signature): raise CertificateSignatureInvalid() # check that the certificate scopes all start with the primary partition value scope = self.get_scope() for item in scope.read_filter + scope.write_filter: if not item.startswith(self.id): raise CertificateRootScopeInvalid( "Scope entry {} does not start with primary partition {}".format( item, self.id ) ) else: # non-root child certificate # check that the certificate is properly signed by its parent if not self.parent.verify(self.serialized, self.signature): raise CertificateSignatureInvalid() # check that certificate's scope is a subset of parent's scope if not self.get_scope().is_subset_of(self.parent.get_scope()): raise CertificateScopeNotSubset() # check that certificate is for same profile as parent if self.profile != self.parent.profile: raise CertificateProfileInvalid( "Certificate profile is {} but parent's is {}".format( self.profile, self.parent.profile ) )
[docs] @classmethod def save_certificate_chain(cls, cert_chain, expected_last_id=None): # parse the chain from json if needed if isinstance(cert_chain, str): cert_chain = json.loads(cert_chain) # start from the bottom of the chain cert_data = cert_chain[-1] # create an in-memory instance of the cert from the serialized data and signature cert = cls.deserialize(cert_data["serialized"], cert_data["signature"]) # verify the id of the cert matches the id of the outer serialized data _assert(cert_data["id"] == cert.id, "Serialized ID does not match") # check that the expected ID matches, if specified if expected_last_id: _assert(cert.id == expected_last_id, "ID does not match expected value") # if cert already exists locally, it's already been verified, so no need to continue # (this also means we have the full cert chain for it, given the `parent` relations) try: return cls.objects.get(id=cert.id) except cls.DoesNotExist: pass # recurse up the certificate chain, until we hit a cert that exists or is the root if len(cert_chain) > 1: cls.save_certificate_chain(cert_chain[:-1], expected_last_id=cert.parent_id) else: _assert( not cert.parent_id, "First cert in chain must be a root cert (no parent)", ) # ensure the certificate checks out (now that we know its parent, if any, is saved) cert.check_certificate() # save the certificate, as it's now fully verified cert.save() return cert
[docs] def sign(self, value): _assert( self.private_key, "Can only sign using certificates that have private keys" ) return self.private_key.sign(value)
[docs] def verify(self, value, signature): return self.public_key.verify(value, signature)
[docs] def get_scope(self): return self.scope_definition.get_scope(self.scope_params)
def __str__(self): if self.scope_definition: return self.scope_definition.get_description(self.scope_params)
[docs] class Nonce(UUIDModelMixin): """ Stores temporary nonce values used for cryptographic handshakes during syncing. These nonces are requested by the client, and then generated and stored by the server. When the client then goes to initiate a sync session, it signs the nonce value using the private key from the certificate it is using for the session, to prove to the server that it owns the certificate. The server checks that the nonce exists and hasn't expired, and then deletes it. """ uuid_input_fields = "RANDOM" timestamp = models.DateTimeField(default=timezone.now) ip = models.CharField(max_length=100, blank=True)
[docs] @classmethod def use_nonce(cls, nonce_value): with transaction.atomic(): # try fetching the nonce try: nonce = cls.objects.get(id=nonce_value) except cls.DoesNotExist: raise NonceDoesNotExist() # check that the nonce hasn't expired if not (0 < (timezone.now() - nonce.timestamp).total_seconds() < 60): nonce.delete() raise NonceExpired() # now that we've used it, delete the nonce nonce.delete()
[docs] class ScopeDefinition(models.Model): # the identifier used to specify this scope within a certificate id = models.CharField(primary_key=True, max_length=20) # the Morango profile with which this scope is associated profile = models.CharField(max_length=20) # version number is incremented whenever scope definition is updated version = models.IntegerField() # the scope_param key that the primary partition value will be inserted into when generating a root cert # (if this is not set, then this scope definition cannot be used to generate a root cert) primary_scope_param_key = models.CharField(max_length=20, blank=True) # human-readable description # (can include string template refs to scope params e.g. "Allows syncing data for user ${username}") description = models.TextField() # filter templates, in the form of a newline-delimited list of colon-delimited partition strings # (can include string template refs to scope params e.g. "122211:singleuser:${user_id}") read_filter_template = models.TextField() write_filter_template = models.TextField() read_write_filter_template = models.TextField()
[docs] @classmethod def retrieve_by_id(cls, scope_def_id): try: return cls.objects.get(id=scope_def_id) except ScopeDefinition.DoesNotExist: call_command("loaddata", "scopedefinitions") return cls.objects.get(id=scope_def_id)
[docs] def get_scope(self, params): return Scope(definition=self, params=params)
[docs] def get_description(self, params): if isinstance(params, str): params = json.loads(params) return string.Template(self.description).safe_substitute(params)
[docs] class Filter(object): def __init__(self, template, params={}): # ensure params have been deserialized if isinstance(params, str): params = json.loads(params) self._template = template self._params = params self._filter_string = string.Template(template).safe_substitute(params) self._filter_tuple = tuple(self._filter_string.split()) or ("",)
[docs] def is_subset_of(self, other): for partition in self._filter_tuple: if not partition.startswith(other._filter_tuple): return False return True
[docs] def contains_partition(self, partition): return partition.startswith(self._filter_tuple)
def __le__(self, other): return self.is_subset_of(other) def __eq__(self, other): if other is None: return False for partition in self._filter_tuple: if partition not in other._filter_tuple: return False for partition in other._filter_tuple: if partition not in self._filter_tuple: return False return True def __contains__(self, partition): return self.contains_partition(partition) def __add__(self, other): return Filter(self._filter_string + "\n" + other._filter_string) def __iter__(self): return iter(self._filter_tuple) def __str__(self): return "\n".join(self._filter_tuple) def __len__(self): return len(self._filter_tuple)
[docs] class Scope(object): def __init__(self, definition, params): # turn the scope definition filter templates into Filter objects rw_filter = Filter(definition.read_write_filter_template, params) self.read_filter = rw_filter + Filter(definition.read_filter_template, params) self.write_filter = rw_filter + Filter(definition.write_filter_template, params)
[docs] def is_subset_of(self, other): if not self.read_filter.is_subset_of(other.read_filter): return False if not self.write_filter.is_subset_of(other.write_filter): return False return True
def __le__(self, other): return self.is_subset_of(other) def __eq__(self, other): return ( self.read_filter == other.read_filter and self.write_filter == other.write_filter )