diff --git a/pkg/engine/assert/expression.go b/pkg/engine/assert/expression.go index fb517256..898fb5b5 100644 --- a/pkg/engine/assert/expression.go +++ b/pkg/engine/assert/expression.go @@ -4,9 +4,7 @@ import ( "context" "reflect" "regexp" - "sync" - "github.com/jmespath-community/go-jmespath/pkg/parsing" reflectutils "github.com/kyverno/kyverno-json/pkg/utils/reflect" ) @@ -23,7 +21,6 @@ type expression struct { statement string binding string engine string - ast func() (parsing.ASTNode, error) } func parseExpressionRegex(_ context.Context, in string) *expression { @@ -57,10 +54,6 @@ func parseExpressionRegex(_ context.Context, in string) *expression { if expression.statement == "" { return nil } - expression.ast = sync.OnceValues(func() (parsing.ASTNode, error) { - parser := parsing.NewParser() - return parser.Parse(expression.statement) - }) return expression } diff --git a/pkg/engine/assert/parse.go b/pkg/engine/assert/parse.go index a34aae51..5a690d36 100644 --- a/pkg/engine/assert/parse.go +++ b/pkg/engine/assert/parse.go @@ -7,6 +7,7 @@ import ( "github.com/jmespath-community/go-jmespath/pkg/binding" jpbinding "github.com/jmespath-community/go-jmespath/pkg/binding" + "github.com/jmespath-community/go-jmespath/pkg/parsing" "github.com/kyverno/kyverno-json/pkg/engine/match" "github.com/kyverno/kyverno-json/pkg/engine/template" reflectutils "github.com/kyverno/kyverno-json/pkg/utils/reflect" @@ -16,16 +17,7 @@ import ( func Parse(ctx context.Context, path *field.Path, assertion any) (Assertion, error) { switch reflectutils.GetKind(assertion) { case reflect.Slice: - node := sliceNode{} - valueOf := reflect.ValueOf(assertion) - for i := 0; i < valueOf.Len(); i++ { - sub, err := Parse(ctx, path.Index(i), valueOf.Index(i).Interface()) - if err != nil { - return nil, err - } - node = append(node, sub) - } - return node, nil + return parseSlice(ctx, path, assertion) case reflect.Map: node := mapNode{} iter := reflect.ValueOf(assertion).MapRange() @@ -39,8 +31,96 @@ func Parse(ctx context.Context, path *field.Path, assertion any) (Assertion, err } return node, nil default: - return newScalarNode(ctx, path, assertion) + return parseScalar(ctx, path, assertion) + } +} + +// node implements the Assertion interface using a delegate func +type node func(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) + +// TODO: do we need the path in the signature ? +func (n node) assert(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) { + return n(ctx, path, value, bindings, opts...) +} + +// parseSlice is the assertion represented by a slice. +// it first compares the length of the analysed resource with the length of the descendants. +// if lengths match all descendants are evaluated with their corresponding items. +func parseSlice(ctx context.Context, path *field.Path, assertion any) (node, error) { + var assertions []Assertion + valueOf := reflect.ValueOf(assertion) + for i := 0; i < valueOf.Len(); i++ { + sub, err := Parse(ctx, path.Index(i), valueOf.Index(i).Interface()) + if err != nil { + return nil, err + } + assertions = append(assertions, sub) } + return func(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) { + var errs field.ErrorList + if value == nil { + errs = append(errs, field.Invalid(path, value, "value is null")) + } else if reflectutils.GetKind(value) != reflect.Slice { + return nil, field.TypeInvalid(path, value, "expected a slice") + } else { + valueOf := reflect.ValueOf(value) + if valueOf.Len() != len(assertions) { + errs = append(errs, field.Invalid(path, value, "lengths of slices don't match")) + } else { + for i := range assertions { + if _errs, err := assertions[i].assert(ctx, path.Index(i), valueOf.Index(i).Interface(), bindings, opts...); err != nil { + return nil, err + } else { + errs = append(errs, _errs...) + } + } + } + } + return errs, nil + }, nil +} + +// parseScalar is the assertion represented by a leaf. +// it receives a value and compares it with an expected value. +// the expected value can be the result of an expression. +func parseScalar(ctx context.Context, path *field.Path, assertion any) (node, error) { + expression := parseExpression(ctx, assertion) + // we only project if the expression uses the engine syntax + // this is to avoid the case where the value is a map and the RHS is a string + var project func(ctx context.Context, value any, bindings binding.Bindings, opts ...template.Option) (any, error) + if expression != nil && expression.engine != "" { + if expression.foreachName != "" { + return nil, field.Invalid(path, assertion, "foreach is not supported on the RHS") + } + if expression.binding != "" { + return nil, field.Invalid(path, assertion, "binding is not supported on the RHS") + } + parser := parsing.NewParser() + ast, err := parser.Parse(expression.statement) + if err != nil { + return nil, field.InternalError(path, err) + } + project = func(ctx context.Context, value any, bindings jpbinding.Bindings, opts ...template.Option) (any, error) { + return template.ExecuteAST(ctx, ast, value, bindings, opts...) + } + } + return func(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) { + expected := assertion + if project != nil { + projected, err := project(ctx, value, bindings, opts...) + if err != nil { + return nil, field.InternalError(path, err) + } + expected = projected + } + var errs field.ErrorList + if match, err := match.Match(ctx, expected, value); err != nil { + return nil, field.InternalError(path, err) + } else if !match { + errs = append(errs, field.Invalid(path, value, expectValueMessage(expected))) + } + return errs, nil + }, nil } // mapNode is the assertion type represented by a map. @@ -110,75 +190,6 @@ func (n mapNode) assert(ctx context.Context, path *field.Path, value any, bindin return errs, nil } -// sliceNode is the assertion type represented by a slice. -// it first compares the length of the analysed resource with the length of the descendants. -// if lengths match all descendants are evaluated with their corresponding items. -type sliceNode []Assertion - -func (n sliceNode) assert(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) { - var errs field.ErrorList - if value == nil { - errs = append(errs, field.Invalid(path, value, "value is null")) - } else if reflectutils.GetKind(value) != reflect.Slice { - return nil, field.TypeInvalid(path, value, "expected a slice") - } else { - valueOf := reflect.ValueOf(value) - if valueOf.Len() != len(n) { - errs = append(errs, field.Invalid(path, value, "lengths of slices don't match")) - } else { - for i := range n { - if _errs, err := n[i].assert(ctx, path.Index(i), valueOf.Index(i).Interface(), bindings, opts...); err != nil { - return nil, err - } else { - errs = append(errs, _errs...) - } - } - } - } - return errs, nil -} - -// scalarNode is a terminal type of assertion. -// it receives a value and compares it with an expected value. -// the expected value can be the result of an expression. -type scalarNode func(value any, bindings binding.Bindings, opts ...template.Option) (any, error) - -func newScalarNode(ctx context.Context, path *field.Path, rhs any) (scalarNode, error) { - expression := parseExpression(ctx, rhs) - // we only project if the expression uses the engine syntax - // this is to avoid the case where the value is a map and the RHS is a string - if expression != nil && expression.engine != "" { - if expression.foreachName != "" { - return nil, field.Invalid(path, rhs, "foreach is not supported on the RHS") - } - if expression.binding != "" { - return nil, field.Invalid(path, rhs, "binding is not supported on the RHS") - } - ast, err := expression.ast() - if err != nil { - return nil, field.InternalError(path, err) - } - return func(value any, bindings binding.Bindings, opts ...template.Option) (any, error) { - return template.ExecuteAST(ctx, ast, value, bindings, opts...) - }, nil - } - return func(value any, bindings binding.Bindings, opts ...template.Option) (any, error) { - return rhs, nil - }, nil -} - -func (n scalarNode) assert(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) { - var errs field.ErrorList - if rhs, err := n(value, bindings, opts...); err != nil { - return nil, field.InternalError(path, err) - } else if match, err := match.Match(ctx, rhs, value); err != nil { - return nil, field.InternalError(path, err) - } else if !match { - errs = append(errs, field.Invalid(path, value, expectValueMessage(rhs))) - } - return errs, nil -} - func expectValueMessage(value any) string { switch t := value.(type) { case int64, int32, float64, float32, bool: diff --git a/pkg/engine/assert/project.go b/pkg/engine/assert/project.go index cd057e22..4a22fe54 100644 --- a/pkg/engine/assert/project.go +++ b/pkg/engine/assert/project.go @@ -6,6 +6,7 @@ import ( "reflect" "github.com/jmespath-community/go-jmespath/pkg/binding" + "github.com/jmespath-community/go-jmespath/pkg/parsing" "github.com/kyverno/kyverno-json/pkg/engine/template" reflectutils "github.com/kyverno/kyverno-json/pkg/utils/reflect" ) @@ -21,7 +22,8 @@ func project(ctx context.Context, key any, value any, bindings binding.Bindings, expression := parseExpression(ctx, key) if expression != nil { if expression.engine != "" { - ast, err := expression.ast() + parser := parsing.NewParser() + ast, err := parser.Parse(expression.statement) if err != nil { return nil, err }