diff --git a/lib/searchkick/query.rb b/lib/searchkick/query.rb index a207612..da90336 100644 --- a/lib/searchkick/query.rb +++ b/lib/searchkick/query.rb @@ -167,14 +167,17 @@ module Searchkick custom_filters = [] boost_by = options[:boost_by] || {} + if boost_by.is_a?(Array) - boost_by = Hash[boost_by.map { |f| [f, {factor: 1}] }] + boost_by_sum = Hash[boost_by.map { |f| [f, {factor: 1}] }] + elsif boost_by.is_a?(Hash) + boost_by_multiply, boost_by_sum = boost_by.partition { |k,v| v[:boost_mode] == "multiply" }.map{|i| Hash[i] } end if options[:boost] - boost_by[options[:boost]] = {factor: 1} + boost_by_sum[options[:boost]] = {factor: 1} end - boost_by.each do |field, value| + boost_by_sum.each do |field, value| script_score = if below12 {script_score: {script: "#{value[:factor].to_f} * log(doc['#{field}'].value + 2.718281828)"}} @@ -191,6 +194,28 @@ module Searchkick }.merge(script_score) end + if boost_by_multiply + multiply_filters = [] + + boost_by_multiply.each do |field, value| + script_score = + if below12 + {script_score: {script: "#{value[:factor].to_f} * doc['#{field}'].value"}} + else + value[:factor] ||= 1 + {field_value_factor: {field: field, factor: value[:factor].to_f}} + end + + multiply_filters << { + filter: { + exists: { + field: field + } + } + }.merge(script_score) + end + end + boost_where = options[:boost_where] || {} if options[:user_id] && personalize_field boost_where[personalize_field] = options[:user_id] @@ -238,31 +263,7 @@ module Searchkick } end - multiply_filters = [] - - multiply_by = options[:multiply_by] || {} - if multiply_by.is_a?(Array) - multiply_by = Hash[multiply_by.map { |f| [f, {factor: 1}] }] - end - - multiply_by.each do |field, value| - script_score = - if below12 - {script_score: {script: "#{value[:factor].to_f} * doc['#{field}'].value"}} - else - {field_value_factor: {field: field, factor: value[:factor].to_f}} - end - - multiply_filters << { - filter: { - exists: { - field: field - } - } - }.merge(script_score) - end - - if multiply_filters.any? + if multiply_filters && multiply_filters.any? payload = { function_score: { functions: multiply_filters, diff --git a/test/boost_test.rb b/test/boost_test.rb index a3f0868..11435ad 100644 --- a/test/boost_test.rb +++ b/test/boost_test.rb @@ -108,8 +108,7 @@ class TestBoost < Minitest::Test {name: "Tomato C", found_rate: 0.5} ] - assert_order "tomato", ["Tomato B", "Tomato A", "Tomato C"], multiply_by: [:found_rate] - assert_order "tomato", ["Tomato B", "Tomato A", "Tomato C"], multiply_by: {found_rate: {factor: 1}} + assert_order "tomato", ["Tomato B", "Tomato A", "Tomato C"], boost_by: {found_rate: {factor: 1, boost_mode: "multiply"}} end def test_boost_where -- libgit2 0.21.0