Skip to content
Merged
28 changes: 24 additions & 4 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,6 @@ class VyperNode:
Field names that, if present, must be set to None or a `SyntaxException`
is raised. This attribute is used to exclude syntax that is valid in Python
but not in Vyper.
_is_terminus : bool, optional
If `True`, indicates that execution halts upon reaching this node.
_translated_fields : Dict, optional
Field names that are reassigned if encountered. Used to normalize fields
across different Python versions.
Expand Down Expand Up @@ -380,6 +378,13 @@ def is_literal_value(self):
"""
return False

@property
def is_terminus(self):
"""
Check if execution halts upon reaching this node.
"""
return False

@property
def has_folded_value(self):
"""
Expand Down Expand Up @@ -717,7 +722,10 @@ class Stmt(VyperNode):

class Return(Stmt):
__slots__ = ("value",)
_is_terminus = True

@property
def is_terminus(self):
return True


class Expr(Stmt):
Expand Down Expand Up @@ -1302,6 +1310,15 @@ def _op(self, left, right):
class Call(ExprNode):
__slots__ = ("func", "args", "keywords")

@property
def is_terminus(self):
# cursed import cycle!
from vyper.builtins.functions import DISPATCH_TABLE

func_name = self.func.id
builtin_t = DISPATCH_TABLE[func_name]
return getattr(builtin_t, "_is_terminus", False)

# try checking if this is a builtin, which is foldable
def _try_fold(self):
if not isinstance(self.func, Name):
Expand Down Expand Up @@ -1483,7 +1500,10 @@ class AugAssign(Stmt):
class Raise(Stmt):
__slots__ = ("exc",)
_only_empty_fields = ("cause",)
_is_terminus = True

@property
def is_terminus(self):
return True


class Assert(Stmt):
Expand Down
40 changes: 1 addition & 39 deletions vyper/codegen/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import contextlib
from typing import Generator

from vyper import ast as vy_ast
from vyper.codegen.ir_node import Encoding, IRnode
from vyper.compiler.settings import OptimizationLevel
from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT
from vyper.evm.opcodes import version_check
from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch
from vyper.exceptions import CompilerPanic, TypeCheckFailure, TypeMismatch
from vyper.semantics.types import (
AddressT,
BoolT,
Expand Down Expand Up @@ -1035,43 +1034,6 @@ def eval_seq(ir_node):
return None


def is_return_from_function(node):
if isinstance(node, vy_ast.Expr) and node.get("value.func.id") in (
"raw_revert",
"selfdestruct",
):
return True
if isinstance(node, (vy_ast.Return, vy_ast.Raise)):
return True
return False


# TODO this is almost certainly duplicated with check_terminus_node
# in vyper/semantics/analysis/local.py
def check_single_exit(fn_node):
_check_return_body(fn_node, fn_node.body)
for node in fn_node.get_descendants(vy_ast.If):
_check_return_body(node, node.body)
if node.orelse:
_check_return_body(node, node.orelse)


def _check_return_body(node, node_list):
return_count = len([n for n in node_list if is_return_from_function(n)])
if return_count > 1:
raise StructureException(
"Too too many exit statements (return, raise or selfdestruct).", node
)
# Check for invalid code after returns.
last_node_pos = len(node_list) - 1
for idx, n in enumerate(node_list):
if is_return_from_function(n) and idx < last_node_pos:
# is not last statement in body.
raise StructureException(
"Exit statement with succeeding code (that will not execute).", node_list[idx + 1]
)


def mzero(dst, nbytes):
# calldatacopy from past-the-end gives zero bytes.
# cf. YP H.2 (ops section) with CALLDATACOPY spec.
Expand Down
5 changes: 0 additions & 5 deletions vyper/codegen/function_definitions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import vyper.ast as vy_ast
from vyper.codegen.context import Constancy, Context
from vyper.codegen.core import check_single_exit
from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function
from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function
from vyper.codegen.ir_node import IRnode
Expand Down Expand Up @@ -115,10 +114,6 @@ def generate_ir_for_function(
# generate _FuncIRInfo
func_t._ir_info = _FuncIRInfo(func_t)

# Validate return statements.
# XXX: This should really be in semantics pass.
check_single_exit(code)

callees = func_t.called_functions

# we start our function frame from the largest callee frame
Expand Down
3 changes: 1 addition & 2 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
get_dyn_array_count,
get_element_ptr,
getpos,
is_return_from_function,
make_byte_array_copier,
make_setter,
pop_dyn_array,
Expand Down Expand Up @@ -406,7 +405,7 @@ def parse_stmt(stmt, context):
def _is_terminated(code):
last_stmt = code[-1]

if is_return_from_function(last_stmt):
if last_stmt.is_terminus:
return True

if isinstance(last_stmt, vy_ast.If):
Expand Down
31 changes: 20 additions & 11 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,28 @@ def validate_functions(vy_module: vy_ast.Module) -> None:
err_list.raise_if_not_empty()


def _is_terminus_node(node: vy_ast.VyperNode) -> bool:
if getattr(node, "_is_terminus", None):
return True
if isinstance(node, vy_ast.Expr) and isinstance(node.value, vy_ast.Call):
func = get_exact_type_from_node(node.value.func)
if getattr(func, "_is_terminus", None):
return True
return False


def check_for_terminus(node_list: list) -> bool:
if next((i for i in node_list if _is_terminus_node(i)), None):
terminus_nodes = []

# Check for invalid code after returns
last_node_pos = len(node_list) - 1
for idx, n in enumerate(node_list):
if n.is_terminus:
terminus_nodes.append(n)
if idx < last_node_pos:
# is not last statement in body.
raise StructureException(
"Exit statement with succeeding code (that will not execute).",
node_list[idx + 1],
)

if len(terminus_nodes) > 1:
raise StructureException(
"Too many exit statements (return, raise or selfdestruct).", terminus_nodes[-1]
)
elif len(terminus_nodes) == 1:
return True

for node in [i for i in node_list if isinstance(i, vy_ast.If)][::-1]:
if not node.orelse or not check_for_terminus(node.orelse):
continue
Expand Down