diff --git a/src/bots/bot_manager.py b/src/bots/bot_manager.py index 8d0e490..1fbaae6 100644 --- a/src/bots/bot_manager.py +++ b/src/bots/bot_manager.py @@ -1,13 +1,178 @@ -from helpers.database_helper import UserDatabase, StatisticsDatabase +import asyncio +import datetime +import os +import sqlite3 +from itertools import islice +from multiprocessing import Process, Lock +from typing import List + +import requests +from azure.core.exceptions import ResourceNotFoundError +from azure.servicebus import ServiceBusMessage +from azure.servicebus.aio import ServiceBusClient +from azure.servicebus.aio.management import ServiceBusAdministrationClient + +from bots.irc_bot import IrcBot +from bots.twitch_bot import TwitchBot +from helpers.logger import RonniaLogger + + +def batcher(iterable, batch_size): + iterator = iter(iterable) + while batch := list(islice(iterator, batch_size)): + yield batch + + +class TwitchAPI: + def __init__(self, client_id, client_secret): + self.client_id = client_id + self.client_secret = client_secret + + self.access_token = self.get_token() + + def get_token(self): + """ + Gets access token from Twitch API + """ + url = "https://id.twitch.tv/oauth2/token?client_id={}&client_secret={}&grant_type=client_credentials".format( + self.client_id, self.client_secret) + response = requests.post(url) + return response.json()['access_token'] + + def get_streams(self, user_ids: List[int]): + """ + Gets streams from Twitch API helix/streams only users playing osu! + """ + headers = {'Authorization': 'Bearer {}'.format(self.access_token), + 'Client-ID': self.client_id} + streams = [] + for user_id in batcher(user_ids, 100): + # game_id = 21465 is osu! + url = f"https://api.twitch.tv/helix/streams?first=100&game_id=21465&" + "&".join( + [f"user_id={user}" for user in user_id]) + response = requests.get(url, headers=headers) + streams += response.json()['data'] + return streams + + +class TwitchProcess(Process): + def __init__(self, user_list: List[str], join_lock: Lock): + super().__init__() + self.join_lock = join_lock + self.user_list = user_list + self.bot = None + + def initialize(self): + self.bot = TwitchBot(initial_channel_ids=self.user_list, join_lock=self.join_lock) + + def run(self) -> None: + self.initialize() + self.bot.run() + + +class IRCProcess(Process): + def __init__(self): + super().__init__() + self.bot = None + + def initialize(self) -> None: + self.bot = IrcBot("#osu", os.getenv('OSU_USERNAME'), "irc.ppy.sh", password=os.getenv("IRC_PASSWORD")) + + def run(self) -> None: + self.initialize() + self.bot.start() class BotManager: def __init__(self, ): + self.users_db = sqlite3.connect(os.path.join(os.getenv('DB_DIR'), 'users.db')) + self.join_lock = Lock() + self.instance_message_queue = None + + self.twitch_client = TwitchAPI(os.getenv('CLIENT_ID'), os.getenv('CLIENT_SECRET')) + self._loop = asyncio.get_event_loop() + + self.servicebus_connection_string = os.getenv('SERVICE_BUS_CONN_STRING') + self.servicebus_webserver_queue_name = 'webserver-signups' + self.servicebus_webserver_reply_queue_name = 'webserver-signups-reply' + self.servicebus_bot_queue_name = 'bot-signups' + self.servicebus_bot_reply_queue_name = 'bot-signups-reply' + self.servicebus_queues = {'webserver-signups': {'max_delivery_count': 100, + 'default_message_time_to_live': datetime.timedelta(seconds=10)}, + 'webserver-signups-reply': {'max_delivery_count': 100, + 'default_message_time_to_live': datetime.timedelta( + seconds=10)}, + 'bot-signups': {'max_delivery_count': 100, + 'default_message_time_to_live': datetime.timedelta(seconds=10)}, + 'bot-signups-reply': {'max_delivery_count': 100, + 'default_message_time_to_live': datetime.timedelta(seconds=10)}, + 'twitch-to-irc': {'max_delivery_count': 100, + 'default_message_time_to_live': datetime.timedelta(seconds=10)}, + } + + self.servicebus_mgmt = ServiceBusAdministrationClient.from_connection_string(self.servicebus_connection_string) + self.servicebus_client = ServiceBusClient.from_connection_string(conn_str=self.servicebus_connection_string) + + self.bot_instances = [] + self.bot_processes = [] + + self.irc_process = IRCProcess() + + def start(self): + self._loop.run_until_complete(self.initialize_queues()) + + all_users = self.users_db.execute('SELECT * FROM users;').fetchall() + all_user_twitch_ids = [user[4] for user in all_users] + streaming_user_ids = [user['user_id'] for user in self.twitch_client.get_streams(all_user_twitch_ids)] + + for user_id in all_user_twitch_ids: + if user_id not in streaming_user_ids: + streaming_user_ids.append(user_id) + + self.irc_process.start() + + for user_id_list in batcher(streaming_user_ids, 100): + p = TwitchProcess(user_id_list, self.join_lock) + p.start() + self.bot_processes.append(p) + + async def initialize_queues(self): + """ + Initializes webserver & bot, signup and reply queues + """ + logger.info('Initializing queues...') + for queue_name, queue_properties in self.servicebus_queues.items(): + try: + queue_details = await self.servicebus_mgmt.get_queue(queue_name) + except ResourceNotFoundError: + await self.servicebus_mgmt.create_queue(queue_name, **queue_properties) + + async def run_service_bus_receiver(self): + """ + Creates the receiver for the webserver queue + Forwards incoming messages to the bot instance + Replies to the webserver with a reply queue + """ + receiver = self.servicebus_client.get_queue_receiver(queue_name=self.servicebus_webserver_queue_name) + logger.info('Started servicebus receiver, listening for messages...') + async for message in receiver: + await self.receive_and_parse_message(message) + await receiver.complete_message(message) + + async def receive_and_parse_message(self, message): + """ + Receive a message from the webserver signup queue and parse it, forward it to bot queue. + """ + logger.info(f'Received signup message: {message}') + async with ServiceBusClient.from_connection_string(self.servicebus_connection_string) as sb_client: + sender = sb_client.get_queue_sender(queue_name=self.servicebus_bot_queue_name) + logger.debug(f'Sending message to bot: {message}') + await sender.send_messages(message) + - self.users_db = UserDatabase() - self.messages_db = StatisticsDatabase() - pass +if __name__ == '__main__': + logger = RonniaLogger(__name__) - async def get_user_data(self, user_id): - await self.users_db.get_user_data(user_id) - pass + bot_manager = BotManager() + bot_manager.start() + asyncio.run(bot_manager.run_service_bus_receiver()) diff --git a/src/bots/irc_bot.py b/src/bots/irc_bot.py index 94ff928..a0cab6d 100644 --- a/src/bots/irc_bot.py +++ b/src/bots/irc_bot.py @@ -1,17 +1,18 @@ import asyncio +import json +import os import sqlite3 from typing import Union import attr -import logging -from threading import Lock - import irc.bot +from azure.servicebus.aio import ServiceBusClient from irc.client import Event, ServerConnection from helpers.database_helper import UserDatabase, StatisticsDatabase +from helpers.logger import RonniaLogger -logger = logging.getLogger('ronnia') +logger = RonniaLogger(__name__) @attr.s @@ -28,20 +29,30 @@ def __init__(self, channel, nickname, server, port=6667, password=None): self.channel = channel self.users_db = UserDatabase() self.messages_db = StatisticsDatabase() - self._loop = asyncio.get_event_loop() - self.message_lock = Lock() + self.servicebus_connection_string = os.getenv('SERVICE_BUS_CONN_STRING') + self.servicebus_client = ServiceBusClient.from_connection_string(conn_str=self.servicebus_connection_string) + self.listen_queue_name = 'twitch-to-irc' + self._loop = asyncio.get_event_loop() self._commands = {'disable': self.disable_requests_on_channel, 'echo': self.toggle_notifications, 'feedback': self.toggle_notifications, 'enable': self.enable_requests_on_channel, - 'register': self.register_bot_on_channel, 'help': self.show_help_message, 'setsr': self.set_sr_rating } self.connection.set_rate_limit(1) + async def receive_servicebus_queue(self): + receiver = self.servicebus_client.get_queue_receiver(queue_name=self.listen_queue_name) + async for message in receiver: + logger.info(f'Received message from service bus: {str(message)}') + message_dict = json.loads(str(message)) + target_channel = message_dict['target_channel'] + message_contents = message_dict['message'] + self.send_message(target_channel, message_contents) + def on_welcome(self, c: ServerConnection, e: Event): logger.info(f"Successfully joined irc!") self._loop.run_until_complete(self.users_db.initialize()) @@ -73,21 +84,10 @@ def do_command(self, e: Event): # Check if the user is registered existing_user = self.users_db.get_user_from_osu_username(db_nick) - if existing_user is None and cmd == 'register': + if existing_user is None: self.send_message(e.source.nick, - f'Hello! Thanks for your interest in this bot! ' - f'But, registering for the bot automatically is not supported currently. ' - f'I\'m hosting this bot with the free tier compute engine... ' - f'So, if it gets too many requests it might blow up! ' - f'That\'s why I\'m manually allowing requests right now. ' - f'(Check out the project page if you haven\'t already.)' - f'[https://github.com/aticie/ronnia] ' - f'Contact me on discord and I can enable it for you! heyronii#9925') + f'Please register your osu! account (from here)[https://ronnia.me/].') return - elif existing_user is None: - self.send_message(e.source.nick, f'Sorry, you are not registered. ' - f'(Check out the project page for details.)' - f'[https://github.com/aticie/ronnia]') else: # Check if command is valid try: @@ -111,16 +111,6 @@ def disable_requests_on_channel(self, event: Event, *args, user_details: Union[d f'If you want to re-enable requests, type !enable anytime.') self.messages_db.add_command('disable', 'osu_irc', event.source.nick) - def register_bot_on_channel(self, event: Event, *args, user_details: Union[dict, sqlite3.Row]): - """ - Registers bot on twitch channel - :param event: Event of the current message - :param user_details: User Details Sqlite row factory - - Currently not supported... TODO: Register user -> ask twitch - """ - logger.debug(f'Register bot on channel: {user_details}') - def enable_requests_on_channel(self, event: Event, *args, user_details: Union[dict, sqlite3.Row]): """ Enables requests on twitch channel @@ -160,7 +150,8 @@ def show_help_message(self, event: Event, *args, user_details: Union[dict, sqlit self.send_message(event.source.nick, f'Check out the (project page)[https://github.com/aticie/ronnia] for more information. ' f'List of available commands are (listed here)' - f'[https://github.com/aticie/ronnia/wiki/Commands].') + f'[https://github.com/aticie/ronnia/wiki/Commands]. ' + f'(Click here)[https://ronnia.me/ to access your dashboard. )') self.messages_db.add_command('help', 'osu_irc', event.source.nick) def set_sr_rating(self, event: Event, *args, user_details: Union[dict, sqlite3.Row]): diff --git a/src/bots/twitch_bot.py b/src/bots/twitch_bot.py index ce3a3e0..969a191 100644 --- a/src/bots/twitch_bot.py +++ b/src/bots/twitch_bot.py @@ -1,22 +1,24 @@ import datetime -import logging +import json import os import time from abc import ABC -from threading import Thread +from multiprocessing import Lock from typing import AnyStr, Tuple, Union, List import aiohttp +from azure.servicebus import ServiceBusMessage +from azure.servicebus.aio import ServiceBusClient from twitchio import Message, Channel, Chatter -from twitchio.ext import commands, routines +from twitchio.ext import commands -from bots.irc_bot import IrcBot from helpers.beatmap_link_parser import parse_beatmap_link from helpers.database_helper import UserDatabase, StatisticsDatabase +from helpers.logger import RonniaLogger from helpers.osu_api_helper import OsuApi from helpers.utils import convert_seconds_to_readable -logger = logging.getLogger('ronnia') +logger = RonniaLogger(__name__) class TwitchBot(commands.Bot, ABC): @@ -29,12 +31,20 @@ class TwitchBot(commands.Bot, ABC): "-1": 'WIP', "-2": 'Graveyard'} - def __init__(self, initial_channel_ids: List[int]): + def __init__(self, initial_channel_ids: List[int], join_lock: Lock): self.users_db = UserDatabase() self.messages_db = StatisticsDatabase() + self.osu_api = OsuApi(self.messages_db) + self.environment = os.getenv('ENVIRONMENT') self.initial_channel_ids = initial_channel_ids + self.servicebus_connection_string = os.getenv('SERVICE_BUS_CONN_STRING') + self.servicebus_client = ServiceBusClient.from_connection_string(conn_str=self.servicebus_connection_string) + self.signup_queue_name = 'bot-signups' + self.signup_reply_queue_name = 'bot-signups-reply' + self.twitch_to_irc_queue_name = 'twitch-to-irc' self.all_user_details = [] + self.channel_names = [] args = { 'token': os.getenv('TMI_TOKEN'), @@ -44,14 +54,52 @@ def __init__(self, initial_channel_ids: List[int]): } super().__init__(**args) + self._join_lock = join_lock + self.main_prefix = None - self.osu_api = OsuApi(self.messages_db) self.user_last_request = {} - self.irc_bot = IrcBot("#osu", os.getenv('OSU_USERNAME'), "irc.ppy.sh", password=os.getenv("IRC_PASSWORD")) - self.irc_bot_thread = Thread(target=self.irc_bot.start) self.join_channels_first_time = True + async def servicebus_message_receiver(self): + receiver = self.servicebus_client.get_queue_receiver(queue_name=self.signup_queue_name) + async for message in receiver: + logger.info(f'Received sign-up message: {message}') + reply_message = await self.receive_and_parse_message(message) + await receiver.complete_message(message) + + async with ServiceBusClient.from_connection_string( + conn_str=self.servicebus_connection_string) as servicebus_client: + sender = servicebus_client.get_queue_sender(queue_name=self.signup_reply_queue_name) + logger.info(f'Sending reply message: {reply_message}') + await sender.send_messages(reply_message) + + async def receive_and_parse_message(self, message): + """ + {'command': 'signup', + 'osu_username': 'heyronii', + 'osu_id': 5642779, + 'twitch_username': 'heyronii', + 'twitch_id': '68427964', + 'avatar_url': 'https://static-cdn.jtvnw.net/jtv_user_pictures/18057641-820c-44d0-af8d-032e129086fb-profile_image-300x300.png'} + """ + message_dict = json.loads(str(message)) + twitch_username = message_dict['twitch_username'] + osu_username = message_dict['osu_username'] + osu_id = message_dict['osu_id'] + twitch_id = message_dict['twitch_id'] + await self.users_db.add_user(twitch_username=twitch_username, + twitch_id=twitch_id, + osu_username=osu_username, + osu_user_id=osu_id) + user_db_details = await self.users_db.get_user_from_twitch_username(twitch_username) + message_dict['user_id'] = user_db_details['user_id'] + return ServiceBusMessage(json.dumps(message_dict)) + + def run(self): + self.loop.create_task(self.servicebus_message_receiver()) + super().run() + @staticmethod async def _get_access_token(): client_id = os.getenv('CLIENT_ID'), @@ -92,7 +140,7 @@ async def handle_request(self, message: Message): if await self.users_db.get_echo_status(twitch_username=message.channel.name): await self._send_twitch_message(message, beatmap_info) - await self._send_irc_message(message, beatmap_info, given_mods) + await self._send_beatmap_to_irc(message, beatmap_info, given_mods) await self.messages_db.add_request(requested_beatmap_id=int(beatmap_info['beatmap_id']), requested_channel_name=message.channel.name, requester_channel_name=message.author.name, @@ -104,7 +152,7 @@ async def inform_user_on_updates(self, osu_username: str, twitch_username: str, if os.path.exists(message_txt_path): with open(message_txt_path) as f: update_message = f.read().strip() - self.irc_bot.send_message(osu_username, update_message) + await self._send_irc_message(osu_username, update_message) else: logger.warning(f'Looking for {message_txt_path}, but it does not exist!') await self.users_db.set_channel_updated(twitch_username) @@ -127,7 +175,7 @@ async def check_beatmap_star_rating(self, message: Message, beatmap_info): async def check_request_criteria(self, message: Message, beatmap_info: dict): test_status = await self.users_db.get_test_status(message.channel.name) - if not test_status: + if not test_status or self.environment != 'testing': await self.check_sub_only_mode(message) await self.check_cp_only_mode(message) await self.check_user_excluded(message) @@ -141,7 +189,8 @@ async def check_request_criteria(self, message: Message, beatmap_info: dict): raise AssertionError async def check_user_excluded(self, message: Message): - excluded_users = await self.users_db.get_excluded_users(twitch_username=message.channel.name, return_mode='list') + excluded_users = await self.users_db.get_excluded_users(twitch_username=message.channel.name, + return_mode='list') assert message.author.name.lower() not in excluded_users, f'{message.author.name} is excluded' async def check_sub_only_mode(self, message: Message): @@ -264,7 +313,7 @@ async def _prune_cooldowns(self, time_right_now: datetime.datetime): return - async def _send_irc_message(self, message: Message, beatmap_info: dict, given_mods: str): + async def _send_beatmap_to_irc(self, message: Message, beatmap_info: dict, given_mods: str): """ Sends the beatmap request message to osu!irc bot :param message: Twitch Message object @@ -273,11 +322,19 @@ async def _send_irc_message(self, message: Message, beatmap_info: dict, given_mo :return: """ irc_message = await self._prepare_irc_message(message, beatmap_info, given_mods) - irc_target_channel = (await self.users_db.get_user_from_twitch_username(message.channel.name))['osu_username'] - self.irc_bot.send_message(irc_target_channel, irc_message) + await self._send_irc_message(irc_message, irc_target_channel) + return + async def _send_irc_message(self, irc_message: str, irc_target_channel): + message = json.dumps({'message': irc_message, 'target_channel': irc_target_channel}) + async with ServiceBusClient.from_connection_string(self.servicebus_connection_string) as sb_client: + sender = sb_client.get_queue_sender(queue_name=self.twitch_to_irc_queue_name) + msg = ServiceBusMessage(message) + await sender.send_messages(msg) + logger.debug(f'Sending message from twitch to irc: {message}') + @staticmethod async def _send_twitch_message(message: Message, beatmap_info: dict): """ @@ -360,33 +417,16 @@ async def event_ready(self): logger.info(f'Successfully initialized databases!') - # TODO: Fix here - # self.all_user_details = await self.users_db.get_all_users() - # self.initial_channel_ids = [user['twitch_id'] for user in self.all_user_details] - logger.debug(f'Populating users: {self.initial_channel_ids}') - # Get channel names from ids - list_batcher = lambda sample_list, chunk_size: [sample_list[i:i + chunk_size] for i in - range(0, len(sample_list), chunk_size)] - - channel_names = [] - for batch in list_batcher(self.initial_channel_ids, 100): - channel_names.extend(await self.fetch_users(ids=batch)) - - channels_to_join = [ch.name for ch in channel_names] - - if self.nick not in channels_to_join: - channels_to_join.append(self.nick) + self.channel_names = await self.fetch_users(ids=self.initial_channel_ids) + channels_to_join = [ch.name for ch in self.channel_names] logger.debug(f'Joining channels: {channels_to_join}') # Join channels channel_join_start = time.time() - await self.join_channels(channels_to_join) + # await self.join_channels(channels_to_join) logger.debug(f'Joined all channels after {time.time() - channel_join_start:.2f}s') - # Start update users routine - self.update_users.start() - self.join_channels_routine.start() initial_extensions = ['cogs.request_cog', 'cogs.admin_cog'] for extension in initial_extensions: @@ -394,85 +434,3 @@ async def event_ready(self): logger.debug(f'Successfully loaded: {extension}') logger.info(f'Ready | {self.nick}') - - @routines.routine(hours=1) - async def update_users(self): - logger.info('Started updating user routine') - user_details = await self.users_db.get_all_users() - channel_ids = [ch['twitch_id'] for ch in user_details] - channel_details = await self.fetch_users(ids=channel_ids) - - # Remove banned twitch users from database - if len(user_details) != len(channel_details): - logger.warning('There\'s a banned user.') - logger.info(f'Users in database vs from twitch api: {len(user_details)} - {len(channel_details)}.') - banned_users = set([user['twitch_id'] for user in user_details]).difference( - set([str(user.id) for user in channel_details])) - logger.info(f'Banned user ids: {banned_users}') - new_user_details = [] - for user in user_details: - if user['twitch_id'] in banned_users: - await self.users_db.remove_user(user['twitch_username']) - else: - new_user_details.append(user) - user_details = new_user_details.copy() - - user_details.sort(key=lambda x: int(x['twitch_id'])) - channel_details.sort(key=lambda x: x.id) - - for db_user, new_twitch_user in zip(user_details, channel_details): - try: - osu_details = await self.osu_api.get_user_info(db_user['osu_username']) - except aiohttp.ClientError as client_error: - logger.error(client_error) - osu_details = {'user_id': db_user['osu_id'], - 'username': db_user['osu_username']} - - # Remove banned osu! users from database - if osu_details is None: - await self.users_db.remove_user(twitch_username=db_user['twitch_username']) - continue - new_twitch_username = new_twitch_user.name.lower() - new_osu_username = osu_details['username'].lower().replace(' ', '_') - twitch_id = new_twitch_user.id - osu_user_id = osu_details['user_id'] - - if new_osu_username != db_user['osu_username'] or new_twitch_username != db_user['twitch_username']: - logger.info(f'Username change:') - logger.info(f'osu! old: {db_user["osu_username"]} - new: {new_osu_username}') - logger.info(f'Twitch old: {db_user["twitch_username"]} - new: {new_twitch_username}') - await self.users_db.update_user(new_twitch_username=new_twitch_username, - new_osu_username=new_osu_username, - twitch_id=twitch_id, - osu_user_id=osu_user_id) - - @routines.routine(hours=1) - async def join_channels_routine(self): - logger.debug('Started join channels routine') - if self.join_channels_first_time: - self.join_channels_first_time = False - return - all_user_details = await self.users_db.get_all_users() - twitch_users = [user['twitch_username'] for user in all_user_details] - logger.debug(f'Joining: {twitch_users}') - await self.join_channels(twitch_users) - - async def close(self): - logger.info('Closing bot') - self.update_users.cancel() - self.join_channels_routine.cancel() - await self.users_db.close() - await self.messages_db.close() - self._connection._keeper.cancel() - self._connection.is_ready.clear() - - futures = self._connection._fetch_futures() - - for fut in futures: - fut.cancel() - - if self._connection._websocket: - await self._connection._websocket.close() - if self._connection._client._http.session: - await self._connection._client._http.session.close() - self._connection._loop.stop() diff --git a/src/cogs/admin_cog.py b/src/cogs/admin_cog.py index f0078ca..8b02137 100644 --- a/src/cogs/admin_cog.py +++ b/src/cogs/admin_cog.py @@ -35,16 +35,6 @@ async def add_user_to_db(self, ctx: Context, *args): logger.info(f'Adding {twitch_username} - {osu_username} to user database!') await ctx.send(f'Added {twitch_username} -> {osu_username}.') - @commands.command(name="rmuser") - async def remove_user_from_db(self, ctx: Context, *args): - - twitch_username = args[0].lower() - - self.bot.users_db.remove_user(twitch_username=twitch_username) - await self.bot.part_channel([twitch_username]) - await ctx.send(f'Removed {twitch_username}.') - logger.info(f'Removed {twitch_username}!') - @commands.command(name="test") async def toggle_test_for_user(self, ctx: Context, *args): @@ -53,22 +43,6 @@ async def toggle_test_for_user(self, ctx: Context, *args): await ctx.send(f'Setting test to {new_value} for {twitch_username}.') logger.info(f'Setting test to {new_value} for {twitch_username}.') - @commands.command(name="status") - async def get_active_channels(self, ctx: Context): - - all_users = await self.bot.users_db.get_all_users() - - not_joined = [] - for user in all_users: - connected_channel_names = [ch.name for ch in self.bot.connected_channels] - if user['twitch_username'] not in connected_channel_names: - not_joined.append(user['twitch_username']) - - if len(not_joined) != 0: - await ctx.send('Not joined to: ' + ','.join(not_joined)) - else: - await ctx.send('We are connected to every channel') - def prepare(bot: TwitchBot): # Load our cog with this module... diff --git a/src/helpers/database_helper.py b/src/helpers/database_helper.py index 6f49bc7..fe12524 100644 --- a/src/helpers/database_helper.py +++ b/src/helpers/database_helper.py @@ -16,6 +16,7 @@ async def initialize(self): self.conn = await aiosqlite.connect(self.db_path, check_same_thread=False, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES) + await self.conn.execute('PRAGMA journal_mode = DELETE') self.conn.row_factory = aiosqlite.Row self.c = await self.conn.cursor() @@ -393,7 +394,16 @@ async def set_range_setting(self, twitch_username: str, setting_key: str, range_ await self.conn.commit() return range_low, range_high - async def get_all_users(self, limit: int = 100, offset: int = 0) -> List[sqlite3.Row]: + async def get_all_user_count(self) -> List[sqlite3.Row]: + """ + Gets all user count in users table + :return: + """ + result = await self.c.execute("SELECT COUNT(*) FROM users;") + value = await result.fetchone() + return value[0] + + async def get_users(self, limit: int = 100, offset: int = 0) -> List[sqlite3.Row]: """ Gets all users in db :return: diff --git a/src/helpers/logger.py b/src/helpers/logger.py new file mode 100644 index 0000000..4c840a9 --- /dev/null +++ b/src/helpers/logger.py @@ -0,0 +1,19 @@ +import logging +import os + + +class RonniaLogger(object): + def __new__(cls, name, *args, **kwargs): + logger = logging.getLogger(name) + logger.setLevel(os.getenv('LOG_LEVEL').upper()) + loggers_formatter = logging.Formatter( + '%(asctime)s | %(levelname)s | %(process)d | %(name)s | %(funcName)s | %(message)s', + datefmt='%d/%m/%Y %I:%M:%S') + + ch = logging.StreamHandler() + ch.setFormatter(loggers_formatter) + logger.addHandler(ch) + + logger.propagate = False + + return logger diff --git a/src/helpers/osu_api_helper.py b/src/helpers/osu_api_helper.py index 3062266..b8c8ba2 100644 --- a/src/helpers/osu_api_helper.py +++ b/src/helpers/osu_api_helper.py @@ -3,7 +3,6 @@ import json import logging import os -import time from typing import Union import aiohttp @@ -54,7 +53,7 @@ async def get_user_info(self, username: Union[str, int]): return None async def _get_endpoint(self, params: dict, endpoint: str): - self._wait_for_rate_limit() + await self._wait_for_rate_limit() timeout = aiohttp.ClientTimeout(total=5) async with aiohttp.ClientSession(timeout=timeout) as session: async with session.get(f'http://osu.ppy.sh/api/{endpoint}', params=params) as response: @@ -67,11 +66,11 @@ async def _get_endpoint(self, params: dict, endpoint: str): await self._messages_db.add_api_usage(endpoint) return r - def _wait_for_rate_limit(self): + async def _wait_for_rate_limit(self): now = datetime.datetime.now() time_passed = now - self._last_request_time if time_passed.total_seconds() < self._cooldown_seconds: - time.sleep(self._cooldown_seconds - time_passed.total_seconds()) + await asyncio.sleep(self._cooldown_seconds - time_passed.total_seconds()) self._last_request_time = datetime.datetime.now() diff --git a/src/requirements.txt b/src/requirements.txt index 4b7f56a..2ed336b 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -1,4 +1,6 @@ -twitchio==2.1.4 +twitchio==2.1.5 irc>=19.0.0,<21.0.0 -aiohttp>=3.7.0,<4.0.0 -aiosqlite==0.17.0 \ No newline at end of file +aiosqlite==0.17.0 +pyzmq==22.3.0 +tornado==6.1 +azure-servicebus==7.6.0 \ No newline at end of file diff --git a/tests/unit_tests/test_twitch_bot.py b/tests/unit_tests/test_twitch_bot.py index 9138156..e3c0d0b 100644 --- a/tests/unit_tests/test_twitch_bot.py +++ b/tests/unit_tests/test_twitch_bot.py @@ -93,7 +93,7 @@ async def test_handle_request_calls_check_request_criteria(self): send_irc_message_return_value = asyncio.Future() send_irc_message_return_value.set_result(None) - self.bot._send_irc_message = MagicMock(return_value=send_irc_message_return_value) + self.bot._send_beatmap_to_irc = MagicMock(return_value=send_irc_message_return_value) self.bot._check_message_contains_beatmap_link = MagicMock(return_value=(0, 'test_beatmap_id')) @@ -128,7 +128,7 @@ async def test_handle_request_calls_send_twitch_message_when_echo_enabled(self): send_irc_message_return_value = asyncio.Future() send_irc_message_return_value.set_result(None) - self.bot._send_irc_message = MagicMock(return_value=send_irc_message_return_value) + self.bot._send_beatmap_to_irc = MagicMock(return_value=send_irc_message_return_value) self.bot._check_user_cooldown = MagicMock() @@ -155,7 +155,7 @@ async def test_handle_request_adds_request_to_statistics_db(self): send_irc_message_return_value = asyncio.Future() send_irc_message_return_value.set_result(None) - self.bot._send_irc_message = MagicMock(return_value=send_irc_message_return_value) + self.bot._send_beatmap_to_irc = MagicMock(return_value=send_irc_message_return_value) self.bot._check_user_cooldown = MagicMock() @@ -241,7 +241,7 @@ async def test__send_irc_message_calls_irc_bot_send_message(self): _prepare_irc_message_return_value.set_result(None) self.bot._prepare_irc_message = MagicMock(return_value=_prepare_irc_message_return_value) - await self.bot._send_irc_message(msg, beatmap_info, mods) + await self.bot._send_beatmap_to_irc(msg, beatmap_info, mods) self.bot.irc_bot.send_message.assert_called_once() async def test__send_irc_message_calls__prepare_irc_message(self): @@ -257,7 +257,7 @@ async def test__send_irc_message_calls__prepare_irc_message(self): _prepare_irc_message_return_value.set_result(None) self.bot._prepare_irc_message = AsyncMock(return_value=_prepare_irc_message_return_value) - await self.bot._send_irc_message(msg, beatmap_info, mods) + await self.bot._send_beatmap_to_irc(msg, beatmap_info, mods) self.bot._prepare_irc_message.assert_called_once() async def test_event_ready_calls_fetch_users(self):