Skip to content

Commit

Permalink
Change accessors in jwk to return (T, bool) (#1207)
Browse files Browse the repository at this point in the history
* Change accessors in jwk to return (T, bool)

* fix jwe test

* fix jwx tests

* Fix jwt test

* fix example

* Update changes

---------

Co-authored-by: Daisuke Maki <lestrrat+github@users.noreplay.github.com>
  • Loading branch information
lestrrat and Daisuke Maki authored Oct 7, 2024
1 parent b7cb51f commit 5b97182
Show file tree
Hide file tree
Showing 26 changed files with 540 additions and 347 deletions.
3 changes: 3 additions & 0 deletions Changes-v3.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ These are changes that are incompatible with the v2.x.x version.

## JWK

* All convenience accssors (e.g. `Algorithm`, `Crv`) now return `(T, bool)` instead
of just `T`, except `KeyType`, which _always_ returns a valid value.

* `jwk.KeyUsageType` can now be configured so that it's possible to assign values
other than "sig" and "enc" via `jwk.RegisterKeyUsage()`. Furthermore, strict
checks can be turned on/off against these registered values
Expand Down
16 changes: 8 additions & 8 deletions examples/jwk_key_specific_methods_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ func ExampleJWK_KeySpecificMethods() {
// generated the contents will be different, and thus our
// tests would fail. But here you can see that once you
// convert the type you can access the RSA-specific methods
_ = rsakey.D()
_ = rsakey.DP()
_ = rsakey.DQ()
_ = rsakey.E()
_ = rsakey.N()
_ = rsakey.P()
_ = rsakey.Q()
_ = rsakey.QI()
_, _ = rsakey.D()
_, _ = rsakey.DP()
_, _ = rsakey.DQ()
_, _ = rsakey.E()
_, _ = rsakey.N()
_, _ = rsakey.P()
_, _ = rsakey.Q()
_, _ = rsakey.QI()
// OUTPUT:
//
}
24 changes: 19 additions & 5 deletions examples/jwx_register_ec_and_key_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ func convertJWKToShangMiSm2(key jwk.Key, hint interface{}) (interface{}, error)
if !ok {
return nil, fmt.Errorf(`invalid key type %T: %w`, key, jwk.ContinueError())
}
if ecdsaKey.Crv() != SM2 {
return nil, fmt.Errorf(`cannot convert curve of type %s to ShangMi key: %w`, ecdsaKey.Crv(), jwk.ContinueError())
if v, ok := ecdsaKey.Crv(); !ok || v != SM2 {
return nil, fmt.Errorf(`cannote convert curve of type %s to ShangMi key: %w`, v, jwk.ContinueError())
}

switch hint.(type) {
Expand All @@ -67,9 +67,23 @@ func convertJWKToShangMiSm2(key jwk.Key, hint interface{}) (interface{}, error)

var ret sm2.PrivateKey
ret.PublicKey.Curve = sm2.P256()
ret.D = (&big.Int{}).SetBytes(ecdsaKey.D())
ret.PublicKey.X = (&big.Int{}).SetBytes(ecdsaKey.X())
ret.PublicKey.Y = (&big.Int{}).SetBytes(ecdsaKey.Y())
d, ok := ecdsaKey.D()
if !ok {
return nil, fmt.Errorf(`missing D field in ECDSA private key: %w`, jwk.ContinueError())
}
ret.D = (&big.Int{}).SetBytes(d)

x, ok := ecdsaKey.X()
if !ok {
return nil, fmt.Errorf(`missing X field in ECDSA private key: %w`, jwk.ContinueError())
}
ret.PublicKey.X = (&big.Int{}).SetBytes(x)

y, ok := ecdsaKey.Y()
if !ok {
return nil, fmt.Errorf(`missing Y field in ECDSA private key: %w`, jwk.ContinueError())
}
ret.PublicKey.Y = (&big.Int{}).SetBytes(y)
return &ret, nil
}

Expand Down
2 changes: 1 addition & 1 deletion jwe/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type KeyEncrypter interface {
// As of this writing, this is solely used to identify KeyEncrypter
// objects that also carry a key ID on its own.
type KeyIDer interface {
KeyID() string
KeyID() (string, bool)
}

// KeyDecrypter is an interface for objects that can decrypt a content
Expand Down
8 changes: 6 additions & 2 deletions jwe/jwe.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,15 @@ func (b *recipientBuilder) Build(cek []byte, calg jwa.ContentEncryptionAlgorithm
if ke, ok := b.key.(KeyEncrypter); ok {
enc = &keyEncrypterWrapper{encrypter: ke}
if kider, ok := enc.(KeyIDer); ok {
keyID = kider.KeyID()
if v, ok := kider.KeyID(); ok {
keyID = v
}
}
} else if jwkKey, ok := b.key.(jwk.Key); ok {
// Meanwhile, grab the kid as well
keyID = jwkKey.KeyID()
if v, ok := jwkKey.KeyID(); ok {
keyID = v
}

var raw interface{}
if err := jwk.Export(jwkKey, &raw); err != nil {
Expand Down
8 changes: 5 additions & 3 deletions jwe/jwe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,9 @@ func TestGH554(t *testing.T) {

pubkey, err := jwk.PublicKeyOf(privkey)
require.NoError(t, err, `jwk.PublicKeyOf() should succeed`)
require.Equal(t, keyID, pubkey.KeyID(), `key ID should match`)
kid, ok := pubkey.KeyID()
require.True(t, ok, `key ID should be present`)
require.Equal(t, keyID, kid, `key ID should match`)

encrypted, err := jwe.Encrypt([]byte(plaintext), jwe.WithKey(jwa.ECDH_ES(), pubkey))
require.NoError(t, err, `jwk.Encrypt() should succeed`)
Expand All @@ -625,9 +627,9 @@ func TestGH554(t *testing.T) {
recipients := msg.Recipients()

// The epk must have the same key ID as the original
kid, ok := recipients[0].Headers().KeyID()
epkKid, ok := recipients[0].Headers().KeyID()
require.True(t, ok, `key ID should be present`)
require.Equal(t, keyID, kid, `key ID in epk should match`)
require.Equal(t, keyID, epkKid, `key ID in epk should match`)
}

func TestGH803(t *testing.T) {
Expand Down
10 changes: 6 additions & 4 deletions jwe/key_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,16 @@ type keySetProvider struct {
}

func (kp *keySetProvider) selectKey(sink KeySink, key jwk.Key, _ Recipient, _ *Message) error {
if usage := key.KeyUsage(); usage != "" && usage != jwk.ForEncryption.String() {
return nil
if usage, ok := key.KeyUsage(); ok {
if usage != "" && usage != jwk.ForEncryption.String() {
return nil
}
}

if v := key.Algorithm(); v != nil {
if v, ok := key.Algorithm(); ok {
kalg, ok := jwa.LookupKeyEncryptionAlgorithm(v.String())
if !ok {
return fmt.Errorf(`invalid key encryption algorithm %s`, key.Algorithm())
return fmt.Errorf(`invalid key encryption algorithm %s`, v)
}

sink.Key(kalg, key)
Expand Down
39 changes: 27 additions & 12 deletions jwk/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,12 @@ func ecdsaJWKToRaw(keyif Key, hint interface{}) (interface{}, error) {

k.mu.RLock()
defer k.mu.RUnlock()
return buildECDSAPublicKey(k.Crv(), k.x, k.y)

crv, ok := k.Crv()
if !ok {
return nil, fmt.Errorf(`missing "crv" field`)
}
return buildECDSAPublicKey(crv, k.x, k.y)
case *ecdsaPrivateKey:
switch hint.(type) {
case ecdsa.PrivateKey, *ecdsa.PrivateKey, interface{}:
Expand All @@ -123,7 +128,12 @@ func ecdsaJWKToRaw(keyif Key, hint interface{}) (interface{}, error) {

k.mu.RLock()
defer k.mu.RUnlock()
pubk, err := buildECDSAPublicKey(k.Crv(), k.x, k.y)

crv, ok := k.Crv()
if !ok {
return nil, fmt.Errorf(`missing "crv" field`)
}
pubk, err := buildECDSAPublicKey(crv, k.x, k.y)
if err != nil {
return nil, fmt.Errorf(`failed to build public key: %w`, err)
}
Expand Down Expand Up @@ -232,28 +242,33 @@ func (k ecdsaPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
}

func ecdsaValidateKey(k interface {
Crv() jwa.EllipticCurveAlgorithm
X() []byte
Y() []byte
Crv() (jwa.EllipticCurveAlgorithm, bool)
X() ([]byte, bool)
Y() ([]byte, bool)
}, checkPrivate bool) error {
crv, err := ourecdsa.CurveFromAlgorithm(k.Crv())
crvtyp, ok := k.Crv()
if !ok {
return fmt.Errorf(`missing "crv" field`)
}

crv, err := ourecdsa.CurveFromAlgorithm(crvtyp)
if err != nil {
return fmt.Errorf(`invalid curve algorithm %q: %w`, k.Crv(), err)
return fmt.Errorf(`invalid curve algorithm %q: %w`, crvtyp, err)
}

keySize := ecutil.CalculateKeySize(crv)
if x := k.X(); len(x) != keySize {
if x, ok := k.X(); !ok || len(x) != keySize {
return fmt.Errorf(`invalid "x" length (%d) for curve %q`, len(x), crv.Params().Name)
}

if y := k.Y(); len(y) != keySize {
if y, ok := k.Y(); !ok || len(y) != keySize {
return fmt.Errorf(`invalid "y" length (%d) for curve %q`, len(y), crv.Params().Name)
}

if checkPrivate {
if priv, ok := k.(interface{ D() []byte }); ok {
if len(priv.D()) != keySize {
return fmt.Errorf(`invalid "d" length (%d) for curve %q`, len(priv.D()), crv.Params().Name)
if priv, ok := k.(keyWithD); ok {
if d, ok := priv.D(); !ok || len(d) != keySize {
return fmt.Errorf(`invalid "d" length (%d) for curve %q`, len(d), crv.Params().Name)
}
} else {
return fmt.Errorf(`missing "d" value`)
Expand Down
Loading

0 comments on commit 5b97182

Please sign in to comment.