Browse Source

Canvas Header Stamping (#7384)

* Strip down the header-stamping PR to the basics.

* Serialize groups.

* Add groups to result backend meta data.

* Fix spelling mistake.

* Revert changes to canvas.py

* Revert changes to app/base.py

* Add stamping implementation to canvas.py

* Send task to AMQP with groups.

* Successfully pass single group to result.

* _freeze_gid dict merge fixed

* First draft of the visitor API.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* OptionsVisitor created

* Fixed canvas.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added test for simple test for chord and fixed chord implementation

* Changed _IMMUTABLE_OPTIONS

* Fixed chord interface

* Fixed chord interface

* Fixed chord interface

* Fixed chord interface

* Fixed list order

* Fixed tests (stamp test and chord test), fixed order in groups

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed lint and elements

* Changed implementation of stamp API and fix lint

* Added documentation to Stamping API. Added chord with groups test

* Implemented stamping inside replace and added test for an implementation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added test additonal tests for chord, improved coverage

* Added test additonal tests for chord, improved coverage

* Added test additonal tests for chord, improved coverage

* Splitted into subtests

* Group stamping rollback

* group.id is None fixed

* Added integration test

* Added integration test

* apply_async fixed

* Integration test and test_chord fixed

* Lint fixed

* chord freeze fixed

* Minor fixes.

* Chain apply_async fixed and tests fixed

* lint fixed

* Added integration test for chord

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* type -> isinstance

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Redo header stamping (#7341)

* _freeze_gid dict merge fixed

* OptionsVisitor created

* Fixed canvas.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added test for simple test for chord and fixed chord implementation

* Changed _IMMUTABLE_OPTIONS

* Fixed chord interface

* Fixed chord interface

* Fixed chord interface

* Fixed chord interface

* Fixed list order

* Fixed tests (stamp test and chord test), fixed order in groups

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed lint and elements

* Changed implementation of stamp API and fix lint

* Added documentation to Stamping API. Added chord with groups test

* Implemented stamping inside replace and added test for an implementation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added test additonal tests for chord, improved coverage

* Added test additonal tests for chord, improved coverage

* Added test additonal tests for chord, improved coverage

* Splitted into subtests

* Group stamping rollback

* group.id is None fixed

* Added integration test

* Added integration test

* apply_async fixed

* Integration test and test_chord fixed

* Lint fixed

* chord freeze fixed

* Minor fixes.

* Chain apply_async fixed and tests fixed

* lint fixed

* Added integration test for chord

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* type -> isinstance

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Omer Katz <omer.katz@omerkatz.com>

* Added stamping mechanism

* Manual stamping improved

* flake8 fixed

* Added subtests

* Add comma.

* Moved groups to stamps

* Fixed chord and added test for that

* Strip down the header-stamping PR to the basics.

* Serialize groups.

* Add groups to result backend meta data.

* Fix spelling mistake.

* Revert changes to canvas.py

* Revert changes to app/base.py

* Add stamping implementation to canvas.py

* Send task to AMQP with groups.

* Successfully pass single group to result.

* _freeze_gid dict merge fixed

* First draft of the visitor API.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* OptionsVisitor created

* Fixed canvas.py

* Added test for simple test for chord and fixed chord implementation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Changed _IMMUTABLE_OPTIONS

* Fixed chord interface

* Fixed chord interface

* Fixed chord interface

* Fixed chord interface

* Fixed list order

* Fixed tests (stamp test and chord test), fixed order in groups

* Fixed lint and elements

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Changed implementation of stamp API and fix lint

* Added documentation to Stamping API. Added chord with groups test

* Implemented stamping inside replace and added test for an implementation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added test additonal tests for chord, improved coverage

* Added test additonal tests for chord, improved coverage

* Added test additonal tests for chord, improved coverage

* Splitted into subtests

* Group stamping rollback

* group.id is None fixed

* Added integration test

* Added integration test

* apply_async fixed

* Integration test and test_chord fixed

* Lint fixed

* chord freeze fixed

* Minor fixes.

* Chain apply_async fixed and tests fixed

* lint fixed

* Added integration test for chord

* type -> isinstance

* Added stamping mechanism

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Manual stamping improved

* fail_ci_if_error uncommented

* flake8 fixed

* Added subtests

* Changes

* Add comma.

* Fixed chord and added test for that

* canvas.py fixed

* Test chord.py fixed

* Fixed stamped_headers

* collections import fixed

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* collections import fixed

* Update celery/backends/base.py

Co-authored-by: Omer Katz <omer.katz@omerkatz.com>

* ampq.py fixed

* Refrain from using deprecated import path.

* Fix test_complex_chain regression.

Whenever we stamp a group we need to freeze it first if it wasn't already frozen.
Somewhere along the line, the group id changed because we were freezing twice.
This commit places the stamping operation after preparing the chain's steps which fixes the problem somehow.

We don't know why yet.

* Fixed integration tests

* Fixed integration tests

* Fixed integration tests

* Fixed integration tests

* Fixed issues with maybe_list. Add documentation

* Fixed potential issue with integration tests

* Fixed issues with _regen

* Fixed issues with _regen

* Fixed test_generator issues

* Fixed _regen stamping

* Fixed _regen stamping

* Fixed TimeOut issue

* Fixed TimeOut issue

* Fixed TimeOut issue

* Update docs/userguide/canvas.rst

Co-authored-by: Omer Katz <omer.katz@omerkatz.com>

* Fixed Couchbase

* Better stamping intro

* New GroupVisitor example

* Adjust documentation.

Co-authored-by: Naomi Elstein <naomi.els@omerkatz.com>
Co-authored-by: Omer Katz <omer.katz@omerkatz.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Asif Saif Uddin <auvipy@gmail.com>
Co-authored-by: Omer Katz <omer.katz@kcg.tech>
pull/7597/head
dobosevych 2 months ago committed by GitHub
parent
commit
1c4ff33bd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      .github/workflows/python-package.yml
  2. 45
      celery/app/amqp.py
  3. 4
      celery/app/base.py
  4. 18
      celery/app/task.py
  5. 7
      celery/backends/base.py
  6. 308
      celery/canvas.py
  7. 4
      celery/utils/functional.py
  8. 12
      celery/worker/request.py
  9. 88
      docs/userguide/canvas.rst
  10. 3
      t/integration/conftest.py
  11. 7
      t/integration/tasks.py
  12. 225
      t/integration/test_canvas.py
  13. 2
      t/unit/conftest.py
  14. 595
      t/unit/tasks/test_canvas.py
  15. 33
      t/unit/tasks/test_chord.py

1
.github/workflows/python-package.yml

@ -105,6 +105,7 @@ jobs:
- name: Install apt packages
run: |
sudo apt update && sudo apt-get install -f libcurl4-openssl-dev libssl-dev libgnutls28-dev httping expect libmemcached-dev
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4

45
celery/app/amqp.py

@ -284,7 +284,9 @@ class AMQP:
time_limit=None, soft_time_limit=None,
create_sent_event=False, root_id=None, parent_id=None,
shadow=None, chain=None, now=None, timezone=None,
origin=None, ignore_result=False, argsrepr=None, kwargsrepr=None):
origin=None, ignore_result=False, argsrepr=None, kwargsrepr=None, stamped_headers=None,
**options):
args = args or ()
kwargs = kwargs or {}
if not isinstance(args, (list, tuple)):
@ -319,25 +321,30 @@ class AMQP:
if not root_id: # empty root_id defaults to task_id
root_id = task_id
stamps = {header: maybe_list(options[header]) for header in stamped_headers or []}
headers = {
'lang': 'py',
'task': name,
'id': task_id,
'shadow': shadow,
'eta': eta,
'expires': expires,
'group': group_id,
'group_index': group_index,
'retries': retries,
'timelimit': [time_limit, soft_time_limit],
'root_id': root_id,
'parent_id': parent_id,
'argsrepr': argsrepr,
'kwargsrepr': kwargsrepr,
'origin': origin or anon_nodename(),
'ignore_result': ignore_result,
'stamped_headers': stamped_headers,
'stamps': stamps,
}
return task_message(
headers={
'lang': 'py',
'task': name,
'id': task_id,
'shadow': shadow,
'eta': eta,
'expires': expires,
'group': group_id,
'group_index': group_index,
'retries': retries,
'timelimit': [time_limit, soft_time_limit],
'root_id': root_id,
'parent_id': parent_id,
'argsrepr': argsrepr,
'kwargsrepr': kwargsrepr,
'origin': origin or anon_nodename(),
'ignore_result': ignore_result,
},
headers=headers,
properties={
'correlation_id': task_id,
'reply_to': reply_to or '',

4
celery/app/base.py

@ -766,6 +766,7 @@ class Celery:
options.setdefault('priority',
parent.request.delivery_info.get('priority'))
# alias for 'task_as_v2'
message = amqp.create_task_message(
task_id, name, args, kwargs, countdown, eta, group_id, group_index,
expires, retries, chord,
@ -774,8 +775,7 @@ class Celery:
self.conf.task_send_sent_event,
root_id, parent_id, shadow, chain,
ignore_result=ignore_result,
argsrepr=options.get('argsrepr'),
kwargsrepr=options.get('kwargsrepr'),
**options
)
if connection:

18
celery/app/task.py

@ -8,7 +8,7 @@ from kombu.utils.uuid import uuid
from celery import current_app, states
from celery._state import _task_stack
from celery.canvas import _chain, group, signature
from celery.canvas import GroupStampingVisitor, _chain, group, signature
from celery.exceptions import Ignore, ImproperlyConfigured, MaxRetriesExceededError, Reject, Retry
from celery.local import class_property
from celery.result import EagerResult, denied_join_result
@ -93,6 +93,8 @@ class Context:
taskset = None # compat alias to group
timelimit = None
utc = None
stamped_headers = None
stamps = None
def __init__(self, *args, **kwargs):
self.update(*args, **kwargs)
@ -794,8 +796,14 @@ class Task:
'exchange': options.get('exchange'),
'routing_key': options.get('routing_key'),
'priority': options.get('priority'),
},
}
}
if 'stamped_headers' in options:
request['stamped_headers'] = maybe_list(options['stamped_headers'])
request['stamps'] = {
header: maybe_list(options.get(header, [])) for header in request['stamped_headers']
}
tb = None
tracer = build_tracer(
task.name, task, eager=True,
@ -942,6 +950,12 @@ class Task:
# retain their original task IDs as well
for t in reversed(self.request.chain or []):
sig |= signature(t, app=self.app)
# Stamping sig with parents groups
stamped_headers = self.request.stamped_headers
if self.request.stamps:
groups = self.request.stamps.get("groups")
sig.stamp(visitor=GroupStampingVisitor(groups=groups, stamped_headers=stamped_headers))
# Finally, either apply or delay the new signature!
if self.request.is_eager:
return sig.apply().get()

7
celery/backends/base.py

@ -230,7 +230,7 @@ class Backend:
hasattr(errback.type, '__header__') and
# workaround to support tasks with bind=True executed as
# link errors. Otherwise retries can't be used
# link errors. Otherwise, retries can't be used
not isinstance(errback.type.__header__, partial) and
arity_greater(errback.type.__header__, 1)
):
@ -488,8 +488,11 @@ class Backend:
'retries': getattr(request, 'retries', None),
'queue': request.delivery_info.get('routing_key')
if hasattr(request, 'delivery_info') and
request.delivery_info else None
request.delivery_info else None,
}
if getattr(request, 'stamps'):
request_meta['stamped_headers'] = request.stamped_headers
request_meta.update(request.stamps)
if encode:
# args and kwargs need to be encoded properly before saving

308
celery/canvas.py

@ -7,6 +7,7 @@
import itertools
import operator
from abc import ABCMeta, abstractmethod
from collections import deque
from collections.abc import MutableSequence
from copy import deepcopy
@ -56,6 +57,155 @@ def task_name_from(task):
return getattr(task, 'name', task)
def _stamp_regen_task(task, visitor, **headers):
task.stamp(visitor=visitor, **headers)
return task
def _merge_dictionaries(d1, d2):
for key, value in d1.items():
if key in d2:
if isinstance(value, dict):
_merge_dictionaries(d1[key], d2[key])
else:
if isinstance(value, (int, float, str)):
d1[key] = [value]
if isinstance(d2[key], list):
d1[key].extend(d2[key])
else:
if d1[key] is None:
d1[key] = []
else:
d1[key] = list(d1[key])
d1[key].append(d2[key])
for key, value in d2.items():
if key not in d1:
d1[key] = value
class StampingVisitor(metaclass=ABCMeta):
"""Stamping API. A class that provides a stamping API possibility for
canvas primitives. If you want to implement stamping behavior for
a canvas primitive override method that represents it.
"""
@abstractmethod
def on_group_start(self, group, **headers) -> dict:
"""Method that is called on group stamping start.
Arguments:
group (group): Group that is stamped.
headers (Dict): Partial headers that could be merged with existing headers.
Returns:
Dict: headers to update.
"""
pass
def on_group_end(self, group, **headers) -> None:
"""Method that is called on group stamping end.
Arguments:
group (group): Group that is stamped.
headers (Dict): Partial headers that could be merged with existing headers.
"""
pass
@abstractmethod
def on_chain_start(self, chain, **headers) -> dict:
"""Method that is called on chain stamping start.
Arguments:
chain (chain): Chain that is stamped.
headers (Dict): Partial headers that could be merged with existing headers.
Returns:
Dict: headers to update.
"""
pass
def on_chain_end(self, chain, **headers) -> None:
"""Method that is called on chain stamping end.
Arguments:
chain (chain): Chain that is stamped.
headers (Dict): Partial headers that could be merged with existing headers.
"""
pass
@abstractmethod
def on_signature(self, sig, **headers) -> dict:
"""Method that is called on signature stamping.
Arguments:
sig (Signature): Signature that is stamped.
headers (Dict): Partial headers that could be merged with existing headers.
Returns:
Dict: headers to update.
"""
pass
def on_chord_header_start(self, chord, **header) -> dict:
"""Method that is called on сhord header stamping start.
Arguments:
chord (chord): chord that is stamped.
headers (Dict): Partial headers that could be merged with existing headers.
Returns:
Dict: headers to update.
"""
if not isinstance(chord.tasks, group):
chord.tasks = group(chord.tasks)
return self.on_group_start(chord.tasks, **header)
def on_chord_header_end(self, chord, **header) -> None:
"""Method that is called on сhord header stamping end.
Arguments:
chord (chord): chord that is stamped.
headers (Dict): Partial headers that could be merged with existing headers.
"""
self.on_group_end(chord.tasks, **header)
def on_chord_body(self, chord, **header) -> dict:
"""Method that is called on chord body stamping.
Arguments:
chord (chord): chord that is stamped.
headers (Dict): Partial headers that could be merged with existing headers.
Returns:
Dict: headers to update.
"""
return self.on_signature(chord.body, **header)
class GroupStampingVisitor(StampingVisitor):
"""
Group stamping implementation based on Stamping API.
"""
def __init__(self, groups=None, stamped_headers=None):
self.groups = groups or []
self.stamped_headers = stamped_headers or []
if "groups" not in self.stamped_headers:
self.stamped_headers.append("groups")
def on_group_start(self, group, **headers) -> dict:
if group.id is None:
group.set(task_id=uuid())
if group.id not in self.groups:
self.groups.append(group.id)
return {'groups': list(self.groups), "stamped_headers": list(self.stamped_headers)}
def on_group_end(self, group, **headers) -> None:
self.groups.pop()
def on_chain_start(self, chain, **headers) -> dict:
return {'groups': list(self.groups), "stamped_headers": list(self.stamped_headers)}
def on_signature(self, sig, **headers) -> dict:
return {'groups': list(self.groups), "stamped_headers": list(self.stamped_headers)}
@abstract.CallableSignature.register
class Signature(dict):
"""Task Signature.
@ -118,7 +268,7 @@ class Signature(dict):
_app = _type = None
# The following fields must not be changed during freezing/merging because
# to do so would disrupt completion of parent tasks
_IMMUTABLE_OPTIONS = {"group_id"}
_IMMUTABLE_OPTIONS = {"group_id", "stamped_headers"}
@classmethod
def register_type(cls, name=None):
@ -178,6 +328,9 @@ class Signature(dict):
"""
args = args if args else ()
kwargs = kwargs if kwargs else {}
groups = self.options.get("groups")
stamped_headers = self.options.get("stamped_headers")
self.stamp(visitor=GroupStampingVisitor(groups=groups, stamped_headers=stamped_headers))
# Extra options set to None are dismissed
options = {k: v for k, v in options.items() if v is not None}
# For callbacks: extra args are prepended to the stored args.
@ -201,6 +354,9 @@ class Signature(dict):
"""
args = args if args else ()
kwargs = kwargs if kwargs else {}
groups = self.options.get("groups")
stamped_headers = self.options.get("stamped_headers")
self.stamp(visitor=GroupStampingVisitor(groups=groups, stamped_headers=stamped_headers))
# Extra options set to None are dismissed
options = {k: v for k, v in options.items() if v is not None}
try:
@ -225,10 +381,13 @@ class Signature(dict):
# override values in `self.options` except for keys which are
# noted as being immutable (unrelated to signature immutability)
# implying that allowing their value to change would stall tasks
new_options = dict(self.options, **{
immutable_options = self._IMMUTABLE_OPTIONS
if "stamped_headers" in self.options:
immutable_options = self._IMMUTABLE_OPTIONS.union(set(self.options["stamped_headers"]))
new_options = {**self.options, **{
k: v for k, v in options.items()
if k not in self._IMMUTABLE_OPTIONS or k not in self.options
})
if k not in immutable_options or k not in self.options
}}
else:
new_options = self.options
if self.immutable and not force:
@ -334,6 +493,21 @@ class Signature(dict):
def set_immutable(self, immutable):
self.immutable = immutable
def stamp(self, visitor=None, **headers):
"""Apply this task asynchronously.
Arguments:
visitor (StampingVisitor): Visitor API object.
headers (Dict): Stamps that should be added to headers.
"""
headers = headers.copy()
if visitor is not None:
headers.update(visitor.on_signature(self, **headers))
else:
headers["stamped_headers"] = [header for header in headers.keys() if header not in self.options]
_merge_dictionaries(headers, self.options)
return self.set(**headers)
def _with_list_option(self, key):
items = self.options.setdefault(key, [])
if not isinstance(items, MutableSequence):
@ -633,6 +807,7 @@ class _chain(Signature):
args = args if args else ()
kwargs = kwargs if kwargs else []
app = self.app
if app.conf.task_always_eager:
with allow_join_result():
return self.apply(args, kwargs, **options)
@ -659,6 +834,10 @@ class _chain(Signature):
task_id, group_id, chord, group_index=group_index,
)
groups = self.options.get("groups")
stamped_headers = self.options.get("stamped_headers")
self.stamp(visitor=GroupStampingVisitor(groups=groups, stamped_headers=stamped_headers))
if results_from_prepare:
if link:
tasks[0].extend_list_option('link', link)
@ -689,6 +868,17 @@ class _chain(Signature):
)
return results[0]
def stamp(self, visitor=None, **headers):
if visitor is not None:
headers.update(visitor.on_chain_start(self, **headers))
super().stamp(visitor=visitor, **headers)
for task in self.tasks:
task.stamp(visitor=visitor, **headers)
if visitor is not None:
visitor.on_chain_end(self, **headers)
def prepare_steps(self, args, kwargs, tasks,
root_id=None, parent_id=None, link_error=None, app=None,
last_task_id=None, group_id=None, chord_body=None,
@ -728,7 +918,7 @@ class _chain(Signature):
task = from_dict(task, app=app)
if isinstance(task, group):
# when groups are nested, they are unrolled - all tasks within
# groups within groups should be called in parallel
# groups should be called in parallel
task = maybe_unroll_group(task)
# first task gets partial args from chain
@ -816,6 +1006,9 @@ class _chain(Signature):
def apply(self, args=None, kwargs=None, **options):
args = args if args else ()
kwargs = kwargs if kwargs else {}
groups = self.options.get("groups")
stamped_headers = self.options.get("stamped_headers")
self.stamp(visitor=GroupStampingVisitor(groups=groups, stamped_headers=stamped_headers))
last, (fargs, fkwargs) = None, (args, kwargs)
for task in self.tasks:
res = task.clone(fargs, fkwargs).apply(
@ -1097,6 +1290,11 @@ class group(Signature):
options, group_id, root_id = self._freeze_gid(options)
tasks = self._prepared(self.tasks, [], group_id, root_id, app)
groups = self.options.get("groups")
stamped_headers = self.options.get("stamped_headers")
self.stamp(visitor=GroupStampingVisitor(groups=groups, stamped_headers=stamped_headers))
p = barrier()
results = list(self._apply_tasks(tasks, producer, app, p,
args=args, kwargs=kwargs, **options))
@ -1120,6 +1318,9 @@ class group(Signature):
def apply(self, args=None, kwargs=None, **options):
args = args if args else ()
kwargs = kwargs if kwargs else {}
groups = self.options.get("groups")
stamped_headers = self.options.get("stamped_headers")
self.stamp(visitor=GroupStampingVisitor(groups=groups, stamped_headers=stamped_headers))
app = self.app
if not self.tasks:
return self.freeze() # empty group returns GroupResult
@ -1133,6 +1334,28 @@ class group(Signature):
for task in self.tasks:
task.set_immutable(immutable)
def stamp(self, visitor=None, **headers):
if visitor is not None:
headers.update(visitor.on_group_start(self, **headers))
super().stamp(visitor=visitor, **headers)
if isinstance(self.tasks, _regen):
self.tasks.map(_partial(_stamp_regen_task, visitor=visitor, **headers))
else:
new_tasks = []
for task in self.tasks:
task = maybe_signature(task, app=self.app)
task.stamp(visitor=visitor, **headers)
new_tasks.append(task)
if isinstance(self.tasks, MutableSequence):
self.tasks[:] = new_tasks
else:
self.tasks = new_tasks
if visitor is not None:
visitor.on_group_end(self, **headers)
def link(self, sig):
# Simply link to first task. Doing this is slightly misleading because
# the callback may be executed before all children in the group are
@ -1225,7 +1448,10 @@ class group(Signature):
def _freeze_gid(self, options):
# remove task_id and use that as the group_id,
# if we don't remove it then every task will have the same id...
options = dict(self.options, **options)
options = {**self.options, **{
k: v for k, v in options.items()
if k not in self._IMMUTABLE_OPTIONS or k not in self.options
}}
options['group_id'] = group_id = (
options.pop('task_id', uuid()))
return options, group_id, options.get('root_id')
@ -1403,26 +1629,52 @@ class _chord(Signature):
# first freeze all tasks in the header
header_result = self.tasks.freeze(
parent_id=parent_id, root_id=root_id, chord=self.body)
# secondly freeze all tasks in the body: those that should be called after the header
body_result = self.body.freeze(
_id, root_id=root_id, chord=chord, group_id=group_id,
group_index=group_index)
# we need to link the body result back to the group result,
# but the body may actually be a chain,
# so find the first result without a parent
node = body_result
seen = set()
while node:
if node.id in seen:
raise RuntimeError('Recursive result parents')
seen.add(node.id)
if node.parent is None:
node.parent = header_result
break
node = node.parent
self.id = self.tasks.id
# secondly freeze all tasks in the body: those that should be called after the header
body_result = None
if self.body:
body_result = self.body.freeze(
_id, root_id=root_id, chord=chord, group_id=group_id,
group_index=group_index)
# we need to link the body result back to the group result,
# but the body may actually be a chain,
# so find the first result without a parent
node = body_result
seen = set()
while node:
if node.id in seen:
raise RuntimeError('Recursive result parents')
seen.add(node.id)
if node.parent is None:
node.parent = header_result
break
node = node.parent
return body_result
def stamp(self, visitor=None, **headers):
if visitor is not None and self.body is not None:
headers.update(visitor.on_chord_body(self, **headers))
self.body.stamp(visitor=visitor, **headers)
if visitor is not None:
headers.update(visitor.on_chord_header_start(self, **headers))
super().stamp(visitor=visitor, **headers)
tasks = self.tasks
if isinstance(tasks, group):
tasks = tasks.tasks
if isinstance(tasks, _regen):
tasks.map(_partial(_stamp_regen_task, visitor=visitor, **headers))
else:
for task in tasks:
task.stamp(visitor=visitor, **headers)
if visitor is not None:
visitor.on_chord_header_end(self, **headers)
def apply_async(self, args=None, kwargs=None, task_id=None,
producer=None, publisher=None, connection=None,
router=None, result_cls=None, **options):
@ -1441,7 +1693,13 @@ class _chord(Signature):
return self.apply(args, kwargs,
body=body, task_id=task_id, **options)
groups = self.options.get("groups")
stamped_headers = self.options.get("stamped_headers")
self.stamp(visitor=GroupStampingVisitor(groups=groups, stamped_headers=stamped_headers))
tasks.stamp(visitor=GroupStampingVisitor(groups=groups, stamped_headers=stamped_headers))
merged_options = dict(self.options, **options) if options else self.options
option_task_id = merged_options.pop("task_id", None)
if task_id is None:
task_id = option_task_id
@ -1453,9 +1711,13 @@ class _chord(Signature):
propagate=True, body=None, **options):
args = args if args else ()
kwargs = kwargs if kwargs else {}
stamped_headers = self.options.get("stamped_headers")
groups = self.options.get("groups")
body = self.body if body is None else body
tasks = (self.tasks.clone() if isinstance(self.tasks, group)
else group(self.tasks, app=self.app))
self.stamp(visitor=GroupStampingVisitor(groups=groups, stamped_headers=stamped_headers))
tasks.stamp(visitor=GroupStampingVisitor(groups=groups, stamped_headers=stamped_headers))
return body.apply(
args=(tasks.apply(args, kwargs).get(propagate=propagate),),
)

4
celery/utils/functional.py

@ -200,6 +200,10 @@ class _regen(UserList, list):
def __reduce__(self):
return list, (self.data,)
def map(self, func):
self.__consumed = [func(el) for el in self.__consumed]
self.__it = map(func, self.__it)
def __length_hint__(self):
return self.__it.__length_hint__()

12
celery/worker/request.py

@ -314,6 +314,18 @@ class Request:
def replaced_task_nesting(self):
return self._request_dict.get('replaced_task_nesting', 0)
@property
def groups(self):
return self._request_dict.get('groups', [])
@property
def stamped_headers(self) -> list:
return self._request_dict.get('stamped_headers', [])
@property
def stamps(self) -> dict:
return {header: self._request_dict[header] for header in self.stamped_headers}
@property
def correlation_id(self):
# used similarly to reply_to

88
docs/userguide/canvas.rst

@ -1130,3 +1130,91 @@ of one:
This means that the first task will have a countdown of one second, the second
task a countdown of two seconds, and so on.
Stamping
========
.. versionadded:: 5.3
The goal of the Stamping API is to give an ability to label
the signature and its components for debugging information purposes.
For example, when the canvas is a complex structure, it may be necessary to
label some or all elements of the formed structure. The complexity
increases even more when nested groups are rolled-out or chain
elements are replaced. In such cases, it may be necessary to
understand which group an element is a part of or on what nested
level it is. This requires a mechanism that traverses the canvas
elements and marks them with specific metadata. The stamping API
allows doing that based on the Visitor pattern.
For example,
.. code-block:: pycon
>>> sig1 = add.si(2, 2)
>>> sig1_res = sig1.freeze()
>>> g = group(sig1, add.si(3, 3))
>>> g.stamp(stamp='your_custom_stamp')
>>> res = g1.apply_async()
>>> res.get(timeout=TIMEOUT)
[4, 6]
>>> sig1_res._get_task_meta()['stamp']
['your_custom_stamp']
will initialize a group ``g`` and mark its components with stamp ``your_custom_stamp``.
For this feature to be useful, you need to set the :setting:`result_extended`
configuration option to ``True`` or directive ``result_extended = True``.
Group stamping
--------------
When the ``apply`` and ``apply_async`` methods are called,
there is an automatic stamping signature with group id.
Stamps are stored in group header.
For example, after
.. code-block:: pycon
>>> g.apply_async()
the header of task sig1 will store the stamp groups with g.id.
In the case of nested groups, the order of the stamps corresponds
to the nesting level. The group stamping is idempotent;
the task cannot be stamped twice with the same group id.
Canvas stamping
----------------
In addition to the default group stamping, we can also stamp
canvas with custom stamps, as shown in the example.
Custom stamping
----------------
If more complex stamping logic is required, it is possible
to implement custom stamping behavior based on the Visitor
pattern. The class that implements this custom logic must
inherit ``VisitorStamping`` and implement appropriate methods.
For example, the following example ``InGroupVisitor`` will label
tasks that are in side of some group by lable ``in_group``.
.. code-block:: python
class InGroupVisitor(StampingVisitor):
def __init__(self):
self.in_group = False
def on_group_start(self, group, **headers) -> dict:
self.in_group = True
return {"in_group": [self.in_group], "stamped_headers": ["in_group"]}
def on_group_end(self, group, **headers) -> None:
self.in_group = False
def on_chain_start(self, chain, **headers) -> dict:
return {"in_group": [self.in_group], "stamped_headers": ["in_group"]}
def on_signature(self, sig, **headers) -> dict:
return {"in_group": [self.in_group], "stamped_headers": ["in_group"]}

3
t/integration/conftest.py

@ -38,7 +38,8 @@ def celery_config():
'cassandra_keyspace': 'tests',
'cassandra_table': 'tests',
'cassandra_read_consistency': 'ONE',
'cassandra_write_consistency': 'ONE'
'cassandra_write_consistency': 'ONE',
'result_extended': True
}

7
t/integration/tasks.py

@ -1,3 +1,4 @@
from collections.abc import Iterable
from time import sleep
from celery import Signature, Task, chain, chord, group, shared_task
@ -87,6 +88,12 @@ def tsum(nums):
return sum(nums)
@shared_task
def xsum(nums):
"""Sum of ints and lists."""
return sum(sum(num) if isinstance(num, Iterable) else num for num in nums)
@shared_task(bind=True)
def add_replaced(self, x, y):
"""Add two numbers (via the add task)."""

225
t/integration/test_canvas.py

@ -20,7 +20,7 @@ from .tasks import (ExpectedException, add, add_chord_to_chord, add_replaced, ad
errback_new_style, errback_old_style, fail, fail_replaced, identity, ids, print_unicode,
raise_error, redis_count, redis_echo, replace_with_chain, replace_with_chain_which_raises,
replace_with_empty_chain, retry_once, return_exception, return_priority, second_order_replace1,
tsum, write_to_file_and_return_int)
tsum, write_to_file_and_return_int, xsum)
RETRYABLE_EXCEPTIONS = (OSError, ConnectionError, TimeoutError)
@ -31,7 +31,6 @@ def is_retryable_exception(exc):
TIMEOUT = 60
_flaky = pytest.mark.flaky(reruns=5, reruns_delay=1, cause=is_retryable_exception)
_timeout = pytest.mark.timeout(timeout=300)
@ -47,7 +46,7 @@ def await_redis_echo(expected_msgs, redis_key="redis-echo", timeout=TIMEOUT):
redis_connection = get_redis_connection()
if isinstance(expected_msgs, (str, bytes, bytearray)):
expected_msgs = (expected_msgs, )
expected_msgs = (expected_msgs,)
expected_msgs = collections.Counter(
e if not isinstance(e, str) else e.encode("utf-8")
for e in expected_msgs
@ -127,7 +126,7 @@ class test_link_error:
args=("test",),
link_error=retry_once.s(countdown=None)
)
assert result.get(timeout=TIMEOUT, propagate=False) == exception
assert result.get(timeout=TIMEOUT / 10, propagate=False) == exception
@flaky
def test_link_error_using_signature_eager(self):
@ -148,7 +147,7 @@ class test_link_error:
fail.link_error(retrun_exception)
exception = ExpectedException("Task expected to fail", "test")
assert (fail.delay().get(timeout=TIMEOUT, propagate=False), True) == (
assert (fail.delay().get(timeout=TIMEOUT / 10, propagate=False), True) == (
exception, True)
@ -166,11 +165,11 @@ class test_chain:
@flaky
def test_complex_chain(self, manager):
g = group(add.s(i) for i in range(4))
c = (
add.s(2, 2) | (
add.s(4) | add_replaced.s(8) | add.s(16) | add.s(32)
) |
group(add.s(i) for i in range(4))
) | g
)
res = c()
assert res.get(timeout=TIMEOUT) == [64, 65, 66, 67]
@ -187,7 +186,7 @@ class test_chain:
)
)
res = c()
assert res.get(timeout=TIMEOUT) == [4, 5]
assert res.get(timeout=TIMEOUT / 10) == [4, 5]
def test_chain_of_chain_with_a_single_task(self, manager):
sig = signature('any_taskname', queue='any_q')
@ -482,7 +481,7 @@ class test_chain:
group(identity.s(42), identity.s(42)), # [42, 42]
)
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [42, 42]
assert res.get(timeout=TIMEOUT / 10) == [42, 42]
def test_nested_chain_group_mid(self, manager):
"""
@ -494,9 +493,9 @@ class test_chain:
raise pytest.skip(e.args[0])
sig = chain(
identity.s(42), # 42
group(identity.s(), identity.s()), # [42, 42]
identity.s(), # [42, 42]
identity.s(42), # 42
group(identity.s(), identity.s()), # [42, 42]
identity.s(), # [42, 42]
)
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [42, 42]
@ -506,8 +505,8 @@ class test_chain:
Test that a final group in a chain with preceding tasks completes.
"""
sig = chain(
identity.s(42), # 42
group(identity.s(), identity.s()), # [42, 42]
identity.s(42), # 42
group(identity.s(), identity.s()), # [42, 42]
)
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [42, 42]
@ -777,6 +776,46 @@ class test_result_set:
class test_group:
def test_group_stamping(self, manager, subtests):
if not manager.app.conf.result_backend.startswith('redis'):
raise pytest.skip('Requires redis result backend.')
sig1 = add.s(1, 1000)
sig1_res = sig1.freeze()
g1 = group(sig1, add.s(1, 2000))
g1_res = g1.freeze()
res = g1.apply_async()
res.get(timeout=TIMEOUT)
with subtests.test("sig_1 is stamped", groups=[g1_res.id]):
assert sig1_res._get_task_meta()["groups"] == [g1_res.id]
def test_nested_group_stamping(self, manager, subtests):
if not manager.app.conf.result_backend.startswith('redis'):
raise pytest.skip('Requires redis result backend.')
sig1 = add.s(2, 2)
sig2 = add.s(2)
sig1_res = sig1.freeze()
sig2_res = sig2.freeze()
g2 = group(sig2, chain(add.s(4), add.s(2)))
g2_res = g2.freeze()
g1 = group(sig1, chain(add.s(1, 1), g2))
g1_res = g1.freeze()
res = g1.apply_async()
res.get(timeout=TIMEOUT)
with subtests.test("sig1 is stamped", groups=[g1_res.id]):
assert sig1_res._get_task_meta()['groups'] == [g1_res.id]
with subtests.test("sig2 is stamped", groups=[g1_res.id, g2_res.id]):
assert sig2_res._get_task_meta()['groups'] == \
[g1_res.id, g2_res.id]
@flaky
def test_ready_with_exception(self, manager):
if not manager.app.conf.result_backend.startswith('redis'):
@ -850,7 +889,7 @@ class test_group:
"""
Test that a simple group completes.
"""
sig = group(identity.s(42), identity.s(42)) # [42, 42]
sig = group(identity.s(42), identity.s(42)) # [42, 42]
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [42, 42]
@ -860,7 +899,7 @@ class test_group:
"""
sig = group(
group(identity.s(42), identity.s(42)), # [42, 42]
) # [42, 42] due to unrolling
) # [42, 42] due to unrolling
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [42, 42]
@ -871,8 +910,8 @@ class test_group:
raise pytest.skip(e.args[0])
gchild_sig = identity.si(42)
child_chord = chord((gchild_sig, ), identity.s())
group_sig = group((child_chord, ))
child_chord = chord((gchild_sig,), identity.s())
group_sig = group((child_chord,))
res = group_sig.delay()
# Wait for the result to land and confirm its value is as expected
assert res.get(timeout=TIMEOUT) == [[42]]
@ -884,9 +923,9 @@ class test_group:
raise pytest.skip(e.args[0])
gchild_count = 42
gchild_sig = chain((identity.si(1337), ) * gchild_count)
child_chord = chord((gchild_sig, ), identity.s())
group_sig = group((child_chord, ))
gchild_sig = chain((identity.si(1337),) * gchild_count)
child_chord = chord((gchild_sig,), identity.s())
group_sig = group((child_chord,))
res = group_sig.delay()
# Wait for the result to land and confirm its value is as expected
assert res.get(timeout=TIMEOUT) == [[1337]]
@ -898,9 +937,9 @@ class test_group:
raise pytest.skip(e.args[0])
gchild_count = 42
gchild_sig = group((identity.si(1337), ) * gchild_count)
child_chord = chord((gchild_sig, ), identity.s())
group_sig = group((child_chord, ))
gchild_sig = group((identity.si(1337),) * gchild_count)
child_chord = chord((gchild_sig,), identity.s())
group_sig = group((child_chord,))
res = group_sig.delay()
# Wait for the result to land and confirm its value is as expected
assert res.get(timeout=TIMEOUT) == [[1337] * gchild_count]
@ -913,10 +952,10 @@ class test_group:
gchild_count = 42
gchild_sig = chord(
(identity.si(1337), ) * gchild_count, identity.si(31337),
(identity.si(1337),) * gchild_count, identity.si(31337),
)
child_chord = chord((gchild_sig, ), identity.s())
group_sig = group((child_chord, ))
child_chord = chord((gchild_sig,), identity.s())
group_sig = group((child_chord,))
res = group_sig.delay()
# Wait for the result to land and confirm its value is as expected
assert res.get(timeout=TIMEOUT) == [[31337]]
@ -931,19 +970,19 @@ class test_group:
child_chord = chord(
(
identity.si(42),
chain((identity.si(42), ) * gchild_count),
group((identity.si(42), ) * gchild_count),
chord((identity.si(42), ) * gchild_count, identity.si(1337)),
chain((identity.si(42),) * gchild_count),
group((identity.si(42),) * gchild_count),
chord((identity.si(42),) * gchild_count, identity.si(1337)),
),
identity.s(),
)
group_sig = group((child_chord, ))
group_sig = group((child_chord,))
res = group_sig.delay()
# Wait for the result to land and confirm its value is as expected. The
# group result gets unrolled into the encapsulating chord, hence the
# weird unpacking below
assert res.get(timeout=TIMEOUT) == [
[42, 42, *((42, ) * gchild_count), 1337]
[42, 42, *((42,) * gchild_count), 1337]
]
@pytest.mark.xfail(raises=TimeoutError, reason="#6734")
@ -953,8 +992,8 @@ class test_group:
except NotImplementedError as e:
raise pytest.skip(e.args[0])
child_chord = chord(identity.si(42), chain((identity.s(), )))
group_sig = group((child_chord, ))
child_chord = chord(identity.si(42), chain((identity.s(),)))
group_sig = group((child_chord,))
res = group_sig.delay()
# The result can be expected to timeout since it seems like its
# underlying promise might not be getting fulfilled (ref #6734). Pick a
@ -1219,6 +1258,43 @@ def assert_ping(manager):
class test_chord:
def test_chord_stamping_two_levels(self, manager, subtests):
"""
For a group within a chord, test that group stamps are stored in
the correct order.
"""
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])
sig_1 = add.s(2, 2)
sig_2 = add.s(2)
sig_1_res = sig_1.freeze()
sig_2_res = sig_2.freeze()
g2 = group(
sig_2,
add.s(4),
)
g2_res = g2.freeze()
sig_sum = xsum.s()
sig_sum.freeze()
g1 = chord([sig_1, chain(add.s(4, 4), g2)], sig_sum)
g1.freeze()
res = g1.apply_async()
res.get(timeout=TIMEOUT)
with subtests.test("sig_1_res is stamped", groups=[g1.tasks.id]):
assert sig_1_res._get_task_meta()['groups'] == [g1.tasks.id]
with subtests.test("sig_2_res is stamped", groups=[g1.id]):
assert sig_2_res._get_task_meta()['groups'] == [g1.tasks.id, g2_res.id]
@flaky
def test_simple_chord_with_a_delay_in_group_save(self, manager, monkeypatch):
try:
@ -1589,6 +1665,7 @@ class test_chord:
with open(file_name) as file_handle:
# ensures chord header generators tasks are processed incrementally #3021
assert file_handle.readline() == '0\n', "Chord header was unrolled too early"
yield write_to_file_and_return_int.s(file_name, i)
with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_file:
@ -1752,7 +1829,7 @@ class test_chord:
(
group(identity.s(42), identity.s(42)), # [42, 42]
),
identity.s() # [42, 42]
identity.s() # [42, 42]
)
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [42, 42]
@ -1772,14 +1849,14 @@ class test_chord:
sig = chord(
group(
chain(
identity.s(42), # 42
identity.s(42), # 42
group(
identity.s(), # 42
identity.s(), # 42
), # [42, 42]
), # [42, 42]
), # [[42, 42]] since the chain prevents unrolling
identity.s(), # [[42, 42]]
identity.s(), # 42
identity.s(), # 42
), # [42, 42]
), # [42, 42]
), # [[42, 42]] since the chain prevents unrolling
identity.s(), # [[42, 42]]
)
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [[42, 42]]
@ -1817,13 +1894,13 @@ class test_chord:
child_sig = fail.s()
chord_sig = chord((child_sig, ), identity.s())
chord_sig = chord((child_sig,), identity.s())
with subtests.test(msg="Error propagates from simple header task"):
res = chord_sig.delay()
with pytest.raises(ExpectedException):
res.get(timeout=TIMEOUT)
chord_sig = chord((identity.si(42), ), child_sig)
chord_sig = chord((identity.si(42),), child_sig)
with subtests.test(msg="Error propagates from simple body task"):
res = chord_sig.delay()
with pytest.raises(ExpectedException):
@ -1841,7 +1918,7 @@ class test_chord:
errback = redis_echo.si(errback_msg, redis_key=redis_key)
child_sig = fail.s()
chord_sig = chord((child_sig, ), identity.s())
chord_sig = chord((child_sig,), identity.s())
chord_sig.link_error(errback)
redis_connection.delete(redis_key)
with subtests.test(msg="Error propagates from simple header task"):
@ -1853,7 +1930,7 @@ class test_chord:
):
await_redis_echo({errback_msg, }, redis_key=redis_key)
chord_sig = chord((identity.si(42), ), child_sig)
chord_sig = chord((identity.si(42),), child_sig)
chord_sig.link_error(errback)
redis_connection.delete(redis_key)
with subtests.test(msg="Error propagates from simple body task"):
@ -1879,7 +1956,7 @@ class test_chord:
errback = errback_task.s()
child_sig = fail.s()
chord_sig = chord((child_sig, ), identity.s())
chord_sig = chord((child_sig,), identity.s())
chord_sig.link_error(errback)
expected_redis_key = chord_sig.body.freeze().id
redis_connection.delete(expected_redis_key)
@ -1892,7 +1969,7 @@ class test_chord:
):
await_redis_count(1, redis_key=expected_redis_key)
chord_sig = chord((identity.si(42), ), child_sig)
chord_sig = chord((identity.si(42),), child_sig)
chord_sig.link_error(errback)
expected_redis_key = chord_sig.body.freeze().id
redis_connection.delete(expected_redis_key)
@ -1914,7 +1991,7 @@ class test_chord:
child_sig = chain(identity.si(42), fail.s(), identity.si(42))
chord_sig = chord((child_sig, ), identity.s())
chord_sig = chord((child_sig,), identity.s())
with subtests.test(
msg="Error propagates from header chain which fails before the end"
):
@ -1922,7 +1999,7 @@ class test_chord:
with pytest.raises(ExpectedException):
res.get(timeout=TIMEOUT)
chord_sig = chord((identity.si(42), ), child_sig)
chord_sig = chord((identity.si(42),), child_sig)
with subtests.test(
msg="Error propagates from body chain which fails before the end"
):
@ -1942,7 +2019,7 @@ class test_chord:
errback = redis_echo.si(errback_msg, redis_key=redis_key)
child_sig = chain(identity.si(42), fail.s(), identity.si(42))