From 3fff8f6d8a7dd8b42f6a5e84425ff0977d5d8fee Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Mon, 7 Oct 2024 10:05:44 +0900 Subject: [PATCH 1/6] Change accessors in jwk to return (T, bool) --- jwe/interface.go | 2 +- jwe/jwe.go | 8 +- jwe/key_provider.go | 10 ++- jwk/ecdsa.go | 39 ++++++--- jwk/ecdsa_gen.go | 153 +++++++++++++++++--------------- jwk/headers_test.go | 26 +++--- jwk/interface_gen.go | 18 ++-- jwk/jwk.go | 7 ++ jwk/jwk_test.go | 67 ++++++++++---- jwk/okp.go | 45 +++++++--- jwk/okp_gen.go | 135 +++++++++++++++-------------- jwk/rsa.go | 22 +++-- jwk/rsa_gen.go | 182 +++++++++++++++++++++++---------------- jwk/set.go | 3 +- jwk/symmetric.go | 3 +- jwk/symmetric_gen.go | 55 ++++++------ jws/jws.go | 2 +- jws/key_provider.go | 13 +-- jws/message.go | 2 +- tools/cmd/genjwk/main.go | 28 +++--- 20 files changed, 492 insertions(+), 328 deletions(-) diff --git a/jwe/interface.go b/jwe/interface.go index ddecc8df..1804db58 100644 --- a/jwe/interface.go +++ b/jwe/interface.go @@ -28,7 +28,7 @@ type KeyEncrypter interface { // As of this writing, this is solely used to identify KeyEncrypter // objects that also carry a key ID on its own. type KeyIDer interface { - KeyID() string + KeyID() (string, bool) } // KeyDecrypter is an interface for objects that can decrypt a content diff --git a/jwe/jwe.go b/jwe/jwe.go index 349f60ae..085ef944 100644 --- a/jwe/jwe.go +++ b/jwe/jwe.go @@ -94,11 +94,15 @@ func (b *recipientBuilder) Build(cek []byte, calg jwa.ContentEncryptionAlgorithm if ke, ok := b.key.(KeyEncrypter); ok { enc = &keyEncrypterWrapper{encrypter: ke} if kider, ok := enc.(KeyIDer); ok { - keyID = kider.KeyID() + if v, ok := kider.KeyID(); ok { + keyID = v + } } } else if jwkKey, ok := b.key.(jwk.Key); ok { // Meanwhile, grab the kid as well - keyID = jwkKey.KeyID() + if v, ok := jwkKey.KeyID(); ok { + keyID = v + } var raw interface{} if err := jwk.Export(jwkKey, &raw); err != nil { diff --git a/jwe/key_provider.go b/jwe/key_provider.go index 965fabfd..1df0e747 100644 --- a/jwe/key_provider.go +++ b/jwe/key_provider.go @@ -107,14 +107,16 @@ type keySetProvider struct { } func (kp *keySetProvider) selectKey(sink KeySink, key jwk.Key, _ Recipient, _ *Message) error { - if usage := key.KeyUsage(); usage != "" && usage != jwk.ForEncryption.String() { - return nil + if usage, ok := key.KeyUsage(); ok { + if usage != "" && usage != jwk.ForEncryption.String() { + return nil + } } - if v := key.Algorithm(); v != nil { + if v, ok := key.Algorithm(); ok { kalg, ok := jwa.LookupKeyEncryptionAlgorithm(v.String()) if !ok { - return fmt.Errorf(`invalid key encryption algorithm %s`, key.Algorithm()) + return fmt.Errorf(`invalid key encryption algorithm %s`, v) } sink.Key(kalg, key) diff --git a/jwk/ecdsa.go b/jwk/ecdsa.go index 5aebc1ce..9b4027a2 100644 --- a/jwk/ecdsa.go +++ b/jwk/ecdsa.go @@ -113,7 +113,12 @@ func ecdsaJWKToRaw(keyif Key, hint interface{}) (interface{}, error) { k.mu.RLock() defer k.mu.RUnlock() - return buildECDSAPublicKey(k.Crv(), k.x, k.y) + + crv, ok := k.Crv() + if !ok { + return nil, fmt.Errorf(`missing "crv" field`) + } + return buildECDSAPublicKey(crv, k.x, k.y) case *ecdsaPrivateKey: switch hint.(type) { case ecdsa.PrivateKey, *ecdsa.PrivateKey, interface{}: @@ -123,7 +128,12 @@ func ecdsaJWKToRaw(keyif Key, hint interface{}) (interface{}, error) { k.mu.RLock() defer k.mu.RUnlock() - pubk, err := buildECDSAPublicKey(k.Crv(), k.x, k.y) + + crv, ok := k.Crv() + if !ok { + return nil, fmt.Errorf(`missing "crv" field`) + } + pubk, err := buildECDSAPublicKey(crv, k.x, k.y) if err != nil { return nil, fmt.Errorf(`failed to build public key: %w`, err) } @@ -232,28 +242,33 @@ func (k ecdsaPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) { } func ecdsaValidateKey(k interface { - Crv() jwa.EllipticCurveAlgorithm - X() []byte - Y() []byte + Crv() (jwa.EllipticCurveAlgorithm, bool) + X() ([]byte, bool) + Y() ([]byte, bool) }, checkPrivate bool) error { - crv, err := ourecdsa.CurveFromAlgorithm(k.Crv()) + crvtyp, ok := k.Crv() + if !ok { + return fmt.Errorf(`missing "crv" field`) + } + + crv, err := ourecdsa.CurveFromAlgorithm(crvtyp) if err != nil { - return fmt.Errorf(`invalid curve algorithm %q: %w`, k.Crv(), err) + return fmt.Errorf(`invalid curve algorithm %q: %w`, crvtyp, err) } keySize := ecutil.CalculateKeySize(crv) - if x := k.X(); len(x) != keySize { + if x, ok := k.X(); !ok || len(x) != keySize { return fmt.Errorf(`invalid "x" length (%d) for curve %q`, len(x), crv.Params().Name) } - if y := k.Y(); len(y) != keySize { + if y, ok := k.Y(); !ok || len(y) != keySize { return fmt.Errorf(`invalid "y" length (%d) for curve %q`, len(y), crv.Params().Name) } if checkPrivate { - if priv, ok := k.(interface{ D() []byte }); ok { - if len(priv.D()) != keySize { - return fmt.Errorf(`invalid "d" length (%d) for curve %q`, len(priv.D()), crv.Params().Name) + if priv, ok := k.(keyWithD); ok { + if d, ok := priv.D(); !ok || len(d) != keySize { + return fmt.Errorf(`invalid "d" length (%d) for curve %q`, len(d), crv.Params().Name) } } else { return fmt.Errorf(`missing "d" value`) diff --git a/jwk/ecdsa_gen.go b/jwk/ecdsa_gen.go index ef12c857..9cd786cc 100644 --- a/jwk/ecdsa_gen.go +++ b/jwk/ecdsa_gen.go @@ -25,9 +25,9 @@ const ( type ECDSAPublicKey interface { Key - Crv() jwa.EllipticCurveAlgorithm - X() []byte - Y() []byte + Crv() (jwa.EllipticCurveAlgorithm, bool) + X() ([]byte, bool) + Y() ([]byte, bool) } type ecdsaPublicKey struct { @@ -65,72 +65,78 @@ func (h ecdsaPublicKey) IsPrivate() bool { return false } -func (h *ecdsaPublicKey) Algorithm() jwa.KeyAlgorithm { +func (h *ecdsaPublicKey) Algorithm() (jwa.KeyAlgorithm, bool) { if h.algorithm != nil { - return *(h.algorithm) + return *(h.algorithm), true } - return nil + return nil, false } -func (h *ecdsaPublicKey) Crv() jwa.EllipticCurveAlgorithm { +func (h *ecdsaPublicKey) Crv() (jwa.EllipticCurveAlgorithm, bool) { if h.crv != nil { - return *(h.crv) + return *(h.crv), true } - return jwa.InvalidEllipticCurve() + return jwa.InvalidEllipticCurve(), false } -func (h *ecdsaPublicKey) KeyID() string { +func (h *ecdsaPublicKey) KeyID() (string, bool) { if h.keyID != nil { - return *(h.keyID) + return *(h.keyID), true } - return "" + return "", false } -func (h *ecdsaPublicKey) KeyOps() KeyOperationList { +func (h *ecdsaPublicKey) KeyOps() (KeyOperationList, bool) { if h.keyOps != nil { - return *(h.keyOps) + return *(h.keyOps), true } - return nil + return nil, false } -func (h *ecdsaPublicKey) KeyUsage() string { +func (h *ecdsaPublicKey) KeyUsage() (string, bool) { if h.keyUsage != nil { - return *(h.keyUsage) + return *(h.keyUsage), true } - return "" + return "", false } -func (h *ecdsaPublicKey) X() []byte { - return h.x +func (h *ecdsaPublicKey) X() ([]byte, bool) { + if h.x != nil { + return h.x, true + } + return nil, false } -func (h *ecdsaPublicKey) X509CertChain() *cert.Chain { - return h.x509CertChain +func (h *ecdsaPublicKey) X509CertChain() (*cert.Chain, bool) { + return h.x509CertChain, true } -func (h *ecdsaPublicKey) X509CertThumbprint() string { +func (h *ecdsaPublicKey) X509CertThumbprint() (string, bool) { if h.x509CertThumbprint != nil { - return *(h.x509CertThumbprint) + return *(h.x509CertThumbprint), true } - return "" + return "", false } -func (h *ecdsaPublicKey) X509CertThumbprintS256() string { +func (h *ecdsaPublicKey) X509CertThumbprintS256() (string, bool) { if h.x509CertThumbprintS256 != nil { - return *(h.x509CertThumbprintS256) + return *(h.x509CertThumbprintS256), true } - return "" + return "", false } -func (h *ecdsaPublicKey) X509URL() string { +func (h *ecdsaPublicKey) X509URL() (string, bool) { if h.x509URL != nil { - return *(h.x509URL) + return *(h.x509URL), true } - return "" + return "", false } -func (h *ecdsaPublicKey) Y() []byte { - return h.y +func (h *ecdsaPublicKey) Y() ([]byte, bool) { + if h.y != nil { + return h.y, true + } + return nil, false } func (h *ecdsaPublicKey) Has(name string) bool { @@ -686,10 +692,10 @@ func (h *ecdsaPublicKey) Keys() []string { type ECDSAPrivateKey interface { Key - Crv() jwa.EllipticCurveAlgorithm - D() []byte - X() []byte - Y() []byte + Crv() (jwa.EllipticCurveAlgorithm, bool) + D() ([]byte, bool) + X() ([]byte, bool) + Y() ([]byte, bool) } type ecdsaPrivateKey struct { @@ -728,76 +734,85 @@ func (h ecdsaPrivateKey) IsPrivate() bool { return true } -func (h *ecdsaPrivateKey) Algorithm() jwa.KeyAlgorithm { +func (h *ecdsaPrivateKey) Algorithm() (jwa.KeyAlgorithm, bool) { if h.algorithm != nil { - return *(h.algorithm) + return *(h.algorithm), true } - return nil + return nil, false } -func (h *ecdsaPrivateKey) Crv() jwa.EllipticCurveAlgorithm { +func (h *ecdsaPrivateKey) Crv() (jwa.EllipticCurveAlgorithm, bool) { if h.crv != nil { - return *(h.crv) + return *(h.crv), true } - return jwa.InvalidEllipticCurve() + return jwa.InvalidEllipticCurve(), false } -func (h *ecdsaPrivateKey) D() []byte { - return h.d +func (h *ecdsaPrivateKey) D() ([]byte, bool) { + if h.d != nil { + return h.d, true + } + return nil, false } -func (h *ecdsaPrivateKey) KeyID() string { +func (h *ecdsaPrivateKey) KeyID() (string, bool) { if h.keyID != nil { - return *(h.keyID) + return *(h.keyID), true } - return "" + return "", false } -func (h *ecdsaPrivateKey) KeyOps() KeyOperationList { +func (h *ecdsaPrivateKey) KeyOps() (KeyOperationList, bool) { if h.keyOps != nil { - return *(h.keyOps) + return *(h.keyOps), true } - return nil + return nil, false } -func (h *ecdsaPrivateKey) KeyUsage() string { +func (h *ecdsaPrivateKey) KeyUsage() (string, bool) { if h.keyUsage != nil { - return *(h.keyUsage) + return *(h.keyUsage), true } - return "" + return "", false } -func (h *ecdsaPrivateKey) X() []byte { - return h.x +func (h *ecdsaPrivateKey) X() ([]byte, bool) { + if h.x != nil { + return h.x, true + } + return nil, false } -func (h *ecdsaPrivateKey) X509CertChain() *cert.Chain { - return h.x509CertChain +func (h *ecdsaPrivateKey) X509CertChain() (*cert.Chain, bool) { + return h.x509CertChain, true } -func (h *ecdsaPrivateKey) X509CertThumbprint() string { +func (h *ecdsaPrivateKey) X509CertThumbprint() (string, bool) { if h.x509CertThumbprint != nil { - return *(h.x509CertThumbprint) + return *(h.x509CertThumbprint), true } - return "" + return "", false } -func (h *ecdsaPrivateKey) X509CertThumbprintS256() string { +func (h *ecdsaPrivateKey) X509CertThumbprintS256() (string, bool) { if h.x509CertThumbprintS256 != nil { - return *(h.x509CertThumbprintS256) + return *(h.x509CertThumbprintS256), true } - return "" + return "", false } -func (h *ecdsaPrivateKey) X509URL() string { +func (h *ecdsaPrivateKey) X509URL() (string, bool) { if h.x509URL != nil { - return *(h.x509URL) + return *(h.x509URL), true } - return "" + return "", false } -func (h *ecdsaPrivateKey) Y() []byte { - return h.y +func (h *ecdsaPrivateKey) Y() ([]byte, bool) { + if h.y != nil { + return h.y, true + } + return nil, false } func (h *ecdsaPrivateKey) Has(name string) bool { diff --git a/jwk/headers_test.go b/jwk/headers_test.go index 4cb19840..b3a446d0 100644 --- a/jwk/headers_test.go +++ b/jwk/headers_test.go @@ -70,16 +70,22 @@ func TestHeader(t *testing.T) { } } require.NoError(t, h.Set("Default", dummy), `Setting "Default" should succeed`) - require.Nil(t, h.Algorithm(), "Algorithm should be nil") - if h.KeyID() != "" { - t.Fatalf("KeyID should be empty string") - } - if h.KeyUsage() != "" { - t.Fatalf("KeyUsage should be empty string") - } - if h.KeyOps() != nil { - t.Fatalf("KeyOps should be empty string") - } + + alg, ok := h.Algorithm() + require.True(t, !ok, `Algorithm should not be set`) + require.Nil(t, alg, "Algorithm should be nil") + + kid, ok := h.KeyID() + require.False(t, ok, `KeyID should not be set`) + require.Empty(t, kid, "KeyID should be empty") + + use, ok := h.KeyUsage() + require.False(t, ok, `KeyUsage should not be set`) + require.Empty(t, use, "KeyUsage should be empty") + + ops, ok := h.KeyOps() + require.False(t, ok, `KeyOps should not be set`) + require.Nil(t, ops, "KeyOps should be nil") }) t.Run("Algorithm", func(t *testing.T) { diff --git a/jwk/interface_gen.go b/jwk/interface_gen.go index 7522e043..85e6e8db 100644 --- a/jwk/interface_gen.go +++ b/jwk/interface_gen.go @@ -86,24 +86,24 @@ type Key interface { // KeyType returns the `kty` of a JWK KeyType() jwa.KeyType // KeyUsage returns `use` of a JWK - KeyUsage() string + KeyUsage() (string, bool) // KeyOps returns `key_ops` of a JWK - KeyOps() KeyOperationList + KeyOps() (KeyOperationList, bool) // Algorithm returns `alg` of a JWK // Algorithm returns the value of the `alg` field. // - // This field may contain either `jwk.SignatureAlgorithm` or `jwk.KeyEncryptionAlgorithm`. + // This field may contain either `jwk.SignatureAlgorithm`, `jwk.KeyEncryptionAlgorithm`, or `jwk.ContentEncryptionAlgorithm`. // This is why there exists a `jwa.KeyAlgorithm` type that encompasses both types. - Algorithm() jwa.KeyAlgorithm + Algorithm() (jwa.KeyAlgorithm, bool) // KeyID returns `kid` of a JWK - KeyID() string + KeyID() (string, bool) // X509URL returns `x5u` of a JWK - X509URL() string + X509URL() (string, bool) // X509CertChain returns `x5c` of a JWK - X509CertChain() *cert.Chain + X509CertChain() (*cert.Chain, bool) // X509CertThumbprint returns `x5t` of a JWK - X509CertThumbprint() string + X509CertThumbprint() (string, bool) // X509CertThumbprintS256 returns `x5t#S256` of a JWK - X509CertThumbprintS256() string + X509CertThumbprintS256() (string, bool) } diff --git a/jwk/jwk.go b/jwk/jwk.go index 13a4a55e..14532f26 100644 --- a/jwk/jwk.go +++ b/jwk/jwk.go @@ -636,3 +636,10 @@ func Configure(options ...GlobalOption) { strictKeyUsage.Store(*strictKeyUsagePtr) } } + +// These are used when validating keys. +type keyWithD interface { + D() ([]byte, bool) +} + +var _ keyWithD = &okpPrivateKey{} diff --git a/jwk/jwk_test.go b/jwk/jwk_test.go index aff6054b..d4f474a5 100644 --- a/jwk/jwk_test.go +++ b/jwk/jwk_test.go @@ -155,25 +155,33 @@ func expectedRawKeyType(key jwk.Key) interface{} { case jwk.SymmetricKey: return []byte(nil) case jwk.OKPPrivateKey: - switch key.Crv() { + crv, ok := key.Crv() + if !ok { + panic("missing crv") + } + switch crv { case jwa.Ed25519(): return ed25519.PrivateKey(nil) case jwa.X25519(): return &ecdh.PrivateKey{} default: - panic("unknown curve type for OKPPrivateKey:" + key.Crv().String()) + panic("unknown curve type for OKPPrivateKey:" + crv.String()) } case jwk.OKPPublicKey: - switch key.Crv() { + crv, ok := key.Crv() + if !ok { + panic("missing crv") + } + switch crv { case jwa.Ed25519(): return ed25519.PublicKey(nil) case jwa.X25519(): return &ecdh.PublicKey{} default: - panic("unknown curve type for OKPPublicKey:" + key.Crv().String()) + panic("unknown curve type for OKPPublicKey:" + crv.String()) } default: - panic("unknown key type:" + reflect.TypeOf(key).String()) + panic(fmt.Sprintf("unknown key type: %T", key)) } } @@ -201,7 +209,11 @@ func VerifyKey(t *testing.T, def map[string]keyDef) { require.NotEqual(t, zeroval, method, `method should not be a zero value`) retvals := method.Call(nil) - require.Len(t, retvals, 1, `there should be 1 return value`) + expectedReturnValues := 2 + if mname == "KeyType" { + expectedReturnValues = 1 + } + require.Len(t, retvals, expectedReturnValues, `there should be 1 return value`) require.Equal(t, expected, retvals[0].Interface()) } }) @@ -362,7 +374,9 @@ func TestParse(t *testing.T) { crawkey = &rawkey case jwk.OKPPrivateKey: require.True(t, isPrivate, `jwk.IsPrivateKey(&ed25519.PrivateKey) should be true`) - switch k.Crv() { + crv, ok := k.Crv() + require.True(t, ok, `k.Crv() should succeed`) + switch crv { case jwa.Ed25519(): var rawkey ed25519.PrivateKey require.NoError(t, jwk.Export(key, &rawkey), `key.Raw(&ed25519.PrivateKey) should succeed`) @@ -372,14 +386,16 @@ func TestParse(t *testing.T) { require.NoError(t, jwk.Export(key, &rawkey), `key.Raw(&ecdh.PrivateKey) should succeed`) crawkey = &rawkey default: - t.Errorf(`invalid curve %s`, k.Crv()) + t.Errorf(`invalid curve %s`, crv) } // NOTE: Has to come after private // key, since it's a subset of the // private key variant. case jwk.OKPPublicKey: require.False(t, isPrivate, `jwk.IsPrivateKey(&ed25519.PublicKey) should be false`) - switch k.Crv() { + crv, ok := k.Crv() + require.True(t, ok, `k.Crv() should succeed`) + switch crv { case jwa.Ed25519(): var rawkey ed25519.PublicKey require.NoError(t, jwk.Export(key, &rawkey), `key.Raw(&ed25519.PublicKey) should succeed`) @@ -389,7 +405,7 @@ func TestParse(t *testing.T) { require.NoError(t, jwk.Export(key, &rawkey), `key.Raw(&ecdh.PublicKey) should succeed`) crawkey = &rawkey default: - t.Errorf(`invalid curve %s`, k.Crv()) + t.Errorf(`invalid curve %s`, crv) } default: t.Errorf(`invalid key type %T`, key) @@ -718,9 +734,13 @@ func TestAssignKeyID(t *testing.T) { for _, generator := range generators { k, err := generator() require.NoError(t, err, `jwk generation should be successful`) - require.Empty(t, k.KeyID(), `k.KeyID should be non-empty`) + kid, ok := k.KeyID() + require.False(t, ok, `k.KeyID should be empty`) + require.Empty(t, kid, `k.KeyID should be non-empty`) require.NoError(t, jwk.AssignKeyID(k), `AssignKeyID shuld be successful`) - require.NotEmpty(t, k.KeyID(), `k.KeyID should be non-empty`) + kid, ok = k.KeyID() + require.True(t, ok, `k.KeyID should be non-empty`) + require.NotEmpty(t, kid, `k.KeyID should be non-empty`) } } @@ -852,7 +872,9 @@ func TestPublicKeyOf(t *testing.T) { for i, key := range setKeys { setKey, ok := newSet.Key(i) require.True(t, ok, `element %d should be present`, i) - require.Equal(t, fmt.Sprintf("key%d", i), setKey.KeyID(), `KeyID() should match for %T`, setKey) + kid, ok := setKey.KeyID() + require.True(t, ok, `KeyID() should be present`) + require.Equal(t, fmt.Sprintf("key%d", i), kid, `KeyID() should match for %T`, setKey) // Get the raw key to compare var rawKey interface{} @@ -1473,7 +1495,9 @@ func TestGH412(t *testing.T) { k, ok := set.Key(idx) require.True(t, ok, `set.Get should succeed`) require.NoError(t, set.RemoveKey(k), `set.Remove should succeed`) - t.Logf("deleted key %s", k.KeyID()) + kid, ok := k.KeyID() + require.True(t, ok, `k.KeyID should succeed`) + t.Logf("deleted key %s", kid) require.Equal(t, iterations-1, set.Len(), `set.Len should be %d`, iterations-1) @@ -1488,8 +1512,10 @@ func TestGH412(t *testing.T) { for i := range set.Len() { key, ok := set.Key(i) require.True(t, ok, `set.Key() should succeed`) - require.NotEqual(t, k.KeyID(), key.KeyID(), `key id should not match`) - delete(expected, key.KeyID()) + gotkid, ok := key.KeyID() + require.True(t, ok, `key.KeyID should succeed`) + require.NotEqual(t, kid, gotkid, `key id should not match`) + delete(expected, gotkid) } require.Len(t, expected, 0, `expected map should be empty`) @@ -1504,7 +1530,8 @@ func TestGH491(t *testing.T) { // there should be 2 keys , get the first key k, _ := keys.Key(0) - ops := k.KeyOps() + ops, ok := k.KeyOps() + require.True(t, ok, `k.KeyOps should succeed`) require.Equal(t, jwk.KeyOperationList{jwk.KeyOpDeriveKey}, ops, `k.KeyOps should match`) } @@ -1904,7 +1931,8 @@ func TestValidation(t *testing.T) { require.NoError(t, err, `jwx.GenerateEcdsaJwk should succeed`) require.NoError(t, key.Validate(), `key.Validate should succeed`) - x := key.(jwk.ECDSAPrivateKey).X() + x, ok := key.(jwk.ECDSAPrivateKey).X() + require.True(t, ok, `key.(jwk.ECDSAPrivateKey).X should succeed`) require.NoError(t, key.Set(jwk.ECDSAXKey, x[:len(x)/2]), `key.Set should succeed`) require.Error(t, key.Validate(), `key.Validate should fail`) @@ -1919,7 +1947,8 @@ func TestValidation(t *testing.T) { key, err := jwxtest.GenerateEd25519Jwk() require.NoError(t, err, `jwx.GenerateEd25519Jwk should succeed`) require.NoError(t, key.Validate(), `key.Validate should succeed`) - x := key.(jwk.OKPPrivateKey).X() + x, ok := key.(jwk.OKPPrivateKey).X() + require.True(t, ok, `key.(jwk.OKPPrivateKey).X should succeed`) require.NoError(t, key.Set(jwk.OKPXKey, []byte(nil)), `key.Set should succeed`) require.Error(t, key.Validate(), `key.Validate should fail`) diff --git a/jwk/okp.go b/jwk/okp.go index fdabeed7..c27eaaa0 100644 --- a/jwk/okp.go +++ b/jwk/okp.go @@ -97,7 +97,12 @@ func (k *okpPublicKey) Raw(v interface{}) error { k.mu.RLock() defer k.mu.RUnlock() - pubk, err := buildOKPPublicKey(k.Crv(), k.x) + crv, ok := k.Crv() + if !ok { + return fmt.Errorf(`missing "crv" field`) + } + + pubk, err := buildOKPPublicKey(crv, k.x) if err != nil { return fmt.Errorf(`jwk.OKPPublicKey: failed to build public key: %w`, err) } @@ -141,7 +146,12 @@ func okpJWKToRaw(key Key, _ interface{} /* this is unused because this is half b key.mu.RLock() defer key.mu.RUnlock() - privk, err := buildOKPPrivateKey(key.Crv(), key.x, key.d) + crv, ok := key.Crv() + if !ok { + return nil, fmt.Errorf(`missing "crv" field`) + } + + privk, err := buildOKPPrivateKey(crv, key.x, key.d) if err != nil { return nil, fmt.Errorf(`jwk.OKPPrivateKey: failed to build private key: %w`, err) } @@ -150,7 +160,11 @@ func okpJWKToRaw(key Key, _ interface{} /* this is unused because this is half b key.mu.RLock() defer key.mu.RUnlock() - pubk, err := buildOKPPublicKey(key.Crv(), key.x) + crv, ok := key.Crv() + if !ok { + return nil, fmt.Errorf(`missing "crv" field`) + } + pubk, err := buildOKPPublicKey(crv, key.x) if err != nil { return nil, fmt.Errorf(`jwk.OKPPublicKey: failed to build public key: %w`, err) } @@ -207,9 +221,13 @@ func (k okpPublicKey) Thumbprint(hash crypto.Hash) ([]byte, error) { k.mu.RLock() defer k.mu.RUnlock() + crv, ok := k.Crv() + if !ok { + return nil, fmt.Errorf(`missing "crv" field`) + } return okpThumbprint( hash, - k.Crv().String(), + crv.String(), base64.EncodeToString(k.x), ), nil } @@ -220,27 +238,32 @@ func (k okpPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) { k.mu.RLock() defer k.mu.RUnlock() + crv, ok := k.Crv() + if !ok { + return nil, fmt.Errorf(`missing "crv" field`) + } + return okpThumbprint( hash, - k.Crv().String(), + crv.String(), base64.EncodeToString(k.x), ), nil } func validateOKPKey(key interface { - Crv() jwa.EllipticCurveAlgorithm - X() []byte + Crv() (jwa.EllipticCurveAlgorithm, bool) + X() ([]byte, bool) }) error { - if key.Crv() == jwa.InvalidEllipticCurve() { + if v, ok := key.Crv(); !ok || v == jwa.InvalidEllipticCurve() { return fmt.Errorf(`invalid curve algorithm`) } - if len(key.X()) == 0 { + if v, ok := key.X(); !ok || len(v) == 0 { return fmt.Errorf(`missing "x" field`) } - if priv, ok := key.(interface{ D() []byte }); ok { - if len(priv.D()) == 0 { + if priv, ok := key.(keyWithD); ok { + if d, ok := priv.D(); !ok || len(d) == 0 { return fmt.Errorf(`missing "d" field`) } } diff --git a/jwk/okp_gen.go b/jwk/okp_gen.go index 43ef145c..efdb651e 100644 --- a/jwk/okp_gen.go +++ b/jwk/okp_gen.go @@ -24,8 +24,8 @@ const ( type OKPPublicKey interface { Key - Crv() jwa.EllipticCurveAlgorithm - X() []byte + Crv() (jwa.EllipticCurveAlgorithm, bool) + X() ([]byte, bool) } type okpPublicKey struct { @@ -62,68 +62,71 @@ func (h okpPublicKey) IsPrivate() bool { return false } -func (h *okpPublicKey) Algorithm() jwa.KeyAlgorithm { +func (h *okpPublicKey) Algorithm() (jwa.KeyAlgorithm, bool) { if h.algorithm != nil { - return *(h.algorithm) + return *(h.algorithm), true } - return nil + return nil, false } -func (h *okpPublicKey) Crv() jwa.EllipticCurveAlgorithm { +func (h *okpPublicKey) Crv() (jwa.EllipticCurveAlgorithm, bool) { if h.crv != nil { - return *(h.crv) + return *(h.crv), true } - return jwa.InvalidEllipticCurve() + return jwa.InvalidEllipticCurve(), false } -func (h *okpPublicKey) KeyID() string { +func (h *okpPublicKey) KeyID() (string, bool) { if h.keyID != nil { - return *(h.keyID) + return *(h.keyID), true } - return "" + return "", false } -func (h *okpPublicKey) KeyOps() KeyOperationList { +func (h *okpPublicKey) KeyOps() (KeyOperationList, bool) { if h.keyOps != nil { - return *(h.keyOps) + return *(h.keyOps), true } - return nil + return nil, false } -func (h *okpPublicKey) KeyUsage() string { +func (h *okpPublicKey) KeyUsage() (string, bool) { if h.keyUsage != nil { - return *(h.keyUsage) + return *(h.keyUsage), true } - return "" + return "", false } -func (h *okpPublicKey) X() []byte { - return h.x +func (h *okpPublicKey) X() ([]byte, bool) { + if h.x != nil { + return h.x, true + } + return nil, false } -func (h *okpPublicKey) X509CertChain() *cert.Chain { - return h.x509CertChain +func (h *okpPublicKey) X509CertChain() (*cert.Chain, bool) { + return h.x509CertChain, true } -func (h *okpPublicKey) X509CertThumbprint() string { +func (h *okpPublicKey) X509CertThumbprint() (string, bool) { if h.x509CertThumbprint != nil { - return *(h.x509CertThumbprint) + return *(h.x509CertThumbprint), true } - return "" + return "", false } -func (h *okpPublicKey) X509CertThumbprintS256() string { +func (h *okpPublicKey) X509CertThumbprintS256() (string, bool) { if h.x509CertThumbprintS256 != nil { - return *(h.x509CertThumbprintS256) + return *(h.x509CertThumbprintS256), true } - return "" + return "", false } -func (h *okpPublicKey) X509URL() string { +func (h *okpPublicKey) X509URL() (string, bool) { if h.x509URL != nil { - return *(h.x509URL) + return *(h.x509URL), true } - return "" + return "", false } func (h *okpPublicKey) Has(name string) bool { @@ -646,9 +649,9 @@ func (h *okpPublicKey) Keys() []string { type OKPPrivateKey interface { Key - Crv() jwa.EllipticCurveAlgorithm - D() []byte - X() []byte + Crv() (jwa.EllipticCurveAlgorithm, bool) + D() ([]byte, bool) + X() ([]byte, bool) } type okpPrivateKey struct { @@ -686,72 +689,78 @@ func (h okpPrivateKey) IsPrivate() bool { return true } -func (h *okpPrivateKey) Algorithm() jwa.KeyAlgorithm { +func (h *okpPrivateKey) Algorithm() (jwa.KeyAlgorithm, bool) { if h.algorithm != nil { - return *(h.algorithm) + return *(h.algorithm), true } - return nil + return nil, false } -func (h *okpPrivateKey) Crv() jwa.EllipticCurveAlgorithm { +func (h *okpPrivateKey) Crv() (jwa.EllipticCurveAlgorithm, bool) { if h.crv != nil { - return *(h.crv) + return *(h.crv), true } - return jwa.InvalidEllipticCurve() + return jwa.InvalidEllipticCurve(), false } -func (h *okpPrivateKey) D() []byte { - return h.d +func (h *okpPrivateKey) D() ([]byte, bool) { + if h.d != nil { + return h.d, true + } + return nil, false } -func (h *okpPrivateKey) KeyID() string { +func (h *okpPrivateKey) KeyID() (string, bool) { if h.keyID != nil { - return *(h.keyID) + return *(h.keyID), true } - return "" + return "", false } -func (h *okpPrivateKey) KeyOps() KeyOperationList { +func (h *okpPrivateKey) KeyOps() (KeyOperationList, bool) { if h.keyOps != nil { - return *(h.keyOps) + return *(h.keyOps), true } - return nil + return nil, false } -func (h *okpPrivateKey) KeyUsage() string { +func (h *okpPrivateKey) KeyUsage() (string, bool) { if h.keyUsage != nil { - return *(h.keyUsage) + return *(h.keyUsage), true } - return "" + return "", false } -func (h *okpPrivateKey) X() []byte { - return h.x +func (h *okpPrivateKey) X() ([]byte, bool) { + if h.x != nil { + return h.x, true + } + return nil, false } -func (h *okpPrivateKey) X509CertChain() *cert.Chain { - return h.x509CertChain +func (h *okpPrivateKey) X509CertChain() (*cert.Chain, bool) { + return h.x509CertChain, true } -func (h *okpPrivateKey) X509CertThumbprint() string { +func (h *okpPrivateKey) X509CertThumbprint() (string, bool) { if h.x509CertThumbprint != nil { - return *(h.x509CertThumbprint) + return *(h.x509CertThumbprint), true } - return "" + return "", false } -func (h *okpPrivateKey) X509CertThumbprintS256() string { +func (h *okpPrivateKey) X509CertThumbprintS256() (string, bool) { if h.x509CertThumbprintS256 != nil { - return *(h.x509CertThumbprintS256) + return *(h.x509CertThumbprintS256), true } - return "" + return "", false } -func (h *okpPrivateKey) X509URL() string { +func (h *okpPrivateKey) X509URL() (string, bool) { if h.x509URL != nil { - return *(h.x509URL) + return *(h.x509URL), true } - return "" + return "", false } func (h *okpPrivateKey) Has(name string) bool { diff --git a/jwk/rsa.go b/jwk/rsa.go index 0dfbeb75..1245de85 100644 --- a/jwk/rsa.go +++ b/jwk/rsa.go @@ -255,21 +255,31 @@ func rsaThumbprint(hash crypto.Hash, key *rsa.PublicKey) ([]byte, error) { } func validateRSAKey(key interface { - N() []byte - E() []byte + N() ([]byte, bool) + E() ([]byte, bool) }, checkPrivate bool) error { - if len(key.N()) == 0 { + n, ok := key.N() + if !ok { + return fmt.Errorf(`missing "n" value`) + } + + e, ok := key.E() + if !ok { + return fmt.Errorf(`missing "e" value`) + } + + if len(n) == 0 { // Ideally we would like to check for the actual length, but unlike // EC keys, we have nothing in the key itself that will tell us // how many bits this key should have. return fmt.Errorf(`missing "n" value`) } - if len(key.E()) == 0 { + if len(e) == 0 { return fmt.Errorf(`missing "e" value`) } if checkPrivate { - if priv, ok := key.(interface{ D() []byte }); ok { - if len(priv.D()) == 0 { + if priv, ok := key.(keyWithD); ok { + if d, ok := priv.D(); !ok || len(d) == 0 { return fmt.Errorf(`missing "d" value`) } } else { diff --git a/jwk/rsa_gen.go b/jwk/rsa_gen.go index 409894c6..049e5005 100644 --- a/jwk/rsa_gen.go +++ b/jwk/rsa_gen.go @@ -29,8 +29,8 @@ const ( type RSAPublicKey interface { Key - E() []byte - N() []byte + E() ([]byte, bool) + N() ([]byte, bool) } type rsaPublicKey struct { @@ -67,65 +67,71 @@ func (h rsaPublicKey) IsPrivate() bool { return false } -func (h *rsaPublicKey) Algorithm() jwa.KeyAlgorithm { +func (h *rsaPublicKey) Algorithm() (jwa.KeyAlgorithm, bool) { if h.algorithm != nil { - return *(h.algorithm) + return *(h.algorithm), true } - return nil + return nil, false } -func (h *rsaPublicKey) E() []byte { - return h.e +func (h *rsaPublicKey) E() ([]byte, bool) { + if h.e != nil { + return h.e, true + } + return nil, false } -func (h *rsaPublicKey) KeyID() string { +func (h *rsaPublicKey) KeyID() (string, bool) { if h.keyID != nil { - return *(h.keyID) + return *(h.keyID), true } - return "" + return "", false } -func (h *rsaPublicKey) KeyOps() KeyOperationList { +func (h *rsaPublicKey) KeyOps() (KeyOperationList, bool) { if h.keyOps != nil { - return *(h.keyOps) + return *(h.keyOps), true } - return nil + return nil, false } -func (h *rsaPublicKey) KeyUsage() string { +func (h *rsaPublicKey) KeyUsage() (string, bool) { if h.keyUsage != nil { - return *(h.keyUsage) + return *(h.keyUsage), true } - return "" + return "", false } -func (h *rsaPublicKey) N() []byte { - return h.n +func (h *rsaPublicKey) N() ([]byte, bool) { + if h.n != nil { + return h.n, true + } + return nil, false } -func (h *rsaPublicKey) X509CertChain() *cert.Chain { - return h.x509CertChain +func (h *rsaPublicKey) X509CertChain() (*cert.Chain, bool) { + return h.x509CertChain, true } -func (h *rsaPublicKey) X509CertThumbprint() string { +func (h *rsaPublicKey) X509CertThumbprint() (string, bool) { if h.x509CertThumbprint != nil { - return *(h.x509CertThumbprint) + return *(h.x509CertThumbprint), true } - return "" + return "", false } -func (h *rsaPublicKey) X509CertThumbprintS256() string { +func (h *rsaPublicKey) X509CertThumbprintS256() (string, bool) { if h.x509CertThumbprintS256 != nil { - return *(h.x509CertThumbprintS256) + return *(h.x509CertThumbprintS256), true } - return "" + return "", false } -func (h *rsaPublicKey) X509URL() string { +func (h *rsaPublicKey) X509URL() (string, bool) { if h.x509URL != nil { - return *(h.x509URL) + return *(h.x509URL), true } - return "" + return "", false } func (h *rsaPublicKey) Has(name string) bool { @@ -646,14 +652,14 @@ func (h *rsaPublicKey) Keys() []string { type RSAPrivateKey interface { Key - D() []byte - DP() []byte - DQ() []byte - E() []byte - N() []byte - P() []byte - Q() []byte - QI() []byte + D() ([]byte, bool) + DP() ([]byte, bool) + DQ() ([]byte, bool) + E() ([]byte, bool) + N() ([]byte, bool) + P() ([]byte, bool) + Q() ([]byte, bool) + QI() ([]byte, bool) } type rsaPrivateKey struct { @@ -696,89 +702,113 @@ func (h rsaPrivateKey) IsPrivate() bool { return true } -func (h *rsaPrivateKey) Algorithm() jwa.KeyAlgorithm { +func (h *rsaPrivateKey) Algorithm() (jwa.KeyAlgorithm, bool) { if h.algorithm != nil { - return *(h.algorithm) + return *(h.algorithm), true } - return nil + return nil, false } -func (h *rsaPrivateKey) D() []byte { - return h.d +func (h *rsaPrivateKey) D() ([]byte, bool) { + if h.d != nil { + return h.d, true + } + return nil, false } -func (h *rsaPrivateKey) DP() []byte { - return h.dp +func (h *rsaPrivateKey) DP() ([]byte, bool) { + if h.dp != nil { + return h.dp, true + } + return nil, false } -func (h *rsaPrivateKey) DQ() []byte { - return h.dq +func (h *rsaPrivateKey) DQ() ([]byte, bool) { + if h.dq != nil { + return h.dq, true + } + return nil, false } -func (h *rsaPrivateKey) E() []byte { - return h.e +func (h *rsaPrivateKey) E() ([]byte, bool) { + if h.e != nil { + return h.e, true + } + return nil, false } -func (h *rsaPrivateKey) KeyID() string { +func (h *rsaPrivateKey) KeyID() (string, bool) { if h.keyID != nil { - return *(h.keyID) + return *(h.keyID), true } - return "" + return "", false } -func (h *rsaPrivateKey) KeyOps() KeyOperationList { +func (h *rsaPrivateKey) KeyOps() (KeyOperationList, bool) { if h.keyOps != nil { - return *(h.keyOps) + return *(h.keyOps), true } - return nil + return nil, false } -func (h *rsaPrivateKey) KeyUsage() string { +func (h *rsaPrivateKey) KeyUsage() (string, bool) { if h.keyUsage != nil { - return *(h.keyUsage) + return *(h.keyUsage), true } - return "" + return "", false } -func (h *rsaPrivateKey) N() []byte { - return h.n +func (h *rsaPrivateKey) N() ([]byte, bool) { + if h.n != nil { + return h.n, true + } + return nil, false } -func (h *rsaPrivateKey) P() []byte { - return h.p +func (h *rsaPrivateKey) P() ([]byte, bool) { + if h.p != nil { + return h.p, true + } + return nil, false } -func (h *rsaPrivateKey) Q() []byte { - return h.q +func (h *rsaPrivateKey) Q() ([]byte, bool) { + if h.q != nil { + return h.q, true + } + return nil, false } -func (h *rsaPrivateKey) QI() []byte { - return h.qi +func (h *rsaPrivateKey) QI() ([]byte, bool) { + if h.qi != nil { + return h.qi, true + } + return nil, false } -func (h *rsaPrivateKey) X509CertChain() *cert.Chain { - return h.x509CertChain +func (h *rsaPrivateKey) X509CertChain() (*cert.Chain, bool) { + return h.x509CertChain, true } -func (h *rsaPrivateKey) X509CertThumbprint() string { +func (h *rsaPrivateKey) X509CertThumbprint() (string, bool) { if h.x509CertThumbprint != nil { - return *(h.x509CertThumbprint) + return *(h.x509CertThumbprint), true } - return "" + return "", false } -func (h *rsaPrivateKey) X509CertThumbprintS256() string { +func (h *rsaPrivateKey) X509CertThumbprintS256() (string, bool) { if h.x509CertThumbprintS256 != nil { - return *(h.x509CertThumbprintS256) + return *(h.x509CertThumbprintS256), true } - return "" + return "", false } -func (h *rsaPrivateKey) X509URL() string { +func (h *rsaPrivateKey) X509URL() (string, bool) { if h.x509URL != nil { - return *(h.x509URL) + return *(h.x509URL), true } - return "" + return "", false } func (h *rsaPrivateKey) Has(name string) bool { diff --git a/jwk/set.go b/jwk/set.go index 1afaf1e2..6ab79ecf 100644 --- a/jwk/set.go +++ b/jwk/set.go @@ -277,7 +277,8 @@ func (s *set) LookupKeyID(kid string) (Key, bool) { if !ok { return nil, false } - if key.KeyID() == kid { + gotkid, ok := key.KeyID() + if ok && gotkid == kid { return key, true } } diff --git a/jwk/symmetric.go b/jwk/symmetric.go index 29525e13..2ead5d01 100644 --- a/jwk/symmetric.go +++ b/jwk/symmetric.go @@ -77,7 +77,8 @@ func (k *symmetricKey) PublicKey() (Key, error) { } func (k *symmetricKey) Validate() error { - if len(k.Octets()) == 0 { + octets, ok := k.Octets() + if !ok || len(octets) == 0 { return NewKeyValidationError(fmt.Errorf(`jwk.SymmetricKey: missing "k" field`)) } return nil diff --git a/jwk/symmetric_gen.go b/jwk/symmetric_gen.go index 0f0df750..ec522067 100644 --- a/jwk/symmetric_gen.go +++ b/jwk/symmetric_gen.go @@ -22,7 +22,7 @@ const ( type SymmetricKey interface { Key - Octets() []byte + Octets() ([]byte, bool) } type symmetricKey struct { @@ -54,61 +54,64 @@ func (h symmetricKey) KeyType() jwa.KeyType { return jwa.OctetSeq() } -func (h *symmetricKey) Algorithm() jwa.KeyAlgorithm { +func (h *symmetricKey) Algorithm() (jwa.KeyAlgorithm, bool) { if h.algorithm != nil { - return *(h.algorithm) + return *(h.algorithm), true } - return nil + return nil, false } -func (h *symmetricKey) KeyID() string { +func (h *symmetricKey) KeyID() (string, bool) { if h.keyID != nil { - return *(h.keyID) + return *(h.keyID), true } - return "" + return "", false } -func (h *symmetricKey) KeyOps() KeyOperationList { +func (h *symmetricKey) KeyOps() (KeyOperationList, bool) { if h.keyOps != nil { - return *(h.keyOps) + return *(h.keyOps), true } - return nil + return nil, false } -func (h *symmetricKey) KeyUsage() string { +func (h *symmetricKey) KeyUsage() (string, bool) { if h.keyUsage != nil { - return *(h.keyUsage) + return *(h.keyUsage), true } - return "" + return "", false } -func (h *symmetricKey) Octets() []byte { - return h.octets +func (h *symmetricKey) Octets() ([]byte, bool) { + if h.octets != nil { + return h.octets, true + } + return nil, false } -func (h *symmetricKey) X509CertChain() *cert.Chain { - return h.x509CertChain +func (h *symmetricKey) X509CertChain() (*cert.Chain, bool) { + return h.x509CertChain, true } -func (h *symmetricKey) X509CertThumbprint() string { +func (h *symmetricKey) X509CertThumbprint() (string, bool) { if h.x509CertThumbprint != nil { - return *(h.x509CertThumbprint) + return *(h.x509CertThumbprint), true } - return "" + return "", false } -func (h *symmetricKey) X509CertThumbprintS256() string { +func (h *symmetricKey) X509CertThumbprintS256() (string, bool) { if h.x509CertThumbprintS256 != nil { - return *(h.x509CertThumbprintS256) + return *(h.x509CertThumbprintS256), true } - return "" + return "", false } -func (h *symmetricKey) X509URL() string { +func (h *symmetricKey) X509URL() (string, bool) { if h.x509URL != nil { - return *(h.x509URL) + return *(h.x509URL), true } - return "" + return "", false } func (h *symmetricKey) Has(name string) bool { diff --git a/jws/jws.go b/jws/jws.go index 21218b10..88884867 100644 --- a/jws/jws.go +++ b/jws/jws.go @@ -249,7 +249,7 @@ func Sign(payload []byte, options ...SignOption) ([]byte, error) { } if key, ok := signer.key.(jwk.Key); ok { - if kid := key.KeyID(); kid != "" { + if kid, ok := key.KeyID(); ok && kid != "" { if err := protected.Set(KeyIDKey, kid); err != nil { return nil, fmt.Errorf(`failed to set "kid" header: %w`, err) } diff --git a/jws/key_provider.go b/jws/key_provider.go index fbfbacc6..6a2820c5 100644 --- a/jws/key_provider.go +++ b/jws/key_provider.go @@ -110,14 +110,17 @@ type keySetProvider struct { } func (kp *keySetProvider) selectKey(sink KeySink, key jwk.Key, sig *Signature, _ *Message) error { - if usage := key.KeyUsage(); usage != "" && usage != jwk.ForSignature.String() { - return nil + if usage, ok := key.KeyUsage(); ok { + // it's okay if use: "". we'll assume it's "sig" + if usage != "" && usage != jwk.ForSignature.String() { + return nil + } } - if v := key.Algorithm(); v != nil { + if v, ok := key.Algorithm(); ok { salg, ok := jwa.LookupSignatureAlgorithm(v.String()) if !ok { - return fmt.Errorf(`invalid signature algorithm %q`, key.Algorithm()) + return fmt.Errorf(`invalid signature algorithm %q`, v) } sink.Key(salg, key) @@ -187,7 +190,7 @@ func (kp *keySetProvider) FetchKeys(_ context.Context, sink KeySink, sig *Signat ok = false for i := range kp.set.Len() { key, _ := kp.set.Key(i) - if key.KeyID() != wantedKid { + if kid, ok := key.KeyID(); !ok || kid != wantedKid { continue } diff --git a/jws/message.go b/jws/message.go index 902d1823..86a6344d 100644 --- a/jws/message.go +++ b/jws/message.go @@ -113,7 +113,7 @@ func (s *Signature) Sign(payload []byte, signer Signer, key interface{}) ([]byte // If the key is a jwk.Key instance, obtain the raw key if jwkKey, ok := key.(jwk.Key); ok { // If we have a key ID specified by this jwk.Key, use that in the header - if kid := jwkKey.KeyID(); kid != "" { + if kid, ok := jwkKey.KeyID(); ok && kid != "" { if err := hdrs.Set(jwk.KeyIDKey, kid); err != nil { return nil, nil, fmt.Errorf(`set key ID from jwk.Key: %w`, err) } diff --git a/tools/cmd/genjwk/main.go b/tools/cmd/genjwk/main.go index c1e39d3a..2351887b 100644 --- a/tools/cmd/genjwk/main.go +++ b/tools/cmd/genjwk/main.go @@ -180,7 +180,7 @@ func generateObject(o *codegen.Output, kt *KeyType, obj *codegen.Object) error { if f.Bool(`is_std`) { continue } - o.L("%s() %s", f.GetterMethod(true), f.Type()) + o.L("%s() (%s, bool)", f.GetterMethod(true), f.Type()) } o.L("}") @@ -217,7 +217,7 @@ func generateObject(o *codegen.Output, kt *KeyType, obj *codegen.Object) error { } for _, f := range obj.Fields() { - o.LL("func (h *%s) %s() ", structName, f.GetterMethod(true)) + o.LL("func (h *%s) %s() (", structName, f.GetterMethod(true)) if v := f.String(`getter_return_value`); v != "" { o.R("%s", v) } else if IsPointer(f) && f.Bool(`noDeref`) { @@ -225,24 +225,29 @@ func generateObject(o *codegen.Output, kt *KeyType, obj *codegen.Object) error { } else { o.R("%s", PointerElem(f)) } - o.R(" {") + o.R(", bool) {") if f.Bool(`hasGet`) { o.L("if h.%s != nil {", f.Name(false)) - o.L("return h.%s.Get()", f.Name(false)) + o.L("return h.%s.Get(), true", f.Name(false)) o.L("}") - o.L("return %s", codegen.ZeroVal(PointerElem(f))) + o.L("return %s, false", codegen.ZeroVal(PointerElem(f))) } else if !IsPointer(f) { if fieldStorageTypeIsIndirect(f.Type()) { o.L("if h.%s != nil {", f.Name(false)) - o.L("return *(h.%s)", f.Name(false)) + o.L("return *(h.%s), true", f.Name(false)) o.L("}") - o.L("return %s", codegen.ZeroVal(PointerElem(f))) + o.L("return %s, false", codegen.ZeroVal(PointerElem(f))) + } else if strings.HasPrefix(f.Type(), `[]`) { + o.L("if h.%s != nil {", f.Name(false)) + o.L("return h.%s, true", f.Name(false)) + o.L("}") + o.L("return nil, false") } else { - o.L("return h.%s", f.Name(false)) + o.L("return h.%s, true", f.Name(false)) } } else { - o.L(`return h.%s`, f.Name(false)) + o.L(`return h.%s, true`, f.Name(false)) } o.L("}") // func (h *stdHeaders) %s() %s } @@ -706,10 +711,10 @@ func generateGenericHeaders(fields codegen.FieldList) error { if f.Name(false) == "algorithm" { o.LL("// Algorithm returns the value of the `alg` field.") o.L("//") - o.L("// This field may contain either `jwk.SignatureAlgorithm` or `jwk.KeyEncryptionAlgorithm`.") + o.L("// This field may contain either `jwk.SignatureAlgorithm`, `jwk.KeyEncryptionAlgorithm`, or `jwk.ContentEncryptionAlgorithm`.") o.L("// This is why there exists a `jwa.KeyAlgorithm` type that encompasses both types.") } - o.L("%s() ", f.GetterMethod(true)) + o.L("%s() (", f.GetterMethod(true)) if v := f.String(`getter_return_value`); v != "" { o.R("%s", v) } else if IsPointer(f) && f.Bool(`noDeref`) { @@ -717,6 +722,7 @@ func generateGenericHeaders(fields codegen.FieldList) error { } else { o.R("%s", PointerElem(f)) } + o.R(", bool)") } o.L("}") From 79cddb55062db6c22a410cba87c7be7c847c864d Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Mon, 7 Oct 2024 10:10:52 +0900 Subject: [PATCH 2/6] fix jwe test --- jwe/jwe_test.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/jwe/jwe_test.go b/jwe/jwe_test.go index 9461b6fd..69bd71f1 100644 --- a/jwe/jwe_test.go +++ b/jwe/jwe_test.go @@ -614,7 +614,9 @@ func TestGH554(t *testing.T) { pubkey, err := jwk.PublicKeyOf(privkey) require.NoError(t, err, `jwk.PublicKeyOf() should succeed`) - require.Equal(t, keyID, pubkey.KeyID(), `key ID should match`) + kid, ok := pubkey.KeyID() + require.True(t, ok, `key ID should be present`) + require.Equal(t, keyID, kid, `key ID should match`) encrypted, err := jwe.Encrypt([]byte(plaintext), jwe.WithKey(jwa.ECDH_ES(), pubkey)) require.NoError(t, err, `jwk.Encrypt() should succeed`) @@ -625,9 +627,9 @@ func TestGH554(t *testing.T) { recipients := msg.Recipients() // The epk must have the same key ID as the original - kid, ok := recipients[0].Headers().KeyID() + epkKid, ok := recipients[0].Headers().KeyID() require.True(t, ok, `key ID should be present`) - require.Equal(t, keyID, kid, `key ID in epk should match`) + require.Equal(t, keyID, epkKid, `key ID in epk should match`) } func TestGH803(t *testing.T) { From cce2543cbd5120e614fa2f95baaa4c30c468a81f Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Mon, 7 Oct 2024 10:14:31 +0900 Subject: [PATCH 3/6] fix jwx tests --- jws/jws_test.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/jws/jws_test.go b/jws/jws_test.go index 8a9d9b1d..5a5afdf0 100644 --- a/jws/jws_test.go +++ b/jws/jws_test.go @@ -1420,7 +1420,14 @@ func TestJKU(t *testing.T) { for _, key := range []jwk.Key{unusedKeys[0], keys[1], unusedKeys[1]} { pubkey, err := jwk.PublicKeyOf(key) require.NoError(t, err, `jwk.PublicKeyOf should succeed`) - require.Equal(t, pubkey.KeyID(), key.KeyID(), `key ID should be populated`) + + kid, ok := key.KeyID() + require.True(t, ok, `key ID should be populated`) + + pubkid, ok := pubkey.KeyID() + require.True(t, ok, `key ID should be populated`) + + require.Equal(t, kid, pubkid, `key ID should be populated`) set.AddKey(pubkey) } srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -1844,7 +1851,8 @@ func TestValidateKey(t *testing.T) { pubKey, err := jwk.PublicKeyOf(privKey) require.NoError(t, err, `jwk.PublicKeyOf should succeed`) - n := pubKey.(jwk.RSAPublicKey).N() + n, ok := pubKey.(jwk.RSAPublicKey).N() + require.True(t, ok, `N should be present`) // Set N to an empty value require.NoError(t, pubKey.Set(jwk.RSANKey, []byte(nil)), `jwk.Set should succeed`) From 1e5423f7f5c5524aeed40997b6629afe5c55fae8 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Mon, 7 Oct 2024 10:18:38 +0900 Subject: [PATCH 4/6] Fix jwt test --- jwt/jwt_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index a09c9e01..fdd97a3b 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -553,7 +553,9 @@ func TestSignJWK(t *testing.T) { require.NoError(t, key.Set(jwk.AlgorithmKey, jwa.RS256()), `key.Set should succeed`) tok := jwt.New() - signed, err := jwt.Sign(tok, jwt.WithKey(key.Algorithm(), key)) + alg, ok := key.Algorithm() + require.True(t, ok, `key.Algorithm should succeed`) + signed, err := jwt.Sign(tok, jwt.WithKey(alg, key)) require.Nil(t, err) header, err := jws.ParseString(string(signed)) From ca120945fd15f3625cc3309f730d6fae1938cbbf Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Mon, 7 Oct 2024 10:24:30 +0900 Subject: [PATCH 5/6] fix example --- .../jwk_key_specific_methods_example_test.go | 16 ++++++------- .../jwx_register_ec_and_key_example_test.go | 24 +++++++++++++++---- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/examples/jwk_key_specific_methods_example_test.go b/examples/jwk_key_specific_methods_example_test.go index d39cc2cd..ef5124f2 100644 --- a/examples/jwk_key_specific_methods_example_test.go +++ b/examples/jwk_key_specific_methods_example_test.go @@ -31,14 +31,14 @@ func ExampleJWK_KeySpecificMethods() { // generated the contents will be different, and thus our // tests would fail. But here you can see that once you // convert the type you can access the RSA-specific methods - _ = rsakey.D() - _ = rsakey.DP() - _ = rsakey.DQ() - _ = rsakey.E() - _ = rsakey.N() - _ = rsakey.P() - _ = rsakey.Q() - _ = rsakey.QI() + _, _ = rsakey.D() + _, _ = rsakey.DP() + _, _ = rsakey.DQ() + _, _ = rsakey.E() + _, _ = rsakey.N() + _, _ = rsakey.P() + _, _ = rsakey.Q() + _, _ = rsakey.QI() // OUTPUT: // } diff --git a/examples/jwx_register_ec_and_key_example_test.go b/examples/jwx_register_ec_and_key_example_test.go index f232b90f..2631bcbe 100644 --- a/examples/jwx_register_ec_and_key_example_test.go +++ b/examples/jwx_register_ec_and_key_example_test.go @@ -55,8 +55,8 @@ func convertJWKToShangMiSm2(key jwk.Key, hint interface{}) (interface{}, error) if !ok { return nil, fmt.Errorf(`invalid key type %T: %w`, key, jwk.ContinueError()) } - if ecdsaKey.Crv() != SM2 { - return nil, fmt.Errorf(`cannot convert curve of type %s to ShangMi key: %w`, ecdsaKey.Crv(), jwk.ContinueError()) + if v, ok := ecdsaKey.Crv(); !ok || v != SM2 { + return nil, fmt.Errorf(`cannote convert curve of type %s to ShangMi key: %w`, v, jwk.ContinueError()) } switch hint.(type) { @@ -67,9 +67,23 @@ func convertJWKToShangMiSm2(key jwk.Key, hint interface{}) (interface{}, error) var ret sm2.PrivateKey ret.PublicKey.Curve = sm2.P256() - ret.D = (&big.Int{}).SetBytes(ecdsaKey.D()) - ret.PublicKey.X = (&big.Int{}).SetBytes(ecdsaKey.X()) - ret.PublicKey.Y = (&big.Int{}).SetBytes(ecdsaKey.Y()) + d, ok := ecdsaKey.D() + if !ok { + return nil, fmt.Errorf(`missing D field in ECDSA private key: %w`, jwk.ContinueError()) + } + ret.D = (&big.Int{}).SetBytes(d) + + x, ok := ecdsaKey.X() + if !ok { + return nil, fmt.Errorf(`missing X field in ECDSA private key: %w`, jwk.ContinueError()) + } + ret.PublicKey.X = (&big.Int{}).SetBytes(x) + + y, ok := ecdsaKey.Y() + if !ok { + return nil, fmt.Errorf(`missing Y field in ECDSA private key: %w`, jwk.ContinueError()) + } + ret.PublicKey.Y = (&big.Int{}).SetBytes(y) return &ret, nil } From de07b8cd5a8679f2f0c369e7dc3a31c62fc78e57 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Mon, 7 Oct 2024 10:33:32 +0900 Subject: [PATCH 6/6] Update changes --- Changes-v3.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Changes-v3.md b/Changes-v3.md index 57dbb616..c96b38ff 100644 --- a/Changes-v3.md +++ b/Changes-v3.md @@ -38,6 +38,9 @@ These are changes that are incompatible with the v2.x.x version. ## JWK +* All convenience accssors (e.g. `Algorithm`, `Crv`) now return `(T, bool)` instead + of just `T`, except `KeyType`, which _always_ returns a valid value. + * `jwk.KeyUsageType` can now be configured so that it's possible to assign values other than "sig" and "enc" via `jwk.RegisterKeyUsage()`. Furthermore, strict checks can be turned on/off against these registered values