diff --git a/examples/providers/bsd/provider.c b/examples/providers/bsd/provider.c index dc9bedf..69ae0a8 100644 --- a/examples/providers/bsd/provider.c +++ b/examples/providers/bsd/provider.c @@ -6,7 +6,7 @@ static int __init dpusm_bsd_provider_init(void) { - const int rc = dpusm_register_bsd(module_name(THIS_MODULE), + const int rc = dpusm_register_bsd(THIS_MODULE, &example_dpusm_provider_functions); printk("%s init: %d\n", module_name(THIS_MODULE), rc); return rc; @@ -14,7 +14,7 @@ dpusm_bsd_provider_init(void) { static void __exit dpusm_bsd_provider_exit(void) { - dpusm_unregister_bsd(module_name(THIS_MODULE)); + dpusm_unregister_bsd(THIS_MODULE); printk("%s exit\n", module_name(THIS_MODULE)); } diff --git a/examples/providers/gpl/provider.c b/examples/providers/gpl/provider.c index 89b83c1..821eea2 100644 --- a/examples/providers/gpl/provider.c +++ b/examples/providers/gpl/provider.c @@ -6,7 +6,7 @@ static int __init dpusm_gpl_provider_init(void) { - const int rc = dpusm_register_gpl(module_name(THIS_MODULE), + const int rc = dpusm_register_gpl(THIS_MODULE, &example_dpusm_provider_functions); printk("%s init: %d\n", module_name(THIS_MODULE), rc); return rc; @@ -14,7 +14,7 @@ dpusm_gpl_provider_init(void) { static void __exit dpusm_gpl_provider_exit(void) { - dpusm_unregister_gpl(module_name(THIS_MODULE)); + dpusm_unregister_gpl(THIS_MODULE); printk("%s exit\n", module_name(THIS_MODULE)); } diff --git a/include/dpusm/provider.h b/include/dpusm/provider.h index db092f4..b133d9e 100644 --- a/include/dpusm/provider.h +++ b/include/dpusm/provider.h @@ -3,13 +3,14 @@ #include #include +#include #include #include /* single provider data */ typedef struct dpusm_provider_handle { - const char *name; /* reference to a string */ + struct module *module; dpusm_pc_t capabilities; /* constant set of capabilities */ const dpusm_pf_t *funcs; /* reference to a struct */ atomic_t refs; /* how many users are holding this provider */ @@ -25,11 +26,11 @@ typedef struct { /* this is not tied to the provider/count */ } dpusm_t; -int dpusm_provider_register(dpusm_t *dpusm, const char *name, const dpusm_pf_t *funcs); +int dpusm_provider_register(dpusm_t *dpusm, struct module *module, const dpusm_pf_t *funcs); /* can't prevent provider module from unloading */ int dpusm_provider_unregister_handle(dpusm_t *dpusm, dpusm_ph_t **provider); -int dpusm_provider_unregister(dpusm_t *dpusm, const char *name); +int dpusm_provider_unregister(dpusm_t *dpusm, struct module *module); dpusm_ph_t **dpusm_provider_get(dpusm_t *dpusm, const char *name); int dpusm_provider_put(dpusm_t *dpusm, void *handle); @@ -42,6 +43,6 @@ void dpusm_provider_write_unlock(dpusm_t *dpusm); * * provider is not unregistered, so dpusm_unregister still needs to be called */ -void dpusm_provider_invalidate(dpusm_t *dpusm, const char *name); +void dpusm_provider_invalidate(dpusm_t *dpusm, struct module *module); #endif diff --git a/include/dpusm/provider_api.h b/include/dpusm/provider_api.h index 3ddb238..1a70894 100644 --- a/include/dpusm/provider_api.h +++ b/include/dpusm/provider_api.h @@ -183,16 +183,16 @@ typedef struct dpusm_provider_functions { } dpusm_pf_t; /* returns -ERRNO instead of DPUSM_* */ -int dpusm_register_bsd(const char *name, const dpusm_pf_t *funcs); -int dpusm_unregister_bsd(const char *name); -int dpusm_register_gpl(const char *name, const dpusm_pf_t *funcs); -int dpusm_unregister_gpl(const char *name); +int dpusm_register_bsd(struct module *module, const dpusm_pf_t *funcs); +int dpusm_unregister_bsd(struct module *module); +int dpusm_register_gpl(struct module *module, const dpusm_pf_t *funcs); +int dpusm_unregister_gpl(struct module *module); /* * call when backing DPU goes down unexpectedly * * provider is not unregistered, so dpusm_unregister still needs to be called */ -void dpusm_invalidate(const char *name); +void dpusm_invalidate(struct module *module); #endif diff --git a/src/dpusm.c b/src/dpusm.c index df1e6bc..ad3d639 100644 --- a/src/dpusm.c +++ b/src/dpusm.c @@ -11,13 +11,13 @@ static dpusm_t dpusm; int -dpusm_register_bsd(const char *name, const dpusm_pf_t *funcs) { - return dpusm_provider_register(&dpusm, name, funcs); +dpusm_register_bsd(struct module *module, const dpusm_pf_t *funcs) { + return dpusm_provider_register(&dpusm, module, funcs); } int -dpusm_unregister_bsd(const char *name) { - return dpusm_provider_unregister(&dpusm, name); +dpusm_unregister_bsd(struct module *module) { + return dpusm_provider_unregister(&dpusm, module); } /* provider facing functions */ @@ -25,13 +25,13 @@ EXPORT_SYMBOL(dpusm_register_bsd); EXPORT_SYMBOL(dpusm_unregister_bsd); int -dpusm_register_gpl(const char *name, const dpusm_pf_t *funcs) { - return dpusm_provider_register(&dpusm, name, funcs); +dpusm_register_gpl(struct module *module, const dpusm_pf_t *funcs) { + return dpusm_provider_register(&dpusm, module, funcs); } int -dpusm_unregister_gpl(const char *name) { - return dpusm_provider_unregister(&dpusm, name); +dpusm_unregister_gpl(struct module *module) { + return dpusm_provider_unregister(&dpusm, module); } /* provider facing functions */ @@ -44,8 +44,8 @@ EXPORT_SYMBOL_GPL(dpusm_unregister_gpl); * Provider is not unregistered, so dpusm_unregister still needs to be called. * Using name instead of handle because the provider handle is not available to the provider. */ -void dpusm_invalidate(const char *name) { - dpusm_provider_invalidate(&dpusm, name); +void dpusm_invalidate(struct module *module) { + dpusm_provider_invalidate(&dpusm, module); } EXPORT_SYMBOL(dpusm_invalidate); diff --git a/src/provider.c b/src/provider.c index 15ce09a..d3f9062 100644 --- a/src/provider.c +++ b/src/provider.c @@ -84,7 +84,7 @@ find_provider(dpusm_t *dpusm, const char *name) { struct list_head *it = NULL; list_for_each(it, &dpusm->providers) { dpusm_ph_t *dpusmph = list_entry(it, dpusm_ph_t, list); - const char *p_name = dpusmph->name; + const char *p_name = module_name(dpusmph->module); const size_t p_name_len = strlen(p_name); if (name_len == p_name_len) { if (memcmp(name, p_name, p_name_len) == 0) { @@ -108,8 +108,9 @@ static void print_supported(const char *name, const char *func) } static dpusm_ph_t * -dpusmph_init(const char *name, const dpusm_pf_t *funcs) +dpusmph_init(struct module *module, const dpusm_pf_t *funcs) { + const char *name = module_name(module); dpusm_ph_t *dpusmph = dpusm_mem_alloc(sizeof(dpusm_ph_t)); if (dpusmph) { /* fill in capabilities bitmasks */ @@ -223,7 +224,7 @@ dpusmph_init(const char *name, const dpusm_pf_t *funcs) dpusmph->capabilities.io &= ~DPUSM_IO_DISK; } - dpusmph->name = name; + dpusmph->module = module; dpusmph->funcs = funcs; dpusmph->self = dpusmph; atomic_set(&dpusmph->refs, 0); @@ -234,7 +235,12 @@ dpusmph_init(const char *name, const dpusm_pf_t *funcs) /* add a new provider */ int -dpusm_provider_register(dpusm_t *dpusm, const char *name, const dpusm_pf_t *funcs) { +dpusm_provider_register(dpusm_t *dpusm, struct module *module, const dpusm_pf_t *funcs) { + if (!try_module_get(module)) { + printk("Error: Could not increment reference count of %s\n", module_name(module)); + return -ECANCELED; + } + const int rc = dpusm_provider_sane_at_load(funcs); if (rc != DPUSM_OK) { static const size_t max = @@ -262,22 +268,25 @@ dpusm_provider_register(dpusm_t *dpusm, const char *name, const dpusm_pf_t *func dpusm_mem_free(buf, size); + module_put(module); return -EINVAL; } dpusm_provider_write_lock(dpusm); - dpusm_ph_t **found = find_provider(dpusm, name); + dpusm_ph_t **found = find_provider(dpusm, module_name(module)); if (found) { printk("%s: DPUSM Provider with the name \"%s\" (%p) already exists. %zu providers registered.\n", __func__, name, *found, dpusm->count); dpusm_provider_write_unlock(dpusm); + module_put(module); return -EEXIST; } - dpusm_ph_t *provider = dpusmph_init(name, funcs); + dpusm_ph_t *provider = dpusmph_init(module, funcs); if (!provider) { dpusm_provider_write_unlock(dpusm); + module_put(module); return -ECANCELED; } @@ -321,19 +330,19 @@ dpusm_provider_unregister_handle(dpusm_t *dpusm, dpusm_ph_t **provider) { } int -dpusm_provider_unregister(dpusm_t *dpusm, const char *name) { +dpusm_provider_unregister(dpusm_t *dpusm, struct module *module) { dpusm_provider_write_lock(dpusm); - dpusm_ph_t **provider = find_provider(dpusm, name); + dpusm_ph_t **provider = find_provider(dpusm, module_name(module)); if (!provider) { - printk("%s: Could not find provider with name \"%s\"\n", __func__, name); + printk("%s: Could not find provider with name \"%s\"\n", __func__, module_name(module); dpusm_provider_write_unlock(dpusm); return DPUSM_ERROR; } void *addr = *provider; const int rc = dpusm_provider_unregister_handle(dpusm, provider); - printk("%s: Unregistered \"%s\" (%p): %d\n", __func__, name, addr, rc); + printk("%s: Unregistered \"%s\" (%p): %d\n", __func__, module_name(module), addr, rc); dpusm_provider_write_unlock(dpusm); return rc; @@ -353,7 +362,7 @@ dpusm_provider_get(dpusm_t *dpusm, const char *name) { atomic_inc(&(*provider)->refs); atomic_inc(&dpusm->active); printk("%s: User has been given a handle to \"%s\" (%p) (now %d users).\n", - __func__, (*provider)->name, *provider, atomic_read(&(*provider)->refs)); + __func__, name, *provider, atomic_read(&(*provider)->refs)); if ((*provider)->funcs->at_connect) { (*provider)->funcs->at_connect(); @@ -378,10 +387,11 @@ dpusm_provider_put(dpusm_t *dpusm, void *handle) { if (!atomic_read(&(*provider)->refs)) { printk("%s Error: Cannot decrement provider \"%s\" user count already at 0.\n", - __func__, (*provider)->name); + __func__, module_name((*provider)->module)); return DPUSM_ERROR; } + module_put((*provider)->module); atomic_dec(&(*provider)->refs); atomic_dec(&dpusm->active); @@ -392,7 +402,7 @@ dpusm_provider_put(dpusm_t *dpusm, void *handle) { } printk("%s: User has returned a handle to \"%s\" (%p) (now %d users).\n", - __func__, (*provider)->name, *provider, atomic_read(&(*provider)->refs)); + __func__, (*provider)->name, module_name((*provider)->module), atomic_read(&(*provider)->refs)); return DPUSM_OK; } @@ -406,18 +416,18 @@ dpusm_provider_write_unlock(dpusm_t *dpusm) { write_unlock(&dpusm->lock); } -void dpusm_provider_invalidate(dpusm_t *dpusm, const char *name) { +void dpusm_provider_invalidate(dpusm_t *dpusm, struct module *module) { dpusm_provider_write_lock(dpusm); - dpusm_ph_t **provider = find_provider(dpusm, name); + dpusm_ph_t **provider = find_provider(dpusm, module_name(module)); if (provider && *provider) { (*provider)->funcs = NULL; memset(&(*provider)->capabilities, 0, sizeof((*provider)->capabilities)); printk("%s: Provider \"%s\" has been invalidated with %d users active.\n", - __func__, name, atomic_read(&(*provider)->refs)); + __func__, module_name(module), atomic_read(&(*provider)->refs)); } else { printk("%s: Error: Did not find provider \"%s\"\n", - __func__, name); + __func__, module_name(module)); } dpusm_provider_write_unlock(dpusm); }