33 template<
class BAYESNET,
class GRAPH>
35 EliminationTree<BAYESNET,GRAPH>::Node::eliminate(
36 const std::shared_ptr<BayesNetType>& output,
37 const Eliminate&
function,
const FastVector<sharedFactor>& childrenResults)
const 41 assert(childrenResults.size() ==
children.size());
47 gatheredFactors.push_back(childrenResults.begin(), childrenResults.end());
51 auto eliminationResult =
function(gatheredFactors, Ordering(keyAsVector));
54 output->push_back(eliminationResult.first);
57 return eliminationResult.second;
61 template<
class BAYESNET,
class GRAPH>
62 void EliminationTree<BAYESNET,GRAPH>::Node::print(
63 const std::string& str,
const KeyFormatter& keyFormatter)
const 65 std::cout << str <<
"(" << keyFormatter(
key) <<
")\n";
70 std::cout << str <<
"null factor\n";
76 template<
class BAYESNET,
class GRAPH>
80 gttic(EliminationTree_Contructor);
84 const size_t m = graph.
size();
85 const size_t n = order.size();
87 static const size_t none = std::numeric_limits<size_t>::max();
97 for (
size_t j = 0; j < n; j++)
101 const sharedNode node = std::make_shared<Node>();
102 node->key = order[j];
105 node->children.reserve(factors.size());
106 node->factors.reserve(factors.size());
107 for(
const size_t i: factors) {
113 if (prevCol[i] != none) {
114 size_t k = prevCol[i];
118 while (parents[r] != none)
125 node->children.push_back(nodes[r]);
129 node->factors.push_back(graph[i]);
130 factorUsed[i] =
true;
136 }
catch(std::invalid_argument& e) {
140 throw std::invalid_argument(
"EliminationTree: given ordering contains variables that are not involved in the factor graph");
146 assert(parents.empty() || parents.back() == none);
147 for(
size_t j = 0; j < n; ++j)
148 if(parents[j] == none)
149 roots_.push_back(nodes[j]);
152 for(
size_t i = 0; i < m; ++i)
153 if(!factorUsed[i] && graph[i])
154 remainingFactors_.push_back(graph[i]);
158 template<
class BAYESNET,
class GRAPH>
165 This temp(factorGraph, variableIndex, order);
170 template<
class BAYESNET,
class GRAPH>
179 remainingFactors_ = other.remainingFactors_;
191 template<
class BAYESNET,
class GRAPH>
196 for (
auto&& root :
roots_) {
197 std::queue<sharedNode> bfs_queue;
200 bfs_queue.push(root);
205 while (!bfs_queue.empty()) {
207 auto node = bfs_queue.front();
211 for (
auto&& child : node->children) {
212 bfs_queue.push(child);
222 template<
class BAYESNET,
class GRAPH>
223 std::pair<std::shared_ptr<BAYESNET>, std::shared_ptr<GRAPH> >
226 gttic(EliminationTree_eliminate);
228 auto result = std::make_shared<BayesNetType>();
234 auto allRemainingFactors = std::make_shared<FactorGraphType>();
235 allRemainingFactors->push_back(remainingFactors_.begin(), remainingFactors_.end());
236 allRemainingFactors->push_back(remainingFactors.begin(), remainingFactors.end());
239 return {result, allRemainingFactors};
243 template<
class BAYESNET,
class GRAPH>
250 template<
class BAYESNET,
class GRAPH>
254 std::stack<sharedNode, FastVector<sharedNode> > stack1, stack2;
259 for(
const sharedNode& root: this->
roots_) { keys.emplace(root->key, root); }
261 for(
const Key_Node& key_node: keys) { stack1.push(key_node.second); }
265 for(
const sharedNode& root: expected.
roots_) { keys.emplace(root->key, root); }
267 for(
const Key_Node& key_node: keys) { stack2.push(key_node.second); }
271 while(!stack1.empty() && !stack2.empty()) {
279 if(node1->key != node2->key)
281 if(node1->factors.size() != node2->factors.size()) {
284 for(
typename Node::Factors::const_iterator it1 = node1->factors.begin(), it2 = node2->factors.begin();
285 it1 != node1->factors.end(); ++it1, ++it2)
288 if(!(*it1)->equals(**it2, tol))
290 }
else if((*it1 && !*it2) || (*it2 && !*it1)) {
299 for(
const sharedNode& node: node1->children) { keys.emplace(node->key, node); }
301 for(
const Key_Node& key_node: keys) { stack1.push(key_node.second); }
305 for(
const sharedNode& node: node2->children) { keys.emplace(node->key, node); }
307 for(
const Key_Node& key_node: keys) { stack2.push(key_node.second); }
312 if(!stack1.empty() || !stack2.empty())
319 template<
class BAYESNET,
class GRAPH>
322 remainingFactors_.swap(other.remainingFactors_);
FastVector< sharedNode > roots_
Definition: EliminationTree.h:86
Factors factors
factors associated with root
Definition: EliminationTree.h:71
bool equals(const This &other, double tol=1e-9) const
Definition: EliminationTree-inst.h:251
GRAPH FactorGraphType
The factor graph type.
Definition: EliminationTree.h:58
FastVector< FactorIndex > FactorIndices
Define collection types:
Definition: Factor.h:36
std::vector< T, typename internal::FastDefaultVectorAllocator< T >::type > FastVector
Definition: FastVector.h:34
void swap(This &other)
Definition: EliminationTree-inst.h:320
Contains generic inference algorithms that convert between templated graphical models, i.e., factor graphs, Bayes nets, and Bayes trees.
Variable ordering for the elimination algorithm.
Definition: Ordering.h:37
Children children
sub-trees
Definition: EliminationTree.h:72
EliminationTree()
Protected default constructor.
Definition: EliminationTree.h:162
size_t size() const
Definition: FactorGraph.h:334
void PrintForest(const FOREST &forest, std::string str, const KeyFormatter &keyFormatter)
Definition: treeTraversal-inst.h:219
Key key
key associated with root
Definition: EliminationTree.h:70
const FastVector< sharedFactor > & remainingFactors() const
Definition: EliminationTree.h:155
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
FastVector< std::shared_ptr< typename FOREST::Node > > CloneForest(const FOREST &forest)
Definition: treeTraversal-inst.h:189
Definition: chartTesting.h:28
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
Definition: Key.h:86
~EliminationTree()
Definition: EliminationTree-inst.h:192
Definition: VariableIndex.h:41
This & operator=(const This &other)
Definition: EliminationTree-inst.h:172
std::pair< std::shared_ptr< BayesNetType >, std::shared_ptr< FactorGraphType > > eliminate(Eliminate function) const
Definition: EliminationTree-inst.h:224
std::shared_ptr< FactorType > sharedFactor
Shared pointer to a factor.
Definition: EliminationTree.h:60
std::shared_ptr< Node > sharedNode
Shared pointer to Node.
Definition: EliminationTree.h:80
Definition: GaussianFactorGraph.h:73
void print(const std::string &name="EliminationTree: ", const KeyFormatter &formatter=DefaultKeyFormatter) const
Definition: EliminationTree-inst.h:244