Reformatted files using black

pull/475/head
Syrus Akbary 4 years ago
parent 96789b291f
commit 54ef52e1c6
  1. 16
      graphene_django/__init__.py
  2. 2
      graphene_django/compat.py
  3. 32
      graphene_django/converter.py
  4. 2
      graphene_django/debug/__init__.py
  5. 16
      graphene_django/debug/middleware.py
  6. 59
      graphene_django/debug/sql/tracking.py
  7. 125
      graphene_django/debug/tests/test_query.py
  8. 49
      graphene_django/fields.py
  9. 10
      graphene_django/filter/__init__.py
  10. 57
      graphene_django/filter/fields.py
  11. 47
      graphene_django/filter/filterset.py
  12. 17
      graphene_django/filter/tests/filters.py
  13. 445
      graphene_django/filter/tests/test_fields.py
  14. 3
      graphene_django/forms/converter.py
  15. 12
      graphene_django/forms/forms.py
  16. 60
      graphene_django/forms/mutation.py
  17. 24
      graphene_django/forms/tests/test_converter.py
  18. 37
      graphene_django/forms/tests/test_mutation.py
  19. 46
      graphene_django/management/commands/graphql_schema.py
  20. 13
      graphene_django/registry.py
  21. 78
      graphene_django/rest_framework/mutation.py
  22. 15
      graphene_django/rest_framework/serializer_converter.py
  23. 42
      graphene_django/rest_framework/tests/test_field_converter.py
  24. 109
      graphene_django/rest_framework/tests/test_mutation.py
  25. 40
      graphene_django/settings.py
  26. 66
      graphene_django/tests/models.py
  27. 5
      graphene_django/tests/schema.py
  28. 4
      graphene_django/tests/schema_view.py
  29. 4
      graphene_django/tests/test_command.py
  30. 108
      graphene_django/tests/test_converter.py
  31. 10
      graphene_django/tests/test_forms.py
  32. 581
      graphene_django/tests/test_query.py
  33. 35
      graphene_django/tests/test_schema.py
  34. 53
      graphene_django/tests/test_types.py
  35. 509
      graphene_django/tests/test_views.py
  36. 4
      graphene_django/tests/urls.py
  37. 5
      graphene_django/tests/urls_inherited.py
  38. 4
      graphene_django/tests/urls_pretty.py
  39. 43
      graphene_django/types.py
  40. 10
      graphene_django/utils.py
  41. 200
      graphene_django/views.py

@ -1,14 +1,6 @@
from .types import (
DjangoObjectType,
)
from .fields import (
DjangoConnectionField,
)
from .types import DjangoObjectType
from .fields import DjangoConnectionField
__version__ = '2.1rc1'
__version__ = "2.1rc1"
__all__ = [
'__version__',
'DjangoObjectType',
'DjangoConnectionField'
]
__all__ = ["__version__", "DjangoObjectType", "DjangoConnectionField"]

@ -7,7 +7,7 @@ try:
# and we cannot have psycopg2 on PyPy
from django.contrib.postgres.fields import ArrayField, HStoreField, RangeField
except ImportError:
ArrayField, HStoreField, JSONField, RangeField = (MissingType, ) * 4
ArrayField, HStoreField, JSONField, RangeField = (MissingType,) * 4
try:

@ -1,8 +1,22 @@
from django.db import models
from django.utils.encoding import force_text
from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List,
NonNull, String, UUID, DateTime, Date, Time)
from graphene import (
ID,
Boolean,
Dynamic,
Enum,
Field,
Float,
Int,
List,
NonNull,
String,
UUID,
DateTime,
Date,
Time,
)
from graphene.types.json import JSONString
from graphene.utils.str_converters import to_camel_case, to_const
from graphql import assert_valid_name
@ -32,7 +46,7 @@ def get_choices(choices):
else:
name = convert_choice_name(value)
while name in converted_names:
name += '_' + str(len(converted_names))
name += "_" + str(len(converted_names))
converted_names.append(name)
description = help_text
yield name, value, description
@ -43,16 +57,15 @@ def convert_django_field_with_choices(field, registry=None):
converted = registry.get_converted_field(field)
if converted:
return converted
choices = getattr(field, 'choices', None)
choices = getattr(field, "choices", None)
if choices:
meta = field.model._meta
name = to_camel_case('{}_{}'.format(meta.object_name, field.name))
name = to_camel_case("{}_{}".format(meta.object_name, field.name))
choices = list(get_choices(choices))
named_choices = [(c[0], c[1]) for c in choices]
named_choices_descriptions = {c[0]: c[2] for c in choices}
class EnumWithDescriptionsType(object):
@property
def description(self):
return named_choices_descriptions[self.name]
@ -69,8 +82,8 @@ def convert_django_field_with_choices(field, registry=None):
@singledispatch
def convert_django_field(field, registry=None):
raise Exception(
"Don't know how to convert the Django field %s (%s)" %
(field, field.__class__))
"Don't know how to convert the Django field %s (%s)" % (field, field.__class__)
)
@convert_django_field.register(models.CharField)
@ -147,7 +160,7 @@ def convert_onetoone_field_to_djangomodel(field, registry=None):
# We do this for a bug in Django 1.8, where null attr
# is not available in the OneToOneRel instance
null = getattr(field, 'null', True)
null = getattr(field, "null", True)
return Field(_type, required=not null)
return Dynamic(dynamic_type)
@ -171,6 +184,7 @@ def convert_field_to_list_or_connection(field, registry=None):
# defined filter_fields in the DjangoObjectType Meta
if _type._meta.filter_fields:
from .filter.fields import DjangoFilterConnectionField
return DjangoFilterConnectionField(_type)
return DjangoConnectionField(_type)

@ -1,4 +1,4 @@
from .middleware import DjangoDebugMiddleware
from .types import DjangoDebug
__all__ = ['DjangoDebugMiddleware', 'DjangoDebug']
__all__ = ["DjangoDebugMiddleware", "DjangoDebug"]

@ -7,7 +7,6 @@ from .types import DjangoDebug
class DjangoDebugContext(object):
def __init__(self):
self.debug_promise = None
self.promises = []
@ -38,20 +37,21 @@ class DjangoDebugContext(object):
class DjangoDebugMiddleware(object):
def resolve(self, next, root, info, **args):
context = info.context
django_debug = getattr(context, 'django_debug', None)
django_debug = getattr(context, "django_debug", None)
if not django_debug:
if context is None:
raise Exception('DjangoDebug cannot be executed in None contexts')
raise Exception("DjangoDebug cannot be executed in None contexts")
try:
context.django_debug = DjangoDebugContext()
except Exception:
raise Exception('DjangoDebug need the context to be writable, context received: {}.'.format(
context.__class__.__name__
))
if info.schema.get_type('DjangoDebug') == info.return_type:
raise Exception(
"DjangoDebug need the context to be writable, context received: {}.".format(
context.__class__.__name__
)
)
if info.schema.get_type("DjangoDebug") == info.return_type:
return context.django_debug.get_debug_promise()
promise = next(root, info, **args)
context.django_debug.add_promise(promise)

@ -16,7 +16,6 @@ class SQLQueryTriggered(Exception):
class ThreadLocalState(local):
def __init__(self):
self.enabled = True
@ -35,7 +34,7 @@ recording = state.recording # export function
def wrap_cursor(connection, panel):
if not hasattr(connection, '_graphene_cursor'):
if not hasattr(connection, "_graphene_cursor"):
connection._graphene_cursor = connection.cursor
def cursor():
@ -46,7 +45,7 @@ def wrap_cursor(connection, panel):
def unwrap_cursor(connection):
if hasattr(connection, '_graphene_cursor'):
if hasattr(connection, "_graphene_cursor"):
previous_cursor = connection._graphene_cursor
connection.cursor = previous_cursor
del connection._graphene_cursor
@ -87,15 +86,14 @@ class NormalCursorWrapper(object):
if not params:
return params
if isinstance(params, dict):
return dict((key, self._quote_expr(value))
for key, value in params.items())
return dict((key, self._quote_expr(value)) for key, value in params.items())
return list(map(self._quote_expr, params))
def _decode(self, param):
try:
return force_text(param, strings_only=True)
except UnicodeDecodeError:
return '(encoded string)'
return "(encoded string)"
def _record(self, method, sql, params):
start_time = time()
@ -103,45 +101,48 @@ class NormalCursorWrapper(object):
return method(sql, params)
finally:
stop_time = time()
duration = (stop_time - start_time)
_params = ''
duration = stop_time - start_time
_params = ""
try:
_params = json.dumps(list(map(self._decode, params)))
except Exception:
pass # object not JSON serializable
alias = getattr(self.db, 'alias', 'default')
alias = getattr(self.db, "alias", "default")
conn = self.db.connection
vendor = getattr(conn, 'vendor', 'unknown')
vendor = getattr(conn, "vendor", "unknown")
params = {
'vendor': vendor,
'alias': alias,
'sql': self.db.ops.last_executed_query(
self.cursor, sql, self._quote_params(params)),
'duration': duration,
'raw_sql': sql,
'params': _params,
'start_time': start_time,
'stop_time': stop_time,
'is_slow': duration > 10,
'is_select': sql.lower().strip().startswith('select'),
"vendor": vendor,
"alias": alias,
"sql": self.db.ops.last_executed_query(
self.cursor, sql, self._quote_params(params)
),
"duration": duration,
"raw_sql": sql,
"params": _params,
"start_time": start_time,
"stop_time": stop_time,
"is_slow": duration > 10,
"is_select": sql.lower().strip().startswith("select"),
}
if vendor == 'postgresql':
if vendor == "postgresql":
# If an erroneous query was ran on the connection, it might
# be in a state where checking isolation_level raises an
# exception.
try:
iso_level = conn.isolation_level
except conn.InternalError:
iso_level = 'unknown'
params.update({
'trans_id': self.logger.get_transaction_id(alias),
'trans_status': conn.get_transaction_status(),
'iso_level': iso_level,
'encoding': conn.encoding,
})
iso_level = "unknown"
params.update(
{
"trans_id": self.logger.get_transaction_id(alias),
"trans_status": conn.get_transaction_status(),
"iso_level": iso_level,
"encoding": conn.encoding,
}
)
_sql = DjangoDebugSQL(**params)
# We keep `sql` to maintain backwards compatibility

@ -12,31 +12,31 @@ from ..types import DjangoDebug
class context(object):
pass
# from examples.starwars_django.models import Character
pytestmark = pytest.mark.django_db
def test_should_query_field():
r1 = Reporter(last_name='ABA')
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name='Griffin')
r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_reporter(self, info, **args):
return Reporter.objects.first()
query = '''
query = """
query ReporterQuery {
reporter {
lastName
@ -47,43 +47,40 @@ def test_should_query_field():
}
}
}
'''
"""
expected = {
'reporter': {
'lastName': 'ABA',
"reporter": {"lastName": "ABA"},
"__debug": {
"sql": [{"rawSql": str(Reporter.objects.order_by("pk")[:1].query)}]
},
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.order_by('pk')[:1].query)
}]
}
}
schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data == expected
def test_should_query_list():
r1 = Reporter(last_name='ABA')
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name='Griffin')
r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class Query(graphene.ObjectType):
all_reporters = graphene.List(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
query = '''
query = """
query ReporterQuery {
allReporters {
lastName
@ -94,45 +91,38 @@ def test_should_query_list():
}
}
}
'''
"""
expected = {
'allReporters': [{
'lastName': 'ABA',
}, {
'lastName': 'Griffin',
}],
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.all().query)
}]
}
"allReporters": [{"lastName": "ABA"}, {"lastName": "Griffin"}],
"__debug": {"sql": [{"rawSql": str(Reporter.objects.all().query)}]},
}
schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data == expected
def test_should_query_connection():
r1 = Reporter(last_name='ABA')
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name='Griffin')
r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
query = '''
query = """
query ReporterQuery {
allReporters(first:1) {
edges {
@ -147,48 +137,41 @@ def test_should_query_connection():
}
}
}
'''
expected = {
'allReporters': {
'edges': [{
'node': {
'lastName': 'ABA',
}
}]
},
}
"""
expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data['allReporters'] == expected['allReporters']
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
assert result.data["allReporters"] == expected["allReporters"]
assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data['__debug']['sql'][1]['rawSql'] == query
assert result.data["__debug"]["sql"][1]["rawSql"] == query
def test_should_query_connectionfilter():
from ...filter import DjangoFilterConnectionField
r1 = Reporter(last_name='ABA')
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name='Griffin')
r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class Query(graphene.ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType, fields=['last_name'])
all_reporters = DjangoFilterConnectionField(ReporterType, fields=["last_name"])
s = graphene.String(resolver=lambda *_: "S")
debug = graphene.Field(DjangoDebug, name='__debug')
debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
query = '''
query = """
query ReporterQuery {
allReporters(first:1) {
edges {
@ -203,20 +186,14 @@ def test_should_query_connectionfilter():
}
}
}
'''
expected = {
'allReporters': {
'edges': [{
'node': {
'lastName': 'ABA',
}
}]
},
}
"""
expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data['allReporters'] == expected['allReporters']
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
assert result.data["allReporters"] == expected["allReporters"]
assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data['__debug']['sql'][1]['rawSql'] == query
assert result.data["__debug"]["sql"][1]["rawSql"] == query

@ -13,7 +13,6 @@ from .utils import maybe_queryset
class DjangoListField(Field):
def __init__(self, _type, *args, **kwargs):
super(DjangoListField, self).__init__(List(_type), *args, **kwargs)
@ -30,25 +29,28 @@ class DjangoListField(Field):
class DjangoConnectionField(ConnectionField):
def __init__(self, *args, **kwargs):
self.on = kwargs.pop('on', False)
self.on = kwargs.pop("on", False)
self.max_limit = kwargs.pop(
'max_limit',
graphene_settings.RELAY_CONNECTION_MAX_LIMIT
"max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
)
self.enforce_first_or_last = kwargs.pop(
'enforce_first_or_last',
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST
"enforce_first_or_last",
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
)
super(DjangoConnectionField, self).__init__(*args, **kwargs)
@property
def type(self):
from .types import DjangoObjectType
_type = super(ConnectionField, self).type
assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types"
assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__)
assert issubclass(
_type, DjangoObjectType
), "DjangoConnectionField only accepts DjangoObjectType types"
assert _type._meta.connection, "The type {} doesn't have a connection".format(
_type.__name__
)
return _type._meta.connection
@property
@ -100,28 +102,37 @@ class DjangoConnectionField(ConnectionField):
return connection
@classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit,
enforce_first_or_last, root, info, **args):
first = args.get('first')
last = args.get('last')
def connection_resolver(
cls,
resolver,
connection,
default_manager,
max_limit,
enforce_first_or_last,
root,
info,
**args
):
first = args.get("first")
last = args.get("last")
if enforce_first_or_last:
assert first or last, (
'You must provide a `first` or `last` value to properly paginate the `{}` connection.'
"You must provide a `first` or `last` value to properly paginate the `{}` connection."
).format(info.field_name)
if max_limit:
if first:
assert first <= max_limit, (
'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.'
"Requesting {} records on the `{}` connection exceeds the `first` limit of {} records."
).format(first, info.field_name, max_limit)
args['first'] = min(first, max_limit)
args["first"] = min(first, max_limit)
if last:
assert last <= max_limit, (
'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.'
"Requesting {} records on the `{}` connection exceeds the `last` limit of {} records."
).format(last, info.field_name, max_limit)
args['last'] = min(last, max_limit)
args["last"] = min(last, max_limit)
iterable = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
@ -138,5 +149,5 @@ class DjangoConnectionField(ConnectionField):
self.type,
self.get_manager(),
self.max_limit,
self.enforce_first_or_last
self.enforce_first_or_last,
)

@ -4,11 +4,15 @@ from ..utils import DJANGO_FILTER_INSTALLED
if not DJANGO_FILTER_INSTALLED:
warnings.warn(
"Use of django filtering requires the django-filter package "
"be installed. You can do so using `pip install django-filter`", ImportWarning
"be installed. You can do so using `pip install django-filter`",
ImportWarning,
)
else:
from .fields import DjangoFilterConnectionField
from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter
__all__ = ['DjangoFilterConnectionField',
'GlobalIDFilter', 'GlobalIDMultipleChoiceFilter']
__all__ = [
"DjangoFilterConnectionField",
"GlobalIDFilter",
"GlobalIDMultipleChoiceFilter",
]

@ -7,10 +7,16 @@ from .utils import get_filtering_args_from_filterset, get_filterset_class
class DjangoFilterConnectionField(DjangoConnectionField):
def __init__(self, type, fields=None, order_by=None,
extra_filter_meta=None, filterset_class=None,
*args, **kwargs):
def __init__(
self,
type,
fields=None,
order_by=None,
extra_filter_meta=None,
filterset_class=None,
*args,
**kwargs
):
self._fields = fields
self._provided_filterset_class = filterset_class
self._filterset_class = None
@ -30,12 +36,13 @@ class DjangoFilterConnectionField(DjangoConnectionField):
def filterset_class(self):
if not self._filterset_class:
fields = self._fields or self.node_type._meta.filter_fields
meta = dict(model=self.model,
fields=fields)
meta = dict(model=self.model, fields=fields)
if self._extra_filter_meta:
meta.update(self._extra_filter_meta)
self._filterset_class = get_filterset_class(self._provided_filterset_class, **meta)
self._filterset_class = get_filterset_class(
self._provided_filterset_class, **meta
)
return self._filterset_class
@ -52,28 +59,40 @@ class DjangoFilterConnectionField(DjangoConnectionField):
# See related PR: https://github.com/graphql-python/graphene-django/pull/126
assert not (default_queryset.query.low_mark and queryset.query.low_mark), (
'Received two sliced querysets (low mark) in the connection, please slice only in one.'
)
assert not (default_queryset.query.high_mark and queryset.query.high_mark), (
'Received two sliced querysets (high mark) in the connection, please slice only in one.'
)
assert not (
default_queryset.query.low_mark and queryset.query.low_mark
), "Received two sliced querysets (low mark) in the connection, please slice only in one."
assert not (
default_queryset.query.high_mark and queryset.query.high_mark
), "Received two sliced querysets (high mark) in the connection, please slice only in one."
low = default_queryset.query.low_mark or queryset.query.low_mark
high = default_queryset.query.high_mark or queryset.query.high_mark
default_queryset.query.clear_limits()
queryset = super(DjangoFilterConnectionField, cls).merge_querysets(default_queryset, queryset)
queryset = super(DjangoFilterConnectionField, cls).merge_querysets(
default_queryset, queryset
)
queryset.query.set_limits(low, high)
return queryset
@classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit,
enforce_first_or_last, filterset_class, filtering_args,
root, info, **args):
def connection_resolver(
cls,
resolver,
connection,
default_manager,
max_limit,
enforce_first_or_last,
filterset_class,
filtering_args,
root,
info,
**args
):
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
qs = filterset_class(
data=filter_kwargs,
queryset=default_manager.get_queryset(),
request=info.context
request=info.context,
).qs
return super(DjangoFilterConnectionField, cls).connection_resolver(
@ -96,5 +115,5 @@ class DjangoFilterConnectionField(DjangoConnectionField):
self.max_limit,
self.enforce_first_or_last,
self.filterset_class,
self.filtering_args
self.filtering_args,
)

@ -28,26 +28,19 @@ class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
GRAPHENE_FILTER_SET_OVERRIDES = {
models.AutoField: {
'filter_class': GlobalIDFilter,
},
models.OneToOneField: {
'filter_class': GlobalIDFilter,
},
models.ForeignKey: {
'filter_class': GlobalIDFilter,
},
models.ManyToManyField: {
'filter_class': GlobalIDMultipleChoiceFilter,
}
models.AutoField: {"filter_class": GlobalIDFilter},
models.OneToOneField: {"filter_class": GlobalIDFilter},
models.ForeignKey: {"filter_class": GlobalIDFilter},
models.ManyToManyField: {"filter_class": GlobalIDMultipleChoiceFilter},
}
class GrapheneFilterSetMixin(BaseFilterSet):
FILTER_DEFAULTS = dict(itertools.chain(
FILTER_FOR_DBFIELD_DEFAULTS.items(),
GRAPHENE_FILTER_SET_OVERRIDES.items()
))
FILTER_DEFAULTS = dict(
itertools.chain(
FILTER_FOR_DBFIELD_DEFAULTS.items(), GRAPHENE_FILTER_SET_OVERRIDES.items()
)
)
@classmethod
def filter_for_reverse_field(cls, f, name):
@ -62,10 +55,7 @@ class GrapheneFilterSetMixin(BaseFilterSet):
except AttributeError:
rel = f.field.rel
default = {
'name': name,
'label': capfirst(rel.related_name)
}
default = {"name": name, "label": capfirst(rel.related_name)}
if rel.multiple:
# For to-many relationships
return GlobalIDMultipleChoiceFilter(**default)
@ -78,25 +68,20 @@ def setup_filterset(filterset_class):
""" Wrap a provided filterset in Graphene-specific functionality
"""
return type(
'Graphene{}'.format(filterset_class.__name__),
"Graphene{}".format(filterset_class.__name__),
(filterset_class, GrapheneFilterSetMixin),
{},
)
def custom_filterset_factory(model, filterset_base_class=FilterSet,
**meta):
def custom_filterset_factory(model, filterset_base_class=FilterSet, **meta):
""" Create a filterset for the given model using the provided meta data
"""
meta.update({
'model': model,
})
meta_class = type(str('Meta'), (object,), meta)
meta.update({"model": model})
meta_class = type(str("Meta"), (object,), meta)
filterset = type(
str('%sFilterSet' % model._meta.object_name),
str("%sFilterSet" % model._meta.object_name),
(filterset_base_class, GrapheneFilterSetMixin),
{
'Meta': meta_class
}
{"Meta": meta_class},
)
return filterset

@ -5,29 +5,26 @@ from graphene_django.tests.models import Article, Pet, Reporter
class ArticleFilter(django_filters.FilterSet):
class Meta:
model = Article
fields = {
'headline': ['exact', 'icontains'],
'pub_date': ['gt', 'lt', 'exact'],
'reporter': ['exact'],
"headline": ["exact", "icontains"],
"pub_date": ["gt", "lt", "exact"],
"reporter": ["exact"],
}
order_by = OrderingFilter(fields=('pub_date',))
order_by = OrderingFilter(fields=("pub_date",))
class ReporterFilter(django_filters.FilterSet):
class Meta:
model = Reporter
fields = ['first_name', 'last_name', 'email', 'pets']
fields = ["first_name", "last_name", "email", "pets"]
order_by = OrderingFilter(fields=('pub_date',))
order_by = OrderingFilter(fields=("pub_date",))
class PetFilter(django_filters.FilterSet):
class Meta:
model = Pet
fields = ['name']
fields = ["name"]

@ -5,8 +5,7 @@ import pytest
from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.forms import (GlobalIDFormField,
GlobalIDMultipleChoiceField)
from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField
from graphene_django.tests.models import Article, Pet, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED
@ -20,36 +19,43 @@ if DJANGO_FILTER_INSTALLED:
import django_filters
from django_filters import FilterSet, NumberFilter
from graphene_django.filter import (GlobalIDFilter, DjangoFilterConnectionField,
GlobalIDMultipleChoiceFilter)
from graphene_django.filter.tests.filters import ArticleFilter, PetFilter, ReporterFilter
from graphene_django.filter import (
GlobalIDFilter,
DjangoFilterConnectionField,
GlobalIDMultipleChoiceFilter,
)
from graphene_django.filter.tests.filters import (
ArticleFilter,
PetFilter,
ReporterFilter,
)
else:
pytestmark.append(pytest.mark.skipif(True, reason='django_filters not installed or not compatible'))
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
pytestmark.append(pytest.mark.django_db)
if DJANGO_FILTER_INSTALLED:
class ArticleNode(DjangoObjectType):
class ArticleNode(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node, )
filter_fields = ('headline', )
interfaces = (Node,)
filter_fields = ("headline",)
class ReporterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class PetNode(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node, )
interfaces = (Node,)
# schema = Schema()
@ -59,58 +65,47 @@ def get_args(field):
def assert_arguments(field, *arguments):
ignore = ('after', 'before', 'first', 'last', 'order_by')
ignore = ("after", "before", "first", "last", "order_by")
args = get_args(field)
actual = [
name
for name in args
if name not in ignore and not name.startswith('_')
]
assert set(arguments) == set(actual), \
'Expected arguments ({}) did not match actual ({})'.format(
arguments,
actual
)
actual = [name for name in args if name not in ignore and not name.startswith("_")]
assert set(arguments) == set(
actual
), "Expected arguments ({}) did not match actual ({})".format(arguments, actual)
def assert_orderable(field):
args = get_args(field)
assert 'order_by' in args, \
'Field cannot be ordered'
assert "order_by" in args, "Field cannot be ordered"
def assert_not_orderable(field):
args = get_args(field)
assert 'order_by' not in args, \
'Field can be ordered'
assert "order_by" not in args, "Field can be ordered"
def test_filter_explicit_filterset_arguments():
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleFilter)
assert_arguments(field,
'headline', 'headline__icontains',
'pub_date', 'pub_date__gt', 'pub_date__lt',
'reporter',
)
assert_arguments(
field,
"headline",
"headline__icontains",
"pub_date",
"pub_date__gt",
"pub_date__lt",
"reporter",
)
def test_filter_shortcut_filterset_arguments_list():
field = DjangoFilterConnectionField(ArticleNode, fields=['pub_date', 'reporter'])
assert_arguments(field,
'pub_date',
'reporter',
)
field = DjangoFilterConnectionField(ArticleNode, fields=["pub_date", "reporter"])
assert_arguments(field, "pub_date", "reporter")
def test_filter_shortcut_filterset_arguments_dict():
field = DjangoFilterConnectionField(ArticleNode, fields={
'headline': ['exact', 'icontains'],
'reporter': ['exact'],
})
assert_arguments(field,
'headline', 'headline__icontains',
'reporter',
)
field = DjangoFilterConnectionField(
ArticleNode, fields={"headline": ["exact", "icontains"], "reporter": ["exact"]}
)
assert_arguments(field, "headline", "headline__icontains", "reporter")
def test_filter_explicit_filterset_orderable():
@ -134,15 +129,14 @@ def test_filter_explicit_filterset_not_orderable():
def test_filter_shortcut_filterset_extra_meta():
field = DjangoFilterConnectionField(ArticleNode, extra_filter_meta={
'exclude': ('headline', )
})
assert 'headline' not in field.filterset_class.get_fields()
field = DjangoFilterConnectionField(
ArticleNode, extra_filter_meta={"exclude": ("headline",)}
)
assert "headline" not in field.filterset_class.get_fields()
def test_filter_shortcut_filterset_context():
class ArticleContextFilter(django_filters.FilterSet):
class Meta:
model = Article
exclude = set()
@ -153,17 +147,31 @@ def test_filter_shortcut_filterset_context():
return qs.filter(reporter=self.request.reporter)
class Query(ObjectType):
context_articles = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleContextFilter)
context_articles = DjangoFilterConnectionField(
ArticleNode, filterset_class=ArticleContextFilter
)
r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com')
r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com')
Article.objects.create(headline='a1', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r1, editor=r1)
Article.objects.create(headline='a2', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r2, editor=r2)
r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
Article.objects.create(
headline="a1",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r1,
editor=r1,
)
Article.objects.create(
headline="a2",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r2,
editor=r2,
)
class context(object):
reporter = r2
query = '''
query = """
query {
contextArticles {
edges {
@ -173,42 +181,39 @@ def test_filter_shortcut_filterset_context():
}
}
}
'''
"""
schema = Schema(query=Query)
result = schema.execute(query, context_value=context())
assert not result.errors
assert len(result.data['contextArticles']['edges']) == 1
assert result.data['contextArticles']['edges'][0]['node']['headline'] == 'a2'
assert len(result.data["contextArticles"]["edges"]) == 1
assert result.data["contextArticles"]["edges"][0]["node"]["headline"] == "a2"
def test_filter_filterset_information_on_meta():
class ReporterFilterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
filter_fields = ['first_name', 'articles']
interfaces = (Node,)
filter_fields = ["first_name", "articles"]
field = DjangoFilterConnectionField(ReporterFilterNode)
assert_arguments(field, 'first_name', 'articles')
assert_arguments(field, "first_name", "articles")
assert_not_orderable(field)
def test_filter_filterset_information_on_meta_related():
class ReporterFilterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
filter_fields = ['first_name', 'articles']
interfaces = (Node,)
filter_fields = ["first_name", "articles"]
class ArticleFilterNode(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node, )
filter_fields = ['headline', 'reporter']
interfaces = (Node,)
filter_fields = ["headline", "reporter"]
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
@ -217,25 +222,23 @@ def test_filter_filterset_information_on_meta_related():
article = Field(ArticleFilterNode)
schema = Schema(query=Query)
articles_field = ReporterFilterNode._meta.fields['articles'].get_type()
assert_arguments(articles_field, 'headline', 'reporter')
articles_field = ReporterFilterNode._meta.fields["articles"].get_type()
assert_arguments(articles_field, "headline", "reporter")
assert_not_orderable(articles_field)
def test_filter_filterset_related_results():
class ReporterFilterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
filter_fields = ['first_name', 'articles']
interfaces = (Node,)
filter_fields = ["first_name", "articles"]
class ArticleFilterNode(DjangoObjectType):
class Meta:
interfaces = (Node, )
interfaces = (Node,)
model = Article
filter_fields = ['headline', 'reporter']
filter_fields = ["headline", "reporter"]
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
@ -243,12 +246,22 @@ def test_filter_filterset_related_results():
reporter = Field(ReporterFilterNode)
article = Field(ArticleFilterNode)
r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com')
r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com')
Article.objects.create(headline='a1', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r1)
Article.objects.create(headline='a2', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r2)
r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
Article.objects.create(
headline="a1",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r1,
)
Article.objects.create(
headline="a2",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r2,
)
query = '''
query = """
query {
allReporters {
edges {
@ -264,123 +277,134 @@ def test_filter_filterset_related_results():
}
}
}
'''
"""
schema = Schema(query=Query)
result = schema.execute(query)
assert not result.errors
# We should only get back a single article for each reporter
assert len(result.data['allReporters']['edges'][0]['node']['articles']['edges']) == 1
assert len(result.data['allReporters']['edges'][1]['node']['articles']['edges']) == 1
assert (
len(result.data["allReporters"]["edges"][0]["node"]["articles"]["edges"]) == 1
)
assert (
len(result.data["allReporters"]["edges"][1]["node"]["articles"]["edges"]) == 1
)
def test_global_id_field_implicit():
field = DjangoFilterConnectionField(ArticleNode, fields=['id'])
field = DjangoFilterConnectionField(ArticleNode, fields=["id"])
filterset_class = field.filterset_class
id_filter = filterset_class.base_filters['id']
id_filter = filterset_class.base_filters["id"]
assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField
def test_global_id_field_explicit():
class ArticleIdFilter(django_filters.FilterSet):
class Meta:
model = Article
fields = ['id']
fields = ["id"]
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
filterset_class = field.filterset_class
id_filter = filterset_class.base_filters['id']
id_filter = filterset_class.base_filters["id"]
assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField