Skip to content

Commit 4b44ee7

Browse files
fix: call internal functions from constructor (#2496)
this commit allows the user to call internal functions from the `__init__` function. it does this by generating a call graph during the annotation phase and then generating code for the functions called from the init function for during deploy code generation this also has a performance benefit (compiler time) because we can get rid of the two-pass method for tracing frame size. now that we have a call graph, this commit also introduces a topsort of functions based on the call dependency tree. this ensures we can compile functions that call functions that occur after them in the source code. lastly, this commit also refactors vyper/codegen/module.py so that the payable logic is cleaner, it uses properties instead of calculations more, and cleans up properties on IRnode, FunctionSignature and Context.
1 parent 03b2f1d commit 4b44ee7

File tree

22 files changed

+383
-341
lines changed

22 files changed

+383
-341
lines changed

examples/stock/company.vy

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,6 @@ def __init__(_company: address, _total_shares: uint256, initial_price: uint256):
3939
# The company holds all the shares at first, but can sell them all.
4040
self.holdings[self.company] = _total_shares
4141

42-
# Find out how much stock the company holds
43-
@view
44-
@internal
45-
def _stockAvailable() -> uint256:
46-
return self.holdings[self.company]
47-
4842
# Public function to allow external access to _stockAvailable
4943
@view
5044
@external
@@ -69,12 +63,6 @@ def buyStock():
6963
# Log the buy event.
7064
log Buy(msg.sender, buy_order)
7165

72-
# Find out how much stock any address (that's owned by someone) has.
73-
@view
74-
@internal
75-
def _getHolding(_stockholder: address) -> uint256:
76-
return self.holdings[_stockholder]
77-
7866
# Public function to allow external access to _getHolding
7967
@view
8068
@external
@@ -135,12 +123,6 @@ def payBill(vendor: address, amount: uint256):
135123
# Log the payment event.
136124
log Pay(vendor, amount)
137125

138-
# Return the amount in wei that a company has raised in stock offerings.
139-
@view
140-
@internal
141-
def _debt() -> uint256:
142-
return (self.totalShares - self._stockAvailable()) * self.price
143-
144126
# Public function to allow external access to _debt
145127
@view
146128
@external
@@ -154,3 +136,21 @@ def debt() -> uint256:
154136
@external
155137
def worth() -> uint256:
156138
return self.balance - self._debt()
139+
140+
# Return the amount in wei that a company has raised in stock offerings.
141+
@view
142+
@internal
143+
def _debt() -> uint256:
144+
return (self.totalShares - self._stockAvailable()) * self.price
145+
146+
# Find out how much stock the company holds
147+
@view
148+
@internal
149+
def _stockAvailable() -> uint256:
150+
return self.holdings[self.company]
151+
152+
# Find out how much stock any address (that's owned by someone) has.
153+
@view
154+
@internal
155+
def _getHolding(_stockholder: address) -> uint256:
156+
return self.holdings[_stockholder]

tests/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,11 @@ def get_contract_module(source_code, *args, **kwargs):
114114

115115

116116
def get_compiler_gas_estimate(code, func):
117-
ir_runtime = compiler.phases.CompilerData(code).ir_runtime
117+
sigs = compiler.phases.CompilerData(code).function_signatures
118118
if func:
119-
return compiler.utils.build_gas_estimates(ir_runtime)[func] + 22000
119+
return compiler.utils.build_gas_estimates(sigs)[func] + 22000
120120
else:
121-
return sum(compiler.utils.build_gas_estimates(ir_runtime).values()) + 22000
121+
return sum(compiler.utils.build_gas_estimates(sigs).values()) + 22000
122122

123123

124124
def check_gas_on_chain(w3, tester, code, func=None, res=None):

tests/parser/features/test_immutable.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,30 @@ def get_idx_two() -> int128:
212212
c = get_contract(code, *values)
213213
assert c.get_my_list() == expected_values
214214
assert c.get_idx_two() == expected_values[2][2][2]
215+
216+
217+
@pytest.mark.parametrize("n", range(5))
218+
def test_internal_function_with_immutables(get_contract, n):
219+
code = """
220+
@internal
221+
def foo() -> uint256:
222+
self.counter += 1
223+
return self.counter
224+
225+
counter: uint256
226+
VALUE: immutable(uint256)
227+
228+
@external
229+
def __init__(x: uint256):
230+
self.counter = x
231+
self.foo()
232+
VALUE = self.foo()
233+
self.foo()
234+
235+
@external
236+
def get_immutable() -> uint256:
237+
return VALUE
238+
"""
239+
240+
c = get_contract(code, n)
241+
assert c.get_immutable() == n + 2

tests/parser/features/test_init.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,34 @@ def __init__(a: uint256):
2121
assert "CALLDATALOAD" in opcodes
2222
assert "CALLDATACOPY" not in opcodes[:ir_return_idx]
2323
assert "CALLDATALOAD" not in opcodes[:ir_return_idx]
24+
25+
26+
def test_init_calls_internal(get_contract, assert_compile_failed, assert_tx_failed):
27+
code = """
28+
foo: public(uint8)
29+
@internal
30+
def bar(x: uint256) -> uint8:
31+
return convert(x, uint8) * 7
32+
@external
33+
def __init__(a: uint256):
34+
self.foo = self.bar(a)
35+
36+
@external
37+
def baz() -> uint8:
38+
return self.bar(convert(self.foo, uint256))
39+
"""
40+
n = 5
41+
c = get_contract(code, n)
42+
assert c.foo() == n * 7
43+
assert c.baz() == 245 # 5*7*7
44+
45+
n = 6
46+
c = get_contract(code, n)
47+
assert c.foo() == n * 7
48+
assert_tx_failed(lambda: c.baz())
49+
50+
n = 255
51+
assert_compile_failed(lambda: get_contract(code, n))
52+
53+
n = 256
54+
assert_compile_failed(lambda: get_contract(code, n))

vyper/ast/signatures/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .function_signature import FunctionSignature, VariableRecord
1+
from .function_signature import FrameInfo, FunctionSignature, VariableRecord

vyper/ast/signatures/function_signature.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from vyper.codegen.ir_node import Encoding
88
from vyper.codegen.types import NodeType, parse_type
99
from vyper.exceptions import StructureException
10-
from vyper.utils import cached_property, mkalphanum
10+
from vyper.utils import MemoryPositions, cached_property, mkalphanum
1111

1212
# dict from function names to signatures
1313
FunctionSignatures = Dict[str, "FunctionSignature"]
@@ -66,6 +66,16 @@ class FunctionArg:
6666
ast_source: vy_ast.VyperNode
6767

6868

69+
@dataclass
70+
class FrameInfo:
71+
frame_start: int
72+
frame_size: int
73+
74+
@property
75+
def mem_used(self):
76+
return self.frame_size + MemoryPositions.RESERVED_MEMORY
77+
78+
6979
# Function signature object
7080
class FunctionSignature:
7181
def __init__(
@@ -84,19 +94,25 @@ def __init__(
8494
self.return_type = return_type
8595
self.mutability = mutability
8696
self.internal = internal
87-
self.gas = None
97+
self.gas_estimate = None
8898
self.nonreentrant_key = nonreentrant_key
8999
self.func_ast_code = func_ast_code
90100
self.is_from_json = is_from_json
91101

92102
self.set_default_args()
93103

104+
# frame info is metadata that will be generated during codegen.
105+
self.frame_info: Optional[FrameInfo] = None
106+
94107
def __str__(self):
95108
input_name = "def " + self.name + "(" + ",".join([str(arg.typ) for arg in self.args]) + ")"
96109
if self.return_type:
97110
return input_name + " -> " + str(self.return_type) + ":"
98111
return input_name + ":"
99112

113+
def set_frame_info(self, frame_info):
114+
self.frame_info = frame_info
115+
100116
@cached_property
101117
def _ir_identifier(self) -> str:
102118
# we could do a bit better than this but it just needs to be unique
@@ -228,3 +244,7 @@ def is_default_func(self):
228244
@property
229245
def is_init_func(self):
230246
return self.name == "__init__"
247+
248+
@property
249+
def is_regular_function(self):
250+
return not self.is_default_func and not self.is_init_func

vyper/ast/signatures/interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ def mk_full_signature_from_json(abi):
9898
def _get_external_signatures(global_ctx, sig_formatter=lambda x: x):
9999
ret = []
100100

101-
for code in global_ctx._defs:
101+
for func_ast in global_ctx._function_defs:
102102
sig = FunctionSignature.from_definition(
103-
code,
103+
func_ast,
104104
sigs=global_ctx._contracts,
105105
custom_structs=global_ctx._structs,
106106
)

vyper/codegen/context.py

Lines changed: 17 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from vyper.ast import VyperNode
66
from vyper.ast.signatures.function_signature import VariableRecord
77
from vyper.codegen.types import NodeType
8-
from vyper.exceptions import CompilerPanic, FunctionDeclarationException
8+
from vyper.exceptions import CompilerPanic
99

1010

1111
class Constancy(enum.Enum):
@@ -22,11 +22,7 @@ def __init__(
2222
vars_=None,
2323
sigs=None,
2424
forvars=None,
25-
return_type=None,
2625
constancy=Constancy.Mutable,
27-
is_internal=False,
28-
is_payable=False,
29-
# method_id="",
3026
sig=None,
3127
):
3228
# In-memory variables, in the form (name, memory location, type)
@@ -41,9 +37,6 @@ def __init__(
4137
# Variables defined in for loops, e.g. for i in range(6): ...
4238
self.forvars = forvars or {}
4339

44-
# Return type of the function
45-
self.return_type = return_type
46-
4740
# Is the function constant?
4841
self.constancy = constancy
4942

@@ -53,14 +46,9 @@ def __init__(
5346
# Whether we are currently parsing a range expression
5447
self.in_range_expr = False
5548

56-
# Is the function payable?
57-
self.is_payable = is_payable
58-
5949
# List of custom structs that have been defined.
6050
self.structs = global_ctx._structs
6151

62-
self.is_internal = is_internal
63-
6452
# store global context
6553
self.global_ctx = global_ctx
6654

@@ -73,23 +61,25 @@ def __init__(
7361
# Not intended to be accessed directly
7462
self.memory_allocator = memory_allocator
7563

76-
self._callee_frame_sizes = []
77-
78-
# Intermented values, used for internal IDs
64+
# Incremented values, used for internal IDs
7965
self._internal_var_iter = 0
8066
self._scope_id_iter = 0
8167

8268
def is_constant(self):
8369
return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr
8470

85-
def register_callee(self, frame_size):
86-
self._callee_frame_sizes.append(frame_size)
71+
# convenience propreties
72+
@property
73+
def is_payable(self):
74+
return self.sig.mutability == "payable"
8775

8876
@property
89-
def max_callee_frame_size(self):
90-
if len(self._callee_frame_sizes) == 0:
91-
return 0
92-
return max(self._callee_frame_sizes)
77+
def is_internal(self):
78+
return self.sig.internal
79+
80+
@property
81+
def return_type(self):
82+
return self.sig.return_type
9383

9484
#
9585
# Context Managers
@@ -248,23 +238,16 @@ def lookup_internal_function(self, method_name, args_ir, ast_source):
248238
the kwargs which need to be filled in by the compiler
249239
"""
250240

241+
sig = self.sigs["self"].get(method_name, None)
242+
251243
def _check(cond, s="Unreachable"):
252244
if not cond:
253245
raise CompilerPanic(s)
254246

255-
sig = self.sigs["self"].get(method_name, None)
256-
if sig is None:
257-
raise FunctionDeclarationException(
258-
"Function does not exist or has not been declared yet "
259-
"(reminder: functions cannot call functions later in code "
260-
"than themselves)",
261-
ast_source,
262-
)
263-
264-
_check(sig.internal) # sanity check
265-
# should have been caught during type checking, sanity check anyway
247+
# these should have been caught during type checking; sanity check
248+
_check(sig is not None)
249+
_check(sig.internal)
266250
_check(len(sig.base_args) <= len(args_ir) <= len(sig.args))
267-
268251
# more sanity check, that the types match
269252
# _check(all(l.typ == r.typ for (l, r) in zip(args_ir, sig.args))
270253

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .common import generate_ir_for_function, is_default_func, is_initializer # noqa
1+
from .common import generate_ir_for_function # noqa

0 commit comments

Comments
 (0)