00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037 '''
00038 MySQL Protocol Result Objects
00039 '''
00040
00041 import struct
00042 import unittest
00043
00044 class BadFieldCount(Exception):
00045 pass
00046
00047 class OkResult(object):
00048 '''This class represents an OK result packet sent from the server.'''
00049
00050 def __init__(self, packed=None, affected_rows=0, insert_id=0, status=0,
00051 warning_count=0, message='', version_40=False):
00052 if packed is None:
00053 self.affected_rows = affected_rows
00054 self.insert_id = insert_id
00055 self.status = status
00056 self.message = message
00057 self.version_40 = version_40
00058 if version_40 is False:
00059 self.warning_count = warning_count
00060 else:
00061 self.version_40 = version_40
00062 if ord(packed[0]) != 0:
00063 raise BadFieldCount('Expected 0, received ' + str(ord(packed[0])))
00064 self.affected_rows = ord(packed[1])
00065 self.insert_id = ord(packed[2])
00066 if version_40 is True:
00067 if len(packed) == 3:
00068 self.status = 0
00069 self.message = ''
00070 else:
00071 data = struct.unpack('<H', packed[3:5])
00072 self.status = data[0]
00073 self.message = packed[5:]
00074 else:
00075 data = struct.unpack('<HH', packed[3:7])
00076 self.status = data[0]
00077 self.warning_count = data[1]
00078 self.message = packed[7:]
00079
00080 def __str__(self):
00081 if self.version_40 is True:
00082 return '''OkResult
00083 affected_rows = %s
00084 insert_id = %s
00085 status = %s
00086 message = %s
00087 version_40 = %s
00088 ''' % (self.affected_rows, self.insert_id, self.status, self.message,
00089 self.version_40)
00090 else:
00091 return '''OkResult
00092 affected_rows = %s
00093 insert_id = %s
00094 status = %s
00095 warning_count = %s
00096 message = %s
00097 version_40 = %s
00098 ''' % (self.affected_rows, self.insert_id, self.status, self.warning_count,
00099 self.message, self.version_40)
00100
00101 class TestOkResult(unittest.TestCase):
00102
00103 def testDefaultInit(self):
00104 result = OkResult()
00105 self.assertEqual(result.affected_rows, 0)
00106 self.assertEqual(result.insert_id, 0)
00107 self.assertEqual(result.status, 0)
00108 self.assertEqual(result.warning_count, 0)
00109 self.assertEqual(result.message, '')
00110 self.assertEqual(result.version_40, False)
00111 result.__str__()
00112
00113 def testDefaultInit40(self):
00114 result = OkResult(version_40=True)
00115 self.assertEqual(result.affected_rows, 0)
00116 self.assertEqual(result.insert_id, 0)
00117 self.assertEqual(result.status, 0)
00118 self.assertEqual(result.message, '')
00119 self.assertEqual(result.version_40, True)
00120 result.__str__()
00121
00122 def testKeywordInit(self):
00123 result = OkResult(affected_rows=3, insert_id=5, status=2,
00124 warning_count=7, message='test', version_40=False)
00125 self.assertEqual(result.affected_rows, 3)
00126 self.assertEqual(result.insert_id, 5)
00127 self.assertEqual(result.status, 2)
00128 self.assertEqual(result.warning_count, 7)
00129 self.assertEqual(result.message, 'test')
00130 self.assertEqual(result.version_40, False)
00131
00132 def testUnpackInit(self):
00133 data = struct.pack('BBB', 0, 3, 5)
00134 data += struct.pack('<HH', 2, 7)
00135 data += 'test'
00136
00137 result = OkResult(data)
00138 self.assertEqual(result.affected_rows, 3)
00139 self.assertEqual(result.insert_id, 5)
00140 self.assertEqual(result.status, 2)
00141 self.assertEqual(result.warning_count, 7)
00142 self.assertEqual(result.message, 'test')
00143 self.assertEqual(result.version_40, False)
00144 result.__str__()
00145
00146 def testUnpackInit40(self):
00147 data = struct.pack('BBB', 0, 3, 5)
00148 data += struct.pack('<H', 2)
00149 data += 'test'
00150
00151 result = OkResult(data, version_40=True)
00152 self.assertEqual(result.affected_rows, 3)
00153 self.assertEqual(result.insert_id, 5)
00154 self.assertEqual(result.status, 2)
00155 self.assertEqual(result.message, 'test')
00156 self.assertEqual(result.version_40, True)
00157 result.__str__()
00158
00159 class ErrorResult(object):
00160 '''This class represents an error result packet sent from the server.'''
00161
00162 def __init__(self, packed=None, error_code=0, sqlstate_marker='#',
00163 sqlstate='XXXXX', message='', version_40=False):
00164 if packed is None:
00165 self.error_code = error_code
00166 self.message = message
00167 self.version_40 = version_40
00168 if version_40 is False:
00169 self.sqlstate_marker = sqlstate_marker
00170 self.sqlstate = sqlstate
00171 else:
00172 self.version_40 = version_40
00173 if ord(packed[0]) != 255:
00174 raise BadFieldCount('Expected 255, received ' + str(ord(packed[0])))
00175 data = struct.unpack('<H', packed[1:3])
00176 self.error_code = data[0]
00177 if version_40 is True:
00178 self.message = packed[3:]
00179 else:
00180 self.sqlstate_marker = packed[3]
00181 self.sqlstate = packed[4:9]
00182 self.message = packed[9:]
00183
00184 def __str__(self):
00185 if self.version_40 is True:
00186 return '''ErrorResult
00187 error_code = %s
00188 message = %s
00189 version_40 = %s
00190 ''' % (self.error_code, self.message, self.version_40)
00191 else:
00192 return '''ErrorResult
00193 error_code = %s
00194 sqlstate_marker = %s
00195 sqlstate = %s
00196 message = %s
00197 version_40 = %s
00198 ''' % (self.error_code, self.sqlstate_marker, self.sqlstate, self.message,
00199 self.version_40)
00200
00201 class TestErrorResult(unittest.TestCase):
00202
00203 def testDefaultInit(self):
00204 result = ErrorResult()
00205 self.assertEqual(result.error_code, 0)
00206 self.assertEqual(result.sqlstate_marker, '#')
00207 self.assertEqual(result.sqlstate, 'XXXXX')
00208 self.assertEqual(result.message, '')
00209 self.assertEqual(result.version_40, False)
00210 result.__str__()
00211
00212 def testDefaultInit40(self):
00213 result = ErrorResult(version_40=True)
00214 self.assertEqual(result.error_code, 0)
00215 self.assertEqual(result.message, '')
00216 self.assertEqual(result.version_40, True)
00217 result.__str__()
00218
00219 def testKeywordInit(self):
00220 result = ErrorResult(error_code=3, sqlstate_marker='@', sqlstate='ABCDE',
00221 message='test', version_40=False)
00222 self.assertEqual(result.error_code, 3)
00223 self.assertEqual(result.sqlstate_marker, '@')
00224 self.assertEqual(result.sqlstate, 'ABCDE')
00225 self.assertEqual(result.message, 'test')
00226 self.assertEqual(result.version_40, False)
00227 result.__str__()
00228
00229 def testUnpackInit(self):
00230 data = chr(255)
00231 data += struct.pack('<H', 1234)
00232 data += '#ABCDE'
00233 data += 'test'
00234
00235 result = ErrorResult(data)
00236 self.assertEqual(result.error_code, 1234)
00237 self.assertEqual(result.sqlstate_marker, '#')
00238 self.assertEqual(result.sqlstate, 'ABCDE')
00239 self.assertEqual(result.message, 'test')
00240 self.assertEqual(result.version_40, False)
00241 result.__str__()
00242
00243 def testUnpackInit40(self):
00244 data = chr(255)
00245 data += struct.pack('<H', 1234)
00246 data += 'test'
00247
00248 result = ErrorResult(data, version_40=True)
00249 self.assertEqual(result.error_code, 1234)
00250 self.assertEqual(result.message, 'test')
00251 self.assertEqual(result.version_40, True)
00252 result.__str__()
00253
00254 class EofResult(object):
00255 '''This class represents an EOF result packet sent from the server.'''
00256
00257 def __init__(self, packed=None, warning_count=0, status=0, version_40=False):
00258 if packed is None:
00259 self.version_40 = version_40
00260 if self.version_40 is False:
00261 self.warning_count = warning_count
00262 self.status = status
00263 else:
00264 self.version_40 = version_40
00265 if ord(packed[0]) != 254:
00266 raise BadFieldCount('Expected 254, received ' + str(ord(packed[0])))
00267 if version_40 is False:
00268 data = struct.unpack('<HH', packed[1:])
00269 self.warning_count = data[0]
00270 self.status = data[1]
00271
00272 def __str__(self):
00273 if self.version_40 is True:
00274 return '''EofResult
00275 version_40 = %s
00276 ''' % self.version_40
00277 else:
00278 return '''EofResult
00279 warning_count = %s
00280 status = %s
00281 version_40 = %s
00282 ''' % (self.warning_count, self.status, self.version_40)
00283
00284 class TestEofResult(unittest.TestCase):
00285
00286 def testDefaultInit(self):
00287 result = EofResult()
00288 self.assertEqual(result.warning_count, 0)
00289 self.assertEqual(result.status, 0)
00290 self.assertEqual(result.version_40, False)
00291 result.__str__()
00292
00293 def testDefaultInit40(self):
00294 result = EofResult(version_40=True)
00295 self.assertEqual(result.version_40, True)
00296 result.__str__()
00297
00298 def testKeywordInit(self):
00299 result = EofResult(warning_count=3, status=5, version_40=False)
00300 self.assertEqual(result.warning_count, 3)
00301 self.assertEqual(result.status, 5)
00302 self.assertEqual(result.version_40, False)
00303 result.__str__()
00304
00305 def testUnpackInit(self):
00306 data = chr(254)
00307 data += struct.pack('<HH', 3, 5)
00308
00309 result = EofResult(data)
00310 self.assertEqual(result.warning_count, 3)
00311 self.assertEqual(result.status, 5)
00312 self.assertEqual(result.version_40, False)
00313 result.__str__()
00314
00315 def testUnpackInit40(self):
00316 result = EofResult(chr(254), version_40=True)
00317 self.assertEqual(result.version_40, True)
00318 result.__str__()
00319
00320 class CountResult(object):
00321 '''This class represents an count result packet sent from the server.'''
00322
00323 def __init__(self, packed=None, count=0):
00324 if packed is None:
00325 self.count = count
00326 else:
00327 self.count = ord(packed[0])
00328 if self.count == 0 or self.count > 253:
00329 raise BadFieldCount('Expected 1-253, received ' + str(ord(packed[0])))
00330
00331 def __str__(self):
00332 return '''CountResult
00333 count = %s
00334 ''' % self.count
00335
00336 class TestCountResult(unittest.TestCase):
00337
00338 def testDefaultInit(self):
00339 result = CountResult()
00340 self.assertEqual(result.count, 0)
00341 result.__str__()
00342
00343 def testKeywordInit(self):
00344 result = CountResult(count=3)
00345 self.assertEqual(result.count, 3)
00346 result.__str__()
00347
00348 def testUnpackInit(self):
00349 result = CountResult("\x03")
00350 self.assertEqual(result.count, 3)
00351 result.__str__()
00352
00353 def create_result(packed, version_40=False):
00354 '''This function creates the appropriate result object instance depending on
00355 first byte.'''
00356 count = ord(packed[0])
00357 if count == 0:
00358 return OkResult(packed, version_40=version_40)
00359 if count == 254:
00360 return EofResult(packed, version_40=version_40)
00361 if count == 255:
00362 return ErrorResult(packed, version_40=version_40)
00363 return CountResult(packed)
00364
00365 if __name__ == '__main__':
00366 unittest.main()