Skip to content

Commit 8272821

Browse files
committed
Added code for rearranging / reinserting internal nodes.
1 parent 558528c commit 8272821

File tree

4 files changed

+170
-118
lines changed

4 files changed

+170
-118
lines changed

src/EMTree.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ void sigEMTreeCluster(vector<SVector<bool>*> &vectors) {
8888

8989
// EMTree
9090
int depth = 3;
91-
int iters = 2;
92-
vector<int> nodeSizes = {100}; //{10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20};
91+
int iters = 10;
92+
vector<int> nodeSizes = {10}; //{10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20};
9393
for (int m : nodeSizes) {
9494
std::cout << "-------------------" << std::endl;
9595
EMTree<vecType, clustererType, distanceType, protoType> emt(m);
@@ -403,9 +403,9 @@ void journalPaperExperiments(vector<SVector<bool>*>& vectors) {
403403
typedef KMeans<vecType, seederType, distanceType, protoType> clustererType;
404404

405405
// run TSVQ vs EM-tree convergence
406-
if (false) {
406+
if (true) {
407407
int depth = 3, m = 10;
408-
int iterRange = 40; // test RMSE at 1 to maxiters iterations
408+
int iterRange = 10; // test RMSE at 1 to maxiters iterations
409409

410410
//TSVQ
411411
if (true) {
@@ -414,6 +414,7 @@ void journalPaperExperiments(vector<SVector<bool>*>& vectors) {
414414
vector<double> seconds;
415415
for (int maxiters = 1; maxiters <= iterRange; ++maxiters) {
416416
boost::timer::auto_cpu_timer all;
417+
srand(1234);
417418
TSVQ<vecType, clustererType, distanceType, protoType> tsvq(m, depth, maxiters);
418419
tsvq.cluster(vectors);
419420
all.stop();
@@ -439,6 +440,7 @@ void journalPaperExperiments(vector<SVector<bool>*>& vectors) {
439440
splits.push_back(m);
440441
}
441442
boost::timer::auto_cpu_timer all;
443+
srand(1234);
442444
EMTree<vecType, clustererType, distanceType, protoType> emt(m);
443445
// seeding does first iteration
444446
emt.seedSingleThreaded(vectors, splits);
@@ -489,7 +491,7 @@ void journalPaperExperiments(vector<SVector<bool>*>& vectors) {
489491
}
490492

491493
//EM-tree
492-
if (true) {
494+
if (false) {
493495
int maxiters = 6;
494496
vector<double> rmse;
495497
vector<int> clusters;
@@ -616,7 +618,7 @@ int main(int argc, char** argv) {
616618
journalPaperExperiments(subset);
617619
//sigKTreeCluster(vectors);
618620
//sigTSVQCluster(vectors);
619-
//sigEMTreeCluster(vectors);
621+
//sigEMTreeCluster(subset);
620622
//testHistogram(vectors);
621623
//testMeanVersusNNSpeed(vectors);
622624
//testReadVectors();

src/EMTree.h

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class EMTree {
2525
ProtoType _protoF;
2626

2727
vector<T*> removed;
28+
vector<Node<T>*> removedChildren;
2829

2930
// Weights for prototype function (we don't have to use these)
3031
vector<int> weights;
@@ -206,6 +207,18 @@ class EMTree {
206207
removed.clear();
207208
}
208209

210+
void rearrangeInternal() {
211+
for (int depth = 2; depth < getMaxLevelCount(); ++depth) {
212+
removeDataInternal(_root, removed, removedChildren, depth);
213+
for (int i = 0; i < removed.size(); i++) {
214+
pushDownNoUpdateInternal(_root, removed[i], removedChildren[i], depth);
215+
}
216+
prune();
217+
removed.clear();
218+
removedChildren.clear();
219+
}
220+
}
221+
209222
int prune() {
210223
return prune(_root);
211224
}
@@ -416,6 +429,26 @@ void add(T *obj) {
416429
pushDownNoUpdate(nearestChild, vec);
417430
}
418431
}
432+
433+
void pushDownNoUpdateInternal(Node<T> *n, T* key, Node<T>* child, int depth) {
434+
435+
//std::cout << "\n\tPushing down (no update) ...";
436+
437+
if (depth == 1) {
438+
n->add(key, child); // Finished
439+
} else { // It is an internal node.
440+
// recurse via nearest neighbour cluster
441+
442+
vector<T*>& keys = n->getKeys();
443+
vector<Node<T>*>& children = n->getChildren();
444+
445+
size_t nearest = nearestChild(key, keys);
446+
Node<T> *nearestChild = children[nearest];
447+
448+
pushDownNoUpdateInternal(nearestChild, key, child, depth - 1);
449+
}
450+
}
451+
419452

420453
/*
421454
SplitResult<T> pushDown(Node<T> *n, T *vec) {
@@ -591,7 +624,16 @@ SplitResult<T> pushDown(Node<T> *n, T *vec) {
591624
}
592625
}
593626
}
594-
627+
628+
void removeDataInternal(Node<T>* n, vector<T*>& keys, vector<Node<T>*>& children, int depth) {
629+
if (depth == 1) {
630+
n->removeData(keys, children);
631+
} else {
632+
for (Node<T>* child : n->getChildren()) {
633+
removeDataInternal(child, keys, children, depth - 1);
634+
}
635+
}
636+
}
595637
};
596638

597639

0 commit comments

Comments
 (0)