Files
gtav-src/tools_ng/bin/mpLogGrabber/mysql/connector/connection.py
T
2025-09-29 00:52:08 +02:00

775 lines
26 KiB
Python
Executable File

# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2009,2011, Oracle and/or its affiliates. All rights reserved.
# Use is subject to license terms. (See COPYING)
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# There are special exceptions to the terms and conditions of the GNU
# General Public License as it is applied to this software. View the
# full text of the exception in file EXCEPTIONS-CLIENT in the directory
# of this software distribution or see the FOSS License Exception at
# www.mysql.com.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
"""Implementing communication to MySQL servers
"""
import socket
import struct
import os
import weakref
from collections import deque
import zlib
try:
import ssl
except ImportError:
pass
import constants
import conversion
import protocol
import errors
import utils
import cursor
MAX_PACKET_LENGTH = 16777215
class MySQLBaseSocket(object):
"""Base class for MySQL Connections subclasses.
Should not be used directly but overloaded, changing the
open_connection part. Examples of subclasses are
MySQLTCPSocket
MySQLUnixSocket
"""
def __init__(self):
self.sock = None # holds the socket connection
self.connection_timeout = None
self.buffer = deque()
self.recvsize = 8192
self.send = self.send_plain
self.recv = self.recv_plain
def open_connection(self):
pass
def close_connection(self):
try:
self.sock.close()
except:
pass
def get_address(self):
pass
def _prepare_packets(self, buf, pktnr):
pkts = []
pllen = len(buf)
while pllen > MAX_PACKET_LENGTH:
pkts.append('\xff\xff\xff' + struct.pack('<B',pktnr)
+ buf[:MAX_PACKET_LENGTH])
buf = buf[MAX_PACKET_LENGTH:]
pllen = len(buf)
pktnr = pktnr + 1
pkts.append(struct.pack('<I',pllen)[0:3] +
struct.pack('<B',pktnr) + buf)
return pkts
def send(self):
pass
def send_plain(self, buf, pktnr):
pkts = self._prepare_packets(buf,pktnr)
for pkt in pkts:
pktlen = len(pkt)
try:
while pktlen:
pktlen -= self.sock.send(pkt)
except Exception, e:
raise errors.OperationalError('%s' % e)
def send_compressed(self, buf, pktnr):
pllen = len(buf)
zpkts = []
if pllen > 16777215:
pkts = self._prepare_packets(buf,pktnr)
tmpbuf = ''.join(pkts)
del pkts
seqid = 0
zbuf = zlib.compress(tmpbuf[:16384])
zpkts.append(struct.pack('<I',len(zbuf))[0:3]
+ struct.pack('<B',seqid) + '\x00\x40\x00' + zbuf)
tmpbuf = tmpbuf[16384:]
pllen = len(tmpbuf)
seqid = seqid + 1
while pllen > MAX_PACKET_LENGTH:
zbuf = zlib.compress(tmpbuf[:MAX_PACKET_LENGTH])
zpkts.append(struct.pack('<I',len(zbuf))[0:3]
+ struct.pack('<B',seqid) + '\xff\xff\xff' + zbuf)
tmpbuf = tmpbuf[MAX_PACKET_LENGTH:]
pllen = len(tmpbuf)
seqid = seqid + 1
if tmpbuf:
zbuf = zlib.compress(tmpbuf)
zpkts.append(struct.pack('<I',len(zbuf))[0:3] +
struct.pack('<B',seqid) + struct.pack('<I',pllen)[0:3]
+ zbuf)
del tmpbuf
else:
pkt = (struct.pack('<I',pllen)[0:3] +
struct.pack('<B',pktnr) + buf)
pllen = len(pkt)
if pllen > 50:
zbuf = zlib.compress(pkt)
zpkts.append(struct.pack('<I',len(zbuf))[0:3] +
struct.pack('<B',0) + struct.pack('<I',pllen)[0:3] +
zbuf)
else:
zpkts.append(struct.pack('<I',pllen)[0:3] +
struct.pack('<B',0) + struct.pack('<I',0)[0:3] + pkt)
for zpkt in zpkts:
zpktlen = len(zpkt)
try:
while zpktlen:
zpktlen -= self.sock.send(zpkt)
except Exception, e:
raise errors.OperationalError('%s' % e)
def recv(self):
pass
def recv_plain(self):
try:
buf = self.buffer.popleft()
if buf[4] == '\xff':
errors.raise_error(buf)
else:
return buf
except IndexError:
pass
pktsize = 0
try:
buf = self.sock.recv(self.recvsize)
while buf:
totalsize = len(buf)
if pktsize == 0 and totalsize >= 4:
pktsize = struct.unpack("<I", buf[0:3]+'\x00')[0]
if pktsize > 0 and totalsize >= pktsize+4:
size = pktsize+4
self.buffer.append(buf[0:size])
buf = buf[size:]
pktsize = 0
if not buf:
try:
buf = self.buffer.popleft()
if buf[4] == '\xff':
errors.raise_error(buf)
else:
return buf
except IndexError, e:
break
elif totalsize < pktsize+4:
buf += self.sock.recv(self.recvsize)
except socket.timeout, e:
raise errors.InterfaceError(errno=2013)
except socket.error, e:
raise errors.InterfaceError(errno=2055,
values=dict(socketaddr=self.get_address(),errno=e.errno))
except:
raise
def recv_compressed(self):
try:
return self.buffer.popleft()
except IndexError:
pass
pkts = []
zpktsize = 0
try:
buf = self.sock.recv(self.recvsize)
while buf:
totalsize = len(buf)
if zpktsize == 0 and totalsize >= 7:
zpktsize = struct.unpack("<I", buf[0:3]+'\x00')[0]
pktsize = struct.unpack("<I", buf[4:4+3]+'\x00')[0]
if zpktsize > 0 and totalsize >= zpktsize+7:
size = zpktsize+7
pkts.append(buf[0:size])
buf = buf[size:]
zpktsize = 0
# Keep reading for packets that were to big
if pktsize == 16384:
buf = self.sock.recv(self.recvsize)
elif not buf:
break
zpktsize = 0
elif totalsize < pktsize+7:
buf += self.sock.recv(self.recvsize)
except socket.timeout, e:
raise errors.InterfaceError(errno=2013)
except socket.error, e:
raise errors.InterfaceError(errno=2055,
values=dict(socketaddr=self.get_address(),errno=e.errno))
except:
raise
bigbuf = ''
tmp = []
for pkt in pkts:
pktsize = struct.unpack("<I", pkt[4:7]+'\x00')[0]
if pktsize == 0:
tmp.append(pkt[7:])
else:
tmp.append(zlib.decompress(pkt[7:]))
pktparts = ()
bigbuf = ''.join(tmp)
del tmp
while bigbuf:
pktsize = struct.unpack("<I", bigbuf[0:3]+'\x00')[0]
pktnr = int(ord(bigbuf[3]))
self.buffer.append(bigbuf[0:pktsize+4])
bigbuf = bigbuf[pktsize+4:]
try:
return self.buffer.popleft()
except IndexError:
pass
def set_connection_timeout(self, timeout):
self.connection_timeout = timeout
def set_ssl(self, ssl_ca, ssl_cert, ssl_key):
self._ssl_ca = ssl_ca
self._ssl_cert = ssl_cert
self._ssl_key = ssl_key
def switch_to_ssl(self):
try:
self.sock = ssl.wrap_socket(self.sock,
keyfile=self._ssl_key, certfile=self._ssl_cert,
ca_certs=self._ssl_ca, cert_reqs=ssl.CERT_REQUIRED,
do_handshake_on_connect=False,
ssl_version=ssl.PROTOCOL_TLSv1)
self.sock.do_handshake()
except NameError:
raise errors.NotSupportedError(
"Python installation has no SSL support")
except ssl.SSLError, e:
raise errors.InterfaceError("SSL error: %s" % e)
class MySQLUnixSocket(MySQLBaseSocket):
"""Opens a connection through the UNIX socket of the MySQL Server."""
def __init__(self, unix_socket='/tmp/mysql.sock'):
MySQLBaseSocket.__init__(self)
self.unix_socket = unix_socket
def get_address(self):
return self.unix_socket
def open_connection(self):
"""Opens a UNIX socket and checks the MySQL handshake."""
try:
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.sock.settimeout(self.connection_timeout)
self.sock.connect(self.unix_socket)
except socket.error, e:
try:
m = e.errno
except:
m = e
raise errors.InterfaceError(errno=2002,
values=dict(socketaddr=self.get_address(),errno=m))
except StandardError, e:
raise errors.InterfaceError('%s' % e)
class MySQLTCPSocket(MySQLBaseSocket):
"""Opens a TCP connection to the MySQL Server."""
def __init__(self, host='127.0.0.1', port=3306):
MySQLBaseSocket.__init__(self)
self.server_host = host
self.server_port = port
def get_address(self):
return "%s:%s" % (self.server_host,self.server_port)
def open_connection(self):
"""Opens a TCP Connection and checks the MySQL handshake."""
try:
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.settimeout(self.connection_timeout)
self.sock.connect( (self.server_host, self.server_port) )
except socket.error, e:
try:
m = e.errno
except:
m = e
raise errors.InterfaceError(errno=2003,
values=dict(socketaddr=self.get_address(),errno=m))
except StandardError, e:
raise errors.InterfaceError('%s' % e)
except:
raise
class MySQLConnection(object):
"""MySQL"""
def __init__(self, *args, **kwargs):
"""Initializing"""
self.protocol = None
self.converter = None
self.cursors = []
self.client_flags = constants.ClientFlag.get_default()
self._charset = 33
self._username = ''
self._database = ''
self._server_host = '127.0.0.1'
self._server_port = 3306
self._unix_socket = None
self.client_host = ''
self.client_port = 0
self.affected_rows = 0
self.server_status = 0
self.warning_count = 0
self.field_count = 0
self.insert_id = 0
self.info_msg = ''
self.use_unicode = True
self.get_warnings = False
self.raise_on_warnings = False
self.connection_timeout = None
self.buffered = False
self.unread_result = False
self.raw = False
if len(kwargs) > 0:
self.connect(*args, **kwargs)
def connect(self, database=None, user='', password='',
host='127.0.0.1', port=3306, unix_socket=None,
use_unicode=True, charset='utf8', collation=None,
autocommit=False,
time_zone=None, sql_mode=None,
get_warnings=False, raise_on_warnings=False,
connection_timeout=None, client_flags=0,
buffered=False, raw=False,
ssl_ca=None, ssl_cert=None, ssl_key=None,
passwd=None, db=None, connect_timeout=None, dsn=None):
if db and not database:
database = db
if passwd and not password:
password = passwd
if connect_timeout and not connection_timeout:
connection_timeout = connect_timeout
if dsn is not None:
errors.NotSupportedError("Data source name is not supported")
self._server_host = host
self._server_port = port
self._unix_socket = unix_socket
if database is not None:
self._database = database.strip()
else:
self._database = None
self._username = user
self.set_warnings(get_warnings,raise_on_warnings)
self.connection_timeout = connection_timeout
self.buffered = buffered
self.raw = raw
self.use_unicode = use_unicode
self.set_client_flags(client_flags)
self._charset = constants.CharacterSet.get_charset_info(charset)[0]
if user or password:
self.set_login(user, password)
self.disconnect()
self._open_connection(username=user, password=password, database=database,
client_flags=self.client_flags, charset=charset,
ssl=(ssl_ca, ssl_cert, ssl_key))
self._post_connection(time_zone=time_zone, sql_mode=sql_mode,
collation=collation)
def _get_connection(self, prtcls=None):
"""Get connection based on configuration
This method will return the appropriated connection object using
the connection parameters.
Returns subclass of MySQLBaseSocket.
"""
conn = None
if self.unix_socket and os.name != 'nt':
conn = MySQLUnixSocket(unix_socket=self.unix_socket)
else:
conn = MySQLTCPSocket(host=self.server_host,
port=self.server_port)
conn.set_connection_timeout(self.connection_timeout)
return conn
def _open_connection(self, username=None, password=None, database=None,
client_flags=None, charset=None, ssl=None):
"""Opens the connection
Open the connection, check the MySQL version, and set the
protocol.
"""
try:
self.protocol = protocol.MySQLProtocol(self._get_connection())
self.protocol.do_handshake()
version = self.protocol.server_version
if version < (4,1):
raise errors.InterfaceError(
"MySQL Version %s is not supported." % version)
if client_flags & constants.ClientFlag.SSL:
self.protocol.conn.set_ssl(*ssl)
self.protocol.do_auth(username, password, database, client_flags,
self._charset)
(self._charset, self.charset_name, c) = \
constants.CharacterSet.get_charset_info(charset)
self.set_converter_class(conversion.MySQLConverter)
if client_flags & constants.ClientFlag.COMPRESS:
self.protocol.conn.recv = self.protocol.conn.recv_compressed
self.protocol.conn.send = self.protocol.conn.send_compressed
except:
raise
def _post_connection(self, time_zone=None, autocommit=False,
sql_mode=None, collation=None):
"""Post connection session setup
Should be called after a connection was established"""
try:
if collation is not None:
self.collation = collation
self.autocommit = autocommit
if time_zone is not None:
self.time_zone = time_zone
if sql_mode is not None:
self.sql_mode = sql_mode
except:
raise
def is_connected(self):
"""
Check whether we are connected to the MySQL server.
"""
return self.protocol.cmd_ping()
ping = is_connected
def disconnect(self):
"""
Disconnect from the MySQL server.
"""
if not self.protocol:
return
if self.protocol.conn.sock is not None:
self.protocol.cmd_quit()
try:
self.protocol.conn.close_connection()
except:
pass
self.protocol = None
def set_converter_class(self, convclass):
"""
Set the converter class to be used. This should be a class overloading
methods and members of conversion.MySQLConverter.
"""
self.converter_class = convclass
self.converter = convclass(self.charset_name, self.use_unicode)
def get_server_version(self):
"""Returns the server version as a tuple"""
try:
return self.protocol.server_version
except:
pass
return None
def get_server_info(self):
"""Returns the server version as a string"""
return self.protocol.server_version_original
@property
def connection_id(self):
"""MySQL connection ID"""
threadid = None
try:
threadid = self.protocol.server_threadid
except:
pass
return threadid
def set_login(self, username=None, password=None):
"""Set login information for MySQL
Set the username and/or password for the user connecting to
the MySQL Server.
"""
if username is not None:
self.username = username.strip()
else:
self.username = ''
if password is not None:
self.password = password.strip()
else:
self.password = ''
def set_unicode(self, value=True):
"""Toggle unicode mode
Set whether we return string fields as unicode or not.
Default is True.
"""
self.use_unicode = value
if self.converter:
self.converter.set_unicode(value)
def set_charset(self, charset):
try:
(idx, charset_name, c) = \
constants.CharacterSet.get_charset_info(charset)
self._execute_query("SET NAMES '%s'" % charset_name)
except:
raise
else:
self._charset = idx
self.charset_name = charset_name
self.converter.set_charset(charset_name)
def get_charset(self):
return self._info_query(
"SELECT @@session.character_set_connection")[0]
charset = property(get_charset, set_charset,
doc="Character set for this connection")
def set_collation(self, collation):
try:
self._execute_query(
"SET @@session.collation_connection = '%s'" % collation)
except:
raise
def get_collation(self):
return self._info_query(
"SELECT @@session.collation_connection")[0]
collation = property(get_collation, set_collation,
doc="Collation for this connection")
def set_warnings(self, fetch=False, raise_on_warnings=False):
"""Set how to handle warnings coming from MySQL
Set wheter we should get warnings whenever an operation produced some.
If you set raise_on_warnings to True, any warning will be raised
as a DataError exception.
"""
if raise_on_warnings is True:
self.get_warnings = True
self.raise_on_warnings = True
else:
self.get_warnings = fetch
self.raise_on_warnings = False
def set_client_flags(self, flags):
"""Set the client flags
The flags-argument can be either an int or a list (or tuple) of
ClientFlag-values. If it is an integer, it will set client_flags
to flags as is.
If flags is a list (or tuple), each flag will be set or unset
when it's negative.
set_client_flags([ClientFlag.FOUND_ROWS,-ClientFlag.LONG_FLAG])
Returns self.client_flags
"""
if isinstance(flags,int) and flags > 0:
self.client_flags = flags
else:
if isinstance(flags,(tuple,list)):
for f in flags:
if f < 0:
self.unset_client_flag(abs(f))
else:
self.set_client_flag(f)
return self.client_flags
def set_client_flag(self, flag):
if flag > 0:
self.client_flags |= flag
def unset_client_flag(self, flag):
if flag > 0:
self.client_flags &= ~flag
def isset_client_flag(self, flag):
if (self.client_flags & flag) > 0:
return True
return False
@property
def user(self):
"""User used while connecting to MySQL"""
return self._username
@property
def server_host(self):
"""MySQL server IP address or name"""
return self._server_host
@property
def server_port(self):
"MySQL server TCP/IP port"
return self._server_port
@property
def unix_socket(self):
"MySQL Unix socket file location"
return self._unix_socket
def set_database(self, value):
try:
self.protocol.cmd_query("USE %s" % value)
except:
raise
def get_database(self):
"""Get the current database"""
return self._info_query("SELECT DATABASE()")[0]
database = property(get_database, set_database,
doc="Current database")
def set_time_zone(self, value):
try:
self.protocol.cmd_query("SET @@session.time_zone = %s" % value)
except:
raise
def get_time_zone(self):
return self._info_query("SELECT @@session.time_zone")[0]
time_zone = property(get_time_zone, set_time_zone,
doc="time_zone value for current MySQL session")
def set_sql_mode(self, value):
try:
self.protocol.cmd_query("SET @@session.sql_mode = %s" % value)
except:
raise
def get_sql_mode(self):
return self._info_query("SELECT @@session.sql_mode")[0]
sql_mode = property(get_sql_mode, set_sql_mode,
doc="sql_mode value for current MySQL session")
def set_autocommit(self, value):
try:
if value:
s = 'ON'
else:
s = 'OFF'
self._execute_query("SET @@session.autocommit = %s" % s)
except:
raise
def get_autocommit(self):
value = self._info_query("SELECT @@session.autocommit")[0]
if value == 1:
return True
return False
autocommit = property(get_autocommit, set_autocommit,
doc="autocommit value for current MySQL session")
def close(self):
del self.cursors[:]
self.disconnect()
def remove_cursor(self, c):
try:
self.cursors.remove(c)
except ValueError:
raise errors.ProgrammingError(
"Cursor could not be removed.")
def cursor(self, buffered=None, raw=None, cursor_class=None):
"""Instantiates and returns a cursor
By default, MySQLCursor is returned. Depending on the options
while connecting, a buffered and/or raw cursor instantiated
instead.
It is possible to also give a custom cursor through the
cursor_class paramter, but it needs to be a subclass of
mysql.connector.cursor.CursorBase.
Returns a cursor-object
"""
if cursor_class is not None:
if not issubclass(cursor_class, cursor.CursorBase):
raise errors.ProgrammingError(
"Cursor class needs be subclass of cursor.CursorBase")
c = (cursor_class)(self)
else:
buffered = buffered or self.buffered
raw = raw or self.raw
t = 0
if buffered is True:
t |= 1
if raw is True:
t |= 2
types = {
0 : cursor.MySQLCursor,
1 : cursor.MySQLCursorBuffered,
2 : cursor.MySQLCursorRaw,
3 : cursor.MySQLCursorBufferedRaw,
}
c = (types[t])(self)
if c not in self.cursors:
self.cursors.append(c)
return c
def commit(self):
"""Commit current transaction"""
self._execute_query("COMMIT")
def rollback(self):
"""Rollback current transaction"""
self._execute_query("ROLLBACK")
def _execute_query(self, query):
if self.unread_result is True:
raise errors.InternalError("Unread result found.")
self.protocol.cmd_query(query)
def _info_query(self, query):
try:
cur = self.cursor(buffered=True)
cur.execute(query)
row = cur.fetchone()
cur.close()
except:
raise
return row