Reformatted files using black
This commit is contained in:
parent
96789b291f
commit
54ef52e1c6
|
@ -1,14 +1,6 @@
|
|||
from .types import (
|
||||
DjangoObjectType,
|
||||
)
|
||||
from .fields import (
|
||||
DjangoConnectionField,
|
||||
)
|
||||
from .types import DjangoObjectType
|
||||
from .fields import DjangoConnectionField
|
||||
|
||||
__version__ = '2.1rc1'
|
||||
__version__ = "2.1rc1"
|
||||
|
||||
__all__ = [
|
||||
'__version__',
|
||||
'DjangoObjectType',
|
||||
'DjangoConnectionField'
|
||||
]
|
||||
__all__ = ["__version__", "DjangoObjectType", "DjangoConnectionField"]
|
||||
|
|
|
@ -7,7 +7,7 @@ try:
|
|||
# and we cannot have psycopg2 on PyPy
|
||||
from django.contrib.postgres.fields import ArrayField, HStoreField, RangeField
|
||||
except ImportError:
|
||||
ArrayField, HStoreField, JSONField, RangeField = (MissingType, ) * 4
|
||||
ArrayField, HStoreField, JSONField, RangeField = (MissingType,) * 4
|
||||
|
||||
|
||||
try:
|
||||
|
|
|
@ -1,8 +1,22 @@
|
|||
from django.db import models
|
||||
from django.utils.encoding import force_text
|
||||
|
||||
from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List,
|
||||
NonNull, String, UUID, DateTime, Date, Time)
|
||||
from graphene import (
|
||||
ID,
|
||||
Boolean,
|
||||
Dynamic,
|
||||
Enum,
|
||||
Field,
|
||||
Float,
|
||||
Int,
|
||||
List,
|
||||
NonNull,
|
||||
String,
|
||||
UUID,
|
||||
DateTime,
|
||||
Date,
|
||||
Time,
|
||||
)
|
||||
from graphene.types.json import JSONString
|
||||
from graphene.utils.str_converters import to_camel_case, to_const
|
||||
from graphql import assert_valid_name
|
||||
|
@ -32,7 +46,7 @@ def get_choices(choices):
|
|||
else:
|
||||
name = convert_choice_name(value)
|
||||
while name in converted_names:
|
||||
name += '_' + str(len(converted_names))
|
||||
name += "_" + str(len(converted_names))
|
||||
converted_names.append(name)
|
||||
description = help_text
|
||||
yield name, value, description
|
||||
|
@ -43,16 +57,15 @@ def convert_django_field_with_choices(field, registry=None):
|
|||
converted = registry.get_converted_field(field)
|
||||
if converted:
|
||||
return converted
|
||||
choices = getattr(field, 'choices', None)
|
||||
choices = getattr(field, "choices", None)
|
||||
if choices:
|
||||
meta = field.model._meta
|
||||
name = to_camel_case('{}_{}'.format(meta.object_name, field.name))
|
||||
name = to_camel_case("{}_{}".format(meta.object_name, field.name))
|
||||
choices = list(get_choices(choices))
|
||||
named_choices = [(c[0], c[1]) for c in choices]
|
||||
named_choices_descriptions = {c[0]: c[2] for c in choices}
|
||||
|
||||
class EnumWithDescriptionsType(object):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return named_choices_descriptions[self.name]
|
||||
|
@ -69,8 +82,8 @@ def convert_django_field_with_choices(field, registry=None):
|
|||
@singledispatch
|
||||
def convert_django_field(field, registry=None):
|
||||
raise Exception(
|
||||
"Don't know how to convert the Django field %s (%s)" %
|
||||
(field, field.__class__))
|
||||
"Don't know how to convert the Django field %s (%s)" % (field, field.__class__)
|
||||
)
|
||||
|
||||
|
||||
@convert_django_field.register(models.CharField)
|
||||
|
@ -147,7 +160,7 @@ def convert_onetoone_field_to_djangomodel(field, registry=None):
|
|||
|
||||
# We do this for a bug in Django 1.8, where null attr
|
||||
# is not available in the OneToOneRel instance
|
||||
null = getattr(field, 'null', True)
|
||||
null = getattr(field, "null", True)
|
||||
return Field(_type, required=not null)
|
||||
|
||||
return Dynamic(dynamic_type)
|
||||
|
@ -171,6 +184,7 @@ def convert_field_to_list_or_connection(field, registry=None):
|
|||
# defined filter_fields in the DjangoObjectType Meta
|
||||
if _type._meta.filter_fields:
|
||||
from .filter.fields import DjangoFilterConnectionField
|
||||
|
||||
return DjangoFilterConnectionField(_type)
|
||||
|
||||
return DjangoConnectionField(_type)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from .middleware import DjangoDebugMiddleware
|
||||
from .types import DjangoDebug
|
||||
|
||||
__all__ = ['DjangoDebugMiddleware', 'DjangoDebug']
|
||||
__all__ = ["DjangoDebugMiddleware", "DjangoDebug"]
|
||||
|
|
|
@ -7,7 +7,6 @@ from .types import DjangoDebug
|
|||
|
||||
|
||||
class DjangoDebugContext(object):
|
||||
|
||||
def __init__(self):
|
||||
self.debug_promise = None
|
||||
self.promises = []
|
||||
|
@ -38,20 +37,21 @@ class DjangoDebugContext(object):
|
|||
|
||||
|
||||
class DjangoDebugMiddleware(object):
|
||||
|
||||
def resolve(self, next, root, info, **args):
|
||||
context = info.context
|
||||
django_debug = getattr(context, 'django_debug', None)
|
||||
django_debug = getattr(context, "django_debug", None)
|
||||
if not django_debug:
|
||||
if context is None:
|
||||
raise Exception('DjangoDebug cannot be executed in None contexts')
|
||||
raise Exception("DjangoDebug cannot be executed in None contexts")
|
||||
try:
|
||||
context.django_debug = DjangoDebugContext()
|
||||
except Exception:
|
||||
raise Exception('DjangoDebug need the context to be writable, context received: {}.'.format(
|
||||
context.__class__.__name__
|
||||
))
|
||||
if info.schema.get_type('DjangoDebug') == info.return_type:
|
||||
raise Exception(
|
||||
"DjangoDebug need the context to be writable, context received: {}.".format(
|
||||
context.__class__.__name__
|
||||
)
|
||||
)
|
||||
if info.schema.get_type("DjangoDebug") == info.return_type:
|
||||
return context.django_debug.get_debug_promise()
|
||||
promise = next(root, info, **args)
|
||||
context.django_debug.add_promise(promise)
|
||||
|
|
|
@ -16,7 +16,6 @@ class SQLQueryTriggered(Exception):
|
|||
|
||||
|
||||
class ThreadLocalState(local):
|
||||
|
||||
def __init__(self):
|
||||
self.enabled = True
|
||||
|
||||
|
@ -35,7 +34,7 @@ recording = state.recording # export function
|
|||
|
||||
|
||||
def wrap_cursor(connection, panel):
|
||||
if not hasattr(connection, '_graphene_cursor'):
|
||||
if not hasattr(connection, "_graphene_cursor"):
|
||||
connection._graphene_cursor = connection.cursor
|
||||
|
||||
def cursor():
|
||||
|
@ -46,7 +45,7 @@ def wrap_cursor(connection, panel):
|
|||
|
||||
|
||||
def unwrap_cursor(connection):
|
||||
if hasattr(connection, '_graphene_cursor'):
|
||||
if hasattr(connection, "_graphene_cursor"):
|
||||
previous_cursor = connection._graphene_cursor
|
||||
connection.cursor = previous_cursor
|
||||
del connection._graphene_cursor
|
||||
|
@ -87,15 +86,14 @@ class NormalCursorWrapper(object):
|
|||
if not params:
|
||||
return params
|
||||
if isinstance(params, dict):
|
||||
return dict((key, self._quote_expr(value))
|
||||
for key, value in params.items())
|
||||
return dict((key, self._quote_expr(value)) for key, value in params.items())
|
||||
return list(map(self._quote_expr, params))
|
||||
|
||||
def _decode(self, param):
|
||||
try:
|
||||
return force_text(param, strings_only=True)
|
||||
except UnicodeDecodeError:
|
||||
return '(encoded string)'
|
||||
return "(encoded string)"
|
||||
|
||||
def _record(self, method, sql, params):
|
||||
start_time = time()
|
||||
|
@ -103,45 +101,48 @@ class NormalCursorWrapper(object):
|
|||
return method(sql, params)
|
||||
finally:
|
||||
stop_time = time()
|
||||
duration = (stop_time - start_time)
|
||||
_params = ''
|
||||
duration = stop_time - start_time
|
||||
_params = ""
|
||||
try:
|
||||
_params = json.dumps(list(map(self._decode, params)))
|
||||
except Exception:
|
||||
pass # object not JSON serializable
|
||||
|
||||
alias = getattr(self.db, 'alias', 'default')
|
||||
alias = getattr(self.db, "alias", "default")
|
||||
conn = self.db.connection
|
||||
vendor = getattr(conn, 'vendor', 'unknown')
|
||||
vendor = getattr(conn, "vendor", "unknown")
|
||||
|
||||
params = {
|
||||
'vendor': vendor,
|
||||
'alias': alias,
|
||||
'sql': self.db.ops.last_executed_query(
|
||||
self.cursor, sql, self._quote_params(params)),
|
||||
'duration': duration,
|
||||
'raw_sql': sql,
|
||||
'params': _params,
|
||||
'start_time': start_time,
|
||||
'stop_time': stop_time,
|
||||
'is_slow': duration > 10,
|
||||
'is_select': sql.lower().strip().startswith('select'),
|
||||
"vendor": vendor,
|
||||
"alias": alias,
|
||||
"sql": self.db.ops.last_executed_query(
|
||||
self.cursor, sql, self._quote_params(params)
|
||||
),
|
||||
"duration": duration,
|
||||
"raw_sql": sql,
|
||||
"params": _params,
|
||||
"start_time": start_time,
|
||||
"stop_time": stop_time,
|
||||
"is_slow": duration > 10,
|
||||
"is_select": sql.lower().strip().startswith("select"),
|
||||
}
|
||||
|
||||
if vendor == 'postgresql':
|
||||
if vendor == "postgresql":
|
||||
# If an erroneous query was ran on the connection, it might
|
||||
# be in a state where checking isolation_level raises an
|
||||
# exception.
|
||||
try:
|
||||
iso_level = conn.isolation_level
|
||||
except conn.InternalError:
|
||||
iso_level = 'unknown'
|
||||
params.update({
|
||||
'trans_id': self.logger.get_transaction_id(alias),
|
||||
'trans_status': conn.get_transaction_status(),
|
||||
'iso_level': iso_level,
|
||||
'encoding': conn.encoding,
|
||||
})
|
||||
iso_level = "unknown"
|
||||
params.update(
|
||||
{
|
||||
"trans_id": self.logger.get_transaction_id(alias),
|
||||
"trans_status": conn.get_transaction_status(),
|
||||
"iso_level": iso_level,
|
||||
"encoding": conn.encoding,
|
||||
}
|
||||
)
|
||||
|
||||
_sql = DjangoDebugSQL(**params)
|
||||
# We keep `sql` to maintain backwards compatibility
|
||||
|
|
|
@ -12,31 +12,31 @@ from ..types import DjangoDebug
|
|||
class context(object):
|
||||
pass
|
||||
|
||||
|
||||
# from examples.starwars_django.models import Character
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
||||
|
||||
def test_should_query_field():
|
||||
r1 = Reporter(last_name='ABA')
|
||||
r1 = Reporter(last_name="ABA")
|
||||
r1.save()
|
||||
r2 = Reporter(last_name='Griffin')
|
||||
r2 = Reporter(last_name="Griffin")
|
||||
r2.save()
|
||||
|
||||
class ReporterType(DjangoObjectType):
|
||||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node, )
|
||||
interfaces = (Node,)
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
reporter = graphene.Field(ReporterType)
|
||||
debug = graphene.Field(DjangoDebug, name='__debug')
|
||||
debug = graphene.Field(DjangoDebug, name="__debug")
|
||||
|
||||
def resolve_reporter(self, info, **args):
|
||||
return Reporter.objects.first()
|
||||
|
||||
query = '''
|
||||
query = """
|
||||
query ReporterQuery {
|
||||
reporter {
|
||||
lastName
|
||||
|
@ -47,43 +47,40 @@ def test_should_query_field():
|
|||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
"""
|
||||
expected = {
|
||||
'reporter': {
|
||||
'lastName': 'ABA',
|
||||
"reporter": {"lastName": "ABA"},
|
||||
"__debug": {
|
||||
"sql": [{"rawSql": str(Reporter.objects.order_by("pk")[:1].query)}]
|
||||
},
|
||||
'__debug': {
|
||||
'sql': [{
|
||||
'rawSql': str(Reporter.objects.order_by('pk')[:1].query)
|
||||
}]
|
||||
}
|
||||
}
|
||||
schema = graphene.Schema(query=Query)
|
||||
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
|
||||
result = schema.execute(
|
||||
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
|
||||
|
||||
def test_should_query_list():
|
||||
r1 = Reporter(last_name='ABA')
|
||||
r1 = Reporter(last_name="ABA")
|
||||
r1.save()
|
||||
r2 = Reporter(last_name='Griffin')
|
||||
r2 = Reporter(last_name="Griffin")
|
||||
r2.save()
|
||||
|
||||
class ReporterType(DjangoObjectType):
|
||||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node, )
|
||||
interfaces = (Node,)
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
all_reporters = graphene.List(ReporterType)
|
||||
debug = graphene.Field(DjangoDebug, name='__debug')
|
||||
debug = graphene.Field(DjangoDebug, name="__debug")
|
||||
|
||||
def resolve_all_reporters(self, info, **args):
|
||||
return Reporter.objects.all()
|
||||
|
||||
query = '''
|
||||
query = """
|
||||
query ReporterQuery {
|
||||
allReporters {
|
||||
lastName
|
||||
|
@ -94,45 +91,38 @@ def test_should_query_list():
|
|||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
"""
|
||||
expected = {
|
||||
'allReporters': [{
|
||||
'lastName': 'ABA',
|
||||
}, {
|
||||
'lastName': 'Griffin',
|
||||
}],
|
||||
'__debug': {
|
||||
'sql': [{
|
||||
'rawSql': str(Reporter.objects.all().query)
|
||||
}]
|
||||
}
|
||||
"allReporters": [{"lastName": "ABA"}, {"lastName": "Griffin"}],
|
||||
"__debug": {"sql": [{"rawSql": str(Reporter.objects.all().query)}]},
|
||||
}
|
||||
schema = graphene.Schema(query=Query)
|
||||
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
|
||||
result = schema.execute(
|
||||
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
|
||||
|
||||
def test_should_query_connection():
|
||||
r1 = Reporter(last_name='ABA')
|
||||
r1 = Reporter(last_name="ABA")
|
||||
r1.save()
|
||||
r2 = Reporter(last_name='Griffin')
|
||||
r2 = Reporter(last_name="Griffin")
|
||||
r2.save()
|
||||
|
||||
class ReporterType(DjangoObjectType):
|
||||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node, )
|
||||
interfaces = (Node,)
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
all_reporters = DjangoConnectionField(ReporterType)
|
||||
debug = graphene.Field(DjangoDebug, name='__debug')
|
||||
debug = graphene.Field(DjangoDebug, name="__debug")
|
||||
|
||||
def resolve_all_reporters(self, info, **args):
|
||||
return Reporter.objects.all()
|
||||
|
||||
query = '''
|
||||
query = """
|
||||
query ReporterQuery {
|
||||
allReporters(first:1) {
|
||||
edges {
|
||||
|
@ -147,48 +137,41 @@ def test_should_query_connection():
|
|||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
expected = {
|
||||
'allReporters': {
|
||||
'edges': [{
|
||||
'node': {
|
||||
'lastName': 'ABA',
|
||||
}
|
||||
}]
|
||||
},
|
||||
}
|
||||
"""
|
||||
expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
|
||||
schema = graphene.Schema(query=Query)
|
||||
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
|
||||
result = schema.execute(
|
||||
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data['allReporters'] == expected['allReporters']
|
||||
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
|
||||
assert result.data["allReporters"] == expected["allReporters"]
|
||||
assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
|
||||
query = str(Reporter.objects.all()[:1].query)
|
||||
assert result.data['__debug']['sql'][1]['rawSql'] == query
|
||||
assert result.data["__debug"]["sql"][1]["rawSql"] == query
|
||||
|
||||
|
||||
def test_should_query_connectionfilter():
|
||||
from ...filter import DjangoFilterConnectionField
|
||||
|
||||
r1 = Reporter(last_name='ABA')
|
||||
r1 = Reporter(last_name="ABA")
|
||||
r1.save()
|
||||
r2 = Reporter(last_name='Griffin')
|
||||
r2 = Reporter(last_name="Griffin")
|
||||
r2.save()
|
||||
|
||||
class ReporterType(DjangoObjectType):
|
||||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node, )
|
||||
interfaces = (Node,)
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
all_reporters = DjangoFilterConnectionField(ReporterType, fields=['last_name'])
|
||||
all_reporters = DjangoFilterConnectionField(ReporterType, fields=["last_name"])
|
||||
s = graphene.String(resolver=lambda *_: "S")
|
||||
debug = graphene.Field(DjangoDebug, name='__debug')
|
||||
debug = graphene.Field(DjangoDebug, name="__debug")
|
||||
|
||||
def resolve_all_reporters(self, info, **args):
|
||||
return Reporter.objects.all()
|
||||
|
||||
query = '''
|
||||
query = """
|
||||
query ReporterQuery {
|
||||
allReporters(first:1) {
|
||||
edges {
|
||||
|
@ -203,20 +186,14 @@ def test_should_query_connectionfilter():
|
|||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
expected = {
|
||||
'allReporters': {
|
||||
'edges': [{
|
||||
'node': {
|
||||
'lastName': 'ABA',
|
||||
}
|
||||
}]
|
||||
},
|
||||
}
|
||||
"""
|
||||
expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
|
||||
schema = graphene.Schema(query=Query)
|
||||
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
|
||||
result = schema.execute(
|
||||
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
|
||||
)
|
||||
assert not result.errors
|
||||
assert result.data['allReporters'] == expected['allReporters']
|
||||
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
|
||||
assert result.data["allReporters"] == expected["allReporters"]
|
||||
assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
|
||||
query = str(Reporter.objects.all()[:1].query)
|
||||
assert result.data['__debug']['sql'][1]['rawSql'] == query
|
||||
assert result.data["__debug"]["sql"][1]["rawSql"] == query
|
||||
|
|
|
@ -13,7 +13,6 @@ from .utils import maybe_queryset
|
|||
|
||||
|
||||
class DjangoListField(Field):
|
||||
|
||||
def __init__(self, _type, *args, **kwargs):
|
||||
super(DjangoListField, self).__init__(List(_type), *args, **kwargs)
|
||||
|
||||
|
@ -30,25 +29,28 @@ class DjangoListField(Field):
|
|||
|
||||
|
||||
class DjangoConnectionField(ConnectionField):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.on = kwargs.pop('on', False)
|
||||
self.on = kwargs.pop("on", False)
|
||||
self.max_limit = kwargs.pop(
|
||||
'max_limit',
|
||||
graphene_settings.RELAY_CONNECTION_MAX_LIMIT
|
||||
"max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
|
||||
)
|
||||
self.enforce_first_or_last = kwargs.pop(
|
||||
'enforce_first_or_last',
|
||||
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST
|
||||
"enforce_first_or_last",
|
||||
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
|
||||
)
|
||||
super(DjangoConnectionField, self).__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
from .types import DjangoObjectType
|
||||
|
||||
_type = super(ConnectionField, self).type
|
||||
assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types"
|
||||
assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__)
|
||||
assert issubclass(
|
||||
_type, DjangoObjectType
|
||||
), "DjangoConnectionField only accepts DjangoObjectType types"
|
||||
assert _type._meta.connection, "The type {} doesn't have a connection".format(
|
||||
_type.__name__
|
||||
)
|
||||
return _type._meta.connection
|
||||
|
||||
@property
|
||||
|
@ -100,28 +102,37 @@ class DjangoConnectionField(ConnectionField):
|
|||
return connection
|
||||
|
||||
@classmethod
|
||||
def connection_resolver(cls, resolver, connection, default_manager, max_limit,
|
||||
enforce_first_or_last, root, info, **args):
|
||||
first = args.get('first')
|
||||
last = args.get('last')
|
||||
def connection_resolver(
|
||||
cls,
|
||||
resolver,
|
||||
connection,
|
||||
default_manager,
|
||||
max_limit,
|
||||
enforce_first_or_last,
|
||||
root,
|
||||
info,
|
||||
**args
|
||||
):
|
||||
first = args.get("first")
|
||||
last = args.get("last")
|
||||
|
||||
if enforce_first_or_last:
|
||||
assert first or last, (
|
||||
'You must provide a `first` or `last` value to properly paginate the `{}` connection.'
|
||||
"You must provide a `first` or `last` value to properly paginate the `{}` connection."
|
||||
).format(info.field_name)
|
||||
|
||||
if max_limit:
|
||||
if first:
|
||||
assert first <= max_limit, (
|
||||
'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.'
|
||||
"Requesting {} records on the `{}` connection exceeds the `first` limit of {} records."
|
||||
).format(first, info.field_name, max_limit)
|
||||
args['first'] = min(first, max_limit)
|
||||
args["first"] = min(first, max_limit)
|
||||
|
||||
if last:
|
||||
assert last <= max_limit, (
|
||||
'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.'
|
||||
"Requesting {} records on the `{}` connection exceeds the `last` limit of {} records."
|
||||
).format(last, info.field_name, max_limit)
|
||||
args['last'] = min(last, max_limit)
|
||||
args["last"] = min(last, max_limit)
|
||||
|
||||
iterable = resolver(root, info, **args)
|
||||
on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
|
||||
|
@ -138,5 +149,5 @@ class DjangoConnectionField(ConnectionField):
|
|||
self.type,
|
||||
self.get_manager(),
|
||||
self.max_limit,
|
||||
self.enforce_first_or_last
|
||||
self.enforce_first_or_last,
|
||||
)
|
||||
|
|
|
@ -4,11 +4,15 @@ from ..utils import DJANGO_FILTER_INSTALLED
|
|||
if not DJANGO_FILTER_INSTALLED:
|
||||
warnings.warn(
|
||||
"Use of django filtering requires the django-filter package "
|
||||
"be installed. You can do so using `pip install django-filter`", ImportWarning
|
||||
"be installed. You can do so using `pip install django-filter`",
|
||||
ImportWarning,
|
||||
)
|
||||
else:
|
||||
from .fields import DjangoFilterConnectionField
|
||||
from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter
|
||||
|
||||
__all__ = ['DjangoFilterConnectionField',
|
||||
'GlobalIDFilter', 'GlobalIDMultipleChoiceFilter']
|
||||
__all__ = [
|
||||
"DjangoFilterConnectionField",
|
||||
"GlobalIDFilter",
|
||||
"GlobalIDMultipleChoiceFilter",
|
||||
]
|
||||
|
|
|
@ -7,10 +7,16 @@ from .utils import get_filtering_args_from_filterset, get_filterset_class
|
|||
|
||||
|
||||
class DjangoFilterConnectionField(DjangoConnectionField):
|
||||
|
||||
def __init__(self, type, fields=None, order_by=None,
|
||||
extra_filter_meta=None, filterset_class=None,
|
||||
*args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
type,
|
||||
fields=None,
|
||||
order_by=None,
|
||||
extra_filter_meta=None,
|
||||
filterset_class=None,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
self._fields = fields
|
||||
self._provided_filterset_class = filterset_class
|
||||
self._filterset_class = None
|
||||
|
@ -30,12 +36,13 @@ class DjangoFilterConnectionField(DjangoConnectionField):
|
|||
def filterset_class(self):
|
||||
if not self._filterset_class:
|
||||
fields = self._fields or self.node_type._meta.filter_fields
|
||||
meta = dict(model=self.model,
|
||||
fields=fields)
|
||||
meta = dict(model=self.model, fields=fields)
|
||||
if self._extra_filter_meta:
|
||||
meta.update(self._extra_filter_meta)
|
||||
|
||||
self._filterset_class = get_filterset_class(self._provided_filterset_class, **meta)
|
||||
self._filterset_class = get_filterset_class(
|
||||
self._provided_filterset_class, **meta
|
||||
)
|
||||
|
||||
return self._filterset_class
|
||||
|
||||
|
@ -52,28 +59,40 @@ class DjangoFilterConnectionField(DjangoConnectionField):
|
|||
|
||||
# See related PR: https://github.com/graphql-python/graphene-django/pull/126
|
||||
|
||||
assert not (default_queryset.query.low_mark and queryset.query.low_mark), (
|
||||
'Received two sliced querysets (low mark) in the connection, please slice only in one.'
|
||||
)
|
||||
assert not (default_queryset.query.high_mark and queryset.query.high_mark), (
|
||||
'Received two sliced querysets (high mark) in the connection, please slice only in one.'
|
||||
)
|
||||
assert not (
|
||||
default_queryset.query.low_mark and queryset.query.low_mark
|
||||
), "Received two sliced querysets (low mark) in the connection, please slice only in one."
|
||||
assert not (
|
||||
default_queryset.query.high_mark and queryset.query.high_mark
|
||||
), "Received two sliced querysets (high mark) in the connection, please slice only in one."
|
||||
low = default_queryset.query.low_mark or queryset.query.low_mark
|
||||
high = default_queryset.query.high_mark or queryset.query.high_mark
|
||||
default_queryset.query.clear_limits()
|
||||
queryset = super(DjangoFilterConnectionField, cls).merge_querysets(default_queryset, queryset)
|
||||
queryset = super(DjangoFilterConnectionField, cls).merge_querysets(
|
||||
default_queryset, queryset
|
||||
)
|
||||
queryset.query.set_limits(low, high)
|
||||
return queryset
|
||||
|
||||
@classmethod
|
||||
def connection_resolver(cls, resolver, connection, default_manager, max_limit,
|
||||
enforce_first_or_last, filterset_class, filtering_args,
|
||||
root, info, **args):
|
||||
def connection_resolver(
|
||||
cls,
|
||||
resolver,
|
||||
connection,
|
||||
default_manager,
|
||||
max_limit,
|
||||
enforce_first_or_last,
|
||||
filterset_class,
|
||||
filtering_args,
|
||||
root,
|
||||
info,
|
||||
**args
|
||||
):
|
||||
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
|
||||
qs = filterset_class(
|
||||
data=filter_kwargs,
|
||||
queryset=default_manager.get_queryset(),
|
||||
request=info.context
|
||||
request=info.context,
|
||||
).qs
|
||||
|
||||
return super(DjangoFilterConnectionField, cls).connection_resolver(
|
||||
|
@ -96,5 +115,5 @@ class DjangoFilterConnectionField(DjangoConnectionField):
|
|||
self.max_limit,
|
||||
self.enforce_first_or_last,
|
||||
self.filterset_class,
|
||||
self.filtering_args
|
||||
self.filtering_args,
|
||||
)
|
||||
|
|
|
@ -28,26 +28,19 @@ class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
|
|||
|
||||
|
||||
GRAPHENE_FILTER_SET_OVERRIDES = {
|
||||
models.AutoField: {
|
||||
'filter_class': GlobalIDFilter,
|
||||
},
|
||||
models.OneToOneField: {
|
||||
'filter_class': GlobalIDFilter,
|
||||
},
|
||||
models.ForeignKey: {
|
||||
'filter_class': GlobalIDFilter,
|
||||
},
|
||||
models.ManyToManyField: {
|
||||
'filter_class': GlobalIDMultipleChoiceFilter,
|
||||
}
|
||||
models.AutoField: {"filter_class": GlobalIDFilter},
|
||||
models.OneToOneField: {"filter_class": GlobalIDFilter},
|
||||
models.ForeignKey: {"filter_class": GlobalIDFilter},
|
||||
models.ManyToManyField: {"filter_class": GlobalIDMultipleChoiceFilter},
|
||||
}
|
||||
|
||||
|
||||
class GrapheneFilterSetMixin(BaseFilterSet):
|
||||
FILTER_DEFAULTS = dict(itertools.chain(
|
||||
FILTER_FOR_DBFIELD_DEFAULTS.items(),
|
||||
GRAPHENE_FILTER_SET_OVERRIDES.items()
|
||||
))
|
||||
FILTER_DEFAULTS = dict(
|
||||
itertools.chain(
|
||||
FILTER_FOR_DBFIELD_DEFAULTS.items(), GRAPHENE_FILTER_SET_OVERRIDES.items()
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def filter_for_reverse_field(cls, f, name):
|
||||
|
@ -62,10 +55,7 @@ class GrapheneFilterSetMixin(BaseFilterSet):
|
|||
except AttributeError:
|
||||
rel = f.field.rel
|
||||
|
||||
default = {
|
||||
'name': name,
|
||||
'label': capfirst(rel.related_name)
|
||||
}
|
||||
default = {"name": name, "label": capfirst(rel.related_name)}
|
||||
if rel.multiple:
|
||||
# For to-many relationships
|
||||
return GlobalIDMultipleChoiceFilter(**default)
|
||||
|
@ -78,25 +68,20 @@ def setup_filterset(filterset_class):
|
|||
""" Wrap a provided filterset in Graphene-specific functionality
|
||||
"""
|
||||
return type(
|
||||
'Graphene{}'.format(filterset_class.__name__),
|
||||
"Graphene{}".format(filterset_class.__name__),
|
||||
(filterset_class, GrapheneFilterSetMixin),
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
def custom_filterset_factory(model, filterset_base_class=FilterSet,
|
||||
**meta):
|
||||
def custom_filterset_factory(model, filterset_base_class=FilterSet, **meta):
|
||||
""" Create a filterset for the given model using the provided meta data
|
||||
"""
|
||||
meta.update({
|
||||
'model': model,
|
||||
})
|
||||
meta_class = type(str('Meta'), (object,), meta)
|
||||
meta.update({"model": model})
|
||||
meta_class = type(str("Meta"), (object,), meta)
|
||||
filterset = type(
|
||||
str('%sFilterSet' % model._meta.object_name),
|
||||
str("%sFilterSet" % model._meta.object_name),
|
||||
(filterset_base_class, GrapheneFilterSetMixin),
|
||||
{
|
||||
'Meta': meta_class
|
||||
}
|
||||
{"Meta": meta_class},
|
||||
)
|
||||
return filterset
|
||||
|
|
|
@ -5,29 +5,26 @@ from graphene_django.tests.models import Article, Pet, Reporter
|
|||
|
||||
|
||||
class ArticleFilter(django_filters.FilterSet):
|
||||
|
||||
class Meta:
|
||||
model = Article
|
||||
fields = {
|
||||
'headline': ['exact', 'icontains'],
|
||||
'pub_date': ['gt', 'lt', 'exact'],
|
||||
'reporter': ['exact'],
|
||||
"headline": ["exact", "icontains"],
|
||||
"pub_date": ["gt", "lt", "exact"],
|
||||
"reporter": ["exact"],
|
||||
}
|
||||
|
||||
order_by = OrderingFilter(fields=('pub_date',))
|
||||
order_by = OrderingFilter(fields=("pub_date",))
|
||||
|
||||
|
||||
class ReporterFilter(django_filters.FilterSet):
|
||||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
fields = ['first_name', 'last_name', 'email', 'pets']
|
||||
fields = ["first_name", "last_name", "email", "pets"]
|
||||
|
||||
order_by = OrderingFilter(fields=('pub_date',))
|
||||
order_by = OrderingFilter(fields=("pub_date",))
|
||||
|
||||
|
||||
class PetFilter(django_filters.FilterSet):
|
||||
|
||||
class Meta:
|
||||
model = Pet
|
||||
fields = ['name']
|
||||
fields = ["name"]
|
||||
|
|
|
@ -5,8 +5,7 @@ import pytest
|
|||
from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String
|
||||
from graphene.relay import Node
|
||||
from graphene_django import DjangoObjectType
|
||||
from graphene_django.forms import (GlobalIDFormField,
|
||||
GlobalIDMultipleChoiceField)
|
||||
from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField
|
||||
from graphene_django.tests.models import Article, Pet, Reporter
|
||||
from graphene_django.utils import DJANGO_FILTER_INSTALLED
|
||||
|
||||
|
@ -20,36 +19,43 @@ if DJANGO_FILTER_INSTALLED:
|
|||
import django_filters
|
||||
from django_filters import FilterSet, NumberFilter
|
||||
|
||||
from graphene_django.filter import (GlobalIDFilter, DjangoFilterConnectionField,
|
||||
GlobalIDMultipleChoiceFilter)
|
||||
from graphene_django.filter.tests.filters import ArticleFilter, PetFilter, ReporterFilter
|
||||
from graphene_django.filter import (
|
||||
GlobalIDFilter,
|
||||
DjangoFilterConnectionField,
|
||||
GlobalIDMultipleChoiceFilter,
|
||||
)
|
||||
from graphene_django.filter.tests.filters import (
|
||||
ArticleFilter,
|
||||
PetFilter,
|
||||
ReporterFilter,
|
||||
)
|
||||
else:
|
||||
pytestmark.append(pytest.mark.skipif(True, reason='django_filters not installed or not compatible'))
|
||||
pytestmark.append(
|
||||
pytest.mark.skipif(
|
||||
True, reason="django_filters not installed or not compatible"
|
||||
)
|
||||
)
|
||||
|
||||
pytestmark.append(pytest.mark.django_db)
|
||||
|
||||
|
||||
if DJANGO_FILTER_INSTALLED:
|
||||
class ArticleNode(DjangoObjectType):
|
||||
|
||||
class ArticleNode(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Article
|
||||
interfaces = (Node, )
|
||||
filter_fields = ('headline', )
|
||||
|
||||
interfaces = (Node,)
|
||||
filter_fields = ("headline",)
|
||||
|
||||
class ReporterNode(DjangoObjectType):
|
||||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node, )
|
||||
|
||||
interfaces = (Node,)
|
||||
|
||||
class PetNode(DjangoObjectType):
|
||||
|
||||
class Meta:
|
||||
model = Pet
|
||||
interfaces = (Node, )
|
||||
interfaces = (Node,)
|
||||
|
||||
# schema = Schema()
|
||||
|
||||
|
@ -59,58 +65,47 @@ def get_args(field):
|
|||
|
||||
|
||||
def assert_arguments(field, *arguments):
|
||||
ignore = ('after', 'before', 'first', 'last', 'order_by')
|
||||
ignore = ("after", "before", "first", "last", "order_by")
|
||||
args = get_args(field)
|
||||
actual = [
|
||||
name
|
||||
for name in args
|
||||
if name not in ignore and not name.startswith('_')
|
||||
]
|
||||
assert set(arguments) == set(actual), \
|
||||
'Expected arguments ({}) did not match actual ({})'.format(
|
||||
arguments,
|
||||
actual
|
||||
)
|
||||
actual = [name for name in args if name not in ignore and not name.startswith("_")]
|
||||
assert set(arguments) == set(
|
||||
actual
|
||||
), "Expected arguments ({}) did not match actual ({})".format(arguments, actual)
|
||||
|
||||
|
||||
def assert_orderable(field):
|
||||
args = get_args(field)
|
||||
assert 'order_by' in args, \
|
||||
'Field cannot be ordered'
|
||||
assert "order_by" in args, "Field cannot be ordered"
|
||||
|
||||
|
||||
def assert_not_orderable(field):
|
||||
args = get_args(field)
|
||||
assert 'order_by' not in args, \
|
||||
'Field can be ordered'
|
||||
assert "order_by" not in args, "Field can be ordered"
|
||||
|
||||
|
||||
def test_filter_explicit_filterset_arguments():
|
||||
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleFilter)
|
||||
assert_arguments(field,
|
||||
'headline', 'headline__icontains',
|
||||
'pub_date', 'pub_date__gt', 'pub_date__lt',
|
||||
'reporter',
|
||||
)
|
||||
assert_arguments(
|
||||
field,
|
||||
"headline",
|
||||
"headline__icontains",
|
||||
"pub_date",
|
||||
"pub_date__gt",
|
||||
"pub_date__lt",
|
||||
"reporter",
|
||||
)
|
||||
|
||||
|
||||
def test_filter_shortcut_filterset_arguments_list():
|
||||
field = DjangoFilterConnectionField(ArticleNode, fields=['pub_date', 'reporter'])
|
||||
assert_arguments(field,
|
||||
'pub_date',
|
||||
'reporter',
|
||||
)
|
||||
field = DjangoFilterConnectionField(ArticleNode, fields=["pub_date", "reporter"])
|
||||
assert_arguments(field, "pub_date", "reporter")
|
||||
|
||||
|
||||
def test_filter_shortcut_filterset_arguments_dict():
|
||||
field = DjangoFilterConnectionField(ArticleNode, fields={
|
||||
'headline': ['exact', 'icontains'],
|
||||
'reporter': ['exact'],
|
||||
})
|
||||
assert_arguments(field,
|
||||
'headline', 'headline__icontains',
|
||||
'reporter',
|
||||
)
|
||||
field = DjangoFilterConnectionField(
|
||||
ArticleNode, fields={"headline": ["exact", "icontains"], "reporter": ["exact"]}
|
||||
)
|
||||
assert_arguments(field, "headline", "headline__icontains", "reporter")
|
||||
|
||||
|
||||
def test_filter_explicit_filterset_orderable():
|
||||
|
@ -134,15 +129,14 @@ def test_filter_explicit_filterset_not_orderable():
|
|||
|
||||
|
||||
def test_filter_shortcut_filterset_extra_meta():
|
||||
field = DjangoFilterConnectionField(ArticleNode, extra_filter_meta={
|
||||
'exclude': ('headline', )
|
||||
})
|
||||
assert 'headline' not in field.filterset_class.get_fields()
|
||||
field = DjangoFilterConnectionField(
|
||||
ArticleNode, extra_filter_meta={"exclude": ("headline",)}
|
||||
)
|
||||
assert "headline" not in field.filterset_class.get_fields()
|
||||
|
||||
|
||||
def test_filter_shortcut_filterset_context():
|
||||
class ArticleContextFilter(django_filters.FilterSet):
|
||||
|
||||
class Meta:
|
||||
model = Article
|
||||
exclude = set()
|
||||
|
@ -153,17 +147,31 @@ def test_filter_shortcut_filterset_context():
|
|||
return qs.filter(reporter=self.request.reporter)
|
||||
|
||||
class Query(ObjectType):
|
||||
context_articles = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleContextFilter)
|
||||
context_articles = DjangoFilterConnectionField(
|
||||
ArticleNode, filterset_class=ArticleContextFilter
|
||||
)
|
||||
|
||||
r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com')
|
||||
r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com')
|
||||
Article.objects.create(headline='a1', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r1, editor=r1)
|
||||
Article.objects.create(headline='a2', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r2, editor=r2)
|
||||
r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
|
||||
r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
|
||||
Article.objects.create(
|
||||
headline="a1",
|
||||
pub_date=datetime.now(),
|
||||
pub_date_time=datetime.now(),
|
||||
reporter=r1,
|
||||
editor=r1,
|
||||
)
|
||||
Article.objects.create(
|
||||
headline="a2",
|
||||
pub_date=datetime.now(),
|
||||
pub_date_time=datetime.now(),
|
||||
reporter=r2,
|
||||
editor=r2,
|
||||
)
|
||||
|
||||
class context(object):
|
||||
reporter = r2
|
||||
|
||||
query = '''
|
||||
query = """
|
||||
query {
|
||||
contextArticles {
|
||||
edges {
|
||||
|
@ -173,42 +181,39 @@ def test_filter_shortcut_filterset_context():
|
|||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
"""
|
||||
schema = Schema(query=Query)
|
||||
result = schema.execute(query, context_value=context())
|
||||
assert not result.errors
|
||||
|
||||
assert len(result.data['contextArticles']['edges']) == 1
|
||||
assert result.data['contextArticles']['edges'][0]['node']['headline'] == 'a2'
|
||||
assert len(result.data["contextArticles"]["edges"]) == 1
|
||||
assert result.data["contextArticles"]["edges"][0]["node"]["headline"] == "a2"
|
||||
|
||||
|
||||
def test_filter_filterset_information_on_meta():
|
||||
class ReporterFilterNode(DjangoObjectType):
|
||||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node, )
|
||||
filter_fields = ['first_name', 'articles']
|
||||
interfaces = (Node,)
|
||||
filter_fields = ["first_name", "articles"]
|
||||
|
||||
field = DjangoFilterConnectionField(ReporterFilterNode)
|
||||
assert_arguments(field, 'first_name', 'articles')
|
||||
assert_arguments(field, "first_name", "articles")
|
||||
assert_not_orderable(field)
|
||||
|
||||
|
||||
def test_filter_filterset_information_on_meta_related():
|
||||
class ReporterFilterNode(DjangoObjectType):
|
||||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node, )
|
||||
filter_fields = ['first_name', 'articles']
|
||||
interfaces = (Node,)
|
||||
filter_fields = ["first_name", "articles"]
|
||||
|
||||
class ArticleFilterNode(DjangoObjectType):
|
||||
|
||||
class Meta:
|
||||
model = Article
|
||||
interfaces = (Node, )
|
||||
filter_fields = ['headline', 'reporter']
|
||||
interfaces = (Node,)
|
||||
filter_fields = ["headline", "reporter"]
|
||||
|
||||
class Query(ObjectType):
|
||||
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
|
||||
|
@ -217,25 +222,23 @@ def test_filter_filterset_information_on_meta_related():
|
|||
article = Field(ArticleFilterNode)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
articles_field = ReporterFilterNode._meta.fields['articles'].get_type()
|
||||
assert_arguments(articles_field, 'headline', 'reporter')
|
||||
articles_field = ReporterFilterNode._meta.fields["articles"].get_type()
|
||||
assert_arguments(articles_field, "headline", "reporter")
|
||||
assert_not_orderable(articles_field)
|
||||
|
||||
|
||||
def test_filter_filterset_related_results():
|
||||
class ReporterFilterNode(DjangoObjectType):
|
||||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node, )
|
||||
filter_fields = ['first_name', 'articles']
|
||||
interfaces = (Node,)
|
||||
filter_fields = ["first_name", "articles"]
|
||||
|
||||
class ArticleFilterNode(DjangoObjectType):
|
||||
|
||||
class Meta:
|
||||
interfaces = (Node, )
|
||||
interfaces = (Node,)
|
||||
model = Article
|
||||
filter_fields = ['headline', 'reporter']
|
||||
filter_fields = ["headline", "reporter"]
|
||||
|
||||
class Query(ObjectType):
|
||||
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
|
||||
|
@ -243,12 +246,22 @@ def test_filter_filterset_related_results():
|
|||
reporter = Field(ReporterFilterNode)
|
||||
article = Field(ArticleFilterNode)
|
||||
|
||||
r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com')
|
||||
r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com')
|
||||
Article.objects.create(headline='a1', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r1)
|
||||
Article.objects.create(headline='a2', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r2)
|
||||
r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
|
||||
r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
|
||||
Article.objects.create(
|
||||
headline="a1",
|
||||
pub_date=datetime.now(),
|
||||
pub_date_time=datetime.now(),
|
||||
reporter=r1,
|
||||
)
|
||||
Article.objects.create(
|
||||
headline="a2",
|
||||
pub_date=datetime.now(),
|
||||
pub_date_time=datetime.now(),
|
||||
reporter=r2,
|
||||
)
|
||||
|
||||
query = '''
|
||||
query = """
|
||||
query {
|
||||
allReporters {
|
||||
edges {
|
||||
|
@ -264,123 +277,134 @@ def test_filter_filterset_related_results():
|
|||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
"""
|
||||
schema = Schema(query=Query)
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
# We should only get back a single article for each reporter
|
||||
assert len(result.data['allReporters']['edges'][0]['node']['articles']['edges']) == 1
|
||||
assert len(result.data['allReporters']['edges'][1]['node']['articles']['edges']) == 1
|
||||
assert (
|
||||
len(result.data["allReporters"]["edges"][0]["node"]["articles"]["edges"]) == 1
|
||||
)
|
||||
assert (
|
||||
len(result.data["allReporters"]["edges"][1]["node"]["articles"]["edges"]) == 1
|
||||
)
|
||||
|
||||
|
||||
def test_global_id_field_implicit():
|
||||
field = DjangoFilterConnectionField(ArticleNode, fields=['id'])
|
||||
field = DjangoFilterConnectionField(ArticleNode, fields=["id"])
|
||||
filterset_class = field.filterset_class
|
||||
id_filter = filterset_class.base_filters['id']
|
||||
id_filter = filterset_class.base_filters["id"]
|
||||
assert isinstance(id_filter, GlobalIDFilter)
|
||||
assert id_filter.field_class == GlobalIDFormField
|
||||
|
||||
|
||||
def test_global_id_field_explicit():
|
||||
class ArticleIdFilter(django_filters.FilterSet):
|
||||
|
||||
class Meta:
|
||||
model = Article
|
||||
fields = ['id']
|
||||
fields = ["id"]
|
||||
|
||||
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
|
||||
filterset_class = field.filterset_class
|
||||
id_filter = filterset_class.base_filters['id']
|
||||
id_filter = filterset_class.base_filters["id"]
|
||||
assert isinstance(id_filter, GlobalIDFilter)
|
||||
assert id_filter.field_class == GlobalIDFormField
|
||||
|
||||
|
||||
def test_filterset_descriptions():
|
||||
class ArticleIdFilter(django_filters.FilterSet):
|
||||
|
||||
class Meta:
|
||||
model = Article
|
||||
fields = ['id']
|
||||
fields = ["id"]
|
||||
|
||||
max_time = django_filters.NumberFilter(method='filter_max_time', label="The maximum time")
|
||||
max_time = django_filters.NumberFilter(
|
||||
method="filter_max_time", label="The maximum time"
|
||||
)
|
||||
|
||||
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
|
||||
max_time = field.args['max_time']
|
||||
max_time = field.args["max_time"]
|
||||
assert isinstance(max_time, Argument)
|
||||
assert max_time.type == Float
|
||||
assert max_time.description == 'The maximum time'
|
||||
assert max_time.description == "The maximum time"
|
||||
|
||||
|
||||
def test_global_id_field_relation():
|
||||
field = DjangoFilterConnectionField(ArticleNode, fields=['reporter'])
|
||||
field = DjangoFilterConnectionField(ArticleNode, fields=["reporter"])
|
||||
filterset_class = field.filterset_class
|
||||
id_filter = filterset_class.base_filters['reporter']
|
||||
id_filter = filterset_class.base_filters["reporter"]
|
||||
assert isinstance(id_filter, GlobalIDFilter)
|
||||
assert id_filter.field_class == GlobalIDFormField
|
||||
|
||||
|
||||
def test_global_id_multiple_field_implicit():
|
||||
field = DjangoFilterConnectionField(ReporterNode, fields=['pets'])
|
||||
field = DjangoFilterConnectionField(ReporterNode, fields=["pets"])
|
||||
filterset_class = field.filterset_class
|
||||
multiple_filter = filterset_class.base_filters['pets']
|
||||
multiple_filter = filterset_class.base_filters["pets"]
|
||||
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
|
||||
assert multiple_filter.field_class == GlobalIDMultipleChoiceField
|
||||
|
||||
|
||||
def test_global_id_multiple_field_explicit():
|
||||
class ReporterPetsFilter(django_filters.FilterSet):
|
||||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
fields = ['pets']
|
||||
fields = ["pets"]
|
||||
|
||||
field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter)
|
||||
field = DjangoFilterConnectionField(
|
||||
ReporterNode, filterset_class=ReporterPetsFilter
|
||||
)
|
||||
filterset_class = field.filterset_class
|
||||
multiple_filter = filterset_class.base_filters['pets']
|
||||
multiple_filter = filterset_class.base_filters["pets"]
|
||||
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
|
||||
assert multiple_filter.field_class == GlobalIDMultipleChoiceField
|
||||
|
||||
|
||||
def test_global_id_multiple_field_implicit_reverse():
|
||||
field = DjangoFilterConnectionField(ReporterNode, fields=['articles'])
|
||||
field = DjangoFilterConnectionField(ReporterNode, fields=["articles"])
|
||||
filterset_class = field.filterset_class
|
||||
multiple_filter = filterset_class.base_filters['articles']
|
||||
multiple_filter = filterset_class.base_filters["articles"]
|
||||
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
|
||||
assert multiple_filter.field_class == GlobalIDMultipleChoiceField
|
||||
|
||||
|
||||
def test_global_id_multiple_field_explicit_reverse():
|
||||
class ReporterPetsFilter(django_filters.FilterSet):
|
||||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
fields = ['articles']
|
||||
fields = ["articles"]
|
||||
|
||||
field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter)
|
||||
field = DjangoFilterConnectionField(
|
||||
ReporterNode, filterset_class=ReporterPetsFilter
|
||||
)
|
||||