/* * * MultiBoost - Multi-purpose boosting package * * Copyright (C) 2010 AppStat group * Laboratoire de l'Accelerateur Lineaire * Universite Paris-Sud, 11, CNRS * * This file is part of the MultiBoost library * * This library is free software; you can redistribute it * and/or modify it under the terms of the GNU General Public * License as published by the Free Software Foundation; either * version 2.1 of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * General Public License for more details. * * You should have received a copy of the GNU General Public * License along with this library; if not, write to the Free Software * Foundation, Inc., 51 Franklin St, 5th Floor, Boston, MA 02110-1301 USA * * Contact: Balazs Kegl (balazs.kegl@gmail.com) * Norman Casagrande (nova77@gmail.com) * Robert Busa-Fekete (busarobi@gmail.com) * * For more information and up-to-date version, please visit * * http://www.multiboost.org/ * */ #include "TreeLearner3.h" #include "IO/Serialization.h" #include "Others/Example.h" #include "Utils/StreamTokenizer.h" #include #include #include namespace MultiBoost { //REGISTER_LEARNER_NAME(Product, TreeLearner3) REGISTER_LEARNER(TreeLearner3) // ----------------------------------------------------------------------- void TreeLearner3::declareArguments(nor_utils::Args& args) { BaseLearner::declareArguments(args); args.declareArgument("baselearnertype", "The name of the learner that serves as a basis for the product\n" " and the number of base learners to be multiplied\n" " Don't forget to add its parameters\n", 2, " "); } // ------------------------------------------------------------------------------ void TreeLearner3::initLearningOptions(const nor_utils::Args& args) { BaseLearner::initLearningOptions(args); string baseLearnerName; args.getValue("baselearnertype", 0, baseLearnerName); args.getValue("baselearnertype", 1, _numBaseLearners); // get the registered weak learner (type from name) BaseLearner* pWeakHypothesisSource = BaseLearner::RegisteredLearners().getLearner(baseLearnerName); pWeakHypothesisSource->initLearningOptions(args); for( int ib = 0; ib < _numBaseLearners; ++ib ) { _baseLearners.push_back(pWeakHypothesisSource->create()); _baseLearners[ib]->initLearningOptions(args); vector< int > tmpVector( 2, -1 ); _idxPairs.push_back( tmpVector ); } } // ------------------------------------------------------------------------------ float TreeLearner3::classify(InputData* pData, int idx, int classIdx) { float result = 1; int ib = 0; while ( 1 ) { float phix = _baseLearners[ib]->classify(pData,idx,0); if ( phix > 0 ) { if ( _idxPairs[ ib ][ 0 ] > 0 ) { ib = _idxPairs[ ib ][ 0 ]; } else { return _baseLearners[ib]->classify( pData, idx, classIdx ); } } else { if ( _idxPairs[ ib ][ 1 ] > 0 ) { ib = _idxPairs[ ib ][ 1 ]; } else { return _baseLearners[ib]->classify( pData, idx, classIdx ); } } } } //------------------------------------------------------------------------------- float TreeLearner3::getEdge( BaseLearner* learner, InputData* d ) { float edge = 0.0; for( int i = 0; i < d->getNumExamples(); i++ ) { vector< Label > l = d->getLabels( i ); //cout << d->getRawIndex( i ) << " " << endl; for( vector