diff --git a/CHANGELOG.md b/CHANGELOG.md index acae331..a9e9bb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,11 @@ they have the usual target under test aspects applied. This allows testing multiple targets in one test with a mixture of configurations. ([#67](https://github.com/bazelbuild/rules_testing/issues/67)) + * `analysis_test` now takes the parameter `provider_subject_factories`. + If you want to perform assertions on custom providers, you no longer need + to use the factory parameter each time you want to retrieve the provider. + instead, you now write `analysis_test(..., provider_subject_factories = [ + type = FooInfo, name = "FooInfo", factory = FooSubjectFactory])`. ## [0.5.0] - 2023-10-04 diff --git a/lib/private/analysis_test.bzl b/lib/private/analysis_test.bzl index c491ebb..df2b647 100644 --- a/lib/private/analysis_test.bzl +++ b/lib/private/analysis_test.bzl @@ -21,6 +21,7 @@ load("@bazel_skylib//lib:dicts.bzl", "dicts") load("@bazel_skylib//lib:types.bzl", "types") load("//lib:truth.bzl", "truth") load("//lib:util.bzl", "recursive_testing_aspect", "testing_aspect") +load("//lib/private:target_subject.bzl", "PROVIDER_SUBJECT_FACTORIES") load("//lib/private:util.bzl", "get_test_name_from_function") def _fail(env, msg): @@ -37,7 +38,7 @@ def _fail(env, msg): print(full_msg) env.failures.append(full_msg) -def _begin_analysis_test(ctx): +def _begin_analysis_test(ctx, provider_subject_factories): """Begins a unit test. This should be the first function called in a unit test implementation @@ -48,6 +49,10 @@ def _begin_analysis_test(ctx): Args: ctx: The Starlark context. Pass the implementation function's `ctx` argument in verbatim. + provider_subject_factories: list of ProviderSubjectFactory structs, these are + additional provider factories on top of built in ones. + See analysis_test's provider_subject_factory arg for more details on + the type. Returns: An analysis_test "environment" struct. The following fields are public: @@ -86,6 +91,7 @@ def _begin_analysis_test(ctx): truth_env = struct( ctx = ctx, fail = lambda msg: _fail(failures_env, msg), + provider_subject_factories = PROVIDER_SUBJECT_FACTORIES + provider_subject_factories, ) analysis_test_env = struct( ctx = ctx, @@ -126,7 +132,8 @@ def analysis_test( fragments = [], config_settings = {}, extra_target_under_test_aspects = [], - collect_actions_recursively = False): + collect_actions_recursively = False, + provider_subject_factories = []): """Creates an analysis test from its implementation function. An analysis test verifies the behavior of a "real" rule target by examining @@ -189,6 +196,7 @@ def analysis_test( analysis test target itself (e.g. common attributes like `tags`, `target_compatible_with`, or attributes from `attrs`). Note that these are for the analysis test target itself, not the target under test. + fragments: An optional list of fragment names that can be used to give rules access to language-specific parts of configuration. config_settings: A dictionary of configuration settings to change for the target under @@ -202,6 +210,13 @@ def analysis_test( in addition to those set up by default for the test harness itself. collect_actions_recursively: If true, runs testing_aspect over all attributes, otherwise it is only applied to the target under test. + provider_subject_factories: Optional list of ProviderSubjectFactory structs, + these are additional provider factories on top of built in ones. + A ProviderSubjectFactory is a struct with the following fields: + * type: A provider object, e.g. the callable FooInfo object + * name: A human-friendly name of the provider (eg. "FooInfo") + * factory: A callable to convert an instance of the provider to a + subject; see TargetSubject.provider()'s factory arg for the signature. Returns: (None) @@ -290,7 +305,7 @@ def analysis_test( ) def wrapped_impl(ctx): - env, target = _begin_analysis_test(ctx) + env, target = _begin_analysis_test(ctx, provider_subject_factories) impl(env, target) return _end_analysis_test(env) diff --git a/lib/private/target_subject.bzl b/lib/private/target_subject.bzl index 47d8b94..a8d7312 100644 --- a/lib/private/target_subject.bzl +++ b/lib/private/target_subject.bzl @@ -187,7 +187,7 @@ def _target_subject_has_provider(self, provider): if self.meta.has_provider(self.target, provider): return self.meta.add_failure( - "expected to have provider: {}".format(_provider_name(provider)), + "expected to have provider: {}".format(_provider_subject_factory(self, provider).name), "but provider was not found", ) @@ -233,23 +233,30 @@ def _target_subject_provider(self, provider_key, factory = None): the subject for the found provider. Required if the provider key is not an inherently supported provider. It must have the following signature: `def factory(value, /, *, meta)`. + Additional types of providers can be pre-registered by using the + `provider_subject_factories` arg of `analysis_test`. Returns: A subject wrapper of the provider value. """ - if not factory: - for key, value in _PROVIDER_SUBJECT_FACTORIES: - if key == provider_key: - factory = value - break + if factory: + provider_subject_factory = struct( + type = provider_key, + # str(provider_key) just returns "", which isn't helpful. + # For lack of a better option, just call it unknown + name = "", + factory = factory, + ) + else: + provider_subject_factory = _provider_subject_factory(self, provider_key) - if not factory: - fail("Unsupported provider: {}".format(provider_key)) + if not provider_subject_factory.factory: + fail("Unsupported provider: {}".format(provider_subject_factory.name)) info = self.target[provider_key] - return factory( + return provider_subject_factory.factory( info, - meta = self.meta.derive("provider({})".format(provider_key)), + meta = self.meta.derive("provider({})".format(provider_subject_factory.name)), ) def _target_subject_action_generating(self, short_path): @@ -385,18 +392,35 @@ def _target_subject_attr(self, name, *, factory = None): meta = self.meta.derive("attr({})".format(name)), ) -# Providers aren't hashable, so we have to use a list of (key, value) -_PROVIDER_SUBJECT_FACTORIES = [ - (InstrumentedFilesInfo, InstrumentedFilesInfoSubject.new), - (RunEnvironmentInfo, RunEnvironmentInfoSubject.new), - (testing.ExecutionInfo, ExecutionInfoSubject.new), -] +def _provider_subject_factory(self, provider): + for provider_subject_factory in self.meta.env.provider_subject_factories: + if provider_subject_factory.type == provider: + return provider_subject_factory -def _provider_name(provider): - # This relies on implementation details of how Starlark represents - # providers, and isn't entirely accurate, but works well enough - # for error messages. - return str(provider).split("")[0] + return struct( + type = provider, + name = "", + factory = None, + ) + +# Providers aren't hashable, so we have to use a list of structs. +PROVIDER_SUBJECT_FACTORIES = [ + struct( + type = InstrumentedFilesInfo, + name = "InstrumentedFilesInfo", + factory = InstrumentedFilesInfoSubject.new, + ), + struct( + type = RunEnvironmentInfo, + name = "RunEnvironmentInfo", + factory = RunEnvironmentInfoSubject.new, + ), + struct( + type = testing.ExecutionInfo, + name = "testing.ExecutionInfo", + factory = ExecutionInfoSubject.new, + ), +] # We use this name so it shows up nice in docs. # buildifier: disable=name-conventions diff --git a/tests/truth_tests.bzl b/tests/truth_tests.bzl index 322c528..a544034 100644 --- a/tests/truth_tests.bzl +++ b/tests/truth_tests.bzl @@ -25,11 +25,13 @@ _IS_BAZEL_6_OR_HIGHER = (testing.ExecutionInfo == testing.ExecutionInfo) _suite = [] def _fake_env(env): + provider_subject_factories = env.expect.meta.env.provider_subject_factories failures = [] env1 = struct( ctx = env.ctx, failures = failures, fail = lambda msg: failures.append(msg), # Silent fail + provider_subject_factories = provider_subject_factories, ) env2 = struct( ctx = env.ctx, @@ -37,6 +39,7 @@ def _fake_env(env): fail = lambda msg: failures.append(msg), # Silent fail expect = truth.expect(env1), reset = lambda: failures.clear(), + provider_subject_factories = provider_subject_factories, ) return env2