Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

During traffic, the JWK/set can be insert multiple keys #1215

Closed
weihong-xu-hpe opened this issue Oct 14, 2024 · 3 comments
Closed

During traffic, the JWK/set can be insert multiple keys #1215

weihong-xu-hpe opened this issue Oct 14, 2024 · 3 comments
Assignees

Comments

@weihong-xu-hpe
Copy link

Describe the bug
Due to the token issuer behaviour, my team try to implement a logic like below:

  1. Create a jwks cache by jwk.NewSet() -> call it local cache
  2. Use the local jwks cache to validate the token come in jwt.Verify
  3. When the token parser report "failed to find key with key ID" , kick the jwks url download process, get the jwks set from remote issuer. -> call it new jwks cache
  4. Loop the new jwks cache, check the new key inside this new jwks is belong to the local cache or not, if not , add that key into the local jwks cache use AddKey
  5. re do jwt.Verify

Our service currently has close to 100 RPS on one instance, the download of the jwks could takes up to 1 second.

Based on the code reads, the AddKey logic should be in jwk/set.go@AddKey , that part has a RWmutex, the comment around the set all describe if the key exists in that local keyset, I should see a following message throw out as error.

AddKey: key already exists

Instead in our production log what we saw is
image

image
The localKeySet add the same key again and again, cause the localKeySet grows.

The value I print above is the kid of that key.

Sample code logic:

localKeySet: =jwk.NewSet()
	jwtToken, err := jwt.Parse(base64EncodedToken, jwt.WithKeySet(localKeySet, jws.WithInferAlgorithmFromKey(p.tknParseOpts.inferAlgorithmFromKey)))
	if err == nil {
		if err := p.validateIssuer(jwtToken); err != nil {
			return noOpToken{}, err
		}

		return p.extractToken(jwtToken)
	}

	// Only if the Parse throw error because key cannot be found, we will try to update the jwks
	// Have to take this ugly implementation, the Error.Is not working here for invalid token cases
	if strings.Contains(err.Error(), "failed to find key with key ID") {
		// 2. Verify failed, we will need to try update the jwks based on the token issuer provided
		// Here we need to do an extra ParseInsecure to get the issuer, the above jwtToken is nil
		jwtToken, err = jwt.ParseInsecure(base64EncodedToken)
		logger := logs.ProvideLogger()
		logger.InfoContext(ctx, "Local JWKS cannot verify token, trying to update JWKS", "error", err)
		if err != nil {
			return noOpToken{}, fmt.Errorf("parsing token insecurely: %w", err)
		}
		issuer := jwtToken.Issuer()
		if err := p.validateIssuer(jwtToken); err != nil {
			return noOpToken{}, err
		}

		// 3. Update JWKS cache
		if err := p.updateKeySetFromJWKS(ctx, issuer); err != nil {
			return noOpToken{}, err
		}

		// 4. Try Parse token again using local jwks
		jwtToken, err = jwt.Parse(base64EncodedToken, jwt.WithKeySet(p.localKeySet))
		if err != nil {
			return noOpToken{}, fmt.Errorf("parsing token: %w", err)
		}

		return p.extractToken(jwtToken)

	}
'''
'''
// Update Local JWKS cache from new JWKS downloaded
func (p *ValidatingParser) updateKeySetFromJWKS(ctx context.Context, issuer string) error {
	logger := logs.ProvideLogger()
	jwksURI, err := p.getJWKSURI(ctx, issuer)
	if err != nil {
		return fmt.Errorf("getting JWKS URI: %w", err)
	}

	newKeySet, err := jwk.Fetch(context.Background(), jwksURI)
	if err != nil {
		return fmt.Errorf("fetching JWKS URI: %w", err)
	}

	for i := 0; i < newKeySet.Len(); i++ {
		key, ok := newKeySet.Key(i)
		if !ok {
			return fmt.Errorf("getting key from newKeySet failed")
		}
		if err := p.localKeySet.AddKey(key); err != nil {
			// When keys already exists, it cannot be added again, we will ignore this error
			if !strings.Contains(err.Error(), "key already exists") {
				return fmt.Errorf("adding public cert to local cache: %w", err)
			}
		}
		logger.InfoContext(ctx, "Added new key to local cache", "key", key.KeyID(), "existing key number", p.localKeySet.Len())
	}

	return nil
}

Software Version
GO : go version go1.23.0 darwin/arm64
Library : github.com/lestrrat-go/jwx/v2 v2.1.1

To Reproduce / Expected behavior

  1. Gin server, takes some level traffic like 100 RPS, each req carry a token
  2. The code use jwx.NewSet to create a jwks cache
  3. When KID in token not found in local jwks cache, download remote jwks from issuer and add missing key one by one
  4. when traffic exists, trigger issuer to rotate the key

Expected behaviour: local jwks should contain only old and new key
What happens: New key add multiple times

@lestrrat
Copy link
Collaborator

lestrrat commented Oct 14, 2024

Sorry, please disregard my earlier response. I thought you were talking about a synchronization issue in jwk.Cache, but I don't think it is.

You need an explicit lock around your loop.

mu.Lock() // WRITE lock
for i := 0; i < set.Len(); i++ {
   p.localKeySet.AddKey(...)
}
mu.Unlock()

The lock in p.localKeySet can only protect the p.localKeySet from concurrent read/write during the method call. For example, w/o a write lock in the localKeySet, and if you were reading an element from position i while some other thread is writing to the same position i (which may include truncating the container slice), the results might be wrong. The lock inside the key protects from that.

However, what you are doing is to change the entirety of the p.localKeySet, which means you need to protect YOUR ENTIRE WRITE OPERATION (the for loop) with a lock. You will also need to protect your other threads from reading from p.localKeySet while you are writing to it.

P.S. I have a feeling you should be able to get away with just swapping the entire p.localKeySet with the new JWKS if you properly protect access to it.

@lestrrat
Copy link
Collaborator

lestrrat commented Oct 14, 2024

sorry, I think I misspoke. the write lock needs to start when you check for the existence of a key. (not 100% sure, b/c I'm writing this as I commute)

Either way you need to lock, not jwx

@weihong-xu-hpe
Copy link
Author

Okay, understand, I think better for me is introduce singleflight around the time the rotation happens.
Great thx for your help

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants