Skip to content
25 changes: 25 additions & 0 deletions tests/ast_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from vyper.ast.nodes import VyperNode


def deepequals(node: VyperNode, other: VyperNode):
# checks two nodes are recursively equal, ignoring metadata
# like line info.
if not isinstance(other, type(node)):
return False

if isinstance(node, list):
if len(node) != len(other):
return False
return all(deepequals(a, b) for a, b in zip(node, other))

if not isinstance(node, VyperNode):
return node == other

if getattr(node, "node_id", None) != getattr(other, "node_id", None):
return False
for field_name in (i for i in node.get_fields() if i not in VyperNode.__slots__):
lhs = getattr(node, field_name, None)
rhs = getattr(other, field_name, None)
if not deepequals(lhs, rhs):
return False
return True
3 changes: 2 additions & 1 deletion tests/unit/ast/nodes/test_binary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from tests.ast_utils import deepequals
from vyper import ast as vy_ast
from vyper.exceptions import SyntaxException

Expand All @@ -18,7 +19,7 @@ def x():
"""
)

assert expected == mutated
assert deepequals(expected, mutated)


def test_binary_length():
Expand Down
11 changes: 6 additions & 5 deletions tests/unit/ast/nodes/test_compare_nodes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from tests.ast_utils import deepequals
from vyper import ast as vy_ast


Expand All @@ -6,33 +7,33 @@ def test_compare_different_node_clases():
left = vyper_ast.body[0].target
right = vyper_ast.body[0].value

assert left != right
assert not deepequals(left, right)


def test_compare_different_nodes_same_class():
vyper_ast = vy_ast.parse_to_ast("[1, 2]")
left, right = vyper_ast.body[0].value.elements

assert left != right
assert not deepequals(left, right)


def test_compare_different_nodes_same_value():
vyper_ast = vy_ast.parse_to_ast("[1, 1]")
left, right = vyper_ast.body[0].value.elements

assert left != right
assert not deepequals(left, right)


def test_compare_similar_node():
# test equality without node_ids
left = vy_ast.Int(value=1)
right = vy_ast.Int(value=1)

assert left == right
assert deepequals(left, right)


def test_compare_same_node():
vyper_ast = vy_ast.parse_to_ast("42")
node = vyper_ast.body[0].value

assert node == node
assert deepequals(node, node)
3 changes: 2 additions & 1 deletion tests/unit/ast/test_ast_dict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import json

from tests.ast_utils import deepequals
from vyper import compiler
from vyper.ast.nodes import NODE_SRC_ATTRIBUTES
from vyper.ast.parse import parse_to_ast
Expand Down Expand Up @@ -138,7 +139,7 @@ def test() -> int128:
new_dict = json.loads(out_json)
new_ast = dict_to_ast(new_dict)

assert new_ast == original_ast
assert deepequals(new_ast, original_ast)


# strip source annotations like lineno, we don't care for inspecting
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/ast/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from tests.ast_utils import deepequals
from vyper.ast.parse import parse_to_ast


Expand All @@ -12,7 +13,7 @@ def test() -> int128:
ast1 = parse_to_ast(code)
ast2 = parse_to_ast("\n \n" + code + "\n\n")

assert ast1 == ast2
assert deepequals(ast1, ast2)


def test_ast_unequal():
Expand All @@ -32,4 +33,4 @@ def test() -> int128:
ast1 = parse_to_ast(code1)
ast2 = parse_to_ast(code2)

assert ast1 != ast2
assert not deepequals(ast1, ast2)
16 changes: 0 additions & 16 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,26 +331,10 @@ def get_fields(cls) -> set:
slot_fields = [x for i in cls.__mro__ for x in getattr(i, "__slots__", [])]
return set(i for i in slot_fields if not i.startswith("_"))

def __hash__(self):
values = [getattr(self, i, None) for i in VyperNode._public_slots]
return hash(tuple(values))

def __deepcopy__(self, memo):
# default implementation of deepcopy is a hotspot
return pickle.loads(pickle.dumps(self))

def __eq__(self, other):
# CMC 2024-03-03 I'm not sure it makes much sense to compare AST
# nodes, especially if they come from other modules
if not isinstance(other, type(self)):
return False
if getattr(other, "node_id", None) != getattr(self, "node_id", None):
return False
for field_name in (i for i in self.get_fields() if i not in VyperNode.__slots__):
if getattr(self, field_name, None) != getattr(other, field_name, None):
return False
return True

def __repr__(self):
cls = type(self)
class_repr = f"{cls.__module__}.{cls.__qualname__}"
Expand Down
10 changes: 5 additions & 5 deletions vyper/semantics/analysis/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def push_path(self, module_ast: vy_ast.Module) -> None:

def pop_path(self, expected: vy_ast.Module) -> None:
popped = self._path.pop()
if expected != popped:
raise CompilerPanic("unreachable")
assert expected is popped, "unreachable"
self._imports.pop()

@contextlib.contextmanager
Expand All @@ -78,7 +77,7 @@ def __init__(self, input_bundle: InputBundle, graph: _ImportGraph):
self.graph = graph
self._ast_of: dict[int, vy_ast.Module] = {}

self.seen: set[int] = set()
self.seen: set[vy_ast.Module] = set()

self._integrity_sum = None

Expand All @@ -103,7 +102,7 @@ def _calculate_integrity_sum_r(self, module_ast: vy_ast.Module):
return sha256sum("".join(acc))

def _resolve_imports_r(self, module_ast: vy_ast.Module):
if id(module_ast) in self.seen:
if module_ast in self.seen:
return
with self.graph.enter_path(module_ast):
for node in module_ast.body:
Expand All @@ -112,7 +111,8 @@ def _resolve_imports_r(self, module_ast: vy_ast.Module):
self._handle_Import(node)
elif isinstance(node, vy_ast.ImportFrom):
self._handle_ImportFrom(node)
self.seen.add(id(module_ast))

self.seen.add(module_ast)

def _handle_Import(self, node: vy_ast.Import):
# import x.y[name] as y[alias]
Expand Down
Loading