1
+ /*
2
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3
+ * or more contributor license agreements. Licensed under the Elastic License;
4
+ * you may not use this file except in compliance with the Elastic License.
5
+ */
6
+
7
+ package org .elasticsearch .xpack .ml .inference .tree ;
8
+
9
+ import org .elasticsearch .test .ESTestCase ;
10
+
11
+ import java .util .ArrayList ;
12
+ import java .util .Arrays ;
13
+ import java .util .List ;
14
+
15
+ import static org .hamcrest .Matchers .hasSize ;
16
+
17
+ public class TreeTests extends ESTestCase {
18
+
19
+ public static Tree buildRandomTree (int numFeatures , int depth ) {
20
+
21
+ Tree .TreeBuilder builder = Tree .TreeBuilder .newTreeBuilder ();
22
+
23
+ Tree .Node node = builder .addJunction (0 , randomFeatureIndex (numFeatures ), true , randomDecisionThreshold ());
24
+ List <Integer > childNodes = List .of (node .leftChild , node .rightChild );
25
+
26
+ for (int i =0 ; i <depth -1 ; i ++) {
27
+
28
+ List <Integer > nextNodes = new ArrayList <>();
29
+ for (int nodeId : childNodes ) {
30
+
31
+ if (i == depth -2 ) {
32
+ builder .addLeaf (nodeId , randomDecisionThreshold ());
33
+ } else {
34
+ Tree .Node childNode = builder .addJunction (nodeId , randomFeatureIndex (numFeatures ), true , randomDecisionThreshold ());
35
+ nextNodes .add (childNode .leftChild );
36
+ nextNodes .add (childNode .rightChild );
37
+ }
38
+ }
39
+
40
+ childNodes = nextNodes ;
41
+ }
42
+
43
+ return builder .build ();
44
+ }
45
+
46
+ static int randomFeatureIndex (int max ) {
47
+ return randomIntBetween (0 , max -1 );
48
+ }
49
+
50
+ static double randomDecisionThreshold () {
51
+ return randomDouble ();
52
+ }
53
+
54
+ public void testPredict () {
55
+ // Build a tree with 2 nodes and 3 leaves using 2 features
56
+ // The leaves have unique values 0.1, 0.2, 0.3
57
+ Tree .TreeBuilder builder = Tree .TreeBuilder .newTreeBuilder ();
58
+ Tree .Node rootNode = builder .addJunction (0 , 0 , true , 0.5 );
59
+ builder .addLeaf (rootNode .rightChild , 0.3 );
60
+ Tree .Node leftChildNode = builder .addJunction (rootNode .leftChild , 1 , true , 0.8 );
61
+ builder .addLeaf (leftChildNode .leftChild , 0.1 );
62
+ builder .addLeaf (leftChildNode .rightChild , 0.2 );
63
+
64
+ Tree tree = builder .build ();
65
+
66
+ // This feature vector should hit the right child of the root node
67
+ List <Double > featureVector = Arrays .asList (0.6 , 0.0 );
68
+ assertEquals (0.3 , tree .predict (featureVector ), 0.00001 );
69
+
70
+ // This should hit the left child of the left child of the root node
71
+ // i.e. it takes the path left, left
72
+ featureVector = Arrays .asList (0.3 , 0.7 );
73
+ assertEquals (0.1 , tree .predict (featureVector ), 0.00001 );
74
+
75
+ // This should hit the right child of the left child of the root node
76
+ // i.e. it takes the path left, right
77
+ featureVector = Arrays .asList (0.3 , 0.8 );
78
+ assertEquals (0.2 , tree .predict (featureVector ), 0.00001 );
79
+ }
80
+
81
+ public void testTrace () {
82
+ int numFeatures = randomIntBetween (1 , 6 );
83
+ int depth = 6 ;
84
+ Tree tree = buildRandomTree (numFeatures , depth );
85
+
86
+ List <Double > features = new ArrayList <>(numFeatures );
87
+ for (int i =0 ; i <numFeatures ; i ++) {
88
+ features .add (randomDecisionThreshold ());
89
+ }
90
+
91
+ List <Tree .Node > trace = tree .trace (features );
92
+ assertThat (trace , hasSize (depth ));
93
+ for (int i =0 ; i <trace .size () -2 ; i ++) {
94
+ assertFalse (trace .get (i ).isLeaf ());
95
+ }
96
+ assertTrue (trace .get (trace .size () -1 ).isLeaf ());
97
+
98
+ double prediction = tree .predict (features );
99
+ assertEquals (trace .get (trace .size () -1 ).value (), prediction , 0.0001 );
100
+
101
+ // Because the tree is built up breadth first we can figure out o
102
+ // a node's id from its child nodes. Then we can trace the route
103
+ // taken an assert it's the branch decisions were correct
104
+
105
+ int expectedNodeId = 0 ;
106
+ for (Tree .Node visitedNode : trace ) {
107
+ if (visitedNode .isLeaf () == false ) {
108
+ // Imagine the nodes array is 1 based index. The root node
109
+ // has index 1, its children 2 & 3. Because the tree is built
110
+ // breadth first node 2 children will be at indexes 4 & 5 and
111
+ // node 3 children are at 6 & 7.
112
+ // So a nodes children are at nodeindex * 2 and nodeindex * 2 +1
113
+ // and the parent is at nodeindex / 2.
114
+ // The +/- 1's are adjusting for a 0 based index
115
+ int nodeId = ((visitedNode .leftChild + 1 ) / 2 ) - 1 ;
116
+ assertEquals (expectedNodeId , nodeId );
117
+
118
+ // unfortunately this doesn't apply to leaf nodes
119
+ // as their children are -1
120
+
121
+ expectedNodeId = visitedNode .compare (features );
122
+ } else {
123
+ assertEquals (prediction , visitedNode .value (), 0.0001 );
124
+ }
125
+ }
126
+
127
+ assertThat (tree .missingNodes (), hasSize (0 ));
128
+ }
129
+
130
+ public void testCompare () {
131
+ int leftChild = 1 ;
132
+ int rightChild = 2 ;
133
+ Tree .Node node = new Tree .Node (leftChild , rightChild , 0 , true , 0.5 );
134
+
135
+ List <Double > features = List .of (0.1 );
136
+ assertEquals (leftChild , node .compare (features ));
137
+
138
+ features = List .of (0.9 );
139
+ assertEquals (rightChild , node .compare (features ));
140
+ }
141
+
142
+ public void testCompare_nonDefaultOperator () {
143
+ int leftChild = 1 ;
144
+ int rightChild = 2 ;
145
+ Tree .Node node = new Tree .Node (leftChild , rightChild , 0 , true , 0.5 , (value , threshold ) -> value >= threshold );
146
+
147
+ List <Double > features = List .of (0.1 );
148
+ assertEquals (rightChild , node .compare (features ));
149
+ features = List .of (0.5 );
150
+ assertEquals (leftChild , node .compare (features ));
151
+ features = List .of (0.9 );
152
+ assertEquals (leftChild , node .compare (features ));
153
+
154
+ node = new Tree .Node (leftChild , rightChild , 0 , true , 0.5 , (value , threshold ) -> value <= threshold );
155
+
156
+ features = List .of (0.1 );
157
+ assertEquals (leftChild , node .compare (features ));
158
+ features = List .of (0.5 );
159
+ assertEquals (leftChild , node .compare (features ));
160
+ features = List .of (0.9 );
161
+ assertEquals (rightChild , node .compare (features ));
162
+ }
163
+
164
+ public void testCompare_missingFeature () {
165
+ int leftChild = 1 ;
166
+ int rightChild = 2 ;
167
+ Tree .Node leftBiasNode = new Tree .Node (leftChild , rightChild , 0 , true , 0.5 );
168
+ List <Double > features = new ArrayList <>();
169
+ features .add (null );
170
+ assertEquals (leftChild , leftBiasNode .compare (features ));
171
+
172
+ Tree .Node rightBiasNode = new Tree .Node (leftChild , rightChild , 0 , false , 0.5 );
173
+ assertEquals (rightChild , rightBiasNode .compare (features ));
174
+ }
175
+
176
+ public void testIsLeaf () {
177
+ Tree .Node leaf = new Tree .Node (0.0 );
178
+ assertTrue (leaf .isLeaf ());
179
+
180
+ Tree .Node node = new Tree .Node (1 , 2 , 0 , false , 0.0 );
181
+ assertFalse (node .isLeaf ());
182
+ }
183
+
184
+ public void testMissingNodes () {
185
+ Tree .TreeBuilder builder = Tree .TreeBuilder .newTreeBuilder ();
186
+ Tree .Node rootNode = builder .addJunction (0 , 0 , true , randomDecisionThreshold ());
187
+
188
+ Tree .Node node2 = builder .addJunction (rootNode .rightChild , 0 , false , 0.1 );
189
+ builder .addLeaf (node2 .leftChild , 0.1 );
190
+
191
+ List <Integer > missingNodeIndexes = builder .build ().missingNodes ();
192
+ assertEquals (Arrays .asList (1 , 4 ), missingNodeIndexes );
193
+ }
194
+ }
0 commit comments