diff --git a/app/services/foreman/renderer/scope/macros/loaders.rb b/app/services/foreman/renderer/scope/macros/loaders.rb index ba930efb3a2..0fd15a72cd4 100644 --- a/app/services/foreman/renderer/scope/macros/loaders.rb +++ b/app/services/foreman/renderer/scope/macros/loaders.rb @@ -66,7 +66,7 @@ module Loaders <% end %>", desc: "Prints users in Europe, their login, when a user was logged on for the last time and the authentication source" end define_method name do |search: '', includes: nil, preload: nil, joins: nil, select: nil, batch: 1_000, limit: nil| - load_resource(klass: model, search: search, permission: permission, includes: includes, preload: preload, joins: joins, select: select, batch: batch, limit: limit) + load_resource(klass: model, search: search, permission: permission, includes: includes, preload: preload, joins: joins, select: select, batch: batch, limit: limit, defined_as: name) end end @@ -76,9 +76,12 @@ module Loaders # .each { |batch| batch.each { |record| record.name }} # or # .each_record { |record| record.name } - def load_resource(klass:, search:, permission:, batch: 1_000, includes: nil, limit: nil, select: nil, joins: nil, where: nil, preload: nil) + def load_resource(klass:, search:, permission:, batch: 1_000, includes: nil, limit: nil, select: nil, joins: nil, where: nil, preload: nil, defined_as: __method__) limit ||= 10 if preview? + type_check!(defined_as, 'select', select, [Symbol, [Array, Symbol]]) + type_check!(defined_as, 'joins', joins, [Symbol, Hash, [Array, Symbol]]) + base = klass base = base.search_for(search) base = base.preload(preload) unless preload.nil? @@ -90,6 +93,18 @@ def load_resource(klass:, search:, permission:, batch: 1_000, includes: nil, lim base = base.select(select) unless select.nil? base.in_batches(of: batch) end + + private + + def type_check!(method, label, what, spec) + return if what.nil? + return if spec.any? { |type, subtype| what.is_a?(type) && (subtype.nil? || what.all? { |value| value.is_a?(subtype) }) } + + options = spec.map { |type, subtype| subtype.nil? ? type : "#{type} of #{subtype.to_s.pluralize}" } + last = options.pop + options = options.join(', ') + " or #{last}" + raise ArgumentError, "Value of '#{label}' passed to #{method} must be #{options}" + end end end end diff --git a/test/unit/foreman/renderer/scope/macros/loader_macros_test.rb b/test/unit/foreman/renderer/scope/macros/loader_macros_test.rb new file mode 100644 index 00000000000..a8395bcf7af --- /dev/null +++ b/test/unit/foreman/renderer/scope/macros/loader_macros_test.rb @@ -0,0 +1,49 @@ +require 'test_helper' + +class LoaderMacrosTest < ActiveSupport::TestCase + setup do + host = FactoryBot.build_stubbed(:host) + template = OpenStruct.new( + name: 'Test', + template: 'Test' + ) + source = Foreman::Renderer::Source::Database.new( + template + ) + @scope = Class.new(Foreman::Renderer::Scope::Base) do + include Foreman::Renderer::Scope::Macros::Base + end.send(:new, source: source) + end + + describe '#load_resources' do + it 'should accept custom selects' do + @scope.load_hosts(select: :id) + @scope.load_hosts(select: [:id, :name]) + end + + it 'should reject unacceptable selects' do + assert_raises(ArgumentError) { @scope.load_hosts(select: 'a string value') } + assert_raises(ArgumentError) { @scope.load_hosts(select: {:key => :value}) } + assert_raises(ArgumentError) { @scope.load_hosts(select: [:mixed, 'array']) } + error = assert_raises(ArgumentError) { @scope.load_hosts(select: 7) } + assert_match /Value of 'select'/, error.message + assert_match /load_hosts/, error.message + assert_match /Symbol or Array of Symbols/, error.message + end + + it 'should accept custom joins' do + @scope.load_hosts(joins: :interfaces) + @scope.load_hosts(joins: {:interfaces => :subnet}) + @scope.load_hosts(joins: [:interfaces, :domain]) + end + + it 'should reject unacceptable joins' do + assert_raises(ArgumentError) { @scope.load_hosts(joins: 'a string value') } + assert_raises(ArgumentError) { @scope.load_hosts(joins: [:mixed, 'array']) } + error = assert_raises(ArgumentError) { @scope.load_hosts(joins: 7) } + assert_match /Value of 'joins'/, error.message + assert_match /load_hosts/, error.message + assert_match /Symbol, Hash or Array of Symbols/, error.message + end + end +end