diff --git a/.gitignore b/.gitignore index d6be2b0a..9de12852 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,5 @@ coverage.xml .eggs/ .python-version venv +test_id_rsa +test_id_rsa.pub diff --git a/AUTHORS b/AUTHORS index ef975283..a4b7ebbc 100644 --- a/AUTHORS +++ b/AUTHORS @@ -92,3 +92,4 @@ Wes Winham Williams Mendez WoLpH dongweiming +SunnyCapt diff --git a/django_celery_beat/admin.py b/django_celery_beat/admin.py index 3d6d89ec..d1d2243d 100644 --- a/django_celery_beat/admin.py +++ b/django_celery_beat/admin.py @@ -81,7 +81,7 @@ def clean(self): regtask = data.get('regtask') if regtask: data['task'] = regtask - if not data['task']: + if not data['task'] and not (self.instance and self.instance.task_signature): exc = forms.ValidationError(_('Need name of task')) self._errors['task'] = self.error_class(exc.messages) raise exc @@ -198,11 +198,16 @@ def toggle_tasks(self, request, queryset): def run_tasks(self, request, queryset): self.celery_app.loader.import_default_modules() - tasks = [(self.celery_app.tasks.get(task.task), - loads(task.args), - loads(task.kwargs), - task.queue) - for task in queryset] + tasks = [ + ( + task.get_verified_task_signature(raise_exceptions=False) + if task.task_signature is not None + else self.celery_app.tasks.get(task.task), + loads(task.args), + loads(task.kwargs), + task.queue + ) for task in queryset + ] if any(t[0] is None for t in tasks): for i, t in enumerate(tasks): @@ -210,7 +215,9 @@ def run_tasks(self, request, queryset): break # variable "i" will be set because list "tasks" is not empty - not_found_task_name = queryset[i].task + not_found_task_name = queryset[i].get_verified_task_signature(raise_exceptions=False).name \ + if queryset[i].task_signature is not None and queryset[i].get_verified_task_signature( + raise_exceptions=False) is not None else queryset[i].task self.message_user( request, @@ -222,7 +229,7 @@ def run_tasks(self, request, queryset): task_ids = [task.apply_async(args=args, kwargs=kwargs, queue=queue) if queue and len(queue) else task.apply_async(args=args, kwargs=kwargs) - for task, args, kwargs, queue in tasks] + for task, args, kwargs, queue in tasks if task is not None] tasks_run = len(task_ids) self.message_user( request, diff --git a/django_celery_beat/migrations/0015_periodictask_task_signature.py b/django_celery_beat/migrations/0015_periodictask_task_signature.py new file mode 100644 index 00000000..301d24c0 --- /dev/null +++ b/django_celery_beat/migrations/0015_periodictask_task_signature.py @@ -0,0 +1,23 @@ +# Generated by Django 2.2.16 on 2020-09-01 10:17 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('django_celery_beat', '0014_remove_clockedschedule_enabled'), + ] + + operations = [ + migrations.AddField( + model_name='periodictask', + name='task_signature', + field=models.BinaryField(help_text="Serialized `celery.canvas.Signature` type's object of task (or chain, group, etc.) got by https://pypi.org/project/dill/", null=True), + ), + migrations.AddField( + model_name='periodictask', + name='task_signature_sign', + field=models.CharField(help_text="Signature (in hex) of serialized `celery.canvas.Signature` type's object (see task_signature field)", max_length=1028, null=True), + ), + ] diff --git a/django_celery_beat/migrations/0016_auto_20200903_1356.py b/django_celery_beat/migrations/0016_auto_20200903_1356.py new file mode 100644 index 00000000..82d49a17 --- /dev/null +++ b/django_celery_beat/migrations/0016_auto_20200903_1356.py @@ -0,0 +1,23 @@ +# Generated by Django 2.2.16 on 2020-09-03 13:56 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('django_celery_beat', '0015_periodictask_task_signature'), + ] + + operations = [ + migrations.AddField( + model_name='periodictask', + name='callback_signature', + field=models.BinaryField(help_text="Serialized `celery.canvas.Signature` type's callback task got by https://pypi.org/project/dill/ (use as link arg in `.apply_async` method)", null=True), + ), + migrations.AddField( + model_name='periodictask', + name='callback_signature_sign', + field=models.CharField(help_text="Signature (in hex) of serialized `celery.canvas.Signature` type's callback task (see callback_signature field)", max_length=1028, null=True), + ), + ] diff --git a/django_celery_beat/migrations/0017_merge_20210421_1344.py b/django_celery_beat/migrations/0017_merge_20210421_1344.py new file mode 100644 index 00000000..54b58d2a --- /dev/null +++ b/django_celery_beat/migrations/0017_merge_20210421_1344.py @@ -0,0 +1,14 @@ +# Generated by Django 3.2 on 2021-04-21 13:44 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('django_celery_beat', '0015_edit_solarschedule_events_choices'), + ('django_celery_beat', '0016_auto_20200903_1356'), + ] + + operations = [ + ] diff --git a/django_celery_beat/models.py b/django_celery_beat/models.py index 583d8b63..08e81623 100644 --- a/django_celery_beat/models.py +++ b/django_celery_beat/models.py @@ -1,8 +1,10 @@ """Database models.""" from datetime import timedelta +import dill import timezone_field from celery import schedules, current_app +from celery.utils.log import get_logger from django.conf import settings from django.core.exceptions import MultipleObjectsReturned, ValidationError from django.core.validators import MaxValueValidator, MinValueValidator @@ -11,10 +13,11 @@ from django.utils.translation import gettext_lazy as _ from . import managers, validators -from .tzcrontab import TzAwareCrontab -from .utils import make_aware, now from .clockedschedule import clocked +from .tzcrontab import TzAwareCrontab +from .utils import make_aware, now, verify_task_signature +logger = get_logger(__name__) DAYS = 'days' HOURS = 'hours' @@ -396,6 +399,27 @@ class PeriodicTask(models.Model): help_text=_('The Name of the Celery Task that Should be Run. ' '(Example: "proj.tasks.import_contacts")'), ) + task_signature = models.BinaryField( + null=True, + help_text='Serialized `celery.canvas.Signature` type\'s object of task (or chain, group, ' + 'etc.) got by https://pypi.org/project/dill/' + ) + callback_signature = models.BinaryField( + null=True, + help_text='Serialized `celery.canvas.Signature` type\'s callback task got ' + 'by https://pypi.org/project/dill/ (use as link arg in `.apply_async` method)' + ) # todo: add support for error_callback (link_error option) + task_signature_sign = models.CharField( + null=True, + max_length=1028, + help_text='Signature (in hex) of serialized `celery.canvas.Signature` type\'s object (see task_signature field)' + ) + callback_signature_sign = models.CharField( + null=True, + max_length=1028, + help_text='Signature (in hex) of serialized `celery.canvas.Signature` type\'s callback ' + 'task (see callback_signature field)' + ) # You can only set ONE of the following schedule FK's # TODO: Redo this as a GenericForeignKey @@ -556,8 +580,8 @@ def validate_unique(self, *args, **kwargs): 'must be set.' ) - err_msg = 'Only one of clocked, interval, crontab, '\ - 'or solar must be set' + err_msg = 'Only one of clocked, interval, crontab, ' \ + 'or solar must be set' if len(selected_schedule_types) > 1: error_info = {} for selected_schedule_type in selected_schedule_types: @@ -578,6 +602,17 @@ def save(self, *args, **kwargs): self.last_run_at = None self._clean_expires() self.validate_unique() + + if self.task_signature: + task = self.get_verified_task_signature().__repr__() + pattern = '' + max_length = PeriodicTask.task.field.max_length - len(pattern) + 2 - 3 + + if len(task) > max_length: + task = pattern.format(task[:max_length] + '...') + + self.task = task + super().save(*args, **kwargs) def _clean_expires(self): @@ -586,6 +621,53 @@ def _clean_expires(self): _('Only one can be set, in expires and expire_seconds') ) + def get_verified_task_signature(self, raise_exceptions=True): + try: + self.get_verified_callback_signature() + except ValueError as e: + err = 'Wrong callback: {} [{}]'.format(e, self) + logger.error(err) + if raise_exceptions: + raise ValueError(err) + return None + + return self._get_verified_obj_signature('task', raise_exceptions) + + def get_verified_callback_signature(self, raise_exceptions=True): + return self._get_verified_obj_signature('callback', raise_exceptions) + + def _get_verified_obj_signature(self, object_name, raise_exceptions): + assert object_name in ('task', 'callback'), ValueError('Unknown object_name') + + obj_signarute = getattr(self, '{}_signature'.format(object_name), None) + obj_signarute_sign = getattr(self, '{}_signature_sign'.format(object_name), None) + + if obj_signarute is None: + return None + + if obj_signarute_sign is None: + err = 'Not found `{}_signature_sign` for `{}` (use django_celery_be' \ + 'at.utils.sign to sign). Task disabled.'.format(object_name, self) + self.enabled = False + self.save(update_fields=['enabled']) + logger.error(err) + if raise_exceptions: + raise ValueError(err) + return None + + obj_signarute = bytes(obj_signarute) + + if not verify_task_signature(obj_signarute, obj_signarute_sign): + err = 'Wrong sign for `{}`. Task disabled.'.format(self) + self.enabled = False + self.save(update_fields=['enabled']) + logger.error(err) + if raise_exceptions: + raise ValueError(err) + return None + + return dill.loads(obj_signarute) + @property def expires_(self): return self.expires or self.expire_seconds diff --git a/django_celery_beat/schedulers.py b/django_celery_beat/schedulers.py index e2606866..5fee37fc 100644 --- a/django_celery_beat/schedulers.py +++ b/django_celery_beat/schedulers.py @@ -1,30 +1,34 @@ """Beat Scheduler Implementation.""" +from __future__ import absolute_import, unicode_literals + import datetime +import importlib import logging import math - -from multiprocessing.util import Finalize +import sys from celery import current_app from celery import schedules -from celery.beat import Scheduler, ScheduleEntry - +# noinspection PyProtectedMember +from celery.beat import Scheduler, ScheduleEntry, SchedulingError from celery.utils.log import get_logger from celery.utils.time import maybe_make_aware -from kombu.utils.encoding import safe_str, safe_repr -from kombu.utils.json import dumps, loads - from django.conf import settings +from django.core.exceptions import ObjectDoesNotExist +# noinspection PyProtectedMember from django.db import transaction, close_old_connections from django.db.utils import DatabaseError, InterfaceError -from django.core.exceptions import ObjectDoesNotExist +from kombu.utils.encoding import safe_str, safe_repr +from kombu.utils.json import dumps, loads +# noinspection PyUnresolvedReferences +from multiprocessing.util import Finalize +from .clockedschedule import clocked from .models import ( PeriodicTask, PeriodicTasks, CrontabSchedule, IntervalSchedule, SolarSchedule, ClockedSchedule ) -from .clockedschedule import clocked from .utils import NEVER_CHECK_TIMEOUT # This scheduler must wake up more frequently than the @@ -56,6 +60,8 @@ def __init__(self, model, app=None): self.app = app or current_app._get_current_object() self.name = model.name self.task = model.task + self.task_signature = model.get_verified_task_signature() + try: self.schedule = model.schedule except model.DoesNotExist: @@ -74,7 +80,10 @@ def __init__(self, model, app=None): ) self._disable(model) - self.options = {} + self.options = { + 'link': model.get_verified_callback_signature() + } + for option in ['queue', 'exchange', 'routing_key', 'priority']: value = getattr(model, option) if value is None: @@ -229,6 +238,7 @@ def __init__(self, *args, **kwargs): """Initialize the database scheduler.""" self._dirty = set() Scheduler.__init__(self, *args, **kwargs) + # noinspection PyUnresolvedReferences self._finalize = Finalize(self, self.sync, exitpriority=5) self.max_interval = ( kwargs.get('max_interval') @@ -368,3 +378,33 @@ def schedule(self): repr(entry) for entry in self._schedule.values()), ) return self._schedule + + def apply_async(self, entry, producer=None, advance=True, **kwargs): + entry = self.reserve(entry) if advance else entry + task = entry.task_signature + + if hasattr(self.app.conf, 'call_before_run_periodic_task'): + # if app.conf has a field call_before_run_periodic_task + # then we try to import and run all the specified functions + for func_ref in self.app.conf.call_before_run_periodic_task: + func_ref = func_ref.split('.') + callback = importlib.import_module( + '.'.join(func_ref[:-1]) + ).__getattribute__(func_ref[-1]) + callback(task=task, entry=entry, producer=producer, advance=advance, **kwargs) + + if entry.task_signature is None: + return super(DatabaseScheduler, self).apply_async(entry, producer=producer, advance=advance, **kwargs) + + try: + return task.apply_async(producer=producer, **entry.options) + except Exception as exc: # pylint: disable=broad-except + e = SchedulingError( + "Couldn't apply scheduled task {0.name}: {exc}".format(entry, exc=exc) + ) + raise e.with_traceback(sys.exc_info()[2]) + + finally: + self._tasks_since_sync += 1 + if self.should_sync(): + self._do_sync() diff --git a/django_celery_beat/utils.py b/django_celery_beat/utils.py index c19f4edb..ae308eb5 100644 --- a/django_celery_beat/utils.py +++ b/django_celery_beat/utils.py @@ -1,8 +1,14 @@ """Utilities.""" +import os +from hashlib import sha256 + +import Crypto.PublicKey.RSA as RSA # -- XXX This module must not use translation as that causes # -- a recursive loader import! +from celery.utils.log import get_logger from django.conf import settings from django.utils import timezone +from functools import lru_cache is_aware = timezone.is_aware # celery schedstate return None will make it not work @@ -11,6 +17,66 @@ # see Issue #222 now_localtime = getattr(timezone, 'template_localtime', timezone.localtime) +logger = get_logger(__name__) + + +def generate_keys( + private_key_path=os.environ.get('DJANGO_CELERY_BEAT_PRIVATE_KEY_PATH', './id_rsa'), + public_key_path=os.environ.get('DJANGO_CELERY_BEAT_PUBLIC_KEY_PATH', './id_rsa.pub') +): + private_key = RSA.generate(4096, os.urandom) + public_key = private_key.publickey() + + if os.path.exists(private_key_path): + raise FileExistsError(private_key_path) + + if os.path.exists(public_key_path): + raise FileExistsError(public_key_path) + + open(private_key_path, 'wb').close() + os.chmod(private_key_path, 0o600) + with open(private_key_path, 'wb') as id_rsa: + id_rsa.write(private_key.exportKey()) + + open(public_key_path, 'wb').close() + os.chmod(public_key_path, 0o644) + with open(public_key_path, 'wb') as id_rsa_pub: + id_rsa_pub.write(public_key.exportKey()) + + +@lru_cache(maxsize=None) +def _load_private_key(): + private_key_path = os.environ.get('DJANGO_CELERY_BEAT_PRIVATE_KEY_PATH', './id_rsa') + + if os.path.exists(private_key_path): + with open(private_key_path, 'rb') as id_rsa: + private_key = RSA.importKey(id_rsa.read()) + return private_key + + raise FileNotFoundError( + 'Private key not found. Use `django_celery_beat.utils.generate_keys` ' + 'to generate new RSA keys... [{}]'.format(private_key_path) + ) + + +@lru_cache(maxsize=None) +def _load_public_key(): + public_key_path = os.environ.get('DJANGO_CELERY_BEAT_PUBLIC_KEY_PATH', './id_rsa.pub') + + if os.path.exists(public_key_path): + with open(public_key_path, 'rb') as id_rsa_pub: + _public_key = RSA.importKey(id_rsa_pub.read()) + return _public_key + + raise FileNotFoundError( + 'Private key not found. Use `django_celery_beat.utils.generate_keys` ' + 'to generate new RSA keys... [{}]'.format(public_key_path) + ) + + +def _load_keys(): + return _load_private_key(), _load_public_key() + def make_aware(value): """Force datatime to have timezone information.""" @@ -46,3 +112,18 @@ def is_database_scheduler(scheduler): scheduler == 'django' or issubclass(symbol_by_name(scheduler), DatabaseScheduler) ) + + +def sign_task_signature(serialized_task_signature): + """Sign the bytes data to protect against database changes and return signature in hex""" + private_key = _load_private_key() + + assert isinstance(serialized_task_signature, bytes), ValueError('Data must be bytes') + return hex(private_key.sign(sha256(serialized_task_signature).hexdigest().encode(), '')[0]) + + +def verify_task_signature(serialized_task_signature, sign_in_hex): + """Check the signature and return True if it is correct for the specified data""" + public_key = _load_public_key() + + return public_key.verify(sha256(serialized_task_signature).hexdigest().encode(), (int(sign_in_hex, 16),)) diff --git a/requirements/default.txt b/requirements/default.txt index 467da065..fa6f9df2 100644 --- a/requirements/default.txt +++ b/requirements/default.txt @@ -1,3 +1,6 @@ celery>=4.4,<6.0 django-timezone-field>=4.1.0,<5.0 python-crontab>=2.3.4 +dill +pycrypto +django-appconf diff --git a/t/proj/settings.py b/t/proj/settings.py index a09c4e33..e8d9fa5f 100644 --- a/t/proj/settings.py +++ b/t/proj/settings.py @@ -12,6 +12,8 @@ import os import sys +from django_celery_beat.utils import generate_keys + CELERY_DEFAULT_EXCHANGE = 'testcelery' CELERY_DEFAULT_ROUTING_KEY = 'testcelery' CELERY_DEFAULT_QUEUE = 'testcelery' @@ -122,3 +124,19 @@ STATIC_URL = '/static/' DJANGO_CELERY_BEAT_TZ_AWARE = True + +PRIVATE_KEY_PATH = './test_id_rsa' +PUBLIC_KEY_PATH = './test_id_rsa.pub' + +os.environ.update({ + 'DJANGO_CELERY_BEAT_PRIVATE_KEY_PATH': PRIVATE_KEY_PATH, + 'DJANGO_CELERY_BEAT_PUBLIC_KEY_PATH': PUBLIC_KEY_PATH, +}) + +try: + generate_keys( + private_key_path=PRIVATE_KEY_PATH, + public_key_path=PUBLIC_KEY_PATH + ) +except FileExistsError: + pass diff --git a/t/unit/test_models.py b/t/unit/test_models.py index 627e47b5..bbc1370c 100644 --- a/t/unit/test_models.py +++ b/t/unit/test_models.py @@ -1,24 +1,29 @@ import os +import random +import string +import dill +import timezone_field from celery import schedules -from django.test import TestCase, override_settings +from celery.canvas import Signature from django.apps import apps -from django.db.migrations.state import ProjectState from django.db.migrations.autodetector import MigrationAutodetector from django.db.migrations.loader import MigrationLoader from django.db.migrations.questioner import NonInteractiveMigrationQuestioner +from django.db.migrations.state import ProjectState +from django.test import TestCase, override_settings from django.utils import timezone -import timezone_field - from django_celery_beat import migrations as beat_migrations from django_celery_beat.models import ( crontab_schedule_celery_timezone, SolarSchedule, CrontabSchedule, ClockedSchedule, + PeriodicTask, IntervalSchedule, ) +from django_celery_beat.utils import sign_task_signature class MigrationTests(TestCase): @@ -146,3 +151,33 @@ class ClockedScheduleTestCase(TestCase, TestDuplicatesMixin): def test_duplicate_schedules(self): kwargs = {'clocked_time': timezone.now()} self._test_duplicate_schedules(ClockedSchedule, kwargs) + + +class PeriodicTaskSignatureTestCase(TestCase): + test_private_key_path = './test_id_rsa' + test_public_key_path = './test_id_rsa.pub' + + def test_periodic_task_with_signatures(self): + empty_task_signature = Signature(task='empty_task') + + serialized_empty_task = dill.dumps(empty_task_signature) + s = sign_task_signature(serialized_empty_task) + + interval, _ = IntervalSchedule.objects.get_or_create( + every=2, + period=IntervalSchedule.MINUTES + ) + periodic_task = PeriodicTask.objects.create( + name='test-' + ''.join(random.choices(string.ascii_letters, k=20)), + task_signature=serialized_empty_task, + task_signature_sign=s, + callback_signature=serialized_empty_task, + callback_signature_sign=s, + interval=interval, + ) + + task_signature = periodic_task.get_verified_callback_signature(raise_exceptions=False) + callback_signature = periodic_task.get_verified_callback_signature(raise_exceptions=False) + + self.assertEqual(empty_task_signature, task_signature) + self.assertEqual(empty_task_signature, callback_signature) diff --git a/t/unit/test_utils.py b/t/unit/test_utils.py new file mode 100644 index 00000000..61f4af16 --- /dev/null +++ b/t/unit/test_utils.py @@ -0,0 +1,21 @@ +from unittest import TestCase + +import dill +from celery.canvas import Signature + +from django_celery_beat.utils import sign_task_signature, verify_task_signature + + +class UtilsTests(TestCase): + test_private_key_path = './test_id_rsa' + test_public_key_path = './test_id_rsa.pub' + + def test_sign_verify_task_signature(self): + empty_task_signature = Signature() + + serialized_empty_task = dill.dumps(empty_task_signature) + s = sign_task_signature(serialized_empty_task) + + is_valid = verify_task_signature(serialized_empty_task, s) + + self.assertTrue(is_valid)