# MySQL Connector/Python - MySQL driver written in Python. # Copyright (c) 2009,2011, Oracle and/or its affiliates. All rights reserved. # Use is subject to license terms. (See COPYING) # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation. # # There are special exceptions to the terms and conditions of the GNU # General Public License as it is applied to this software. View the # full text of the exception in file EXCEPTIONS-CLIENT in the directory # of this software distribution or see the FOSS License Exception at # www.mysql.com. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA """Implementing communication to MySQL servers """ import socket import struct import os import weakref from collections import deque import zlib try: import ssl except ImportError: pass import constants import conversion import protocol import errors import utils import cursor MAX_PACKET_LENGTH = 16777215 class MySQLBaseSocket(object): """Base class for MySQL Connections subclasses. Should not be used directly but overloaded, changing the open_connection part. Examples of subclasses are MySQLTCPSocket MySQLUnixSocket """ def __init__(self): self.sock = None # holds the socket connection self.connection_timeout = None self.buffer = deque() self.recvsize = 8192 self.send = self.send_plain self.recv = self.recv_plain def open_connection(self): pass def close_connection(self): try: self.sock.close() except: pass def get_address(self): pass def _prepare_packets(self, buf, pktnr): pkts = [] pllen = len(buf) while pllen > MAX_PACKET_LENGTH: pkts.append('\xff\xff\xff' + struct.pack(' 16777215: pkts = self._prepare_packets(buf,pktnr) tmpbuf = ''.join(pkts) del pkts seqid = 0 zbuf = zlib.compress(tmpbuf[:16384]) zpkts.append(struct.pack(' MAX_PACKET_LENGTH: zbuf = zlib.compress(tmpbuf[:MAX_PACKET_LENGTH]) zpkts.append(struct.pack(' 50: zbuf = zlib.compress(pkt) zpkts.append(struct.pack('= 4: pktsize = struct.unpack(" 0 and totalsize >= pktsize+4: size = pktsize+4 self.buffer.append(buf[0:size]) buf = buf[size:] pktsize = 0 if not buf: try: buf = self.buffer.popleft() if buf[4] == '\xff': errors.raise_error(buf) else: return buf except IndexError, e: break elif totalsize < pktsize+4: buf += self.sock.recv(self.recvsize) except socket.timeout, e: raise errors.InterfaceError(errno=2013) except socket.error, e: raise errors.InterfaceError(errno=2055, values=dict(socketaddr=self.get_address(),errno=e.errno)) except: raise def recv_compressed(self): try: return self.buffer.popleft() except IndexError: pass pkts = [] zpktsize = 0 try: buf = self.sock.recv(self.recvsize) while buf: totalsize = len(buf) if zpktsize == 0 and totalsize >= 7: zpktsize = struct.unpack(" 0 and totalsize >= zpktsize+7: size = zpktsize+7 pkts.append(buf[0:size]) buf = buf[size:] zpktsize = 0 # Keep reading for packets that were to big if pktsize == 16384: buf = self.sock.recv(self.recvsize) elif not buf: break zpktsize = 0 elif totalsize < pktsize+7: buf += self.sock.recv(self.recvsize) except socket.timeout, e: raise errors.InterfaceError(errno=2013) except socket.error, e: raise errors.InterfaceError(errno=2055, values=dict(socketaddr=self.get_address(),errno=e.errno)) except: raise bigbuf = '' tmp = [] for pkt in pkts: pktsize = struct.unpack(" 0: self.connect(*args, **kwargs) def connect(self, database=None, user='', password='', host='127.0.0.1', port=3306, unix_socket=None, use_unicode=True, charset='utf8', collation=None, autocommit=False, time_zone=None, sql_mode=None, get_warnings=False, raise_on_warnings=False, connection_timeout=None, client_flags=0, buffered=False, raw=False, ssl_ca=None, ssl_cert=None, ssl_key=None, passwd=None, db=None, connect_timeout=None, dsn=None): if db and not database: database = db if passwd and not password: password = passwd if connect_timeout and not connection_timeout: connection_timeout = connect_timeout if dsn is not None: errors.NotSupportedError("Data source name is not supported") self._server_host = host self._server_port = port self._unix_socket = unix_socket if database is not None: self._database = database.strip() else: self._database = None self._username = user self.set_warnings(get_warnings,raise_on_warnings) self.connection_timeout = connection_timeout self.buffered = buffered self.raw = raw self.use_unicode = use_unicode self.set_client_flags(client_flags) self._charset = constants.CharacterSet.get_charset_info(charset)[0] if user or password: self.set_login(user, password) self.disconnect() self._open_connection(username=user, password=password, database=database, client_flags=self.client_flags, charset=charset, ssl=(ssl_ca, ssl_cert, ssl_key)) self._post_connection(time_zone=time_zone, sql_mode=sql_mode, collation=collation) def _get_connection(self, prtcls=None): """Get connection based on configuration This method will return the appropriated connection object using the connection parameters. Returns subclass of MySQLBaseSocket. """ conn = None if self.unix_socket and os.name != 'nt': conn = MySQLUnixSocket(unix_socket=self.unix_socket) else: conn = MySQLTCPSocket(host=self.server_host, port=self.server_port) conn.set_connection_timeout(self.connection_timeout) return conn def _open_connection(self, username=None, password=None, database=None, client_flags=None, charset=None, ssl=None): """Opens the connection Open the connection, check the MySQL version, and set the protocol. """ try: self.protocol = protocol.MySQLProtocol(self._get_connection()) self.protocol.do_handshake() version = self.protocol.server_version if version < (4,1): raise errors.InterfaceError( "MySQL Version %s is not supported." % version) if client_flags & constants.ClientFlag.SSL: self.protocol.conn.set_ssl(*ssl) self.protocol.do_auth(username, password, database, client_flags, self._charset) (self._charset, self.charset_name, c) = \ constants.CharacterSet.get_charset_info(charset) self.set_converter_class(conversion.MySQLConverter) if client_flags & constants.ClientFlag.COMPRESS: self.protocol.conn.recv = self.protocol.conn.recv_compressed self.protocol.conn.send = self.protocol.conn.send_compressed except: raise def _post_connection(self, time_zone=None, autocommit=False, sql_mode=None, collation=None): """Post connection session setup Should be called after a connection was established""" try: if collation is not None: self.collation = collation self.autocommit = autocommit if time_zone is not None: self.time_zone = time_zone if sql_mode is not None: self.sql_mode = sql_mode except: raise def is_connected(self): """ Check whether we are connected to the MySQL server. """ return self.protocol.cmd_ping() ping = is_connected def disconnect(self): """ Disconnect from the MySQL server. """ if not self.protocol: return if self.protocol.conn.sock is not None: self.protocol.cmd_quit() try: self.protocol.conn.close_connection() except: pass self.protocol = None def set_converter_class(self, convclass): """ Set the converter class to be used. This should be a class overloading methods and members of conversion.MySQLConverter. """ self.converter_class = convclass self.converter = convclass(self.charset_name, self.use_unicode) def get_server_version(self): """Returns the server version as a tuple""" try: return self.protocol.server_version except: pass return None def get_server_info(self): """Returns the server version as a string""" return self.protocol.server_version_original @property def connection_id(self): """MySQL connection ID""" threadid = None try: threadid = self.protocol.server_threadid except: pass return threadid def set_login(self, username=None, password=None): """Set login information for MySQL Set the username and/or password for the user connecting to the MySQL Server. """ if username is not None: self.username = username.strip() else: self.username = '' if password is not None: self.password = password.strip() else: self.password = '' def set_unicode(self, value=True): """Toggle unicode mode Set whether we return string fields as unicode or not. Default is True. """ self.use_unicode = value if self.converter: self.converter.set_unicode(value) def set_charset(self, charset): try: (idx, charset_name, c) = \ constants.CharacterSet.get_charset_info(charset) self._execute_query("SET NAMES '%s'" % charset_name) except: raise else: self._charset = idx self.charset_name = charset_name self.converter.set_charset(charset_name) def get_charset(self): return self._info_query( "SELECT @@session.character_set_connection")[0] charset = property(get_charset, set_charset, doc="Character set for this connection") def set_collation(self, collation): try: self._execute_query( "SET @@session.collation_connection = '%s'" % collation) except: raise def get_collation(self): return self._info_query( "SELECT @@session.collation_connection")[0] collation = property(get_collation, set_collation, doc="Collation for this connection") def set_warnings(self, fetch=False, raise_on_warnings=False): """Set how to handle warnings coming from MySQL Set wheter we should get warnings whenever an operation produced some. If you set raise_on_warnings to True, any warning will be raised as a DataError exception. """ if raise_on_warnings is True: self.get_warnings = True self.raise_on_warnings = True else: self.get_warnings = fetch self.raise_on_warnings = False def set_client_flags(self, flags): """Set the client flags The flags-argument can be either an int or a list (or tuple) of ClientFlag-values. If it is an integer, it will set client_flags to flags as is. If flags is a list (or tuple), each flag will be set or unset when it's negative. set_client_flags([ClientFlag.FOUND_ROWS,-ClientFlag.LONG_FLAG]) Returns self.client_flags """ if isinstance(flags,int) and flags > 0: self.client_flags = flags else: if isinstance(flags,(tuple,list)): for f in flags: if f < 0: self.unset_client_flag(abs(f)) else: self.set_client_flag(f) return self.client_flags def set_client_flag(self, flag): if flag > 0: self.client_flags |= flag def unset_client_flag(self, flag): if flag > 0: self.client_flags &= ~flag def isset_client_flag(self, flag): if (self.client_flags & flag) > 0: return True return False @property def user(self): """User used while connecting to MySQL""" return self._username @property def server_host(self): """MySQL server IP address or name""" return self._server_host @property def server_port(self): "MySQL server TCP/IP port" return self._server_port @property def unix_socket(self): "MySQL Unix socket file location" return self._unix_socket def set_database(self, value): try: self.protocol.cmd_query("USE %s" % value) except: raise def get_database(self): """Get the current database""" return self._info_query("SELECT DATABASE()")[0] database = property(get_database, set_database, doc="Current database") def set_time_zone(self, value): try: self.protocol.cmd_query("SET @@session.time_zone = %s" % value) except: raise def get_time_zone(self): return self._info_query("SELECT @@session.time_zone")[0] time_zone = property(get_time_zone, set_time_zone, doc="time_zone value for current MySQL session") def set_sql_mode(self, value): try: self.protocol.cmd_query("SET @@session.sql_mode = %s" % value) except: raise def get_sql_mode(self): return self._info_query("SELECT @@session.sql_mode")[0] sql_mode = property(get_sql_mode, set_sql_mode, doc="sql_mode value for current MySQL session") def set_autocommit(self, value): try: if value: s = 'ON' else: s = 'OFF' self._execute_query("SET @@session.autocommit = %s" % s) except: raise def get_autocommit(self): value = self._info_query("SELECT @@session.autocommit")[0] if value == 1: return True return False autocommit = property(get_autocommit, set_autocommit, doc="autocommit value for current MySQL session") def close(self): del self.cursors[:] self.disconnect() def remove_cursor(self, c): try: self.cursors.remove(c) except ValueError: raise errors.ProgrammingError( "Cursor could not be removed.") def cursor(self, buffered=None, raw=None, cursor_class=None): """Instantiates and returns a cursor By default, MySQLCursor is returned. Depending on the options while connecting, a buffered and/or raw cursor instantiated instead. It is possible to also give a custom cursor through the cursor_class paramter, but it needs to be a subclass of mysql.connector.cursor.CursorBase. Returns a cursor-object """ if cursor_class is not None: if not issubclass(cursor_class, cursor.CursorBase): raise errors.ProgrammingError( "Cursor class needs be subclass of cursor.CursorBase") c = (cursor_class)(self) else: buffered = buffered or self.buffered raw = raw or self.raw t = 0 if buffered is True: t |= 1 if raw is True: t |= 2 types = { 0 : cursor.MySQLCursor, 1 : cursor.MySQLCursorBuffered, 2 : cursor.MySQLCursorRaw, 3 : cursor.MySQLCursorBufferedRaw, } c = (types[t])(self) if c not in self.cursors: self.cursors.append(c) return c def commit(self): """Commit current transaction""" self._execute_query("COMMIT") def rollback(self): """Rollback current transaction""" self._execute_query("ROLLBACK") def _execute_query(self, query): if self.unread_result is True: raise errors.InternalError("Unread result found.") self.protocol.cmd_query(query) def _info_query(self, query): try: cur = self.cursor(buffered=True) cur.execute(query) row = cur.fetchone() cur.close() except: raise return row