diff --git a/go.mod b/go.mod index 393c95c..5cb024a 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/go-kit/kit v0.8.0 github.com/go-logfmt/logfmt v0.3.0 // indirect github.com/go-stack/stack v1.8.0 // indirect + github.com/google/uuid v1.3.0 github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 // indirect github.com/stretchr/testify v1.7.0 github.com/ugorji/go/codec v1.2.6 diff --git a/go.sum b/go.sum index 8370579..6684580 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/go-logfmt/logfmt v0.3.0 h1:8HUsc87TaSWLKwrnumgC8/YconD2fJQsRJAsWaPg2i github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 h1:T+h1c/A9Gawja4Y9mFVWj2vyii2bbUNDw3kt9VxK2EY= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/spec_validator.go b/spec_validator.go new file mode 100644 index 0000000..c9bae3f --- /dev/null +++ b/spec_validator.go @@ -0,0 +1,160 @@ +/** + * Copyright (c) 2022 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package wrp + +import ( + "errors" + "fmt" + "regexp" + "strconv" + "strings" + "unicode" + + "github.com/google/uuid" +) + +const ( + serialPrefix = "serial" + uuidPrefix = "uuid" + eventPrefix = "event" + dnsPrefix = "dns" +) + +var ( + ErrorInvalidMessageEncoding = errors.New("invalid message encoding") + ErrorInvalidMessageType = errors.New("invalid message type") + ErrorInvalidSource = errors.New("invalid Source name") + ErrorInvalidDestination = errors.New("invalid Destination name") + errInvalidUUID = errors.New("invalid UUID") + errEmptyAuthority = errors.New("invalid empty authority (ID)") + errInvalidMacLength = errors.New("invalid mac length") + errInvalidCharacter = errors.New("invalid character") + errInvalidLocatorPattern = errors.New("value given doesn't match expected locator pattern") +) + +// locatorPattern is the precompiled regular expression that all source and dest locators must match. +// Matching is partial, as everything after the authority (ID) is ignored. https://xmidt.io/docs/wrp/basics/#locators +var locatorPattern = regexp.MustCompile( + `^(?P(?i)` + macPrefix + `|` + uuidPrefix + `|` + eventPrefix + `|` + dnsPrefix + `|` + serialPrefix + `):(?P[^/]+)?`, +) + +// SpecValidators returns a WRP validator that ensures messages are valid based on +// each spec validator in the list. Only validates the opinionated portions of the spec. +// SpecValidators validates the following fields: UTF8 (all string fields), MessageType, Source, Destination +func SpecValidators() Validators { + return Validators{UTF8Validator(), MessageTypeValidator(), SourceValidator(), DestinationValidator()} +} + +// UTF8Validator returns a WRP validator that takes messages and validates that it contains UTF-8 strings. +func UTF8Validator() ValidatorFunc { + return func(m Message) error { + if err := UTF8(m); err != nil { + return fmt.Errorf("%w: %v", ErrorInvalidMessageEncoding, err) + } + + return nil + } +} + +// MessageTypeValidator returns a WRP validator that takes messages and validates their Type. +func MessageTypeValidator() ValidatorFunc { + return func(m Message) error { + if m.Type < Invalid0MessageType || m.Type > lastMessageType { + return ErrorInvalidMessageType + } + + switch m.Type { + case Invalid0MessageType, Invalid1MessageType, lastMessageType: + return ErrorInvalidMessageType + } + + return nil + } +} + +// SourceValidator returns a WRP validator that takes messages and validates their Source. +// Only mac and uuid sources are validated. Serial, event and dns sources are +// not validated. +func SourceValidator() ValidatorFunc { + return func(m Message) error { + if err := validateLocator(m.Source); err != nil { + return fmt.Errorf("%w '%s': %v", ErrorInvalidSource, m.Source, err) + } + + return nil + } +} + +// DestinationValidator returns a WRP validator that takes messages and validates their Destination. +// Only mac and uuid destinations are validated. Serial, event and dns destinations are +// not validated. +func DestinationValidator() ValidatorFunc { + return func(m Message) error { + if err := validateLocator(m.Destination); err != nil { + return fmt.Errorf("%w '%s': %v", ErrorInvalidDestination, m.Destination, err) + } + + return nil + } +} + +// validateLocator validates a given locator's scheme and authority (ID). +// Only mac and uuid schemes' IDs are validated. IDs from serial, event and dns schemes are +// not validated. +func validateLocator(l string) error { + match := locatorPattern.FindStringSubmatch(l) + if match == nil { + return errInvalidLocatorPattern + } + + idPart := match[2] + if len(idPart) == 0 { + return errEmptyAuthority + } + + switch strings.ToLower(match[1]) { + case macPrefix: + var invalidCharacter rune = -1 + idPart = strings.Map( + func(r rune) rune { + switch { + case strings.ContainsRune(hexDigits, r): + return unicode.ToLower(r) + case strings.ContainsRune(macDelimiters, r): + return -1 + default: + invalidCharacter = r + return -1 + } + }, + idPart, + ) + + if invalidCharacter != -1 { + return fmt.Errorf("%w: %v", errInvalidCharacter, strconv.QuoteRune(invalidCharacter)) + } else if len(idPart) != macLength { + return errInvalidMacLength + } + case uuidPrefix: + if _, err := uuid.Parse(idPart); err != nil { + return fmt.Errorf("%w: %v", errInvalidUUID, err) + } + } + + return nil +} diff --git a/spec_validator_test.go b/spec_validator_test.go new file mode 100644 index 0000000..2d64cbc --- /dev/null +++ b/spec_validator_test.go @@ -0,0 +1,616 @@ +/** + * Copyright (c) 2022 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package wrp + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSpecHelperValidators(t *testing.T) { + tests := []struct { + description string + test func(*testing.T) + }{ + {"UTF8Validator", testUTF8Validator}, + {"MessageTypeValidator", testMessageTypeValidator}, + {"SourceValidator", testSourceValidator}, + {"DestinationValidator", testDestinationValidator}, + {"validateLocator", testValidateLocator}, + } + + for _, tc := range tests { + t.Run(tc.description, tc.test) + } +} + +func TestSpecValidators(t *testing.T) { + var ( + expectedStatus int64 = 3471 + expectedRequestDeliveryResponse int64 = 34 + expectedIncludeSpans bool = true + ) + + tests := []struct { + description string + msg Message + expectedErr []error + }{ + // Success case + { + description: "Valid spec success", + msg: Message{ + Type: SimpleRequestResponseMessageType, + Source: "dns:external.com", + Destination: "MAC:11:22:33:44:55:66", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + Payload: []byte{1, 2, 3, 4, 0xff, 0xce}, + ServiceName: "ServiceName", + URL: "someURL.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }, + }, + // Failure case + { + description: "Invaild spec error", + msg: Message{ + Type: Invalid0MessageType, + // Missing scheme + Source: "external.com", + // Invalid Mac + Destination: "MAC:+++BB-44-55", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + Payload: []byte{1, 2, 3, 4, 0xff, 0xce}, + ServiceName: "ServiceName", + // Not UFT8 URL string + URL: "someURL\xed\xbf\xbf.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }, + expectedErr: []error{ErrorInvalidMessageType, ErrorInvalidSource, ErrorInvalidDestination, ErrorInvalidMessageEncoding}, + }, + { + description: "Invaild spec error, empty message", + msg: Message{}, + expectedErr: []error{ErrorInvalidMessageType, ErrorInvalidSource, ErrorInvalidDestination}, + }, + { + description: "Invaild spec error, nonexistent MessageType", + msg: Message{ + Type: lastMessageType + 1, + Source: "dns:external.com", + Destination: "MAC:11:22:33:44:55:66", + }, + expectedErr: []error{ErrorInvalidMessageType}, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := SpecValidators().Validate(tc.msg) + if tc.expectedErr != nil { + for _, e := range tc.expectedErr { + assert.ErrorIs(err, e) + } + return + } + + assert.NoError(err) + }) + } +} + +func ExampleTypeValidator_Validate_specValidators() { + msgv, err := NewTypeValidator( + // Validates found msg types + map[MessageType]Validator{ + // Validates opinionated portions of the spec + SimpleEventMessageType: SpecValidators(), + // Only validates Source and nothing else + SimpleRequestResponseMessageType: SourceValidator(), + }, + // Validates unfound msg types + AlwaysInvalid()) + if err != nil { + return + } + + var ( + expectedStatus int64 = 3471 + expectedRequestDeliveryResponse int64 = 34 + expectedIncludeSpans bool = true + ) + foundErrFailure := msgv.Validate(Message{ + Type: SimpleEventMessageType, + // Missing scheme + Source: "external.com", + // Invalid Mac + Destination: "MAC:+++BB-44-55", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + // Not UFT8 Payload + Payload: []byte{1, 2, 3, 4, 0xff /* \xed\xbf\xbf is invalid */, 0xce}, + ServiceName: "ServiceName", + // Not UFT8 URL string + URL: "someURL\xed\xbf\xbf.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }) // Found error + foundErrSuccess1 := msgv.Validate(Message{ + Type: SimpleEventMessageType, + Source: "MAC:11:22:33:44:55:66", + Destination: "MAC:11:22:33:44:55:61", + }) // Found success + foundErrSuccess2 := msgv.Validate(Message{ + Type: SimpleRequestResponseMessageType, + Source: "MAC:11:22:33:44:55:66", + Destination: "invalid:a-BB-44-55", + }) // Found success + unfoundErrFailure := msgv.Validate(Message{Type: CreateMessageType}) // Unfound error + fmt.Println(foundErrFailure == nil, foundErrSuccess1 == nil, foundErrSuccess2 == nil, unfoundErrFailure == nil) + // Output: false true true false +} + +func testUTF8Validator(t *testing.T) { + var ( + expectedStatus int64 = 3471 + expectedRequestDeliveryResponse int64 = 34 + expectedIncludeSpans bool = true + ) + + tests := []struct { + description string + msg Message + expectedErr error + }{ + // Success case + { + description: "UTF8 success", + msg: Message{ + Type: SimpleRequestResponseMessageType, + Source: "dns:external.com", + Destination: "MAC:11:22:33:44:55:66", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + Payload: []byte{1, 2, 3, 4, 0xff, 0xce}, + ServiceName: "ServiceName", + URL: "someURL.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }, + }, + { + description: "Not UTF8 error", + msg: Message{ + Type: SimpleRequestResponseMessageType, + Source: "dns:external.com", + // Not UFT8 Destination string + Destination: "MAC:\xed\xbf\xbf", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + Payload: []byte{1, 2, 3, 4, 0xff, 0xce}, + ServiceName: "ServiceName", + URL: "someURL.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }, + expectedErr: ErrorInvalidMessageEncoding, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := UTF8Validator()(tc.msg) + if tc.expectedErr != nil { + assert.ErrorIs(err, tc.expectedErr) + return + } + + assert.NoError(err) + }) + } +} + +func testMessageTypeValidator(t *testing.T) { + tests := []struct { + description string + msg Message + expectedErr error + }{ + // Success case + { + description: "AuthorizationMessageType success", + msg: Message{Type: AuthorizationMessageType}, + }, + { + description: "SimpleRequestResponseMessageType success", + msg: Message{Type: SimpleRequestResponseMessageType}, + }, + { + description: "SimpleEventMessageType success", + msg: Message{Type: SimpleEventMessageType}, + }, + { + description: "CreateMessageType success", + msg: Message{Type: CreateMessageType}, + }, + { + description: "RetrieveMessageType success", + msg: Message{Type: RetrieveMessageType}, + }, + { + description: "UpdateMessageType success", + msg: Message{Type: UpdateMessageType}, + }, + { + description: "DeleteMessageType success", + msg: Message{Type: DeleteMessageType}, + }, + { + description: "ServiceRegistrationMessageType success", + msg: Message{Type: ServiceRegistrationMessageType}, + }, + { + description: "ServiceAliveMessageType success", + msg: Message{Type: ServiceAliveMessageType}, + }, + { + description: "UnknownMessageType success", + msg: Message{Type: UnknownMessageType}, + }, + // Failure case + { + description: "Invalid0MessageType error", + msg: Message{Type: Invalid0MessageType}, + expectedErr: ErrorInvalidMessageType, + }, + { + description: "Invalid0MessageType error", + msg: Message{Type: Invalid0MessageType}, + expectedErr: ErrorInvalidMessageType, + }, + { + description: "Invalid1MessageType error", + msg: Message{Type: Invalid1MessageType}, + expectedErr: ErrorInvalidMessageType, + }, + { + description: "lastMessageType error", + msg: Message{Type: lastMessageType}, + expectedErr: ErrorInvalidMessageType, + }, + { + description: "Nonexistent negative MessageType error", + msg: Message{Type: -10}, + expectedErr: ErrorInvalidMessageType, + }, + { + description: "Nonexistent positive MessageType error", + msg: Message{Type: lastMessageType + 1}, + expectedErr: ErrorInvalidMessageType, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := MessageTypeValidator()(tc.msg) + if tc.expectedErr != nil { + assert.ErrorIs(err, tc.expectedErr) + return + } + + assert.NoError(err) + }) + } +} + +func testSourceValidator(t *testing.T) { + // SourceValidator is mainly a wrapper for validateLocator. + // This test mainly ensures that SourceValidator returns nil for non errors + // and wraps errors with ErrorInvalidSource. + // testValidateLocator covers the actual spectrum of test cases. + + tests := []struct { + description string + msg Message + expectedErr error + }{ + // Success case + { + description: "Source success", + msg: Message{Source: "MAC:11:22:33:44:55:66"}, + }, + // Failures + { + description: "Source error", + msg: Message{Source: "invalid:a-BB-44-55"}, + expectedErr: ErrorInvalidSource, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := SourceValidator()(tc.msg) + if tc.expectedErr != nil { + assert.ErrorIs(err, tc.expectedErr) + return + } + + assert.NoError(err) + }) + } +} + +func testDestinationValidator(t *testing.T) { + // DestinationValidator is mainly a wrapper for validateLocator. + // This test mainly ensures that DestinationValidator returns nil for non errors + // and wraps errors with ErrorInvalidDestination. + // testValidateLocator covers the actual spectrum of test cases. + + tests := []struct { + description string + msg Message + expectedErr error + }{ + // Success case + { + description: "Destination success", + msg: Message{Destination: "MAC:11:22:33:44:55:66"}, + }, + // Failures + { + description: "Destination error", + msg: Message{Destination: "invalid:a-BB-44-55"}, + expectedErr: ErrorInvalidDestination, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := DestinationValidator()(tc.msg) + if tc.expectedErr != nil { + assert.ErrorIs(err, tc.expectedErr) + return + } + + assert.NoError(err) + }) + } +} + +func testValidateLocator(t *testing.T) { + tests := []struct { + description string + value string + expectedErr error + }{ + // mac success case + { + description: "Mac ID success, ':' delimiter", + value: "MAC:11:22:33:44:55:66", + expectedErr: nil, + }, + { + description: "Mac ID success, no delimiter", + value: "MAC:11aaBB445566", + expectedErr: nil, + }, + { + description: "Mac ID success, '-' delimiter", + value: "mac:11-aa-BB-44-55-66", + expectedErr: nil, + }, + { + description: "Mac ID success, ',' delimiter", + value: "mac:11,aa,BB,44,55,66", + expectedErr: nil, + }, + { + description: "Mac service success", + value: "mac:11,aa,BB,44,55,66/parodus/tag/test0", + expectedErr: nil, + }, + // Mac failure case + { + description: "Mac ID error, invalid mac ID character", + value: "MAC:invalid45566", + expectedErr: errInvalidCharacter, + }, + { + description: "Mac ID error, invalid mac ID length", + value: "mac:11-aa-BB-44-55", + expectedErr: errInvalidMacLength, + }, + { + description: "Mac ID error, no ID", + value: "mac:", + expectedErr: errEmptyAuthority, + }, + // Serial success case + { + description: "Serial ID success", + value: "serial:anything Goes!", + expectedErr: nil, + }, + // Serial failure case + { + description: "Invalid serial ID error, no ID", + value: "serial:", + expectedErr: errEmptyAuthority, + }, + // UUID success case + // The variant specified in RFC4122 + { + description: "UUID RFC4122 variant ID success", + value: "uuid:f47ac10b-58cc-0372-8567-0e02b2c3d479", + expectedErr: nil, + }, + { + description: "UUID RFC4122 variant ID success, with Microsoft encoding", + value: "uuid:{f47ac10b-58cc-0372-8567-0e02b2c3d479}", + expectedErr: nil, + }, + // Reserved, NCS backward compatibility. + { + description: "UUID Reserved variant ID success, with URN lower case ", + value: "UUID:urn:uuid:f47ac10b-58cc-4372-0567-0e02b2c3d479", + expectedErr: nil, + }, + { + description: "UUID Reserved variant ID success, with URN upper case", + value: "UUID:URN:UUID:f47ac10b-58cc-4372-0567-0e02b2c3d479", + expectedErr: nil, + }, + { + description: "UUID Reserved variant ID success, without URN", + value: "UUID:f47ac10b-58cc-4372-0567-0e02b2c3d479", + expectedErr: nil, + }, + // Reserved, Microsoft Corporation backward compatibility. + { + description: "UUID Microsoft variant ID success", + value: "uuid:f47ac10b-58cc-4372-c567-0e02b2c3d479", + expectedErr: nil, + }, + // Reserved for future definition. + { + description: "UUID Future variant ID success", + value: "uuid:f47ac10b-58cc-4372-e567-0e02b2c3d479", + expectedErr: nil, + }, + // UUID failure case + { + description: "Invalid UUID ID error", + value: "uuid:invalid45566", + expectedErr: errInvalidUUID, + }, + { + description: "Invalid UUID ID error, with URN", + value: "uuid:URN:UUID:invalid45566", + expectedErr: errInvalidUUID, + }, + { + description: "Invalid UUID ID error, with Microsoft encoding", + value: "uuid:{invalid45566}", + expectedErr: errInvalidUUID, + }, + { + description: "Invalid UUID ID error, no ID", + value: "uuid:", + expectedErr: errEmptyAuthority, + }, + // Event success case + { + description: "Event ID success", + value: "event:anything Goes!", + expectedErr: nil, + }, + // Event failure case + { + description: "Invalid event ID error, no ID", + value: "event:", + expectedErr: errEmptyAuthority, + }, + // DNS success case + { + description: "DNS ID success", + value: "dns:anything Goes!", + expectedErr: nil, + }, + // DNS failure case + { + description: "Invalid DNS ID error, no ID", + value: "dns:", + expectedErr: errEmptyAuthority, + }, + // Scheme failure case + { + description: "Invalid scheme error", + value: "invalid:a-BB-44-55", + expectedErr: errInvalidLocatorPattern, + }, + { + description: "Invalid scheme error, empty string", + value: "", + expectedErr: errInvalidLocatorPattern, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := validateLocator(tc.value) + if tc.expectedErr != nil { + assert.ErrorIs(err, tc.expectedErr) + return + } + + assert.NoError(err) + }) + } +} diff --git a/utf8.go b/utf8.go index 5aa0417..d1e26a2 100644 --- a/utf8.go +++ b/utf8.go @@ -26,7 +26,7 @@ import ( var ( ErrNotUTF8 = errors.New("field contains non-utf-8 characters") - ErrUnexpectedKind = errors.New("A struct or non-nil pointer to struct is required") + ErrUnexpectedKind = errors.New("a struct or non-nil pointer to struct is required") ) // UTF8 takes any struct verifies that it contains UTF-8 strings. @@ -55,7 +55,6 @@ func UTF8(v interface{}) error { if !utf8.ValidString(s) { return fmt.Errorf("%w: '%s:%v'", ErrNotUTF8, ft.Name, s) } - fmt.Println(s) } } diff --git a/validator.go b/validator.go index 1b3c07e..61d6ac1 100644 --- a/validator.go +++ b/validator.go @@ -29,12 +29,6 @@ var ( ErrInvalidMsgType = errors.New("invalid WRP message type") ) -// AlwaysInvalid doesn't validate anything about the message and always returns an error. -var AlwaysInvalid ValidatorFunc = func(m Message) error { return ErrInvalidMsgType } - -// AlwaysValid doesn't validate anything about the message and always returns nil. -var AlwaysValid ValidatorFunc = func(msg Message) error { return nil } - // Validator is a WRP validator that allows access to the Validate function. type Validator interface { Validate(m Message) error @@ -91,7 +85,7 @@ func NewTypeValidator(m map[MessageType]Validator, defaultValidator Validator) ( } if defaultValidator == nil { - defaultValidator = AlwaysInvalid + defaultValidator = AlwaysInvalid() } return TypeValidator{ @@ -99,3 +93,13 @@ func NewTypeValidator(m map[MessageType]Validator, defaultValidator Validator) ( defaultValidator: defaultValidator, }, nil } + +// AlwaysInvalid returns a WRP validator that doesn't validate anything about the message and always returns an error. +func AlwaysInvalid() ValidatorFunc { + return func(_ Message) error { return ErrInvalidMsgType } +} + +// AlwaysValid returns a WRP validator that doesn't validate anything about the message and always returns nil. +func AlwaysValid() ValidatorFunc { + return func(_ Message) error { return nil } +} diff --git a/validator_test.go b/validator_test.go index 70c3aab..1bafb36 100644 --- a/validator_test.go +++ b/validator_test.go @@ -42,7 +42,7 @@ func TestValidators(t *testing.T) { // Failure case { description: "Mix Validators error", - vs: Validators{AlwaysValid, nil, AlwaysInvalid, Validators{AlwaysValid, nil, AlwaysInvalid}}, + vs: Validators{AlwaysValid(), nil, AlwaysInvalid(), Validators{AlwaysValid(), nil, AlwaysInvalid()}}, msg: Message{Type: SimpleEventMessageType}, expectedErr: []error{ErrInvalidMsgType, ErrInvalidMsgType}, }, @@ -96,9 +96,9 @@ func TestTypeValidator(t *testing.T) { func ExampleNewTypeValidator() { msgv, err := NewTypeValidator( // Validates found msg types - map[MessageType]Validator{SimpleEventMessageType: AlwaysValid}, + map[MessageType]Validator{SimpleEventMessageType: AlwaysValid()}, // Validates unfound msg types - AlwaysInvalid) + AlwaysInvalid()) fmt.Printf("%v %T", err == nil, msgv) // Output: true wrp.TypeValidator } @@ -106,9 +106,9 @@ func ExampleNewTypeValidator() { func ExampleTypeValidator_Validate() { msgv, err := NewTypeValidator( // Validates found msg types - map[MessageType]Validator{SimpleEventMessageType: AlwaysValid}, + map[MessageType]Validator{SimpleEventMessageType: AlwaysValid()}, // Validates unfound msg types - AlwaysInvalid) + AlwaysInvalid()) if err != nil { return } @@ -131,22 +131,22 @@ func testTypeValidatorValidate(t *testing.T) { { description: "Found success", m: map[MessageType]Validator{ - SimpleEventMessageType: AlwaysValid, + SimpleEventMessageType: AlwaysValid(), }, msg: Message{Type: SimpleEventMessageType}, }, { description: "Unfound success", m: map[MessageType]Validator{ - SimpleEventMessageType: AlwaysInvalid, + SimpleEventMessageType: AlwaysInvalid(), }, - defaultValidator: AlwaysValid, + defaultValidator: AlwaysValid(), msg: Message{Type: CreateMessageType}, }, { description: "Unfound success, nil list of default Validators", m: map[MessageType]Validator{ - SimpleEventMessageType: AlwaysInvalid, + SimpleEventMessageType: AlwaysInvalid(), }, defaultValidator: Validators{nil}, msg: Message{Type: CreateMessageType}, @@ -154,7 +154,7 @@ func testTypeValidatorValidate(t *testing.T) { { description: "Unfound success, empty map of default Validators", m: map[MessageType]Validator{ - SimpleEventMessageType: AlwaysInvalid, + SimpleEventMessageType: AlwaysInvalid(), }, defaultValidator: Validators{}, msg: Message{Type: CreateMessageType}, @@ -163,9 +163,9 @@ func testTypeValidatorValidate(t *testing.T) { { description: "Found error", m: map[MessageType]Validator{ - SimpleEventMessageType: AlwaysInvalid, + SimpleEventMessageType: AlwaysInvalid(), }, - defaultValidator: AlwaysValid, + defaultValidator: AlwaysValid(), msg: Message{Type: SimpleEventMessageType}, expectedErr: ErrInvalidMsgType, }, @@ -180,7 +180,7 @@ func testTypeValidatorValidate(t *testing.T) { { description: "Unfound error", m: map[MessageType]Validator{ - SimpleEventMessageType: AlwaysValid, + SimpleEventMessageType: AlwaysValid(), }, msg: Message{Type: CreateMessageType}, expectedErr: ErrInvalidMsgType, @@ -188,7 +188,7 @@ func testTypeValidatorValidate(t *testing.T) { { description: "Unfound error, nil default Validators", m: map[MessageType]Validator{ - SimpleEventMessageType: AlwaysInvalid, + SimpleEventMessageType: AlwaysInvalid(), }, defaultValidator: nil, msg: Message{Type: CreateMessageType}, @@ -232,15 +232,15 @@ func testTypeValidatorFactory(t *testing.T) { { description: "Default Validators success", m: map[MessageType]Validator{ - SimpleEventMessageType: AlwaysValid, + SimpleEventMessageType: AlwaysValid(), }, - defaultValidator: AlwaysValid, + defaultValidator: AlwaysValid(), expectedErr: nil, }, { description: "Omit default Validators success", m: map[MessageType]Validator{ - SimpleEventMessageType: AlwaysValid, + SimpleEventMessageType: AlwaysValid(), }, expectedErr: nil, }, @@ -248,7 +248,7 @@ func testTypeValidatorFactory(t *testing.T) { { description: "Nil map of Validators error", m: nil, - defaultValidator: AlwaysValid, + defaultValidator: AlwaysValid(), expectedErr: ErrInvalidValidator, }, } @@ -307,7 +307,6 @@ func testAlwaysValid(t *testing.T) { SessionID: "sessionID123", }, }, - // Failure case { description: "Filled message success", msg: Message{ @@ -348,7 +347,7 @@ func testAlwaysValid(t *testing.T) { for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { assert := assert.New(t) - err := AlwaysValid.Validate(tc.msg) + err := AlwaysValid().Validate(tc.msg) assert.NoError(err) }) } @@ -430,7 +429,7 @@ func testAlwaysInvalid(t *testing.T) { for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { assert := assert.New(t) - err := AlwaysInvalid.Validate(tc.msg) + err := AlwaysInvalid().Validate(tc.msg) assert.ErrorIs(err, ErrInvalidMsgType) }) }