35#ifndef VIGRA_RF_ALGORITHM_HXX
36#define VIGRA_RF_ALGORITHM_HXX
66 int columnCount = std::distance(b, e);
67 int rowCount =
in.shape(0);
70 for(Iter iter = b; iter != e; ++iter, ++
ii)
72 columnVector(
out,
ii) = columnVector(
in, *iter);
100 template<
class Feature_t,
class Response_t>
109 return oob.oob_breiman;
125 typedef std::vector<int> FeatureList_t;
126 typedef std::vector<double> ErrorList_t;
127 typedef FeatureList_t::iterator Pivot_t;
166 vigra_precondition(std::distance(b, e) ==
static_cast<std::ptrdiff_t
>(
selected.size()),
167 "Number of features in ranking != number of features matrix");
223 std::map<typename ResponseT::value_type, int>
res_map;
224 std::vector<int>
cts;
226 for(
int ii = 0;
ii < response.shape(0); ++
ii)
238 /
double(response.shape(0));
293template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
299 VariableSelectionResult::FeatureList_t & selected = result.
selected;
300 VariableSelectionResult::ErrorList_t & errors = result.
errors;
301 VariableSelectionResult::Pivot_t & pivot = result.pivot;
302 int featureCount = features.shape(1);
308 vigra_precondition(
static_cast<int>(selected.size()) == featureCount,
309 "forward_selection(): Number of features in Feature "
310 "matrix and number of features in previously used "
311 "result struct mismatch!");
319 VariableSelectionResult::Pivot_t next = pivot;
322 std::swap(*pivot, *next);
324 detail::choose( features,
330 std::swap(*pivot, *next);
336 std::advance(next, pos);
337 std::swap(*pivot, *next);
338 errors[std::distance(selected.begin(), pivot)] =
current_errors[pos];
341 std::cerr <<
"Choosing " << *pivot <<
" at error of " <<
current_errors[pos] << std::endl;
347template<
class FeatureT,
class ResponseT>
350 VariableSelectionResult & result)
395template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
401 int featureCount = features.shape(1);
402 VariableSelectionResult::FeatureList_t & selected = result.
selected;
403 VariableSelectionResult::ErrorList_t & errors = result.
errors;
404 VariableSelectionResult::Pivot_t & pivot = result.pivot;
411 vigra_precondition(
static_cast<int>(selected.size()) == featureCount,
412 "backward_elimination(): Number of features in Feature "
413 "matrix and number of features in previously used "
414 "result struct mismatch!");
416 pivot = selected.end() - 1;
421 VariableSelectionResult::Pivot_t next = selected.begin();
425 std::swap(*pivot, *next);
427 detail::choose( features,
433 std::swap(*pivot, *next);
438 next = selected.begin();
439 std::advance(next, pos);
440 std::swap(*pivot, *next);
442 errors[std::distance(selected.begin(), pivot)-1] =
current_errors[pos];
446 std::cerr <<
"Eliminating " << *pivot <<
" at error of " <<
current_errors[pos] << std::endl;
452template<
class FeatureT,
class ResponseT>
455 VariableSelectionResult & result)
492template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
498 VariableSelectionResult::FeatureList_t & selected = result.
selected;
499 VariableSelectionResult::ErrorList_t & errors = result.
errors;
500 VariableSelectionResult::Pivot_t & iter = result.pivot;
501 int featureCount = features.shape(1);
507 vigra_precondition(
static_cast<int>(selected.size()) == featureCount,
508 "forward_selection(): Number of features in Feature "
509 "matrix and number of features in previously used "
510 "result struct mismatch!");
514 for(; iter != selected.end(); ++iter)
518 detail::choose( features,
523 errors[std::distance(selected.begin(), iter)] =
error;
525 std::copy(selected.begin(), iter+1, std::ostream_iterator<int>(std::cerr,
", "));
526 std::cerr <<
"Choosing " << *(iter+1) <<
" at error of " <<
error << std::endl;
532template<
class FeatureT,
class ResponseT>
535 VariableSelectionResult & result)
542enum ClusterLeafTypes{c_Leaf = 95, c_Node = 99};
558 ClusterNode(
int nCol,
559 BT::T_Container_type & topology,
560 BT::P_Container_type & split_param)
561 : BT(nCol + 5, 5,topology, split_param)
571 ClusterNode( BT::T_Container_type
const & topology,
572 BT::P_Container_type
const & split_param,
574 :
NodeBase(5 , 5,topology, split_param, n)
580 ClusterNode( BT & node_)
585 BT::parameter_size_ += 0;
591 void set_index(
int in)
618 : parent(p), level(
l), addr(a), infm(
in)
647 double dist_func(
double a,
double b)
649 return std::min(a, b);
655 template<
class Functor>
659 std::vector<int>
stack;
660 stack.push_back(begin_addr);
661 while(!
stack.empty())
663 ClusterNode node(topology_, parameters_,
stack.back());
667 if(node.columns_size() != 1)
669 stack.push_back(node.child(0));
670 stack.push_back(node.child(1));
678 template<
class Functor>
682 std::queue<HC_Entry>
queue;
688 while(!
queue.empty())
690 level =
queue.front().level;
691 parent =
queue.front().parent;
692 addr =
queue.front().addr;
693 infm =
queue.front().infm;
694 ClusterNode node(topology_, parameters_,
queue.front().addr);
698 parnt = ClusterNode(topology_, parameters_, parent);
702 if(node.columns_size() != 1)
712 void save(std::string
file, std::string
prefix)
717 Shp(topology_.
size(),1),
721 Shp(parameters_.
size(), 1),
722 parameters_.data()));
732 template<
class T,
class C>
736 std::vector<std::pair<int, int> > addr;
738 for(
int ii = 0;
ii < distance.shape(0); ++
ii)
740 addr.push_back(std::make_pair(topology_.
size(),
ii));
741 ClusterNode
leaf(1, topology_, parameters_);
742 leaf.set_index(index);
744 leaf.columns_begin()[0] =
ii;
747 while(addr.size() != 1)
753 (addr.begin()+
jj_min)->second);
754 for(
unsigned int ii = 0;
ii < addr.
size(); ++
ii)
759 (addr.begin()+
jj_min)->second)
760 >
dist((addr.begin()+
ii)->second,
761 (addr.begin()+
jj)->second))
764 (addr.begin()+
jj)->second);
778 (addr.begin() +
ii_min)->first);
781 (addr.begin() +
jj_min)->first);
792 (addr.begin() +
ii_min)->first);
795 (addr.begin() +
jj_min)->first);
796 parent.parameters_begin()[0] =
min_dist;
797 parent.set_index(index);
801 parent.columns_begin());
805 if(*parent.columns_begin() == *
firstChild.columns_begin())
807 parent.child(0) = (addr.begin()+
ii_min)->first;
808 parent.child(1) = (addr.begin()+
jj_min)->first;
812 addr.erase(addr.begin()+
jj_min);
816 parent.child(1) = (addr.begin()+
ii_min)->first;
817 parent.child(0) = (addr.begin()+
jj_min)->first;
821 addr.erase(addr.begin()+
ii_min);
829 double bla = dist_func(
832 (addr.begin()+
jj)->second));
835 (addr.begin()+
jj)->second) =
bla;
836 dist((addr.begin()+
jj)->second,
858 bool operator()(Node& node)
871template<
class Iter,
class DT>
886 template<
class Feat_T,
class Label_T>
895 :tmp_mem_(_spl(a, b).size(),
feats.shape(1)),
898 feats_(_spl(a,b).size(),
feats.shape(1)),
899 labels_(_spl(a,b).size(),1),
905 copy_splice(_spl(a,b),
906 _spl(
feats.shape(1)),
909 copy_splice(_spl(a,b),
910 _spl(
labls.shape(1)),
916 bool operator()(Node& node)
920 int class_count = perm_imp.shape(1) - 1;
922 for(
int kk = 0;
kk < nPerm; ++
kk)
925 for(
int ii = 0;
ii < rowCount(feats_); ++
ii)
927 int index = random.uniformInt(rowCount(feats_) -
ii) +
ii;
928 for(
int jj = 0;
jj < node.columns_size(); ++
jj)
930 if(node.columns_begin()[
jj] != feats_.shape(1))
931 tmp_mem_(
ii, node.columns_begin()[
jj])
932 = tmp_mem_(index, node.columns_begin()[
jj]);
936 for(
int ii = 0;
ii < rowCount(tmp_mem_); ++
ii)
939 .predictLabel(rowVector(tmp_mem_,
ii))
943 ++perm_imp(index,labels_(
ii, 0));
945 ++perm_imp(index, class_count);
975 void save(std::string
file, std::string
prefix)
983 bool operator()(Node& node)
985 for(
int ii = 0;
ii < node.columns_size(); ++
ii)
1000 bool operator()(
Nde &
cur,
int ,
Nde parent,
bool )
1003 cur.status() = std::min(parent.status(),
cur.status());
1030 std::ofstream graphviz;
1035 std::string
const gz)
1036 :features_(features), labels_(labels),
1037 graphviz(
gz.c_str(), std::ios::out)
1039 graphviz <<
"digraph G\n{\n node [shape=\"record\"]";
1043 graphviz <<
"\n}\n";
1048 bool operator()(
Nde &
cur,
int ,
Nde parent,
bool )
1050 graphviz <<
"node" <<
cur.index() <<
" [style=\"filled\"][label = \" #Feats: "<<
cur.columns_size() <<
"\\n";
1051 graphviz <<
" status: " <<
cur.status() <<
"\\n";
1052 for(
int kk = 0;
kk <
cur.columns_size(); ++
kk)
1054 graphviz <<
cur.columns_begin()[
kk] <<
" ";
1058 graphviz <<
"\"] [color = \"" <<
cur.status() <<
" 1.000 1.000\"];\n";
1060 graphviz <<
"\"node" << parent.index() <<
"\" -> \"node" <<
cur.index() <<
"\";\n";
1080 int repetition_count_;
1086 void save(std::string filename, std::string
prefix)
1106 template<
class RF,
class PR>
1109 Int32 const class_count = rf.ext_param_.class_count_;
1110 Int32 const column_count = rf.ext_param_.column_count_+1;
1131 template<
class RF,
class PR,
class SM,
class ST>
1135 Int32 column_count = rf.ext_param_.column_count_ +1;
1136 Int32 class_count = rf.ext_param_.class_count_;
1140 typename PR::Feature_t & features
1141 =
const_cast<typename PR::Feature_t &
>(
pr.features());
1145 ArrayVector<Int32>::iterator
1148 if(rf.ext_param_.actual_msample_ <
pr.features().shape(0)- 10000)
1152 for(
int ii = 0;
ii <
pr.features().shape(0); ++
ii)
1153 indices.push_back(
ii);
1154 std::random_shuffle(indices.
begin(), indices.
end());
1155 for(
int ii = 0;
ii < rf.ext_param_.row_count_; ++
ii)
1157 if(!
sm.is_used()[indices[
ii]] &&
cts[
pr.response()(indices[
ii], 0)] < 3000)
1160 ++
cts[
pr.response()(indices[
ii], 0)];
1166 for(
int ii = 0;
ii < rf.ext_param_.row_count_; ++
ii)
1167 if(!
sm.is_used()[
ii])
1186 .predictLabel(rowVector(features, *iter))
1187 ==
pr.response()(*iter, 0))
1223 template<
class RF,
class PR,
class SM,
class ST>
1231 template<
class RF,
class PR>
1271template<
class FeatureT,
class ResponseT>
1279 opt.tree_count(100);
1280 if(features.shape(0) > 40000)
1281 opt.samples_per_tree(20000).use_stratification(RF_EQUAL);
1287 RF.learn(features, response,
1289 distance =
missc.distance;
1316template<
class FeatureT,
class ResponseT>
1326template<
class Array1,
class Vector1>
1329 std::map<double, int>
mymap;
1332 for(std::map<double, int>::reverse_iterator iter =
mymap.rbegin(); iter!=
mymap.rend(); ++iter)
1334 out.push_back(iter->second);
void reshape(const difference_type &shape)
Definition multi_array.hxx:2861
Topology_type column_data() const
Definition rf_nodeproxy.hxx:159
INT & typeID()
Definition rf_nodeproxy.hxx:136
NodeBase()
Definition rf_nodeproxy.hxx:237
Parameter_type parameters_begin() const
Definition rf_nodeproxy.hxx:207
Class for a single RGB value.
Definition rgbvalue.hxx:128
Options object for the random forest.
Definition rf_common.hxx:171
size_type size() const
Definition tinyvector.hxx:913
iterator end()
Definition tinyvector.hxx:864
iterator begin()
Definition tinyvector.hxx:861
Class for fixed size vectors.
Definition tinyvector.hxx:1008
Definition rf_algorithm.hxx:1068
MultiArray< 2, double > cluster_importance_
Definition rf_algorithm.hxx:1076
MultiArray< 2, int > variables
Definition rf_algorithm.hxx:1073
void visit_at_end(RF &rf, PR &)
Definition rf_algorithm.hxx:1232
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition rf_algorithm.hxx:1224
MultiArray< 2, double > cluster_stdev_
Definition rf_algorithm.hxx:1079
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition rf_algorithm.hxx:1132
void visit_at_beginning(RF const &rf, PR const &)
Definition rf_algorithm.hxx:1107
Definition rf_algorithm.hxx:997
Definition rf_algorithm.hxx:1025
Definition rf_algorithm.hxx:964
MultiArrayView< 2, int > variables
Definition rf_algorithm.hxx:969
Definition rf_algorithm.hxx:639
void iterate(Functor &tester)
Definition rf_algorithm.hxx:656
void cluster(MultiArrayView< 2, T, C > distance)
Definition rf_algorithm.hxx:733
void breadth_first_traversal(Functor &tester)
Definition rf_algorithm.hxx:679
Definition rf_algorithm.hxx:848
NormalizeStatus(double m)
Definition rf_algorithm.hxx:854
Definition rf_algorithm.hxx:873
Definition rf_algorithm.hxx:84
double operator()(Feature_t const &features, Response_t const &response)
Definition rf_algorithm.hxx:101
RFErrorCallback(RandomForestOptions opt=RandomForestOptions())
Definition rf_algorithm.hxx:93
Definition rf_algorithm.hxx:117
double no_features
Definition rf_algorithm.hxx:151
ErrorList_t errors
Definition rf_algorithm.hxx:146
FeatureList_t selected
Definition rf_algorithm.hxx:133
bool init(FeatureT const &all_features, ResponseT const &response, ErrorRateCallBack errorcallback)
Definition rf_algorithm.hxx:205
Definition rf_visitors.hxx:1496
Definition rf_visitors.hxx:864
Definition rf_visitors.hxx:1460
Definition rf_visitors.hxx:102
void backward_elimination(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition rf_algorithm.hxx:396
void rank_selection(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition rf_algorithm.hxx:493
void forward_selection(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition rf_algorithm.hxx:294
void cluster_permutation_importance(FeatureT const &features, ResponseT const &response, HClustering &linkage, MultiArray< 2, double > &distance)
Definition rf_algorithm.hxx:1272
detail::VisitorNode< A > create_visitor(A &a)
Definition rf_visitors.hxx:344
void writeHDF5(...)
Store array data in an HDF5 file.
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition sized_int.hxx:175
Definition metaprogramming.hxx:123
Definition rf_algorithm.hxx:612