mirror of https://github.com/celery/celery.git
Added integration test test_chord_header_id_duplicated_on_rabbitmq_msg_duplication() (#7692)
When a task that predates a chord in a chain was duplicated by Rabbitmq (for whatever reason), the chord header id was not duplicated. This caused the chord header to have a different id. This test ensures that the chord header's id preserves itself in face of such an edge case.pull/7714/head
parent
6f95c040ae
commit
3db7c9dde9
|
@ -83,7 +83,7 @@ jobs:
|
|||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ['3.7', '3.8', '3.9', '3.10']
|
||||
toxenv: ['redis', 'rabbitmq']
|
||||
toxenv: ['redis', 'rabbitmq', 'rabbitmq_redis']
|
||||
|
||||
services:
|
||||
redis:
|
||||
|
|
|
@ -605,7 +605,7 @@ class Signature(dict):
|
|||
|
||||
def __deepcopy__(self, memo):
|
||||
memo[id(self)] = self
|
||||
return dict(self)
|
||||
return dict(self) # TODO: Potential bug of being a shallow copy
|
||||
|
||||
def __invert__(self):
|
||||
return self.apply_async().get()
|
||||
|
@ -1687,7 +1687,7 @@ class _chord(Signature):
|
|||
body = body.clone(**options)
|
||||
app = self._get_app(body)
|
||||
tasks = (self.tasks.clone() if isinstance(self.tasks, group)
|
||||
else group(self.tasks, app=app))
|
||||
else group(self.tasks, app=app, task_id=self.options.get('task_id', uuid())))
|
||||
if app.conf.task_always_eager:
|
||||
with allow_join_result():
|
||||
return self.apply(args, kwargs,
|
||||
|
|
|
@ -3,7 +3,7 @@ addopts = "--strict-markers"
|
|||
testpaths = "t/unit/"
|
||||
python_classes = "test_*"
|
||||
xfail_strict=true
|
||||
markers = ["sleepdeprived_patched_module", "masked_modules", "patched_environ", "patched_module"]
|
||||
markers = ["sleepdeprived_patched_module", "masked_modules", "patched_environ", "patched_module", "flaky", "timeout"]
|
||||
|
||||
[tool.mypy]
|
||||
warn_unused_configs = true
|
||||
|
|
|
@ -2,4 +2,5 @@ pytz>dev
|
|||
git+https://github.com/celery/py-amqp.git
|
||||
git+https://github.com/celery/kombu.git
|
||||
git+https://github.com/celery/billiard.git
|
||||
vine>=5.0.0
|
||||
vine>=5.0.0
|
||||
isort~=5.10.1
|
||||
|
|
|
@ -241,6 +241,12 @@ def redis_echo(message, redis_key="redis-echo"):
|
|||
redis_connection.rpush(redis_key, message)
|
||||
|
||||
|
||||
@shared_task(bind=True)
|
||||
def redis_echo_group_id(self, _, redis_key="redis-group-ids"):
|
||||
redis_connection = get_redis_connection()
|
||||
redis_connection.rpush(redis_key, self.request.group)
|
||||
|
||||
|
||||
@shared_task
|
||||
def redis_count(redis_key="redis-count"):
|
||||
"""Task that increments a specified or well-known redis key."""
|
||||
|
|
|
@ -12,15 +12,16 @@ from celery import chain, chord, group, signature
|
|||
from celery.backends.base import BaseKeyValueStoreBackend
|
||||
from celery.exceptions import ImproperlyConfigured, TimeoutError
|
||||
from celery.result import AsyncResult, GroupResult, ResultSet
|
||||
from celery.signals import before_task_publish
|
||||
|
||||
from . import tasks
|
||||
from .conftest import TEST_BACKEND, get_active_redis_channels, get_redis_connection
|
||||
from .tasks import (ExpectedException, add, add_chord_to_chord, add_replaced, add_to_all, add_to_all_to_chord,
|
||||
build_chain_inside_task, collect_ids, delayed_sum, delayed_sum_with_soft_guard,
|
||||
errback_new_style, errback_old_style, fail, fail_replaced, identity, ids, print_unicode,
|
||||
raise_error, redis_count, redis_echo, replace_with_chain, replace_with_chain_which_raises,
|
||||
replace_with_empty_chain, retry_once, return_exception, return_priority, second_order_replace1,
|
||||
tsum, write_to_file_and_return_int, xsum)
|
||||
raise_error, redis_count, redis_echo, redis_echo_group_id, replace_with_chain,
|
||||
replace_with_chain_which_raises, replace_with_empty_chain, retry_once, return_exception,
|
||||
return_priority, second_order_replace1, tsum, write_to_file_and_return_int, xsum)
|
||||
|
||||
RETRYABLE_EXCEPTIONS = (OSError, ConnectionError, TimeoutError)
|
||||
|
||||
|
@ -62,12 +63,36 @@ def await_redis_echo(expected_msgs, redis_key="redis-echo", timeout=TIMEOUT):
|
|||
)
|
||||
retrieved_key, msg = maybe_key_msg
|
||||
assert retrieved_key.decode("utf-8") == redis_key
|
||||
expected_msgs[msg] -= 1 # silently accepts unexpected messages
|
||||
expected_msgs[msg] -= 1 # silently accepts unexpected messages
|
||||
|
||||
# There should be no more elements - block momentarily
|
||||
assert redis_connection.blpop(redis_key, min(1, timeout)) is None
|
||||
|
||||
|
||||
def await_redis_list_message_length(expected_length, redis_key="redis-group-ids", timeout=TIMEOUT):
|
||||
"""
|
||||
Helper to wait for a specified or well-known redis key to contain a string.
|
||||
"""
|
||||
sleep(1)
|
||||
redis_connection = get_redis_connection()
|
||||
|
||||
check_interval = 0.1
|
||||
check_max = int(timeout / check_interval)
|
||||
|
||||
for i in range(check_max + 1):
|
||||
length = redis_connection.llen(redis_key)
|
||||
|
||||
if length == expected_length:
|
||||
break
|
||||
|
||||
sleep(check_interval)
|
||||
else:
|
||||
raise TimeoutError(f'{redis_key!r} has length of {length}, but expected to be of length {expected_length}')
|
||||
|
||||
sleep(min(1, timeout))
|
||||
assert redis_connection.llen(redis_key) == expected_length
|
||||
|
||||
|
||||
def await_redis_count(expected_count, redis_key="redis-count", timeout=TIMEOUT):
|
||||
"""
|
||||
Helper to wait for a specified or well-known redis key to count to a value.
|
||||
|
@ -95,6 +120,13 @@ def await_redis_count(expected_count, redis_key="redis-count", timeout=TIMEOUT):
|
|||
assert int(redis_connection.get(redis_key)) == expected_count
|
||||
|
||||
|
||||
def compare_group_ids_in_redis(redis_key='redis-group-ids'):
|
||||
redis_connection = get_redis_connection()
|
||||
actual = redis_connection.lrange(redis_key, 0, -1)
|
||||
assert len(actual) >= 2, 'Expected at least 2 group ids in redis'
|
||||
assert actual[0] == actual[1], 'Expected group ids to be equal'
|
||||
|
||||
|
||||
class test_link_error:
|
||||
@flaky
|
||||
def test_link_error_eager(self):
|
||||
|
@ -754,6 +786,78 @@ class test_chain:
|
|||
res_obj = orig_sig.delay()
|
||||
assert res_obj.get(timeout=TIMEOUT) == 42
|
||||
|
||||
@pytest.mark.parametrize('redis_key', ['redis-group-ids'])
|
||||
def test_chord_header_id_duplicated_on_rabbitmq_msg_duplication(self, manager, subtests, celery_session_app,
|
||||
redis_key):
|
||||
"""
|
||||
When a task that predates a chord in a chain was duplicated by Rabbitmq (for whatever reason),
|
||||
the chord header id was not duplicated. This caused the chord header to have a different id.
|
||||
This test ensures that the chord header's id preserves itself in face of such an edge case.
|
||||
To validate the correct behavior is implemented, we collect the original and duplicated chord header ids
|
||||
in redis, to ensure that they are the same.
|
||||
"""
|
||||
|
||||
try:
|
||||
manager.app.backend.ensure_chords_allowed()
|
||||
except NotImplementedError as e:
|
||||
raise pytest.skip(e.args[0])
|
||||
|
||||
if manager.app.conf.broker_url.startswith('redis'):
|
||||
raise pytest.xfail('Redis broker does not duplicate the task (t1)')
|
||||
|
||||
# Republish t1 to cause the chain to be executed twice
|
||||
@before_task_publish.connect
|
||||
def before_task_publish_handler(sender=None, body=None, exchange=None, routing_key=None, headers=None,
|
||||
properties=None,
|
||||
declare=None, retry_policy=None, **kwargs):
|
||||
""" We want to republish t1 to ensure that the chain is executed twice """
|
||||
|
||||
metadata = {
|
||||
'body': body,
|
||||
'exchange': exchange,
|
||||
'routing_key': routing_key,
|
||||
'properties': properties,
|
||||
'headers': headers,
|
||||
}
|
||||
|
||||
with celery_session_app.producer_pool.acquire(block=True) as producer:
|
||||
# Publish t1 to the message broker, just before it's going to be published which causes duplication
|
||||
return producer.publish(
|
||||
metadata['body'],
|
||||
exchange=metadata['exchange'],
|
||||
routing_key=metadata['routing_key'],
|
||||
retry=None,
|
||||
retry_policy=retry_policy,
|
||||
serializer='json',
|
||||
delivery_mode=None,
|
||||
headers=headers,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Clean redis key
|
||||
redis_connection = get_redis_connection()
|
||||
if redis_connection.exists(redis_key):
|
||||
redis_connection.delete(redis_key)
|
||||
|
||||
# Prepare tasks
|
||||
t1, t2, t3, t4 = identity.s(42), redis_echo_group_id.s(), identity.s(), identity.s()
|
||||
c = chain(t1, chord([t2, t3], t4))
|
||||
|
||||
# Delay chain
|
||||
r1 = c.delay()
|
||||
r1.get(timeout=TIMEOUT)
|
||||
|
||||
# Cleanup
|
||||
before_task_publish.disconnect(before_task_publish_handler)
|
||||
|
||||
with subtests.test(msg='Compare group ids via redis list'):
|
||||
await_redis_list_message_length(2, redis_key=redis_key, timeout=15)
|
||||
compare_group_ids_in_redis(redis_key=redis_key)
|
||||
|
||||
# Cleanup
|
||||
redis_connection = get_redis_connection()
|
||||
redis_connection.delete(redis_key)
|
||||
|
||||
|
||||
class test_result_set:
|
||||
|
||||
|
|
5
tox.ini
5
tox.ini
|
@ -3,7 +3,7 @@ requires =
|
|||
tox-gh-actions
|
||||
envlist =
|
||||
{3.7,3.8,3.9,3.10,pypy3}-unit
|
||||
{3.7,3.8,3.9,3.10,pypy3}-integration-{rabbitmq,redis,dynamodb,azureblockblob,cache,cassandra,elasticsearch}
|
||||
{3.7,3.8,3.9,3.10,pypy3}-integration-{rabbitmq_redis,rabbitmq,redis,dynamodb,azureblockblob,cache,cassandra,elasticsearch}
|
||||
|
||||
flake8
|
||||
apicheck
|
||||
|
@ -64,6 +64,9 @@ setenv =
|
|||
redis: TEST_BROKER=redis://
|
||||
redis: TEST_BACKEND=redis://
|
||||
|
||||
rabbitmq_redis: TEST_BROKER=pyamqp://
|
||||
rabbitmq_redis: TEST_BACKEND=redis://
|
||||
|
||||
dynamodb: TEST_BROKER=redis://
|
||||
dynamodb: TEST_BACKEND=dynamodb://@localhost:8000
|
||||
dynamodb: AWS_ACCESS_KEY_ID=test_aws_key_id
|
||||
|
|
Loading…
Reference in New Issue