Browse Source

Set TypedDict module correctly when using class definition syntax (#2) (#4)

I shuffled the code so that we only override `__module__` when TypedDict is instantiated (and not when it's subclassed). 

And I tested that `__module__` is correct with both syntaxes.
pull/8/head
Anirudh Padmarao 4 years ago committed by Ivan Levkivskyi
parent
commit
508d25ef34
  1. 17
      mypy_extensions.py
  2. 2
      tests/testextensions.py

17
mypy_extensions.py

@ -36,8 +36,15 @@ def _typeddict_new(cls, _typename, _fields=None, **kwargs):
elif kwargs:
raise TypeError("TypedDict takes either a dict or keyword arguments,"
" but not both")
return _TypedDictMeta(_typename, (), {'__annotations__': dict(_fields),
'__total__': total})
ns = {'__annotations__': dict(_fields), '__total__': total}
try:
# Setting correct module is necessary to make typed dict classes pickleable.
ns['__module__'] = sys._getframe(1).f_globals.get('__name__', '__main__')
except (AttributeError, ValueError):
pass
return _TypedDictMeta(_typename, (), ns)
class _TypedDictMeta(type):
@ -50,11 +57,7 @@ class _TypedDictMeta(type):
# via _dict_new.
ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new
tp_dict = super(_TypedDictMeta, cls).__new__(cls, name, (dict,), ns)
try:
# Setting correct module is necessary to make typed dict classes pickleable.
tp_dict.__module__ = sys._getframe(2).f_globals.get('__name__', '__main__')
except (AttributeError, ValueError):
pass
anns = ns.get('__annotations__', {})
msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"
anns = {n: _type_check(tp, msg) for n, tp in anns.items()}

2
tests/testextensions.py

@ -97,6 +97,8 @@ class TypedDictTests(BaseTestCase):
@skipUnless(PY36, 'Python 3.6 required')
def test_py36_class_syntax_usage(self):
self.assertEqual(LabelPoint2D.__name__, 'LabelPoint2D') # noqa
self.assertEqual(LabelPoint2D.__module__, __name__) # noqa
self.assertEqual(LabelPoint2D.__annotations__, {'x': int, 'y': int, 'label': str}) # noqa
self.assertEqual(LabelPoint2D.__bases__, (dict,)) # noqa
self.assertEqual(LabelPoint2D.__total__, True) # noqa

Loading…
Cancel
Save