Skip to content

Commit 8c2edbe

Browse files
Merge commit from fork
* fix: fix entrypoints over relations that are reused for arrows Fixes the reachability graph to prevent skipping of necessary entrypoint computations when a relation is reused for an arrow in the same permission * fix: Fix building query planner trees that contain arrows that may contain types that don't have the right side of the arrow. --------- Co-authored-by: Barak Michener <[email protected]>
1 parent b13788c commit 8c2edbe

File tree

4 files changed

+164
-21
lines changed

4 files changed

+164
-21
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
schema: |+
2+
definition special_user {}
3+
4+
definition user {
5+
relation special_user_mapping: special_user
6+
permission special_user = special_user_mapping
7+
}
8+
9+
definition group {
10+
relation member: user
11+
permission membership = member + member->special_user
12+
}
13+
14+
definition system {
15+
relation viewer: user | group#membership
16+
permission view = viewer + viewer->special_user
17+
}
18+
19+
20+
relationships: |-
21+
system:somesystem#viewer@group:somegroup#membership
22+
23+
group:somegroup#member@user:someuser1
24+
25+
user:someuser1#special_user_mapping@special_user:specialuser
26+
assertions:
27+
assertTrue:
28+
- system:somesystem#view@special_user:specialuser
29+
assertFalse: []
30+
validation: {}

pkg/query/build_tree.go

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package query
22

33
import (
4+
"errors"
45
"fmt"
56

67
core "github.com/authzed/spicedb/pkg/proto/core/v1"
@@ -89,7 +90,10 @@ func (b *iteratorBuilder) buildIteratorFromSchemaInternal(definitionName string,
8990
} else if r, ok := def.GetRelation(relationName); ok {
9091
result, err = b.buildIteratorFromRelation(r, withSubRelations)
9192
} else {
92-
err = fmt.Errorf("BuildIteratorFromSchema: couldn't find a relation or permission named `%s` in definition `%s`", relationName, definitionName)
93+
err = RelationNotFoundError{
94+
definitionName: definitionName,
95+
relationName: relationName,
96+
}
9397
}
9498

9599
// Remove from building after we're done (allows reuse in other branches)
@@ -284,36 +288,76 @@ func (b *iteratorBuilder) buildBaseRelationIterator(br *schema.BaseRelation, wit
284288
// buildArrowIterators creates a union of arrow iterators for the given relation and right-hand side
285289
func (b *iteratorBuilder) buildArrowIterators(rel *schema.Relation, rightSide string) (Iterator, error) {
286290
union := NewUnion()
291+
hasMultipleBaseRelations := len(rel.BaseRelations()) > 1
292+
var lastNotFoundError error
293+
287294
for _, br := range rel.BaseRelations() {
288295
left, err := b.buildBaseRelationIterator(br, false)
289296
if err != nil {
290297
return nil, err
291298
}
292299
right, err := b.buildIteratorFromSchemaInternal(br.Type(), rightSide, false)
293300
if err != nil {
301+
// If the right side doesn't exist on this type, the arrow produces an empty set.
302+
// This is valid when a relation has multiple types and the arrow only
303+
// applies to some of them. If there's only one base relation, we should error.
304+
if errors.As(err, &RelationNotFoundError{}) {
305+
if hasMultipleBaseRelations {
306+
union.addSubIterator(NewEmptyFixedIterator())
307+
continue
308+
}
309+
lastNotFoundError = err
310+
continue
311+
}
294312
return nil, err
295313
}
296314
arrow := NewArrow(left, right)
297315
union.addSubIterator(arrow)
298316
}
317+
318+
// If we have no sub-iterators and only have a not-found error, return that error
319+
if len(union.Subiterators()) == 0 && lastNotFoundError != nil {
320+
return nil, lastNotFoundError
321+
}
322+
299323
return union, nil
300324
}
301325

302326
// buildIntersectionArrowIterators creates a union of intersection arrow iterators for the given relation and right-hand side
303327
func (b *iteratorBuilder) buildIntersectionArrowIterators(rel *schema.Relation, rightSide string) (Iterator, error) {
304328
union := NewUnion()
329+
hasMultipleBaseRelations := len(rel.BaseRelations()) > 1
330+
var lastNotFoundError error
331+
305332
for _, br := range rel.BaseRelations() {
306333
left, err := b.buildBaseRelationIterator(br, false)
307334
if err != nil {
308335
return nil, err
309336
}
310337
right, err := b.buildIteratorFromSchemaInternal(br.Type(), rightSide, false)
311338
if err != nil {
339+
// If the right side doesn't exist on this type, the intersection arrow produces an empty set.
340+
// This is valid when a relation has multiple types and the arrow only
341+
// applies to some of them. If there's only one base relation, we should error.
342+
if errors.As(err, &RelationNotFoundError{}) {
343+
if hasMultipleBaseRelations {
344+
union.addSubIterator(NewEmptyFixedIterator())
345+
continue
346+
}
347+
lastNotFoundError = err
348+
continue
349+
}
312350
return nil, err
313351
}
314352
intersectionArrow := NewIntersectionArrow(left, right)
315353
union.addSubIterator(intersectionArrow)
316354
}
355+
356+
// If we have no sub-iterators and only have a not-found error, return that error
357+
if len(union.Subiterators()) == 0 && lastNotFoundError != nil {
358+
return nil, lastNotFoundError
359+
}
360+
317361
return union, nil
318362
}
319363

@@ -327,3 +371,13 @@ func functionTypeString(ft schema.FunctionType) string {
327371
return "unknown"
328372
}
329373
}
374+
375+
// RelationNotFoundError is returned when a relation or permission is not found in a definition
376+
type RelationNotFoundError struct {
377+
definitionName string
378+
relationName string
379+
}
380+
381+
func (e RelationNotFoundError) Error() string {
382+
return fmt.Sprintf("BuildIteratorFromSchema: couldn't find a relation or permission named `%s` in definition `%s`", e.relationName, e.definitionName)
383+
}

pkg/schema/reachabilitygraph.go

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func (rg *DefinitionReachability) RelationsEncounteredForResource(
3737
ctx context.Context,
3838
resourceType *core.RelationReference,
3939
) ([]*core.RelationReference, error) {
40-
_, relationNames, err := rg.computeEntrypoints(ctx, resourceType, nil /* include all entrypoints */, reachabilityFull, entrypointLookupFindAll)
40+
_, relationNames, err := rg.computeEntrypoints(ctx, resourceType, nil, reachabilityFull /* include all entrypoints */, entrypointLookupFindAll)
4141
if err != nil {
4242
return nil, err
4343
}
@@ -85,11 +85,12 @@ func (rg *DefinitionReachability) RelationsEncounteredForSubject(
8585
continue
8686
}
8787

88-
encounteredRelations := map[string]struct{}{}
88+
allEncounteredRelations := mapz.NewSet[string]()
89+
encounteredRelationsForComputation := mapz.NewSet[string]()
8990
err := nrg.collectEntrypoints(ctx, &core.RelationReference{
9091
Namespace: nsDef.Name,
9192
Relation: relation.Name,
92-
}, subjectType, collected, encounteredRelations, reachabilityFull, entrypointLookupFindAll)
93+
}, subjectType, collected, allEncounteredRelations, encounteredRelationsForComputation, reachabilityFull, entrypointLookupFindAll)
9394
if err != nil {
9495
return nil, err
9596
}
@@ -193,10 +194,13 @@ func (rg *DefinitionReachability) computeEntrypoints(
193194
}
194195

195196
collected := &[]ReachabilityEntrypoint{}
196-
encounteredRelations := map[string]struct{}{}
197-
err := rg.collectEntrypoints(ctx, resourceType, optionalSubjectType, collected, encounteredRelations, reachabilityOption, entrypointLookupOption)
197+
198+
allEncounteredRelations := mapz.NewSet[string]()
199+
encounteredRelationsForComputation := mapz.NewSet[string]()
200+
201+
err := rg.collectEntrypoints(ctx, resourceType, optionalSubjectType, collected, allEncounteredRelations, encounteredRelationsForComputation, reachabilityOption, entrypointLookupOption)
198202
if err != nil {
199-
return nil, slices.Collect(maps.Keys(encounteredRelations)), err
203+
return nil, allEncounteredRelations.AsSlice(), err
200204
}
201205

202206
collectedEntrypoints := *collected
@@ -213,7 +217,7 @@ func (rg *DefinitionReachability) computeEntrypoints(
213217
for _, entrypoint := range collectedEntrypoints {
214218
hash, err := entrypoint.Hash()
215219
if err != nil {
216-
return nil, slices.Collect(maps.Keys(encounteredRelations)), err
220+
return nil, allEncounteredRelations.AsSlice(), err
217221
}
218222

219223
if _, ok := entrypointMap[hash]; !ok {
@@ -222,7 +226,7 @@ func (rg *DefinitionReachability) computeEntrypoints(
222226
}
223227
}
224228

225-
return uniqueEntrypoints, slices.Collect(maps.Keys(encounteredRelations)), nil
229+
return uniqueEntrypoints, allEncounteredRelations.AsSlice(), nil
226230
}
227231

228232
func (rg *DefinitionReachability) getOrBuildGraph(ctx context.Context, resourceType *core.RelationReference, reachabilityOption reachabilityOption) (*core.ReachabilityGraph, error) {
@@ -253,17 +257,17 @@ func (rg *DefinitionReachability) collectEntrypoints(
253257
resourceType *core.RelationReference,
254258
optionalSubjectType *core.RelationReference,
255259
collected *[]ReachabilityEntrypoint,
256-
encounteredRelations map[string]struct{},
260+
allEncounteredRelations *mapz.Set[string],
261+
encounteredRelationsForComputation *mapz.Set[string],
257262
reachabilityOption reachabilityOption,
258263
entrypointLookupOption entrypointLookupOption,
259264
) error {
260265
// Ensure that we only process each relation once.
261266
key := tuple.JoinRelRef(resourceType.Namespace, resourceType.Relation)
262-
if _, ok := encounteredRelations[key]; ok {
267+
if !encounteredRelationsForComputation.Add(key) {
263268
return nil
264269
}
265-
266-
encounteredRelations[key] = struct{}{}
270+
allEncounteredRelations.Add(key)
267271

268272
rrg, err := rg.getOrBuildGraph(ctx, resourceType, reachabilityOption)
269273
if err != nil {
@@ -274,7 +278,7 @@ func (rg *DefinitionReachability) collectEntrypoints(
274278
// Add subject type entrypoints.
275279
subjectTypeEntrypoints, ok := rrg.EntrypointsBySubjectType[optionalSubjectType.Namespace]
276280
if ok {
277-
addEntrypoints(subjectTypeEntrypoints, resourceType, collected, encounteredRelations)
281+
addEntrypoints(subjectTypeEntrypoints, resourceType, collected, allEncounteredRelations, encounteredRelationsForComputation)
278282
}
279283

280284
if entrypointLookupOption == entrypointLookupFindOne && len(*collected) > 0 {
@@ -284,7 +288,7 @@ func (rg *DefinitionReachability) collectEntrypoints(
284288
// Add subject relation entrypoints.
285289
subjectRelationEntrypoints, ok := rrg.EntrypointsBySubjectRelation[tuple.JoinRelRef(optionalSubjectType.Namespace, optionalSubjectType.Relation)]
286290
if ok {
287-
addEntrypoints(subjectRelationEntrypoints, resourceType, collected, encounteredRelations)
291+
addEntrypoints(subjectRelationEntrypoints, resourceType, collected, allEncounteredRelations, encounteredRelationsForComputation)
288292
}
289293

290294
if entrypointLookupOption == entrypointLookupFindOne && len(*collected) > 0 {
@@ -293,11 +297,11 @@ func (rg *DefinitionReachability) collectEntrypoints(
293297
} else {
294298
// Add all entrypoints.
295299
for _, entrypoints := range rrg.EntrypointsBySubjectType {
296-
addEntrypoints(entrypoints, resourceType, collected, encounteredRelations)
300+
addEntrypoints(entrypoints, resourceType, collected, allEncounteredRelations, encounteredRelationsForComputation)
297301
}
298302

299303
for _, entrypoints := range rrg.EntrypointsBySubjectRelation {
300-
addEntrypoints(entrypoints, resourceType, collected, encounteredRelations)
304+
addEntrypoints(entrypoints, resourceType, collected, allEncounteredRelations, encounteredRelationsForComputation)
301305
}
302306
}
303307

@@ -309,7 +313,7 @@ func (rg *DefinitionReachability) collectEntrypoints(
309313
for _, entrypointSetKey := range keys {
310314
entrypointSet := rrg.EntrypointsBySubjectRelation[entrypointSetKey]
311315
if entrypointSet.SubjectRelation != nil && entrypointSet.SubjectRelation.Relation != tuple.Ellipsis {
312-
err := rg.collectEntrypoints(ctx, entrypointSet.SubjectRelation, optionalSubjectType, collected, encounteredRelations, reachabilityOption, entrypointLookupOption)
316+
err := rg.collectEntrypoints(ctx, entrypointSet.SubjectRelation, optionalSubjectType, collected, allEncounteredRelations, encounteredRelationsForComputation, reachabilityOption, entrypointLookupOption)
313317
if err != nil {
314318
return err
315319
}
@@ -323,13 +327,12 @@ func (rg *DefinitionReachability) collectEntrypoints(
323327
return nil
324328
}
325329

326-
func addEntrypoints(entrypoints *core.ReachabilityEntrypoints, parentRelation *core.RelationReference, collected *[]ReachabilityEntrypoint, encounteredRelations map[string]struct{}) {
330+
func addEntrypoints(entrypoints *core.ReachabilityEntrypoints, parentRelation *core.RelationReference, collected *[]ReachabilityEntrypoint, allEncounteredRelations *mapz.Set[string], encounteredRelationsForComputation *mapz.Set[string]) {
327331
for _, entrypoint := range entrypoints.Entrypoints {
328332
if entrypoint.TuplesetRelation != "" {
329333
key := tuple.JoinRelRef(entrypoint.TargetRelation.Namespace, entrypoint.TuplesetRelation)
330-
encounteredRelations[key] = struct{}{}
334+
allEncounteredRelations.Add(key)
331335
}
332-
333336
*collected = append(*collected, ReachabilityEntrypoint{entrypoint, parentRelation})
334337
}
335338
}

pkg/schema/reachabilitygraph_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,62 @@ func TestReachabilityGraph(t *testing.T) {
11801180
rrt("document", "view", false),
11811181
},
11821182
},
1183+
{
1184+
"multiple arrows with same starting point",
1185+
`definition special_user {}
1186+
1187+
definition user {
1188+
relation special_user_mapping: special_user
1189+
permission special_user = special_user_mapping
1190+
}
1191+
1192+
definition group {
1193+
relation member: user
1194+
permission membership = member + member->special_user
1195+
}
1196+
1197+
definition system {
1198+
relation viewer: user | group#membership
1199+
permission view = viewer + viewer->special_user
1200+
}`,
1201+
rr("system", "view"),
1202+
rr("special_user", "..."),
1203+
[]rrtStruct{
1204+
rrt("user", "special_user_mapping", true),
1205+
},
1206+
[]rrtStruct{
1207+
rrt("user", "special_user_mapping", true),
1208+
},
1209+
},
1210+
{
1211+
"multiple arrows with same non-terminal starting point",
1212+
`definition special_user {}
1213+
1214+
definition user {
1215+
relation special_user_mapping: special_user
1216+
permission special_user = special_user_mapping
1217+
}
1218+
1219+
definition group {
1220+
relation member: user
1221+
permission membership = member + member->special_user
1222+
}
1223+
1224+
definition system {
1225+
relation viewer: user | group#membership
1226+
permission view = viewer + viewer->special_user
1227+
}`,
1228+
rr("system", "view"),
1229+
rr("user", "special_user"),
1230+
[]rrtStruct{
1231+
rrt("system", "view", true),
1232+
rrt("group", "membership", true),
1233+
},
1234+
[]rrtStruct{
1235+
rrt("system", "view", true),
1236+
rrt("group", "membership", true),
1237+
},
1238+
},
11831239
}
11841240

11851241
for _, tc := range testCases {

0 commit comments

Comments
 (0)