diff --git a/README.md b/README.md index 9806c21..5a8f048 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ give it a try. ## Elasticsearch version -* Currently designed for Elasticsearch 5.6.0. +* Currently designed for Elasticsearch 6.0.0. * for Elasticsearch 5.2.2 use branch `es-5.2.2` * for Elasticsearch 2.4.4 use branch `es-2.4.4` @@ -42,28 +42,29 @@ give it a try. * The vector can be of any dimension ### Converting a vector to Base64 -to convert an array of doubles to a base64 string we use these example methods: +to convert an array of float32 to a base64 string we use these example methods: **Java** ``` -public static final String convertArrayToBase64(double[] array) { - final int capacity = 8 * array.length; - final ByteBuffer bb = ByteBuffer.allocate(capacity); - for (int i = 0; i < array.length; i++) { - bb.putDouble(array[i]); - } - bb.rewind(); - final ByteBuffer encodedBB = Base64.getEncoder().encode(bb); - return new String(encodedBB.array()); +public static float[] convertBase64ToArray(String base64Str) { + final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes()); + final FloatBuffer floatBuffer = ByteBuffer.wrap(decode).asFloatBuffer(); + final float[] dims = new float[floatBuffer.capacity()]; + floatBuffer.get(dims); + + return dims; } -public static double[] convertBase64ToArray(String base64Str) { - final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes()); - final DoubleBuffer doubleBuffer = ByteBuffer.wrap(decode).asDoubleBuffer(); +public static String convertArrayToBase64(float[] array) { + final int capacity = Float.BYTES * array.length; + final ByteBuffer bb = ByteBuffer.allocate(capacity); + for (float v : array) { + bb.putFloat(v); + } + bb.rewind(); + final ByteBuffer encodedBB = Base64.getEncoder().encode(bb); - final double[] dims = new double[doubleBuffer.capacity()]; - doubleBuffer.get(dims); - return dims; + return new String(encodedBB.array()); } ``` **Python** @@ -71,14 +72,14 @@ public static double[] convertBase64ToArray(String base64Str) { import base64 import numpy as np -dbig = np.dtype('>f8') +dfloat32 = np.dtype('>f4') def decode_float_list(base64_string): bytes = base64.b64decode(base64_string) - return np.frombuffer(bytes, dtype=dbig).tolist() + return np.frombuffer(bytes, dtype=dfloat32).tolist() def encode_array(arr): - base64_str = base64.b64encode(np.array(arr).astype(dbig)).decode("utf-8") + base64_str = base64.b64encode(np.array(arr).astype(dfloat32)).decode("utf-8") return base64_str ``` @@ -87,11 +88,11 @@ def encode_array(arr): require 'base64' def decode_float_list(base64_string) - Base64.strict_decode64(base64_string).unpack('G*') + Base64.strict_decode64(base64_string).unpack('g*') end def encode_array(arr) - Base64.strict_encode64(arr.pack('G*')) + Base64.strict_encode64(arr.pack('g*')) end ``` @@ -103,12 +104,12 @@ import( "encoding/base64" ) -func convertArrayToBase64(array []float64) string { - bytes := make([]byte, 0, 8*len(array)) +func convertArrayToBase64(array []float32) string { + bytes := make([]byte, 0, 4*len(array)) for _, a := range array { - bits := math.Float64bits(a) - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, bits) + bits := math.Float32bits(a) + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, bits) bytes = append(bytes, b...) } @@ -116,18 +117,18 @@ func convertArrayToBase64(array []float64) string { return encoded } -func convertBase64ToArray(base64Str string) ([]float64, error) { +func convertBase64ToArray(base64Str string) ([]float32, error) { decoded, err := base64.StdEncoding.DecodeString(base64Str) if err != nil { return nil, err } length := len(decoded) - array := make([]float64, 0, length/8) + array := make([]float32, 0, length/4) - for i := 0; i < len(decoded); i += 8 { - bits := binary.BigEndian.Uint64(decoded[i : i+8]) - f := math.Float64frombits(bits) + for i := 0; i < len(decoded); i += 4 { + bits := binary.BigEndian.Uint32(decoded[i : i+4]) + f := math.Float32frombits(bits) array = append(array, f) } return array, nil @@ -146,7 +147,7 @@ func convertBase64ToArray(base64Str string) ([]float64, error) { "boost_mode": "replace", "script_score": { "script": { - "inline": "binary_vector_score", + "source": "binary_vector_score", "lang": "knn", "params": { "cosine": false, diff --git a/pom.xml b/pom.xml index d765043..79d7430 100755 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ elasticsearch-binary-vector-scoring com.liorkn.elasticsearch elasticsearch-binary-vector-scoring - 5.6.0 + 6.0.0 ElasticSearch Plugin for Binary Vector Scoring @@ -27,11 +27,11 @@ ${project.basedir}/src/main/resources/license-check/license_header_definition.xml warn - 5.6.0 + 6.0.0 2.4 4.4.8 4.12 - 2.7.4 + 2.8.11.3 @@ -65,7 +65,7 @@ org.elasticsearch.plugin - transport-netty3-client + transport-netty4-client ${elasticsearch.version} test @@ -86,12 +86,6 @@ - - - - - - org.codelibs.elasticsearch.module lang-painless diff --git a/src/main/java/com/liorkn/elasticsearch/Util.java b/src/main/java/com/liorkn/elasticsearch/Util.java index de81af8..53cf1f6 100644 --- a/src/main/java/com/liorkn/elasticsearch/Util.java +++ b/src/main/java/com/liorkn/elasticsearch/Util.java @@ -1,7 +1,7 @@ package com.liorkn.elasticsearch; import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; import java.util.Base64; /** @@ -9,23 +9,24 @@ */ public class Util { - public static final double[] convertBase64ToArray(String base64Str) { + public static float[] convertBase64ToArray(String base64Str) { final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes()); - final DoubleBuffer doubleBuffer = ByteBuffer.wrap(decode).asDoubleBuffer(); + final FloatBuffer floatBuffer = ByteBuffer.wrap(decode).asFloatBuffer(); + final float[] dims = new float[floatBuffer.capacity()]; + floatBuffer.get(dims); - final double[] dims = new double[doubleBuffer.capacity()]; - doubleBuffer.get(dims); return dims; } - public static final String convertArrayToBase64(double[] array) { - final int capacity = 8 * array.length; + public static String convertArrayToBase64(float[] array) { + final int capacity = Float.BYTES * array.length; final ByteBuffer bb = ByteBuffer.allocate(capacity); - for (int i = 0; i < array.length; i++) { - bb.putDouble(array[i]); + for (double v : array) { + bb.putFloat((float) v); } bb.rewind(); final ByteBuffer encodedBB = Base64.getEncoder().encode(bb); + return new String(encodedBB.array()); } } diff --git a/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java new file mode 100644 index 0000000..6c00692 --- /dev/null +++ b/src/main/java/com/liorkn/elasticsearch/engine/VectorScoringScriptEngine.java @@ -0,0 +1,36 @@ +package com.liorkn.elasticsearch.engine; + +import com.liorkn.elasticsearch.script.VectorScoreScript; + +import java.util.Map; + +import org.elasticsearch.script.ScriptContext; +import org.elasticsearch.script.ScriptEngine; +import org.elasticsearch.script.SearchScript; + +/** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */ +public class VectorScoringScriptEngine implements ScriptEngine { + + public static final String NAME = "knn"; + private static final String SCRIPT_SOURCE = "binary_vector_score"; + + @Override + public String getType() { + return NAME; + } + + @Override + public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { + if (context.equals(SearchScript.CONTEXT) == false) { + throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); + } + + // we use the script "source" as the script identifier + if (!SCRIPT_SOURCE.equals(scriptSource)) { + throw new IllegalArgumentException("Unknown script name " + scriptSource); + } + + SearchScript.Factory factory = VectorScoreScript.VectorScoreScriptFactory::new; + return context.factoryClazz.cast(factory); + } +} diff --git a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java index 88a3599..c279636 100755 --- a/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java +++ b/src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java @@ -13,13 +13,15 @@ */ package com.liorkn.elasticsearch.plugin; -import com.liorkn.elasticsearch.service.VectorScoringScriptEngineService; +import com.liorkn.elasticsearch.engine.VectorScoringScriptEngine; + +import java.util.Collection; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ScriptPlugin; -import org.elasticsearch.script.ScriptEngineService; - +import org.elasticsearch.script.ScriptContext; +import org.elasticsearch.script.ScriptEngine; /** * This class is instantiated when Elasticsearch loads the plugin for the * first time. If you change the name of this plugin, make sure to update @@ -27,9 +29,8 @@ */ public final class VectorScoringPlugin extends Plugin implements ScriptPlugin { - public final ScriptEngineService getScriptEngineService(Settings settings) { - return new VectorScoringScriptEngineService(settings); + @Override + public ScriptEngine getScriptEngine(Settings settings, Collection> contexts) { + return new VectorScoringScriptEngine(); } - - } \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java old mode 100755 new mode 100644 index 0b87cf3..be06fd4 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -1,114 +1,90 @@ -/* -Based on: https://discuss.elastic.co/t/vector-scoring/85227/4 -and https://github.com/MLnick/elasticsearch-vector-scoring - -another slower implementation using strings: https://github.com/ginobefun/elasticsearch-feature-vector-scoring - -storing arrays is no luck - lucine index doesn't keep the array members orders -https://www.elastic.co/guide/en/elasticsearch/guide/current/complex-core-fields.html - -Delimited Payload Token Filter: https://www.elastic.co/guide/en/elasticsearch/reference/2.4/analysis-delimited-payload-tokenfilter.html - - - */ - package com.liorkn.elasticsearch.script; -import com.liorkn.elasticsearch.Util; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.store.ByteArrayDataInput; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.script.ExecutableScript; -import org.elasticsearch.script.LeafSearchScript; -import org.elasticsearch.script.ScriptException; - -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; +import java.io.IOException; +import java.io.UncheckedIOException; import java.util.ArrayList; -import java.util.Base64; import java.util.Map; -/** - * Script that scores documents based on cosine similarity embedding vectors. - */ -public final class VectorScoreScript implements LeafSearchScript, ExecutableScript { - - public static final String SCRIPT_NAME = "binary_vector_score"; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.store.ByteArrayDataInput; +import org.elasticsearch.script.SearchScript; +import org.elasticsearch.search.lookup.SearchLookup; - private static final int DOUBLE_SIZE = 8; +import com.liorkn.elasticsearch.Util; - // the field containing the vectors to be scored against - public final String field; +public final class VectorScoreScript extends SearchScript { - private int docId; private BinaryDocValues binaryEmbeddingReader; - - private final double[] inputVector; - private final double magnitude; - + + private final String field; private final boolean cosine; + private final float[] inputVector; + private final float magnitude; + @Override public long runAsLong() { - return ((Number)this.run()).longValue(); - } - @Override - public double runAsDouble() { - return ((Number)this.run()).doubleValue(); - } - @Override - public void setNextVar(String name, Object value) {} - @Override - public void setDocument(int docId) { - this.docId = docId; + return (long) runAsDouble(); } + + @Override + public double runAsDouble() { + try { + final byte[] bytes = binaryEmbeddingReader.binaryValue().bytes; + final ByteArrayDataInput input = new ByteArrayDataInput(bytes); + + input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls + + final int len = input.readVInt(); + // in case vector is of different size + if (len != inputVector.length * Float.BYTES) { + return 0.0; + } + + float score = 0; + + if (cosine) { + float docVectorNorm = 0.0f; + for (int i = 0; i < inputVector.length; i++) { + float v = Float.intBitsToFloat(input.readInt()); + docVectorNorm += v * v; // inputVector norm + score += v * inputVector[i]; // dot product + } + + if (docVectorNorm == 0 || magnitude == 0) { + return 0f; + } else { + return score / (Math.sqrt(docVectorNorm) * magnitude); + } + } else { + for (int i = 0; i < inputVector.length; i++) { + float v = Float.intBitsToFloat(input.readInt()); + score += v * inputVector[i]; // dot product + } - public void setBinaryEmbeddingReader(BinaryDocValues binaryEmbeddingReader) { - if(binaryEmbeddingReader == null) { - throw new IllegalStateException("binaryEmbeddingReader can't be null"); - } - this.binaryEmbeddingReader = binaryEmbeddingReader; - } - - - /** - * Factory that is registered in - * {@link VectorScoringPlugin#onModule(org.elasticsearch.script.ScriptModule)} - * method when the plugin is loaded. - */ - public static class Factory { - - /** - * This method is called for every search on every shard. - * - * @param params - * list of script parameters passed with the query - * @return new native script - */ - public ExecutableScript newScript(@Nullable Map params) throws ScriptException { - return new VectorScoreScript(params); - } - - /** - * Indicates if document scores may be needed by the produced scripts. - * - * @return {@code true} if scores are needed. - */ - public boolean needsScores() { - return false; + return score; + } + } catch (Exception e) { + return 0.0; } - + } + + @Override + public void setDocument(int docId) { + try { + this.binaryEmbeddingReader.advanceExact(docId); + } catch (IOException e) { + throw new UncheckedIOException(e); + } } - - /** - * Init - * @param params index that a scored are placed in this parameter. Initialize them here. - */ @SuppressWarnings("unchecked") - public VectorScoreScript(Map params) { - final Object cosineBool = params.get("cosine"); - cosine = cosineBool != null ? + public VectorScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { + super(params, lookup, leafContext); + + final Object cosineBool = params.get("cosine"); + this.cosine = cosineBool != null ? (boolean)cosineBool : true; @@ -121,9 +97,9 @@ public VectorScoreScript(Map params) { final Object vector = params.get("vector"); if(vector != null) { final ArrayList tmp = (ArrayList) vector; - inputVector = new double[tmp.size()]; + inputVector = new float[tmp.size()]; for (int i = 0; i < inputVector.length; i++) { - inputVector[i] = tmp.get(i); + inputVector[i] = tmp.get(i).floatValue(); } } else { final Object encodedVector = params.get("encoded_vector"); @@ -133,61 +109,41 @@ public VectorScoreScript(Map params) { inputVector = Util.convertBase64ToArray((String) encodedVector); } - if(cosine) { + if (this.cosine) { // calc magnitude - double queryVectorNorm = 0.0; + float queryVectorNorm = 0.0f; // compute query inputVector norm once - for (double v : inputVector) { + for (float v: this.inputVector) { queryVectorNorm += v * v; } - magnitude = Math.sqrt(queryVectorNorm); + this.magnitude = (float) Math.sqrt(queryVectorNorm); } else { - magnitude = 0.0; + this.magnitude = 0.0f; } - } - - - /** - * Called for each document - * @return cosine similarity of the current document against the input inputVector - */ - @Override - public final Object run() { - final int size = inputVector.length; - - final byte[] bytes = binaryEmbeddingReader.get(docId).bytes; - final ByteArrayDataInput input = new ByteArrayDataInput(bytes); - input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls - final int len = input.readVInt(); // returns the number of bytes to read - if(len != size * DOUBLE_SIZE) { - return 0.0; - } - final int position = input.getPosition(); - final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, len).asDoubleBuffer(); - - final double[] docVector = new double[size]; - doubleBuffer.get(docVector); - - double docVectorNorm = 0.0f; - double score = 0; - for (int i = 0; i < size; i++) { - // doc inputVector norm - if(cosine) { - docVectorNorm += docVector[i]*docVector[i]; - } - // dot product - score += docVector[i] * inputVector[i]; + + try { + this.binaryEmbeddingReader = leafContext.reader().getBinaryDocValues(this.field); + } catch (IOException e) { + throw new IllegalStateException("binaryEmbeddingReader can't be null"); + } + } + + public static class VectorScoreScriptFactory implements LeafFactory { + private final Map params; + private final SearchLookup lookup; + + public VectorScoreScriptFactory(Map params, SearchLookup lookup) { + this.params = params; + this.lookup = lookup; } - if(cosine) { - // cosine similarity score - if (docVectorNorm == 0 || magnitude == 0){ - return 0f; - } else { - return score / (Math.sqrt(docVectorNorm) * magnitude); - } - } else { - return score; + + public boolean needs_score() { + return false; } - } + @Override + public SearchScript newInstance(LeafReaderContext ctx) throws IOException { + return new VectorScoreScript(this.params, this.lookup, ctx); + } + } } \ No newline at end of file diff --git a/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java b/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java deleted file mode 100755 index 58db087..0000000 --- a/src/main/java/com/liorkn/elasticsearch/service/VectorScoringScriptEngineService.java +++ /dev/null @@ -1,78 +0,0 @@ -package com.liorkn.elasticsearch.service; - -import com.liorkn.elasticsearch.script.VectorScoreScript; -import org.apache.lucene.index.LeafReaderContext; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.component.AbstractComponent; -import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.script.CompiledScript; -import org.elasticsearch.script.ExecutableScript; -import org.elasticsearch.script.LeafSearchScript; -import org.elasticsearch.script.ScriptEngineService; -import org.elasticsearch.script.SearchScript; -import org.elasticsearch.search.lookup.SearchLookup; - -import java.io.IOException; -import java.util.Map; - -/** - * Created by Lior Knaany on 5/14/17. - */ -public class VectorScoringScriptEngineService extends AbstractComponent implements ScriptEngineService{ - - public static final String NAME = "knn"; - - @Inject - public VectorScoringScriptEngineService(Settings settings) { - super(settings); - } - - @Override - public Object compile(String scriptName, String scriptSource, Map params) { - return new VectorScoreScript.Factory(); - } - - - @Override - public boolean isInlineScriptEnabled() { - return true; - } - - @Override - public String getType() { - return NAME; - } - - @Override - public String getExtension() { - return NAME; - } - - @Override - public ExecutableScript executable(CompiledScript compiledScript, @Nullable Map vars) { - VectorScoreScript.Factory scriptFactory = (VectorScoreScript.Factory) compiledScript.compiled(); - return scriptFactory.newScript(vars); - } - - @Override - public SearchScript search(CompiledScript compiledScript, final SearchLookup lookup, @Nullable final Map vars) { - final VectorScoreScript.Factory scriptFactory = (VectorScoreScript.Factory) compiledScript.compiled(); - final VectorScoreScript script = (VectorScoreScript) scriptFactory.newScript(vars); - return new SearchScript() { - @Override - public LeafSearchScript getLeafSearchScript(LeafReaderContext context) throws IOException { - script.setBinaryEmbeddingReader(context.reader().getBinaryDocValues(script.field)); - return script; - } - @Override - public boolean needsScores() { - return scriptFactory.needsScores(); - } - }; - } - - @Override - public void close() { - } -} diff --git a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java index 5627240..1cbdb9c 100644 --- a/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java +++ b/src/test/java/com/liorkn/elasticsearch/EmbeddedElasticsearchServer.java @@ -10,7 +10,7 @@ import org.elasticsearch.node.Node; import org.elasticsearch.node.NodeValidationException; import org.elasticsearch.painless.PainlessPlugin; -import org.elasticsearch.transport.Netty3Plugin; +import org.elasticsearch.transport.Netty4Plugin; import java.io.File; import java.io.IOException; @@ -41,13 +41,10 @@ private EmbeddedElasticsearchServer(String defaultDataDirectory, int port) throw Settings.Builder settings = Settings.builder() .put("http.enabled", "true") - .put("transport.type", "local") - .put("http.type", "netty3") + .put("http.type", "netty4") .put("path.data", dataDirectory) .put("path.home", DEFAULT_HOME_DIRECTORY) - .put("script.inline", "on") - .put("node.max_local_storage_nodes", 10000) - .put("script.stored", "on"); + .put("node.max_local_storage_nodes", 10000); startNodeInAvailablePort(settings); } @@ -61,7 +58,7 @@ private void startNodeInAvailablePort(Settings.Builder settings) throws NodeVali settings.put("http.port", String.valueOf(this.port)); // this a hack in order to load Groovy plug in since we want to enable the usage of scripts - node = new NodeExt(settings.build() , Arrays.asList(Netty3Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class)); + node = new NodeExt(settings.build() , Arrays.asList(Netty4Plugin.class, PainlessPlugin.class, ReindexPlugin.class, VectorScoringPlugin.class)); node.start(); success = true; System.out.println(EmbeddedElasticsearchServer.class.getName() + ": Using port: " + this.port); diff --git a/src/test/java/com/liorkn/elasticsearch/PluginTest.java b/src/test/java/com/liorkn/elasticsearch/PluginTest.java index b95b65c..8427932 100644 --- a/src/test/java/com/liorkn/elasticsearch/PluginTest.java +++ b/src/test/java/com/liorkn/elasticsearch/PluginTest.java @@ -2,6 +2,8 @@ import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; + import org.apache.http.HttpHost; import org.apache.http.entity.ContentType; import org.apache.http.entity.StringEntity; @@ -70,9 +72,8 @@ public static void init() throws Exception { public void test() throws Exception { final Map params = new HashMap<>(); params.put("refresh", "true"); - final ObjectMapper mapper = new ObjectMapper(); - final TestObject[] objs = {new TestObject(1, new double[] {0.0, 0.5, 1.0}), - new TestObject(2, new double[] {0.2, 0.6, 0.99})}; + final TestObject[] objs = {new TestObject(1, new float[] {0.0f, 0.5f, 1.0f}), + new TestObject(2, new float[] {0.2f, 0.6f, 0.99f})}; for (int i = 0; i < objs.length; i++) { final TestObject t = objs[i]; @@ -92,10 +93,10 @@ public void test() throws Exception { " \"boost_mode\": \"replace\"," + " \"script_score\": {" + " \"script\": {" + - " \"inline\": \"binary_vector_score\"," + + " \"source\": \"binary_vector_score\"," + " \"lang\": \"knn\"," + " \"params\": {" + - " \"cosine\": false," + + " \"cosine\": true," + " \"field\": \"embedding_vector\"," + " \"vector\": [" + " 0.1, 0.2, 0.3" + @@ -113,6 +114,10 @@ public void test() throws Exception { System.out.println(resBody); Assert.assertEquals("search should return status code 200", 200, res.getStatusLine().getStatusCode()); Assert.assertTrue(String.format("There should be %d documents in the search response", objs.length), resBody.contains("\"hits\":{\"total\":" + objs.length)); + // Testing Scores + final ArrayNode hitsJson = (ArrayNode)mapper.readTree(resBody).get("hits").get("hits"); + Assert.assertEquals(0.9941734, hitsJson.get(0).get("_score").asDouble(), 0); + Assert.assertEquals(0.95618284, hitsJson.get(1).get("_score").asDouble(), 0); } @AfterClass diff --git a/src/test/java/com/liorkn/elasticsearch/TestObject.java b/src/test/java/com/liorkn/elasticsearch/TestObject.java index f37d98a..8338e31 100644 --- a/src/test/java/com/liorkn/elasticsearch/TestObject.java +++ b/src/test/java/com/liorkn/elasticsearch/TestObject.java @@ -10,7 +10,7 @@ public class TestObject { int jobId; String embeddingVector; - double[] vector; + float[] vector; public int getJobId() { return jobId; @@ -20,13 +20,13 @@ public String getEmbeddingVector() { return embeddingVector; } - public double[] getVector() { + public float[] getVector() { return vector; } - public TestObject(int jobId, double[] vector) { + public TestObject(int jobId, float[] vector) { this.jobId = jobId; this.vector = vector; this.embeddingVector = Util.convertArrayToBase64(vector); } -} +} \ No newline at end of file