Drizzled Public API Documentation

result.py
00001 #!/usr/bin/env python
00002 #
00003 # Drizzle Client & Protocol Library
00004 # 
00005 # Copyright (C) 2008 Eric Day (eday@oddments.org)
00006 # All rights reserved.
00007 #
00008 # Redistribution and use in source and binary forms, with or without
00009 # modification, are permitted provided that the following conditions are
00010 # met:
00011 #
00012 #     * Redistributions of source code must retain the above copyright
00013 # notice, this list of conditions and the following disclaimer.
00014 #
00015 #     * Redistributions in binary form must reproduce the above
00016 # copyright notice, this list of conditions and the following disclaimer
00017 # in the documentation and/or other materials provided with the
00018 # distribution.
00019 #
00020 #     * The names of its contributors may not be used to endorse or
00021 # promote products derived from this software without specific prior
00022 # written permission.
00023 #
00024 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
00025 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
00026 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
00027 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
00028 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
00029 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
00030 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
00031 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
00032 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00033 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
00034 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00035 #
00036 
00037 '''
00038 MySQL Protocol Result Objects
00039 '''
00040 
00041 import struct
00042 import unittest
00043 
00044 class BadFieldCount(Exception):
00045   pass
00046 
00047 class OkResult(object):
00048   '''This class represents an OK result packet sent from the server.'''
00049 
00050   def __init__(self, packed=None, affected_rows=0, insert_id=0, status=0,
00051                warning_count=0, message='', version_40=False):
00052     if packed is None:
00053       self.affected_rows = affected_rows
00054       self.insert_id = insert_id
00055       self.status = status
00056       self.message = message
00057       self.version_40 = version_40
00058       if version_40 is False:
00059         self.warning_count = warning_count
00060     else:
00061       self.version_40 = version_40
00062       if ord(packed[0]) != 0:
00063         raise BadFieldCount('Expected 0, received ' + str(ord(packed[0])))
00064       self.affected_rows = ord(packed[1])
00065       self.insert_id = ord(packed[2])
00066       if version_40 is True:
00067         if len(packed) == 3:
00068           self.status = 0
00069           self.message = ''
00070         else:
00071           data = struct.unpack('<H', packed[3:5])
00072           self.status = data[0]
00073           self.message = packed[5:]
00074       else:
00075         data = struct.unpack('<HH', packed[3:7])
00076         self.status = data[0]
00077         self.warning_count = data[1]
00078         self.message = packed[7:]
00079 
00080   def __str__(self):
00081     if self.version_40 is True:
00082       return '''OkResult
00083   affected_rows = %s
00084   insert_id = %s
00085   status = %s
00086   message = %s
00087   version_40 = %s
00088 ''' % (self.affected_rows, self.insert_id, self.status, self.message,
00089        self.version_40)
00090     else:
00091       return '''OkResult
00092   affected_rows = %s
00093   insert_id = %s
00094   status = %s
00095   warning_count = %s
00096   message = %s
00097   version_40 = %s
00098 ''' % (self.affected_rows, self.insert_id, self.status, self.warning_count,
00099        self.message, self.version_40)
00100 
00101 class TestOkResult(unittest.TestCase):
00102 
00103   def testDefaultInit(self):
00104     result = OkResult()
00105     self.assertEqual(result.affected_rows, 0)
00106     self.assertEqual(result.insert_id, 0)
00107     self.assertEqual(result.status, 0)
00108     self.assertEqual(result.warning_count, 0)
00109     self.assertEqual(result.message, '')
00110     self.assertEqual(result.version_40, False)
00111     result.__str__()
00112 
00113   def testDefaultInit40(self):
00114     result = OkResult(version_40=True)
00115     self.assertEqual(result.affected_rows, 0)
00116     self.assertEqual(result.insert_id, 0)
00117     self.assertEqual(result.status, 0)
00118     self.assertEqual(result.message, '')
00119     self.assertEqual(result.version_40, True)
00120     result.__str__()
00121 
00122   def testKeywordInit(self):
00123     result = OkResult(affected_rows=3, insert_id=5, status=2,
00124                       warning_count=7, message='test', version_40=False)
00125     self.assertEqual(result.affected_rows, 3)
00126     self.assertEqual(result.insert_id, 5)
00127     self.assertEqual(result.status, 2)
00128     self.assertEqual(result.warning_count, 7)
00129     self.assertEqual(result.message, 'test')
00130     self.assertEqual(result.version_40, False)
00131 
00132   def testUnpackInit(self):
00133     data = struct.pack('BBB', 0, 3, 5)
00134     data += struct.pack('<HH', 2, 7)
00135     data += 'test'
00136 
00137     result = OkResult(data)
00138     self.assertEqual(result.affected_rows, 3)
00139     self.assertEqual(result.insert_id, 5)
00140     self.assertEqual(result.status, 2)
00141     self.assertEqual(result.warning_count, 7)
00142     self.assertEqual(result.message, 'test')
00143     self.assertEqual(result.version_40, False)
00144     result.__str__()
00145 
00146   def testUnpackInit40(self):
00147     data = struct.pack('BBB', 0, 3, 5)
00148     data += struct.pack('<H', 2)
00149     data += 'test'
00150 
00151     result = OkResult(data, version_40=True)
00152     self.assertEqual(result.affected_rows, 3)
00153     self.assertEqual(result.insert_id, 5)
00154     self.assertEqual(result.status, 2)
00155     self.assertEqual(result.message, 'test')
00156     self.assertEqual(result.version_40, True)
00157     result.__str__()
00158 
00159 class ErrorResult(object):
00160   '''This class represents an error result packet sent from the server.'''
00161 
00162   def __init__(self, packed=None, error_code=0, sqlstate_marker='#',
00163                sqlstate='XXXXX', message='', version_40=False):
00164     if packed is None:
00165       self.error_code = error_code
00166       self.message = message
00167       self.version_40 = version_40
00168       if version_40 is False:
00169         self.sqlstate_marker = sqlstate_marker
00170         self.sqlstate = sqlstate
00171     else:
00172       self.version_40 = version_40
00173       if ord(packed[0]) != 255:
00174         raise BadFieldCount('Expected 255, received ' + str(ord(packed[0])))
00175       data = struct.unpack('<H', packed[1:3])
00176       self.error_code = data[0]
00177       if version_40 is True:
00178         self.message = packed[3:]
00179       else:
00180         self.sqlstate_marker = packed[3]
00181         self.sqlstate = packed[4:9]
00182         self.message = packed[9:]
00183 
00184   def __str__(self):
00185     if self.version_40 is True:
00186       return '''ErrorResult
00187   error_code = %s
00188   message = %s
00189   version_40 = %s
00190 ''' % (self.error_code, self.message, self.version_40)
00191     else:
00192       return '''ErrorResult
00193   error_code = %s
00194   sqlstate_marker = %s
00195   sqlstate = %s
00196   message = %s
00197   version_40 = %s
00198 ''' % (self.error_code, self.sqlstate_marker, self.sqlstate, self.message,
00199        self.version_40)
00200 
00201 class TestErrorResult(unittest.TestCase):
00202 
00203   def testDefaultInit(self):
00204     result = ErrorResult()
00205     self.assertEqual(result.error_code, 0)
00206     self.assertEqual(result.sqlstate_marker, '#')
00207     self.assertEqual(result.sqlstate, 'XXXXX')
00208     self.assertEqual(result.message, '')
00209     self.assertEqual(result.version_40, False)
00210     result.__str__()
00211 
00212   def testDefaultInit40(self):
00213     result = ErrorResult(version_40=True)
00214     self.assertEqual(result.error_code, 0)
00215     self.assertEqual(result.message, '')
00216     self.assertEqual(result.version_40, True)
00217     result.__str__()
00218 
00219   def testKeywordInit(self):
00220     result = ErrorResult(error_code=3, sqlstate_marker='@', sqlstate='ABCDE',
00221                          message='test', version_40=False)
00222     self.assertEqual(result.error_code, 3)
00223     self.assertEqual(result.sqlstate_marker, '@')
00224     self.assertEqual(result.sqlstate, 'ABCDE')
00225     self.assertEqual(result.message, 'test')
00226     self.assertEqual(result.version_40, False)
00227     result.__str__()
00228 
00229   def testUnpackInit(self):
00230     data = chr(255)
00231     data += struct.pack('<H', 1234)
00232     data += '#ABCDE'
00233     data += 'test'
00234 
00235     result = ErrorResult(data)
00236     self.assertEqual(result.error_code, 1234)
00237     self.assertEqual(result.sqlstate_marker, '#')
00238     self.assertEqual(result.sqlstate, 'ABCDE')
00239     self.assertEqual(result.message, 'test')
00240     self.assertEqual(result.version_40, False)
00241     result.__str__()
00242 
00243   def testUnpackInit40(self):
00244     data = chr(255)
00245     data += struct.pack('<H', 1234)
00246     data += 'test'
00247 
00248     result = ErrorResult(data, version_40=True)
00249     self.assertEqual(result.error_code, 1234)
00250     self.assertEqual(result.message, 'test')
00251     self.assertEqual(result.version_40, True)
00252     result.__str__()
00253 
00254 class EofResult(object):
00255   '''This class represents an EOF result packet sent from the server.'''
00256 
00257   def __init__(self, packed=None, warning_count=0, status=0, version_40=False):
00258     if packed is None:
00259       self.version_40 = version_40
00260       if self.version_40 is False:
00261         self.warning_count = warning_count
00262         self.status = status
00263     else:
00264       self.version_40 = version_40
00265       if ord(packed[0]) != 254:
00266         raise BadFieldCount('Expected 254, received ' + str(ord(packed[0])))
00267       if version_40 is False:
00268         data = struct.unpack('<HH', packed[1:])
00269         self.warning_count = data[0]
00270         self.status = data[1]
00271 
00272   def __str__(self):
00273     if self.version_40 is True:
00274       return '''EofResult
00275   version_40 = %s
00276 ''' % self.version_40
00277     else:
00278       return '''EofResult
00279   warning_count = %s
00280   status = %s
00281   version_40 = %s
00282 ''' % (self.warning_count, self.status, self.version_40)
00283 
00284 class TestEofResult(unittest.TestCase):
00285 
00286   def testDefaultInit(self):
00287     result = EofResult()
00288     self.assertEqual(result.warning_count, 0)
00289     self.assertEqual(result.status, 0)
00290     self.assertEqual(result.version_40, False)
00291     result.__str__()
00292 
00293   def testDefaultInit40(self):
00294     result = EofResult(version_40=True)
00295     self.assertEqual(result.version_40, True)
00296     result.__str__()
00297 
00298   def testKeywordInit(self):
00299     result = EofResult(warning_count=3, status=5, version_40=False)
00300     self.assertEqual(result.warning_count, 3)
00301     self.assertEqual(result.status, 5)
00302     self.assertEqual(result.version_40, False)
00303     result.__str__()
00304 
00305   def testUnpackInit(self):
00306     data = chr(254)
00307     data += struct.pack('<HH', 3, 5)
00308 
00309     result = EofResult(data)
00310     self.assertEqual(result.warning_count, 3)
00311     self.assertEqual(result.status, 5)
00312     self.assertEqual(result.version_40, False)
00313     result.__str__()
00314 
00315   def testUnpackInit40(self):
00316     result = EofResult(chr(254), version_40=True)
00317     self.assertEqual(result.version_40, True)
00318     result.__str__()
00319 
00320 class CountResult(object):
00321   '''This class represents an count result packet sent from the server.'''
00322 
00323   def __init__(self, packed=None, count=0):
00324     if packed is None:
00325       self.count = count
00326     else:
00327       self.count = ord(packed[0])
00328       if self.count == 0 or self.count > 253:
00329         raise BadFieldCount('Expected 1-253, received ' + str(ord(packed[0])))
00330 
00331   def __str__(self):
00332     return '''CountResult
00333   count = %s
00334 ''' % self.count
00335 
00336 class TestCountResult(unittest.TestCase):
00337 
00338   def testDefaultInit(self):
00339     result = CountResult()
00340     self.assertEqual(result.count, 0)
00341     result.__str__()
00342 
00343   def testKeywordInit(self):
00344     result = CountResult(count=3)
00345     self.assertEqual(result.count, 3)
00346     result.__str__()
00347 
00348   def testUnpackInit(self):
00349     result = CountResult("\x03")
00350     self.assertEqual(result.count, 3)
00351     result.__str__()
00352 
00353 def create_result(packed, version_40=False):
00354   '''This function creates the appropriate result object instance depending on
00355      first byte.'''
00356   count = ord(packed[0])
00357   if count == 0:
00358     return OkResult(packed, version_40=version_40)
00359   if count == 254:
00360     return EofResult(packed, version_40=version_40)
00361   if count == 255:
00362     return ErrorResult(packed, version_40=version_40)
00363   return CountResult(packed)
00364 
00365 if __name__ == '__main__':
00366   unittest.main()