karma
All Modules
main.cpp
1 /*
2  * Copyright (C) 2012 Department of Robotics Brain and Cognitive Sciences - Istituto Italiano di Tecnologia
3  * Author: Ugo Pattacini
4  * email: ugo.pattacini@iit.it
5  * Permission is granted to copy, distribute, and/or modify this program
6  * under the terms of the GNU General Public License, version 2 or any
7  * later version published by the Free Software Foundation.
8  *
9  * A copy of the license can be found at
10  * http://www.robotcub.org/icub/license/gpl.txt
11  *
12  * This program is distributed in the hope that it will be useful, but
13  * WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
15  * Public License for more details
16 */
17 
106 #include <cstdio>
107 #include <mutex>
108 #include <fstream>
109 #include <sstream>
110 #include <string>
111 #include <map>
112 #include <algorithm>
113 
114 #include <opencv2/opencv.hpp>
115 
116 #include <yarp/os/all.h>
117 #include <yarp/sig/all.h>
118 #include <yarp/cv/Cv.h>
119 
120 #include <iCub/learningMachine/FixedRangeScaler.h>
121 #include <iCub/learningMachine/IMachineLearner.h>
122 #include <iCub/learningMachine/LSSVMLearner.h>
123 
124 #define DEFAULT_STEP 1.0
125 
126 using namespace std;
127 using namespace yarp::os;
128 using namespace yarp::sig;
129 using namespace yarp::cv;
130 using namespace iCub::learningmachine;
131 
132 
133 /************************************************************************/
134 class KarmaLearn: public RFModule
135 {
136 protected:
137  FixedRangeScaler scalerIn;
138  FixedRangeScaler scalerOut;
139  map<string,IMachineLearner*> machines;
140 
141  string name;
142  string configFileName;
143 
144  string plotItem;
145  double plotStep;
146 
147  mutex mtx;
148  RpcServer rpcPort;
149  BufferedPort<ImageOf<PixelMono> > plotPort;
150 
151  /************************************************************************/
152  IMachineLearner *createLearner()
153  {
154  IMachineLearner *learner=new LSSVMLearner;
155  LSSVMLearner *lssvm=dynamic_cast<LSSVMLearner*>(learner);
156  lssvm->setDomainSize(1);
157  lssvm->setCoDomainSize(1);
158  lssvm->setC(100.0);
159  lssvm->getKernel()->setGamma(10.0);
160 
161  return learner;
162  }
163 
164  /************************************************************************/
165  void extractMinMax(const Bottle &b, double &min, double &max)
166  {
167  min=scalerOut.getUpperBoundIn();
168  max=scalerOut.getLowerBoundIn();
169 
170  for (int i=0; i<b.size(); i++)
171  {
172  double bi=b.get(i).asDouble();
173 
174  if (max<bi)
175  max=bi;
176 
177  if (min>bi)
178  min=bi;
179  }
180  }
181 
182  /************************************************************************/
183  void train(const string &item, const double input, const double output)
184  {
185  IMachineLearner *learner;
186  map<string,IMachineLearner*>::const_iterator itr=machines.find(item);
187  if (itr==machines.end())
188  {
189  learner=createLearner();
190  machines[item]=learner;
191  }
192  else
193  learner=itr->second;
194 
195  Vector in(1,input),out(1,output);
196  out[0]=std::min(out[0],scalerOut.getUpperBoundIn());
197 
198  in[0]=scalerIn.transform(in[0]);
199  out[0]=scalerOut.transform(out[0]);
200 
201  learner->feedSample(in,out);
202  learner->train();
203  }
204 
205  /************************************************************************/
206  bool predict(const string &item, const Bottle &input, Bottle &output,
207  Bottle &variance)
208  {
209  map<string,IMachineLearner*>::const_iterator itr=machines.find(item);
210  if (itr!=machines.end())
211  {
212  output.clear();
213  variance.clear();
214  for (int i=0; i<input.size(); i++)
215  {
216  Vector in(1,input.get(i).asDouble());
217  in[0]=scalerIn.transform(in[0]);
218 
219  IMachineLearner *learner=itr->second;
220  Prediction prediction=learner->predict(in);
221 
222  Vector v=prediction.getPrediction();
223  output.addDouble(scalerOut.unTransform(v[0]));
224 
225  if (prediction.hasVariance())
226  {
227  Vector v=prediction.getVariance();
228  variance.addDouble(v[0]);
229  }
230  else
231  variance.addDouble(-1.0);
232  }
233 
234  return true;
235  }
236  else
237  return false;
238  }
239 
240  /************************************************************************/
241  bool optimize(const string &item, const Bottle &searchDomain, double &input,
242  double &output)
243  {
244  map<string,IMachineLearner*>::const_iterator itr=machines.find(item);
245  if (itr!=machines.end())
246  {
247  IMachineLearner *learner=itr->second;
248 
249  input=scalerIn.getLowerBoundIn();
250  double maxOut=scalerOut.getLowerBoundOut();
251 
252  Bottle bin,bout,bdummy;
253  bin.addDouble(input);
254  predict(item,bin,bout,bdummy);
255  output=bout.get(0).asDouble();
256 
257  for (int i=0; i<searchDomain.size(); i++)
258  {
259  double val=searchDomain.get(i).asDouble();
260  Vector in(1,scalerIn.transform(val));
261  Prediction prediction=learner->predict(in);
262  Vector v=prediction.getPrediction();
263 
264  if (v[0]>maxOut)
265  {
266  input=val;
267  output=scalerOut.unTransform(v[0]);
268  maxOut=v[0];
269  }
270  }
271 
272  return true;
273  }
274  else
275  return false;
276  }
277 
278  /************************************************************************/
279  Bottle items()
280  {
281  Bottle ret;
282  for (map<string,IMachineLearner*>::const_iterator itr=machines.begin(); itr!=machines.end(); itr++)
283  ret.addString(itr->first);
284 
285  return ret;
286  }
287 
288  /************************************************************************/
289  bool machineContent(const string &item, string &content)
290  {
291  map<string,IMachineLearner*>::iterator itr=machines.find(item);
292  if (itr!=machines.end())
293  {
294  content=itr->second->toString();
295  return true;
296  }
297  else
298  return false;
299  }
300 
301  /************************************************************************/
302  void clear()
303  {
304  for (map<string,IMachineLearner*>::const_iterator itr=machines.begin(); itr!=machines.end(); itr++)
305  delete itr->second;
306 
307  machines.clear();
308  plotItem="";
309  }
310 
311  /************************************************************************/
312  bool clear(const string &item)
313  {
314  map<string,IMachineLearner*>::iterator itr=machines.find(item);
315  if (itr!=machines.end())
316  {
317  delete itr->second;
318  machines.erase(itr);
319 
320  if (plotItem==item)
321  plotItem="";
322 
323  return true;
324  }
325  else
326  return false;
327  }
328 
329  /************************************************************************/
330  void save()
331  {
332  ofstream fout(configFileName.c_str());
333 
334  fout<<"[general]"<<endl;
335  fout<<"name "<<name<<endl;
336  fout<<"num_items "<<machines.size()<<endl;
337  fout<<"in_lb "<<scalerIn.getLowerBoundIn()<<endl;
338  fout<<"in_ub "<<scalerIn.getUpperBoundIn()<<endl;
339  fout<<"out_lb "<<scalerOut.getLowerBoundIn()<<endl;
340  fout<<"out_ub "<<scalerOut.getUpperBoundIn()<<endl;
341  fout<<endl;
342 
343  int i=0;
344  for (map<string,IMachineLearner*>::const_iterator itr=machines.begin(); itr!=machines.end(); itr++, i++)
345  {
346  fout<<"[item_"<<i<<"]"<<endl;
347  fout<<"name "<<itr->first<<endl;
348  fout<<"learner "<<"("+itr->second->toString()+")"<<endl;
349  fout<<endl;
350  }
351 
352  fout.close();
353  }
354 
355  /************************************************************************/
356  void plot()
357  {
358  if ((plotItem!="") && (plotPort.getOutputCount()>0))
359  {
360  Bottle input,output,variance;
361  for (double d=scalerIn.getLowerBoundIn(); d<scalerIn.getUpperBoundIn(); d+=plotStep)
362  input.addDouble(d);
363 
364  if (predict(plotItem,input,output,variance))
365  {
366  ImageOf<PixelMono> &img=plotPort.prepare();
367  img.resize(320,240);
368  cv::Mat imgMat=toCvMat(img);
369  imgMat.setTo(cv::Scalar(255));
370 
371  cv::putText(imgMat,plotItem,cv::Point(250,20),cv::FONT_HERSHEY_SIMPLEX,0.5,cv::Scalar(0));
372 
373  double x_min=scalerIn.getLowerBoundIn();
374  double x_max=scalerIn.getUpperBoundIn();
375  double x_range=x_max-x_min;
376 
377  double y_min,y_max;
378  extractMinMax(output,y_min,y_max);
379  y_min*=(y_min>0.0?0.8:1.2);
380  y_max*=(y_max>0.0?1.2:0.8);
381  double y_range=y_max-y_min;
382 
383  {
384  ostringstream tag; tag.precision(3);
385  tag<<x_min;
386  cv::putText(imgMat,tag.str(),cv::Point(10,230),cv::FONT_HERSHEY_SIMPLEX,0.5,cv::Scalar(0));
387  }
388 
389  {
390  ostringstream tag; tag.precision(3);
391  tag<<x_max;
392  cv::putText(imgMat,tag.str(),cv::Point(280,230),cv::FONT_HERSHEY_SIMPLEX,0.5,cv::Scalar(0));
393  }
394 
395  {
396  ostringstream tag; tag.precision(3);
397  tag<<y_min;
398  cv::putText(imgMat,tag.str(),cv::Point(10,215),cv::FONT_HERSHEY_SIMPLEX,0.5,cv::Scalar(0));
399  }
400 
401  {
402  ostringstream tag; tag.precision(3);
403  tag<<y_max;
404  cv::putText(imgMat,tag.str(),cv::Point(10,20),cv::FONT_HERSHEY_SIMPLEX,0.5,cv::Scalar(0));
405  }
406 
407  cv::Point pold;
408  for (int i=0; i<input.size(); i++)
409  {
410  cv::Point p;
411  p.x=int((img.width()/x_range)*(input.get(i).asDouble()-x_min));
412  p.y=img.height()-int((img.height()/y_range)*(output.get(i).asDouble()-y_min));
413 
414  if (i>0)
415  cv::line(imgMat,p,pold,cv::Scalar(0),2);
416 
417  pold=p;
418  }
419 
420  plotPort.write();
421  }
422  }
423  }
424 
425  /************************************************************************/
426  bool respond(const Bottle &command, Bottle &reply)
427  {
428  lock_guard<mutex> lg(mtx);
429  if (command.size()>=1)
430  {
431  int header=command.get(0).asVocab();
432  Bottle payload=command.tail();
433  if (header==Vocab::encode("train"))
434  {
435  if (payload.size()>=3)
436  {
437  string item=payload.get(0).asString();
438  double input=payload.get(1).asDouble();
439  double output=payload.get(2).asDouble();
440 
441  train(item,input,output);
442  reply.addVocab(Vocab::encode("ack"));
443 
444  // trigger a change for the "plot"
445  plotItem=item;
446  }
447  else
448  reply.addVocab(Vocab::encode("nack"));
449  }
450  else if (header==Vocab::encode("predict"))
451  {
452  if (payload.size()>=2)
453  {
454  Bottle output,variance;
455  string item=payload.get(0).asString();
456  if (payload.get(1).isDouble())
457  {
458  Bottle input; input.addDouble(payload.get(1).asDouble());
459  if (predict(item,input,output,variance))
460  {
461  reply.addVocab(Vocab::encode("ack"));
462  reply.addDouble(output.get(0).asDouble());
463  reply.addDouble(variance.get(0).asDouble());
464  }
465  else
466  reply.addVocab(Vocab::encode("nack"));
467  }
468  else if (payload.get(1).isList())
469  {
470  if (predict(item,*payload.get(1).asList(),output,variance))
471  {
472  reply.addVocab(Vocab::encode("ack"));
473  reply.addList().append(output);
474  reply.addList().append(variance);
475  }
476  else
477  reply.addVocab(Vocab::encode("nack"));
478  }
479  else
480  reply.addVocab(Vocab::encode("nack"));
481  }
482  else
483  reply.addVocab(Vocab::encode("nack"));
484  }
485  else if (header==Vocab::encode("span"))
486  {
487  if (payload.size()>=1)
488  {
489  double step=DEFAULT_STEP;
490  string item=payload.get(0).asString();
491  if (payload.size()>=2)
492  step=payload.get(1).asDouble();
493 
494  Bottle input,output,variance;
495  for (double d=scalerIn.getLowerBoundIn(); d<scalerIn.getUpperBoundIn(); d+=step)
496  input.addDouble(d);
497 
498  if (predict(item,input,output,variance))
499  {
500  reply.addVocab(Vocab::encode("ack"));
501  reply.addList().append(output);
502  reply.addList().append(variance);
503  }
504  else
505  reply.addVocab(Vocab::encode("nack"));
506  }
507  else
508  reply.addVocab(Vocab::encode("nack"));
509  }
510  else if (header==Vocab::encode("optimize"))
511  {
512  if (payload.size()>=1)
513  {
514  Bottle searchDomain;
515  double step=DEFAULT_STEP;
516 
517  string item=payload.get(0).asString();
518  if (payload.size()>=2)
519  {
520  if (payload.get(1).isDouble())
521  step=payload.get(1).asDouble();
522  else if (payload.get(1).isList())
523  searchDomain=*payload.get(1).asList();
524  }
525 
526  if (searchDomain.size()==0)
527  for (double d=scalerIn.getLowerBoundIn(); d<scalerIn.getUpperBoundIn(); d+=step)
528  searchDomain.addDouble(d);
529 
530  double input,output;
531  if (optimize(item,searchDomain,input,output))
532  {
533  reply.addVocab(Vocab::encode("ack"));
534  reply.addDouble(input);
535  reply.addDouble(output);
536  }
537  else
538  reply.addVocab(Vocab::encode("nack"));
539  }
540  else
541  reply.addVocab(Vocab::encode("nack"));
542  }
543  else if (header==Vocab::encode("items"))
544  {
545  reply.addVocab(Vocab::encode("ack"));
546  reply.append(items());
547  }
548  else if (header==Vocab::encode("machine"))
549  {
550  if (payload.size()>=1)
551  {
552  string item=payload.get(0).asString();
553  string content;
554  if (machineContent(item,content))
555  {
556  reply.addVocab(Vocab::encode("ack"));
557  reply.addString(content);
558  }
559  else
560  reply.addVocab(Vocab::encode("nack"));
561  }
562  else
563  reply.addVocab(Vocab::encode("nack"));
564  }
565  else if (header==Vocab::encode("clear"))
566  {
567  if (payload.size()>=1)
568  {
569  string item=payload.get(0).asString();
570  if (clear(item))
571  reply.addVocab(Vocab::encode("ack"));
572  else
573  reply.addVocab(Vocab::encode("nack"));
574  }
575  else
576  {
577  clear();
578  reply.addVocab(Vocab::encode("ack"));
579  }
580  }
581  else if (header==Vocab::encode("save"))
582  {
583  save();
584  reply.addVocab(Vocab::encode("ack"));
585  }
586  else if (header==Vocab::encode("plot"))
587  {
588  if (payload.size()>=1)
589  {
590  string item=payload.get(0).asString();
591  if (payload.size()>=2)
592  plotStep=payload.get(1).asDouble();
593 
594  if (machines.find(item)!=machines.end())
595  {
596  plotItem=item;
597  reply.addVocab(Vocab::encode("ack"));
598  }
599  else
600  reply.addVocab(Vocab::encode("nack"));
601  }
602  else
603  reply.addVocab(Vocab::encode("nack"));
604  }
605  else
606  reply.addVocab(Vocab::encode("nack"));
607  }
608  else
609  reply.addVocab(Vocab::encode("nack"));
610 
611  return true;
612  }
613 
614 public:
615  /************************************************************************/
616  bool configure(ResourceFinder &rf)
617  {
618  // default values
619  name="karmaLearn";
620  int nItems=0;
621 
622  double in_lb=0.0;
623  double in_ub=360.0;
624  double out_lb=0.0;
625  double out_ub=2.0;
626 
627  Bottle &generalGroup=rf.findGroup("general");
628  if (!generalGroup.isNull())
629  {
630  name=generalGroup.check("name",Value("karmaLearn")).asString();
631  nItems=generalGroup.check("num_items",Value(0)).asInt();
632  in_lb=generalGroup.check("in_lb",Value(0.0)).asDouble();
633  in_ub=generalGroup.check("in_ub",Value(360.0)).asDouble();
634  out_lb=generalGroup.check("out_lb",Value(0.0)).asDouble();
635  out_ub=generalGroup.check("out_ub",Value(2.0)).asDouble();
636  }
637 
638  scalerIn.setLowerBoundIn(in_lb);
639  scalerIn.setUpperBoundIn(in_ub);
640  scalerIn.setLowerBoundOut(0.0);
641  scalerIn.setUpperBoundOut(1.0);
642 
643  scalerOut.setLowerBoundIn(out_lb);
644  scalerOut.setUpperBoundIn(out_ub);
645  scalerOut.setLowerBoundOut(0.0);
646  scalerOut.setUpperBoundOut(1.0);
647 
648  // retrieve machines for each item
649  for (int i=0; i<nItems; i++)
650  {
651  ostringstream item; item<<"item_"<<i;
652  Bottle &itemGroup=rf.findGroup(item.str());
653  if (!itemGroup.isNull())
654  {
655  if (!itemGroup.check("name"))
656  continue;
657 
658  IMachineLearner *learner=createLearner();
659  if (itemGroup.check("learner"))
660  learner->fromString(itemGroup.find("learner").asList()->toString());
661 
662  machines[itemGroup.find("name").asString()]=learner;
663  }
664  }
665 
666  // save the file name
667  configFileName=rf.findPath("from");
668 
669  plotItem="";
670  plotStep=1.0;
671 
672  plotPort.open("/"+name+"/plot:o");
673  rpcPort.open("/"+name+"/rpc");
674  attach(rpcPort);
675 
676  return true;
677  }
678 
679  /************************************************************************/
680  bool interruptModule()
681  {
682  plotPort.interrupt();
683  rpcPort.interrupt();
684  return true;
685  }
686 
687  /************************************************************************/
688  bool close()
689  {
690  save();
691  clear();
692  plotPort.close();
693  rpcPort.close();
694  return true;
695  }
696 
697  /************************************************************************/
698  double getPeriod()
699  {
700  return 0.25;
701  }
702 
703  /************************************************************************/
704  bool updateModule()
705  {
706  lock_guard<mutex> lg(mtx);
707  plot();
708  return true;
709  }
710 };
711 
712 
713 /************************************************************************/
714 int main(int argc, char *argv[])
715 {
716  Network yarp;
717  if (!yarp.checkNetwork())
718  {
719  printf("YARP server not available!\n");
720  return 1;
721  }
722 
723  ResourceFinder rf;
724  rf.setDefaultContext("karma");
725  rf.setDefaultConfigFile("karmaLearn.ini");
726  rf.configure(argc,argv);
727 
728  KarmaLearn karmaLearn;
729  return karmaLearn.runModule(rf);
730 }
731 
732 
733