Skip to content

Commit eb2efa3

Browse files
refactor[lang]: refactor decorator parsing (#4490)
minor refactor of decorator parsing, which simplifies the code and should make it easier to add new builtin decorators in the future
1 parent efdfa7f commit eb2efa3

File tree

2 files changed

+113
-59
lines changed

2 files changed

+113
-59
lines changed
Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,61 @@
11
import pytest
22

33
from vyper import compiler
4-
from vyper.exceptions import StructureException
4+
from vyper.exceptions import FunctionDeclarationException, StructureException
55

66
FAILING_CONTRACTS = [
7-
"""
7+
(
8+
"""
89
@external
910
@pure
1011
@nonreentrant
1112
def nonreentrant_foo() -> uint256:
1213
return 1
1314
""",
14-
"""
15+
StructureException,
16+
),
17+
(
18+
"""
1519
@external
1620
@nonreentrant
1721
@nonreentrant
1822
def nonreentrant_foo() -> uint256:
1923
return 1
2024
""",
21-
"""
25+
StructureException,
26+
),
27+
(
28+
"""
2229
@external
2330
@nonreentrant("foo")
2431
def nonreentrant_foo() -> uint256:
2532
return 1
2633
""",
34+
StructureException,
35+
),
36+
(
37+
"""
38+
@deploy
39+
@nonreentrant
40+
def __init__():
41+
pass
42+
""",
43+
FunctionDeclarationException,
44+
),
2745
]
2846

2947

30-
@pytest.mark.parametrize("failing_contract_code", FAILING_CONTRACTS)
31-
def test_invalid_function_decorators(failing_contract_code):
32-
with pytest.raises(StructureException):
33-
compiler.compile_code(failing_contract_code)
48+
@pytest.mark.parametrize("bad_code,exc", FAILING_CONTRACTS)
49+
def test_invalid_function_decorators(bad_code, exc):
50+
with pytest.raises(exc):
51+
compiler.compile_code(bad_code)
52+
53+
54+
def test_invalid_function_decorator_vyi():
55+
code = """
56+
@nonreentrant
57+
def foo():
58+
...
59+
"""
60+
with pytest.raises(FunctionDeclarationException):
61+
compiler.compile_code(code, contract_path="foo.vyi", output_formats=["abi"])

vyper/semantics/types/function.py

Lines changed: 77 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -326,25 +326,21 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT":
326326
-------
327327
ContractFunctionT
328328
"""
329-
function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef)
329+
decorators = _parse_decorators(funcdef)
330330

331-
if nonreentrant:
332-
# TODO: refactor so parse_decorators returns the AST location
333-
decorator = next(d for d in funcdef.decorator_list if d.id == "nonreentrant")
331+
if decorators.nonreentrant:
334332
raise FunctionDeclarationException(
335-
"`@nonreentrant` not allowed in interfaces", decorator
333+
"`@nonreentrant` not allowed in interfaces", decorators.nonreentrant_node
336334
)
337335

338336
# it's redundant to specify visibility in vyi - always should be external
337+
function_visibility = decorators.visibility
339338
if function_visibility is None:
340339
function_visibility = FunctionVisibility.EXTERNAL
341340

342341
if function_visibility != FunctionVisibility.EXTERNAL:
343-
nonexternal = next(
344-
d for d in funcdef.decorator_list if d.id in FunctionVisibility.values()
345-
)
346342
raise FunctionDeclarationException(
347-
"Interface functions can only be marked as `@external`", nonexternal
343+
"Interface functions can only be marked as `@external`", decorators.visibility_node
348344
)
349345

350346
if funcdef.name == "__init__":
@@ -374,9 +370,9 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT":
374370
keyword_args,
375371
return_type,
376372
function_visibility,
377-
state_mutability,
373+
decorators.state_mutability,
378374
from_interface=True,
379-
nonreentrant=nonreentrant,
375+
nonreentrant=decorators.nonreentrant,
380376
ast_def=funcdef,
381377
)
382378

@@ -394,9 +390,10 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT":
394390
-------
395391
ContractFunctionT
396392
"""
397-
function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef)
393+
decorators = _parse_decorators(funcdef)
398394

399395
# it's redundant to specify internal visibility - it's implied by not being external
396+
function_visibility = decorators.visibility
400397
if function_visibility is None:
401398
function_visibility = FunctionVisibility.INTERNAL
402399

@@ -420,7 +417,7 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT":
420417
"Only constructors can be marked as `@deploy`!", funcdef
421418
)
422419
if funcdef.name == "__init__":
423-
if state_mutability in (StateMutability.PURE, StateMutability.VIEW):
420+
if decorators.state_mutability in (StateMutability.PURE, StateMutability.VIEW):
424421
raise FunctionDeclarationException(
425422
"Constructor cannot be marked as `@pure` or `@view`", funcdef
426423
)
@@ -438,20 +435,19 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT":
438435
raise FunctionDeclarationException(
439436
"Constructor may not use default arguments", funcdef.args.defaults[0]
440437
)
441-
if nonreentrant:
442-
decorator = next(d for d in funcdef.decorator_list if d.id == "nonreentrant")
438+
if decorators.nonreentrant:
443439
msg = "`@nonreentrant` decorator disallowed on `__init__`"
444-
raise FunctionDeclarationException(msg, decorator)
440+
raise FunctionDeclarationException(msg, decorators.nonreentrant_node)
445441

446442
return cls(
447443
funcdef.name,
448444
positional_args,
449445
keyword_args,
450446
return_type,
451447
function_visibility,
452-
state_mutability,
448+
decorators.state_mutability,
453449
from_interface=False,
454-
nonreentrant=nonreentrant,
450+
nonreentrant=decorators.nonreentrant,
455451
ast_def=funcdef,
456452
)
457453

@@ -723,12 +719,61 @@ def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]:
723719
return type_from_annotation(funcdef.returns, DataLocation.MEMORY)
724720

725721

726-
def _parse_decorators(
727-
funcdef: vy_ast.FunctionDef,
728-
) -> tuple[Optional[FunctionVisibility], StateMutability, bool]:
729-
function_visibility = None
730-
state_mutability = None
731-
nonreentrant_node = None
722+
@dataclass
723+
class _ParsedDecorators:
724+
visibility_node: Optional[vy_ast.Name] = None
725+
state_mutability_node: Optional[vy_ast.Name] = None
726+
nonreentrant_node: Optional[vy_ast.Name] = None
727+
728+
def set_visibility(self, decorator_node: vy_ast.Name):
729+
assert FunctionVisibility.is_valid_value(decorator_node.id), "unreachable"
730+
if self.visibility_node is not None:
731+
raise FunctionDeclarationException(
732+
f"Visibility is already set to: {self.visibility}",
733+
self.visibility_node,
734+
decorator_node,
735+
hint="only one visibility decorator is allowed per function",
736+
)
737+
self.visibility_node = decorator_node
738+
739+
@property
740+
def visibility(self) -> Optional[FunctionVisibility]:
741+
if self.visibility_node is None:
742+
return None
743+
return FunctionVisibility(self.visibility_node.id)
744+
745+
def set_state_mutability(self, decorator_node: vy_ast.Name):
746+
assert StateMutability.is_valid_value(decorator_node.id), "unreachable"
747+
if self.state_mutability_node is not None:
748+
raise FunctionDeclarationException(
749+
f"Mutability is already set to: {self.state_mutability}",
750+
self.state_mutability_node,
751+
decorator_node,
752+
hint="only one state mutability decorator is allowed per function",
753+
)
754+
self.state_mutability_node = decorator_node
755+
756+
@property
757+
def state_mutability(self) -> StateMutability:
758+
if self.state_mutability_node is None:
759+
return StateMutability.NONPAYABLE # default
760+
return StateMutability(self.state_mutability_node.id)
761+
762+
def set_nonreentrant(self, decorator_node: vy_ast.Name):
763+
if self.nonreentrant_node is not None:
764+
raise StructureException(
765+
"nonreentrant decorator is already set", self.nonreentrant_node, decorator_node
766+
)
767+
768+
self.nonreentrant_node = decorator_node
769+
770+
@property
771+
def nonreentrant(self) -> bool:
772+
return self.nonreentrant_node is not None
773+
774+
775+
def _parse_decorators(funcdef: vy_ast.FunctionDef) -> _ParsedDecorators:
776+
ret = _ParsedDecorators()
732777

733778
for decorator in funcdef.decorator_list:
734779
if isinstance(decorator, vy_ast.Call):
@@ -741,44 +786,25 @@ def _parse_decorators(
741786
raise StructureException(msg, decorator, hint=hint)
742787

743788
if decorator.get("id") == "nonreentrant":
744-
if nonreentrant_node is not None:
745-
raise StructureException("nonreentrant decorator is already set", nonreentrant_node)
746-
747-
nonreentrant_node = decorator
789+
ret.set_nonreentrant(decorator)
748790

749791
elif isinstance(decorator, vy_ast.Name):
750792
if FunctionVisibility.is_valid_value(decorator.id):
751-
if function_visibility is not None:
752-
raise FunctionDeclarationException(
753-
f"Visibility is already set to: {function_visibility}",
754-
decorator,
755-
hint="only one visibility decorator is allowed per function",
756-
)
757-
758-
function_visibility = FunctionVisibility(decorator.id)
759-
793+
ret.set_visibility(decorator)
760794
elif StateMutability.is_valid_value(decorator.id):
761-
if state_mutability is not None:
762-
raise FunctionDeclarationException(
763-
f"Mutability is already set to: {state_mutability}", funcdef
764-
)
765-
state_mutability = StateMutability(decorator.id)
766-
795+
ret.set_state_mutability(decorator)
767796
else:
768797
raise FunctionDeclarationException(f"Unknown decorator: {decorator.id}", decorator)
769798

770799
else:
771800
raise StructureException("Bad decorator syntax", decorator)
772801

773-
if state_mutability is None:
774-
# default to nonpayable
775-
state_mutability = StateMutability.NONPAYABLE
776-
777-
if state_mutability == StateMutability.PURE and nonreentrant_node is not None:
778-
raise StructureException("Cannot use reentrancy guard on pure functions", nonreentrant_node)
802+
if ret.state_mutability == StateMutability.PURE and ret.nonreentrant_node is not None:
803+
raise StructureException(
804+
"Cannot use reentrancy guard on pure functions", ret.nonreentrant_node
805+
)
779806

780-
nonreentrant = nonreentrant_node is not None
781-
return function_visibility, state_mutability, nonreentrant
807+
return ret
782808

783809

784810
def _parse_args(

0 commit comments

Comments
 (0)