114 #include <opencv2/opencv.hpp>
116 #include <yarp/os/all.h>
117 #include <yarp/sig/all.h>
118 #include <yarp/cv/Cv.h>
120 #include <iCub/learningMachine/FixedRangeScaler.h>
121 #include <iCub/learningMachine/IMachineLearner.h>
122 #include <iCub/learningMachine/LSSVMLearner.h>
124 #define DEFAULT_STEP 1.0
127 using namespace yarp::os;
128 using namespace yarp::sig;
129 using namespace yarp::cv;
130 using namespace iCub::learningmachine;
134 class KarmaLearn:
public RFModule
137 FixedRangeScaler scalerIn;
138 FixedRangeScaler scalerOut;
139 map<string,IMachineLearner*> machines;
142 string configFileName;
149 BufferedPort<ImageOf<PixelMono> > plotPort;
152 IMachineLearner *createLearner()
154 IMachineLearner *learner=
new LSSVMLearner;
155 LSSVMLearner *lssvm=
dynamic_cast<LSSVMLearner*
>(learner);
156 lssvm->setDomainSize(1);
157 lssvm->setCoDomainSize(1);
159 lssvm->getKernel()->setGamma(10.0);
165 void extractMinMax(
const Bottle &b,
double &min,
double &max)
167 min=scalerOut.getUpperBoundIn();
168 max=scalerOut.getLowerBoundIn();
170 for (
int i=0; i<b.size(); i++)
172 double bi=b.get(i).asDouble();
183 void train(
const string &item,
const double input,
const double output)
185 IMachineLearner *learner;
186 map<string,IMachineLearner*>::const_iterator itr=machines.find(item);
187 if (itr==machines.end())
189 learner=createLearner();
190 machines[item]=learner;
195 Vector in(1,input),out(1,output);
196 out[0]=std::min(out[0],scalerOut.getUpperBoundIn());
198 in[0]=scalerIn.transform(in[0]);
199 out[0]=scalerOut.transform(out[0]);
201 learner->feedSample(in,out);
206 bool predict(
const string &item,
const Bottle &input, Bottle &output,
209 map<string,IMachineLearner*>::const_iterator itr=machines.find(item);
210 if (itr!=machines.end())
214 for (
int i=0; i<input.size(); i++)
216 Vector in(1,input.get(i).asDouble());
217 in[0]=scalerIn.transform(in[0]);
219 IMachineLearner *learner=itr->second;
220 Prediction prediction=learner->predict(in);
222 Vector v=prediction.getPrediction();
223 output.addDouble(scalerOut.unTransform(v[0]));
225 if (prediction.hasVariance())
227 Vector v=prediction.getVariance();
228 variance.addDouble(v[0]);
231 variance.addDouble(-1.0);
241 bool optimize(
const string &item,
const Bottle &searchDomain,
double &input,
244 map<string,IMachineLearner*>::const_iterator itr=machines.find(item);
245 if (itr!=machines.end())
247 IMachineLearner *learner=itr->second;
249 input=scalerIn.getLowerBoundIn();
250 double maxOut=scalerOut.getLowerBoundOut();
252 Bottle bin,bout,bdummy;
253 bin.addDouble(input);
254 predict(item,bin,bout,bdummy);
255 output=bout.get(0).asDouble();
257 for (
int i=0; i<searchDomain.size(); i++)
259 double val=searchDomain.get(i).asDouble();
260 Vector in(1,scalerIn.transform(val));
261 Prediction prediction=learner->predict(in);
262 Vector v=prediction.getPrediction();
267 output=scalerOut.unTransform(v[0]);
282 for (map<string,IMachineLearner*>::const_iterator itr=machines.begin(); itr!=machines.end(); itr++)
283 ret.addString(itr->first);
289 bool machineContent(
const string &item,
string &content)
291 map<string,IMachineLearner*>::iterator itr=machines.find(item);
292 if (itr!=machines.end())
294 content=itr->second->toString();
304 for (map<string,IMachineLearner*>::const_iterator itr=machines.begin(); itr!=machines.end(); itr++)
312 bool clear(
const string &item)
314 map<string,IMachineLearner*>::iterator itr=machines.find(item);
315 if (itr!=machines.end())
332 ofstream fout(configFileName.c_str());
334 fout<<
"[general]"<<endl;
335 fout<<
"name "<<name<<endl;
336 fout<<
"num_items "<<machines.size()<<endl;
337 fout<<
"in_lb "<<scalerIn.getLowerBoundIn()<<endl;
338 fout<<
"in_ub "<<scalerIn.getUpperBoundIn()<<endl;
339 fout<<
"out_lb "<<scalerOut.getLowerBoundIn()<<endl;
340 fout<<
"out_ub "<<scalerOut.getUpperBoundIn()<<endl;
344 for (map<string,IMachineLearner*>::const_iterator itr=machines.begin(); itr!=machines.end(); itr++, i++)
346 fout<<
"[item_"<<i<<
"]"<<endl;
347 fout<<
"name "<<itr->first<<endl;
348 fout<<
"learner "<<
"("+itr->second->toString()+
")"<<endl;
358 if ((plotItem!=
"") && (plotPort.getOutputCount()>0))
360 Bottle input,output,variance;
361 for (
double d=scalerIn.getLowerBoundIn(); d<scalerIn.getUpperBoundIn(); d+=plotStep)
364 if (predict(plotItem,input,output,variance))
366 ImageOf<PixelMono> &img=plotPort.prepare();
368 cv::Mat imgMat=toCvMat(img);
369 imgMat.setTo(cv::Scalar(255));
371 cv::putText(imgMat,plotItem,cv::Point(250,20),cv::FONT_HERSHEY_SIMPLEX,0.5,cv::Scalar(0));
373 double x_min=scalerIn.getLowerBoundIn();
374 double x_max=scalerIn.getUpperBoundIn();
375 double x_range=x_max-x_min;
378 extractMinMax(output,y_min,y_max);
379 y_min*=(y_min>0.0?0.8:1.2);
380 y_max*=(y_max>0.0?1.2:0.8);
381 double y_range=y_max-y_min;
384 ostringstream tag; tag.precision(3);
386 cv::putText(imgMat,tag.str(),cv::Point(10,230),cv::FONT_HERSHEY_SIMPLEX,0.5,cv::Scalar(0));
390 ostringstream tag; tag.precision(3);
392 cv::putText(imgMat,tag.str(),cv::Point(280,230),cv::FONT_HERSHEY_SIMPLEX,0.5,cv::Scalar(0));
396 ostringstream tag; tag.precision(3);
398 cv::putText(imgMat,tag.str(),cv::Point(10,215),cv::FONT_HERSHEY_SIMPLEX,0.5,cv::Scalar(0));
402 ostringstream tag; tag.precision(3);
404 cv::putText(imgMat,tag.str(),cv::Point(10,20),cv::FONT_HERSHEY_SIMPLEX,0.5,cv::Scalar(0));
408 for (
int i=0; i<input.size(); i++)
411 p.x=int((img.width()/x_range)*(input.get(i).asDouble()-x_min));
412 p.y=img.height()-int((img.height()/y_range)*(output.get(i).asDouble()-y_min));
415 cv::line(imgMat,p,pold,cv::Scalar(0),2);
426 bool respond(
const Bottle &command, Bottle &reply)
428 lock_guard<mutex> lg(mtx);
429 if (command.size()>=1)
431 int header=command.get(0).asVocab();
432 Bottle payload=command.tail();
433 if (header==Vocab::encode(
"train"))
435 if (payload.size()>=3)
437 string item=payload.get(0).asString();
438 double input=payload.get(1).asDouble();
439 double output=payload.get(2).asDouble();
441 train(item,input,output);
442 reply.addVocab(Vocab::encode(
"ack"));
448 reply.addVocab(Vocab::encode(
"nack"));
450 else if (header==Vocab::encode(
"predict"))
452 if (payload.size()>=2)
454 Bottle output,variance;
455 string item=payload.get(0).asString();
456 if (payload.get(1).isDouble())
458 Bottle input; input.addDouble(payload.get(1).asDouble());
459 if (predict(item,input,output,variance))
461 reply.addVocab(Vocab::encode(
"ack"));
462 reply.addDouble(output.get(0).asDouble());
463 reply.addDouble(variance.get(0).asDouble());
466 reply.addVocab(Vocab::encode(
"nack"));
468 else if (payload.get(1).isList())
470 if (predict(item,*payload.get(1).asList(),output,variance))
472 reply.addVocab(Vocab::encode(
"ack"));
473 reply.addList().append(output);
474 reply.addList().append(variance);
477 reply.addVocab(Vocab::encode(
"nack"));
480 reply.addVocab(Vocab::encode(
"nack"));
483 reply.addVocab(Vocab::encode(
"nack"));
485 else if (header==Vocab::encode(
"span"))
487 if (payload.size()>=1)
489 double step=DEFAULT_STEP;
490 string item=payload.get(0).asString();
491 if (payload.size()>=2)
492 step=payload.get(1).asDouble();
494 Bottle input,output,variance;
495 for (
double d=scalerIn.getLowerBoundIn(); d<scalerIn.getUpperBoundIn(); d+=step)
498 if (predict(item,input,output,variance))
500 reply.addVocab(Vocab::encode(
"ack"));
501 reply.addList().append(output);
502 reply.addList().append(variance);
505 reply.addVocab(Vocab::encode(
"nack"));
508 reply.addVocab(Vocab::encode(
"nack"));
510 else if (header==Vocab::encode(
"optimize"))
512 if (payload.size()>=1)
515 double step=DEFAULT_STEP;
517 string item=payload.get(0).asString();
518 if (payload.size()>=2)
520 if (payload.get(1).isDouble())
521 step=payload.get(1).asDouble();
522 else if (payload.get(1).isList())
523 searchDomain=*payload.get(1).asList();
526 if (searchDomain.size()==0)
527 for (
double d=scalerIn.getLowerBoundIn(); d<scalerIn.getUpperBoundIn(); d+=step)
528 searchDomain.addDouble(d);
531 if (optimize(item,searchDomain,input,output))
533 reply.addVocab(Vocab::encode(
"ack"));
534 reply.addDouble(input);
535 reply.addDouble(output);
538 reply.addVocab(Vocab::encode(
"nack"));
541 reply.addVocab(Vocab::encode(
"nack"));
543 else if (header==Vocab::encode(
"items"))
545 reply.addVocab(Vocab::encode(
"ack"));
546 reply.append(items());
548 else if (header==Vocab::encode(
"machine"))
550 if (payload.size()>=1)
552 string item=payload.get(0).asString();
554 if (machineContent(item,content))
556 reply.addVocab(Vocab::encode(
"ack"));
557 reply.addString(content);
560 reply.addVocab(Vocab::encode(
"nack"));
563 reply.addVocab(Vocab::encode(
"nack"));
565 else if (header==Vocab::encode(
"clear"))
567 if (payload.size()>=1)
569 string item=payload.get(0).asString();
571 reply.addVocab(Vocab::encode(
"ack"));
573 reply.addVocab(Vocab::encode(
"nack"));
578 reply.addVocab(Vocab::encode(
"ack"));
581 else if (header==Vocab::encode(
"save"))
584 reply.addVocab(Vocab::encode(
"ack"));
586 else if (header==Vocab::encode(
"plot"))
588 if (payload.size()>=1)
590 string item=payload.get(0).asString();
591 if (payload.size()>=2)
592 plotStep=payload.get(1).asDouble();
594 if (machines.find(item)!=machines.end())
597 reply.addVocab(Vocab::encode(
"ack"));
600 reply.addVocab(Vocab::encode(
"nack"));
603 reply.addVocab(Vocab::encode(
"nack"));
606 reply.addVocab(Vocab::encode(
"nack"));
609 reply.addVocab(Vocab::encode(
"nack"));
616 bool configure(ResourceFinder &rf)
627 Bottle &generalGroup=rf.findGroup(
"general");
628 if (!generalGroup.isNull())
630 name=generalGroup.check(
"name",Value(
"karmaLearn")).asString();
631 nItems=generalGroup.check(
"num_items",Value(0)).asInt();
632 in_lb=generalGroup.check(
"in_lb",Value(0.0)).asDouble();
633 in_ub=generalGroup.check(
"in_ub",Value(360.0)).asDouble();
634 out_lb=generalGroup.check(
"out_lb",Value(0.0)).asDouble();
635 out_ub=generalGroup.check(
"out_ub",Value(2.0)).asDouble();
638 scalerIn.setLowerBoundIn(in_lb);
639 scalerIn.setUpperBoundIn(in_ub);
640 scalerIn.setLowerBoundOut(0.0);
641 scalerIn.setUpperBoundOut(1.0);
643 scalerOut.setLowerBoundIn(out_lb);
644 scalerOut.setUpperBoundIn(out_ub);
645 scalerOut.setLowerBoundOut(0.0);
646 scalerOut.setUpperBoundOut(1.0);
649 for (
int i=0; i<nItems; i++)
651 ostringstream item; item<<
"item_"<<i;
652 Bottle &itemGroup=rf.findGroup(item.str());
653 if (!itemGroup.isNull())
655 if (!itemGroup.check(
"name"))
658 IMachineLearner *learner=createLearner();
659 if (itemGroup.check(
"learner"))
660 learner->fromString(itemGroup.find(
"learner").asList()->toString());
662 machines[itemGroup.find(
"name").asString()]=learner;
667 configFileName=rf.findPath(
"from");
672 plotPort.open(
"/"+name+
"/plot:o");
673 rpcPort.open(
"/"+name+
"/rpc");
680 bool interruptModule()
682 plotPort.interrupt();
706 lock_guard<mutex> lg(mtx);
714 int main(
int argc,
char *argv[])
717 if (!yarp.checkNetwork())
719 printf(
"YARP server not available!\n");
724 rf.setDefaultContext(
"karma");
725 rf.setDefaultConfigFile(
"karmaLearn.ini");
726 rf.configure(argc,argv);
728 KarmaLearn karmaLearn;
729 return karmaLearn.runModule(rf);