Skip to content

Commit

Permalink
Implement 'read_timeout' parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Aliaksandr Akulchyk committed Jan 5, 2024
1 parent 83aa96e commit 185b943
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
23 changes: 20 additions & 3 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def connect(host="localhost", user=None, password="",
read_default_file=None, conv=decoders, use_unicode=None,
client_flag=0, cursorclass=Cursor, init_command=None,
connect_timeout=None, read_default_group=None,
read_timeout=None,
autocommit=False, echo=False,
local_infile=False, loop=None, ssl=None, auth_plugin='',
program_name='', server_public_key=None):
Expand All @@ -64,6 +65,7 @@ def connect(host="localhost", user=None, password="",
init_command=init_command,
connect_timeout=connect_timeout,
read_default_group=read_default_group,
read_timeout=read_timeout,
autocommit=autocommit, echo=echo,
local_infile=local_infile, loop=loop, ssl=ssl,
auth_plugin=auth_plugin, program_name=program_name)
Expand Down Expand Up @@ -139,7 +141,7 @@ def __init__(self, host="localhost", user=None, password="",
charset='', sql_mode=None,
read_default_file=None, conv=decoders, use_unicode=None,
client_flag=0, cursorclass=Cursor, init_command=None,
connect_timeout=None, read_default_group=None,
connect_timeout=None, read_default_group=None, read_timeout=None,
autocommit=False, echo=False,
local_infile=False, loop=None, ssl=None, auth_plugin='',
program_name='', server_public_key=None):
Expand Down Expand Up @@ -171,6 +173,7 @@ def __init__(self, host="localhost", user=None, password="",
when connecting.
:param read_default_group: Group to read from in the configuration
file.
:param read_timeout: The timeout for reading from the connection in seconds (default: None - no timeout)
:param autocommit: Autocommit mode. None means use server default.
(default: False)
:param local_infile: boolean to enable the use of LOAD DATA LOCAL
Expand Down Expand Up @@ -257,6 +260,7 @@ def __init__(self, host="localhost", user=None, password="",

self.cursorclass = cursorclass
self.connect_timeout = connect_timeout
self.read_timeout = read_timeout

self._result = None
self._affected_rows = 0
Expand Down Expand Up @@ -654,12 +658,25 @@ async def _read_packet(self, packet_type=MysqlPacket):

async def _read_bytes(self, num_bytes):
try:
data = await self._reader.readexactly(num_bytes)
if self.read_timeout:
try:
data = await asyncio.wait_for(
self._reader.readexactly(num_bytes),
self.read_timeout
)
except asyncio.TimeoutError as e:
raise asyncio.TimeoutError("Read timeout exceeded") from e
else:
data = await self._reader.readexactly(num_bytes)
except asyncio.IncompleteReadError as e:
msg = "Lost connection to MySQL server during query"
self.close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
except OSError as e:
except (OSError, asyncio.TimeoutError) as e:
msg = f"Lost connection to MySQL server during query ({e})"
self.close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
except Exception as e:
msg = f"Lost connection to MySQL server during query ({e})"
self.close()
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
Expand Down
7 changes: 7 additions & 0 deletions tests/sa/test_sa_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ async def connect(**kwargs):
return connect


@pytest.mark.run_loop
async def test_read_timeout(sa_connect):
conn = await sa_connect(read_timeout=0.01)
with pytest.raises(aiomysql.OperationalError):
await conn.execute("DO SLEEP(1)")


@pytest.mark.run_loop
async def test_execute_text_select(sa_connect):
conn = await sa_connect()
Expand Down
8 changes: 8 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ async def test_connect_timeout(connection_creator):
await connection_creator(connect_timeout=0.000000000001)


@pytest.mark.run_loop
async def test_read_timeout(connection_creator):
with pytest.raises(aiomysql.OperationalError):
con = await connection_creator(read_timeout=0.01)
cur = await con.cursor()
await cur.execute("DO SLEEP(1)")


@pytest.mark.run_loop
async def test_config_file(fill_my_cnf, connection_creator, mysql_params):
tests_root = os.path.abspath(os.path.dirname(__file__))
Expand Down

0 comments on commit 185b943

Please sign in to comment.