Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tests/compiler/asm/test_asm_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from vyper.compiler.phases import CompilerData


def test_dead_code_eliminator():
code = """
s: uint256

@internal
def foo():
self.s = 1

@internal
def qux():
self.s = 2

@external
def bar():
self.foo()

@external
def __init__():
self.qux()
"""

c = CompilerData(code, no_optimize=True)
initcode_asm = [i for i in c.assembly if not isinstance(i, list)]
runtime_asm = c.assembly_runtime

foo_label = "_sym_internal_foo___"
qux_label = "_sym_internal_qux___"

# all the labels should be in all the unoptimized asms
for s in (foo_label, qux_label):
assert s in initcode_asm
assert s in runtime_asm

c = CompilerData(code, no_optimize=False)
initcode_asm = [i for i in c.assembly if not isinstance(i, list)]
runtime_asm = c.assembly_runtime

# qux should not be in runtime code
for instr in runtime_asm:
if isinstance(instr, str):
assert not instr.startswith(qux_label), instr

# foo should not be in initcode asm
for instr in initcode_asm:
if isinstance(instr, str):
assert not instr.startswith(foo_label), instr
8 changes: 4 additions & 4 deletions tests/functional/semantics/analysis/test_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,30 +108,30 @@ def main():
for j in range(3):
x: uint256 = j
y: uint16 = j
""", # issue 3212
""", # GH issue 3212
"""
@external
def foo():
for i in [1]:
a:uint256 = i
b:uint16 = i
""", # issue 3374
""", # GH issue 3374
"""
@external
def foo():
for i in [1]:
for j in [1]:
a:uint256 = i
b:uint16 = i
""", # issue 3374
""", # GH issue 3374
"""
@external
def foo():
for i in [1,2,3]:
for j in [1,2,3]:
b:uint256 = j + i
c:uint16 = i
""", # issue 3374
""", # GH issue 3374
]


Expand Down
2 changes: 1 addition & 1 deletion tests/parser/features/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def test_3034_verbatim(get_contract):
# test issue #3034 exactly
# test GH issue 3034 exactly
code = """
@view
@external
Expand Down
87 changes: 87 additions & 0 deletions tests/parser/features/test_immutable.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,90 @@ def get_immutable() -> uint256:

c = get_contract(code, n)
assert c.get_immutable() == n + 2


# GH issue 3292
def test_internal_functions_called_by_ctor_location(get_contract):
code = """
d: uint256
x: immutable(uint256)

@external
def __init__():
self.d = 1
x = 2
self.a()

@external
def test() -> uint256:
return self.d

@internal
def a():
self.d = x
"""
c = get_contract(code)
assert c.test() == 2


# GH issue 3292, extended to nested internal functions
def test_nested_internal_function_immutables(get_contract):
code = """
d: public(uint256)
x: public(immutable(uint256))

@external
def __init__():
self.d = 1
x = 2
self.a()

@internal
def a():
self.b()

@internal
def b():
self.d = x
"""
c = get_contract(code)
assert c.x() == 2
assert c.d() == 2


# GH issue 3292, test immutable read from both ctor and runtime
def test_immutable_read_ctor_and_runtime(get_contract):
code = """
d: public(uint256)
x: public(immutable(uint256))

@external
def __init__():
self.d = 1
x = 2
self.a()

@internal
def a():
self.d = x

@external
def thrash():
self.d += 5

@external
def fix():
self.a()
"""
c = get_contract(code)
assert c.x() == 2
assert c.d() == 2

c.thrash(transact={})

assert c.x() == 2
assert c.d() == 2 + 5

c.fix(transact={})
assert c.x() == 2
assert c.d() == 2
26 changes: 26 additions & 0 deletions tests/parser/features/test_init.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think it's worth adding a separate test that combines both issues? Also, it might make sense to add an additional test where the immutable variable is declared as public since this involves the generation of an additional getter that also must read the state correctly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added the addl test in ddbeed0

Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,29 @@ def baz() -> uint8:

n = 256
assert_compile_failed(lambda: get_contract(code, n))


# GH issue 3206
def test_nested_internal_call_from_ctor(get_contract):
code = """
x: uint256

@external
def __init__():
self.a()

@internal
def a():
self.x += 1
self.b()

@internal
def b():
self.x += 2

@external
def test() -> uint256:
return self.x
"""
c = get_contract(code)
assert c.test() == 3
4 changes: 4 additions & 0 deletions vyper/codegen/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
forvars=None,
constancy=Constancy.Mutable,
sig=None,
is_ctor_context=False,
):
# In-memory variables, in the form (name, memory location, type)
self.vars = vars_ or {}
Expand Down Expand Up @@ -89,6 +90,9 @@ def __init__(
self._internal_var_iter = 0
self._scope_id_iter = 0

# either the constructor, or called from the constructor
self.is_ctor_context = is_ctor_context

def is_constant(self):
return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr

Expand Down
2 changes: 1 addition & 1 deletion vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def parse_Name(self):

ofst = varinfo.position.offset

if self.context.sig.is_init_func:
if self.context.is_ctor_context:
mutable = True
location = IMMUTABLES
else:
Expand Down
10 changes: 9 additions & 1 deletion vyper/codegen/function_definitions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def generate_ir_for_function(
sigs: Dict[str, Dict[str, FunctionSignature]], # all signatures in all namespaces
global_ctx: GlobalContext,
skip_nonpayable_check: bool,
is_ctor_context: bool = False,
) -> IRnode:
"""
Parse a function and produce IR code for the function, includes:
Expand Down Expand Up @@ -51,6 +52,7 @@ def generate_ir_for_function(
memory_allocator=memory_allocator,
constancy=Constancy.Constant if sig.mutability in ("view", "pure") else Constancy.Mutable,
sig=sig,
is_ctor_context=is_ctor_context,
)

if sig.internal:
Expand All @@ -65,13 +67,19 @@ def generate_ir_for_function(

frame_size = context.memory_allocator.size_of_mem - MemoryPositions.RESERVED_MEMORY

sig.set_frame_info(FrameInfo(allocate_start, frame_size, context.vars))
frame_info = FrameInfo(allocate_start, frame_size, context.vars)

if sig.frame_info is None:
sig.set_frame_info(frame_info)
else:
assert frame_info == sig.frame_info

if not sig.internal:
# adjust gas estimate to include cost of mem expansion
# frame_size of external function includes all private functions called
# (note: internal functions do not need to adjust gas estimate since
# it is already accounted for by the caller.)
assert sig.frame_info is not None # mypy hint
o.add_gas_estimate += calc_mem_gas(sig.frame_info.mem_used)

sig.gas_estimate = o.gas
Expand Down
34 changes: 22 additions & 12 deletions vyper/codegen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,21 @@ def _runtime_ir(runtime_functions, all_sigs, global_ctx):

# create a map of the IR functions since they might live in both
# runtime and deploy code (if init function calls them)
internal_functions_map: Dict[str, IRnode] = {}
internal_functions_ir: list[IRnode] = []

for func_ast in internal_functions:
func_ir = generate_ir_for_function(func_ast, all_sigs, global_ctx, False)
internal_functions_map[func_ast.name] = func_ir
internal_functions_ir.append(func_ir)

# for some reason, somebody may want to deploy a contract with no
# external functions, or more likely, a "pure data" contract which
# contains immutables
if len(external_functions) == 0:
# TODO: prune internal functions in this case?
runtime = ["seq"] + list(internal_functions_map.values())
return runtime, internal_functions_map
# TODO: prune internal functions in this case? dead code eliminator
# might not eliminate them, since internal function jumpdest is at the
# first instruction in the contract.
runtime = ["seq"] + internal_functions_ir
return runtime

# note: if the user does not provide one, the default fallback function
# reverts anyway. so it does not hurt to batch the payable check.
Expand Down Expand Up @@ -125,10 +127,10 @@ def _runtime_ir(runtime_functions, all_sigs, global_ctx):
["label", "fallback", ["var_list"], fallback_ir],
]

# TODO: prune unreachable functions?
runtime.extend(internal_functions_map.values())
# note: dead code eliminator will clean dead functions
runtime.extend(internal_functions_ir)

return runtime, internal_functions_map
return runtime


# take a GlobalContext, which is basically
Expand Down Expand Up @@ -159,12 +161,15 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> Tuple[IRnode, IRnode, F
runtime_functions = [f for f in function_defs if not _is_init_func(f)]
init_function = next((f for f in function_defs if _is_init_func(f)), None)

runtime, internal_functions = _runtime_ir(runtime_functions, all_sigs, global_ctx)
runtime = _runtime_ir(runtime_functions, all_sigs, global_ctx)

deploy_code: List[Any] = ["seq"]
immutables_len = global_ctx.immutable_section_bytes
if init_function:
init_func_ir = generate_ir_for_function(init_function, all_sigs, global_ctx, False)
# TODO might be cleaner to separate this into an _init_ir helper func
init_func_ir = generate_ir_for_function(
init_function, all_sigs, global_ctx, skip_nonpayable_check=False, is_ctor_context=True
)
deploy_code.append(init_func_ir)

# pass the amount of memory allocated for the init function
Expand All @@ -174,8 +179,13 @@ def generate_ir_for_module(global_ctx: GlobalContext) -> Tuple[IRnode, IRnode, F
deploy_code.append(["deploy", init_mem_used, runtime, immutables_len])

# internal functions come after everything else
for f in init_function._metadata["type"].called_functions:
deploy_code.append(internal_functions[f.name])
internal_functions = [f for f in runtime_functions if _is_internal(f)]
for f in internal_functions:
func_ir = generate_ir_for_function(
f, all_sigs, global_ctx, skip_nonpayable_check=False, is_ctor_context=True
)
# note: we depend on dead code eliminator to clean dead function defs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So are you 100% sure that if internal functions run only within ctor are eliminated by the dead code eliminator or does this need further investigations? Can we maybe add here an additional test? Just to ensure to no dead code is deployed (this happened to Solidity already btw).

Copy link
Member Author

@charles-cooper charles-cooper May 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea i'm 100% sure, i double checked by hand. we could add tests for the dead code eliminator altho it seems a bit out of scope for this PR

ps:

@internal
def foo():
    pass

@internal
def qux():
    pass

@external
def bar():
    self.foo()

@external
def __init__():
    self.qux()

produces the following asm:

CALLVALUE _sym_revert1 JUMPI _sym_external___init______cleanup _sym_internal_qux____cleanup JUMP _sym_external___init______cleanup JUMPDEST _sym_subcode_size _sym_runtime_begin2 _mem_deploy_start CODECOPY _OFST _sym_subcode_size 0 _mem_deploy_start RETURN _sym_runtime_begin2 BLANK { _DEPLOY_MEM_OFST_64 PUSH1 0x00 CALLDATALOAD PUSH1 0xe0 SHR CALLVALUE _sym_revert1 JUMPI PUSH4 0xfebb0f7e DUP2 XOR _sym_join3 JUMPI PUSH1 0x04 CALLDATASIZE LT _sym_revert1 JUMPI _sym_external_bar____cleanup _sym_internal_foo____cleanup JUMP _sym_external_bar____cleanup JUMPDEST STOP _sym_join3 JUMPDEST POP PUSH1 0x00 PUSH1 0x00 REVERT _sym_internal_foo____cleanup JUMPDEST JUMP _sym_revert1 JUMPDEST PUSH1 0x00 DUP1 REVERT } _sym_internal_qux____cleanup JUMPDEST JUMP _sym_revert1 JUMPDEST PUSH1 0x00 DUP1 REVERT 

which you can manually verify that qux does not appear in the runtime code and foo does not appear in the initcode.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok - i added the dead code eliminator test in 2897ff4

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also quickly verified it manually. Also, for the record, I tested this morning manually a contract with 10 nested internal calls. +1 for the code eliminator test.

deploy_code.append(func_ir)

else:
if immutables_len != 0:
Expand Down