Skip to content

Commit

Permalink
rwlock to mutex
Browse files Browse the repository at this point in the history
  • Loading branch information
calccrypto committed Jul 2, 2024
1 parent 8da724f commit 80d1f20
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 29 deletions.
7 changes: 2 additions & 5 deletions include/dpusm/provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <linux/atomic.h>
#include <linux/list.h>
#include <linux/module.h>
#include <linux/spinlock.h>
#include <linux/mutex.h>

#include <dpusm/provider_api.h>

Expand All @@ -20,7 +20,7 @@ typedef struct dpusm_provider_handle {
typedef struct {
struct list_head providers; /* list of providers */
size_t count; /* count of registered providers */
rwlock_t lock;
struct mutex lock;
atomic_t active; /* how many providers are active (may be larger than count) */
/* this is not tied to the provider/count */
} dpusm_t;
Expand All @@ -34,9 +34,6 @@ int dpusm_provider_unregister(dpusm_t *dpusm, struct module *module);
dpusm_ph_t **dpusm_provider_get(dpusm_t *dpusm, const char *name);
int dpusm_provider_put(dpusm_t *dpusm, void *handle);

void dpusm_provider_write_lock(dpusm_t *dpusm);
void dpusm_provider_write_unlock(dpusm_t *dpusm);

/*
* call when backing DPU goes down unexpectedly
*
Expand Down
6 changes: 3 additions & 3 deletions src/dpusm.c
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ static int __init
dpusm_init(void) {
INIT_LIST_HEAD(&dpusm.providers);
dpusm.count = 0;
rwlock_init(&dpusm.lock);
mutex_init(&dpusm.lock);

atomic_set(&dpusm.active, 0);
dpusm_mem_init();
Expand All @@ -77,7 +77,7 @@ dpusm_init(void) {

static void __exit
dpusm_exit(void) {
dpusm_provider_write_lock(&dpusm);
mutex_lock(&dpusm.lock);

const int active = atomic_read(&dpusm.active);
if (unlikely(active)) {
Expand All @@ -100,7 +100,7 @@ dpusm_exit(void) {
dpusm_provider_unregister_handle(&dpusm, &provider->self);
}

dpusm_provider_write_unlock(&dpusm);
mutex_unlock(&dpusm.lock);

#if DPUSM_TRACK_ALLOCS
size_t alloc_count = 0;
Expand Down
32 changes: 11 additions & 21 deletions src/provider.c
Original file line number Diff line number Diff line change
Expand Up @@ -265,19 +265,19 @@ dpusm_provider_register(dpusm_t *dpusm, struct module *module, const dpusm_pf_t
return -EINVAL;
}

dpusm_provider_write_lock(dpusm);
mutex_lock(&dpusm->lock);;

dpusm_ph_t **found = find_provider(dpusm, module_name(module));
if (found) {
printk("%s: DPUSM Provider with the name \"%s\" (%p) already exists. %zu providers registered.\n",
__func__, module_name(module), *found, dpusm->count);
dpusm_provider_write_unlock(dpusm);
mutex_unlock(&dpusm->lock);;
return -EEXIST;
}

dpusm_ph_t *provider = dpusmph_init(module, funcs);
if (!provider) {
dpusm_provider_write_unlock(dpusm);
mutex_unlock(&dpusm->lock);;
return -ECANCELED;
}

Expand All @@ -286,7 +286,7 @@ dpusm_provider_register(dpusm_t *dpusm, struct module *module, const dpusm_pf_t
printk("%s: DPUSM Provider \"%s\" (%p) added. Now %zu providers registered.\n",
__func__, module_name(module), provider, dpusm->count);

dpusm_provider_write_unlock(dpusm);
mutex_unlock(&dpusm->lock);;

return 0;
}
Expand Down Expand Up @@ -320,20 +320,20 @@ dpusm_provider_unregister_handle(dpusm_t *dpusm, dpusm_ph_t **provider) {

int
dpusm_provider_unregister(dpusm_t *dpusm, struct module *module) {
dpusm_provider_write_lock(dpusm);
mutex_lock(&dpusm->lock);;

dpusm_ph_t **provider = find_provider(dpusm, module_name(module));
if (!provider) {
printk("%s: Could not find provider with name \"%s\"\n", __func__, module_name(module));
dpusm_provider_write_unlock(dpusm);
mutex_unlock(&dpusm->lock);;
return DPUSM_ERROR;
}

void *addr = *provider;
const int rc = dpusm_provider_unregister_handle(dpusm, provider);
printk("%s: Unregistered \"%s\" (%p): %d\n", __func__, module_name(module), addr, rc);

dpusm_provider_write_unlock(dpusm);
mutex_unlock(&dpusm->lock);;
return rc;
}

Expand All @@ -345,7 +345,7 @@ dpusm_provider_unregister(dpusm_t *dpusm, struct module *module) {
/* get a provider by name */
dpusm_ph_t **
dpusm_provider_get(dpusm_t *dpusm, const char *name) {
read_lock(&dpusm->lock);
mutex_lock(&dpusm->lock);
dpusm_ph_t **provider = find_provider(dpusm, name);
if (provider) {
struct module *module = (*provider)->module;
Expand All @@ -369,7 +369,7 @@ dpusm_provider_get(dpusm_t *dpusm, const char *name) {
printk("%s: Error: Did not find provider \"%s\"\n",
__func__, name);
}
read_unlock(&dpusm->lock);
mutex_unlock(&dpusm->lock);
return provider;
}

Expand Down Expand Up @@ -404,18 +404,8 @@ dpusm_provider_put(dpusm_t *dpusm, void *handle) {
return DPUSM_OK;
}

void
dpusm_provider_write_lock(dpusm_t *dpusm) {
write_lock(&dpusm->lock);
}

void
dpusm_provider_write_unlock(dpusm_t *dpusm) {
write_unlock(&dpusm->lock);
}

void dpusm_provider_invalidate(dpusm_t *dpusm, const char *name) {
dpusm_provider_write_lock(dpusm);
mutex_lock(&dpusm->lock);;
dpusm_ph_t **provider = find_provider(dpusm, name);
if (provider && *provider) {
(*provider)->funcs = NULL;
Expand All @@ -428,5 +418,5 @@ void dpusm_provider_invalidate(dpusm_t *dpusm, const char *name) {
printk("%s: Error: Did not find provider \"%s\"\n",
__func__, name);
}
dpusm_provider_write_unlock(dpusm);
mutex_unlock(&dpusm->lock);;
}

0 comments on commit 80d1f20

Please sign in to comment.