Added integration test test_chord_header_id_duplicated_on_rabbitmq_msg_duplication() (#7692)

When a task that predates a chord in a chain was duplicated by Rabbitmq (for whatever reason),
the chord header id was not duplicated. This caused the chord header to have a different id.
This test ensures that the chord header's id preserves itself in face of such an edge case.
pull/7714/head
Tomer Nosrati 2022-08-14 20:05:57 +03:00 committed by GitHub
parent 6f95c040ae
commit 3db7c9dde9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 124 additions and 10 deletions

View File

@ -83,7 +83,7 @@ jobs:
fail-fast: false
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10']
toxenv: ['redis', 'rabbitmq']
toxenv: ['redis', 'rabbitmq', 'rabbitmq_redis']
services:
redis:

View File

@ -605,7 +605,7 @@ class Signature(dict):
def __deepcopy__(self, memo):
memo[id(self)] = self
return dict(self)
return dict(self) # TODO: Potential bug of being a shallow copy
def __invert__(self):
return self.apply_async().get()
@ -1687,7 +1687,7 @@ class _chord(Signature):
body = body.clone(**options)
app = self._get_app(body)
tasks = (self.tasks.clone() if isinstance(self.tasks, group)
else group(self.tasks, app=app))
else group(self.tasks, app=app, task_id=self.options.get('task_id', uuid())))
if app.conf.task_always_eager:
with allow_join_result():
return self.apply(args, kwargs,

View File

@ -3,7 +3,7 @@ addopts = "--strict-markers"
testpaths = "t/unit/"
python_classes = "test_*"
xfail_strict=true
markers = ["sleepdeprived_patched_module", "masked_modules", "patched_environ", "patched_module"]
markers = ["sleepdeprived_patched_module", "masked_modules", "patched_environ", "patched_module", "flaky", "timeout"]
[tool.mypy]
warn_unused_configs = true

View File

@ -2,4 +2,5 @@ pytz>dev
git+https://github.com/celery/py-amqp.git
git+https://github.com/celery/kombu.git
git+https://github.com/celery/billiard.git
vine>=5.0.0
vine>=5.0.0
isort~=5.10.1

View File

@ -241,6 +241,12 @@ def redis_echo(message, redis_key="redis-echo"):
redis_connection.rpush(redis_key, message)
@shared_task(bind=True)
def redis_echo_group_id(self, _, redis_key="redis-group-ids"):
redis_connection = get_redis_connection()
redis_connection.rpush(redis_key, self.request.group)
@shared_task
def redis_count(redis_key="redis-count"):
"""Task that increments a specified or well-known redis key."""

View File

@ -12,15 +12,16 @@ from celery import chain, chord, group, signature
from celery.backends.base import BaseKeyValueStoreBackend
from celery.exceptions import ImproperlyConfigured, TimeoutError
from celery.result import AsyncResult, GroupResult, ResultSet
from celery.signals import before_task_publish
from . import tasks
from .conftest import TEST_BACKEND, get_active_redis_channels, get_redis_connection
from .tasks import (ExpectedException, add, add_chord_to_chord, add_replaced, add_to_all, add_to_all_to_chord,
build_chain_inside_task, collect_ids, delayed_sum, delayed_sum_with_soft_guard,
errback_new_style, errback_old_style, fail, fail_replaced, identity, ids, print_unicode,
raise_error, redis_count, redis_echo, replace_with_chain, replace_with_chain_which_raises,
replace_with_empty_chain, retry_once, return_exception, return_priority, second_order_replace1,
tsum, write_to_file_and_return_int, xsum)
raise_error, redis_count, redis_echo, redis_echo_group_id, replace_with_chain,
replace_with_chain_which_raises, replace_with_empty_chain, retry_once, return_exception,
return_priority, second_order_replace1, tsum, write_to_file_and_return_int, xsum)
RETRYABLE_EXCEPTIONS = (OSError, ConnectionError, TimeoutError)
@ -62,12 +63,36 @@ def await_redis_echo(expected_msgs, redis_key="redis-echo", timeout=TIMEOUT):
)
retrieved_key, msg = maybe_key_msg
assert retrieved_key.decode("utf-8") == redis_key
expected_msgs[msg] -= 1 # silently accepts unexpected messages
expected_msgs[msg] -= 1 # silently accepts unexpected messages
# There should be no more elements - block momentarily
assert redis_connection.blpop(redis_key, min(1, timeout)) is None
def await_redis_list_message_length(expected_length, redis_key="redis-group-ids", timeout=TIMEOUT):
"""
Helper to wait for a specified or well-known redis key to contain a string.
"""
sleep(1)
redis_connection = get_redis_connection()
check_interval = 0.1
check_max = int(timeout / check_interval)
for i in range(check_max + 1):
length = redis_connection.llen(redis_key)
if length == expected_length:
break
sleep(check_interval)
else:
raise TimeoutError(f'{redis_key!r} has length of {length}, but expected to be of length {expected_length}')
sleep(min(1, timeout))
assert redis_connection.llen(redis_key) == expected_length
def await_redis_count(expected_count, redis_key="redis-count", timeout=TIMEOUT):
"""
Helper to wait for a specified or well-known redis key to count to a value.
@ -95,6 +120,13 @@ def await_redis_count(expected_count, redis_key="redis-count", timeout=TIMEOUT):
assert int(redis_connection.get(redis_key)) == expected_count
def compare_group_ids_in_redis(redis_key='redis-group-ids'):
redis_connection = get_redis_connection()
actual = redis_connection.lrange(redis_key, 0, -1)
assert len(actual) >= 2, 'Expected at least 2 group ids in redis'
assert actual[0] == actual[1], 'Expected group ids to be equal'
class test_link_error:
@flaky
def test_link_error_eager(self):
@ -754,6 +786,78 @@ class test_chain:
res_obj = orig_sig.delay()
assert res_obj.get(timeout=TIMEOUT) == 42
@pytest.mark.parametrize('redis_key', ['redis-group-ids'])
def test_chord_header_id_duplicated_on_rabbitmq_msg_duplication(self, manager, subtests, celery_session_app,
redis_key):
"""
When a task that predates a chord in a chain was duplicated by Rabbitmq (for whatever reason),
the chord header id was not duplicated. This caused the chord header to have a different id.
This test ensures that the chord header's id preserves itself in face of such an edge case.
To validate the correct behavior is implemented, we collect the original and duplicated chord header ids
in redis, to ensure that they are the same.
"""
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])
if manager.app.conf.broker_url.startswith('redis'):
raise pytest.xfail('Redis broker does not duplicate the task (t1)')
# Republish t1 to cause the chain to be executed twice
@before_task_publish.connect
def before_task_publish_handler(sender=None, body=None, exchange=None, routing_key=None, headers=None,
properties=None,
declare=None, retry_policy=None, **kwargs):
""" We want to republish t1 to ensure that the chain is executed twice """
metadata = {
'body': body,
'exchange': exchange,
'routing_key': routing_key,
'properties': properties,
'headers': headers,
}
with celery_session_app.producer_pool.acquire(block=True) as producer:
# Publish t1 to the message broker, just before it's going to be published which causes duplication
return producer.publish(
metadata['body'],
exchange=metadata['exchange'],
routing_key=metadata['routing_key'],
retry=None,
retry_policy=retry_policy,
serializer='json',
delivery_mode=None,
headers=headers,
**kwargs
)
# Clean redis key
redis_connection = get_redis_connection()
if redis_connection.exists(redis_key):
redis_connection.delete(redis_key)
# Prepare tasks
t1, t2, t3, t4 = identity.s(42), redis_echo_group_id.s(), identity.s(), identity.s()
c = chain(t1, chord([t2, t3], t4))
# Delay chain
r1 = c.delay()
r1.get(timeout=TIMEOUT)
# Cleanup
before_task_publish.disconnect(before_task_publish_handler)
with subtests.test(msg='Compare group ids via redis list'):
await_redis_list_message_length(2, redis_key=redis_key, timeout=15)
compare_group_ids_in_redis(redis_key=redis_key)
# Cleanup
redis_connection = get_redis_connection()
redis_connection.delete(redis_key)
class test_result_set:

View File

@ -3,7 +3,7 @@ requires =
tox-gh-actions
envlist =
{3.7,3.8,3.9,3.10,pypy3}-unit
{3.7,3.8,3.9,3.10,pypy3}-integration-{rabbitmq,redis,dynamodb,azureblockblob,cache,cassandra,elasticsearch}
{3.7,3.8,3.9,3.10,pypy3}-integration-{rabbitmq_redis,rabbitmq,redis,dynamodb,azureblockblob,cache,cassandra,elasticsearch}
flake8
apicheck
@ -64,6 +64,9 @@ setenv =
redis: TEST_BROKER=redis://
redis: TEST_BACKEND=redis://
rabbitmq_redis: TEST_BROKER=pyamqp://
rabbitmq_redis: TEST_BACKEND=redis://
dynamodb: TEST_BROKER=redis://
dynamodb: TEST_BACKEND=dynamodb://@localhost:8000
dynamodb: AWS_ACCESS_KEY_ID=test_aws_key_id