You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
246 lines
8.0 KiB
Python
246 lines
8.0 KiB
Python
"""Astroid hooks for various builtins."""
|
|
|
|
import sys
|
|
from functools import partial
|
|
from textwrap import dedent
|
|
|
|
import six
|
|
from astroid import (MANAGER, UseInferenceDefault,
|
|
inference_tip, YES, InferenceError, UnresolvableName)
|
|
from astroid import nodes
|
|
from astroid.builder import AstroidBuilder
|
|
|
|
|
|
def _extend_str(class_node, rvalue):
|
|
"""function to extend builtin str/unicode class"""
|
|
# TODO(cpopa): this approach will make astroid to believe
|
|
# that some arguments can be passed by keyword, but
|
|
# unfortunately, strings and bytes don't accept keyword arguments.
|
|
code = dedent('''
|
|
class whatever(object):
|
|
def join(self, iterable):
|
|
return {rvalue}
|
|
def replace(self, old, new, count=None):
|
|
return {rvalue}
|
|
def format(self, *args, **kwargs):
|
|
return {rvalue}
|
|
def encode(self, encoding='ascii', errors=None):
|
|
return ''
|
|
def decode(self, encoding='ascii', errors=None):
|
|
return u''
|
|
def capitalize(self):
|
|
return {rvalue}
|
|
def title(self):
|
|
return {rvalue}
|
|
def lower(self):
|
|
return {rvalue}
|
|
def upper(self):
|
|
return {rvalue}
|
|
def swapcase(self):
|
|
return {rvalue}
|
|
def index(self, sub, start=None, end=None):
|
|
return 0
|
|
def find(self, sub, start=None, end=None):
|
|
return 0
|
|
def count(self, sub, start=None, end=None):
|
|
return 0
|
|
def strip(self, chars=None):
|
|
return {rvalue}
|
|
def lstrip(self, chars=None):
|
|
return {rvalue}
|
|
def rstrip(self, chars=None):
|
|
return {rvalue}
|
|
def rjust(self, width, fillchar=None):
|
|
return {rvalue}
|
|
def center(self, width, fillchar=None):
|
|
return {rvalue}
|
|
def ljust(self, width, fillchar=None):
|
|
return {rvalue}
|
|
''')
|
|
code = code.format(rvalue=rvalue)
|
|
fake = AstroidBuilder(MANAGER).string_build(code)['whatever']
|
|
for method in fake.mymethods():
|
|
class_node.locals[method.name] = [method]
|
|
method.parent = class_node
|
|
|
|
def extend_builtins(class_transforms):
|
|
from astroid.bases import BUILTINS
|
|
builtin_ast = MANAGER.astroid_cache[BUILTINS]
|
|
for class_name, transform in class_transforms.items():
|
|
transform(builtin_ast[class_name])
|
|
|
|
if sys.version_info > (3, 0):
|
|
extend_builtins({'bytes': partial(_extend_str, rvalue="b''"),
|
|
'str': partial(_extend_str, rvalue="''")})
|
|
else:
|
|
extend_builtins({'str': partial(_extend_str, rvalue="''"),
|
|
'unicode': partial(_extend_str, rvalue="u''")})
|
|
|
|
|
|
def register_builtin_transform(transform, builtin_name):
|
|
"""Register a new transform function for the given *builtin_name*.
|
|
|
|
The transform function must accept two parameters, a node and
|
|
an optional context.
|
|
"""
|
|
def _transform_wrapper(node, context=None):
|
|
result = transform(node, context=context)
|
|
if result:
|
|
result.parent = node
|
|
result.lineno = node.lineno
|
|
result.col_offset = node.col_offset
|
|
return iter([result])
|
|
|
|
MANAGER.register_transform(nodes.CallFunc,
|
|
inference_tip(_transform_wrapper),
|
|
lambda n: (isinstance(n.func, nodes.Name) and
|
|
n.func.name == builtin_name))
|
|
|
|
|
|
def _generic_inference(node, context, node_type, transform):
|
|
args = node.args
|
|
if not args:
|
|
return node_type()
|
|
if len(node.args) > 1:
|
|
raise UseInferenceDefault()
|
|
|
|
arg, = args
|
|
transformed = transform(arg)
|
|
if not transformed:
|
|
try:
|
|
infered = next(arg.infer(context=context))
|
|
except (InferenceError, StopIteration):
|
|
raise UseInferenceDefault()
|
|
if infered is YES:
|
|
raise UseInferenceDefault()
|
|
transformed = transform(infered)
|
|
if not transformed or transformed is YES:
|
|
raise UseInferenceDefault()
|
|
return transformed
|
|
|
|
|
|
def _generic_transform(arg, klass, iterables, build_elts):
|
|
if isinstance(arg, klass):
|
|
return arg
|
|
elif isinstance(arg, iterables):
|
|
if not all(isinstance(elt, nodes.Const)
|
|
for elt in arg.elts):
|
|
# TODO(cpopa): Don't support heterogenous elements.
|
|
# Not yet, though.
|
|
raise UseInferenceDefault()
|
|
elts = [elt.value for elt in arg.elts]
|
|
elif isinstance(arg, nodes.Dict):
|
|
if not all(isinstance(elt[0], nodes.Const)
|
|
for elt in arg.items):
|
|
raise UseInferenceDefault()
|
|
elts = [item[0].value for item in arg.items]
|
|
elif (isinstance(arg, nodes.Const) and
|
|
isinstance(arg.value, (six.string_types, six.binary_type))):
|
|
elts = arg.value
|
|
else:
|
|
return
|
|
return klass(elts=build_elts(elts))
|
|
|
|
|
|
def _infer_builtin(node, context,
|
|
klass=None, iterables=None,
|
|
build_elts=None):
|
|
transform_func = partial(
|
|
_generic_transform,
|
|
klass=klass,
|
|
iterables=iterables,
|
|
build_elts=build_elts)
|
|
|
|
return _generic_inference(node, context, klass, transform_func)
|
|
|
|
# pylint: disable=invalid-name
|
|
infer_tuple = partial(
|
|
_infer_builtin,
|
|
klass=nodes.Tuple,
|
|
iterables=(nodes.List, nodes.Set),
|
|
build_elts=tuple)
|
|
|
|
infer_list = partial(
|
|
_infer_builtin,
|
|
klass=nodes.List,
|
|
iterables=(nodes.Tuple, nodes.Set),
|
|
build_elts=list)
|
|
|
|
infer_set = partial(
|
|
_infer_builtin,
|
|
klass=nodes.Set,
|
|
iterables=(nodes.List, nodes.Tuple),
|
|
build_elts=set)
|
|
|
|
|
|
def _get_elts(arg, context):
|
|
is_iterable = lambda n: isinstance(n,
|
|
(nodes.List, nodes.Tuple, nodes.Set))
|
|
try:
|
|
infered = next(arg.infer(context))
|
|
except (InferenceError, UnresolvableName):
|
|
raise UseInferenceDefault()
|
|
if isinstance(infered, nodes.Dict):
|
|
items = infered.items
|
|
elif is_iterable(infered):
|
|
items = []
|
|
for elt in infered.elts:
|
|
# If an item is not a pair of two items,
|
|
# then fallback to the default inference.
|
|
# Also, take in consideration only hashable items,
|
|
# tuples and consts. We are choosing Names as well.
|
|
if not is_iterable(elt):
|
|
raise UseInferenceDefault()
|
|
if len(elt.elts) != 2:
|
|
raise UseInferenceDefault()
|
|
if not isinstance(elt.elts[0],
|
|
(nodes.Tuple, nodes.Const, nodes.Name)):
|
|
raise UseInferenceDefault()
|
|
items.append(tuple(elt.elts))
|
|
else:
|
|
raise UseInferenceDefault()
|
|
return items
|
|
|
|
def infer_dict(node, context=None):
|
|
"""Try to infer a dict call to a Dict node.
|
|
|
|
The function treats the following cases:
|
|
|
|
* dict()
|
|
* dict(mapping)
|
|
* dict(iterable)
|
|
* dict(iterable, **kwargs)
|
|
* dict(mapping, **kwargs)
|
|
* dict(**kwargs)
|
|
|
|
If a case can't be infered, we'll fallback to default inference.
|
|
"""
|
|
has_keywords = lambda args: all(isinstance(arg, nodes.Keyword)
|
|
for arg in args)
|
|
if not node.args and not node.kwargs:
|
|
# dict()
|
|
return nodes.Dict()
|
|
elif has_keywords(node.args) and node.args:
|
|
# dict(a=1, b=2, c=4)
|
|
items = [(nodes.Const(arg.arg), arg.value) for arg in node.args]
|
|
elif (len(node.args) >= 2 and
|
|
has_keywords(node.args[1:])):
|
|
# dict(some_iterable, b=2, c=4)
|
|
elts = _get_elts(node.args[0], context)
|
|
keys = [(nodes.Const(arg.arg), arg.value) for arg in node.args[1:]]
|
|
items = elts + keys
|
|
elif len(node.args) == 1:
|
|
items = _get_elts(node.args[0], context)
|
|
else:
|
|
raise UseInferenceDefault()
|
|
|
|
empty = nodes.Dict()
|
|
empty.items = items
|
|
return empty
|
|
|
|
# Builtins inference
|
|
register_builtin_transform(infer_tuple, 'tuple')
|
|
register_builtin_transform(infer_set, 'set')
|
|
register_builtin_transform(infer_list, 'list')
|
|
register_builtin_transform(infer_dict, 'dict')
|