Skip to content

Commit efe1dbe

Browse files
feat: more arithmetic optimizations (#2647)
this is a small rewrite of the IR optimizer. it changes the structure of the binop optimizations so that it is easier to add more optimizations. it also refactors the `clamp` optimizations to be in terms of an `assert` statement, so that the clamp conditions can be optimized using the binop optimizer code. Co-authored-by: El De-dog-lo <[email protected]>
1 parent 4b44ee7 commit efe1dbe

25 files changed

+592
-471
lines changed

examples/auctions/blind_auction.vy

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,7 @@ def reveal(_numBids: int128, _values: uint256[128], _fakes: bool[128], _secrets:
127127

128128
# Bid was not actually revealed
129129
# Do not refund deposit
130-
if (blindedBid != bidToCheck.blindedBid):
131-
assert 1 == 0
132-
continue
130+
assert blindedBid == bidToCheck.blindedBid
133131

134132
# Add deposit to refund if bid was indeed revealed
135133
refund += bidToCheck.deposit

tests/compiler/LLL/test_optimize_lll.py

Lines changed: 0 additions & 21 deletions
This file was deleted.
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ def test_ir_compile_fail(bad_ir, get_contract_from_ir, assert_compile_failed):
2323

2424
valid_list = [
2525
["pass"],
26-
["clamplt", ["mload", 0], 300],
27-
["clampgt", ["mload", 0], -1],
28-
["uclampgt", 1, ["mload", 0]],
29-
["uclampge", ["mload", 0], 0],
26+
["assert", ["slt", ["mload", 0], 300]],
27+
["assert", ["sgt", ["mload", 0], -1]],
28+
["assert", ["gt", 1, ["mload", 0]]],
29+
["assert", ["ge", ["mload", 0], 0]],
3030
]
3131

3232

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import pytest
2+
3+
from vyper.codegen.ir_node import IRnode
4+
from vyper.exceptions import StaticAssertionException
5+
from vyper.ir import optimizer
6+
7+
optimize_list = [
8+
(["eq", 1, 2], [0]),
9+
(["lt", 1, 2], [1]),
10+
(["eq", "x", 0], ["iszero", "x"]),
11+
# branch pruner
12+
(["if", ["eq", 1, 2], "pass"], ["seq"]),
13+
(["if", ["eq", 1, 1], 3, 4], [3]),
14+
(["if", ["eq", 1, 2], 3, 4], [4]),
15+
(["seq", ["assert", ["lt", 1, 2]]], ["seq"]),
16+
(["seq", ["assert", ["lt", 1, 2]], 2], [2]),
17+
# condition rewriter
18+
(["if", ["eq", "x", "y"], "pass"], ["if", ["iszero", ["sub", "x", "y"]], "pass"]),
19+
(["if", "cond", 1, 0], ["if", ["iszero", "cond"], 0, 1]),
20+
(["assert", ["eq", "x", "y"]], ["assert", ["iszero", ["sub", "x", "y"]]]),
21+
# nesting
22+
(["mstore", 0, ["eq", 1, 2]], ["mstore", 0, 0]),
23+
# conditions
24+
(["ge", "x", 0], [1]), # x >= 0 == True
25+
(["iszero", ["gt", "x", 2 ** 256 - 1]], [1]), # x >= MAX_UINT256 == False
26+
(["iszero", ["sgt", "x", 2 ** 255 - 1]], [1]), # signed x >= MAX_INT256 == False
27+
(["le", "x", 0], ["iszero", "x"]),
28+
(["le", 0, "x"], [1]),
29+
(["lt", "x", 0], [0]),
30+
(["lt", 0, "x"], ["iszero", ["iszero", "x"]]),
31+
(["gt", 5, "x"], ["lt", "x", 5]),
32+
(["ge", 5, "x"], ["le", "x", 5]),
33+
(["lt", 5, "x"], ["gt", "x", 5]),
34+
(["le", 5, "x"], ["ge", "x", 5]),
35+
(["sgt", 5, "x"], ["slt", "x", 5]),
36+
(["sge", 5, "x"], ["sle", "x", 5]),
37+
(["slt", 5, "x"], ["sgt", "x", 5]),
38+
(["sle", 5, "x"], ["sge", "x", 5]),
39+
(["slt", "x", -(2 ** 255)], ["slt", "x", -(2 ** 255)]), # unimplemented
40+
# tricky conditions
41+
(["sgt", 2 ** 256 - 1, 0], [0]), # -1 > 0
42+
(["gt", 2 ** 256 - 1, 0], [1]), # -1 > 0
43+
(["gt", 2 ** 255, 0], [1]), # 0x80 > 0
44+
(["sgt", 2 ** 255, 0], [0]), # 0x80 > 0
45+
(["sgt", 2 ** 255, 2 ** 255 - 1], [0]), # 0x80 > 0x81
46+
(["gt", -(2 ** 255), 2 ** 255 - 1], [1]), # 0x80 > 0x81
47+
(["slt", 2 ** 255, 2 ** 255 - 1], [1]), # 0x80 < 0x7f
48+
(["lt", -(2 ** 255), 2 ** 255 - 1], [0]), # 0x80 < 0x7f
49+
(["sle", -1, 2 ** 256 - 1], [1]), # -1 <= -1
50+
(["sge", -(2 ** 255), 2 ** 255], [1]), # 0x80 >= 0x80
51+
(["sgt", -(2 ** 255), 2 ** 255], [0]), # 0x80 > 0x80
52+
(["slt", 2 ** 255, -(2 ** 255)], [0]), # 0x80 < 0x80
53+
# arithmetic
54+
(["add", "x", 0], ["x"]),
55+
(["add", 0, "x"], ["x"]),
56+
(["sub", "x", 0], ["x"]),
57+
(["sub", "x", "x"], [0]),
58+
(["sub", ["sload", 0], ["sload", 0]], ["sub", ["sload", 0], ["sload", 0]]), # no-op
59+
(["sub", ["callvalue"], ["callvalue"]], ["sub", ["callvalue"], ["callvalue"]]), # no-op
60+
(["mul", "x", 1], ["x"]),
61+
(["div", "x", 1], ["x"]),
62+
(["sdiv", "x", 1], ["x"]),
63+
(["mod", "x", 1], [0]),
64+
(["smod", "x", 1], [0]),
65+
(["mul", "x", -1], ["sub", 0, "x"]),
66+
(["sdiv", "x", -1], ["sub", 0, "x"]),
67+
(["mul", "x", 0], [0]),
68+
(["div", "x", 0], [0]),
69+
(["sdiv", "x", 0], [0]),
70+
(["mod", "x", 0], [0]),
71+
(["smod", "x", 0], [0]),
72+
(["mul", "x", 32], ["shl", 5, "x"]),
73+
(["div", "x", 64], ["shr", 6, "x"]),
74+
(["mod", "x", 128], ["and", "x", 127]),
75+
(["sdiv", "x", 64], ["sdiv", "x", 64]), # no-op
76+
(["smod", "x", 64], ["smod", "x", 64]), # no-op
77+
# bitwise ops
78+
(["shr", 0, "x"], ["x"]),
79+
(["sar", 0, "x"], ["x"]),
80+
(["shl", 0, "x"], ["x"]),
81+
(["and", 1, 2], [0]),
82+
(["or", 1, 2], [3]),
83+
(["xor", 1, 2], [3]),
84+
(["xor", 3, 2], [1]),
85+
(["and", 0, "x"], [0]),
86+
(["and", "x", 0], [0]),
87+
(["or", "x", 0], ["x"]),
88+
(["or", 0, "x"], ["x"]),
89+
(["xor", "x", 0], ["x"]),
90+
(["xor", "x", 1], ["xor", "x", 1]), # no-op
91+
(["and", "x", 1], ["and", "x", 1]), # no-op
92+
(["or", "x", 1], ["or", "x", 1]), # no-op
93+
(["xor", 0, "x"], ["x"]),
94+
(["iszero", ["or", "x", 1]], [0]),
95+
(["iszero", ["or", 2, "x"]], [0]),
96+
# nested optimizations
97+
(["eq", 0, ["sub", 1, 1]], [1]),
98+
(["eq", 0, ["add", 2 ** 255, 2 ** 255]], [1]), # test compile-time wrapping
99+
(["eq", 0, ["add", 2 ** 255, -(2 ** 255)]], [1]), # test compile-time wrapping
100+
(["eq", -1, ["add", 0, -1]], [1]), # test compile-time wrapping
101+
(["eq", -1, ["add", 2 ** 255, 2 ** 255 - 1]], [1]), # test compile-time wrapping
102+
(["eq", -1, ["add", -(2 ** 255), 2 ** 255 - 1]], [1]), # test compile-time wrapping
103+
(["eq", -2, ["add", 2 ** 256 - 1, 2 ** 256 - 1]], [1]), # test compile-time wrapping
104+
]
105+
106+
107+
@pytest.mark.parametrize("ir", optimize_list)
108+
def test_ir_optimizer(ir):
109+
optimized = optimizer.optimize(IRnode.from_list(ir[0]))
110+
optimized.repr_show_gas = True
111+
hand_optimized = IRnode.from_list(ir[1])
112+
hand_optimized.repr_show_gas = True
113+
assert optimized == hand_optimized
114+
115+
116+
static_assertions_list = [
117+
["assert", ["eq", 2, 1]],
118+
["assert", ["ne", 1, 1]],
119+
["assert", ["sub", 1, 1]],
120+
["assert", ["lt", 2, 1]],
121+
["assert", ["lt", 1, 1]],
122+
["assert", ["lt", "x", 0]], # +x < 0
123+
["assert", ["le", 1, 0]],
124+
["assert", ["le", 2 ** 256 - 1, 0]],
125+
["assert", ["gt", 1, 2]],
126+
["assert", ["gt", 1, 1]],
127+
["assert", ["gt", 0, 2 ** 256 - 1]],
128+
["assert", ["gt", "x", 2 ** 256 - 1]],
129+
["assert", ["ge", 1, 2]],
130+
["assert", ["ge", 1, 2]],
131+
["assert", ["slt", 2, 1]],
132+
["assert", ["slt", 1, 1]],
133+
["assert", ["slt", 0, 2 ** 256 - 1]], # 0 < -1
134+
["assert", ["slt", -(2 ** 255), 2 ** 255]], # 0x80 < 0x80
135+
["assert", ["sle", 0, 2 ** 255]], # 0 < 0x80
136+
["assert", ["sgt", 1, 2]],
137+
["assert", ["sgt", 1, 1]],
138+
["assert", ["sgt", 2 ** 256 - 1, 0]], # -1 > 0
139+
["assert", ["sgt", 2 ** 255, -(2 ** 255)]], # 0x80 > 0x80
140+
["assert", ["sge", 2 ** 255, 0]], # 0x80 > 0
141+
]
142+
143+
144+
@pytest.mark.parametrize("ir", static_assertions_list)
145+
def test_static_assertions(ir, assert_compile_failed):
146+
ir = IRnode.from_list(ir)
147+
assert_compile_failed(lambda: optimizer.optimize(ir), StaticAssertionException)

tests/compiler/test_clamps.py

Lines changed: 0 additions & 94 deletions
This file was deleted.

tests/parser/features/test_assert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def test_assert_refund(w3, get_contract_with_gas_estimation, assert_tx_failed):
1212
code = """
1313
@external
1414
def foo():
15-
assert 1 == 2
15+
raise
1616
"""
1717
c = get_contract_with_gas_estimation(code)
1818
a0 = w3.eth.accounts[0]

tests/parser/features/test_assert_unreachable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ def test_assure_refund(w3, get_contract):
22
code = """
33
@external
44
def foo():
5-
assert 1 == 2, UNREACHABLE
5+
assert msg.sender != msg.sender, UNREACHABLE
66
"""
77

88
c = get_contract(code)

0 commit comments

Comments
 (0)