Skip to content

Commit

Permalink
fix: Only suggest moves for the active color for a position instead o…
Browse files Browse the repository at this point in the history
…f for both

Fixes #5
  • Loading branch information
nitzel committed Sep 26, 2023
1 parent 99a091b commit 9f57e7c
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 116 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,7 @@ The coverage is not very good yet.
```sh
pipenv run hupper -m pytest --verbose # automatically reruns unit tests on filechange
```
### TQDM
If you encounter the warning `UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown` on restarts - we get that because we're using TQDM.
I'm not sure what can be done about it.
20 changes: 19 additions & 1 deletion base_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
from typing import NewType
from typing import NewType, Union, Literal


TpsSymmetry = NewType("TpsSymmetry", int)
TpsString = NewType("TpsString", str) # with xn collapsed (x,x,x,... -> xn)
TpsStringExpanded = NewType("TpsStringExpanded", str) # with xn expanded to x,x,x...
NormalizedTpsString = NewType("NormalizedTpsString", str)
BoardSize = NewType("BoardSize", int)
PlayerToMove = Union[Literal["white"], Literal["black"]]

def color_to_place_from_tps(tps: str) -> PlayerToMove:
"""
The color of the next piece to place.
After move 1 this equals the player that makes the move.
"""
[_tps_str, player_to_move, move_counter] = tps.split(" ")
player_to_move = "white" if player_to_move == "1" else "black"
if int(move_counter) == 1: # first move -> apply swap
player_to_move = get_opponent(player_to_move)

return player_to_move

def get_opponent(player_to_move: PlayerToMove) -> PlayerToMove:
if player_to_move == "black":
return "white"
return "black"
184 changes: 94 additions & 90 deletions position_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import sqlite3
from typing import Optional, Union
from base_types import BoardSize, TpsString
from base_types import BoardSize, PlayerToMove, TpsString

import symmetry_normalisator
from position_processor import PositionProcessor
Expand Down Expand Up @@ -36,7 +36,8 @@ def __enter__(self):
"""
CREATE TABLE IF NOT EXISTS positions (
id integer PRIMARY KEY,
tps text UNIQUE,
tps text NOT NULL,
player_to_move text NOT NULL,
moves text
);
""",
Expand All @@ -51,16 +52,16 @@ def __enter__(self):
"""]

create_index_sql = [
"CREATE INDEX IF NOT EXISTS idx_xref_game_id ON game_position_xref (game_id);",
"CREATE INDEX IF NOT EXISTS idx_xref_position_id ON game_position_xref (position_id);",
"CREATE INDEX IF NOT EXISTS idx_position_tps ON positions (tps);",
"CREATE INDEX IF NOT EXISTS idx_games_white ON games (white);",
"CREATE INDEX IF NOT EXISTS idx_games_black ON games (black);",
"CREATE INDEX IF NOT EXISTS idx_games_rating_white ON games (rating_white);",
"CREATE INDEX IF NOT EXISTS idx_games_rating_black ON games (rating_black);",
"CREATE INDEX IF NOT EXISTS idx_games_komi ON games (komi);",
"CREATE INDEX IF NOT EXISTS idx_games_date ON games (date);",
"CREATE INDEX IF NOT EXISTS idx_games_tournament ON games (tournament);",
"CREATE INDEX IF NOT EXISTS idx_xref_game_id ON game_position_xref (game_id);",
"CREATE INDEX IF NOT EXISTS idx_xref_position_id ON game_position_xref (position_id);",
"CREATE UNIQUE INDEX IF NOT EXISTS idx_position_tps ON positions (tps, player_to_move);",
"CREATE INDEX IF NOT EXISTS idx_games_white ON games (white);",
"CREATE INDEX IF NOT EXISTS idx_games_black ON games (black);",
"CREATE INDEX IF NOT EXISTS idx_games_rating_white ON games (rating_white);",
"CREATE INDEX IF NOT EXISTS idx_games_rating_black ON games (rating_black);",
"CREATE INDEX IF NOT EXISTS idx_games_komi ON games (komi);",
"CREATE INDEX IF NOT EXISTS idx_games_date ON games (date);",
"CREATE INDEX IF NOT EXISTS idx_games_tournament ON games (tournament);",
]

try:
Expand Down Expand Up @@ -122,89 +123,93 @@ def add_position(
) -> int:
assert self.conn is not None
assert bool(next_tps) == bool(move) # either none or both must be set
curr = self.conn.cursor()
with closing(self.conn.cursor()) as curr:

# normalize for symmetries
tps_normalized, own_symmetry = symmetry_normalisator.get_tps_orientation(tps)
# In the beginning of the game, on ply 2 and 3, white is placed consecutively
color_to_place = tak.colour_to_play(tak.ply_counter - 1)
color_to_place_next = tak.colour_to_play(tak.ply_counter)

select_position_row_sql = f"""
SELECT *
FROM positions
WHERE tps = '{tps_normalized}'
;
"""

curr.execute(select_position_row_sql)
row = curr.fetchone()
# normalize for symmetries
tps_normalized, own_symmetry = symmetry_normalisator.get_tps_orientation(tps)
select_position_row_sql = f"""
SELECT *
FROM positions
WHERE tps = '{tps_normalized}' AND player_to_move = '{color_to_place}'
;
"""

# if this position does not exist, create it
if row is None:
self.create_position_entry(tps_normalized)
curr.execute(select_position_row_sql)
row = curr.fetchone()

# update the game-move crossreference table
row_dict = dict(row)
position_id = row_dict['id']
# if this position does not exist, create it
if row is None:
self.create_position_entry(tps_normalized, color_to_place)
curr.execute(select_position_row_sql)
row = curr.fetchone()

curr.execute(
"INSERT INTO game_position_xref (game_id, position_id) VALUES (:game_id, :position_id);",
{ 'game_id': game_id, 'position_id': position_id }
)
# update the game-move crossreference table
row_dict = dict(row)
position_id = row_dict['id']

if next_tps is not None and move is not None:
next_tps_normalized, _next_symmetry = symmetry_normalisator.get_tps_orientation(next_tps)
select_next_position_row_sql = f"""
SELECT *
FROM positions
WHERE tps = '{next_tps_normalized}'
;
"""
curr.execute(select_next_position_row_sql)
next_pos = curr.fetchone()
curr.execute(
"INSERT INTO game_position_xref (game_id, position_id) VALUES (:game_id, :position_id);",
{ 'game_id': game_id, 'position_id': position_id }
)

# if next position does not exist, create it
if next_pos is None:
self.create_position_entry(next_tps_normalized)
if next_tps is not None and move is not None:
next_tps_normalized, _next_symmetry = symmetry_normalisator.get_tps_orientation(next_tps)
select_next_position_row_sql = f"""
SELECT *
FROM positions
WHERE tps = '{next_tps_normalized}'
AND player_to_move = '{color_to_place_next}'
;
"""
curr.execute(select_next_position_row_sql)
next_pos = curr.fetchone()

next_pos_id = dict(next_pos)['id']
# if next position does not exist, create it
if next_pos is None:
self.create_position_entry(next_tps_normalized, color_to_place_next)
curr.execute(select_next_position_row_sql)
next_pos = curr.fetchone()

# if a move is given also update the move table
# orient move to previous symmetry
move = symmetry_normalisator.transform_move(
move=move,
orientation=own_symmetry,
board_size=tak.size,
)
position_moves = row_dict['moves']
if position_moves != '':
position_moves = row_dict['moves'].split(';')
else:
position_moves = []
moves_list = list(map(lambda x: x.split(','), position_moves))

# if move is in moves_list, update count
move_found = False
for moves in moves_list:
if moves[0] == move:
move_found = True
break

if not move_found:
# append new move to moves_list
moves_list.append((move, str(next_pos_id)))

# transform moves_list into db string format
position_moves = ';'.join(map(','.join, moves_list))

curr.execute(
"UPDATE positions SET moves=:position_moves WHERE id=:position_id",
{ 'position_moves': position_moves, 'position_id': position_id }
)
next_pos_id = dict(next_pos)['id']

return own_symmetry
# if a move is given also update the move table
# orient move to previous symmetry
move = symmetry_normalisator.transform_move(
move=move,
orientation=own_symmetry,
board_size=tak.size,
)
position_moves = row_dict['moves']
if position_moves != '':
position_moves = row_dict['moves'].split(';')
else:
position_moves = []
moves_list = list(map(lambda x: x.split(','), position_moves))

# if move is in moves_list, update count
move_found = False
for moves in moves_list:
if moves[0] == move:
move_found = True
break

if not move_found:
# append new move to moves_list
moves_list.append((move, str(next_pos_id)))

# transform moves_list into db string format
position_moves = ';'.join(map(','.join, moves_list))

curr.execute(
"UPDATE positions SET moves=:position_moves WHERE id=:position_id",
{ 'position_moves': position_moves, 'position_id': position_id }
)

return own_symmetry

def dump(self):
assert self.conn is not None
Expand Down Expand Up @@ -232,14 +237,13 @@ def add_game(
RETURNING id;
""" # use RETURNING so that we can get the inserted id after the query

curr = self.conn.cursor()
curr.execute(insert_game_data_sql)
inserted_id = curr.fetchone()[0]
return inserted_id
with closing(self.conn.cursor()) as curr:
curr.execute(insert_game_data_sql)
inserted_id = curr.fetchone()[0]
return inserted_id

def create_position_entry(self, tps: str):
def create_position_entry(self, tps: str, player_to_move: PlayerToMove):
assert self.conn is not None

insert_position_data_sql = "INSERT INTO positions (tps, moves) VALUES (:tps, '');"
curr = self.conn.cursor()
curr.execute(insert_position_data_sql, { 'tps': tps })
insert_position_data_sql = "INSERT INTO positions (tps, player_to_move, moves) VALUES (:tps, :player_to_move, '');"
with closing(self.conn.cursor()) as curr:
curr.execute(insert_position_data_sql, { 'tps': tps, 'player_to_move': player_to_move })
23 changes: 13 additions & 10 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import symmetry_normalisator
from db_extractor import BOTLIST, get_games_from_db, get_ptn
from position_db import PositionDataBase
from base_types import BoardSize, NormalizedTpsString, TpsString, TpsSymmetry
from base_types import BoardSize, NormalizedTpsString, TpsString, TpsSymmetry, color_to_place_from_tps

DATA_DIR = 'data'
PLAYTAK_GAMES_DB = os.path.join(DATA_DIR, 'games_anon.db')
Expand Down Expand Up @@ -219,7 +219,9 @@ def get_position_analysis(
settings: AnalysisSettings,
tps: TpsString,
) -> PositionAnalysis:
print(f'requested position with white: {settings.white}, black: {settings.black}, min. min_rating: {settings.min_rating}, tps: {tps}')
print(f'requested position with white: {settings.white}, black: {settings.black}, min rating: {settings.min_rating}, tps: {tps}')

player_to_move = color_to_place_from_tps(tps)

settings.min_rating = max(config.min_rating, settings.min_rating) if settings.min_rating else config.min_rating
settings.include_bot_games = config.include_bot_games and settings.include_bot_games
Expand All @@ -230,29 +232,27 @@ def get_position_analysis(
else:
raise ValueError(f"tournament field is '{settings.tournament}' of type '{type(settings.tournament)}' but should be bool or null")

print("Searching with", config, settings)

# we don't care about move number:
sym_tps, symmetry = to_symmetric_tps(tps)
print(f"Searching for player={player_to_move} with", config, settings, "sym_tps=", sym_tps)

select_results_sql = "SELECT * FROM positions WHERE tps=:sym_tps;"
select_results_sql = "SELECT * FROM positions WHERE tps=:sym_tps AND player_to_move=:player_to_move"

with closing(sqlite3.connect(config.db_file_name)) as db:
db.row_factory = sqlite3.Row
with closing(db.cursor()) as cur:
cur.execute(select_results_sql, {"sym_tps": sym_tps})
cur.execute(select_results_sql, {"sym_tps": sym_tps, "player_to_move": player_to_move})

rows = cur.fetchone()
if rows is None:
return PositionAnalysis(config=config, settings=settings)

rows = dict(rows)

position_moves = rows['moves']
if position_moves == '':
if rows['moves'] == '':
position_moves = []
else:
position_moves = position_moves.split(';')
position_moves = rows['moves'].split(';')

moves_list: list[tuple[str, str]] = list(map(lambda x: x.split(','), position_moves))

Expand Down Expand Up @@ -310,6 +310,7 @@ def build_condition(
tournament_str, tournament_vals = build_condition("tournament", settings.tournament)

default_query_vars = {
"player_to_move": player_to_move,
"min_rating": settings.min_rating,
"min_date": playtak_timestamp_from(settings.min_date) if settings.min_date else None,
"max_date": playtak_timestamp_from(settings.max_date) if settings.max_date else None,
Expand All @@ -326,6 +327,7 @@ def build_condition(
continue
explored_position_ids.add(position_id)

# no need to specify player_to_move here, because we're already walking by positions.id
select_games_sql = f"""
SELECT games.result, count(games.result) AS count
FROM game_position_xref, games, positions
Expand Down Expand Up @@ -393,6 +395,7 @@ def build_condition(
WHERE game_position_xref.position_id=positions.id
AND games.id = game_position_xref.game_id
AND positions.tps = :sym_tps
AND positions.player_to_move = :player_to_move
AND games.rating_white >= :min_rating
AND games.rating_black >= :min_rating
{tournament_str}
Expand Down Expand Up @@ -443,7 +446,7 @@ def get_position_with_db_id(db_id: int, tps: str):

if db_id >= len(openings_db_configs):
raise NotFound("database index out of range, query api/v1/databases for options")
tps_string: TpsString = tps # type: ignore
tps_string = TpsString(tps)
analysis = get_position_analysis(openings_db_configs[db_id], settings, tps_string)
return jsonify(analysis)

Expand Down
Loading

0 comments on commit 9f57e7c

Please sign in to comment.