diff --git a/aiomysql/connection.py b/aiomysql/connection.py index 3520dfcc..3dbb8312 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -51,6 +51,7 @@ def connect(host="localhost", user=None, password="", read_default_file=None, conv=decoders, use_unicode=None, client_flag=0, cursorclass=Cursor, init_command=None, connect_timeout=None, read_default_group=None, + read_timeout=None, autocommit=False, echo=False, local_infile=False, loop=None, ssl=None, auth_plugin='', program_name='', server_public_key=None): @@ -64,6 +65,7 @@ def connect(host="localhost", user=None, password="", init_command=init_command, connect_timeout=connect_timeout, read_default_group=read_default_group, + read_timeout=read_timeout, autocommit=autocommit, echo=echo, local_infile=local_infile, loop=loop, ssl=ssl, auth_plugin=auth_plugin, program_name=program_name) @@ -139,7 +141,7 @@ def __init__(self, host="localhost", user=None, password="", charset='', sql_mode=None, read_default_file=None, conv=decoders, use_unicode=None, client_flag=0, cursorclass=Cursor, init_command=None, - connect_timeout=None, read_default_group=None, + connect_timeout=None, read_default_group=None, read_timeout=None, autocommit=False, echo=False, local_infile=False, loop=None, ssl=None, auth_plugin='', program_name='', server_public_key=None): @@ -171,6 +173,7 @@ def __init__(self, host="localhost", user=None, password="", when connecting. :param read_default_group: Group to read from in the configuration file. + :param read_timeout: The timeout for reading from the connection in seconds (default: None - no timeout) :param autocommit: Autocommit mode. None means use server default. (default: False) :param local_infile: boolean to enable the use of LOAD DATA LOCAL @@ -257,6 +260,7 @@ def __init__(self, host="localhost", user=None, password="", self.cursorclass = cursorclass self.connect_timeout = connect_timeout + self.read_timeout = read_timeout self._result = None self._affected_rows = 0 @@ -654,12 +658,25 @@ async def _read_packet(self, packet_type=MysqlPacket): async def _read_bytes(self, num_bytes): try: - data = await self._reader.readexactly(num_bytes) + if self.read_timeout: + try: + data = await asyncio.wait_for( + self._reader.readexactly(num_bytes), + self.read_timeout + ) + except asyncio.TimeoutError as e: + raise asyncio.TimeoutError("Read timeout exceeded") from e + else: + data = await self._reader.readexactly(num_bytes) except asyncio.IncompleteReadError as e: msg = "Lost connection to MySQL server during query" self.close() raise OperationalError(CR.CR_SERVER_LOST, msg) from e - except OSError as e: + except (OSError, asyncio.TimeoutError) as e: + msg = f"Lost connection to MySQL server during query ({e})" + self.close() + raise OperationalError(CR.CR_SERVER_LOST, msg) from e + except Exception as e: msg = f"Lost connection to MySQL server during query ({e})" self.close() raise OperationalError(CR.CR_SERVER_LOST, msg) from e diff --git a/tests/sa/test_sa_connection.py b/tests/sa/test_sa_connection.py index a68e9032..321a682a 100644 --- a/tests/sa/test_sa_connection.py +++ b/tests/sa/test_sa_connection.py @@ -35,6 +35,13 @@ async def connect(**kwargs): return connect +@pytest.mark.run_loop +async def test_read_timeout(sa_connect): + conn = await sa_connect(read_timeout=0.01) + with pytest.raises(aiomysql.OperationalError): + await conn.execute("DO SLEEP(1)") + + @pytest.mark.run_loop async def test_execute_text_select(sa_connect): conn = await sa_connect() diff --git a/tests/test_connection.py b/tests/test_connection.py index c0c1be3d..3e07795b 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -32,6 +32,14 @@ async def test_connect_timeout(connection_creator): await connection_creator(connect_timeout=0.000000000001) +@pytest.mark.run_loop +async def test_read_timeout(connection_creator): + with pytest.raises(aiomysql.OperationalError): + con = await connection_creator(read_timeout=0.01) + cur = await con.cursor() + await cur.execute("DO SLEEP(1)") + + @pytest.mark.run_loop async def test_config_file(fill_my_cnf, connection_creator, mysql_params): tests_root = os.path.abspath(os.path.dirname(__file__))