Commit 263fd118 authored by Lars van den Haak's avatar Lars van den Haak
Browse files

Now doesn't throw away annotated variables over untyped ones.

parent 9e7a70e4
......@@ -20,6 +20,7 @@ import re
from IPython import get_ipython
from IPython.core.magic import register_line_magic
import logging
from typing import Set
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
......@@ -58,63 +59,68 @@ class NamesLister(ast.NodeVisitor):
If replace is true, it will also gather assigned variables which have
subscripts or attributes."""
def __init__(self, replace=False):
self.names = set()
def __init__(self, replace: bool = False):
self.var_names = set()
self.annotated_names = set()
self.classfunc_names = set()
self.replace = replace
def visit_FunctionDef(self, node):
self.names.add(node.name)
self.classfunc_names.add(node.name)
def visit_AsyncFunctionDef(self, node):
self.names.add(node.name)
self.classfunc_names.add(node.name)
def visit_ClassDef(self, node):
self.names.add(node.name)
self.classfunc_names.add(node.name)
def visit_Assign(self, node):
namer = Names(self.replace)
for t in node.targets:
namer.visit(t)
self.names.update(namer.names)
self.var_names.update(namer.names)
def visit_AnnAssign(self, node):
namer = Names(self.replace)
namer.visit(node.target)
self.names.update(namer.names)
self.annotated_names.update(namer.names)
def visit_AugAssign(self, node):
namer = Names(self.replace)
namer.visit(node.target)
self.names.update(namer.names)
self.var_names.update(namer.names)
class Replacer(ast.NodeTransformer):
"""Replace all functions, classes and variable declarations
with a Pass node, which are present in a list of names."""
def __init__(self, known):
def __init__(self, known_vars: Set[str], known_annotated: Set[str], known_classfunc: Set[str]):
"""Intialize the Replacer.
known -- The set of names which should be replaced.
"""
self.known = known
self.known_vars = known_vars
self.known_annotated = known_annotated
self.known_classfunc = known_classfunc
def visit_FunctionDef(self, node):
if node.name in self.known:
if node.name in self.known_classfunc:
return ast.Pass()
else:
# Remove the body of the function, we don't have to typecheck it anymore
node.body = [ast.Pass()]
return node
def visit_AsyncFunctionDef(self, node):
if node.name in self.known:
if node.name in self.known_classfunc:
return ast.Pass()
else:
node.body = [ast.Pass()]
return node
def visit_ClassDef(self, node):
if node.name in self.known:
if node.name in self.known_classfunc:
return ast.Pass()
else:
return node
......@@ -122,24 +128,24 @@ class Replacer(ast.NodeTransformer):
def visit_Assign(self, node):
mynames = NamesLister(True)
mynames.visit(node)
for n in mynames.names:
if n in self.known:
for n in mynames.var_names:
if n in self.known_vars:
return ast.Pass()
return node
def visit_AnnAssign(self, node):
def visit_AugAssign(self, node):
mynames = NamesLister(True)
mynames.visit(node)
for n in mynames.names:
if n in self.known:
for n in mynames.var_names:
if n in self.known_vars:
return ast.Pass()
return node
def visit_AugAssign(self, node):
def visit_AnnAssign(self, node):
mynames = NamesLister(True)
mynames.visit(node)
for n in mynames.names:
if n in self.known:
for n in mynames.annotated_names:
if n in self.known_annotated:
return ast.Pass()
return node
......@@ -161,7 +167,9 @@ class __MyPyIPython:
def __init__(self):
self.mypy_cells: str = "from IPython.core.getipython import get_ipython\n"
self.mypy_names = set()
self.mypy_var_names = set()
self.mypy_annotated_names = set()
self.mypy_classfunc_names = set()
mypy_shell = get_ipython()
mypy_tmp_func = mypy_shell.run_cell
self.mypy_typecheck = False
......@@ -237,9 +245,22 @@ class __MyPyIPython:
getCell = NamesLister()
getCell.visit(cell_p)
newnames = getCell.names
remove = newnames & self.mypy_names
if len(remove):
new_var = getCell.var_names
new_annotated = getCell.annotated_names
new_classfunc = getCell.classfunc_names
# Remove if there is a new (annotated) variable or a new function
remove_var = (new_var | new_annotated |
new_classfunc) & self.mypy_var_names
# Remove if there is a new annotatted variable or a new function
remove_annotated = (
new_annotated | new_classfunc) & self.mypy_annotated_names
# Remove a function, if any of the three is introduced with the same name
remove_classfunc = (
new_var | new_annotated | new_classfunc) & self.mypy_classfunc_names
if len(remove_var) or len(remove_annotated) or len(remove_classfunc):
try:
mypy_cells_ast = ast.parse(self.mypy_cells)
except SyntaxError:
......@@ -247,10 +268,18 @@ class __MyPyIPython:
logger.debug(self.mypy_cells)
return result
new_mypy_cells_ast = Replacer(
remove).visit(mypy_cells_ast)
remove_var, remove_annotated, remove_classfunc).visit(mypy_cells_ast)
self.mypy_cells = astor.to_source(new_mypy_cells_ast)
self.mypy_names.update(newnames)
# First remove the removed things from the sets, since it could changed from function to variable
# or visa-versa
self.mypy_var_names.difference_update(remove_var)
self.mypy_var_names.update(new_var)
self.mypy_annotated_names.difference_update(
remove_annotated)
self.mypy_annotated_names.update(new_annotated)
self.mypy_classfunc_names.difference_update(
remove_classfunc)
self.mypy_classfunc_names.update(new_classfunc)
mypy_cells_length = len(self.mypy_cells.split('\n'))-1
self.mypy_cells += (cell_filter + '\n')
......@@ -264,7 +293,7 @@ class __MyPyIPython:
compiled = re.compile(
'(<[a-z]+>:)(\\d+)(.*?)$').findall(line)
if len(compiled) > 0:
l, n, r = compiled[0]
_, n, r = compiled[0]
if int(n) > mypy_cells_length:
n = str(int(n)-mypy_cells_length)
r = fixLineNr(r, mypy_cells_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment