Skip to content

Commit 42b3bf3

Browse files
authored
feat(cloudformation): support default values and list results in Fn::FindInMap (#9515)
Signed-off-by: nikpivkin <[email protected]>
1 parent 8e40d27 commit 42b3bf3

File tree

2 files changed

+93
-87
lines changed

2 files changed

+93
-87
lines changed
Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,80 @@
11
package parser
22

33
import (
4+
"fmt"
5+
46
"github.com/aquasecurity/trivy/pkg/iac/scanners/cloudformation/cftypes"
57
)
68

7-
func ResolveFindInMap(property *Property) (resolved *Property, success bool) {
9+
func ResolveFindInMap(property *Property) (*Property, bool) {
810
if !property.isFunction() {
911
return property, true
1012
}
1113

1214
refValue := property.AsMap()["Fn::FindInMap"].AsList()
1315

14-
if len(refValue) != 3 {
15-
return abortIntrinsic(property, "Fn::FindInMap should have exactly 3 values, returning original Property")
16+
if len(refValue) < 3 || len(refValue) > 4 {
17+
return abortIntrinsic(property, "Fn::FindInMap expects 3 or 4 arguments")
1618
}
1719

18-
mapName := refValue[0].AsString()
19-
topLevelKey := refValue[1].AsString()
20-
secondaryLevelKey := refValue[2].AsString()
21-
2220
if property.ctx == nil {
23-
return abortIntrinsic(property, "the property does not have an attached context, returning original Property")
21+
return abortIntrinsic(property, "property context is missing")
2422
}
2523

26-
m, ok := property.ctx.Mappings[mapName]
27-
if !ok {
28-
return abortIntrinsic(property, "could not find map %s, returning original Property")
24+
var defaultValue any
25+
if len(refValue) == 4 {
26+
if m := refValue[3].AsMap(); m != nil {
27+
if defProp, exists := m["DefaultValue"]; exists && defProp != nil {
28+
defaultValue = defProp.RawValue()
29+
}
30+
}
31+
}
32+
33+
mapName := refValue[0].AsString()
34+
topKey := refValue[1].AsString()
35+
secKey := refValue[2].AsString()
36+
37+
value, err := resolveMapping(property.ctx, mapName, topKey, secKey)
38+
if err != nil {
39+
if defaultValue == nil {
40+
return abortIntrinsic(property, err.Error())
41+
}
42+
value = defaultValue
2943
}
3044

31-
mapContents := m.(map[string]any)
45+
switch v := value.(type) {
46+
case string:
47+
return property.deriveResolved(cftypes.String, v), true
48+
case []any:
49+
elems := make([]*Property, len(v))
50+
for i, el := range v {
51+
elems[i] = property.deriveResolved(cftypes.String, el)
52+
}
53+
return property.deriveResolved(cftypes.List, elems), true
54+
default:
55+
return abortIntrinsic(property, fmt.Sprintf("unsupported type in mapping: %T", v))
56+
}
57+
}
3258

33-
k, ok := mapContents[topLevelKey]
59+
func resolveMapping(ctx *FileContext, mapName, topKey, secKey string) (any, error) {
60+
m, ok := ctx.Mappings[mapName]
3461
if !ok {
35-
return abortIntrinsic(property, "could not find %s in the %s map, returning original Property", topLevelKey, mapName)
62+
return nil, fmt.Errorf("map %s not found", mapName)
63+
}
64+
mapContents, ok := m.(map[string]any)
65+
if !ok {
66+
return nil, fmt.Errorf("map %s has invalid type", mapName)
3667
}
3768

69+
k, ok := mapContents[topKey]
70+
if !ok {
71+
return nil, fmt.Errorf("key %s not found in map %s", topKey, mapName)
72+
}
3873
mapValues := k.(map[string]any)
3974

40-
prop, ok := mapValues[secondaryLevelKey]
75+
prop, ok := mapValues[secKey]
4176
if !ok {
42-
return abortIntrinsic(property, "could not find a value for %s in %s, returning original Property", secondaryLevelKey, topLevelKey)
77+
return nil, fmt.Errorf("key %s not found in %s", secKey, topKey)
4378
}
44-
return property.deriveResolved(cftypes.String, prop), true
79+
return prop, nil
4580
}

pkg/iac/scanners/cloudformation/parser/fn_find_in_map_test.go

Lines changed: 41 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3,100 +3,71 @@ package parser
33
import (
44
"testing"
55

6+
"github.com/samber/lo"
67
"github.com/stretchr/testify/assert"
78
"github.com/stretchr/testify/require"
89
)
910

10-
func Test_resolve_find_in_map_value(t *testing.T) {
11-
11+
func Test_FindInMap(t *testing.T) {
1212
source := `---
1313
Parameters:
1414
Environment:
1515
Type: String
16-
Default: production
16+
Default: dev
1717
Mappings:
1818
CacheNodeTypes:
1919
production:
2020
NodeType: cache.t2.large
21+
CacheSecurityGroupNames: [ "sg-1", "sg-2" ]
2122
test:
2223
NodeType: cache.t2.small
24+
CacheSecurityGroupNames: [ "sg-3" ]
2325
dev:
2426
NodeType: cache.t2.micro
27+
CacheSecurityGroupNames: [ "sg-4" ]
2528
Resources:
26-
ElasticacheSecurityGroup:
27-
Type: 'AWS::EC2::SecurityGroup'
28-
Properties:
29-
GroupDescription: Elasticache Security Group
30-
SecurityGroupIngress:
31-
- IpProtocol: tcp
32-
FromPort: 11211
33-
ToPort: 11211
34-
SourceSecurityGroupName: !Ref InstanceSecurityGroup
35-
ElasticacheCluster:
36-
Type: 'AWS::ElastiCache::CacheCluster'
37-
Properties:
38-
Engine: memcached
39-
CacheNodeType: !FindInMap [ CacheNodeTypes, production, NodeType ]
40-
NumCacheNodes: '1'
41-
VpcSecurityGroupIds:
42-
- !GetAtt
43-
- ElasticacheSecurityGroup
44-
- GroupId
45-
`
46-
ctx := createTestFileContext(t, source)
47-
require.NotNil(t, ctx)
48-
49-
testRes := ctx.GetResourceByLogicalID("ElasticacheCluster")
50-
assert.NotNil(t, testRes)
51-
52-
nodeTypeProp := testRes.GetStringProperty("CacheNodeType", "")
53-
assert.Equal(t, "cache.t2.large", nodeTypeProp.Value())
54-
}
29+
ElasticacheCluster:
30+
Type: 'AWS::ElastiCache::CacheCluster'
31+
Properties:
32+
Engine: memcached
33+
CacheNodeType: !FindInMap [ CacheNodeTypes, !Ref Environment, NodeType ]
34+
NumCacheNodes: '1'
5535
56-
func Test_resolve_find_in_map_with_nested_intrinsic_value(t *testing.T) {
36+
ElasticacheClusterWithDefault:
37+
Type: 'AWS::ElastiCache::CacheCluster'
38+
Properties:
39+
Engine: memcached
40+
CacheNodeType: !FindInMap [ CacheNodeTypes, staging, NodeType, DefaultValue: cache.t2.medium ]
41+
NumCacheNodes: '1'
5742
58-
source := `---
59-
Parameters:
60-
Environment:
61-
Type: String
62-
Default: dev
63-
Mappings:
64-
CacheNodeTypes:
65-
production:
66-
NodeType: cache.t2.large
67-
test:
68-
NodeType: cache.t2.small
69-
dev:
70-
NodeType: cache.t2.micro
71-
Resources:
72-
ElasticacheSecurityGroup:
73-
Type: 'AWS::EC2::SecurityGroup'
74-
Properties:
75-
GroupDescription: Elasticache Security Group
76-
SecurityGroupIngress:
77-
- IpProtocol: tcp
78-
FromPort: 11211
79-
ToPort: 11211
80-
SourceSecurityGroupName: !Ref InstanceSecurityGroup
81-
ElasticacheCluster:
82-
Type: 'AWS::ElastiCache::CacheCluster'
83-
Properties:
84-
Engine: memcached
85-
CacheNodeType: !FindInMap [ CacheNodeTypes, !Ref Environment, NodeType ]
86-
NumCacheNodes: '1'
87-
VpcSecurityGroupIds:
88-
- !GetAtt
89-
- ElasticacheSecurityGroup
90-
- GroupId
43+
ElasticacheClusterList:
44+
Type: 'AWS::ElastiCache::CacheCluster'
45+
Properties:
46+
Engine: memcached
47+
CacheSecurityGroupNames: !FindInMap [ CacheNodeTypes, production, CacheSecurityGroupNames ]
48+
NumCacheNodes: '1'
9149
`
50+
9251
ctx := createTestFileContext(t, source)
9352
require.NotNil(t, ctx)
9453

95-
testRes := ctx.GetResourceByLogicalID("ElasticacheCluster")
96-
assert.NotNil(t, testRes)
97-
98-
nodeTypeProp := testRes.GetStringProperty("CacheNodeType", "")
54+
cluster := ctx.GetResourceByLogicalID("ElasticacheCluster")
55+
require.NotNil(t, cluster)
56+
nodeTypeProp := cluster.GetStringProperty("CacheNodeType", "")
9957
assert.Equal(t, "cache.t2.micro", nodeTypeProp.Value())
58+
59+
clusterDefault := ctx.GetResourceByLogicalID("ElasticacheClusterWithDefault")
60+
require.NotNil(t, clusterDefault)
61+
nodeTypePropDefault := clusterDefault.GetStringProperty("CacheNodeType", "")
62+
assert.Equal(t, "cache.t2.medium", nodeTypePropDefault.Value())
63+
64+
clusterList := ctx.GetResourceByLogicalID("ElasticacheClusterList")
65+
require.NotNil(t, clusterList)
66+
sgNamesProp := clusterList.GetProperty("CacheSecurityGroupNames").AsList()
67+
groupNames := lo.Map(sgNamesProp, func(prop *Property, _ int) any {
68+
return prop.AsString()
69+
})
70+
assert.ElementsMatch(t, []any{"sg-1", "sg-2"}, groupNames)
10071
}
10172

10273
func Test_InferType(t *testing.T) {

0 commit comments

Comments
 (0)