diff --git a/server/rating.py b/server/rating.py index 3a286edf5..736bd722c 100644 --- a/server/rating.py +++ b/server/rating.py @@ -8,6 +8,7 @@ class RatingType(): GLOBAL = "global" LADDER_1V1 = "ladder_1v1" + TMM_2V2 = "tmm_2v2" K = Union[RatingType, str] @@ -36,7 +37,9 @@ def __setitem__(self, key: K, value: Tuple[float, float]) -> None: super().__setitem__(key, val) def __getitem__(self, key: K) -> Tuple[float, float]: - if key == "tmm_2v2" and key not in self: + # TODO: Generalize for arbitrary ratings + # https://github.com/FAForever/server/issues/727 + if key == RatingType.TMM_2V2 and key not in self: mean, dev = self[RatingType.GLOBAL] if dev > 250: tmm_2v2_rating = (mean, dev) diff --git a/server/rating_service/rating_service.py b/server/rating_service/rating_service.py index 641a40bf6..eabe8a634 100644 --- a/server/rating_service/rating_service.py +++ b/server/rating_service/rating_service.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import asynccontextmanager from typing import Dict import aiocron @@ -32,6 +33,15 @@ ) +@asynccontextmanager +async def acquire_or_default(db, default=None): + if default is None: + async with db.acquire() as conn: + yield conn + else: + yield default + + @with_logger class RatingService(Service): """ @@ -146,7 +156,7 @@ async def _get_rating_data(self, summary: GameRatingSummary) -> GameRatingData: ] async def _get_player_rating( - self, player_id: int, rating_type: str + self, player_id: int, rating_type: str, conn=None ) -> Rating: if self._rating_type_ids is None: self._logger.warning( @@ -158,7 +168,7 @@ async def _get_player_rating( if rating_type_id is None: raise ValueError(f"Unknown rating type {rating_type}.") - async with self._db.acquire() as conn: + async with acquire_or_default(self._db, conn) as conn: sql = select( [leaderboard_rating.c.mean, leaderboard_rating.c.deviation] ).where( @@ -172,6 +182,13 @@ async def _get_player_rating( row = await result.fetchone() if not row: + # TODO: Generalize for arbitrary ratings + # https://github.com/FAForever/server/issues/727 + if rating_type == RatingType.TMM_2V2: + return await self._create_tmm_2v2_rating( + conn, player_id + ) + try: return await self._get_player_legacy_rating( conn, player_id, rating_type @@ -183,6 +200,27 @@ async def _get_player_rating( return Rating(row["mean"], row["deviation"]) + async def _create_tmm_2v2_rating( + self, conn, player_id: int + ) -> Rating: + mean, dev = await self._get_player_rating( + player_id, RatingType.GLOBAL, conn=conn + ) + if dev < 250: + dev = min(dev + 150, 250) + + insertion_sql = leaderboard_rating.insert().values( + login_id=player_id, + mean=mean, + deviation=dev, + total_games=0, + won_games=0, + leaderboard_id=self._rating_type_ids[RatingType.TMM_2V2], + ) + await conn.execute(insertion_sql) + + return Rating(mean, dev) + async def _get_player_legacy_rating( self, conn, player_id: int, rating_type: str ) -> Rating: diff --git a/tests/integration_tests/test_game.py b/tests/integration_tests/test_game.py index 84ceeb56f..1be5e3575 100644 --- a/tests/integration_tests/test_game.py +++ b/tests/integration_tests/test_game.py @@ -376,7 +376,7 @@ async def test_ladder_game_draw_bug(lobby_server, database): [army2, "defeat -10"] ): for proto in (proto1, proto2): - await proto1.send_message({ + await proto.send_message({ "target": "game", "command": "GameResult", "args": result diff --git a/tests/integration_tests/test_teammatchmaker.py b/tests/integration_tests/test_teammatchmaker.py index b91753651..e9cc16553 100644 --- a/tests/integration_tests/test_teammatchmaker.py +++ b/tests/integration_tests/test_teammatchmaker.py @@ -1,7 +1,14 @@ import asyncio import pytest - +from sqlalchemy import and_, select + +from server.db.models import ( + game_player_stats, + leaderboard, + leaderboard_rating, + leaderboard_rating_journal +) from tests.utils import fast_forward from .conftest import connect_and_sign_in, read_until, read_until_command @@ -416,7 +423,8 @@ async def test_game_matchmaking_timeout(lobby_server): # We don't send the `GameState: Lobby` command so the game should time out await asyncio.gather(*[ - read_until_command(proto, "match_cancelled", timeout=120) for proto in protos + read_until_command(proto, "match_cancelled", timeout=120) + for proto in protos ]) # Player's state is reset once they leave the game @@ -508,7 +516,7 @@ async def test_game_ratings(lobby_server): @fast_forward(60) -async def test_game_ratings_initialized_based_on_global(lobby_server): +async def test_ratings_initialized_based_on_global(lobby_server): test_id, _, proto = await connect_and_sign_in( ("test", "test_password"), lobby_server ) @@ -579,6 +587,107 @@ async def test_game_ratings_initialized_based_on_global(lobby_server): } +@fast_forward(60) +async def test_ratings_initialized_based_on_global_persisted( + lobby_server, + database +): + # 2 ladder and global noobs + _, _, proto1 = await connect_and_sign_in( + ("ladder1", "ladder1"), lobby_server + ) + _, _, proto2 = await connect_and_sign_in( + ("ladder2", "ladder2"), lobby_server + ) + # One global pro with no tmm games + test_id, _, proto3 = await connect_and_sign_in( + ("test", "test_password"), lobby_server + ) + # One tmm pro to balance the match + _, _, proto4 = await connect_and_sign_in( + ("tmm2", "tmm2"), lobby_server + ) + protos = [proto1, proto2, proto3, proto4] + for proto in protos: + await read_until_command(proto, "game_info") + await proto.send_message({ + "command": "game_matchmaking", + "state": "start", + "mod": "tmm2v2" + }) + + msg1, msg2, msg3, msg4 = await asyncio.gather(*[ + client_response(proto) for proto in protos + ]) + # So it doesn't matter who is host + await asyncio.gather(*[ + proto.send_message({ + "command": "GameState", + "target": "game", + "args": ["Launching"] + }) for proto in protos + ]) + + army1 = msg1["map_position"] + army2 = msg2["map_position"] + test_army = msg3["map_position"] + army4 = msg4["map_position"] + + for result in ( + [army1, "defeat -10"], + [army2, "defeat -10"], + [army4, "defeat -10"], + [test_army, "victory 10"], + ): + for proto in protos: + await proto.send_message({ + "target": "game", + "command": "GameResult", + "args": result + }) + + for proto in protos: + await proto.send_message({ + "target": "game", + "command": "GameEnded", + "args": [] + }) + + await read_until( + proto3, + lambda msg: msg["command"] == "player_info" + and any(player["id"] == test_id for player in msg["players"]), + timeout=10 + ) + + async with database.acquire() as conn: + res = await conn.execute( + select([leaderboard_rating]).select_from( + leaderboard.join(leaderboard_rating) + ).where(and_( + leaderboard.c.technical_name == "tmm_2v2", + leaderboard_rating.c.login_id == test_id + )) + ) + row = await res.fetchone() + assert row.mean > 2000 + + res = await conn.execute( + select([leaderboard_rating_journal]).select_from( + leaderboard + .join(leaderboard_rating_journal) + .join(game_player_stats) + ).where(and_( + leaderboard.c.technical_name == "tmm_2v2", + game_player_stats.c.playerId == test_id + )) + ) + rows = await res.fetchall() + assert len(rows) == 1 + assert rows[0].rating_mean_before == 2000 + assert rows[0].rating_deviation_before == 250 + + @fast_forward(30) async def test_party_cleanup_on_abort(lobby_server): for _ in range(3):