diff --git a/README.md b/README.md
index 9806c21..5a8f048 100644
--- a/README.md
+++ b/README.md
@@ -10,7 +10,7 @@ give it a try.
## Elasticsearch version
-* Currently designed for Elasticsearch 5.6.0.
+* Currently designed for Elasticsearch 6.0.0.
* for Elasticsearch 5.2.2 use branch `es-5.2.2`
* for Elasticsearch 2.4.4 use branch `es-2.4.4`
@@ -42,28 +42,29 @@ give it a try.
* The vector can be of any dimension
### Converting a vector to Base64
-to convert an array of doubles to a base64 string we use these example methods:
+to convert an array of float32 to a base64 string we use these example methods:
**Java**
```
-public static final String convertArrayToBase64(double[] array) {
- final int capacity = 8 * array.length;
- final ByteBuffer bb = ByteBuffer.allocate(capacity);
- for (int i = 0; i < array.length; i++) {
- bb.putDouble(array[i]);
- }
- bb.rewind();
- final ByteBuffer encodedBB = Base64.getEncoder().encode(bb);
- return new String(encodedBB.array());
+public static float[] convertBase64ToArray(String base64Str) {
+ final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes());
+ final FloatBuffer floatBuffer = ByteBuffer.wrap(decode).asFloatBuffer();
+ final float[] dims = new float[floatBuffer.capacity()];
+ floatBuffer.get(dims);
+
+ return dims;
}
-public static double[] convertBase64ToArray(String base64Str) {
- final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes());
- final DoubleBuffer doubleBuffer = ByteBuffer.wrap(decode).asDoubleBuffer();
+public static String convertArrayToBase64(float[] array) {
+ final int capacity = Float.BYTES * array.length;
+ final ByteBuffer bb = ByteBuffer.allocate(capacity);
+ for (float v : array) {
+ bb.putFloat(v);
+ }
+ bb.rewind();
+ final ByteBuffer encodedBB = Base64.getEncoder().encode(bb);
- final double[] dims = new double[doubleBuffer.capacity()];
- doubleBuffer.get(dims);
- return dims;
+ return new String(encodedBB.array());
}
```
**Python**
@@ -71,14 +72,14 @@ public static double[] convertBase64ToArray(String base64Str) {
import base64
import numpy as np
-dbig = np.dtype('>f8')
+dfloat32 = np.dtype('>f4')
def decode_float_list(base64_string):
bytes = base64.b64decode(base64_string)
- return np.frombuffer(bytes, dtype=dbig).tolist()
+ return np.frombuffer(bytes, dtype=dfloat32).tolist()
def encode_array(arr):
- base64_str = base64.b64encode(np.array(arr).astype(dbig)).decode("utf-8")
+ base64_str = base64.b64encode(np.array(arr).astype(dfloat32)).decode("utf-8")
return base64_str
```
@@ -87,11 +88,11 @@ def encode_array(arr):
require 'base64'
def decode_float_list(base64_string)
- Base64.strict_decode64(base64_string).unpack('G*')
+ Base64.strict_decode64(base64_string).unpack('g*')
end
def encode_array(arr)
- Base64.strict_encode64(arr.pack('G*'))
+ Base64.strict_encode64(arr.pack('g*'))
end
```
@@ -103,12 +104,12 @@ import(
"encoding/base64"
)
-func convertArrayToBase64(array []float64) string {
- bytes := make([]byte, 0, 8*len(array))
+func convertArrayToBase64(array []float32) string {
+ bytes := make([]byte, 0, 4*len(array))
for _, a := range array {
- bits := math.Float64bits(a)
- b := make([]byte, 8)
- binary.BigEndian.PutUint64(b, bits)
+ bits := math.Float32bits(a)
+ b := make([]byte, 4)
+ binary.BigEndian.PutUint32(b, bits)
bytes = append(bytes, b...)
}
@@ -116,18 +117,18 @@ func convertArrayToBase64(array []float64) string {
return encoded
}
-func convertBase64ToArray(base64Str string) ([]float64, error) {
+func convertBase64ToArray(base64Str string) ([]float32, error) {
decoded, err := base64.StdEncoding.DecodeString(base64Str)
if err != nil {
return nil, err
}
length := len(decoded)
- array := make([]float64, 0, length/8)
+ array := make([]float32, 0, length/4)
- for i := 0; i < len(decoded); i += 8 {
- bits := binary.BigEndian.Uint64(decoded[i : i+8])
- f := math.Float64frombits(bits)
+ for i := 0; i < len(decoded); i += 4 {
+ bits := binary.BigEndian.Uint32(decoded[i : i+4])
+ f := math.Float32frombits(bits)
array = append(array, f)
}
return array, nil
@@ -146,7 +147,7 @@ func convertBase64ToArray(base64Str string) ([]float64, error) {
"boost_mode": "replace",
"script_score": {
"script": {
- "inline": "binary_vector_score",
+ "source": "binary_vector_score",
"lang": "knn",
"params": {
"cosine": false,
diff --git a/pom.xml b/pom.xml
index d765043..79d7430 100755
--- a/pom.xml
+++ b/pom.xml
@@ -7,7 +7,7 @@
elasticsearch-binary-vector-scoring
com.liorkn.elasticsearch
elasticsearch-binary-vector-scoring
- 5.6.0
+ 6.0.0
ElasticSearch Plugin for Binary Vector Scoring
@@ -27,11 +27,11 @@
${project.basedir}/src/main/resources/license-check/license_header_definition.xml
warn
- 5.6.0
+ 6.0.0
2.4
4.4.8
4.12
- 2.7.4
+ 2.8.11.3
@@ -65,7 +65,7 @@
org.elasticsearch.plugin
- transport-netty3-client
+ transport-netty4-client
${elasticsearch.version}
test
@@ -86,12 +86,6 @@
-
-
-
-
-
-
org.codelibs.elasticsearch.module
lang-painless
diff --git a/src/main/java/com/liorkn/elasticsearch/Util.java b/src/main/java/com/liorkn/elasticsearch/Util.java
index de81af8..53cf1f6 100644
--- a/src/main/java/com/liorkn/elasticsearch/Util.java
+++ b/src/main/java/com/liorkn/elasticsearch/Util.java
@@ -1,7 +1,7 @@
package com.liorkn.elasticsearch;
import java.nio.ByteBuffer;
-import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
import java.util.Base64;
/**
@@ -9,23 +9,24 @@
*/
public class Util {
- public static final double[] convertBase64ToArray(String base64Str) {
+ public static float[] convertBase64ToArray(String base64Str) {
final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes());
- final DoubleBuffer doubleBuffer = ByteBuffer.wrap(decode).asDoubleBuffer();
+ final FloatBuffer floatBuffer = ByteBuffer.wrap(decode).asFloatBuffer();
+ final float[] dims = new float[floatBuffer.capacity()];
+ floatBuffer.get(dims);
- final double[] dims = new double[doubleBuffer.capacity()];
- doubleBuffer.get(dims);
return dims;
}
- public static final String convertArrayToBase64(double[] array) {
- final int capacity = 8 * array.length;
+ public static String convertArrayToBase64(float[] array) {
+ final int capacity = Float.BYTES * array.length;
final ByteBuffer bb = ByteBuffer.allocate(capacity);
- for (int i = 0; i < array.length; i++) {
- bb.putDouble(array[i]);
+ for (double v : array) {
+ bb.putFloat((float) v);
}
bb.rewind();
final ByteBuffer encodedBB = Base64.getEncoder().encode(bb);
+
return new String(encodedBB.array());
}
}
diff --git a/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java
new file mode 100644
index 0000000..6c00692
--- /dev/null
+++ b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java
@@ -0,0 +1,36 @@
+package com.liorkn.elasticsearch.engine;
+
+import com.liorkn.elasticsearch.script.VectorScoreScript;
+
+import java.util.Map;
+
+import org.elasticsearch.script.ScriptContext;
+import org.elasticsearch.script.ScriptEngine;
+import org.elasticsearch.script.SearchScript;
+
+/** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */
+public class VectorScoringScriptEngine implements ScriptEngine {
+
+ public static final String NAME = "knn";
+ private static final String SCRIPT_SOURCE = "binary_vector_score";
+
+ @Override
+ public String getType() {
+ return NAME;
+ }
+
+ @Override
+ public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) {
+ if (context.equals(SearchScript.CONTEXT) == false) {
+ throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]");
+ }
+
+ // we use the script "source" as the script identifier
+ if (!SCRIPT_SOURCE.equals(scriptSource)) {
+ throw new IllegalArgumentException("Unknown script name " + scriptSource);
+ }
+
+ SearchScript.Factory factory = VectorScoreScript.VectorScoreScriptFactory::new;
+ return context.factoryClazz.cast(factory);
+ }
+}
diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java
index 88a3599..c279636 100755
--- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java
+++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java
@@ -13,13 +13,15 @@
*/
package com.liorkn.elasticsearch.plugin;
-import com.liorkn.elasticsearch.service.VectorScoringScriptEngineService;
+import com.liorkn.elasticsearch.engine.VectorScoringScriptEngine;
+
+import java.util.Collection;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.ScriptPlugin;
-import org.elasticsearch.script.ScriptEngineService;
-
+import org.elasticsearch.script.ScriptContext;
+import org.elasticsearch.script.ScriptEngine;
/**
* This class is instantiated when Elasticsearch loads the plugin for the
* first time. If you change the name of this plugin, make sure to update
@@ -27,9 +29,8 @@
*/
public final class VectorScoringPlugin extends Plugin implements ScriptPlugin {
- public final ScriptEngineService getScriptEngineService(Settings settings) {
- return new VectorScoringScriptEngineService(settings);
+ @Override
+ public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) {
+ return new VectorScoringScriptEngine();
}
-
-
}
\ No newline at end of file
diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java
old mode 100755
new mode 100644
index 0b87cf3..be06fd4
--- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java
+++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java
@@ -1,114 +1,90 @@
-/*
-Based on: https://discuss.elastic.co/t/vector-scoring/85227/4
-and https://github.com/MLnick/elasticsearch-vector-scoring
-
-another slower implementation using strings: https://github.com/ginobefun/elasticsearch-feature-vector-scoring
-
-storing arrays is no luck - lucine index doesn't keep the array members orders
-https://www.elastic.co/guide/en/elasticsearch/guide/current/complex-core-fields.html
-
-Delimited Payload Token Filter: https://www.elastic.co/guide/en/elasticsearch/reference/2.4/analysis-delimited-payload-tokenfilter.html
-
-
- */
-
package com.liorkn.elasticsearch.script;
-import com.liorkn.elasticsearch.Util;
-import org.apache.lucene.index.BinaryDocValues;
-import org.apache.lucene.store.ByteArrayDataInput;
-import org.elasticsearch.common.Nullable;
-import org.elasticsearch.script.ExecutableScript;
-import org.elasticsearch.script.LeafSearchScript;
-import org.elasticsearch.script.ScriptException;
-
-import java.nio.ByteBuffer;
-import java.nio.DoubleBuffer;
+import java.io.IOException;
+import java.io.UncheckedIOException;
import java.util.ArrayList;
-import java.util.Base64;
import java.util.Map;
-/**
- * Script that scores documents based on cosine similarity embedding vectors.
- */
-public final class VectorScoreScript implements LeafSearchScript, ExecutableScript {
-
- public static final String SCRIPT_NAME = "binary_vector_score";
+import org.apache.lucene.index.BinaryDocValues;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.store.ByteArrayDataInput;
+import org.elasticsearch.script.SearchScript;
+import org.elasticsearch.search.lookup.SearchLookup;
- private static final int DOUBLE_SIZE = 8;
+import com.liorkn.elasticsearch.Util;
- // the field containing the vectors to be scored against
- public final String field;
+public final class VectorScoreScript extends SearchScript {
- private int docId;
private BinaryDocValues binaryEmbeddingReader;
-
- private final double[] inputVector;
- private final double magnitude;
-
+
+ private final String field;
private final boolean cosine;
+ private final float[] inputVector;
+ private final float magnitude;
+
@Override
public long runAsLong() {
- return ((Number)this.run()).longValue();
- }
- @Override
- public double runAsDouble() {
- return ((Number)this.run()).doubleValue();
- }
- @Override
- public void setNextVar(String name, Object value) {}
- @Override
- public void setDocument(int docId) {
- this.docId = docId;
+ return (long) runAsDouble();
}
+
+ @Override
+ public double runAsDouble() {
+ try {
+ final byte[] bytes = binaryEmbeddingReader.binaryValue().bytes;
+ final ByteArrayDataInput input = new ByteArrayDataInput(bytes);
+
+ input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls
+
+ final int len = input.readVInt();
+ // in case vector is of different size
+ if (len != inputVector.length * Float.BYTES) {
+ return 0.0;
+ }
+
+ float score = 0;
+
+ if (cosine) {
+ float docVectorNorm = 0.0f;
+ for (int i = 0; i < inputVector.length; i++) {
+ float v = Float.intBitsToFloat(input.readInt());
+ docVectorNorm += v * v; // inputVector norm
+ score += v * inputVector[i]; // dot product
+ }
+
+ if (docVectorNorm == 0 || magnitude == 0) {
+ return 0f;
+ } else {
+ return score / (Math.sqrt(docVectorNorm) * magnitude);
+ }
+ } else {
+ for (int i = 0; i < inputVector.length; i++) {
+ float v = Float.intBitsToFloat(input.readInt());
+ score += v * inputVector[i]; // dot product
+ }
- public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) {
- if(binaryEmbeddingReader == null) {
- throw new IllegalStateException("binaryEmbeddingReader can't be null");
- }
- this.binaryEmbeddingReader = binaryEmbeddingReader;
- }
-
-
- /**
- * Factory that is registered in
- * {@link VectorScoringPlugin#onModule(org.elasticsearch.script.ScriptModule)}
- * method when the plugin is loaded.
- */
- public static class Factory {
-
- /**
- * This method is called for every search on every shard.
- *
- * @param params
- * list of script parameters passed with the query
- * @return new native script
- */
- public ExecutableScript newScript(@Nullable Map params) throws ScriptException {
- return new VectorScoreScript(params);
- }
-
- /**
- * Indicates if document scores may be needed by the produced scripts.
- *
- * @return {@code true} if scores are needed.
- */
- public boolean needsScores() {
- return false;
+ return score;
+ }
+ } catch (Exception e) {
+ return 0.0;
}
-
+ }
+
+ @Override
+ public void setDocument(int docId) {
+ try {
+ this.binaryEmbeddingReader.advanceExact(docId);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
}
-
- /**
- * Init
- * @param params index that a scored are placed in this parameter. Initialize them here.
- */
@SuppressWarnings("unchecked")
- public VectorScoreScript(Map params) {
- final Object cosineBool = params.get("cosine");
- cosine = cosineBool != null ?
+ public VectorScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) {
+ super(params, lookup, leafContext);
+
+ final Object cosineBool = params.get("cosine");
+ this.cosine = cosineBool != null ?
(boolean)cosineBool :
true;
@@ -121,9 +97,9 @@ public VectorScoreScript(Map params) {
final Object vector = params.get("vector");
if(vector != null) {
final ArrayList tmp = (ArrayList) vector;
- inputVector = new double[tmp.size()];
+ inputVector = new float[tmp.size()];
for (int i = 0; i < inputVector.length; i++) {
- inputVector[i] = tmp.get(i);
+ inputVector[i] = tmp.get(i).floatValue();
}
} else {
final Object encodedVector = params.get("encoded_vector");
@@ -133,61 +109,41 @@ public VectorScoreScript(Map params) {
inputVector = Util.convertBase64ToArray((String) encodedVector);
}
- if(cosine) {
+ if (this.cosine) {
// calc magnitude
- double queryVectorNorm = 0.0;
+ float queryVectorNorm = 0.0f;
// compute query inputVector norm once
- for (double v : inputVector) {
+ for (float v: this.inputVector) {
queryVectorNorm += v * v;
}
- magnitude = Math.sqrt(queryVectorNorm);
+ this.magnitude = (float) Math.sqrt(queryVectorNorm);
} else {
- magnitude = 0.0;
+ this.magnitude = 0.0f;
}
- }
-
-
- /**
- * Called for each document
- * @return cosine similarity of the current document against the input inputVector
- */
- @Override
- public final Object run() {
- final int size = inputVector.length;
-
- final byte[] bytes = binaryEmbeddingReader.get(docId).bytes;
- final ByteArrayDataInput input = new ByteArrayDataInput(bytes);
- input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls
- final int len = input.readVInt(); // returns the number of bytes to read
- if(len != size * DOUBLE_SIZE) {
- return 0.0;
- }
- final int position = input.getPosition();
- final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, len).asDoubleBuffer();
-
- final double[] docVector = new double[size];
- doubleBuffer.get(docVector);
-
- double docVectorNorm = 0.0f;
- double score = 0;
- for (int i = 0; i < size; i++) {
- // doc inputVector norm
- if(cosine) {
- docVectorNorm += docVector[i]*docVector[i];
- }
- // dot product
- score += docVector[i] * inputVector[i];
+
+ try {
+ this.binaryEmbeddingReader = leafContext.reader().getBinaryDocValues(this.field);
+ } catch (IOException e) {
+ throw new IllegalStateException("binaryEmbeddingReader can't be null");
+ }
+ }
+
+ public static class VectorScoreScriptFactory implements LeafFactory {
+ private final Map params;
+ private final SearchLookup lookup;
+
+ public VectorScoreScriptFactory(Map params, SearchLookup lookup) {
+ this.params = params;
+ this.lookup = lookup;
}
- if(cosine) {
- // cosine similarity score
- if (docVectorNorm == 0 || magnitude == 0){
- return 0f;
- } else {
- return score / (Math.sqrt(docVectorNorm) * magnitude);
- }
- } else {
- return score;
+
+ public boolean needs_score() {
+ return false;
}
- }
+ @Override
+ public SearchScript newInstance(LeafReaderContext ctx) throws IOException {
+ return new VectorScoreScript(this.params, this.lookup, ctx);
+ }
+ }
}
\ No newline at end of file
diff --git a/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java b/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java
deleted file mode 100755
index 58db087..0000000
--- a/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java
+++ /dev/null
@@ -1,78 +0,0 @@
-package com.liorkn.elasticsearch.service;
-
-import com.liorkn.elasticsearch.script.VectorScoreScript;
-import org.apache.lucene.index.LeafReaderContext;
-import org.elasticsearch.common.Nullable;
-import org.elasticsearch.common.component.AbstractComponent;
-import org.elasticsearch.common.inject.Inject;
-import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.script.CompiledScript;
-import org.elasticsearch.script.ExecutableScript;
-import org.elasticsearch.script.LeafSearchScript;
-import org.elasticsearch.script.ScriptEngineService;
-import org.elasticsearch.script.SearchScript;
-import org.elasticsearch.search.lookup.SearchLookup;
-
-import java.io.IOException;
-import java.util.Map;
-
-/**
- * Created by Lior Knaany on 5/14/17.
- */
-public class VectorScoringScriptEngineService extends AbstractComponent implements ScriptEngineService{
-
- public static final String NAME = "knn";
-
- @Inject
- public VectorScoringScriptEngineService(Settings settings) {
- super(settings);
- }
-
- @Override
- public Object compile(String scriptName, String scriptSource, Map params) {
- return new VectorScoreScript.Factory();
- }
-
-
- @Override
- public boolean isInlineScriptEnabled() {
- return true;
- }
-
- @Override
- public String getType() {
- return NAME;
- }
-
- @Override
- public String getExtension() {
- return NAME;
- }
-
- @Override
- public ExecutableScript executable(CompiledScript compiledScript, @Nullable Map vars) {
- VectorScoreScript.Factory scriptFactory = (VectorScoreScript.Factory) compiledScript.compiled();
- return scriptFactory.newScript(vars);
- }
-
- @Override
- public SearchScript search(CompiledScript compiledScript, final SearchLookup lookup, @Nullable final Map vars) {
- final VectorScoreScript.Factory scriptFactory = (VectorScoreScript.Factory) compiledScript.compiled();
- final VectorScoreScript script = (VectorScoreScript) scriptFactory.newScript(vars);
- return new SearchScript() {
- @Override
- public LeafSearchScript getLeafSearchScript(LeafReaderContext context) throws IOException {
- script.setBinaryEmbeddingReader(context.reader().getBinaryDocValues(script.field));
- return script;
- }
- @Override
- public boolean needsScores() {
- return scriptFactory.needsScores();
- }
- };
- }
-
- @Override
- public void close() {
- }
-}
diff --git a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java
index 5627240..1cbdb9c 100644
--- a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java
+++ b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java
@@ -10,7 +10,7 @@
import org.elasticsearch.node.Node;
import org.elasticsearch.node.NodeValidationException;
import org.elasticsearch.painless.PainlessPlugin;
-import org.elasticsearch.transport.Netty3Plugin;
+import org.elasticsearch.transport.Netty4Plugin;
import java.io.File;
import java.io.IOException;
@@ -41,13 +41,10 @@ private EmbeddedElasticsearchServer(String defaultDataDirectory, int port) throw
Settings.Builder settings = Settings.builder()
.put("http.enabled", "true")
- .put("transport.type", "local")
- .put("http.type", "netty3")
+ .put("http.type", "netty4")
.put("path.data", dataDirectory)
.put("path.home", DEFAULT_HOME_DIRECTORY)
- .put("script.inline", "on")
- .put("node.max_local_storage_nodes", 10000)
- .put("script.stored", "on");
+ .put("node.max_local_storage_nodes", 10000);
startNodeInAvailablePort(settings);
}
@@ -61,7 +58,7 @@ private void startNodeInAvailablePort(Settings.Builder settings) throws NodeVali
settings.put("http.port", String.valueOf(this.port));
// this a hack in order to load Groovy plug in since we want to enable the usage of scripts
- node = new NodeExt(settings.build() , Arrays.asList(Netty3Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class));
+ node = new NodeExt(settings.build() , Arrays.asList(Netty4Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class));
node.start();
success = true;
System.out.println(EmbeddedElasticsearchServer.class.getName() + ": Using port: " + this.port);
diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java
index b95b65c..8427932 100644
--- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java
+++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java
@@ -2,6 +2,8 @@
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.node.ArrayNode;
+
import org.apache.http.HttpHost;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
@@ -70,9 +72,8 @@ public static void init() throws Exception {
public void test() throws Exception {
final Map params = new HashMap<>();
params.put("refresh", "true");
- final ObjectMapper mapper = new ObjectMapper();
- final TestObject[] objs = {new TestObject(1, new double[] {0.0, 0.5, 1.0}),
- new TestObject(2, new double[] {0.2, 0.6, 0.99})};
+ final TestObject[] objs = {new TestObject(1, new float[] {0.0f, 0.5f, 1.0f}),
+ new TestObject(2, new float[] {0.2f, 0.6f, 0.99f})};
for (int i = 0; i < objs.length; i++) {
final TestObject t = objs[i];
@@ -92,10 +93,10 @@ public void test() throws Exception {
" \"boost_mode\": \"replace\"," +
" \"script_score\": {" +
" \"script\": {" +
- " \"inline\": \"binary_vector_score\"," +
+ " \"source\": \"binary_vector_score\"," +
" \"lang\": \"knn\"," +
" \"params\": {" +
- " \"cosine\": false," +
+ " \"cosine\": true," +
" \"field\": \"embedding_vector\"," +
" \"vector\": [" +
" 0.1, 0.2, 0.3" +
@@ -113,6 +114,10 @@ public void test() throws Exception {
System.out.println(resBody);
Assert.assertEquals("search should return status code 200", 200, res.getStatusLine().getStatusCode());
Assert.assertTrue(String.format("There should be %d documents in the search response", objs.length), resBody.contains("\"hits\":{\"total\":" + objs.length));
+ // Testing Scores
+ final ArrayNode hitsJson = (ArrayNode)mapper.readTree(resBody).get("hits").get("hits");
+ Assert.assertEquals(0.9941734, hitsJson.get(0).get("_score").asDouble(), 0);
+ Assert.assertEquals(0.95618284, hitsJson.get(1).get("_score").asDouble(), 0);
}
@AfterClass
diff --git a/src/test/java/com/liorkn/elasticsearch/TestObject.java b/src/test/java/com/liorkn/elasticsearch/TestObject.java
index f37d98a..8338e31 100644
--- a/src/test/java/com/liorkn/elasticsearch/TestObject.java
+++ b/src/test/java/com/liorkn/elasticsearch/TestObject.java
@@ -10,7 +10,7 @@
public class TestObject {
int jobId;
String embeddingVector;
- double[] vector;
+ float[] vector;
public int getJobId() {
return jobId;
@@ -20,13 +20,13 @@ public String getEmbeddingVector() {
return embeddingVector;
}
- public double[] getVector() {
+ public float[] getVector() {
return vector;
}
- public TestObject(int jobId, double[] vector) {
+ public TestObject(int jobId, float[] vector) {
this.jobId = jobId;
this.vector = vector;
this.embeddingVector = Util.convertArrayToBase64(vector);
}
-}
+}
\ No newline at end of file