diff --git a/raven/contrib/celery/__init__.py b/raven/contrib/celery/__init__.py index 02f3ab00e..22f13c681 100644 --- a/raven/contrib/celery/__init__.py +++ b/raven/contrib/celery/__init__.py @@ -8,6 +8,7 @@ from __future__ import absolute_import import logging +import inspect from celery.exceptions import SoftTimeLimitExceeded from celery.signals import ( @@ -26,13 +27,12 @@ def filter(self, record): return extra_data.get('internal', record.funcName != '_log_error') -def register_signal(client, ignore_expected=False): - SentryCeleryHandler(client, ignore_expected=ignore_expected).install() +def register_signal(client, ignore_expected=False, context_args=None): + SentryCeleryHandler(client, ignore_expected=ignore_expected, context_args=context_args).install() def register_logger_signal(client, logger=None, loglevel=logging.ERROR): filter_ = CeleryFilter() - handler = SentryHandler(client) handler.setLevel(loglevel) handler.addFilter(filter_) @@ -46,16 +46,16 @@ def process_logger_event(sender, logger, loglevel, logfile, format, if isinstance(h, SentryHandler): h.addFilter(filter_) return False - logger.addHandler(handler) after_setup_logger.connect(process_logger_event, weak=False) class SentryCeleryHandler(object): - def __init__(self, client, ignore_expected=False): + def __init__(self, client, ignore_expected=False, context_args=None): self.client = client self.ignore_expected = ignore_expected + self.context_args = context_args def install(self): task_prerun.connect(self.handle_task_prerun, weak=False) @@ -89,8 +89,31 @@ def process_failure_signal(self, sender, task_id, args, kwargs, einfo, **kw): def handle_task_prerun(self, sender, task_id, task, **kw): self.client.context.activate() + if self.context_args: + context = self.infer_context(task, **kw) + self.set_logger_context(context) self.client.transaction.push(task.name) def handle_task_postrun(self, sender, task_id, task, **kw): self.client.transaction.pop(task.name) self.client.context.clear() + + def infer_context(self, task, **kw): + args = inspect.getargspec(task.run).args + if task._app: + args.pop(0) + tags = {} + for i, arg in enumerate(args): + if arg in self.context_args: + value = kw['args'][i] + tags.update({arg: value}) + for k, v in kw['kwargs'].iteritems(): + if k in self.context_args: + tags.update({k, v}) + return {'tags': tags} + + def set_logger_context(self, context): + logger = logging.getLogger() + for h in logger.handlers: + if isinstance(h, SentryHandler): + h.client.context.merge(context) diff --git a/tests/contrib/test_celery.py b/tests/contrib/test_celery.py index cfa5b0ff8..d6c2ff6fa 100644 --- a/tests/contrib/test_celery.py +++ b/tests/contrib/test_celery.py @@ -19,7 +19,7 @@ def setUp(self): self.celery.conf.CELERY_ALWAYS_EAGER = True self.client = InMemoryClient() - self.handler = SentryCeleryHandler(self.client, ignore_expected=True) + self.handler = SentryCeleryHandler(self.client, ignore_expected=True, context_args=['foo', 'kw_bar']) self.handler.install() self.addCleanup(self.handler.uninstall) @@ -45,6 +45,20 @@ def dummy_task(x, y): dummy_task.delay(1, 0) assert len(self.client.events) == 0 + def test_context_args(self): + @self.celery.task(name='dummy_task') + def dummy_task(foo, bar): + return foo / bar + dummy_task.delay(1, 2) + assert self.client.context + + def test_context_kwargs(self): + @self.celery.task(name='dummy_task') + def dummy_task(foo, kw_bar=2): + return foo / kw_bar + dummy_task.delay(1, kw_bar=2) + assert self.client.context + class CeleryLoggingHandlerTestCase(TestCase): def setUp(self):