24 from os.path
import join
25 np.set_printoptions(threshold=numpy.nan, precision=2)
26 warnings.simplefilter(
"ignore")
36 """Callback function to record any errors that occur in the log files. 39 Substitutes the standard python exception_hook with one that records the error into a log file. Can only work if trainSAMModel.py is called from python and not ipython because ipython overrides this substitution. 42 exc_type: Exception Type. 43 exc_value: Exception Value. 44 exc_traceback: Exception Traceback. 49 logging.error(
"Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))
51 sys.excepthook = exception_hook
53 dataPath = sys.argv[1]
54 modelPath = sys.argv[2]
55 driverName = sys.argv[3]
56 windowedMode = sys.argv[6] ==
'True' 57 baseLogFileName =
'trainErrorLog_' + driverName
60 loggerFName = join(dataPath, baseLogFileName +
'_' + str(file_i) +
'.log')
63 while os.path.isfile(loggerFName)
and os.path.getsize(loggerFName) > 0:
64 loggerFName = join(dataPath, baseLogFileName +
'_' + str(file_i) +
'.log')
68 logFormatter = logging.Formatter(
"[%(levelname)s] %(message)s")
70 logFormatter = logging.Formatter(
"\033[34m%(asctime)s [%(name)-33s] [%(levelname)8s] %(message)s\033[0m")
72 rootLogger = logging.getLogger(
'train ' + driverName)
73 rootLogger.setLevel(logging.DEBUG)
75 fileHandler = logging.FileHandler(loggerFName)
76 fileHandler.setFormatter(logFormatter)
77 rootLogger.addHandler(fileHandler)
79 consoleHandler = logging.StreamHandler()
80 consoleHandler.setFormatter(logFormatter)
81 rootLogger.addHandler(consoleHandler)
82 logging.root = rootLogger
84 logging.info(loggerFName)
89 if mm[0].calibrateUnknown
or len(mm) > 1:
90 SAMTesting.calibrateModelRecall(mm)
92 overallPerformance = 100000
93 if mm[0].model_mode !=
'temporal':
94 overallPerformance, overallPerformanceLabels, labelComparisonDict = mm[0].testPerformance(mm, mm[0].Yall, mm[0].Lall, mm[0].YtestAll, mm[0].LtestAll,
True)
95 elif mm[0].model_mode ==
'temporal':
96 overallPerformance, overallPerformanceLabels, labelComparisonDict = mm[0].testTemporalPerformance(mm, mm[0].Xall, mm[0].Yall, mm[0].Lall,
97 mm[0].XtestAll, mm[0].YtestAll, mm[0].LtestAll,
True)
99 numParts = len(mm[0].participantList)
100 for k
in range(numParts):
101 mm[k].paramsDict[
'overallPerformance'] = overallPerformance
102 mm[k].paramsDict[
'overallPerformanceLabels'] = overallPerformanceLabels
103 mm[k].paramsDict[
'labelComparisonDict'] = labelComparisonDict
104 mm[k].paramsDict[
'ratioData'] = mm[0].ratioData
105 mm[k].paramsDict[
'model_type'] = mm[k].model_type
106 mm[k].paramsDict[
'model_mode'] = mm[0].model_mode
107 mm[k].paramsDict[
'verbose'] = mm[0].verbose
108 mm[k].paramsDict[
'Quser'] = mm[0].Quser
109 mm[k].paramsDict[
'model_num_inducing'] = mm[0].model_num_inducing
110 mm[k].paramsDict[
'model_num_iterations'] = mm[0].model_num_iterations
111 mm[k].paramsDict[
'model_init_iterations'] = mm[0].model_init_iterations
112 mm[k].paramsDict[
'kernelString'] = mm[0].kernelString
113 mm[k].paramsDict[
'economy_save'] = mm[0].economy_save
115 if mm[0].model_mode !=
'temporal':
116 mm[k].paramsDict[
'textLabels'] = mm[0].textLabels
117 mm[k].paramsDict[
'modelQ'] = mm[0].SAMObject.Q
118 elif mm[0].model_mode ==
'temporal':
119 mm[k].paramsDict[
'temporalModelWindowSize'] = mm[0].temporalModelWindowSize
121 if mm[k].model_type ==
'mrd' and mm[k].model_mode !=
'temporal':
122 logging.info(mm[k].Y[
'L'].shape)
123 logging.info(mm[k].Y[
'Y'].shape)
126 mm[0].paramsDict[
'listOfModels'] = mm[0].listOfModels
127 mm[0].paramsDict[
'avgClassTime'] = mm[0].avgClassTime
128 mm[0].paramsDict[
'optimiseRecall'] = mm[0].optimiseRecall
129 if mm[0].calibrateUnknown:
130 mm[0].paramsDict[
'classificationDict'] = mm[0].classificationDict
131 mm[0].paramsDict[
'calibrateUnknown'] = mm[0].calibrateUnknown
132 mm[0].paramsDict[
'calibrated'] = mm[0].calibrated
135 mm[0].paramsDict[
'X'] = mm[0].X
137 mm[0].paramsDict[
'X'] = mm[0].X.shape
139 if mm[0].model_mode !=
'temporal':
140 mm[0].paramsDict[
'Y'] = mm[k].Y[
'Y'].shape
142 mm[0].paramsDict[
'useMaxDistance'] = mm[0].useMaxDistance
147 mm[k].paramsDict[
'X'] = mm[k].X
149 mm[k].paramsDict[
'X'] = mm[k].X.shape
150 mm[k].paramsDict[
'Y'] = mm[k].Y[
'Y'].shape
156 logging.info(
'-------------------')
157 logging.info(
'Saving: ' + mm[k].fname)
158 mm[k].saveParameters()
159 logging.info(
'Keys:')
160 logging.info(mm[k].paramsDict.keys())
161 SAMCore.save_pruned_model(mm[k].SAMObject, mm[k].fname, mm[0].economy_save, extraDict=mm[k].paramsDict)
def initialiseModels(argv, update, initMode='training')
Initialise SAM Model data structure, training parameters and user parameters.
def exception_hook(exc_type, exc_value, exc_traceback)
Generic training function.