Skip to content

Commit c202c4e

Browse files
fix: constructor context for internal functions (#3388)
this commit fixes two related issues with initcode generation: - nested internal functions called from the constructor would cause a compiler panic - internal functions called from the constructor would not read/write from the correct immutables space the relevant examples reproducing each issue are in the tests. this commit fixes the issue by - not trying to traverse the call graph to figure out which internal functions to include in the initcode. instead, all internal functions are included, and we rely on the dead code eliminator to remove unused functions - adding a "constructor" flag to the codegen, so we can distinguish between internal calls which are being generated to include in initcode or runtime code.
1 parent 1c8349e commit c202c4e

File tree

9 files changed

+203
-19
lines changed

9 files changed

+203
-19
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from vyper.compiler.phases import CompilerData
2+
3+
4+
def test_dead_code_eliminator():
5+
code = """
6+
s: uint256
7+
8+
@internal
9+
def foo():
10+
self.s = 1
11+
12+
@internal
13+
def qux():
14+
self.s = 2
15+
16+
@external
17+
def bar():
18+
self.foo()
19+
20+
@external
21+
def __init__():
22+
self.qux()
23+
"""
24+
25+
c = CompilerData(code, no_optimize=True)
26+
initcode_asm = [i for i in c.assembly if not isinstance(i, list)]
27+
runtime_asm = c.assembly_runtime
28+
29+
foo_label = "_sym_internal_foo___"
30+
qux_label = "_sym_internal_qux___"
31+
32+
# all the labels should be in all the unoptimized asms
33+
for s in (foo_label, qux_label):
34+
assert s in initcode_asm
35+
assert s in runtime_asm
36+
37+
c = CompilerData(code, no_optimize=False)
38+
initcode_asm = [i for i in c.assembly if not isinstance(i, list)]
39+
runtime_asm = c.assembly_runtime
40+
41+
# qux should not be in runtime code
42+
for instr in runtime_asm:
43+
if isinstance(instr, str):
44+
assert not instr.startswith(qux_label), instr
45+
46+
# foo should not be in initcode asm
47+
for instr in initcode_asm:
48+
if isinstance(instr, str):
49+
assert not instr.startswith(foo_label), instr

tests/functional/semantics/analysis/test_for_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,30 +108,30 @@ def main():
108108
for j in range(3):
109109
x: uint256 = j
110110
y: uint16 = j
111-
""", # issue 3212
111+
""", # GH issue 3212
112112
"""
113113
@external
114114
def foo():
115115
for i in [1]:
116116
a:uint256 = i
117117
b:uint16 = i
118-
""", # issue 3374
118+
""", # GH issue 3374
119119
"""
120120
@external
121121
def foo():
122122
for i in [1]:
123123
for j in [1]:
124124
a:uint256 = i
125125
b:uint16 = i
126-
""", # issue 3374
126+
""", # GH issue 3374
127127
"""
128128
@external
129129
def foo():
130130
for i in [1,2,3]:
131131
for j in [1,2,3]:
132132
b:uint256 = j + i
133133
c:uint16 = i
134-
""", # issue 3374
134+
""", # GH issue 3374
135135
]
136136

137137

tests/parser/features/test_comparison.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
def test_3034_verbatim(get_contract):
7-
# test issue #3034 exactly
7+
# test GH issue 3034 exactly
88
code = """
99
@view
1010
@external

tests/parser/features/test_immutable.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,90 @@ def get_immutable() -> uint256:
239239

240240
c = get_contract(code, n)
241241
assert c.get_immutable() == n + 2
242+
243+
244+
# GH issue 3292
245+
def test_internal_functions_called_by_ctor_location(get_contract):
246+
code = """
247+
d: uint256
248+
x: immutable(uint256)
249+
250+
@external
251+
def __init__():
252+
self.d = 1
253+
x = 2
254+
self.a()
255+
256+
@external
257+
def test() -> uint256:
258+
return self.d
259+
260+
@internal
261+
def a():
262+
self.d = x
263+
"""
264+
c = get_contract(code)
265+
assert c.test() == 2
266+
267+
268+
# GH issue 3292, extended to nested internal functions
269+
def test_nested_internal_function_immutables(get_contract):
270+
code = """
271+
d: public(uint256)
272+
x: public(immutable(uint256))
273+
274+
@external
275+
def __init__():
276+
self.d = 1
277+
x = 2
278+
self.a()
279+
280+
@internal
281+
def a():
282+
self.b()
283+
284+
@internal
285+
def b():
286+
self.d = x
287+
"""
288+
c = get_contract(code)
289+
assert c.x() == 2
290+
assert c.d() == 2
291+
292+
293+
# GH issue 3292, test immutable read from both ctor and runtime
294+
def test_immutable_read_ctor_and_runtime(get_contract):
295+
code = """
296+
d: public(uint256)
297+
x: public(immutable(uint256))
298+
299+
@external
300+
def __init__():
301+
self.d = 1
302+
x = 2
303+
self.a()
304+
305+
@internal
306+
def a():
307+
self.d = x
308+
309+
@external
310+
def thrash():
311+
self.d += 5
312+
313+
@external
314+
def fix():
315+
self.a()
316+
"""
317+
c = get_contract(code)
318+
assert c.x() == 2
319+
assert c.d() == 2
320+
321+
c.thrash(transact={})
322+
323+
assert c.x() == 2
324+
assert c.d() == 2 + 5
325+
326+
c.fix(transact={})
327+
assert c.x() == 2
328+
assert c.d() == 2

tests/parser/features/test_init.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,29 @@ def baz() -> uint8:
5353

5454
n = 256
5555
assert_compile_failed(lambda: get_contract(code, n))
56+
57+
58+
# GH issue 3206
59+
def test_nested_internal_call_from_ctor(get_contract):
60+
code = """
61+
x: uint256
62+
63+
@external
64+
def __init__():
65+
self.a()
66+
67+
@internal
68+
def a():
69+
self.x += 1
70+
self.b()
71+
72+
@internal
73+
def b():
74+
self.x += 2
75+
76+
@external
77+
def test() -> uint256:
78+
return self.x
79+
"""
80+
c = get_contract(code)
81+
assert c.test() == 3

vyper/codegen/context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
forvars=None,
5555
constancy=Constancy.Mutable,
5656
sig=None,
57+
is_ctor_context=False,
5758
):
5859
# In-memory variables, in the form (name, memory location, type)
5960
self.vars = vars_ or {}
@@ -92,6 +93,9 @@ def __init__(
9293
self._internal_var_iter = 0
9394
self._scope_id_iter = 0
9495

96+
# either the constructor, or called from the constructor
97+
self.is_ctor_context = is_ctor_context
98+
9599
def is_constant(self):
96100
return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr
97101

vyper/codegen/expr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def parse_Name(self):
184184

185185
ofst = varinfo.position.offset
186186

187-
if self.context.sig.is_init_func:
187+
if self.context.is_ctor_context:
188188
mutable = True
189189
location = IMMUTABLES
190190
else:

vyper/codegen/function_definitions/common.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def generate_ir_for_function(
1818
sigs: Dict[str, Dict[str, FunctionSignature]], # all signatures in all namespaces
1919
global_ctx: GlobalContext,
2020
skip_nonpayable_check: bool,
21+
is_ctor_context: bool = False,
2122
) -> IRnode:
2223
"""
2324
Parse a function and produce IR code for the function, includes:
@@ -51,6 +52,7 @@ def generate_ir_for_function(
5152
memory_allocator=memory_allocator,
5253
constancy=Constancy.Constant if sig.mutability in ("view", "pure") else Constancy.Mutable,
5354
sig=sig,
55+
is_ctor_context=is_ctor_context,
5456
)
5557

5658
if sig.internal:
@@ -65,13 +67,19 @@ def generate_ir_for_function(
6567

6668
frame_size = context.memory_allocator.size_of_mem - MemoryPositions.RESERVED_MEMORY
6769

68-
sig.set_frame_info(FrameInfo(allocate_start, frame_size, context.vars))
70+
frame_info = FrameInfo(allocate_start, frame_size, context.vars)
71+
72+
if sig.frame_info is None:
73+
sig.set_frame_info(frame_info)
74+
else:
75+
assert frame_info == sig.frame_info
6976

7077
if not sig.internal:
7178
# adjust gas estimate to include cost of mem expansion
7279
# frame_size of external function includes all private functions called
7380
# (note: internal functions do not need to adjust gas estimate since
7481
# it is already accounted for by the caller.)
82+
assert sig.frame_info is not None # mypy hint
7583
o.add_gas_estimate += calc_mem_gas(sig.frame_info.mem_used)
7684

7785
sig.gas_estimate = o.gas

vyper/codegen/module.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,21 @@ def _runtime_ir(runtime_functions, all_sigs, global_ctx):
6767

6868
# create a map of the IR functions since they might live in both
6969
# runtime and deploy code (if init function calls them)
70-
internal_functions_map: Dict[str, IRnode] = {}
70+
internal_functions_ir: list[IRnode] = []
7171

7272
for func_ast in internal_functions:
7373
func_ir = generate_ir_for_function(func_ast, all_sigs, global_ctx, False)
74-
internal_functions_map[func_ast.name] = func_ir
74+
internal_functions_ir.append(func_ir)
7575

7676
# for some reason, somebody may want to deploy a contract with no
7777
# external functions, or more likely, a "pure data" contract which
7878
# contains immutables
7979
if len(external_functions) == 0:
80-
# TODO: prune internal functions in this case?
81-
runtime = ["seq"] + list(internal_functions_map.values())
82-
return runtime, internal_functions_map
80+
# TODO: prune internal functions in this case? dead code eliminator
81+
# might not eliminate them, since internal function jumpdest is at the
82+
# first instruction in the contract.
83+
runtime = ["seq"] + internal_functions_ir
84+
return runtime
8385

8486
# note: if the user does not provide one, the default fallback function
8587
# reverts anyway. so it does not hurt to batch the payable check.
@@ -125,10 +127,10 @@ def _runtime_ir(runtime_functions, all_sigs, global_ctx):
125127
["label", "fallback", ["var_list"], fallback_ir],
126128
]
127129

128-
# TODO: prune unreachable functions?
129-
runtime.extend(internal_functions_map.values())
130+
# note: dead code eliminator will clean dead functions
131+
runtime.extend(internal_functions_ir)
130132

131-
return runtime, internal_functions_map
133+
return runtime
132134

133135

134136
# take a GlobalContext, which is basically
@@ -159,12 +161,15 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> Tuple[IRnode, IRnode, F
159161
runtime_functions = [f for f in function_defs if not _is_init_func(f)]
160162
init_function = next((f for f in function_defs if _is_init_func(f)), None)
161163

162-
runtime, internal_functions = _runtime_ir(runtime_functions, all_sigs, global_ctx)
164+
runtime = _runtime_ir(runtime_functions, all_sigs, global_ctx)
163165

164166
deploy_code: List[Any] = ["seq"]
165167
immutables_len = global_ctx.immutable_section_bytes
166168
if init_function:
167-
init_func_ir = generate_ir_for_function(init_function, all_sigs, global_ctx, False)
169+
# TODO might be cleaner to separate this into an _init_ir helper func
170+
init_func_ir = generate_ir_for_function(
171+
init_function, all_sigs, global_ctx, skip_nonpayable_check=False, is_ctor_context=True
172+
)
168173
deploy_code.append(init_func_ir)
169174

170175
# pass the amount of memory allocated for the init function
@@ -174,8 +179,13 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> Tuple[IRnode, IRnode, F
174179
deploy_code.append(["deploy", init_mem_used, runtime, immutables_len])
175180

176181
# internal functions come after everything else
177-
for f in init_function._metadata["type"].called_functions:
178-
deploy_code.append(internal_functions[f.name])
182+
internal_functions = [f for f in runtime_functions if _is_internal(f)]
183+
for f in internal_functions:
184+
func_ir = generate_ir_for_function(
185+
f, all_sigs, global_ctx, skip_nonpayable_check=False, is_ctor_context=True
186+
)
187+
# note: we depend on dead code eliminator to clean dead function defs
188+
deploy_code.append(func_ir)
179189

180190
else:
181191
if immutables_len != 0:

0 commit comments

Comments
 (0)