icub-client
trainSAMModel.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 # """"""""""""""""""""""""""""""""""""""""""""""
3 # The University of Sheffield
4 # WYSIWYD Project
5 #
6 # A generic training class that is used to train all models.
7 # This class can train both single model implementations as well as multiple model implementations
8 #
9 # Created on 20 July 2016
10 #
11 # @author: Daniel Camilleri
12 #
13 # """"""""""""""""""""""""""""""""""""""""""""""
14 
15 import warnings
16 import sys
17 import numpy
18 import numpy as np
19 from SAM.SAM_Core import SAMCore
20 from SAM.SAM_Core import SAMTesting
21 from SAM.SAM_Core.SAM_utils import initialiseModels
22 import logging
23 import os
24 from os.path import join
25 np.set_printoptions(threshold=numpy.nan, precision=2)
26 warnings.simplefilter("ignore")
27 
28 
33 
34 
35 def exception_hook(exc_type, exc_value, exc_traceback):
36  """Callback function to record any errors that occur in the log files.
37 
38  Documentation:
39  Substitutes the standard python exception_hook with one that records the error into a log file. Can only work if trainSAMModel.py is called from python and not ipython because ipython overrides this substitution.
40 
41  Args:
42  exc_type: Exception Type.
43  exc_value: Exception Value.
44  exc_traceback: Exception Traceback.
45 
46  Returns:
47  None
48  """
49  logging.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))
50 
51 sys.excepthook = exception_hook
52 
53 dataPath = sys.argv[1]
54 modelPath = sys.argv[2]
55 driverName = sys.argv[3]
56 windowedMode = sys.argv[6] == 'True'
57 baseLogFileName = 'trainErrorLog_' + driverName
58 
59 file_i = 0
60 loggerFName = join(dataPath, baseLogFileName + '_' + str(file_i) + '.log')
61 
62 # check if file exists
63 while os.path.isfile(loggerFName) and os.path.getsize(loggerFName) > 0:
64  loggerFName = join(dataPath, baseLogFileName + '_' + str(file_i) + '.log')
65  file_i += 1
66 
67 if windowedMode:
68  logFormatter = logging.Formatter("[%(levelname)s] %(message)s")
69 else:
70  logFormatter = logging.Formatter("\033[34m%(asctime)s [%(name)-33s] [%(levelname)8s] %(message)s\033[0m")
71 
72 rootLogger = logging.getLogger('train ' + driverName)
73 rootLogger.setLevel(logging.DEBUG)
74 
75 fileHandler = logging.FileHandler(loggerFName)
76 fileHandler.setFormatter(logFormatter)
77 rootLogger.addHandler(fileHandler)
78 
79 consoleHandler = logging.StreamHandler()
80 consoleHandler.setFormatter(logFormatter)
81 rootLogger.addHandler(consoleHandler)
82 logging.root = rootLogger
83 
84 logging.info(loggerFName)
85 
86 mm = initialiseModels(sys.argv[1:4], sys.argv[4])
87 # mm[0].SAMObject.visualise()
88 
89 if mm[0].calibrateUnknown or len(mm) > 1:
90  SAMTesting.calibrateModelRecall(mm)
91 
92 overallPerformance = 100000
93 if mm[0].model_mode != 'temporal':
94  overallPerformance, overallPerformanceLabels, labelComparisonDict = mm[0].testPerformance(mm, mm[0].Yall, mm[0].Lall, mm[0].YtestAll, mm[0].LtestAll, True)
95 elif mm[0].model_mode == 'temporal':
96  overallPerformance, overallPerformanceLabels, labelComparisonDict = mm[0].testTemporalPerformance(mm, mm[0].Xall, mm[0].Yall, mm[0].Lall,
97  mm[0].XtestAll, mm[0].YtestAll, mm[0].LtestAll, True)
98 
99 numParts = len(mm[0].participantList)
100 for k in range(numParts):
101  mm[k].paramsDict['overallPerformance'] = overallPerformance
102  mm[k].paramsDict['overallPerformanceLabels'] = overallPerformanceLabels
103  mm[k].paramsDict['labelComparisonDict'] = labelComparisonDict
104  mm[k].paramsDict['ratioData'] = mm[0].ratioData
105  mm[k].paramsDict['model_type'] = mm[k].model_type
106  mm[k].paramsDict['model_mode'] = mm[0].model_mode
107  mm[k].paramsDict['verbose'] = mm[0].verbose
108  mm[k].paramsDict['Quser'] = mm[0].Quser
109  mm[k].paramsDict['model_num_inducing'] = mm[0].model_num_inducing
110  mm[k].paramsDict['model_num_iterations'] = mm[0].model_num_iterations
111  mm[k].paramsDict['model_init_iterations'] = mm[0].model_init_iterations
112  mm[k].paramsDict['kernelString'] = mm[0].kernelString
113  mm[k].paramsDict['economy_save'] = mm[0].economy_save
114 
115  if mm[0].model_mode != 'temporal':
116  mm[k].paramsDict['textLabels'] = mm[0].textLabels
117  mm[k].paramsDict['modelQ'] = mm[0].SAMObject.Q
118  elif mm[0].model_mode == 'temporal':
119  mm[k].paramsDict['temporalModelWindowSize'] = mm[0].temporalModelWindowSize
120 
121  if mm[k].model_type == 'mrd' and mm[k].model_mode != 'temporal':
122  logging.info(mm[k].Y['L'].shape)
123  logging.info(mm[k].Y['Y'].shape)
124 
125  if k == 0:
126  mm[0].paramsDict['listOfModels'] = mm[0].listOfModels
127  mm[0].paramsDict['avgClassTime'] = mm[0].avgClassTime
128  mm[0].paramsDict['optimiseRecall'] = mm[0].optimiseRecall
129  if mm[0].calibrateUnknown:
130  mm[0].paramsDict['classificationDict'] = mm[0].classificationDict
131  mm[0].paramsDict['calibrateUnknown'] = mm[0].calibrateUnknown
132  mm[0].paramsDict['calibrated'] = mm[0].calibrated
133  if numParts == 1:
134  if mm[0].X is None:
135  mm[0].paramsDict['X'] = mm[0].X
136  else:
137  mm[0].paramsDict['X'] = mm[0].X.shape
138 
139  if mm[0].model_mode != 'temporal':
140  mm[0].paramsDict['Y'] = mm[k].Y['Y'].shape
141 
142  mm[0].paramsDict['useMaxDistance'] = mm[0].useMaxDistance
143 
144  elif numParts > 1:
145  # fname = mm[0].listOfModels[k-1]
146  if mm[k].X is None:
147  mm[k].paramsDict['X'] = mm[k].X
148  else:
149  mm[k].paramsDict['X'] = mm[k].X.shape
150  mm[k].paramsDict['Y'] = mm[k].Y['Y'].shape
151  # else:
152  # pass
153  # fname = fnameProto
154 
155  # save model with custom .pickle dictionary by iterating through all nested models
156  logging.info('-------------------')
157  logging.info('Saving: ' + mm[k].fname)
158  mm[k].saveParameters()
159  logging.info('Keys:')
160  logging.info(mm[k].paramsDict.keys())
161  SAMCore.save_pruned_model(mm[k].SAMObject, mm[k].fname, mm[0].economy_save, extraDict=mm[k].paramsDict)
def initialiseModels(argv, update, initMode='training')
Initialise SAM Model data structure, training parameters and user parameters.
Definition: SAM_utils.py:59
def exception_hook(exc_type, exc_value, exc_traceback)
Generic training function.