Skip to content

Commit b25709c

Browse files
committed
Tree and XgBoost tree parser
1 parent e15eafd commit b25709c

File tree

7 files changed

+792
-1
lines changed

7 files changed

+792
-1
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@
203203
import org.elasticsearch.xpack.ml.inference.InferenceProcessor;
204204
import org.elasticsearch.xpack.ml.inference.ModelLoader;
205205
import org.elasticsearch.xpack.ml.inference.sillymodel.SillyModelLoader;
206+
import org.elasticsearch.xpack.ml.inference.tree.Tree;
207+
import org.elasticsearch.xpack.ml.inference.tree.TreeModelLoader;
206208
import org.elasticsearch.xpack.ml.job.JobManager;
207209
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
208210
import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier;
@@ -628,7 +630,8 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
628630
}
629631

630632
private Map<String, ModelLoader> getModelLoaders(Client client) {
631-
return Map.of(SillyModelLoader.MODEL_TYPE, new SillyModelLoader(client));
633+
return Map.of(SillyModelLoader.MODEL_TYPE, new SillyModelLoader(client),
634+
TreeModelLoader.MODEL_TYPE, new TreeModelLoader(client));
632635
}
633636

634637
@Override
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
package org.elasticsearch.xpack.ml.inference.tree;
2+
3+
import java.util.ArrayList;
4+
import java.util.Collections;
5+
import java.util.List;
6+
import java.util.function.BiPredicate;
7+
8+
9+
/**
10+
* A decision tree that can make predictions given a feature vector
11+
*/
12+
public class Tree {
13+
private final List<Node> nodes;
14+
15+
Tree(List<Node> nodes) {
16+
this.nodes = Collections.unmodifiableList(nodes);
17+
}
18+
19+
/**
20+
* Trace the route predicting on the feature vector takes.
21+
* @param features The feature vector
22+
* @return The list of traversed nodes ordered from root to leaf
23+
*/
24+
public List<Node> trace(List<Double> features) {
25+
return trace(features, 0, new ArrayList<>());
26+
}
27+
28+
private List<Node> trace(List<Double> features, int nodeIndex, List<Node> visited) {
29+
Node node = nodes.get(nodeIndex);
30+
visited.add(node);
31+
if (node.isLeaf()) {
32+
return visited;
33+
}
34+
35+
int nextNode = node.compare(features);
36+
return trace(features, nextNode, visited);
37+
}
38+
39+
/**
40+
* Make a prediction based on the feature vector
41+
* @param features The feature vector
42+
* @return The prediction
43+
*/
44+
public Double predict(List<Double> features) {
45+
return predict(features, 0);
46+
}
47+
48+
private Double predict(List<Double> features, int nodeIndex) {
49+
Node node = nodes.get(nodeIndex);
50+
if (node.isLeaf()) {
51+
return node.value();
52+
}
53+
54+
int nextNode = node.compare(features);
55+
return predict(features, nextNode);
56+
}
57+
58+
/**
59+
* Finds null nodes
60+
* @return List of indexes to null nodes
61+
*/
62+
List<Integer> missingNodes() {
63+
List<Integer> nullNodeIndices = new ArrayList<>();
64+
for (int i=0; i<nodes.size(); i++) {
65+
if (nodes.get(i) == null) {
66+
nullNodeIndices.add(i);
67+
}
68+
}
69+
return nullNodeIndices;
70+
}
71+
72+
@Override
73+
public String toString() {
74+
return nodes.toString();
75+
}
76+
77+
public static class Node {
78+
int leftChild;
79+
int rightChild;
80+
int featureIndex;
81+
boolean isDefaultLeft;
82+
double thresholdValue;
83+
BiPredicate<Double, Double> operator;
84+
85+
Node(int leftChild, int rightChild, int featureIndex, boolean isDefaultLeft, double thresholdValue) {
86+
this.leftChild = leftChild;
87+
this.rightChild = rightChild;
88+
this.featureIndex = featureIndex;
89+
this.isDefaultLeft = isDefaultLeft;
90+
this.thresholdValue = thresholdValue;
91+
this.operator = (value, threshold) -> value < threshold; // less than
92+
}
93+
94+
Node(int leftChild, int rightChild, int featureIndex, boolean isDefaultLeft, double thresholdValue,
95+
BiPredicate<Double, Double> operator) {
96+
this.leftChild = leftChild;
97+
this.rightChild = rightChild;
98+
this.featureIndex = featureIndex;
99+
this.isDefaultLeft = isDefaultLeft;
100+
this.thresholdValue = thresholdValue;
101+
this.operator = operator;
102+
}
103+
104+
Node(double value) {
105+
this(-1, -1, -1, false, value);
106+
}
107+
108+
boolean isLeaf() {
109+
return leftChild < 1;
110+
}
111+
112+
int compare(List<Double> features) {
113+
Double feature = features.get(featureIndex);
114+
if (isMissing(feature)) {
115+
return isDefaultLeft ? leftChild : rightChild;
116+
}
117+
118+
return operator.test(feature, thresholdValue) ? leftChild : rightChild;
119+
}
120+
121+
boolean isMissing(Double feature) {
122+
return feature == null;
123+
}
124+
125+
Double value() {
126+
return thresholdValue;
127+
}
128+
129+
@Override
130+
public String toString() {
131+
StringBuilder builder = new StringBuilder("{\n");
132+
builder.append("left: ").append(leftChild).append('\n');
133+
builder.append("right: ").append(rightChild).append('\n');
134+
builder.append("isDefaultLeft: ").append(isDefaultLeft).append('\n');
135+
builder.append("isLeaf: ").append(isLeaf()).append('\n');
136+
builder.append("featureIndex: ").append(featureIndex).append('\n');
137+
builder.append("value: ").append(thresholdValue).append('\n');
138+
builder.append("}\n");
139+
return builder.toString();
140+
}
141+
}
142+
143+
144+
public static class TreeBuilder {
145+
146+
private final ArrayList<Node> nodes;
147+
private int numNodes;
148+
149+
public static TreeBuilder newTreeBuilder() {
150+
return new TreeBuilder();
151+
}
152+
153+
TreeBuilder() {
154+
nodes = new ArrayList<>();
155+
// allocate space in the root node and set to a leaf
156+
nodes.add(null);
157+
addLeaf(0, 0.0);
158+
numNodes = 1;
159+
}
160+
161+
/**
162+
* Add a decision node. Space for the child nodes is allocated
163+
* @param nodeIndex Where to place the node. This is either 0 (root) or an existing child node index
164+
* @param featureIndex The feature index the decision is made on
165+
* @param isDefaultLeft Default left branch if the feature is missing
166+
* @param decisionThreshold The decision threshold
167+
* @return The created node
168+
*/
169+
public Node addJunction(int nodeIndex, int featureIndex, boolean isDefaultLeft, double decisionThreshold) {
170+
// assert nodeIndex < nodes.size() : "node index " + nodeIndex + " >= size " + nodes.size();
171+
172+
int leftChild = numNodes++;
173+
int rightChild = numNodes++;
174+
nodes.ensureCapacity(nodeIndex +1);
175+
for (int i=nodes.size(); i<nodeIndex +1; i++) {
176+
nodes.add(null);
177+
}
178+
179+
180+
Node node = new Node(leftChild, rightChild, featureIndex, isDefaultLeft, decisionThreshold);
181+
nodes.set(nodeIndex, node);
182+
// allocate space for the child nodes
183+
nodes.add(null);
184+
nodes.add(null);
185+
// assert nodes.size() == numNodes : "nodes size " + nodes.size() + " != num nodes " + numNodes;
186+
187+
return node;
188+
}
189+
190+
/**
191+
* Sets the node at {@code nodeIndex} to a leaf node.
192+
* @param nodeIndex The index as allocated by a call to {@link #addJunction(int, int, boolean, double)}
193+
* @param value The prediction value
194+
* @return this
195+
*/
196+
public TreeBuilder addLeaf(int nodeIndex, double value) {
197+
// assert nodeIndex < nodes.size();
198+
199+
for (int i=nodes.size(); i<nodeIndex +1; i++) {
200+
nodes.add(null);
201+
}
202+
203+
204+
205+
assert nodes.get(nodeIndex) == null;
206+
207+
nodes.set(nodeIndex, new Node(value));
208+
return this;
209+
}
210+
211+
public Tree build() {
212+
return new Tree(nodes);
213+
}
214+
}
215+
}
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
package org.elasticsearch.xpack.ml.inference.tree;
2+
3+
import org.apache.log4j.LogManager;
4+
import org.apache.log4j.Logger;
5+
6+
import java.util.ArrayList;
7+
import java.util.Arrays;
8+
import java.util.Collections;
9+
import java.util.List;
10+
import java.util.Map;
11+
import java.util.stream.Collectors;
12+
13+
public class TreeEnsembleModel {
14+
15+
private static Logger logger = LogManager.getLogger(TreeEnsembleModel.class);
16+
17+
private final List<Tree> trees;
18+
private final Map<String, Integer> featureMap;
19+
20+
private TreeEnsembleModel(List<Tree> trees, Map<String, Integer> featureMap) {
21+
this.trees = Collections.unmodifiableList(trees);
22+
this.featureMap = featureMap;
23+
}
24+
25+
public int numFeatures() {
26+
return featureMap.size();
27+
}
28+
29+
public int numTrees() {
30+
return trees.size();
31+
}
32+
33+
public List<Integer> checkForNull() {
34+
List<Integer> missing = new ArrayList<>();
35+
for (Tree tree : trees) {
36+
missing.addAll(tree.missingNodes());
37+
}
38+
return missing;
39+
}
40+
41+
public double predictFromDoc(Map<String, Object> features) {
42+
List<Double> featureVec = docToFeatureVector(features);
43+
List<Double> predictions = trees.stream().map(tree -> tree.predict(featureVec)).collect(Collectors.toList());
44+
return mergePredictions(predictions);
45+
}
46+
47+
public double predict(Map<String, Double> features) {
48+
List<Double> featureVec = doubleDocToFeatureVector(features);
49+
List<Double> predictions = trees.stream().map(tree -> tree.predict(featureVec)).collect(Collectors.toList());
50+
return mergePredictions(predictions);
51+
}
52+
53+
public List<List<Tree.Node>> trace(Map<String, Double> features) {
54+
List<Double> featureVec = doubleDocToFeatureVector(features);
55+
return trees.stream().map(tree -> tree.trace(featureVec)).collect(Collectors.toList());
56+
}
57+
58+
double mergePredictions(List<Double> predictions) {
59+
return predictions.stream().mapToDouble(f -> f).summaryStatistics().getSum();
60+
}
61+
62+
List<Double> doubleDocToFeatureVector(Map<String, Double> features) {
63+
List<Double> featureVec = Arrays.asList(new Double[featureMap.size()]);
64+
65+
for (Map.Entry<String, Double> keyValue : features.entrySet()) {
66+
if (featureMap.containsKey(keyValue.getKey())) {
67+
featureVec.set(featureMap.get(keyValue.getKey()), keyValue.getValue());
68+
}
69+
}
70+
71+
return featureVec;
72+
}
73+
74+
List<Double> docToFeatureVector(Map<String, Object> features) {
75+
List<Double> featureVec = Arrays.asList(new Double[featureMap.size()]);
76+
77+
for (Map.Entry<String, Object> keyValue : features.entrySet()) {
78+
if (featureMap.containsKey(keyValue.getKey())) {
79+
Double value = (Double)keyValue.getValue();
80+
if (value != null) {
81+
featureVec.set(featureMap.get(keyValue.getKey()), value);
82+
}
83+
}
84+
}
85+
86+
return featureVec;
87+
}
88+
89+
public static ModelBuilder modelBuilder(Map<String, Integer> featureMap) {
90+
return new ModelBuilder(featureMap);
91+
}
92+
93+
public static class ModelBuilder {
94+
private List<Tree> trees;
95+
private Map<String, Integer> featureMap;
96+
97+
public ModelBuilder(Map<String, Integer> featureMap) {
98+
this.featureMap = featureMap;
99+
trees = new ArrayList<>();
100+
}
101+
102+
public ModelBuilder addTree(Tree tree) {
103+
trees.add(tree);
104+
return this;
105+
}
106+
107+
public TreeEnsembleModel build() {
108+
return new TreeEnsembleModel(trees, featureMap);
109+
}
110+
}
111+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package org.elasticsearch.xpack.ml.inference.tree;
2+
3+
import org.elasticsearch.ingest.IngestDocument;
4+
import org.elasticsearch.xpack.ml.inference.Model;
5+
6+
public class TreeModel implements Model {
7+
8+
private String targetFieldName;
9+
private TreeEnsembleModel ensemble;
10+
11+
12+
TreeModel(TreeEnsembleModel ensemble, String targetFieldName) {
13+
this.ensemble = ensemble;
14+
this.targetFieldName = targetFieldName;
15+
}
16+
17+
@Override
18+
public IngestDocument infer(IngestDocument document) {
19+
Double prediction = ensemble.predictFromDoc(document.getSourceAndMetadata());
20+
document.setFieldValue(targetFieldName, prediction);
21+
return document;
22+
}
23+
}

0 commit comments

Comments
 (0)