Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change accessors in jwk to return (T, bool) #1207

Merged
merged 6 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Changes-v3.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ These are changes that are incompatible with the v2.x.x version.

## JWK

* All convenience accssors (e.g. `Algorithm`, `Crv`) now return `(T, bool)` instead
of just `T`, except `KeyType`, which _always_ returns a valid value.

* `jwk.KeyUsageType` can now be configured so that it's possible to assign values
other than "sig" and "enc" via `jwk.RegisterKeyUsage()`. Furthermore, strict
checks can be turned on/off against these registered values
Expand Down
16 changes: 8 additions & 8 deletions examples/jwk_key_specific_methods_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ func ExampleJWK_KeySpecificMethods() {
// generated the contents will be different, and thus our
// tests would fail. But here you can see that once you
// convert the type you can access the RSA-specific methods
_ = rsakey.D()
_ = rsakey.DP()
_ = rsakey.DQ()
_ = rsakey.E()
_ = rsakey.N()
_ = rsakey.P()
_ = rsakey.Q()
_ = rsakey.QI()
_, _ = rsakey.D()
_, _ = rsakey.DP()
_, _ = rsakey.DQ()
_, _ = rsakey.E()
_, _ = rsakey.N()
_, _ = rsakey.P()
_, _ = rsakey.Q()
_, _ = rsakey.QI()
// OUTPUT:
//
}
24 changes: 19 additions & 5 deletions examples/jwx_register_ec_and_key_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ func convertJWKToShangMiSm2(key jwk.Key, hint interface{}) (interface{}, error)
if !ok {
return nil, fmt.Errorf(`invalid key type %T: %w`, key, jwk.ContinueError())
}
if ecdsaKey.Crv() != SM2 {
return nil, fmt.Errorf(`cannot convert curve of type %s to ShangMi key: %w`, ecdsaKey.Crv(), jwk.ContinueError())
if v, ok := ecdsaKey.Crv(); !ok || v != SM2 {
return nil, fmt.Errorf(`cannote convert curve of type %s to ShangMi key: %w`, v, jwk.ContinueError())
}

switch hint.(type) {
Expand All @@ -67,9 +67,23 @@ func convertJWKToShangMiSm2(key jwk.Key, hint interface{}) (interface{}, error)

var ret sm2.PrivateKey
ret.PublicKey.Curve = sm2.P256()
ret.D = (&big.Int{}).SetBytes(ecdsaKey.D())
ret.PublicKey.X = (&big.Int{}).SetBytes(ecdsaKey.X())
ret.PublicKey.Y = (&big.Int{}).SetBytes(ecdsaKey.Y())
d, ok := ecdsaKey.D()
if !ok {
return nil, fmt.Errorf(`missing D field in ECDSA private key: %w`, jwk.ContinueError())
}
ret.D = (&big.Int{}).SetBytes(d)

x, ok := ecdsaKey.X()
if !ok {
return nil, fmt.Errorf(`missing X field in ECDSA private key: %w`, jwk.ContinueError())
}
ret.PublicKey.X = (&big.Int{}).SetBytes(x)

y, ok := ecdsaKey.Y()
if !ok {
return nil, fmt.Errorf(`missing Y field in ECDSA private key: %w`, jwk.ContinueError())
}
ret.PublicKey.Y = (&big.Int{}).SetBytes(y)
return &ret, nil
}

Expand Down
2 changes: 1 addition & 1 deletion jwe/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type KeyEncrypter interface {
// As of this writing, this is solely used to identify KeyEncrypter
// objects that also carry a key ID on its own.
type KeyIDer interface {
KeyID() string
KeyID() (string, bool)
}

// KeyDecrypter is an interface for objects that can decrypt a content
Expand Down
8 changes: 6 additions & 2 deletions jwe/jwe.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,15 @@ func (b *recipientBuilder) Build(cek []byte, calg jwa.ContentEncryptionAlgorithm
if ke, ok := b.key.(KeyEncrypter); ok {
enc = &keyEncrypterWrapper{encrypter: ke}
if kider, ok := enc.(KeyIDer); ok {
keyID = kider.KeyID()
if v, ok := kider.KeyID(); ok {
keyID = v
}
}
} else if jwkKey, ok := b.key.(jwk.Key); ok {
// Meanwhile, grab the kid as well
keyID = jwkKey.KeyID()
if v, ok := jwkKey.KeyID(); ok {
keyID = v
}

var raw interface{}
if err := jwk.Export(jwkKey, &raw); err != nil {
Expand Down
8 changes: 5 additions & 3 deletions jwe/jwe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,9 @@ func TestGH554(t *testing.T) {

pubkey, err := jwk.PublicKeyOf(privkey)
require.NoError(t, err, `jwk.PublicKeyOf() should succeed`)
require.Equal(t, keyID, pubkey.KeyID(), `key ID should match`)
kid, ok := pubkey.KeyID()
require.True(t, ok, `key ID should be present`)
require.Equal(t, keyID, kid, `key ID should match`)

encrypted, err := jwe.Encrypt([]byte(plaintext), jwe.WithKey(jwa.ECDH_ES(), pubkey))
require.NoError(t, err, `jwk.Encrypt() should succeed`)
Expand All @@ -625,9 +627,9 @@ func TestGH554(t *testing.T) {
recipients := msg.Recipients()

// The epk must have the same key ID as the original
kid, ok := recipients[0].Headers().KeyID()
epkKid, ok := recipients[0].Headers().KeyID()
require.True(t, ok, `key ID should be present`)
require.Equal(t, keyID, kid, `key ID in epk should match`)
require.Equal(t, keyID, epkKid, `key ID in epk should match`)
}

func TestGH803(t *testing.T) {
Expand Down
10 changes: 6 additions & 4 deletions jwe/key_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,16 @@ type keySetProvider struct {
}

func (kp *keySetProvider) selectKey(sink KeySink, key jwk.Key, _ Recipient, _ *Message) error {
if usage := key.KeyUsage(); usage != "" && usage != jwk.ForEncryption.String() {
return nil
if usage, ok := key.KeyUsage(); ok {
if usage != "" && usage != jwk.ForEncryption.String() {
return nil
}
}

if v := key.Algorithm(); v != nil {
if v, ok := key.Algorithm(); ok {
kalg, ok := jwa.LookupKeyEncryptionAlgorithm(v.String())
if !ok {
return fmt.Errorf(`invalid key encryption algorithm %s`, key.Algorithm())
return fmt.Errorf(`invalid key encryption algorithm %s`, v)
}

sink.Key(kalg, key)
Expand Down
39 changes: 27 additions & 12 deletions jwk/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,12 @@ func ecdsaJWKToRaw(keyif Key, hint interface{}) (interface{}, error) {

k.mu.RLock()
defer k.mu.RUnlock()
return buildECDSAPublicKey(k.Crv(), k.x, k.y)

crv, ok := k.Crv()
if !ok {
return nil, fmt.Errorf(`missing "crv" field`)
}
return buildECDSAPublicKey(crv, k.x, k.y)
case *ecdsaPrivateKey:
switch hint.(type) {
case ecdsa.PrivateKey, *ecdsa.PrivateKey, interface{}:
Expand All @@ -123,7 +128,12 @@ func ecdsaJWKToRaw(keyif Key, hint interface{}) (interface{}, error) {

k.mu.RLock()
defer k.mu.RUnlock()
pubk, err := buildECDSAPublicKey(k.Crv(), k.x, k.y)

crv, ok := k.Crv()
if !ok {
return nil, fmt.Errorf(`missing "crv" field`)
}
pubk, err := buildECDSAPublicKey(crv, k.x, k.y)
if err != nil {
return nil, fmt.Errorf(`failed to build public key: %w`, err)
}
Expand Down Expand Up @@ -232,28 +242,33 @@ func (k ecdsaPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
}

func ecdsaValidateKey(k interface {
Crv() jwa.EllipticCurveAlgorithm
X() []byte
Y() []byte
Crv() (jwa.EllipticCurveAlgorithm, bool)
X() ([]byte, bool)
Y() ([]byte, bool)
}, checkPrivate bool) error {
crv, err := ourecdsa.CurveFromAlgorithm(k.Crv())
crvtyp, ok := k.Crv()
if !ok {
return fmt.Errorf(`missing "crv" field`)
}

crv, err := ourecdsa.CurveFromAlgorithm(crvtyp)
if err != nil {
return fmt.Errorf(`invalid curve algorithm %q: %w`, k.Crv(), err)
return fmt.Errorf(`invalid curve algorithm %q: %w`, crvtyp, err)
}

keySize := ecutil.CalculateKeySize(crv)
if x := k.X(); len(x) != keySize {
if x, ok := k.X(); !ok || len(x) != keySize {
return fmt.Errorf(`invalid "x" length (%d) for curve %q`, len(x), crv.Params().Name)
}

if y := k.Y(); len(y) != keySize {
if y, ok := k.Y(); !ok || len(y) != keySize {
return fmt.Errorf(`invalid "y" length (%d) for curve %q`, len(y), crv.Params().Name)
}

if checkPrivate {
if priv, ok := k.(interface{ D() []byte }); ok {
if len(priv.D()) != keySize {
return fmt.Errorf(`invalid "d" length (%d) for curve %q`, len(priv.D()), crv.Params().Name)
if priv, ok := k.(keyWithD); ok {
if d, ok := priv.D(); !ok || len(d) != keySize {
return fmt.Errorf(`invalid "d" length (%d) for curve %q`, len(d), crv.Params().Name)
}
} else {
return fmt.Errorf(`missing "d" value`)
Expand Down
Loading
Loading