From a0bf7f01aba996d3d3d2c6115eb1ebf8ade3c675 Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Mon, 1 Jul 2024 12:16:29 -0600 Subject: [PATCH] use THIS_MODULE instead of the module's name --- examples/providers/bsd/provider.c | 4 +-- examples/providers/gpl/provider.c | 4 +-- include/dpusm/provider.h | 7 ++-- include/dpusm/provider_api.h | 8 ++--- src/dpusm.c | 19 +++++------ src/provider.c | 56 ++++++++++++++++++++----------- src/user.c | 4 +-- 7 files changed, 60 insertions(+), 42 deletions(-) diff --git a/examples/providers/bsd/provider.c b/examples/providers/bsd/provider.c index dc9bedf..69ae0a8 100644 --- a/examples/providers/bsd/provider.c +++ b/examples/providers/bsd/provider.c @@ -6,7 +6,7 @@ static int __init dpusm_bsd_provider_init(void) { - const int rc = dpusm_register_bsd(module_name(THIS_MODULE), + const int rc = dpusm_register_bsd(THIS_MODULE, &example_dpusm_provider_functions); printk("%s init: %d\n", module_name(THIS_MODULE), rc); return rc; @@ -14,7 +14,7 @@ dpusm_bsd_provider_init(void) { static void __exit dpusm_bsd_provider_exit(void) { - dpusm_unregister_bsd(module_name(THIS_MODULE)); + dpusm_unregister_bsd(THIS_MODULE); printk("%s exit\n", module_name(THIS_MODULE)); } diff --git a/examples/providers/gpl/provider.c b/examples/providers/gpl/provider.c index 89b83c1..821eea2 100644 --- a/examples/providers/gpl/provider.c +++ b/examples/providers/gpl/provider.c @@ -6,7 +6,7 @@ static int __init dpusm_gpl_provider_init(void) { - const int rc = dpusm_register_gpl(module_name(THIS_MODULE), + const int rc = dpusm_register_gpl(THIS_MODULE, &example_dpusm_provider_functions); printk("%s init: %d\n", module_name(THIS_MODULE), rc); return rc; @@ -14,7 +14,7 @@ dpusm_gpl_provider_init(void) { static void __exit dpusm_gpl_provider_exit(void) { - dpusm_unregister_gpl(module_name(THIS_MODULE)); + dpusm_unregister_gpl(THIS_MODULE); printk("%s exit\n", module_name(THIS_MODULE)); } diff --git a/include/dpusm/provider.h b/include/dpusm/provider.h index db092f4..324ecde 100644 --- a/include/dpusm/provider.h +++ b/include/dpusm/provider.h @@ -3,13 +3,14 @@ #include #include +#include #include #include /* single provider data */ typedef struct dpusm_provider_handle { - const char *name; /* reference to a string */ + struct module *module; dpusm_pc_t capabilities; /* constant set of capabilities */ const dpusm_pf_t *funcs; /* reference to a struct */ atomic_t refs; /* how many users are holding this provider */ @@ -25,11 +26,11 @@ typedef struct { /* this is not tied to the provider/count */ } dpusm_t; -int dpusm_provider_register(dpusm_t *dpusm, const char *name, const dpusm_pf_t *funcs); +int dpusm_provider_register(dpusm_t *dpusm, struct module *module, const dpusm_pf_t *funcs); /* can't prevent provider module from unloading */ int dpusm_provider_unregister_handle(dpusm_t *dpusm, dpusm_ph_t **provider); -int dpusm_provider_unregister(dpusm_t *dpusm, const char *name); +int dpusm_provider_unregister(dpusm_t *dpusm, struct module *module); dpusm_ph_t **dpusm_provider_get(dpusm_t *dpusm, const char *name); int dpusm_provider_put(dpusm_t *dpusm, void *handle); diff --git a/include/dpusm/provider_api.h b/include/dpusm/provider_api.h index 3ddb238..cb9c9ca 100644 --- a/include/dpusm/provider_api.h +++ b/include/dpusm/provider_api.h @@ -183,10 +183,10 @@ typedef struct dpusm_provider_functions { } dpusm_pf_t; /* returns -ERRNO instead of DPUSM_* */ -int dpusm_register_bsd(const char *name, const dpusm_pf_t *funcs); -int dpusm_unregister_bsd(const char *name); -int dpusm_register_gpl(const char *name, const dpusm_pf_t *funcs); -int dpusm_unregister_gpl(const char *name); +int dpusm_register_bsd(struct module *module, const dpusm_pf_t *funcs); +int dpusm_unregister_bsd(struct module *module); +int dpusm_register_gpl(struct module *module, const dpusm_pf_t *funcs); +int dpusm_unregister_gpl(struct module *module); /* * call when backing DPU goes down unexpectedly diff --git a/src/dpusm.c b/src/dpusm.c index df1e6bc..9d661e3 100644 --- a/src/dpusm.c +++ b/src/dpusm.c @@ -11,13 +11,13 @@ static dpusm_t dpusm; int -dpusm_register_bsd(const char *name, const dpusm_pf_t *funcs) { - return dpusm_provider_register(&dpusm, name, funcs); +dpusm_register_bsd(struct module *module, const dpusm_pf_t *funcs) { + return dpusm_provider_register(&dpusm, module, funcs); } int -dpusm_unregister_bsd(const char *name) { - return dpusm_provider_unregister(&dpusm, name); +dpusm_unregister_bsd(struct module *module) { + return dpusm_provider_unregister(&dpusm, module); } /* provider facing functions */ @@ -25,13 +25,13 @@ EXPORT_SYMBOL(dpusm_register_bsd); EXPORT_SYMBOL(dpusm_unregister_bsd); int -dpusm_register_gpl(const char *name, const dpusm_pf_t *funcs) { - return dpusm_provider_register(&dpusm, name, funcs); +dpusm_register_gpl(struct module *module, const dpusm_pf_t *funcs) { + return dpusm_provider_register(&dpusm, module, funcs); } int -dpusm_unregister_gpl(const char *name) { - return dpusm_provider_unregister(&dpusm, name); +dpusm_unregister_gpl(struct module *module) { + return dpusm_provider_unregister(&dpusm, module); } /* provider facing functions */ @@ -76,8 +76,7 @@ dpusm_init(void) { } static void __exit -dpusm_exit(void) -{ +dpusm_exit(void) { dpusm_provider_write_lock(&dpusm); const int active = atomic_read(&dpusm.active); diff --git a/src/provider.c b/src/provider.c index 15ce09a..2b78252 100644 --- a/src/provider.c +++ b/src/provider.c @@ -84,7 +84,7 @@ find_provider(dpusm_t *dpusm, const char *name) { struct list_head *it = NULL; list_for_each(it, &dpusm->providers) { dpusm_ph_t *dpusmph = list_entry(it, dpusm_ph_t, list); - const char *p_name = dpusmph->name; + const char *p_name = module_name(dpusmph->module); const size_t p_name_len = strlen(p_name); if (name_len == p_name_len) { if (memcmp(name, p_name, p_name_len) == 0) { @@ -108,8 +108,9 @@ static void print_supported(const char *name, const char *func) } static dpusm_ph_t * -dpusmph_init(const char *name, const dpusm_pf_t *funcs) +dpusmph_init(struct module *module, const dpusm_pf_t *funcs) { + const char *name = module_name(module); dpusm_ph_t *dpusmph = dpusm_mem_alloc(sizeof(dpusm_ph_t)); if (dpusmph) { /* fill in capabilities bitmasks */ @@ -223,7 +224,7 @@ dpusmph_init(const char *name, const dpusm_pf_t *funcs) dpusmph->capabilities.io &= ~DPUSM_IO_DISK; } - dpusmph->name = name; + dpusmph->module = module; dpusmph->funcs = funcs; dpusmph->self = dpusmph; atomic_set(&dpusmph->refs, 0); @@ -234,7 +235,13 @@ dpusmph_init(const char *name, const dpusm_pf_t *funcs) /* add a new provider */ int -dpusm_provider_register(dpusm_t *dpusm, const char *name, const dpusm_pf_t *funcs) { +dpusm_provider_register(dpusm_t *dpusm, struct module *module, const dpusm_pf_t *funcs) { + /* make sure provider can't be unloaded before dpusm */ + if (!try_module_get(module)) { + printk("Error: Could not increment reference count of %s\n", module_name(module)); + return -ECANCELED; + } + const int rc = dpusm_provider_sane_at_load(funcs); if (rc != DPUSM_OK) { static const size_t max = @@ -258,33 +265,36 @@ dpusm_provider_register(dpusm_t *dpusm, const char *name, const dpusm_pf_t *func printk("%s: DPUSM Provider \"%s\" does not provide " "a valid set of functions. Bad function groups: %s\n", - __func__, name, buf); + __func__, module_name(module), buf); dpusm_mem_free(buf, size); + module_put(module); return -EINVAL; } dpusm_provider_write_lock(dpusm); - dpusm_ph_t **found = find_provider(dpusm, name); + dpusm_ph_t **found = find_provider(dpusm, module_name(module)); if (found) { printk("%s: DPUSM Provider with the name \"%s\" (%p) already exists. %zu providers registered.\n", - __func__, name, *found, dpusm->count); + __func__, module_name(module), *found, dpusm->count); dpusm_provider_write_unlock(dpusm); + module_put(module); return -EEXIST; } - dpusm_ph_t *provider = dpusmph_init(name, funcs); + dpusm_ph_t *provider = dpusmph_init(module, funcs); if (!provider) { dpusm_provider_write_unlock(dpusm); + module_put(module); return -ECANCELED; } list_add(&provider->list, &dpusm->providers); dpusm->count++; printk("%s: DPUSM Provider \"%s\" (%p) added. Now %zu providers registered.\n", - __func__, name, provider, dpusm->count); + __func__, module_name(module), provider, dpusm->count); dpusm_provider_write_unlock(dpusm); @@ -292,8 +302,6 @@ dpusm_provider_register(dpusm_t *dpusm, const char *name, const dpusm_pf_t *func } /* remove provider from list */ -/* can't prevent provider module from unloading */ -/* locking is done by caller */ int dpusm_provider_unregister_handle(dpusm_t *dpusm, dpusm_ph_t **provider) { if (!provider || !*provider) { @@ -305,12 +313,13 @@ dpusm_provider_unregister_handle(dpusm_t *dpusm, dpusm_ph_t **provider) { const int refs = atomic_read(&(*provider)->refs); if (refs) { printk("%s: Unregistering provider \"%s\" with %d references remaining.\n", - __func__, (*provider)->name, refs); + __func__, module_name((*provider)->module), refs); rc = -EBUSY; } list_del(&(*provider)->list); atomic_sub(refs, &dpusm->active); /* remove this provider's references from the global active count */ + module_put((*provider)->module); dpusmph_destroy(*provider); dpusm->count--; @@ -321,19 +330,19 @@ dpusm_provider_unregister_handle(dpusm_t *dpusm, dpusm_ph_t **provider) { } int -dpusm_provider_unregister(dpusm_t *dpusm, const char *name) { +dpusm_provider_unregister(dpusm_t *dpusm, struct module *module) { dpusm_provider_write_lock(dpusm); - dpusm_ph_t **provider = find_provider(dpusm, name); + dpusm_ph_t **provider = find_provider(dpusm, module_name(module)); if (!provider) { - printk("%s: Could not find provider with name \"%s\"\n", __func__, name); + printk("%s: Could not find provider with name \"%s\"\n", __func__, module_name(module)); dpusm_provider_write_unlock(dpusm); return DPUSM_ERROR; } void *addr = *provider; const int rc = dpusm_provider_unregister_handle(dpusm, provider); - printk("%s: Unregistered \"%s\" (%p): %d\n", __func__, name, addr, rc); + printk("%s: Unregistered \"%s\" (%p): %d\n", __func__, module_name(module), addr, rc); dpusm_provider_write_unlock(dpusm); return rc; @@ -352,8 +361,15 @@ dpusm_provider_get(dpusm_t *dpusm, const char *name) { if (provider) { atomic_inc(&(*provider)->refs); atomic_inc(&dpusm->active); + + /* /\* make sure provider can't be unloaded before caller *\/ */ + /* if (!try_module_get((*provider)->module)) { */ + /* printk("Error: Could not increment reference count of %s\n", name); */ + /* return -ECANCELED; */ + /* } */ + printk("%s: User has been given a handle to \"%s\" (%p) (now %d users).\n", - __func__, (*provider)->name, *provider, atomic_read(&(*provider)->refs)); + __func__, name, *provider, atomic_read(&(*provider)->refs)); if ((*provider)->funcs->at_connect) { (*provider)->funcs->at_connect(); @@ -378,10 +394,11 @@ dpusm_provider_put(dpusm_t *dpusm, void *handle) { if (!atomic_read(&(*provider)->refs)) { printk("%s Error: Cannot decrement provider \"%s\" user count already at 0.\n", - __func__, (*provider)->name); + __func__, module_name((*provider)->module)); return DPUSM_ERROR; } + /* module_put((*provider)->module); */ atomic_dec(&(*provider)->refs); atomic_dec(&dpusm->active); @@ -392,7 +409,7 @@ dpusm_provider_put(dpusm_t *dpusm, void *handle) { } printk("%s: User has returned a handle to \"%s\" (%p) (now %d users).\n", - __func__, (*provider)->name, *provider, atomic_read(&(*provider)->refs)); + __func__, module_name((*provider)->module), *provider, atomic_read(&(*provider)->refs)); return DPUSM_OK; } @@ -414,6 +431,7 @@ void dpusm_provider_invalidate(dpusm_t *dpusm, const char *name) { memset(&(*provider)->capabilities, 0, sizeof((*provider)->capabilities)); printk("%s: Provider \"%s\" has been invalidated with %d users active.\n", __func__, name, atomic_read(&(*provider)->refs)); + /* not decrementing module reference count here - provider is still registered */ } else { printk("%s: Error: Did not find provider \"%s\"\n", diff --git a/src/user.c b/src/user.c index ca3b727..52b063a 100644 --- a/src/user.c +++ b/src/user.c @@ -78,7 +78,7 @@ dpusm_provider_sane(dpusm_ph_t **provider) { } if (!FUNCS(provider)) { - printk("Error: Invalidated provider: %s\n", (*provider)->name); + printk("Error: Invalidated provider: %s\n", module_name((*provider)->module)); return DPUSM_PROVIDER_INVALIDATED; } @@ -144,7 +144,7 @@ dpusm_get_provider(const char *name) { static const char * dpusm_get_provider_name(void *provider) { dpusm_ph_t **dpusmph = (dpusm_ph_t **) provider; - return (dpusmph && *dpusmph)?(*dpusmph)->name:NULL; + return (dpusmph && *dpusmph)?module_name((*dpusmph)->module):NULL; } static int