Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions go/vt/sqlparser/normalizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type (
bindVars map[string]*querypb.BindVariable
reserved *ReservedVars
vals map[Literal]string
tupleVals map[string]string
err error
inDerived int
inSelect int
Expand Down Expand Up @@ -145,6 +146,7 @@ func newNormalizer(
bindVars: bindVars,
reserved: reserved,
vals: make(map[Literal]string),
tupleVals: make(map[string]string),
bindVarNeeds: &BindVarNeeds{},
keyspace: keyspace,
selectLimit: selectLimit,
Expand Down Expand Up @@ -470,8 +472,22 @@ func (nz *normalizer) rewriteInComparisons(node *ComparisonExpr) {
Value: bval.Value,
})
}
bvname := nz.reserved.nextUnusedVar()
nz.bindVars[bvname] = bvals

var bvname string

if key, err := bvals.MarshalVT(); err != nil {
bvname = nz.reserved.nextUnusedVar()
nz.bindVars[bvname] = bvals
} else {
// Check if we can find key in tuplevals
if bvname, ok = nz.tupleVals[string(key)]; !ok {
bvname = nz.reserved.nextUnusedVar()
}

nz.bindVars[bvname] = bvals
nz.tupleVals[string(key)] = bvname
}
Comment on lines +476 to +489
Copy link
Member Author

@arthurschreiber arthurschreiber May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit awkward - We can't use bval directly as the map key, because Values on querypb.BindVariable is a slice, and structs containing slices can't be used as map keys.

So I'm using MarshalVT in combination with string to generate a value that can be used as a key. This is probably not super-optimal, but 🤷


node.Right = ListArg(bvname)
}

Expand Down
108 changes: 57 additions & 51 deletions go/vt/sqlparser/normalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,13 @@ func TestNormalize(t *testing.T) {
"bv1": sqltypes.TestBindVariable([]any{1, "2"}),
},
}, {
// EXPLAIN query will be normalized and not parameterized
// repeated IN clause with vals
in: "select * from t where v1 in (1, '2') OR v2 in (1, '2')",
outstmt: "select * from t where v1 in ::bv1 or v2 in ::bv1",
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.TestBindVariable([]any{1, "2"}),
},
}, { // EXPLAIN query will be normalized and not parameterized
in: "explain select @x from t where v1 in (1, '2')",
outstmt: "explain select :__vtudvx as `@x` from t where v1 in (1, '2')",
outbv: map[string]*querypb.BindVariable{},
Expand Down Expand Up @@ -1325,9 +1331,9 @@ JOIN warehouse%d AS w ON c_w_id=w_id
WHERE w_id = %d
AND c_d_id = %d
AND c_id = %d`,
`SELECT d_next_o_id, d_tax
FROM district%d
WHERE d_w_id = %d
`SELECT d_next_o_id, d_tax
FROM district%d
WHERE d_w_id = %d
AND d_id = %d FOR UPDATE`,
`UPDATE district%d
SET d_next_o_id = %d
Expand All @@ -1337,130 +1343,130 @@ WHERE d_id = %d AND d_w_id= %d`,
VALUES (%d,%d,%d,%d,NOW(),%d,%d)`,
`INSERT INTO new_orders%d (no_o_id, no_d_id, no_w_id)
VALUES (%d,%d,%d)`,
`SELECT i_price, i_name, i_data
`SELECT i_price, i_name, i_data
FROM item%d
WHERE i_id = %d`,
`SELECT s_quantity, s_data, s_dist_%s s_dist
FROM stock%d
`SELECT s_quantity, s_data, s_dist_%s s_dist
FROM stock%d
WHERE s_i_id = %d AND s_w_id= %d FOR UPDATE`,
`UPDATE stock%d
SET s_quantity = %d
WHERE s_i_id = %d
WHERE s_i_id = %d
AND s_w_id= %d`,
`INSERT INTO order_line%d
(ol_o_id, ol_d_id, ol_w_id, ol_number, ol_i_id, ol_supply_w_id, ol_quantity, ol_amount, ol_dist_info)
VALUES (%d,%d,%d,%d,%d,%d,%d,%d,'%s')`,
`UPDATE warehouse%d
SET w_ytd = w_ytd + %d
SET w_ytd = w_ytd + %d
WHERE w_id = %d`,
`SELECT w_street_1, w_street_2, w_city, w_state, w_zip, w_name
FROM warehouse%d
WHERE w_id = %d`,
`UPDATE district%d
SET d_ytd = d_ytd + %d
WHERE d_w_id = %d
`UPDATE district%d
SET d_ytd = d_ytd + %d
WHERE d_w_id = %d
AND d_id= %d`,
`SELECT d_street_1, d_street_2, d_city, d_state, d_zip, d_name
`SELECT d_street_1, d_street_2, d_city, d_state, d_zip, d_name
FROM district%d
WHERE d_w_id = %d
WHERE d_w_id = %d
AND d_id = %d`,
`SELECT count(c_id) namecnt
FROM customer%d
WHERE c_w_id = %d
WHERE c_w_id = %d
AND c_d_id= %d
AND c_last='%s'`,
`SELECT c_first, c_middle, c_last, c_street_1,
c_street_2, c_city, c_state, c_zip, c_phone,
c_credit, c_credit_lim, c_discount, c_balance, c_ytd_payment, c_since
FROM customer%d
WHERE c_w_id = %d
WHERE c_w_id = %d
AND c_d_id= %d
AND c_id=%d FOR UPDATE`,
`SELECT c_data
FROM customer%d
WHERE c_w_id = %d
WHERE c_w_id = %d
AND c_d_id=%d
AND c_id= %d`,
`UPDATE customer%d
SET c_balance=%f, c_ytd_payment=%f, c_data='%s'
WHERE c_w_id = %d
WHERE c_w_id = %d
AND c_d_id=%d
AND c_id=%d`,
`UPDATE customer%d
SET c_balance=%f, c_ytd_payment=%f
WHERE c_w_id = %d
WHERE c_w_id = %d
AND c_d_id=%d
AND c_id=%d`,
`INSERT INTO history%d
(h_c_d_id, h_c_w_id, h_c_id, h_d_id, h_w_id, h_date, h_amount, h_data)
VALUES (%d,%d,%d,%d,%d,NOW(),%d,'%s')`,
`SELECT count(c_id) namecnt
FROM customer%d
WHERE c_w_id = %d
WHERE c_w_id = %d
AND c_d_id= %d
AND c_last='%s'`,
`SELECT c_balance, c_first, c_middle, c_id
FROM customer%d
WHERE c_w_id = %d
WHERE c_w_id = %d
AND c_d_id= %d
AND c_last='%s' ORDER BY c_first`,
`SELECT c_balance, c_first, c_middle, c_last
FROM customer%d
WHERE c_w_id = %d
WHERE c_w_id = %d
AND c_d_id=%d
AND c_id=%d`,
`SELECT o_id, o_carrier_id, o_entry_d
FROM orders%d
WHERE o_w_id = %d
AND o_d_id = %d
AND o_c_id = %d
FROM orders%d
WHERE o_w_id = %d
AND o_d_id = %d
AND o_c_id = %d
ORDER BY o_id DESC`,
`SELECT ol_i_id, ol_supply_w_id, ol_quantity, ol_amount, ol_delivery_d
FROM order_line%d WHERE ol_w_id = %d AND ol_d_id = %d AND ol_o_id = %d`,
`SELECT no_o_id
FROM new_orders%d
WHERE no_d_id = %d
AND no_w_id = %d
FROM new_orders%d
WHERE no_d_id = %d
AND no_w_id = %d
ORDER BY no_o_id ASC LIMIT 1 FOR UPDATE`,
`DELETE FROM new_orders%d
WHERE no_o_id = %d
AND no_d_id = %d
WHERE no_o_id = %d
AND no_d_id = %d
AND no_w_id = %d`,
`SELECT o_c_id
FROM orders%d
WHERE o_id = %d
AND o_d_id = %d
FROM orders%d
WHERE o_id = %d
AND o_d_id = %d
AND o_w_id = %d`,
`UPDATE orders%d
`UPDATE orders%d
SET o_carrier_id = %d
WHERE o_id = %d
AND o_d_id = %d
WHERE o_id = %d
AND o_d_id = %d
AND o_w_id = %d`,
`UPDATE order_line%d
`UPDATE order_line%d
SET ol_delivery_d = NOW()
WHERE ol_o_id = %d
AND ol_d_id = %d
WHERE ol_o_id = %d
AND ol_d_id = %d
AND ol_w_id = %d`,
`SELECT SUM(ol_amount) sm
FROM order_line%d
WHERE ol_o_id = %d
AND ol_d_id = %d
FROM order_line%d
WHERE ol_o_id = %d
AND ol_d_id = %d
AND ol_w_id = %d`,
`UPDATE customer%d
`UPDATE customer%d
SET c_balance = c_balance + %f,
c_delivery_cnt = c_delivery_cnt + 1
WHERE c_id = %d
AND c_d_id = %d
WHERE c_id = %d
AND c_d_id = %d
AND c_w_id = %d`,
`SELECT d_next_o_id
`SELECT d_next_o_id
FROM district%d
WHERE d_id = %d AND d_w_id= %d`,
`SELECT COUNT(DISTINCT(s.s_i_id))
FROM stock%d AS s
JOIN order_line%d AS ol ON ol.ol_w_id=s.s_w_id AND ol.ol_i_id=s.s_i_id
WHERE ol.ol_w_id = %d
JOIN order_line%d AS ol ON ol.ol_w_id=s.s_w_id AND ol.ol_i_id=s.s_i_id
WHERE ol.ol_w_id = %d
AND ol.ol_d_id = %d
AND ol.ol_o_id < %d
AND ol.ol_o_id < %d
AND ol.ol_o_id >= %d
AND s.s_w_id= %d
AND s.s_quantity < %d `,
Expand All @@ -1471,7 +1477,7 @@ AND ol_o_id < %d AND ol_o_id >= %d`,
WHERE s_w_id = %d AND s_i_id = %d
AND s_quantity < %d`,
`SELECT min(no_o_id) mo
FROM new_orders%d
FROM new_orders%d
WHERE no_w_id = %d AND no_d_id = %d`,
`SELECT o_id FROM orders%d o, (SELECT o_c_id,o_w_id,o_d_id,count(distinct o_id) FROM orders%d WHERE o_w_id=%d AND o_d_id=%d AND o_id > 2100 AND o_id < %d GROUP BY o_c_id,o_d_id,o_w_id having count( distinct o_id) > 1 limit 1) t WHERE t.o_w_id=o.o_w_id and t.o_d_id=o.o_d_id and t.o_c_id=o.o_c_id limit 1 `,
`DELETE FROM order_line%d where ol_w_id=%d AND ol_d_id=%d AND ol_o_id=%d`,
Expand Down
20 changes: 2 additions & 18 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3333,33 +3333,17 @@ func TestSelectWithUnionAll(t *testing.T) {
bv1, _ := sqltypes.BuildBindVariable([]int64{1, 2})
bv2, _ := sqltypes.BuildBindVariable([]int64{3})
sbc1WantQueries := []*querypb.BoundQuery{{
Sql: "select id from `user` where id in ::__vals",
BindVariables: map[string]*querypb.BindVariable{
"__vals": bv1,
"vtg1": bv,
"vtg2": bv,
},
}, {
Sql: "select id from `user` where id in ::__vals",
Sql: "select id from `user` where id in ::__vals union all select id from `user` where id in ::vtg1",
BindVariables: map[string]*querypb.BindVariable{
"__vals": bv1,
"vtg1": bv,
"vtg2": bv,
},
}}
sbc2WantQueries := []*querypb.BoundQuery{{
Sql: "select id from `user` where id in ::__vals",
BindVariables: map[string]*querypb.BindVariable{
"__vals": bv2,
"vtg1": bv,
"vtg2": bv,
},
}, {
Sql: "select id from `user` where id in ::__vals",
Sql: "select id from `user` where id in ::__vals union all select id from `user` where id in ::vtg1",
BindVariables: map[string]*querypb.BindVariable{
"__vals": bv2,
"vtg1": bv,
"vtg2": bv,
},
}}
session := &vtgatepb.Session{
Expand Down
18 changes: 15 additions & 3 deletions go/vt/vtgate/planbuilder/operators/route_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,14 +478,26 @@ func gen4ValuesEqual(ctx *plancontext.PlanningContext, a, b []sqlparser.Expr) (b
return false, nil
}
if c != nil {
conditions = append(conditions, *c)
conditions = append(conditions, c...)
}
}
return true, conditions
}

func gen4ValEqual(ctx *plancontext.PlanningContext, a, b sqlparser.Expr) (bool, *engine.Condition) {
func gen4ValEqual(ctx *plancontext.PlanningContext, a, b sqlparser.Expr) (bool, []engine.Condition) {
switch a := a.(type) {
case sqlparser.ValTuple:
if b, ok := b.(sqlparser.ValTuple); ok {
return gen4ValuesEqual(ctx, a, b)
}

return false, nil

case sqlparser.ListArg:
if b, ok := b.(sqlparser.ListArg); ok {
return a == b, nil
}

case *sqlparser.ColName:
if b, ok := b.(*sqlparser.ColName); ok {
if !a.Name.Equal(b.Name) {
Expand Down Expand Up @@ -518,7 +530,7 @@ func gen4ValEqual(ctx *plancontext.PlanningContext, a, b sqlparser.Expr) (bool,
}

return aVal.Type == bVal.Type && bytes.Equal(aVal.Value, bVal.Value),
&engine.Condition{A: a.Name, B: b.Name}
[]engine.Condition{{A: a.Name, B: b.Name}}

case *sqlparser.Literal:
b, ok := b.(*sqlparser.Literal)
Expand Down
8 changes: 5 additions & 3 deletions go/vt/vtgate/planbuilder/operators/union_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,6 @@ func tryMergeUnionShardedRouting(

scatterA := tblA.RouteOpCode == engine.Scatter
scatterB := tblB.RouteOpCode == engine.Scatter
uniqueA := tblA.RouteOpCode == engine.EqualUnique
uniqueB := tblB.RouteOpCode == engine.EqualUnique

switch {
case scatterA:
Expand All @@ -156,7 +154,11 @@ func tryMergeUnionShardedRouting(
case scatterB:
return createMergedUnion(ctx, routeA, routeB, exprsA, exprsB, distinct, tblB, nil)

case uniqueA && uniqueB:
case tblA.RouteOpCode == engine.EqualUnique && tblB.RouteOpCode == engine.EqualUnique:
fallthrough
case tblA.RouteOpCode == engine.Equal && tblB.RouteOpCode == engine.Equal:
fallthrough
case tblA.RouteOpCode == engine.IN && tblB.RouteOpCode == engine.IN:
aVdx := tblA.SelectedVindex()
bVdx := tblB.SelectedVindex()
aExpr := tblA.VindexExpressions()
Expand Down
Loading
Loading