GTSAM  4.0.2
C++ library for smoothing and mapping (SAM)
DecisionTree-inl.h
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
20 #pragma once
21 
23 
24 #include <algorithm>
25 
26 #include <cmath>
27 #include <fstream>
28 #include <list>
29 #include <map>
30 #include <set>
31 #include <sstream>
32 #include <string>
33 #include <vector>
34 #include <optional>
35 #include <cassert>
36 
37 namespace gtsam {
38 
39  /****************************************************************************/
40  // Node
41  /****************************************************************************/
42 #ifdef DT_DEBUG_MEMORY
43  template<typename L, typename Y>
44  int DecisionTree<L, Y>::Node::nrNodes = 0;
45 #endif
46 
47  /****************************************************************************/
48  // Leaf
49  /****************************************************************************/
50  template <typename L, typename Y>
51  struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
54 
59 
61  Leaf() {}
62 
64  Leaf(const Y& constant, size_t nrAssignments = 1)
65  : constant_(constant), nrAssignments_(nrAssignments) {}
66 
68  const Y& constant() const {
69  return constant_;
70  }
71 
73  size_t nrAssignments() const { return nrAssignments_; }
74 
76  bool sameLeaf(const Leaf& q) const override {
77  return constant_ == q.constant_;
78  }
79 
81  bool sameLeaf(const Node& q) const override {
82  return (q.isLeaf() && q.sameLeaf(*this));
83  }
84 
86  bool equals(const Node& q, const CompareFunc& compare) const override {
87  const Leaf* other = dynamic_cast<const Leaf*>(&q);
88  if (!other) return false;
89  return compare(this->constant_, other->constant_);
90  }
91 
93  void print(const std::string& s, const LabelFormatter& labelFormatter,
94  const ValueFormatter& valueFormatter) const override {
95  std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
96  }
97 
99  void dot(std::ostream& os, const LabelFormatter& labelFormatter,
100  const ValueFormatter& valueFormatter,
101  bool showZero) const override {
102  std::string value = valueFormatter(constant_);
103  if (showZero || value.compare("0"))
104  os << "\"" << this->id() << "\" [label=\"" << value
105  << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
106  }
107 
109  const Y& operator()(const Assignment<L>& x) const override {
110  return constant_;
111  }
112 
114  NodePtr apply(const Unary& op) const override {
115  NodePtr f(new Leaf(op(constant_), nrAssignments_));
116  return f;
117  }
118 
120  NodePtr apply(const UnaryAssignment& op,
121  const Assignment<L>& assignment) const override {
122  NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_));
123  return f;
124  }
125 
126  // Apply binary operator "h = f op g" on Leaf node
127  // Note op is not assumed commutative so we need to keep track of order
128  // Simply calls apply on argument to call correct virtual method:
129  // fL.apply_f_op_g(gL) -> gL.apply_g_op_fL(fL) (below)
130  // fL.apply_f_op_g(gC) -> gC.apply_g_op_fL(fL) (Choice)
131  NodePtr apply_f_op_g(const Node& g, const Binary& op) const override {
132  return g.apply_g_op_fL(*this, op);
133  }
134 
135  // Applying binary operator to two leaves results in a leaf
136  NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
137  // fL op gL
138  NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_));
139  return h;
140  }
141 
142  // If second argument is a Choice node, call it's apply with leaf as second
143  NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
144  return fC.apply_fC_op_gL(*this, op); // operand order back to normal
145  }
146 
148  NodePtr choose(const L& label, size_t index) const override {
149  return NodePtr(new Leaf(constant(), nrAssignments()));
150  }
151 
152  bool isLeaf() const override { return true; }
153 
154  private:
156 
157 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
158 
159  friend class boost::serialization::access;
160  template <class ARCHIVE>
161  void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
162  ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
163  ar& BOOST_SERIALIZATION_NVP(constant_);
164  ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
165  }
166 #endif
167  }; // Leaf
168 
169  /****************************************************************************/
170  // Choice
171  /****************************************************************************/
172  template<typename L, typename Y>
173  struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
176 
178  std::vector<NodePtr> branches_;
179 
180  private:
185  size_t allSame_;
186 
187  using ChoicePtr = std::shared_ptr<const Choice>;
188 
189  public:
191  Choice() {}
192 
193  ~Choice() override {
194 #ifdef DT_DEBUG_MEMORY
195  std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
196  << std::endl;
197 #endif
198  }
199 
201  static NodePtr Unique(const ChoicePtr& f) {
202 #ifndef GTSAM_DT_NO_PRUNING
203  if (f->allSame_) {
204  assert(f->branches().size() > 0);
205  NodePtr f0 = f->branches_[0];
206 
207  size_t nrAssignments = 0;
208  for(auto branch: f->branches()) {
209  assert(branch->isLeaf());
210  nrAssignments +=
211  std::dynamic_pointer_cast<const Leaf>(branch)->nrAssignments();
212  }
213  NodePtr newLeaf(
214  new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
215  nrAssignments));
216  return newLeaf;
217  } else
218 #endif
219  return f;
220  }
221 
222  bool isLeaf() const override { return false; }
223 
225  Choice(const L& label, size_t count) :
226  label_(label), allSame_(true) {
227  branches_.reserve(count);
228  }
229 
231  Choice(const Choice& f, const Choice& g, const Binary& op) :
232  allSame_(true) {
233  // Choose what to do based on label
234  if (f.label() > g.label()) {
235  // f higher than g
236  label_ = f.label();
237  size_t count = f.nrChoices();
238  branches_.reserve(count);
239  for (size_t i = 0; i < count; i++)
240  push_back(f.branches_[i]->apply_f_op_g(g, op));
241  } else if (g.label() > f.label()) {
242  // f lower than g
243  label_ = g.label();
244  size_t count = g.nrChoices();
245  branches_.reserve(count);
246  for (size_t i = 0; i < count; i++)
247  push_back(g.branches_[i]->apply_g_op_fC(f, op));
248  } else {
249  // f same level as g
250  label_ = f.label();
251  size_t count = f.nrChoices();
252  branches_.reserve(count);
253  for (size_t i = 0; i < count; i++)
254  push_back(f.branches_[i]->apply_f_op_g(*g.branches_[i], op));
255  }
256  }
257 
259  const L& label() const {
260  return label_;
261  }
262 
263  size_t nrChoices() const {
264  return branches_.size();
265  }
266 
267  const std::vector<NodePtr>& branches() const {
268  return branches_;
269  }
270 
272  void push_back(const NodePtr& node) {
273  // allSame_ is restricted to leaf nodes in a decision tree
274  if (allSame_ && !branches_.empty()) {
275  allSame_ = node->sameLeaf(*branches_.back());
276  }
277  branches_.push_back(node);
278  }
279 
281  void print(const std::string& s, const LabelFormatter& labelFormatter,
282  const ValueFormatter& valueFormatter) const override {
283  std::cout << s << " Choice(";
284  std::cout << labelFormatter(label_) << ") " << std::endl;
285  for (size_t i = 0; i < branches_.size(); i++) {
286  branches_[i]->print(s + " " + std::to_string(i), labelFormatter, valueFormatter);
287  }
288  }
289 
291  void dot(std::ostream& os, const LabelFormatter& labelFormatter,
292  const ValueFormatter& valueFormatter,
293  bool showZero) const override {
294  os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
295  << "\"]\n";
296  size_t B = branches_.size();
297  for (size_t i = 0; i < B; i++) {
298  const NodePtr& branch = branches_[i];
299 
300  // Check if zero
301  if (!showZero) {
302  const Leaf* leaf = dynamic_cast<const Leaf*>(branch.get());
303  if (leaf && valueFormatter(leaf->constant()).compare("0")) continue;
304  }
305 
306  os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
307  if (B == 2 && i == 0) os << " [style=dashed]";
308  os << std::endl;
309  branch->dot(os, labelFormatter, valueFormatter, showZero);
310  }
311  }
312 
314  bool sameLeaf(const Leaf& q) const override {
315  return false;
316  }
317 
319  bool sameLeaf(const Node& q) const override {
320  return (q.isLeaf() && q.sameLeaf(*this));
321  }
322 
324  bool equals(const Node& q, const CompareFunc& compare) const override {
325  const Choice* other = dynamic_cast<const Choice*>(&q);
326  if (!other) return false;
327  if (this->label_ != other->label_) return false;
328  if (branches_.size() != other->branches_.size()) return false;
329  // we don't care about shared pointers being equal here
330  for (size_t i = 0; i < branches_.size(); i++)
331  if (!(branches_[i]->equals(*(other->branches_[i]), compare)))
332  return false;
333  return true;
334  }
335 
337  const Y& operator()(const Assignment<L>& x) const override {
338 #ifndef NDEBUG
339  typename Assignment<L>::const_iterator it = x.find(label_);
340  if (it == x.end()) {
341  std::cout << "Trying to find value for " << label_ << std::endl;
342  throw std::invalid_argument(
343  "DecisionTree::operator(): value undefined for a label");
344  }
345 #endif
346  size_t index = x.at(label_);
347  NodePtr child = branches_[index];
348  return (*child)(x);
349  }
350 
352  Choice(const L& label, const Choice& f, const Unary& op) :
353  label_(label), allSame_(true) {
354  branches_.reserve(f.branches_.size()); // reserve space
355  for (const NodePtr& branch : f.branches_) {
356  push_back(branch->apply(op));
357  }
358  }
359 
370  Choice(const L& label, const Choice& f, const UnaryAssignment& op,
371  const Assignment<L>& assignment)
372  : label_(label), allSame_(true) {
373  branches_.reserve(f.branches_.size()); // reserve space
374 
375  Assignment<L> assignment_ = assignment;
376 
377  for (size_t i = 0; i < f.branches_.size(); i++) {
378  assignment_[label_] = i; // Set assignment for label to i
379 
380  const NodePtr branch = f.branches_[i];
381  push_back(branch->apply(op, assignment_));
382 
383  // Remove the assignment so we are backtracking
384  auto assignment_it = assignment_.find(label_);
385  assignment_.erase(assignment_it);
386  }
387  }
388 
390  NodePtr apply(const Unary& op) const override {
391  auto r = std::make_shared<Choice>(label_, *this, op);
392  return Unique(r);
393  }
394 
396  NodePtr apply(const UnaryAssignment& op,
397  const Assignment<L>& assignment) const override {
398  auto r = std::make_shared<Choice>(label_, *this, op, assignment);
399  return Unique(r);
400  }
401 
402  // Apply binary operator "h = f op g" on Choice node
403  // Note op is not assumed commutative so we need to keep track of order
404  // Simply calls apply on argument to call correct virtual method:
405  // fC.apply_f_op_g(gL) -> gL.apply_g_op_fC(fC) -> (Leaf)
406  // fC.apply_f_op_g(gC) -> gC.apply_g_op_fC(fC) -> (below)
407  NodePtr apply_f_op_g(const Node& g, const Binary& op) const override {
408  return g.apply_g_op_fC(*this, op);
409  }
410 
411  // If second argument of binary op is Leaf node, recurse on branches
412  NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
413  auto h = std::make_shared<Choice>(label(), nrChoices());
414  for (auto&& branch : branches_)
415  h->push_back(fL.apply_f_op_g(*branch, op));
416  return Unique(h);
417  }
418 
419  // If second argument of binary op is Choice, call constructor
420  NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
421  auto h = std::make_shared<Choice>(fC, *this, op);
422  return Unique(h);
423  }
424 
425  // If second argument of binary op is Leaf
426  template<typename OP>
427  NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const {
428  auto h = std::make_shared<Choice>(label(), nrChoices());
429  for (auto&& branch : branches_)
430  h->push_back(branch->apply_f_op_g(gL, op));
431  return Unique(h);
432  }
433 
435  NodePtr choose(const L& label, size_t index) const override {
436  if (label_ == label) return branches_[index]; // choose branch
437 
438  // second case, not label of interest, just recurse
439  auto r = std::make_shared<Choice>(label_, branches_.size());
440  for (auto&& branch : branches_)
441  r->push_back(branch->choose(label, index));
442  return Unique(r);
443  }
444 
445  private:
447 
448 #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
449 
450  friend class boost::serialization::access;
451  template <class ARCHIVE>
452  void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
453  ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
454  ar& BOOST_SERIALIZATION_NVP(label_);
455  ar& BOOST_SERIALIZATION_NVP(branches_);
456  ar& BOOST_SERIALIZATION_NVP(allSame_);
457  }
458 #endif
459  }; // Choice
460 
461  /****************************************************************************/
462  // DecisionTree
463  /****************************************************************************/
464  template<typename L, typename Y>
466  }
467 
468  template<typename L, typename Y>
470  root_(root) {
471  }
472 
473  /****************************************************************************/
474  template<typename L, typename Y>
476  root_ = NodePtr(new Leaf(y));
477  }
478 
479  /****************************************************************************/
480  template <typename L, typename Y>
481  DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
482  auto a = std::make_shared<Choice>(label, 2);
483  NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
484  a->push_back(l1);
485  a->push_back(l2);
486  root_ = Choice::Unique(a);
487  }
488 
489  /****************************************************************************/
490  template <typename L, typename Y>
491  DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
492  const Y& y2) {
493  if (labelC.second != 2) throw std::invalid_argument(
494  "DecisionTree: binary constructor called with non-binary label");
495  auto a = std::make_shared<Choice>(labelC.first, 2);
496  NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
497  a->push_back(l1);
498  a->push_back(l2);
499  root_ = Choice::Unique(a);
500  }
501 
502  /****************************************************************************/
503  template<typename L, typename Y>
504  DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
505  const std::vector<Y>& ys) {
506  // call recursive Create
507  root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
508  }
509 
510  /****************************************************************************/
511  template<typename L, typename Y>
512  DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
513  const std::string& table) {
514  // Convert std::string to values of type Y
515  std::vector<Y> ys;
516  std::istringstream iss(table);
517  copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
518  back_inserter(ys));
519 
520  // now call recursive Create
521  root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
522  }
523 
524  /****************************************************************************/
525  template<typename L, typename Y>
526  template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
527  Iterator begin, Iterator end, const L& label) {
528  root_ = compose(begin, end, label);
529  }
530 
531  /****************************************************************************/
532  template<typename L, typename Y>
534  const DecisionTree& f0, const DecisionTree& f1) {
535  const std::vector<DecisionTree> functions{f0, f1};
536  root_ = compose(functions.begin(), functions.end(), label);
537  }
538 
539  /****************************************************************************/
540  template <typename L, typename Y>
541  template <typename X, typename Func>
543  Func Y_of_X) {
544  // Define functor for identity mapping of node label.
545  auto L_of_L = [](const L& label) { return label; };
546  root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
547  }
548 
549  /****************************************************************************/
550  template <typename L, typename Y>
551  template <typename M, typename X, typename Func>
553  const std::map<M, L>& map, Func Y_of_X) {
554  auto L_of_M = [&map](const M& label) -> L { return map.at(label); };
555  root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
556  }
557 
558  /****************************************************************************/
559  // Called by two constructors above.
560  // Takes a label and a corresponding range of decision trees, and creates a
561  // new decision tree. However, the order of the labels needs to be respected,
562  // so we cannot just create a root Choice node on the label: if the label is
563  // not the highest label, we need a complicated/ expensive recursive call.
564  template <typename L, typename Y>
565  template <typename Iterator>
567  Iterator begin, Iterator end, const L& label) const {
568  // find highest label among branches
569  std::optional<L> highestLabel;
570  size_t nrChoices = 0;
571  for (Iterator it = begin; it != end; it++) {
572  if (it->root_->isLeaf())
573  continue;
574  std::shared_ptr<const Choice> c =
575  std::dynamic_pointer_cast<const Choice>(it->root_);
576  if (!highestLabel || c->label() > *highestLabel) {
577  highestLabel = c->label();
578  nrChoices = c->nrChoices();
579  }
580  }
581 
582  // if label is already in correct order, just put together a choice on label
583  if (!nrChoices || !highestLabel || label > *highestLabel) {
584  auto choiceOnLabel = std::make_shared<Choice>(label, end - begin);
585  for (Iterator it = begin; it != end; it++)
586  choiceOnLabel->push_back(it->root_);
587  return Choice::Unique(choiceOnLabel);
588  } else {
589  // Set up a new choice on the highest label
590  auto choiceOnHighestLabel =
591  std::make_shared<Choice>(*highestLabel, nrChoices);
592  // now, for all possible values of highestLabel
593  for (size_t index = 0; index < nrChoices; index++) {
594  // make a new set of functions for composing by iterating over the given
595  // functions, and selecting the appropriate branch.
596  std::vector<DecisionTree> functions;
597  for (Iterator it = begin; it != end; it++) {
598  // by restricting the input functions to value i for labelBelow
599  DecisionTree chosen = it->choose(*highestLabel, index);
600  functions.push_back(chosen);
601  }
602  // We then recurse, for all values of the highest label
603  NodePtr fi = compose(functions.begin(), functions.end(), label);
604  choiceOnHighestLabel->push_back(fi);
605  }
606  return Choice::Unique(choiceOnHighestLabel);
607  }
608  }
609 
610  /****************************************************************************/
611  // "create" is a bit of a complicated thing, but very useful.
612  // It takes a range of labels and a corresponding range of values,
613  // and creates a decision tree, as follows:
614  // - if there is only one label, creates a choice node with values in leaves
615  // - otherwise, it evenly splits up the range of values and creates a tree for
616  // each sub-range, and assigns that tree to first label's choices
617  // Example:
618  // create([B A],[1 2 3 4]) would call
619  // create([A],[1 2])
620  // create([A],[3 4])
621  // and produce
622  // B=0
623  // A=0: 1
624  // A=1: 2
625  // B=1
626  // A=0: 3
627  // A=1: 4
628  // Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce
629  // exactly the same tree as above: the highest label is always the root.
630  // However, it will be *way* faster if labels are given highest to lowest.
631  template<typename L, typename Y>
632  template<typename It, typename ValueIt>
634  It begin, It end, ValueIt beginY, ValueIt endY) const {
635  // get crucial counts
636  size_t nrChoices = begin->second;
637  size_t size = endY - beginY;
638 
639  // Find the next key to work on
640  It labelC = begin + 1;
641  if (labelC == end) {
642  // Base case: only one key left
643  // Create a simple choice node with values as leaves.
644  if (size != nrChoices) {
645  std::cout << "Trying to create DD on " << begin->first << std::endl;
646  std::cout << "DecisionTree::create: expected " << nrChoices
647  << " values but got " << size << " instead" << std::endl;
648  throw std::invalid_argument("DecisionTree::create invalid argument");
649  }
650  auto choice = std::make_shared<Choice>(begin->first, endY - beginY);
651  for (ValueIt y = beginY; y != endY; y++)
652  choice->push_back(NodePtr(new Leaf(*y)));
653  return Choice::Unique(choice);
654  }
655 
656  // Recursive case: perform "Shannon expansion"
657  // Creates one tree (i.e.,function) for each choice of current key
658  // by calling create recursively, and then puts them all together.
659  std::vector<DecisionTree> functions;
660  size_t split = size / nrChoices;
661  for (size_t i = 0; i < nrChoices; i++, beginY += split) {
662  NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY + split);
663  functions.emplace_back(f);
664  }
665  return compose(functions.begin(), functions.end(), begin->first);
666  }
667 
668  /****************************************************************************/
669  template <typename L, typename Y>
670  template <typename M, typename X>
672  const typename DecisionTree<M, X>::NodePtr& f,
673  std::function<L(const M&)> L_of_M,
674  std::function<Y(const X&)> Y_of_X) const {
675  using LY = DecisionTree<L, Y>;
676 
677  // Ugliness below because apparently we can't have templated virtual
678  // functions.
679  // If leaf, apply unary conversion "op" and create a unique leaf.
680  using MXLeaf = typename DecisionTree<M, X>::Leaf;
681  if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
682  return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments()));
683  }
684 
685  // Check if Choice
686  using MXChoice = typename DecisionTree<M, X>::Choice;
687  auto choice = std::dynamic_pointer_cast<const MXChoice>(f);
688  if (!choice) throw std::invalid_argument(
689  "DecisionTree::convertFrom: Invalid NodePtr");
690 
691  // get new label
692  const M oldLabel = choice->label();
693  const L newLabel = L_of_M(oldLabel);
694 
695  // put together via Shannon expansion otherwise not sorted.
696  std::vector<LY> functions;
697  for (auto&& branch : choice->branches()) {
698  functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
699  }
700  return LY::compose(functions.begin(), functions.end(), newLabel);
701  }
702 
703  /****************************************************************************/
714  template <typename L, typename Y>
715  struct Visit {
716  using F = std::function<void(const Y&)>;
717  explicit Visit(F f) : f(f) {}
718  F f;
719 
721  void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
722  using Leaf = typename DecisionTree<L, Y>::Leaf;
723  if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
724  return f(leaf->constant());
725 
726  using Choice = typename DecisionTree<L, Y>::Choice;
727  auto choice = std::dynamic_pointer_cast<const Choice>(node);
728  if (!choice)
729  throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
730  for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
731  }
732  };
733 
734  template <typename L, typename Y>
735  template <typename Func>
736  void DecisionTree<L, Y>::visit(Func f) const {
737  Visit<L, Y> visit(f);
738  visit(root_);
739  }
740 
741  /****************************************************************************/
751  template <typename L, typename Y>
752  struct VisitLeaf {
753  using F = std::function<void(const typename DecisionTree<L, Y>::Leaf&)>;
754  explicit VisitLeaf(F f) : f(f) {}
755  F f;
756 
758  void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
759  using Leaf = typename DecisionTree<L, Y>::Leaf;
760  if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
761  return f(*leaf);
762 
763  using Choice = typename DecisionTree<L, Y>::Choice;
764  auto choice = std::dynamic_pointer_cast<const Choice>(node);
765  if (!choice)
766  throw std::invalid_argument("DecisionTree::VisitLeaf: Invalid NodePtr");
767  for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
768  }
769  };
770 
771  template <typename L, typename Y>
772  template <typename Func>
773  void DecisionTree<L, Y>::visitLeaf(Func f) const {
775  visit(root_);
776  }
777 
778  /****************************************************************************/
785  template <typename L, typename Y>
786  struct VisitWith {
787  using F = std::function<void(const Assignment<L>&, const Y&)>;
788  explicit VisitWith(F f) : f(f) {}
790  F f;
791 
793  void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
794  using Leaf = typename DecisionTree<L, Y>::Leaf;
795  if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
796  return f(assignment, leaf->constant());
797 
798  using Choice = typename DecisionTree<L, Y>::Choice;
799  auto choice = std::dynamic_pointer_cast<const Choice>(node);
800  if (!choice)
801  throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
802  for (size_t i = 0; i < choice->nrChoices(); i++) {
803  assignment[choice->label()] = i; // Set assignment for label to i
804 
805  (*this)(choice->branches()[i]); // recurse!
806 
807  // Remove the choice so we are backtracking
808  auto choice_it = assignment.find(choice->label());
809  assignment.erase(choice_it);
810  }
811  }
812  };
813 
814  template <typename L, typename Y>
815  template <typename Func>
816  void DecisionTree<L, Y>::visitWith(Func f) const {
818  visit(root_);
819  }
820 
821  /****************************************************************************/
822  template <typename L, typename Y>
824  size_t total = 0;
825  visit([&total](const Y& node) { total += 1; });
826  return total;
827  }
828 
829  /****************************************************************************/
830  // fold is just done with a visit
831  template <typename L, typename Y>
832  template <typename Func, typename X>
833  X DecisionTree<L, Y>::fold(Func f, X x0) const {
834  visit([&](const Y& y) { x0 = f(y, x0); });
835  return x0;
836  }
837 
838  /****************************************************************************/
852  template <typename L, typename Y>
853  std::set<L> DecisionTree<L, Y>::labels() const {
854  std::set<L> unique;
855  auto f = [&](const Assignment<L>& assignment, const Y&) {
856  for (auto&& kv : assignment) {
857  unique.insert(kv.first);
858  }
859  };
860  visitWith(f);
861  return unique;
862  }
863 
864 /****************************************************************************/
865  template <typename L, typename Y>
866  bool DecisionTree<L, Y>::equals(const DecisionTree& other,
867  const CompareFunc& compare) const {
868  return root_->equals(*other.root_, compare);
869  }
870 
871  template <typename L, typename Y>
872  void DecisionTree<L, Y>::print(const std::string& s,
873  const LabelFormatter& labelFormatter,
874  const ValueFormatter& valueFormatter) const {
875  root_->print(s, labelFormatter, valueFormatter);
876  }
877 
878  template<typename L, typename Y>
879  bool DecisionTree<L, Y>::operator==(const DecisionTree& other) const {
880  return root_->equals(*other.root_);
881  }
882 
883  template<typename L, typename Y>
885  return root_->operator ()(x);
886  }
887 
888  template<typename L, typename Y>
890  // It is unclear what should happen if tree is empty:
891  if (empty()) {
892  throw std::runtime_error(
893  "DecisionTree::apply(unary op) undefined for empty tree.");
894  }
895  return DecisionTree(root_->apply(op));
896  }
897 
899  template <typename L, typename Y>
901  const UnaryAssignment& op) const {
902  // It is unclear what should happen if tree is empty:
903  if (empty()) {
904  throw std::runtime_error(
905  "DecisionTree::apply(unary op) undefined for empty tree.");
906  }
907  Assignment<L> assignment;
908  return DecisionTree(root_->apply(op, assignment));
909  }
910 
911  /****************************************************************************/
912  template<typename L, typename Y>
914  const Binary& op) const {
915  // It is unclear what should happen if either tree is empty:
916  if (empty() || g.empty()) {
917  throw std::runtime_error(
918  "DecisionTree::apply(binary op) undefined for empty trees.");
919  }
920  // apply the operaton on the root of both diagrams
921  NodePtr h = root_->apply_f_op_g(*g.root_, op);
922  // create a new class with the resulting root "h"
923  DecisionTree result(h);
924  return result;
925  }
926 
927  /****************************************************************************/
928  // The way this works:
929  // We have an ADT, picture it as a tree.
930  // At a certain depth, we have a branch on "label".
931  // The function "choose(label,index)" will return a tree of one less depth,
932  // where there is no more branch on "label": only the subtree under that
933  // branch point corresponding to the value "index" is left instead.
934  // The function below get all these smaller trees and "ops" them together.
935  // This implements marginalization in Darwiche09book, pg 330
936  template<typename L, typename Y>
938  size_t cardinality, const Binary& op) const {
939  DecisionTree result = choose(label, 0);
940  for (size_t index = 1; index < cardinality; index++) {
941  DecisionTree chosen = choose(label, index);
942  result = result.apply(chosen, op);
943  }
944  return result;
945  }
946 
947  /****************************************************************************/
948  template <typename L, typename Y>
949  void DecisionTree<L, Y>::dot(std::ostream& os,
950  const LabelFormatter& labelFormatter,
951  const ValueFormatter& valueFormatter,
952  bool showZero) const {
953  os << "digraph G {\n";
954  root_->dot(os, labelFormatter, valueFormatter, showZero);
955  os << " [ordering=out]}" << std::endl;
956  }
957 
958  template <typename L, typename Y>
959  void DecisionTree<L, Y>::dot(const std::string& name,
960  const LabelFormatter& labelFormatter,
961  const ValueFormatter& valueFormatter,
962  bool showZero) const {
963  std::ofstream os((name + ".dot").c_str());
964  dot(os, labelFormatter, valueFormatter, showZero);
965  int result =
966  system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
967  .c_str());
968  if (result == -1)
969  throw std::runtime_error("DecisionTree::dot system call failed");
970  }
971 
972  template <typename L, typename Y>
973  std::string DecisionTree<L, Y>::dot(const LabelFormatter& labelFormatter,
974  const ValueFormatter& valueFormatter,
975  bool showZero) const {
976  std::stringstream ss;
977  dot(ss, labelFormatter, valueFormatter, showZero);
978  return ss.str();
979  }
980 
981 /******************************************************************************/
982 
983  } // namespace gtsam
const Y & operator()(const Assignment< L > &x) const override
evaluate
Definition: DecisionTree-inl.h:337
Decision Tree for use in DiscreteFactors.
NodePtr convertFrom(const typename DecisionTree< M, X >::NodePtr &f, std::function< L(const M &)> L_of_M, std::function< Y(const X &)> Y_of_X) const
Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
Definition: DecisionTree-inl.h:671
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const override
print (as a tree).
Definition: DecisionTree-inl.h:281
NodePtr apply(const Unary &op) const override
apply unary operator.
Definition: DecisionTree-inl.h:390
void visit(Func f) const
Visit all leaves in depth-first fashion.
Definition: DecisionTree-inl.h:736
size_t nrAssignments() const
Return the number of assignments contained within this leaf.
Definition: DecisionTree-inl.h:73
L label_
Definition: DecisionTree-inl.h:175
std::vector< NodePtr > branches_
Definition: DecisionTree-inl.h:178
std::string serialize(const T &input)
serializes to a string
Definition: serialization.h:113
Definition: DecisionTree-inl.h:715
size_t nrAssignments_
Definition: DecisionTree-inl.h:58
Definition: DecisionTree-inl.h:752
VisitLeaf(F f)
Construct from folding function.
Definition: DecisionTree-inl.h:754
DecisionTree()
Definition: DecisionTree-inl.h:465
std::function< Y(const Y &)> Unary
Definition: DecisionTree.h:62
Choice(const Choice &f, const Choice &g, const Binary &op)
Construct from applying binary op to two Choice nodes.
Definition: DecisionTree-inl.h:231
std::set< L > labels() const
Definition: DecisionTree-inl.h:853
Choice(const L &label, const Choice &f, const Unary &op)
Construct from applying unary op to a Choice node.
Definition: DecisionTree-inl.h:352
NodePtr apply(const UnaryAssignment &op, const Assignment< L > &assignment) const override
Apply unary operator with assignment.
Definition: DecisionTree-inl.h:120
const L & label() const
Return the label of this choice node.
Definition: DecisionTree-inl.h:259
Definition: DecisionTree-inl.h:786
F f
folding function object.
Definition: DecisionTree-inl.h:755
Leaf(const Y &constant, size_t nrAssignments=1)
Constructor from constant.
Definition: DecisionTree-inl.h:64
NodePtr root_
A DecisionTree just contains the root. TODO(dellaert): make protected.
Definition: DecisionTree.h:136
void split(const G &g, const PredecessorMap< KEY > &tree, G &Ab1, G &Ab2)
Definition: graph-inl.h:245
NodePtr choose(const L &label, size_t index) const override
Definition: DecisionTree-inl.h:148
Definition: Testable.h:112
DecisionTree combine(const L &label, size_t cardinality, const Binary &op) const
Definition: DecisionTree-inl.h:937
Choice(const L &label, size_t count)
Constructor, given choice label and mandatory expected branch count.
Definition: DecisionTree-inl.h:225
void visitWith(Func f) const
Visit all leaves in depth-first fashion.
Definition: DecisionTree-inl.h:816
void operator()(const typename DecisionTree< L, Y >::NodePtr &node) const
Do a depth-first visit on the tree rooted at node.
Definition: DecisionTree-inl.h:758
Definition: DecisionTree.h:49
size_t nrLeaves() const
Return the number of leaves in the tree.
Definition: DecisionTree-inl.h:823
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero) const override
Definition: DecisionTree-inl.h:99
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero) const override
Definition: DecisionTree-inl.h:291
Y constant_
Definition: DecisionTree-inl.h:53
NodePtr apply(const UnaryAssignment &op, const Assignment< L > &assignment) const override
Apply unary operator with assignment.
Definition: DecisionTree-inl.h:396
bool sameLeaf(const Node &q) const override
polymorphic equality: is q a leaf and is it the same as this leaf?
Definition: DecisionTree-inl.h:81
void visitLeaf(Func f) const
Visit all leaves in depth-first fashion.
Definition: DecisionTree-inl.h:773
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const override
print
Definition: DecisionTree-inl.h:93
void push_back(const NodePtr &node)
Definition: DecisionTree-inl.h:272
Leaf()
Default constructor for serialization.
Definition: DecisionTree-inl.h:61
Definition: DecisionTree-inl.h:173
static NodePtr Unique(const ChoicePtr &f)
If all branches of a choice node f are the same, just return a branch.
Definition: DecisionTree-inl.h:201
X fold(Func f, X x0) const
Fold a binary function over the tree, returning accumulator.
Definition: DecisionTree-inl.h:833
NodePtr choose(const L &label, size_t index) const override
Definition: DecisionTree-inl.h:435
DecisionTree apply(const Unary &op) const
Definition: DecisionTree-inl.h:889
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const
Definition: DecisionTree-inl.h:633
typename Node::Ptr NodePtr
Definition: DecisionTree.h:133
void operator()(const typename DecisionTree< L, Y >::NodePtr &node)
Do a depth-first visit on the tree rooted at node.
Definition: DecisionTree-inl.h:793
Definition: Assignment.h:37
Definition: chartTesting.h:28
NodePtr apply(const Unary &op) const override
Definition: DecisionTree-inl.h:114
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const
GTSAM-style print.
Definition: DecisionTree-inl.h:872
bool empty() const
Check if tree is empty.
Definition: DecisionTree.h:240
const Y & operator()(const Assignment< L > &x) const
Definition: DecisionTree-inl.h:884
bool equals(const Node &q, const CompareFunc &compare) const override
equality
Definition: DecisionTree-inl.h:324
Visit(F f)
Construct from folding function.
Definition: DecisionTree-inl.h:717
Definition: DecisionTree-inl.h:51
Assignment< L > assignment
Assignment, mutating through recursion.
Definition: DecisionTree-inl.h:789
bool operator==(const DecisionTree &q) const
Definition: DecisionTree-inl.h:879
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero=true) const
Definition: DecisionTree-inl.h:949
bool sameLeaf(const Leaf &q) const override
Choice-Leaf equality: always false.
Definition: DecisionTree-inl.h:314
void operator()(const typename DecisionTree< L, Y >::NodePtr &node) const
Do a depth-first visit on the tree rooted at node.
Definition: DecisionTree-inl.h:721
F f
folding function object.
Definition: DecisionTree-inl.h:718
Definition: DecisionTree.h:74
bool equals(const Node &q, const CompareFunc &compare) const override
equality up to tolerance
Definition: DecisionTree-inl.h:86
std::pair< Key, size_t > LabelC
Definition: DecisionTree.h:67
VisitWith(F f)
Construct from folding function.
Definition: DecisionTree-inl.h:788
Choice(const L &label, const Choice &f, const UnaryAssignment &op, const Assignment< L > &assignment)
Constructor which accepts a UnaryAssignment op and the corresponding assignment.
Definition: DecisionTree-inl.h:370
bool sameLeaf(const Leaf &q) const override
Leaf-Leaf equality.
Definition: DecisionTree-inl.h:76
bool sameLeaf(const Node &q) const override
polymorphic equality: if q is a leaf, could be...
Definition: DecisionTree-inl.h:319
DecisionTree choose(const L &label, size_t index) const
Definition: DecisionTree.h:341
F f
folding function object.
Definition: DecisionTree-inl.h:790
const Y & constant() const
Return the constant.
Definition: DecisionTree-inl.h:68
const Y & operator()(const Assignment< L > &x) const override
Definition: DecisionTree-inl.h:109
Choice()
Default constructor for serialization.
Definition: DecisionTree-inl.h:191