Skip to content

Commit 4a35264

Browse files
committed
Fix over allocating nodes and add missing tests
1 parent e6edcd5 commit 4a35264

File tree

5 files changed

+327
-20
lines changed

5 files changed

+327
-20
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/Tree.java

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@ private Double predict(List<Double> features, int nodeIndex) {
6262
}
6363

6464
/**
65-
* Finds null nodes
66-
* @return List of indexes to null nodes
65+
* Finds {@code null} nodes. If constructed properly there should be no {@code null} nodes.
66+
* {@code null} nodes indicates missing leaf or junction nodes
67+
*
68+
* @return List of indexes to the {@code null} nodes
6769
*/
6870
List<Integer> missingNodes() {
6971
List<Integer> nullNodeIndices = new ArrayList<>();
@@ -185,22 +187,20 @@ public static TreeBuilder newTreeBuilder() {
185187
* @return The created node
186188
*/
187189
public Node addJunction(int nodeIndex, int featureIndex, boolean isDefaultLeft, double decisionThreshold) {
188-
// assert nodeIndex < nodes.size() : "node index " + nodeIndex + " >= size " + nodes.size();
189-
190190
int leftChild = numNodes++;
191191
int rightChild = numNodes++;
192192
nodes.ensureCapacity(nodeIndex +1);
193193
for (int i=nodes.size(); i<nodeIndex +1; i++) {
194194
nodes.add(null);
195195
}
196196

197-
198197
Node node = new Node(leftChild, rightChild, featureIndex, isDefaultLeft, decisionThreshold);
199198
nodes.set(nodeIndex, node);
199+
200200
// allocate space for the child nodes
201-
nodes.add(null);
202-
nodes.add(null);
203-
// assert nodes.size() == numNodes : "nodes size " + nodes.size() + " != num nodes " + numNodes;
201+
while (nodes.size() <= rightChild) {
202+
nodes.add(null);
203+
}
204204

205205
return node;
206206
}
@@ -212,15 +212,11 @@ public Node addJunction(int nodeIndex, int featureIndex, boolean isDefaultLeft,
212212
* @return this
213213
*/
214214
public TreeBuilder addLeaf(int nodeIndex, double value) {
215-
// assert nodeIndex < nodes.size();
216-
217215
for (int i=nodes.size(); i<nodeIndex +1; i++) {
218216
nodes.add(null);
219217
}
220218

221-
222-
223-
assert nodes.get(nodeIndex) == null;
219+
assert nodes.get(nodeIndex) == null : "expected null value at index " + nodeIndex;
224220

225221
nodes.set(nodeIndex, new Node(value));
226222
return this;

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/tree/TreeEnsembleModel.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66

77
package org.elasticsearch.xpack.ml.inference.tree;
88

9-
import org.apache.log4j.LogManager;
10-
import org.apache.log4j.Logger;
11-
129
import java.util.ArrayList;
1310
import java.util.Arrays;
1411
import java.util.Collections;
@@ -18,8 +15,6 @@
1815

1916
public class TreeEnsembleModel {
2017

21-
private static Logger logger = LogManager.getLogger(TreeEnsembleModel.class);
22-
2318
private final List<Tree> trees;
2419
private final Map<String, Integer> featureMap;
2520

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
7+
package org.elasticsearch.xpack.ml.inference.tree;
8+
9+
import org.elasticsearch.test.ESTestCase;
10+
11+
import java.util.Arrays;
12+
import java.util.Collections;
13+
import java.util.HashMap;
14+
import java.util.List;
15+
import java.util.Map;
16+
17+
import static org.hamcrest.Matchers.empty;
18+
import static org.hamcrest.Matchers.hasSize;
19+
import static org.hamcrest.Matchers.not;
20+
21+
public class TreeEnsembleModelTests extends ESTestCase {
22+
23+
public void testMergePredictions() {
24+
TreeEnsembleModel emptyModel = TreeEnsembleModel.modelBuilder(Collections.emptyMap()).build();
25+
26+
List<Double> predictions = Arrays.asList(0.1, 0.2, 0.3);
27+
// prediction is the sum of the component predictions
28+
assertEquals(0.6, emptyModel.mergePredictions(predictions), 0.00001);
29+
}
30+
31+
public void testDocToFeatureVector() {
32+
Map<String, Integer> featureMap = new HashMap<>();
33+
34+
featureMap.put("price", 0);
35+
featureMap.put("number_of_wheels", 1);
36+
featureMap.put("phone_charger", 2);
37+
38+
Map<String, Object> document = new HashMap<>();
39+
document.put("price", 1000.0);
40+
document.put("number_of_wheels", 4.0);
41+
document.put("phone_charger", 0.0);
42+
document.put("unrelated", 1.0);
43+
44+
TreeEnsembleModel emptyModel = TreeEnsembleModel.modelBuilder(featureMap).build();
45+
List<Double> featureVector = emptyModel.docToFeatureVector(document);
46+
assertEquals(Arrays.asList(1000.0, 4.0, 0.0), featureVector);
47+
48+
49+
featureMap.put("missing_field", 3);
50+
emptyModel = TreeEnsembleModel.modelBuilder(featureMap).build();
51+
featureVector = emptyModel.docToFeatureVector(document);
52+
assertEquals(Arrays.asList(1000.0, 4.0, 0.0, null), featureVector);
53+
}
54+
55+
public void testTrace() {
56+
int numFeatures = randomIntBetween(1, 4);
57+
int numTrees = randomIntBetween(1, 5);
58+
TreeEnsembleModel model = buildRandomModel(numFeatures, numTrees);
59+
60+
Map<String, Double> doc = new HashMap<>();
61+
doc.put("a", randomDouble());
62+
doc.put("b", randomDouble());
63+
doc.put("c", randomDouble());
64+
doc.put("d", randomDouble());
65+
List<List<Tree.Node>> traces = model.trace(doc);
66+
assertThat(traces, hasSize(numTrees));
67+
for (var trace : traces) {
68+
assertThat(trace, not(empty()));
69+
}
70+
}
71+
72+
public void testPredict() {
73+
Map<String, Integer> featureMap = new HashMap<>();
74+
featureMap.put("a", 0);
75+
featureMap.put("b", 1);
76+
77+
TreeEnsembleModel.ModelBuilder modelBuilder = TreeEnsembleModel.modelBuilder(featureMap);
78+
79+
// Simple tree
80+
Tree.TreeBuilder treeBuilder1 = Tree.TreeBuilder.newTreeBuilder();
81+
Tree.Node rootNode = treeBuilder1.addJunction(0, 0, true, 0.5);
82+
treeBuilder1.addLeaf(rootNode.leftChild, 0.1);
83+
treeBuilder1.addLeaf(rootNode.rightChild, 0.2);
84+
85+
Tree.TreeBuilder treeBuilder2 = Tree.TreeBuilder.newTreeBuilder();
86+
Tree.Node rootNode2 = treeBuilder2.addJunction(0, 1, true, 0.5);
87+
treeBuilder2.addLeaf(rootNode2.leftChild, 0.1);
88+
treeBuilder2.addLeaf(rootNode2.rightChild, 0.2);
89+
90+
modelBuilder.addTree(treeBuilder1.build());
91+
modelBuilder.addTree(treeBuilder2.build());
92+
93+
TreeEnsembleModel model = modelBuilder.build();
94+
95+
// this doc should result in prediction 0.1 for tree1 and 0.2 for tree2
96+
// the prediction is the sum of the scores (boosted tree)
97+
Map<String, Double> doc = new HashMap<>();
98+
doc.put("a", 0.4);
99+
doc.put("b", 0.7);
100+
101+
double prediction = model.predict(doc);
102+
assertEquals(0.3, prediction, 0.00001);
103+
}
104+
105+
private TreeEnsembleModel buildRandomModel(int numFeatures, int numTrees) {
106+
Map<String, Integer> featureMap = new HashMap<>();
107+
char fieldName = 'a';
108+
int index = 0;
109+
for (int i=0; i<numFeatures; i++) {
110+
featureMap.put(Character.toString(fieldName++), index++);
111+
}
112+
113+
TreeEnsembleModel.ModelBuilder builder = TreeEnsembleModel.modelBuilder(featureMap);
114+
for (int i=0; i<numTrees; i++) {
115+
builder.addTree(TreeTests.buildRandomTree(numFeatures, randomIntBetween(3, 6)));
116+
}
117+
118+
return builder.build();
119+
}
120+
}
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
7+
package org.elasticsearch.xpack.ml.inference.tree;
8+
9+
import org.elasticsearch.test.ESTestCase;
10+
11+
import java.util.ArrayList;
12+
import java.util.Arrays;
13+
import java.util.List;
14+
15+
import static org.hamcrest.Matchers.hasSize;
16+
17+
public class TreeTests extends ESTestCase {
18+
19+
public static Tree buildRandomTree(int numFeatures, int depth) {
20+
21+
Tree.TreeBuilder builder = Tree.TreeBuilder.newTreeBuilder();
22+
23+
Tree.Node node = builder.addJunction(0, randomFeatureIndex(numFeatures), true, randomDecisionThreshold());
24+
List<Integer> childNodes = List.of(node.leftChild, node.rightChild);
25+
26+
for (int i=0; i<depth -1; i++) {
27+
28+
List<Integer> nextNodes = new ArrayList<>();
29+
for (int nodeId : childNodes) {
30+
31+
if (i == depth -2) {
32+
builder.addLeaf(nodeId, randomDecisionThreshold());
33+
} else {
34+
Tree.Node childNode = builder.addJunction(nodeId, randomFeatureIndex(numFeatures), true, randomDecisionThreshold());
35+
nextNodes.add(childNode.leftChild);
36+
nextNodes.add(childNode.rightChild);
37+
}
38+
}
39+
40+
childNodes = nextNodes;
41+
}
42+
43+
return builder.build();
44+
}
45+
46+
static int randomFeatureIndex(int max) {
47+
return randomIntBetween(0, max -1);
48+
}
49+
50+
static double randomDecisionThreshold() {
51+
return randomDouble();
52+
}
53+
54+
public void testPredict() {
55+
// Build a tree with 2 nodes and 3 leaves using 2 features
56+
// The leaves have unique values 0.1, 0.2, 0.3
57+
Tree.TreeBuilder builder = Tree.TreeBuilder.newTreeBuilder();
58+
Tree.Node rootNode = builder.addJunction(0, 0, true, 0.5);
59+
builder.addLeaf(rootNode.rightChild, 0.3);
60+
Tree.Node leftChildNode = builder.addJunction(rootNode.leftChild, 1, true, 0.8);
61+
builder.addLeaf(leftChildNode.leftChild, 0.1);
62+
builder.addLeaf(leftChildNode.rightChild, 0.2);
63+
64+
Tree tree = builder.build();
65+
66+
// This feature vector should hit the right child of the root node
67+
List<Double> featureVector = Arrays.asList(0.6, 0.0);
68+
assertEquals(0.3, tree.predict(featureVector), 0.00001);
69+
70+
// This should hit the left child of the left child of the root node
71+
// i.e. it takes the path left, left
72+
featureVector = Arrays.asList(0.3, 0.7);
73+
assertEquals(0.1, tree.predict(featureVector), 0.00001);
74+
75+
// This should hit the right child of the left child of the root node
76+
// i.e. it takes the path left, right
77+
featureVector = Arrays.asList(0.3, 0.8);
78+
assertEquals(0.2, tree.predict(featureVector), 0.00001);
79+
}
80+
81+
public void testTrace() {
82+
int numFeatures = randomIntBetween(1, 6);
83+
int depth = 6;
84+
Tree tree = buildRandomTree(numFeatures, depth);
85+
86+
List<Double> features = new ArrayList<>(numFeatures);
87+
for (int i=0; i<numFeatures; i++) {
88+
features.add(randomDecisionThreshold());
89+
}
90+
91+
List<Tree.Node> trace = tree.trace(features);
92+
assertThat(trace, hasSize(depth));
93+
for (int i=0; i<trace.size() -2; i++) {
94+
assertFalse(trace.get(i).isLeaf());
95+
}
96+
assertTrue(trace.get(trace.size() -1).isLeaf());
97+
98+
double prediction = tree.predict(features);
99+
assertEquals(trace.get(trace.size() -1).value(), prediction, 0.0001);
100+
101+
// Because the tree is built up breadth first we can figure out o
102+
// a node's id from its child nodes. Then we can trace the route
103+
// taken an assert it's the branch decisions were correct
104+
105+
int expectedNodeId = 0;
106+
for (Tree.Node visitedNode: trace) {
107+
if (visitedNode.isLeaf() == false) {
108+
// Imagine the nodes array is 1 based index. The root node
109+
// has index 1, its children 2 & 3. Because the tree is built
110+
// breadth first node 2 children will be at indexes 4 & 5 and
111+
// node 3 children are at 6 & 7.
112+
// So a nodes children are at nodeindex * 2 and nodeindex * 2 +1
113+
// and the parent is at nodeindex / 2.
114+
// The +/- 1's are adjusting for a 0 based index
115+
int nodeId = ((visitedNode.leftChild + 1) / 2) - 1;
116+
assertEquals(expectedNodeId, nodeId);
117+
118+
// unfortunately this doesn't apply to leaf nodes
119+
// as their children are -1
120+
121+
expectedNodeId = visitedNode.compare(features);
122+
} else {
123+
assertEquals(prediction, visitedNode.value(), 0.0001);
124+
}
125+
}
126+
127+
assertThat(tree.missingNodes(), hasSize(0));
128+
}
129+
130+
public void testCompare() {
131+
int leftChild = 1;
132+
int rightChild = 2;
133+
Tree.Node node = new Tree.Node(leftChild, rightChild, 0, true, 0.5);
134+
135+
List<Double> features = List.of(0.1);
136+
assertEquals(leftChild, node.compare(features));
137+
138+
features = List.of(0.9);
139+
assertEquals(rightChild, node.compare(features));
140+
}
141+
142+
public void testCompare_nonDefaultOperator() {
143+
int leftChild = 1;
144+
int rightChild = 2;
145+
Tree.Node node = new Tree.Node(leftChild, rightChild, 0, true, 0.5, (value, threshold) -> value >= threshold);
146+
147+
List<Double> features = List.of(0.1);
148+
assertEquals(rightChild, node.compare(features));
149+
features = List.of(0.5);
150+
assertEquals(leftChild, node.compare(features));
151+
features = List.of(0.9);
152+
assertEquals(leftChild, node.compare(features));
153+
154+
node = new Tree.Node(leftChild, rightChild, 0, true, 0.5, (value, threshold) -> value <= threshold);
155+
156+
features = List.of(0.1);
157+
assertEquals(leftChild, node.compare(features));
158+
features = List.of(0.5);
159+
assertEquals(leftChild, node.compare(features));
160+
features = List.of(0.9);
161+
assertEquals(rightChild, node.compare(features));
162+
}
163+
164+
public void testCompare_missingFeature() {
165+
int leftChild = 1;
166+
int rightChild = 2;
167+
Tree.Node leftBiasNode = new Tree.Node(leftChild, rightChild, 0, true, 0.5);
168+
List<Double> features = new ArrayList<>();
169+
features.add(null);
170+
assertEquals(leftChild, leftBiasNode.compare(features));
171+
172+
Tree.Node rightBiasNode = new Tree.Node(leftChild, rightChild, 0, false, 0.5);
173+
assertEquals(rightChild, rightBiasNode.compare(features));
174+
}
175+
176+
public void testIsLeaf() {
177+
Tree.Node leaf = new Tree.Node(0.0);
178+
assertTrue(leaf.isLeaf());
179+
180+
Tree.Node node = new Tree.Node(1, 2, 0, false, 0.0);
181+
assertFalse(node.isLeaf());
182+
}
183+
184+
public void testMissingNodes() {
185+
Tree.TreeBuilder builder = Tree.TreeBuilder.newTreeBuilder();
186+
Tree.Node rootNode = builder.addJunction(0, 0, true, randomDecisionThreshold());
187+
188+
Tree.Node node2 = builder.addJunction(rootNode.rightChild, 0, false, 0.1);
189+
builder.addLeaf(node2.leftChild, 0.1);
190+
191+
List<Integer> missingNodeIndexes = builder.build().missingNodes();
192+
assertEquals(Arrays.asList(1, 4), missingNodeIndexes);
193+
}
194+
}

0 commit comments

Comments
 (0)