ml_classifiers
Package Summary
ml_classifiers
- Author: Scott Niekum
- License: BSD
- Source: git https://github.com/sniekum/ml_classifiers.git (branch: master)
Contents
Overview
This package provides a ROS service for interfacing with various machine learning algorithms for supervised classification. Three example classifiers are included, including a nearest neighbor and support vector machine classifier, but additional classifiers can be added easily though pluginlib.
Installation
$ sudo apt-get install ros-fuerte-ml-classifiers
$ sudo apt-get install ros-groovy-ml-classifiers
$ sudo apt-get install ros-hydro-ml-classifiers
ROS API
ml_classifiers provides a general interface for all plugin classifiers to implement:
Plugins
ml_classifiers comes with 3 plugins built in: A "zero" classifier that classifies everything as class 0 no matter what, a nearest neighbor classifier, and an SVM classifier based on libSVM. More classifiers can be added easily though pluginlib and by implementing the above interface.
Usage
Simply compile all additional desired plugins and then:
$ roslaunch ml_classifiers classifier_server.launch
An example python script to wrap the service calls and test the SVM classifier:
1 import roslib; roslib.load_manifest('ml_classifiers')
2 import rospy
3 import ml_classifiers.srv
4 import ml_classifiers.msg
5
6 #Wrapper for calls to ROS classifier service and management of classifier data
7 class ClassifierWrapper:
8
9 def __init__(self):
10 #Set up Classifier service handles
11 print 'Waiting for Classifier services...'
12 rospy.wait_for_service("/ml_classifiers/create_classifier")
13 self.add_class_data = rospy.ServiceProxy(
14 "/ml_classifiers/add_class_data",
15 ml_classifiers.srv.AddClassData, persistent=True)
16 self.classify_data = rospy.ServiceProxy(
17 "/ml_classifiers/classify_data",
18 ml_classifiers.srv.ClassifyData, persistent=True)
19 self.clear_classifier = rospy.ServiceProxy(
20 "/ml_classifiers/clear_classifier",
21 ml_classifiers.srv.ClearClassifier, persistent=True)
22 self.create_classifier = rospy.ServiceProxy(
23 "/ml_classifiers/create_classifier",
24 ml_classifiers.srv.CreateClassifier, persistent=True)
25 self.load_classifier = rospy.ServiceProxy(
26 "/ml_classifiers/load_classifier",
27 ml_classifiers.srv.LoadClassifier, persistent=True)
28 self.save_classifier = rospy.ServiceProxy(
29 "/ml_classifiers/save_classifier",
30 ml_classifiers.srv.SaveClassifier, persistent=True)
31 self.train_classifier = rospy.ServiceProxy(
32 "/ml_classifiers/train_classifier",
33 ml_classifiers.srv.TrainClassifier, persistent=True)
34 print 'OK\n'
35
36
37 def addClassDataPoint(self, identifier, target_class, p):
38 req = ml_classifiers.srv.AddClassDataRequest()
39 req.identifier = identifier
40 dp = ml_classifiers.msg.ClassDataPoint()
41 dp.point = p
42 dp.target_class = target_class
43 req.data.append(dp)
44 resp = self.add_class_data(req)
45
46
47 def addClassDataPoints(self, identifier, target_classes, pts):
48 req = ml_classifiers.srv.AddClassDataRequest()
49 req.identifier = identifier
50 for i in xrange(len(pts)):
51 dp = ml_classifiers.msg.ClassDataPoint()
52 dp.point = pts[i]
53 dp.target_class = target_classes[i]
54 req.data.append(dp)
55 resp = self.add_class_data(req)
56
57
58 def classifyPoint(self, identifier, p):
59 req = ml_classifiers.srv.ClassifyDataRequest()
60 req.identifier = identifier
61 dp = ml_classifiers.msg.ClassDataPoint()
62 dp.point = p
63 req.data.append(dp)
64 resp = self.classify_data(req)
65 return resp.classifications[0]
66
67
68 def classifyPoints(self, identifier, pts):
69 req = ml_classifiers.srv.ClassifyDataRequest()
70 req.identifier = identifier
71 for p in pts:
72 dp = ml_classifiers.msg.ClassDataPoint()
73 dp.point = p
74 req.data.append(dp)
75
76 resp = self.classify_data(req)
77 return resp.classifications
78
79
80 def clearClassifier(self, identifier):
81 req = ml_classifiers.srv.ClearClassifierRequest()
82 req.identifier = identifier
83 resp = self.clear_classifier(req)
84
85
86 def createClassifier(self, identifier, class_type):
87 req = ml_classifiers.srv.CreateClassifierRequest()
88 req.identifier = identifier
89 req.class_type = class_type
90 resp = self.create_classifier(req)
91
92
93 def loadClassifier(self, identifier, class_type, filename):
94 req = ml_classifiers.srv.LoadClassifierRequest()
95 req.identifier = identifier
96 req.class_type = class_type
97 req.filename = filename
98 resp = self.load_classifier(req)
99
100
101 def saveClassifier(self, identifier, filename):
102 req = ml_classifiers.srv.SaveClassifierRequest()
103 req.identifier = identifier
104 req.filename = filename
105 resp = self.save_classifier(req)
106
107
108 def trainClassifier(self, identifier):
109 req = ml_classifiers.srv.TrainClassifierRequest()
110 req.identifier = identifier
111 resp = self.train_classifier(req)
112
113
114 if __name__ == '__main__':
115 cw = ClassifierWrapper()
116 cw.createClassifier('test','ml_classifiers/SVMClassifier')
117
118 targs = ['1','1','2','2','3']
119 pts = [[0.1,0.2],[0.3,0.1],[3.1,3.2],[3.3,4.1],[5.1,5.2]]
120 cw.addClassDataPoints('test', targs, pts)
121 cw.trainClassifier('test')
122
123 testpts = [[0.0,0.0],[5.5,5.5],[2.9,3.6]]
124 resp = cw.classifyPoints('test',testpts)
125 print resp
Report a Bug
<<TracLink(REPO COMPONENT)>>