Skip to content

Commit

Permalink
Unit test coverage for failed issuer
Browse files Browse the repository at this point in the history
Test the behavior around an issuer failure -- as long as the cached
pubkey isn't expired, we shouldn't try again for 5 minutes.
  • Loading branch information
bbockelm committed Nov 1, 2022
1 parent 7ec2670 commit 2f14506
Showing 1 changed file with 293 additions and 0 deletions.
293 changes: 293 additions & 0 deletions test/main.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
#include "../src/scitokens.h"

#include <pwd.h>
#include <memory>
#include <gtest/gtest.h>

#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/ec.h>
#include <openssl/pem.h>

#ifndef PICOJSON_USE_INT64
#define PICOJSON_USE_INT64
#endif
#include <picojson/picojson.h>
#include <sqlite3.h>

namespace {

const char ec_private[] = "-----BEGIN EC PRIVATE KEY-----\n"
Expand All @@ -27,6 +39,216 @@ const char ec_public_2[] = "-----BEGIN PUBLIC KEY-----\n"
"XWCq4E/g2ME/uBOdP8RE0tqle8fxYcaPikgMcppGq2ycTiLGgEYXgsq2JA==\n"
"-----END PUBLIC KEY-----\n";

/**
* Duplicate of get_cache_file from scitokens_cache.cpp; used for direct
* SQLite manipulation.
*/
std::string
get_cache_file() {

const char *xdg_cache_home = getenv("XDG_CACHE_HOME");

auto bufsize = sysconf(_SC_GETPW_R_SIZE_MAX);
bufsize = (bufsize == -1) ? 16384 : bufsize;

std::unique_ptr<char[]> buf(new char[bufsize]);

std::string home_dir;
struct passwd pwd, *result = NULL;
getpwuid_r(geteuid(), &pwd, buf.get(), bufsize, &result);
if (result && result->pw_dir) {
home_dir = result->pw_dir;
home_dir += "/.cache";
}

std::string cache_dir(xdg_cache_home ? xdg_cache_home : home_dir.c_str());
if (cache_dir.size() == 0) {
return "";
}

int r = mkdir(cache_dir.c_str(), 0700);
if ((r < 0) && errno != EEXIST) {
return "";
}

std::string keycache_dir = cache_dir + "/scitokens";
r = mkdir(keycache_dir.c_str(), 0700);
if ((r < 0) && errno != EEXIST) {
return "";
}

std::string keycache_file = keycache_dir + "/scitokens_cpp.sqllite";
// Assume this isn't needed; we'll trigger it via the "real" cache routines.
//initialize_cachedb(keycache_file);

return keycache_file;
}

/**
* Duplicate of remove_issuer_entry from scitokens_cache.cpp; used for direct cache manipulation
*/
void
remove_issuer_entry(sqlite3 *db, const std::string &issuer, bool new_transaction) {

if (new_transaction) sqlite3_exec(db, "BEGIN", 0, 0 , 0);

sqlite3_stmt *stmt;
int rc = sqlite3_prepare_v2(db, "DELETE FROM keycache WHERE issuer = ?", -1, &stmt, NULL);
if (rc != SQLITE_OK) {
sqlite3_close(db);
return;
}

if (sqlite3_bind_text(stmt, 1, issuer.c_str(), issuer.size(), SQLITE_STATIC) != SQLITE_OK) {
sqlite3_finalize(stmt);
sqlite3_close(db);
return;
}

rc = sqlite3_step(stmt);
if (rc != SQLITE_DONE) {
sqlite3_finalize(stmt);
sqlite3_close(db);
return;
}

sqlite3_finalize(stmt);

if (new_transaction) sqlite3_exec(db, "COMMIT", 0, 0 , 0);
}

/**
* Duplicate of store_public_keys from scitokens_cache.cpp; used for direct cache manipulation.
*/
bool
store_public_keys(const std::string &issuer, const std::string &keys, int64_t next_update, int64_t expires) {

picojson::value json_obj;
auto err = picojson::parse(json_obj, keys);
if (!err.empty() || !json_obj.is<picojson::object>()) {
return false;
}

picojson::object top_obj;
top_obj["jwks"] = json_obj;
top_obj["next_update"] = picojson::value(next_update);
top_obj["expires"] = picojson::value(expires);
picojson::value db_value(top_obj);
std::string db_str = db_value.serialize();

auto cache_fname = get_cache_file();
if (cache_fname.size() == 0) {return false;}

sqlite3 *db;
int rc = sqlite3_open(cache_fname.c_str(), &db);
if (rc) {
sqlite3_close(db);
return false;
}

sqlite3_exec(db, "BEGIN", 0, 0 , 0);

remove_issuer_entry(db, issuer, false);

sqlite3_stmt *stmt;
rc = sqlite3_prepare_v2(db, "INSERT INTO keycache VALUES (?, ?)", -1, &stmt, NULL);
if (rc != SQLITE_OK) {
sqlite3_close(db);
return false;
}

if (sqlite3_bind_text(stmt, 1, issuer.c_str(), issuer.size(), SQLITE_STATIC) != SQLITE_OK) {
sqlite3_finalize(stmt);
sqlite3_close(db);
return false;
}

if (sqlite3_bind_text(stmt, 2, db_str.c_str(), db_str.size(), SQLITE_STATIC) != SQLITE_OK) {
sqlite3_finalize(stmt);
sqlite3_close(db);
return false;
}

rc = sqlite3_step(stmt);
if (rc != SQLITE_DONE) {
sqlite3_finalize(stmt);
sqlite3_close(db);
return false;
}

sqlite3_exec(db, "COMMIT", 0, 0 , 0);

sqlite3_finalize(stmt);
sqlite3_close(db);
return true;
}

bool
get_public_keys_from_db(const std::string issuer, int64_t &expires, int64_t &next_update) {
auto cache_fname = get_cache_file();
if (cache_fname.size() == 0) {return false;}

sqlite3 *db;
int rc = sqlite3_open(cache_fname.c_str(), &db);
if (rc) {
sqlite3_close(db);
return false;
}

sqlite3_stmt *stmt;
rc = sqlite3_prepare_v2(db, "SELECT keys from keycache where issuer = ?", -1, &stmt, NULL);
if (rc != SQLITE_OK) {
sqlite3_close(db);
return false;
}

if (sqlite3_bind_text(stmt, 1, issuer.c_str(), issuer.size(), SQLITE_STATIC) != SQLITE_OK) {
sqlite3_finalize(stmt);
sqlite3_close(db);
return false;
}

rc = sqlite3_step(stmt);
if (rc == SQLITE_ROW) {
const unsigned char * data = sqlite3_column_text(stmt, 0);
std::string metadata(reinterpret_cast<const char *>(data));
sqlite3_finalize(stmt);
picojson::value json_obj;
auto err = picojson::parse(json_obj, metadata);
if (!err.empty() || !json_obj.is<picojson::object>()) {
sqlite3_close(db);
return false;
}
auto top_obj = json_obj.get<picojson::object>();
auto iter = top_obj.find("jwks");
auto keys_local = iter->second;
iter = top_obj.find("expires");
if (iter == top_obj.end() || !iter->second.is<int64_t>()) {
sqlite3_close(db);
return false;
}
auto expiry = iter->second.get<int64_t>();
sqlite3_close(db);
iter = top_obj.find("next_update");
if (iter == top_obj.end() || !iter->second.is<int64_t>()) {
next_update = expiry - 4*3600;
} else {
next_update = iter->second.get<int64_t>();
}
expires = expiry;
return true;
} else if (rc == SQLITE_DONE) {
sqlite3_finalize(stmt);
sqlite3_close(db);
return false;
} else {
// TODO: log error?
sqlite3_finalize(stmt);
sqlite3_close(db);
return false;
}
}

TEST(SciTokenTest, CreateToken) {
SciToken token = scitoken_create(nullptr);
ASSERT_TRUE(token != nullptr);
Expand Down Expand Up @@ -63,6 +285,7 @@ class KeycacheTest : public ::testing::Test
{
protected:
std::string demo_scitokens_url = "https://demo.scitokens.org";
std::string demo_invalid_url = "https://demo.scitokens.org/invalid";

void SetUp() override {
char *err_msg;
Expand All @@ -77,6 +300,76 @@ class KeycacheTest : public ::testing::Test
};


// Emulate the case of an issuer failure. Store a public key that
// is in the need of an update. Make sure, on failure, the next_update
// is 5 minutes ahead of the present.
TEST_F(KeycacheTest, FailureTest) {
time_t now = time(NULL);
const time_t expiry = now + 86400;
// Insert a public key that requires an update on next token verification.
ASSERT_TRUE(store_public_keys(demo_invalid_url, demo_scitokens2, now - 600, expiry));

// Create a new token with an invalid signature.
OpenSSL_add_all_algorithms();
ERR_load_BIO_strings();
ERR_load_crypto_strings();
auto outbio = BIO_new(BIO_s_mem());
ASSERT_TRUE(outbio != nullptr);
auto eccgrp = OBJ_txt2nid("secp256k1");
auto ecc = EC_KEY_new_by_curve_name(eccgrp);
ASSERT_TRUE(1 == EC_KEY_generate_key(ecc));

auto pkey = EVP_PKEY_new();
ASSERT_TRUE(1 == EVP_PKEY_assign_EC_KEY(pkey, ecc));
ASSERT_TRUE(1 == PEM_write_bio_PrivateKey(outbio, pkey, NULL, NULL, 0, 0, NULL));

char *pem_data;
long pem_len = BIO_get_mem_data(outbio, &pem_data);
std::string pem_str(pem_data, pem_len);

// Generate a serialized token from the new key.
auto key = scitoken_key_create("test_key", "ES256", "", pem_str.c_str(), nullptr);
ASSERT_TRUE(key != nullptr);

auto token = scitoken_create(key);
ASSERT_TRUE(token != nullptr);

auto rv = scitoken_set_claim_string(token, "iss", demo_invalid_url.c_str(), nullptr);
ASSERT_TRUE(rv == 0);

rv = scitoken_set_claim_string(token, "sub", "test_user", nullptr);
ASSERT_TRUE(rv == 0);

scitoken_set_lifetime(token, 86400);

char *token_encoded;
rv = scitoken_serialize(token, &token_encoded, nullptr);
ASSERT_TRUE(rv == 0);
std::string token_str(token_encoded);
free(token_encoded);

// Try to deserialize the newly generated token. Should fail as the key doesn't match.
auto token_read = scitoken_create(nullptr);
ASSERT_TRUE(token_read != nullptr);
rv = scitoken_deserialize_v2(token_str.c_str(), token_read, nullptr, nullptr);
ASSERT_FALSE(rv == 0);

// Now, for the real test -- what's the value of expired and next_update?
int64_t new_expiry, new_next_update;
ASSERT_TRUE(get_public_keys_from_db(demo_invalid_url, new_expiry, new_next_update));

EXPECT_EQ(new_expiry, expiry);
EXPECT_GE(new_next_update, now + 300);

// Second test: if the expiration is behind us, fetching the key should trigger
// a deletion of the key cache.
ASSERT_TRUE(store_public_keys(demo_invalid_url, demo_scitokens2, now - 600, now - 600));

rv = scitoken_deserialize_v2(token_str.c_str(), token_read, nullptr, nullptr);

ASSERT_FALSE(get_public_keys_from_db(demo_invalid_url, new_expiry, new_next_update));
}

TEST_F(KeycacheTest, RefreshTest) {
char *err_msg;
auto rv = keycache_refresh_jwks(demo_scitokens_url.c_str(), &err_msg);
Expand Down

0 comments on commit 2f14506

Please sign in to comment.