"""
``Certificate`` objects are the core of the authentication system which allows the synchronization of data with varying permissions.
Each certificate has a ``private_key`` used for signing (child) certificates (thus giving certain permissions)
and a ``public_key`` used for verifying that a certificate(s) was properly signed.
"""
import json
import logging
import string
from contextlib import contextmanager
import mptt.models
from django.core.management import call_command
from django.db import connection, models, transaction
from django.db.utils import OperationalError
from django.utils import timezone
from morango.errors import (
CertificateIDInvalid,
CertificateProfileInvalid,
CertificateRootScopeInvalid,
CertificateScopeNotSubset,
CertificateSignatureInvalid,
NonceDoesNotExist,
NonceExpired,
)
from morango.sync.backends.utils import load_backend
from morango.utils import _assert
from .fields.crypto import Key, PrivateKeyField, PublicKeyField
from .fields.uuids import UUIDModelMixin
[docs]
class Certificate(mptt.models.MPTTModel, UUIDModelMixin):
uuid_input_fields = ("public_key", "profile", "salt")
parent = models.ForeignKey("Certificate", blank=True, null=True, on_delete=models.CASCADE)
# the Morango profile with which this certificate is associated
profile = models.CharField(max_length=20)
# scope of this certificate, and version of the scope, along with associated params
scope_definition = models.ForeignKey("ScopeDefinition", on_delete=models.CASCADE)
scope_version = models.IntegerField()
scope_params = models.TextField() # JSON dict of values to insert into scope definitions
# track the certificate's public key so we can verify any certificates it signs
public_key = PublicKeyField()
# a salt value to include in the UUID calculation, to prevent CSR requests from forcing ID collisions
salt = models.CharField(max_length=32, blank=True)
# the JSON-serialized copy of all the fields above
serialized = models.TextField()
# signature from the private key of the parent certificate, of the "serialized" field text
signature = models.TextField()
# when we own a certificate, we'll have the private key for it (otherwise not)
_private_key = PrivateKeyField(blank=True, null=True, db_column="private_key")
@property
def private_key(self):
return self._private_key
@private_key.setter
def private_key(self, value):
self._private_key = value
if value and not self.public_key:
self.public_key = Key(public_key_string=self._private_key.get_public_key_string())
[docs]
@classmethod
def generate_root_certificate(cls, scope_def_id, **extra_scope_params):
# attempt to retrieve the requested scope definition object
scope_def = ScopeDefinition.retrieve_by_id(scope_def_id)
# create a certificate model instance
cert = cls()
# set the scope definition foreign key, and read some values off of the scope definition model
cert.scope_definition = scope_def
cert.scope_version = scope_def.version
cert.profile = scope_def.profile
primary_scope_param_key = scope_def.primary_scope_param_key
_assert(
primary_scope_param_key,
"Root cert can only be created for ScopeDefinition that has primary_scope_param_key defined",
)
# generate a key and extract the public key component
cert.private_key = Key()
cert.public_key = Key(public_key_string=cert.private_key.get_public_key_string())
# calculate the certificate's ID on the basis of the profile and public key
cert.id = cert.calculate_uuid()
# set the scope params to include the primary partition value and any additional params
scope_params = {primary_scope_param_key: cert.id}
scope_params.update(extra_scope_params)
cert.scope_params = json.dumps(scope_params)
# self-sign the certificate
cert.sign_certificate(cert)
# save and return the certificate
cert.save()
return cert
[docs]
def has_private_key(self):
return self._private_key is not None
[docs]
def serialize(self):
if not self.id:
self.id = self.calculate_uuid()
data = {
"id": self.id,
"parent_id": self.parent_id,
"profile": self.profile,
"salt": self.salt,
"scope_definition_id": self.scope_definition_id,
"scope_version": self.scope_version,
"scope_params": self.scope_params,
"public_key_string": self.public_key.get_public_key_string(),
}
return json.dumps(data)
[docs]
@classmethod
def deserialize(cls, serialized, signature):
data = json.loads(serialized)
model = cls(
id=data["id"],
parent_id=data["parent_id"],
profile=data["profile"],
salt=data.get("salt") or "",
scope_definition_id=data["scope_definition_id"],
scope_version=data["scope_version"],
scope_params=data["scope_params"],
public_key=Key(public_key_string=data["public_key_string"]),
serialized=serialized,
signature=signature,
)
return model
def _serialize_if_needed(self):
if not self.serialized:
self.serialized = self.serialize()
[docs]
def sign_certificate(self, cert_to_sign):
cert_to_sign._serialize_if_needed()
cert_to_sign.signature = self.sign(cert_to_sign.serialized)
[docs]
def check_certificate(self):
# check that the certificate's ID is properly calculated
if self.id != self.calculate_uuid():
raise CertificateIDInvalid(
"Certificate ID is {} but should be {}".format(self.id, self.calculate_uuid())
)
if not self.parent: # self-signed root certificate
# check that the certificate is properly self-signed
if not self.verify(self.serialized, self.signature):
raise CertificateSignatureInvalid()
# check that the certificate scopes all start with the primary partition value
scope = self.get_scope()
for item in scope.read_filter + scope.write_filter:
if not item.startswith(self.id):
raise CertificateRootScopeInvalid(
"Scope entry {} does not start with primary partition {}".format(
item, self.id
)
)
else: # non-root child certificate
# check that the certificate is properly signed by its parent
if not self.parent.verify(self.serialized, self.signature):
raise CertificateSignatureInvalid()
# check that certificate's scope is a subset of parent's scope
if not self.get_scope().is_subset_of(self.parent.get_scope()):
raise CertificateScopeNotSubset()
# check that certificate is for same profile as parent
if self.profile != self.parent.profile:
raise CertificateProfileInvalid(
"Certificate profile is {} but parent's is {}".format(
self.profile, self.parent.profile
)
)
[docs]
@classmethod
def save_certificate_chain(cls, cert_chain, expected_last_id=None):
# parse the chain from json if needed
if isinstance(cert_chain, str):
cert_chain = json.loads(cert_chain)
# start from the bottom of the chain
cert_data = cert_chain[-1]
# create an in-memory instance of the cert from the serialized data and signature
cert = cls.deserialize(cert_data["serialized"], cert_data["signature"])
# verify the id of the cert matches the id of the outer serialized data
_assert(cert_data["id"] == cert.id, "Serialized ID does not match")
# check that the expected ID matches, if specified
if expected_last_id:
_assert(cert.id == expected_last_id, "ID does not match expected value")
# if cert already exists locally, it's already been verified, so no need to continue
# (this also means we have the full cert chain for it, given the `parent` relations)
try:
return cls.objects.get(id=cert.id)
except cls.DoesNotExist:
pass
# recurse up the certificate chain, until we hit a cert that exists or is the root
if len(cert_chain) > 1:
cls.save_certificate_chain(cert_chain[:-1], expected_last_id=cert.parent_id)
else:
_assert(
not cert.parent_id,
"First cert in chain must be a root cert (no parent)",
)
# ensure the certificate checks out (now that we know its parent, if any, is saved)
cert.check_certificate()
# save the certificate, as it's now fully verified
cert.save()
return cert
[docs]
def sign(self, value):
_assert(self.private_key, "Can only sign using certificates that have private keys")
return self.private_key.sign(value)
[docs]
def verify(self, value, signature):
return self.public_key.verify(value, signature)
[docs]
def get_scope(self):
return self.scope_definition.get_scope(self.scope_params)
@contextmanager
def _attempt_lock_mptt(self):
from morango.sync.utils import lock_partitions
DBBackend = load_backend(connection)
with transaction.atomic():
# Call get_root on the parent as it is already saved in the DB
root_id = self.parent.get_root().id if self.parent else self.id
# lock the partitions in our scope to prevent MPTT tree corruption during concurrent certificate creation
lock_partitions(DBBackend, sync_filter=Filter(root_id) if root_id else None)
yield
@contextmanager
def _lock_mptt(self):
try:
with self._attempt_lock_mptt():
yield
except OperationalError as e:
if "deadlock detected" in e.args[0]:
logging.error(
"Deadlock detected when attempting to lock MPTT partitions, retrying once more"
)
with self._attempt_lock_mptt():
yield
else:
raise
[docs]
def save(self, *args, **kwargs):
with self._lock_mptt():
super().save(*args, **kwargs)
def __str__(self):
if self.scope_definition:
return self.scope_definition.get_description(self.scope_params)
[docs]
class Nonce(UUIDModelMixin):
"""
Stores temporary nonce values used for cryptographic handshakes during syncing.
These nonces are requested by the client, and then generated and stored by the server.
When the client then goes to initiate a sync session, it signs the nonce value using
the private key from the certificate it is using for the session, to prove to the
server that it owns the certificate. The server checks that the nonce exists and hasn't
expired, and then deletes it.
"""
uuid_input_fields = "RANDOM"
timestamp = models.DateTimeField(default=timezone.now)
ip = models.CharField(max_length=100, blank=True)
[docs]
@classmethod
def use_nonce(cls, nonce_value):
with transaction.atomic():
# try fetching the nonce
try:
nonce = cls.objects.get(id=nonce_value)
except cls.DoesNotExist:
raise NonceDoesNotExist()
# check that the nonce hasn't expired
if not (0 < (timezone.now() - nonce.timestamp).total_seconds() < 60):
nonce.delete()
raise NonceExpired()
# now that we've used it, delete the nonce
nonce.delete()
[docs]
class ScopeDefinition(models.Model):
# the identifier used to specify this scope within a certificate
id = models.CharField(primary_key=True, max_length=20)
# the Morango profile with which this scope is associated
profile = models.CharField(max_length=20)
# version number is incremented whenever scope definition is updated
version = models.IntegerField()
# the scope_param key that the primary partition value will be inserted into when generating a root cert
# (if this is not set, then this scope definition cannot be used to generate a root cert)
primary_scope_param_key = models.CharField(max_length=20, blank=True)
# human-readable description
# (can include string template refs to scope params e.g. "Allows syncing data for user ${username}")
description = models.TextField()
# filter templates, in the form of a newline-delimited list of colon-delimited partition strings
# (can include string template refs to scope params e.g. "122211:singleuser:${user_id}")
read_filter_template = models.TextField()
write_filter_template = models.TextField()
read_write_filter_template = models.TextField()
[docs]
@classmethod
def retrieve_by_id(cls, scope_def_id):
try:
return cls.objects.get(id=scope_def_id)
except ScopeDefinition.DoesNotExist:
call_command("loaddata", "scopedefinitions")
return cls.objects.get(id=scope_def_id)
[docs]
def get_scope(self, params):
return Scope(definition=self, params=params)
[docs]
def get_description(self, params):
if isinstance(params, str):
params = json.loads(params)
return string.Template(self.description).safe_substitute(params)
[docs]
class Filter(object):
def __init__(self, filter_str, params=None):
"""
:param filter_str: The partition filter string, which may have multiple separated by newlines
:type filter_str: str
:param params: DEPRECATED: USE Filter.from_template() INSTEAD
:type params: dict|str
"""
if params is not None:
logging.warning(
"DEPRECATED: Constructing a filter with a template and params is deprecated. Use Filter.from_template() instead"
)
filter_str = str(Filter.from_template(filter_str, params=params))
self._filter_tuple = tuple(filter_str.split()) or ("",)
[docs]
def is_subset_of(self, other):
"""
:param other: The other Filter
:type other: Filter
:return: A boolean on whether this Filter is captured within the other Filter
:rtype: bool
"""
for partition in self:
if not other.contains_partition(partition):
return False
return True
[docs]
def contains_partition(self, partition):
"""Returns True if the partition starts with as least one of the partitions in this Filter"""
return partition.startswith(self._filter_tuple)
[docs]
def contains_exact_partition(self, partition):
"""Returns True if the partition exactly matches one of the partitions in this Filter"""
return partition in self._filter_tuple
[docs]
def copy(self):
return Filter(str(self))
def __le__(self, other):
"""Returns True if this Filter is a subset of the other"""
return self.is_subset_of(other)
def __eq__(self, other):
"""Returns True if this Filter has exactly the same partitions as the other"""
if other is None:
return False
for partition in self:
if not other.contains_exact_partition(partition):
return False
for partition in other:
if not self.contains_exact_partition(partition):
return False
return True
def __contains__(self, partition):
"""
Performs a 'startswith' comparison on the partition, determining whether it matches or
is a subset of any partition in this Filter
:param partition: str
:return: A boolean
:rtype: bool
"""
return self.contains_partition(partition)
def __add__(self, other):
"""
The Filter's addition operator overload
:param other: Filter or None
:type other: Filter|None
:return: The combined Filter
:rtype: Filter
"""
if other is None:
return self
# create a list of partition filters, deduplicating them between the two filter objects
partitions = []
partitions.extend(p for p in self if p)
partitions.extend(p for p in other if p and p not in partitions)
return Filter("\n".join(partitions))
def __iter__(self):
"""
:rtype: tuple[str]
"""
return iter(self._filter_tuple)
def __str__(self):
return "\n".join(self._filter_tuple)
def __len__(self):
return len(self._filter_tuple)
[docs]
@classmethod
def add(cls, filter_a, filter_b):
"""
The Filter's addition operator overload is already defensive against None being the
right-hand operand, but this method is defensive against None being the left-hand operand
:param filter_a: A Filter or None
:type filter_a: Filter|None
:param filter_b: A Filter or None
:type filter_b: Filter|None
:return: The combined Filter or None
:rtype: Filter|None
"""
if filter_a is None:
return filter_b
return filter_a + filter_b
[docs]
@classmethod
def from_template(cls, template, params=None):
"""
Create a filter from a string template, which may have params that will be replaced with
values passed to `params`
:param template: The partition filter template
:type template: str
:param params: The param dictionary or JSON object string
:type params: dict|str
:return: The filter with params replaced
:rtype: Filter
"""
if isinstance(params, str):
params = json.loads(params)
params = params or {}
return Filter(string.Template(template).safe_substitute(params))
[docs]
class Scope(object):
def __init__(self, definition, params):
# turn the scope definition filter templates into Filter objects
rw_filter = Filter.from_template(definition.read_write_filter_template, params)
self.read_filter = rw_filter + Filter.from_template(definition.read_filter_template, params)
self.write_filter = rw_filter + Filter.from_template(
definition.write_filter_template, params
)
[docs]
def is_subset_of(self, other):
if not self.read_filter.is_subset_of(other.read_filter):
return False
if not self.write_filter.is_subset_of(other.write_filter):
return False
return True
def __le__(self, other):
return self.is_subset_of(other)
def __eq__(self, other):
return self.read_filter == other.read_filter and self.write_filter == other.write_filter