diff --git a/MANIFEST.in b/MANIFEST.in index 43965e228..1b3b1101c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,7 +2,6 @@ include LICENSE include README.rst include requirements.txt include requirements_mturk.txt -recursive-include otree/certs * recursive-include otree/static * recursive-include otree/templates * recursive-include otree/project_template * diff --git a/PKG-INFO b/PKG-INFO index 363beb1f4..5ab59f0d6 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: otree -Version: 2.3.2 +Version: 2.5.0 Summary: oTree is a toolset that makes it easy to create and administer web-based social science experiments. Home-page: http://otree.org/ Author: chris@otree.org diff --git a/otree/__init__.py b/otree/__init__.py index 911a4bbad..e74605d16 100644 --- a/otree/__init__.py +++ b/otree/__init__.py @@ -1,4 +1,4 @@ # setup.py imports this module, so this module must not import django # or any other 3rd party packages. -__version__ = '2.3.2' +__version__ = '2.5.0' default_app_config = 'otree.apps.OtreeConfig' diff --git a/otree/api.py b/otree/api.py index c38cca557..37b237b92 100644 --- a/otree/api.py +++ b/otree/api.py @@ -3,8 +3,8 @@ from otree.models import BaseSubsession, BaseGroup, BasePlayer # noqa from otree.constants import BaseConstants # noqa from otree.views import Page, WaitPage # noqa -from otree.common import Currency, currency_range, safe_json # noqa -from otree.bots import Bot, Submission, SubmissionMustFail # noqa - -models = _import_module('otree.models') -widgets = _import_module('otree.forms.widgets') +from otree.currency import Currency, currency_range # noqa +from otree.common import safe_json +from otree.bots import Bot, Submission, SubmissionMustFail, expect # noqa +from otree import models # noqa +from otree.forms import widgets # noqa diff --git a/otree/api.pyi b/otree/api.pyi index df669b7a1..f160caeea 100644 --- a/otree/api.pyi +++ b/otree/api.pyi @@ -1,5 +1,6 @@ -from typing import Union, List, Any -from otree.common import RealWorldCurrency, Currency +from typing import Union, List, Any, Optional + +from otree.currency import RealWorldCurrency, Currency class Currency(Currency): ''' @@ -7,9 +8,11 @@ class Currency(Currency): (if I import, it says the reference to Currency is not found) ''' -def currency_range(first, last, increment) -> List[Currency]: pass -def safe_json(obj): pass +def currency_range(first, last, increment) -> List[Currency]: + pass +def safe_json(obj): + pass # mocking the public API for PyCharm autocomplete. # one downside is that PyCharm doesn't seem to fully autocomplete arguments @@ -54,51 +57,49 @@ class models: def __getattr__(self, item): pass - class BooleanField(bool): def __init__( - self, - *, - choices=None, - widget=None, - initial=None, - label=None, - doc='', - blank=False, - **kwargs): + self, + *, + choices=None, + widget=None, + initial=None, + label=None, + doc='', + blank=False, + **kwargs + ): pass - class StringField(str): def __init__( - self, - *, - choices=None, - widget=None, - initial=None, - label=None, - doc='', - max_length=10000, - blank=False, - **kwargs): + self, + *, + choices=None, + widget=None, + initial=None, + label=None, + doc='', + max_length=10000, + blank=False, + **kwargs + ): pass - class LongStringField(str): def __init__( - self, - *, - initial=None, - label=None, - doc='', - max_length=None, - blank=False, - **kwargs): + self, + *, + initial=None, + label=None, + doc='', + max_length=None, + blank=False, + **kwargs + ): pass - # need to copy-paste the __init__ between # Integer, Float, and Currency # because if I use inheritance, PyCharm doesn't auto-complete # while typing args - class IntegerField(int): def __init__( self, @@ -111,9 +112,9 @@ class models: min=None, max=None, blank=False, - **kwargs): - pass - + **kwargs + ): + pass class FloatField(float): def __init__( self, @@ -126,9 +127,9 @@ class models: min=None, max=None, blank=False, - **kwargs): - pass - + **kwargs + ): + pass class CurrencyField(Currency): def __init__( self, @@ -141,149 +142,189 @@ class models: min=None, max=None, blank=False, - **kwargs): - pass - - + **kwargs + ): + pass class widgets: def __getattr__(self, item): pass - # don't need HiddenInput because you can just write # and then you know the element's selector - class CheckboxInput: pass - class RadioSelect: pass - class RadioSelectHorizontal: pass - class Slider: pass - + class CheckboxInput: + pass + class RadioSelect: + pass + class RadioSelectHorizontal: + pass class Session: - config = None # type: dict - vars = None # type: dict - num_participants = None # type: int - def get_participants(self) -> List[Participant]: pass - def get_subsessions(self) -> List[BaseSubsession]: pass + config: dict + vars: dict + num_participants: int + def get_participants(self) -> List[Participant]: + pass + def get_subsessions(self) -> List[BaseSubsession]: + pass class Participant: - session = None # type: Session - vars = None # type: dict - label = None # type: str - id_in_session = None # type: int - payoff = None # type: Currency - - def get_players(self) -> List[BasePlayer]: pass - def payoff_plus_participation_fee(self) -> RealWorldCurrency: pass - - -class BaseConstants: pass + session: Session + vars: dict + label: str + id_in_session: int + payoff: Currency + def get_players(self) -> List[BasePlayer]: + pass + def payoff_plus_participation_fee(self) -> RealWorldCurrency: + pass +class BaseConstants: + pass class BaseSubsession: - session = None # type: Session - round_number = None # type: int - - def get_groups(self) -> List[BaseGroup]: pass - def get_group_matrix(self) -> List[List[BasePlayer]]: pass + session: Session + round_number: int + def get_groups(self) -> List[BaseGroup]: + pass + def get_group_matrix(self) -> List[List[BasePlayer]]: + pass def set_group_matrix( - self, - group_matrix: Union[List[List[BasePlayer]],List[List[int]]]): pass - def get_players(self) -> List[BasePlayer]: pass - def in_previous_rounds(self) -> List['BaseSubsession']: pass - def in_all_rounds(self) -> List['BaseSubsession']: pass - def creating_session(self): pass - def in_round(self, round_number) -> 'BaseSubsession': pass - def in_rounds(self, first, last) -> List['BaseSubsession']: pass - def group_like_round(self, round_number: int): pass - def group_randomly(self, fixed_id_in_group: bool=False): pass - def vars_for_admin_report(self): pass - + self, group_matrix: Union[List[List[BasePlayer]], List[List[int]]] + ): + pass + def get_players(self) -> List[BasePlayer]: + pass + def in_previous_rounds(self) -> List[BaseSubsession]: + pass + def in_all_rounds(self) -> List[BaseSubsession]: + pass + def creating_session(self): + pass + def in_round(self, round_number) -> BaseSubsession: + pass + def in_rounds(self, first, last) -> List[BaseSubsession]: + pass + def group_like_round(self, round_number: int): + pass + def group_randomly(self, fixed_id_in_group: bool = False): + pass + def vars_for_admin_report(self) -> dict: + pass # this is so PyCharm doesn't flag attributes that are only defined on the app's Subsession, # not on the BaseSubsession - def __getattribute__(self, item): pass + def __getattribute__(self, item): + pass class BaseGroup: - session = None # type: Session - subsession = None # type: BaseSubsession - round_number = None # type: int - - def get_players(self) -> List[BasePlayer]: pass - def get_player_by_role(self, role) -> BasePlayer: pass - def get_player_by_id(self, id_in_group) -> BasePlayer: pass - def in_previous_rounds(self) -> List['BaseGroup']: pass - def in_all_rounds(self) -> List['BaseGroup']: pass - def in_round(self, round_number) -> 'BaseGroup': pass - def in_rounds(self, first: int, last: int) -> List['BaseGroup']: pass - - def __getattribute__(self, item): pass + session: Session + subsession: BaseSubsession + round_number: int + def get_players(self) -> List[BasePlayer]: + pass + def get_player_by_role(self, role) -> BasePlayer: + pass + def get_player_by_id(self, id_in_group) -> BasePlayer: + pass + def in_previous_rounds(self) -> List[BaseGroup]: + pass + def in_all_rounds(self) -> List[BaseGroup]: + pass + def in_round(self, round_number) -> BaseGroup: + pass + def in_rounds(self, first: int, last: int) -> List[BaseGroup]: + pass + def __getattribute__(self, item): + pass class BasePlayer: - id_in_group = None # type: int - payoff = None # type: Currency - participant = None # type: Participant - session = None # type: Session - group = None # type: BaseGroup - subsession = None # type: BaseSubsession - round_number = None # type: int - - def in_previous_rounds(self) -> List['BasePlayer']: pass - def in_all_rounds(self) -> List['BasePlayer']: pass - def get_others_in_group(self) -> List['BasePlayer']: pass - def get_others_in_subsession(self) -> List['BasePlayer']: pass - def role(self) -> str: pass - def in_round(self, round_number) -> 'BasePlayer': pass - def in_rounds(self, first, last) -> List['BasePlayer']: pass - - def __getattribute__(self, item): pass - + id_in_group: int + payoff: Currency + participant: Participant + session: Session + group: BaseGroup + subsession: BaseSubsession + round_number: int + def in_previous_rounds(self) -> List[BasePlayer]: + pass + def in_all_rounds(self) -> List[BasePlayer]: + pass + def get_others_in_group(self) -> List[BasePlayer]: + pass + def get_others_in_subsession(self) -> List[BasePlayer]: + pass + def role(self) -> str: + pass + def in_round(self, round_number) -> BasePlayer: + pass + def in_rounds(self, first, last) -> List[BasePlayer]: + pass + def __getattribute__(self, item): + pass class WaitPage: wait_for_all_groups = False group_by_arrival_time = False - title_text = None - body_text = None - template_name = None - round_number = None # type: int - participant = None # type: Participant - session = None # type: Session - - def is_displayed(self) -> bool: pass - def after_all_players_arrive(self): pass - def get_players_for_group(self, waiting_players): pass - + title_text: str + body_text: str + template_name: str + round_number: int + participant: Participant + session: Session + def is_displayed(self) -> bool: + pass + def after_all_players_arrive(self): + pass + def get_players_for_group(self, waiting_players) -> Optional[list]: + pass class Page: - round_number = None # type: int - template_name = None # type: str - timeout_seconds = None # type: int - timeout_submission = None # type: dict - timeout_happened = None # type: bool - timer_text = None # type: str - participant = None # type: Participant - session = None # type: Session - form_model = None # - form_fields = None # type: List[str] - - def get_form_fields(self) -> List['str']: pass - def vars_for_template(self) -> dict: pass - def before_next_page(self): pass - def is_displayed(self) -> bool: pass - def error_message(self, values): pass - def get_timeout_seconds(self): pass - + round_number: int + template_name: str + timeout_seconds: int + timeout_submission: dict + timeout_happened: bool + timer_text: str + participant: Participant + session: Session + form_model: str + form_fields: List[str] + def get_form_fields(self) -> List[str]: + pass + def vars_for_template(self) -> dict: + pass + def before_next_page(self): + pass + def is_displayed(self) -> bool: + pass + def error_message(self, values) -> Optional[str]: + pass + def get_timeout_seconds(self) -> Optional[float]: + pass + def app_after_this_page(self, upcoming_apps: List[str]) -> Optional[str]: + pass class Bot: - html = '' # type: str - case = None # type: Any - cases = [] # type: List - participant = None # type: Participant - session = None # type: Participant - round_number = None # type: int - -def Submission(PageClass, post_data: dict={}, *, check_html=True, timeout_happened=False): pass -def SubmissionMustFail(PageClass, post_data: dict={}, *, check_html=True, error_fields=[]): pass + html: str + case: Any + cases: List + participant: Participant + session: Participant + round_number: int + +def Submission( + PageClass, post_data: dict = {}, *, check_html=True, timeout_happened=False +): + pass + +def SubmissionMustFail( + PageClass, post_data: dict = {}, *, check_html=True, error_fields=[] +): + pass + +def expect(*args): + pass diff --git a/otree/app_template/models.py b/otree/app_template/models.py index 3ec45e612..561c8fdc3 100644 --- a/otree/app_template/models.py +++ b/otree/app_template/models.py @@ -1,6 +1,12 @@ from otree.api import ( - models, widgets, BaseConstants, BaseSubsession, BaseGroup, BasePlayer, - Currency as c, currency_range + models, + widgets, + BaseConstants, + BaseSubsession, + BaseGroup, + BasePlayer, + Currency as c, + currency_range, ) diff --git a/otree/app_template/pages.py b/otree/app_template/pages.py index 5e7439afd..bc06d2716 100644 --- a/otree/app_template/pages.py +++ b/otree/app_template/pages.py @@ -8,7 +8,6 @@ class MyPage(Page): class ResultsWaitPage(WaitPage): - def after_all_players_arrive(self): pass @@ -17,8 +16,4 @@ class Results(Page): pass -page_sequence = [ - MyPage, - ResultsWaitPage, - Results -] +page_sequence = [MyPage, ResultsWaitPage, Results] diff --git a/otree/app_template/tests.py b/otree/app_template/tests.py index 236de688c..fbb60094e 100644 --- a/otree/app_template/tests.py +++ b/otree/app_template/tests.py @@ -5,6 +5,5 @@ class PlayerBot(Bot): - def play_round(self): pass diff --git a/otree/apps.py b/otree/apps.py index f70def866..58334da91 100644 --- a/otree/apps.py +++ b/otree/apps.py @@ -1,26 +1,32 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- import logging import sys +import sqlite3 +import django.db.utils import colorama from django.apps import AppConfig -from django.conf import settings from django.db.models import signals -import otree -import otree.common_internal -from otree.common_internal import ( - ensure_superuser_exists -) +from otree.common import ensure_superuser_exists from otree.strict_templates import patch_template_silent_failures +try: + from psycopg2.errors import UndefinedColumn, UndefinedTable +except ModuleNotFoundError: + + class UndefinedColumn(Exception): + pass + + class UndefinedTable(Exception): + pass + logger = logging.getLogger('otree') def create_singleton_objects(sender, **kwargs): from otree.models_concrete import UndefinedFormModel + for ModelClass in [UndefinedFormModel]: # if it doesn't already exist, create one. ModelClass.objects.get_or_create() @@ -33,6 +39,33 @@ def create_singleton_objects(sender, **kwargs): ) +def patched_execute(self, sql, params=None): + try: + return self._execute_with_wrappers( + sql, params, many=False, executor=self._execute + ) + except Exception as exc: + + ExceptionClass = type(exc) + tb = sys.exc_info()[2] + # Django seems to reraise with new exceptions, so we need to look at the __cause__: + # sqlite3.OperationalError -> django.db.utils.OperationalError + # psycopg2.errors.UndefinedColumn -> django.db.utils.ProgrammingError + CauseClass = type(exc.__cause__) + + if CauseClass == sqlite3.OperationalError and 'locked' in str(exc): + raise ExceptionClass(f'{exc} - {SQLITE_LOCKING_ADVICE}.').with_traceback( + tb + ) from None + + # this will only work on postgres, but if they are using sqlite they should be using + # devserver anyway. + if CauseClass in (UndefinedColumn, UndefinedTable): + msg = f'{exc} - try resetting the database.' + raise ExceptionClass(msg).with_traceback(tb) from None + raise + + def monkey_patch_db_cursor(): '''Monkey-patch the DB cursor, to catch ProgrammingError and OperationalError. The alternative is to use middleware, but (1) @@ -42,57 +75,21 @@ def monkey_patch_db_cursor(): unrelated to resetdb. This is the most targeted location. ''' - - # In Django 2.0, this method is renamed to _execute. - def execute(self, sql, params=None): - self.db.validate_no_broken_transaction() - with self.db.wrap_database_errors: - try: - if params is None: - return self.cursor.execute(sql) - else: - return self.cursor.execute(sql, params) - except Exception as exc: - ExceptionClass = type(exc) - # it seems there are different exceptions all named - # OperationalError (django.db.OperationalError, - # sqlite.OperationalError, mysql....) - # so, simplest to use the string name - if ExceptionClass.__name__ in ( - 'OperationalError', 'ProgrammingError'): - # these error messages are localized, so we can't - # just check for substring 'column' or 'table' - # all the ProgrammingError and OperationalError - # instances I've seen so far are related to resetdb, - # except for "database is locked" - tb = sys.exc_info()[2] - if 'locked' in str(exc): - advice = SQLITE_LOCKING_ADVICE - import django.db.transaction - else: - advice = 'try resetting the database ("otree resetdb")' - - raise ExceptionClass('{} - {}.'.format( - exc, advice)).with_traceback(tb) from None - else: - raise - from django.db.backends import utils - utils.CursorWrapper.execute = execute + + utils.CursorWrapper.execute = patched_execute def setup_create_default_superuser(): signals.post_migrate.connect( - ensure_superuser_exists, - dispatch_uid='otree.create_superuser' + ensure_superuser_exists, dispatch_uid='otree.create_superuser' ) def setup_create_singleton_objects(): - signals.post_migrate.connect(create_singleton_objects, - dispatch_uid='create_singletons') - - + signals.post_migrate.connect( + create_singleton_objects, dispatch_uid='create_singletons' + ) class OtreeConfig(AppConfig): @@ -109,7 +106,6 @@ def ready(self): colorama.init(autoreset=True) import otree.checks + otree.checks.register_system_checks() patch_template_silent_failures() - - diff --git a/otree/bots/__init__.py b/otree/bots/__init__.py index 0d08d0874..2d1f003ef 100644 --- a/otree/bots/__init__.py +++ b/otree/bots/__init__.py @@ -1,27 +1 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# NOTE: this imports the following submodules and then subclasses several -# classes importing is done via import_module rather than an ordinary import. -# -# The only reason for this is to hide the base classes from IDEs like PyCharm, -# so that those members/attributes don't show up in autocomplete, -# including all the built-in django fields that an ordinary oTree programmer -# will never need or want. if this was a conventional Django project I wouldn't -# do it this way, but because oTree is aimed at newcomers who may need more -# assistance from their IDE, I want to try this approach out. -# -# This module is also a form of documentation of the public API. - -# 2016-07-18: not using the import_module trick for now, because currently, -# the PlayerBot class doesn't have any methods we need to hide -# from importlib import import_module -# otree_bot = import_module('otree.bots.bot') - -from importlib import import_module - -_bot_module = import_module('otree.bots.bot') - -Bot = _bot_module.PlayerBot -Submission = _bot_module.Submission -SubmissionMustFail = _bot_module.SubmissionMustFail +from .bot import PlayerBot as Bot, Submission, SubmissionMustFail, expect diff --git a/otree/bots/bot.py b/otree/bots/bot.py index 4118465cf..7f41b1b66 100644 --- a/otree/bots/bot.py +++ b/otree/bots/bot.py @@ -2,23 +2,24 @@ import re import decimal import logging -import abc -import six +import operator + from urllib.parse import unquote, urlsplit -from six.moves.html_parser import HTMLParser +from html.parser import HTMLParser +import otree.constants from otree.models_concrete import ParticipantToPlayerLookup from django import test -from django.core.urlresolvers import resolve +from django.urls import resolve from django.conf import settings from otree.currency import Currency -from django.apps import apps -from otree import constants_internal from otree.models import Participant, Session -from otree import common_internal -from otree.common_internal import ( - get_dotted_name, get_bots_module, get_admin_secret_code, - get_models_module +from otree import common +from otree.common import ( + get_dotted_name, + get_bots_module, + get_admin_secret_code, + get_models_module, ) ADMIN_SECRET_CODE = get_admin_secret_code() @@ -26,8 +27,11 @@ logger = logging.getLogger('otree.bots') INTERNAL_FORM_FIELDS = { - 'csrfmiddlewaretoken', 'must_fail', 'timeout_happened', - 'admin_secret_code', 'error_fields' + 'csrfmiddlewaretoken', + 'must_fail', + 'timeout_happened', + 'admin_secret_code', + 'error_fields', } DISABLE_CHECK_HTML_INSTRUCTIONS = ''' @@ -39,32 +43,92 @@ yield Submission(views.PageName, {{...}}, check_html=False) ''' -HTML_MISSING_BUTTON_WARNING = (''' +HTML_MISSING_BUTTON_WARNING = ( + ( + ''' Bot is trying to submit page {page_name}, but no button was found in the HTML of the page. (searched for with type='submit' or with type != 'button'). -''' + DISABLE_CHECK_HTML_INSTRUCTIONS).replace('\n', ' ').strip() +''' + + DISABLE_CHECK_HTML_INSTRUCTIONS + ) + .replace('\n', ' ') + .strip() +) -HTML_MISSING_FIELD_WARNING = (''' +HTML_MISSING_FIELD_WARNING = ( + ( + ''' Bot is trying to submit page {page_name} with fields: "{fields}", but these form fields were not found in the HTML of the page (searched for tags {tags} with name= attribute matching the field name). -''' + DISABLE_CHECK_HTML_INSTRUCTIONS).replace('\n', ' ').strip() +''' + + DISABLE_CHECK_HTML_INSTRUCTIONS + ) + .replace('\n', ' ') + .strip() +) -class ParticipantBot(test.Client): +class ExpectError(AssertionError): + pass + + +def expect(*args): + if len(args) == 2: + lhs, rhs = args + op = '==' + elif len(args) == 3: + lhs, op, rhs = args + else: + msg = f'expect() takes 2 or 3 arguments' + raise ValueError(msg) + + operators = { + '==': operator.eq, + '!=': operator.ne, + '>': operator.gt, + '<': operator.lt, + '>=': operator.ge, + '<=': operator.le, + # operator.contains() has args in opposite order (rhs, lhs), so use this: + 'in': lambda a, b: a in b, + 'not in': lambda a, b: a not in b, + } + + if op not in operators: + msg = f'"{op}" not allowed in expect()' + raise ValueError(msg) + res = operators[op](lhs, rhs) + if not res: + error_messages = { + '==': f'Expected {rhs!r}, actual value is {lhs!r}', + # rhs might be huge, can't print it + 'in': f'{lhs!r} was not found', + 'not in': f'{lhs!r} was not expected but was found anyway', + } + default_msg = f'Assertion failed: {lhs!r} {op} {rhs!r}' + msg = error_messages.get(op, default_msg) + raise ExpectError(msg) + + +class ParticipantBot(test.Client): def __init__( - self, participant: Participant=None, *, - lookups: List[ParticipantToPlayerLookup] = None, - load_player_bots=True, case_number=None + self, + participant: Participant = None, + *, + lookups: List[ParticipantToPlayerLookup] = None, + load_player_bots=True, + case_number=None, ): # usually lookups should be passed in. for ad-hoc testing, # ok to pass a participant if not lookups: lookups_with_duplicates = ParticipantToPlayerLookup.objects.filter( - participant_id=participant.id).order_by('player_pk') + participant_id=participant.id + ).order_by('player_pk') seen_player_pks = set() lookups = [] for lookup in lookups_with_duplicates: @@ -92,18 +156,14 @@ def __init__( bots_module = get_bots_module(app_name) player_bot = bots_module.PlayerBot( - lookup=lookup, case_number=case_number, - participant_bot=self + lookup=lookup, case_number=case_number, participant_bot=self ) self.player_bots.append(player_bot) self.submits_generator = self.get_submits() def open_start_url(self): - start_url = common_internal.participant_start_url(self.participant_code) - self.response = self.get( - start_url, - follow=True - ) + start_url = common.participant_start_url(self.participant_code) + self.response = self.get(start_url, follow=True) def get_submits(self): for player_bot in self.player_bots: @@ -132,6 +192,14 @@ def get_submits(self): pass else: raise + except ExpectError as exc: + # the point is to re-raise so that i can reference the original + # exception as exc.__cause__ or exc.__context__, since that exception + # is much smaller and doesn't have all the extra layers. + # pass it to response_for_exception. + # this results in much nicer output for browser bots (devserver and runprodserver) + # but keep the original message, which is needed for CLI bots + raise ExpectError(str(exc)) def _play_individually(self): '''convenience method for testing''' @@ -142,8 +210,8 @@ def _play_individually(self): def assert_html_ok(self, submission): if submission['check_html']: fields_to_check = [ - f for f in submission['post_data'] - if f not in INTERNAL_FORM_FIELDS] + f for f in submission['post_data'] if f not in INTERNAL_FORM_FIELDS + ] checker = PageHtmlChecker(fields_to_check) missing_fields = checker.get_missing_fields(self.html) if missing_fields: @@ -152,12 +220,16 @@ def assert_html_ok(self, submission): HTML_MISSING_FIELD_WARNING.format( page_name=page_name, fields=', '.join(missing_fields), - tags=', '.join('<{}>'.format(tag) - for tag in checker.field_tags))) + tags=', '.join( + '<{}>'.format(tag) for tag in checker.field_tags + ), + ) + ) if not checker.submit_button_found: page_name = submission['page_class'].url_name() - raise MissingHtmlButtonError(HTML_MISSING_BUTTON_WARNING.format( - page_name=page_name)) + raise MissingHtmlButtonError( + HTML_MISSING_BUTTON_WARNING.format(page_name=page_name) + ) def assert_correct_page(self, submission): PageClass = submission['page_class'] @@ -168,8 +240,9 @@ def assert_correct_page(self, submission): raise AssertionError( "Bot expects to be on page {}, " "but current page is {}. " - "Check your bot in tests.py, " - "then create a new session.".format(expected_url, actual_url)) + "Check your bot code, " + "then create a new session.".format(expected_url, actual_url) + ) @property def response(self): @@ -227,8 +300,11 @@ class PlayerBot: cases = [] def __init__( - self, case_number: int, participant_bot: ParticipantBot, - lookup: ParticipantToPlayerLookup): + self, + case_number: int, + participant_bot: ParticipantBot, + lookup: ParticipantToPlayerLookup, + ): app_name = lookup.app_name models_module = get_models_module(app_name) @@ -286,8 +362,6 @@ def html(self): return self.participant_bot.html - - class MissingHtmlButtonError(AssertionError): pass @@ -301,8 +375,14 @@ class BOTS_CHECK_HTML: def _Submission( - PageClass, post_data=None, *, check_html=BOTS_CHECK_HTML, - must_fail=False, error_fields=None, timeout_happened=False): + PageClass, + post_data=None, + *, + check_html=BOTS_CHECK_HTML, + must_fail=False, + error_fields=None, + timeout_happened=False, +): post_data = post_data or {} @@ -322,8 +402,8 @@ def _Submission( post_data['error_fields'] = error_fields if timeout_happened: - post_data[constants_internal.timeout_happened] = True - post_data[constants_internal.admin_secret_code] = ADMIN_SECRET_CODE + post_data[otree.constants.timeout_happened] = True + post_data[otree.constants.admin_secret_code] = ADMIN_SECRET_CODE # easy way to check if it's a wait page, without any messy imports if hasattr(PageClass, 'wait_for_all_groups'): @@ -347,24 +427,25 @@ def _Submission( def Submission( - PageClass, post_data=None, *, check_html=BOTS_CHECK_HTML, - timeout_happened=False): + PageClass, post_data=None, *, check_html=BOTS_CHECK_HTML, timeout_happened=False +): return _Submission( - PageClass, post_data, check_html=check_html, - timeout_happened=timeout_happened) + PageClass, post_data, check_html=check_html, timeout_happened=timeout_happened + ) def SubmissionMustFail( - PageClass, post_data=None, *, check_html=BOTS_CHECK_HTML, - error_fields=None + PageClass, post_data=None, *, check_html=BOTS_CHECK_HTML, error_fields=None ): '''lets you intentionally submit with invalid input to ensure it's correctly rejected''' return _Submission( PageClass, - post_data=post_data, check_html=check_html, must_fail=True, - error_fields=error_fields + post_data=post_data, + check_html=check_html, + must_fail=True, + error_fields=error_fields, ) @@ -390,7 +471,6 @@ def normalize_html_whitespace(html): class HtmlString(str): - def truncated(self): ''' Make output more readable by truncating everything before the @@ -412,7 +492,6 @@ def __repr__(self): # inherit from object for Python2.7 support. # otherwise, get class PageHtmlChecker(HTMLParser, object): - def __init__(self, fields_to_check): super().__init__() self.missing_fields = set(fields_to_check) @@ -448,9 +527,9 @@ def handle_starttag(self, tag, attrs): def is_wait_page(response): return ( - response.get(constants_internal.wait_page_http_header) == - constants_internal.get_param_truth_value) - + response.get(otree.constants.wait_page_http_header) + == otree.constants.get_param_truth_value + ) def bot_prettify_post_data(post_data): @@ -462,4 +541,4 @@ def bot_prettify_post_data(post_data): # 2018-03-25: why not use dict()? post_data = post_data.dict() - return {k: v for k,v in post_data.items() if k not in INTERNAL_FORM_FIELDS} + return {k: v for k, v in post_data.items() if k not in INTERNAL_FORM_FIELDS} diff --git a/otree/bots/browser.py b/otree/bots/browser.py index 6e2403168..952364bb1 100644 --- a/otree/bots/browser.py +++ b/otree/bots/browser.py @@ -1,25 +1,18 @@ -from typing import Dict import json -import threading import logging +import random +import threading +import traceback from collections import OrderedDict +from typing import Dict -import channels import otree.channels.utils as channel_utils -import traceback - -import otree.common_internal -from otree import common_internal - -from otree.common_internal import get_redis_conn - -from .runner import make_bots -from .bot import ParticipantBot -import random - +import otree.common +from otree import common +from otree.common import get_redis_conn from otree.models import Session -from channels.layers import get_channel_layer - +from .bot import ParticipantBot +from .runner import make_bots REDIS_KEY_PREFIX = 'otree-bots' @@ -48,6 +41,7 @@ class BotRequestError(Exception): and passed through Redis. if USE_REDIS==False, this will raise normally. ''' + pass @@ -66,7 +60,7 @@ class Worker: def __init__(self, redis_conn=None): self.redis_conn = redis_conn self.participants_by_session = OrderedDict() - self.browser_bots = {} # type: Dict[str, ParticipantBot] + self.browser_bots = {} # type: Dict[str, ParticipantBot] def initialize_session(self, session_pk, case_number): self.prune() @@ -76,6 +70,7 @@ def initialize_session(self, session_pk, case_number): if case_number is None: # choose one randomly from otree.session import SessionConfig + config = SessionConfig(session.config) num_cases = config.get_num_bot_cases() case_number = random.choice(range(num_cases)) @@ -84,8 +79,7 @@ def initialize_session(self, session_pk, case_number): session_pk=session_pk, case_number=case_number, use_browser_bots=True ) for bot in bots: - self.participants_by_session[session_pk].append( - bot.participant_code) + self.participants_by_session[session_pk].append(bot.participant_code) self.browser_bots[bot.participant_code] = bot def prune(self): @@ -101,7 +95,8 @@ def get_bot(self, participant_code): return self.browser_bots[participant_code] except KeyError: msg = PARTICIPANT_NOT_IN_BOTWORKER_MSG.format( - participant_code=participant_code, prune_limit=SESSIONS_PRUNE_LIMIT) + participant_code=participant_code, prune_limit=SESSIONS_PRUNE_LIMIT + ) raise BotRequestError(msg) def get_next_post_data(self, participant_code): @@ -164,10 +159,7 @@ def try_process_one_redis_message(self): response = {'error': str(exc)} except Exception as exc: # un-anticipated error - response = { - 'error': repr(exc), - 'traceback': traceback.format_exc() - } + response = {'error': repr(exc), 'traceback': traceback.format_exc()} # don't raise, because then this would crash. # logger.exception() will record the full traceback logger.exception(repr(exc)) @@ -187,7 +179,8 @@ def ping(redis_conn, *, timeout): timeouts piling up. ''' response_key = redis_enqueue_method_call( - redis_conn=redis_conn, method_name='ping', method_kwargs={}) + redis_conn=redis_conn, method_name='ping', method_kwargs={} + ) # make it very long, so we don't get spurious ping errors result = redis_conn.blpop(response_key, timeout) @@ -211,7 +204,7 @@ def load_redis_response_dict(response_bytes: bytes): if 'traceback' in response: # cram the other traceback in this traceback message. # note: - raise common_internal.BotError(response['traceback']) + raise common.BotError(response['traceback']) elif 'error' in response: # handled exception raise BotRequestError(response['error']) @@ -224,12 +217,8 @@ def redis_flush_bots(redis_conn): def redis_enqueue_method_call(redis_conn, method_name, method_kwargs) -> str: - response_key = '{}-{}'.format(REDIS_KEY_PREFIX, random.randint(1,10**9)) - msg = { - 'method': method_name, - 'kwargs': method_kwargs, - 'response_key': response_key, - } + response_key = '{}-{}'.format(REDIS_KEY_PREFIX, random.randint(1, 10 ** 9)) + msg = {'method': method_name, 'kwargs': method_kwargs, 'response_key': response_key} redis_conn.rpush(REDIS_KEY_PREFIX, json.dumps(msg)) return response_key @@ -250,21 +239,18 @@ def redis_get_method_retval(redis_conn, response_key: str) -> dict: if result is None: # ping will raise if it times out ping(redis_conn, timeout=3) - raise Exception( - 'botworker is running but did not return a submission.' - ) + raise Exception('botworker is running but did not return a submission.') key, submit_bytes = result return load_redis_response_dict(submit_bytes) def wrap_method_call(method_name: str, method_kwargs): - if otree.common_internal.USE_REDIS: + if otree.common.USE_REDIS: redis_conn = get_redis_conn() response_key = redis_enqueue_method_call( - redis_conn=redis_conn, method_name=method_name, - method_kwargs=method_kwargs) - return redis_get_method_retval( - redis_conn=redis_conn, response_key=response_key) + redis_conn=redis_conn, method_name=method_name, method_kwargs=method_kwargs + ) + return redis_get_method_retval(redis_conn=redis_conn, response_key=response_key) else: method = getattr(browser_bot_worker, method_name) return method(**method_kwargs) @@ -283,20 +269,18 @@ def initialize_session(**kwargs): # timeout must be int. # my tests show that it can initialize about 3000 players per second. # so 300-500 is conservative, plus pad for a few seconds - #timeout = int(6 + num_players_total / 500) + # timeout = int(6 + num_players_total / 500) # maybe number of ParticipantToPlayerLookups? - timeout = 6 # FIXME: adjust to number of players + timeout = 6 # FIXME: adjust to number of players return wrap_method_call('initialize_session', kwargs) def send_completion_message(*, session_code, participant_code): group_name = channel_utils.browser_bots_launcher_group(session_code) - channel_utils.sync_group_send( - group_name, - { - 'text': participant_code, - 'type': 'send_completion_message' - } + channel_utils.sync_group_send_wrapper( + group=group_name, + type='send_completion_message', + event={'text': participant_code}, ) diff --git a/otree/bots/browser_launcher.py b/otree/bots/browser_launcher.py index e8d893098..1065af59e 100644 --- a/otree/bots/browser_launcher.py +++ b/otree/bots/browser_launcher.py @@ -1,17 +1,17 @@ import logging -from subprocess import check_output, Popen +import os import sys import time -import os -from requests import session as requests_session +from enum import Enum +from subprocess import check_output, Popen +from urllib.parse import urljoin + from django.conf import settings from django.urls import reverse -from urllib.parse import urljoin -import otree.channels.utils as channel_utils +from ws4py.client.threadedclient import WebSocketClient +import otree.channels.utils as channel_utils from otree.session import SESSION_CONFIGS_DICT -from ws4py.client.threadedclient import WebSocketClient -from enum import Enum AUTH_FAILURE_MESSAGE = """ Could not login to the server using your ADMIN_USERNAME @@ -23,6 +23,15 @@ logger = logging.getLogger(__name__) +try: + from requests import session as requests_session +except ModuleNotFoundError: + sys.exit( + 'To use command-line browser bots, you need to install the "requests" library locally. ' + 'Do: "pip3 install requests"' + ) + + class OSEnum(Enum): windows = 'windows' mac = 'mac' @@ -34,15 +43,13 @@ class OSEnum(Enum): 'chrome': [ 'C:/Program Files (x86)/Google/Chrome/Application/chrome.exe', 'C:/Program Files/Google/Chrome/Application/chrome.exe', - os.getenv('LOCALAPPDATA', '') + r"\Google\Chrome\Application\chrome.exe", - ], + os.getenv('LOCALAPPDATA', '') + r"/Google/Chrome/Application/chrome.exe", + ] }, OSEnum.mac: { - 'chrome': ['/Applications/Google Chrome.app/Contents/MacOS/Google Chrome'], + 'chrome': ['/Applications/Google Chrome.app/Contents/MacOS/Google Chrome'] }, - OSEnum.linux: { - 'chrome': ['google-chrome'], - } + OSEnum.linux: {'chrome': ['google-chrome']}, } @@ -63,10 +70,10 @@ class URLs: WEBSOCKET_COMPLETED_MESSAGE = b'closed_by_browser_launcher' +WEBSOCKET_1000 = 1000 class OtreeWebSocketClient(WebSocketClient): - def __init__(self, *args, session_size, **kwargs): self.session_size = session_size self.seen_participant_codes = set() @@ -82,16 +89,17 @@ def received_message(self, message): self.seen_participant_codes.add(code) self.participants_finished += 1 if self.participants_finished == self.session_size: - self.close(reason=WEBSOCKET_COMPLETED_MESSAGE, code=1000) + self.close(reason=WEBSOCKET_COMPLETED_MESSAGE, code=WEBSOCKET_1000) def closed(self, code, reason=None): ''' make sure the websocket closed properly, not because of server-side exception etc. ''' - if reason != WEBSOCKET_COMPLETED_MESSAGE: + # i used to check "reason", but for some reason it's always an empty string. + if code != WEBSOCKET_1000: logger.error( - f'Lost connection with server.' + f'Lost connection with server. ' f'code: {code}, reason: "{reason}".' 'Check the oTree server logs for errors.' ) @@ -111,7 +119,6 @@ def run_websocket_client_until_finished(*, websocket_url, session_size) -> float class Launcher: - def __init__(self, *, session_config_name, server_url, num_participants): self.session_config_name = session_config_name self.server_url = server_url @@ -131,8 +138,7 @@ def run(self): if session_config_name: if session_config_name not in SESSION_CONFIGS_DICT: raise ValueError( - 'No session config named "{}"'.format( - session_config_name) + 'No session config named "{}"'.format(session_config_name) ) session_config_names = [session_config_name] @@ -148,13 +154,16 @@ def run(self): session_config = SESSION_CONFIGS_DICT[session_config_name] num_bot_cases = session_config.get_num_bot_cases() for case_number in range(num_bot_cases): - num_participants = (self.num_participants or - session_config['num_demo_participants']) - sessions_to_create.append({ - 'session_config_name': session_config_name, - 'num_participants': num_participants, - 'case_number': case_number, - }) + num_participants = ( + self.num_participants or session_config['num_demo_participants'] + ) + sessions_to_create.append( + { + 'session_config_name': session_config_name, + 'num_participants': num_participants, + 'case_number': case_number, + } + ) total_time_spent = 0 # run in a separate loop, because we want to validate upfront @@ -163,9 +172,7 @@ def run(self): for session_to_create in sessions_to_create: total_time_spent += self.run_session(**session_to_create) - print('Total: {} seconds'.format( - round(total_time_spent, 1) - )) + print('Total: {} seconds'.format(round(total_time_spent, 1))) # don't delete sessions -- it's too susceptible to race conditions # between sending the completion message and loading the last page @@ -173,8 +180,7 @@ def run(self): # just label these sessions clearly in the admin UI # and make it easy to delete manually - def run_session( - self, session_config_name, num_participants, case_number): + def run_session(self, session_config_name, num_participants, case_number): self.close_existing_session() browser_process = self.launch_browser(num_participants) @@ -183,7 +189,8 @@ def run_session( print(row_fmt.format(session_config_name, num_participants), end='') session_code = self.create_session( - session_config_name, num_participants, case_number) + session_config_name, num_participants, case_number + ) time_spent = self.websocket_listen(session_code, num_participants) print('...finished in {} seconds'.format(time_spent)) @@ -200,15 +207,14 @@ def websocket_listen(self, session_code, num_participants) -> float: # seems that urljoin doesn't work with ws:// urls # so do the ws replace after URLjoin websocket_url = urljoin( - self.server_url, - channel_utils.browser_bots_launcher_path(session_code) + self.server_url, channel_utils.browser_bots_launcher_path(session_code) + ) + websocket_url = websocket_url.replace('http://', 'ws://').replace( + 'https://', 'wss://' ) - websocket_url = websocket_url.replace( - 'http://', 'ws://').replace('https://', 'wss://') return run_websocket_client_until_finished( - websocket_url=websocket_url, - session_size=num_participants, + websocket_url=websocket_url, session_size=num_participants ) def set_urls(self): @@ -221,10 +227,7 @@ def set_urls(self): self.server_url = server_url # CREATE_SESSION URL - self.create_session_url = urljoin( - server_url, - URLs.create_browser_bots, - ) + self.create_session_url = urljoin(server_url, URLs.create_browser_bots) # LOGIN URL # TODO: use reverse? reverse('django.contrib.auth.views.login') @@ -264,7 +267,6 @@ def server_configuration_check(self): resp = self.client.get(self.create_session_url) assert resp.ok - def ping_server(self): logging.getLogger("requests").setLevel(logging.WARNING) @@ -276,32 +278,24 @@ def ping_server(self): except: raise Exception( - 'Could not connect to server at {}.' + f'Could not connect to server at {self.server_url}.' 'Before running this command, ' - 'you need to run the server (see --server-url flag).'.format( - self.server_url, - ) + 'you need to run the server (see --server-url flag).' ) if not resp.ok: raise Exception( - 'Could not open page at {}.' - '(HTTP status code: {})'.format( - self.login_url, - resp.status_code, - ) + f'Could not open page at {self.login_url}.' + f'(HTTP status code: {resp.status_code})' ) - def create_session( - self, session_config_name, num_participants, case_number - ): - + def create_session(self, session_config_name, num_participants, case_number): resp = self.post( self.create_session_url, - data={ - 'session_config_name': session_config_name, - 'num_participants': num_participants, - 'case_number': case_number, - } + data=dict( + session_config_name=session_config_name, + num_participants=num_participants, + case_number=case_number, + ), ) assert resp.ok, 'Failed to create session. Check the server logs.' session_code = resp.text @@ -330,7 +324,9 @@ def check_browser(self): process_list_args = ['tasklist'] else: process_list_args = ['ps', 'axw'] - ps_output = check_output(process_list_args).decode(sys.stdout.encoding, 'ignore') + ps_output = check_output(process_list_args).decode( + sys.stdout.encoding, 'ignore' + ) is_running = browser_type.lower() in ps_output.lower() if is_running: @@ -342,8 +338,7 @@ def check_browser(self): def close_existing_session(self): # make sure room is closed - resp = self.post( - urljoin(self.server_url, URLs.close_browser_bots)) + resp = self.post(urljoin(self.server_url, URLs.close_browser_bots)) if not resp.ok: raise AssertionError( 'Request to close existing browser bots session failed. ' @@ -351,13 +346,21 @@ def close_existing_session(self): ) def launch_browser(self, num_participants): - wait_room_url = urljoin( - self.server_url, - URLs.browser_bots_start, - ) + wait_room_url = urljoin(self.server_url, URLs.browser_bots_start) for browser_cmd in self.browser_cmds: args = [browser_cmd] + if os.environ.get('BROWSER_BOTS_USE_HEADLESS'): + args.append('--headless') + # needed in windows + args.append('--disable-gpu') + + # for some reason --screenshot OR --remote-debugging-port is necessary to get my JS to execute?!? + # NO idea why. --remote-debugging-port gets me further than --screenshot, which gets stuck + # on skip_lookahead + # --remote-debugging-port=9222 works also + args.append('--remote-debugging-port=9222') + for i in range(num_participants): args.append(wait_room_url) try: @@ -374,4 +377,4 @@ def launch_browser(self, num_participants): # we should show the original exception, because it might have # valuable info about why the browser didn't launch, # not raise from None. - raise FileNotFoundError(msg) \ No newline at end of file + raise FileNotFoundError(msg) diff --git a/otree/bots/conftest.py b/otree/bots/conftest.py deleted file mode 100644 index 3634f21bf..000000000 --- a/otree/bots/conftest.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -Conftest module to be used with bots, not the tests themselves. -""" - - -def pytest_addoption(parser): - parser.addoption("--session_config_name") - parser.addoption("--num_participants", type=int) - parser.addoption("--export_path") - - -def pytest_generate_tests(metafunc): - '''pass command line args to the test function''' - option = metafunc.config.option - - metafunc.parametrize( - "session_config_name,num_participants,export_path", - [[option.session_config_name, option.num_participants, option.export_path]] - ) diff --git a/otree/bots/runner.py b/otree/bots/runner.py index 4471acc65..c4d0d23e3 100644 --- a/otree/bots/runner.py +++ b/otree/bots/runner.py @@ -1,22 +1,20 @@ -from typing import List +import datetime import logging +import os from collections import OrderedDict, defaultdict from pathlib import Path -from django.conf import settings -import pytest +from typing import List -import otree.session -import otree.common_internal +from django.conf import settings -from .bot import ParticipantBot -import datetime -import os -import codecs +import otree.common import otree.export -from otree.constants_internal import AUTO_NAME_BOTS_EXPORT_FOLDER -from otree.models_concrete import ParticipantToPlayerLookup +import otree.session +from otree.constants import AUTO_NAME_BOTS_EXPORT_FOLDER from otree.models import Session, Participant +from otree.models_concrete import ParticipantToPlayerLookup from otree.session import SESSION_CONFIGS_DICT +from .bot import ParticipantBot logger = logging.getLogger(__name__) @@ -73,8 +71,9 @@ def make_bots(*, session_pk, case_number, use_browser_bots) -> List[ParticipantB # can't use .distinct('player_pk') because it only works on Postgres # this implicitly orders by round also - lookups = ParticipantToPlayerLookup.objects.filter( - session_pk=session_pk).order_by('page_index') + lookups = ParticipantToPlayerLookup.objects.filter(session_pk=session_pk).order_by( + 'page_index' + ) seen_players = set() lookups_per_participant = defaultdict(list) @@ -99,16 +98,12 @@ def run_bots(session: Session, case_number=None): runner.play() -# function name needs to start with "test" for pytest to discover it -# in this module -@pytest.mark.django_db(transaction=True) -def test_all_bots_for_session_config( - session_config_name, num_participants, export_path): +def run_all_bots_for_session_config( + session_config_name, num_participants, export_path +): """ - this means all configs and test cases are in 1 big test case. + this means all test cases are in 1 big test case. so if 1 fails, the others will not get run. - to separate them, we would need to move some of this code - to pytest_generate_tests in conftest.py """ if session_config_name: session_config_names = [session_config_name] @@ -116,22 +111,29 @@ def test_all_bots_for_session_config( session_config_names = SESSION_CONFIGS_DICT.keys() for config_name in session_config_names: - config = SESSION_CONFIGS_DICT[config_name] - - bot_modules = [f'{app_name}.tests' for app_name in config['app_sequence']] - pytest.register_assert_rewrite(*bot_modules) + try: + config = SESSION_CONFIGS_DICT[config_name] + except KeyError: + # important to alert the user, since people might be trying to enter app names. + msg = f"No session config with name '{config_name}'." + raise Exception(msg) from None num_bot_cases = config.get_num_bot_cases() for case_number in range(num_bot_cases): - logger.info("Creating '{}' session (test case {})".format( - config_name, case_number)) + logger.info( + "Creating '{}' session (test case {})".format( + config_name, case_number + ) + ) session = otree.session.create_session( session_config_name=config_name, - num_participants=(num_participants or config['num_demo_participants']), + num_participants=( + num_participants or config['num_demo_participants'] + ), ) - run_bots(session, case_number=case_number) + logger.info('Bots completed session') if export_path: @@ -143,9 +145,8 @@ def test_all_bots_for_session_config( os.makedirs(export_path, exist_ok=True) - for app in settings.INSTALLED_OTREE_APPS: - model_module = otree.common_internal.get_models_module(app) + model_module = otree.common.get_models_module(app) if model_module.Player.objects.exists(): fpath = Path(export_path, "{}.csv".format(app)) with fpath.open("w", encoding="utf8") as fp: @@ -155,3 +156,8 @@ def test_all_bots_for_session_config( otree.export.export_wide(fp, 'csv') logger.info('Exported CSV to folder "{}"'.format(export_path)) + else: + logger.info( + 'Tip: Run this command with the --export flag' + ' to save the data generated by bots.' + ) diff --git a/otree/certs/development.crt b/otree/certs/development.crt deleted file mode 100644 index efb3265d8..000000000 --- a/otree/certs/development.crt +++ /dev/null @@ -1,18 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIC7TCCAlYCCQDYvIq6zKvBwDANBgkqhkiG9w0BAQUFADCBujELMAkGA1UEBhMC -VVMxGjAYBgNVBAgMEURFVkVMT1BNRU5UIFNUQVRFMRkwFwYDVQQHDBBERVZFTE9Q -TUVOVCBDSVRZMRwwGgYDVQQKDBNERVZFTE9QTUVOVCBDT01QQU5ZMRowGAYDVQQL -DBFESkFOR08gREVWRUxPUEVSUzESMBAGA1UEAwwJbG9jYWxob3N0MSYwJAYJKoZI -hvcNAQkBFhdkZXZlbG9wbWVudEBleGFtcGxlLmNvbTAeFw0xMzAzMDEwMzQ1NDda -Fw0yMzAyMjcwMzQ1NDdaMIG6MQswCQYDVQQGEwJVUzEaMBgGA1UECAwRREVWRUxP -UE1FTlQgU1RBVEUxGTAXBgNVBAcMEERFVkVMT1BNRU5UIENJVFkxHDAaBgNVBAoM -E0RFVkVMT1BNRU5UIENPTVBBTlkxGjAYBgNVBAsMEURKQU5HTyBERVZFTE9QRVJT -MRIwEAYDVQQDDAlsb2NhbGhvc3QxJjAkBgkqhkiG9w0BCQEWF2RldmVsb3BtZW50 -QGV4YW1wbGUuY29tMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCTC7RORlCZ -L25D5HEUTrXQIiKCoGesuw3Mbyl05PK2HmyiOKBD2dL/l1JIqfZKIFrT4RxDpGnr -R3RgQ4J3mc7osladyrlyqCvLjPXCJV8aaks87fNPW4fE8z0liaeiqbOCwJAZsLE3 -JmkvhFLOtsd2CfDRxGHMyBka0ou3N6l3cQIDAQABMA0GCSqGSIb3DQEBBQUAA4GB -AFLhm2A7w1q4KfcP7HyWJFuDF1LaKNq3+z3qrEHZXnL2hLW8dAfQISMCfNFqZuio -x8QLRRiJ/qvI4P+oQGkPs8Yz31uMmstZenDnjl8fmopBm4mJqtLi1VT/O/pZmFUG -dmNM3HnuRwqIdrKmxWgI1e7vErV8vVNStWhL0ukNHapr ------END CERTIFICATE----- diff --git a/otree/certs/development.key b/otree/certs/development.key deleted file mode 100644 index 124f50f23..000000000 --- a/otree/certs/development.key +++ /dev/null @@ -1,15 +0,0 @@ ------BEGIN RSA PRIVATE KEY----- -MIICWwIBAAKBgQCTC7RORlCZL25D5HEUTrXQIiKCoGesuw3Mbyl05PK2HmyiOKBD -2dL/l1JIqfZKIFrT4RxDpGnrR3RgQ4J3mc7osladyrlyqCvLjPXCJV8aaks87fNP -W4fE8z0liaeiqbOCwJAZsLE3JmkvhFLOtsd2CfDRxGHMyBka0ou3N6l3cQIDAQAB -AoGAdWG2gXWoCWDPiOrnSeq7QHa/Tb92g3Cexz9FvMa26aLH3YeOiBtuUBIf4Vmr -/ehuGQ1uXqD03Jih0eaSU584h1tuESnKMJjmN3EPdIzx3VhRP4Oo28khFGAbsNFe -NRFPj9w2IZk9wRbgipnm4CcqNGmJ/plCmVq7/txjKM5t/EECQQDEAvSdfFV1jZkF -A1Equr5sJRsZ6GQyNjcAVxqSL0Q63JMjkrq8xcLlSuTQ9tcV2gIOZSrVD60V1d63 -FE3riKiNAkEAwAxkLmqSQDotA9ALBPwJa/Fc+XwTLd/7g8K/XoJW4UguhewtOE/u -rsKkX7IKd8DT7raTZ37RC0aS89IMuifrdQJAc3MuMyhNiay6KVq3zwwpJreAS/U2 -NuD56mhjjSDr9iN/Qt+kv5VX4wgG2BHbw9IhjesGnHHcR9UtlfYOoyFd7QJASj3A -EK2EIi4bLskjKWchYUgqMAwGAgr/WR1VC30Jhwd3bLAzfvxvgcGe95uFLmwtwa90 -5mKA/4Hl1znRT7mU7QJAAVZi9pj9WXQrOJVT9Lj0/Dq9HB4O9kEG7ofKV95HbvZp -mcrEPOc1LyIiIJKuLe2gCP4FgEsCWmLc0qU62B/QBg== ------END RSA PRIVATE KEY----- diff --git a/otree/channels/asgi_redis.py b/otree/channels/asgi_redis.py deleted file mode 100644 index f8fc0b79f..000000000 --- a/otree/channels/asgi_redis.py +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import sys - -import asgi_redis -import redis.exceptions -import six - - -class RedisChannelLayer(asgi_redis.RedisChannelLayer): - - # In SAL experiment, we got 503 "queue full" errors when using ~50 - # browser bots. This occurred even after i enabled multiple botworkers. - def __init__(self, *args, **kwargs): - kwargs.setdefault('capacity', 10000) - super().__init__(*args, **kwargs) diff --git a/otree/channels/consumers.py b/otree/channels/consumers.py index ca09d2abc..08f12d995 100644 --- a/otree/channels/consumers.py +++ b/otree/channels/consumers.py @@ -1,35 +1,36 @@ -import json +import base64 +import datetime +import io import logging +import time +import traceback +import urllib.parse + import django.db import django.utils.timezone -import traceback -import time -from channels.generic.websocket import ( - JsonWebsocketConsumer, WebsocketConsumer) -from channels.consumer import SyncConsumer +from channels.db import database_sync_to_async +from channels.generic.websocket import AsyncJsonWebsocketConsumer, WebsocketConsumer +from django.conf import settings from django.core.signing import Signer, BadSignature +from django.shortcuts import reverse + +import otree.bots.browser +import otree.channels.utils as channel_utils import otree.session from otree.channels.utils import get_chat_group -from otree.models import Participant, Session +from otree.common import get_models_module +from otree.export import export_wide, export_app +from otree.models import Participant from otree.models_concrete import ( - CompletedGroupWaitPage, CompletedSubsessionWaitPage, ChatMessage) -from otree.common_internal import ( - get_models_module + CompletedGroupWaitPage, + CompletedSubsessionWaitPage, + ChatMessage, + WaitPagePassage, ) -import otree.channels.utils as channel_utils -from otree.models_concrete import ( - ParticipantRoomVisit, - BrowserBotsLauncherSessionCode) +from otree.models_concrete import ParticipantRoomVisit, BrowserBotsLauncherSessionCode from otree.room import ROOM_DICT -import otree.bots.browser -from otree.export import export_wide, export_app -import io -import base64 -import datetime -from django.conf import settings -from django.shortcuts import reverse -from otree.views.admin import CreateSessionForm from otree.session import SESSION_CONFIGS_DICT +from otree.views.admin import CreateSessionForm logger = logging.getLogger(__name__) @@ -41,15 +42,10 @@ class InvalidWebSocketParams(Exception): '''exception to raise when websocket params are invalid''' -class _OTreeJsonWebsocketConsumer(JsonWebsocketConsumer): +class _OTreeAsyncJsonWebsocketConsumer(AsyncJsonWebsocketConsumer): """ This is not public API, might change at any time. """ - def group_send_channel(self, type: str, groups=None, **event): - for group in (groups or self.groups): - channel_utils.sync_group_send(group, {'type': type, **event}) - #print('call_args', channel_utils.sync_group_send.call_args) - #assert channel_utils.sync_group_send.call_args def clean_kwargs(self, **kwargs): ''' @@ -67,187 +63,253 @@ def group_name(self, **kwargs): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.cleaned_kwargs = self.clean_kwargs(**self.scope['url_route']['kwargs']) - self.groups = self.connection_groups() - - def connection_groups(self, **kwargs): group_name = self.group_name(**self.cleaned_kwargs) - return [group_name] + self.groups = [group_name] if group_name else [] unrestricted_when = '' # there is no login_required for channels # so we need to make our own # https://github.com/django/channels/issues/1241 - def connect(self): + async def connect(self): AUTH_LEVEL = settings.AUTH_LEVEL auth_required = ( - (not self.unrestricted_when) and AUTH_LEVEL - or - self.unrestricted_when == UNRESTRICTED_IN_DEMO_MODE and AUTH_LEVEL == 'STUDY' + (not self.unrestricted_when) + and AUTH_LEVEL + or self.unrestricted_when == UNRESTRICTED_IN_DEMO_MODE + and AUTH_LEVEL == 'STUDY' ) if auth_required and not self.scope['user'].is_staff: - msg = 'rejected un-authenticated access to websocket path {}'.format(self.scope['path']) - logger.warning(msg) + msg = 'rejected un-authenticated access to websocket path {}'.format( + self.scope['path'] + ) + # print(msg) + logger.error(msg) # consider also self.accept() then send error message then self.close(code=1008) # this only affects otree core websockets. else: # need to accept no matter what, so we can at least send # an error message - self.accept() - self.post_connect(**self.cleaned_kwargs) + await self.accept() + await self.post_connect(**self.cleaned_kwargs) - def post_connect(self, **kwargs): + async def post_connect(self, **kwargs): pass - def disconnect(self, message, **kwargs): - self.pre_disconnect(**self.cleaned_kwargs) + async def disconnect(self, message, **kwargs): + await self.pre_disconnect(**self.cleaned_kwargs) - def pre_disconnect(self, **kwargs): + async def pre_disconnect(self, **kwargs): pass - def receive_json(self, content, **etc): - self.post_receive_json(content, **self.cleaned_kwargs) + async def receive_json(self, content, **etc): + await self.post_receive_json(content, **self.cleaned_kwargs) - def post_receive_json(self, content, **kwargs): + async def post_receive_json(self, content, **kwargs): pass -class GroupByArrivalTime(_OTreeJsonWebsocketConsumer): +class BaseWaitPage(_OTreeAsyncJsonWebsocketConsumer): + unrestricted_when = ALWAYS_UNRESTRICTED + kwarg_names: list + + def clean_kwargs(self): + d = parse_querystring(self.scope['query_string']) + kwargs = {} + for k in self.kwarg_names: + kwargs[k] = int(d[k]) + return kwargs + + async def wait_page_ready(self, event=None): + await self.send_json({'status': 'ready'}) + + async def pre_disconnect(self, session_pk, participant_id, **kwargs): + + await create_waitpage_passage( + participant_id=participant_id, session_pk=session_pk, is_enter=False + ) + + +class SubsessionWaitPage(BaseWaitPage): + + kwarg_names = ('session_pk', 'page_index', 'participant_id') + + def group_name(self, session_pk, page_index, participant_id): + return channel_utils.subsession_wait_page_name(session_pk, page_index) + + def completion_exists(self, **kwargs): + return CompletedSubsessionWaitPage.objects.filter(**kwargs).exists() + + async def post_connect(self, session_pk, page_index, participant_id): + if await database_sync_to_async(self.completion_exists)( + page_index=page_index, session_id=session_pk + ): + await self.wait_page_ready() + await create_waitpage_passage( + participant_id=participant_id, session_pk=session_pk, is_enter=True + ) + + +class GroupWaitPage(BaseWaitPage): + + kwarg_names = SubsessionWaitPage.kwarg_names + ('group_id_in_subsession',) + + def group_name( + self, session_pk, page_index, group_id_in_subsession, participant_id + ): + return channel_utils.group_wait_page_name( + session_pk, page_index, group_id_in_subsession + ) + + def completion_exists(self, **kwargs): + return CompletedGroupWaitPage.objects.filter(**kwargs).exists() + + async def post_connect( + self, session_pk, page_index, group_id_in_subsession, participant_id + ): + if await database_sync_to_async(self.completion_exists)( + page_index=page_index, + id_in_subsession=group_id_in_subsession, + session_id=session_pk, + ): + await self.wait_page_ready() + await create_waitpage_passage( + participant_id=participant_id, session_pk=session_pk, is_enter=True + ) + + +class GroupByArrivalTime(_OTreeAsyncJsonWebsocketConsumer): unrestricted_when = ALWAYS_UNRESTRICTED - def clean_kwargs(self, params): - session_pk, page_index, app_name, player_id = params.split(',') + def clean_kwargs(self): + d = parse_querystring(self.scope['query_string']) return { - 'app_name': app_name, - 'session_pk': int(session_pk), - 'page_index': int(page_index), - 'player_id': int(player_id) + 'app_name': d['app_name'], + 'session_pk': int(d['session_pk']), + 'participant_id': int(d['participant_id']), + 'page_index': int(d['page_index']), + 'player_id': int(d['player_id']), } - def group_name(self, app_name, player_id, page_index, session_pk): - gn = channel_utils.gbat_group_name( - session_pk, page_index) + def group_name(self, app_name, player_id, page_index, session_pk, participant_id): + gn = channel_utils.gbat_group_name(session_pk, page_index) return gn - def post_connect(self, app_name, player_id, page_index, session_pk): + def is_ready(self, *, app_name, player_id, page_index, session_pk): models_module = get_models_module(app_name) - group_id_in_subsession = models_module.Group.objects.filter( - player__id=player_id).values_list( - 'id_in_subsession', flat=True)[0] + group_id_in_subsession = ( + models_module.Group.objects.filter(player__id=player_id) + .values_list('id_in_subsession', flat=True) + .get() + ) - ready = CompletedGroupWaitPage.objects.filter( + return CompletedGroupWaitPage.objects.filter( page_index=page_index, id_in_subsession=int(group_id_in_subsession), session_id=session_pk, ).exists() - if ready: - self.wait_page_ready() - def wait_page_ready(self, event=None): - self.send_json({'status': 'ready'}) - - -class WaitPage(_OTreeJsonWebsocketConsumer): + async def post_connect( + self, app_name, player_id, page_index, session_pk, participant_id + ): + if await database_sync_to_async(self.is_ready)( + app_name=app_name, + player_id=player_id, + page_index=page_index, + session_pk=session_pk, + ): + await self.gbat_ready() + await create_waitpage_passage( + participant_id=participant_id, session_pk=session_pk, is_enter=True + ) - unrestricted_when = ALWAYS_UNRESTRICTED + async def gbat_ready(self, event=None): + await self.send_json({'status': 'ready'}) - def clean_kwargs(self, params): - session_pk, page_index, group_id_in_subsession = params.split(',') - return { - 'session_pk': int(session_pk), - 'page_index': int(page_index), - # don't convert group_id_in_subsession to int yet, it might be null - 'group_id_in_subsession': group_id_in_subsession, - } + async def pre_disconnect( + self, app_name, player_id, page_index, session_pk, participant_id + ): + await create_waitpage_passage( + participant_id=participant_id, session_pk=session_pk, is_enter=False + ) - def group_name(self, session_pk, page_index, group_id_in_subsession): - return channel_utils.wait_page_group_name( - session_pk, page_index, group_id_in_subsession) - def post_connect(self, session_pk, page_index, group_id_in_subsession): - # in case message was sent before this web socket connects - if group_id_in_subsession: - ready = CompletedGroupWaitPage.objects.filter( - page_index=page_index, - id_in_subsession=int(group_id_in_subsession), - session_id=session_pk, - ).exists() - else: # subsession - ready = CompletedSubsessionWaitPage.objects.filter( - page_index=page_index, - session_id=session_pk, - ).exists() - if ready: - self.wait_page_ready() - - def wait_page_ready(self, event=None): - self.send_json({'status': 'ready'}) - - -class DetectAutoAdvance(_OTreeJsonWebsocketConsumer): +class DetectAutoAdvance(_OTreeAsyncJsonWebsocketConsumer): unrestricted_when = ALWAYS_UNRESTRICTED - def clean_kwargs(self, params): - participant_code, page_index = params.split(',') + def clean_kwargs(self): + d = parse_querystring(self.scope['query_string']) return { - 'participant_code': participant_code, - 'page_index': int(page_index), + 'participant_code': d['participant_code'], + 'page_index': int(d['page_index']), } def group_name(self, page_index, participant_code): return channel_utils.auto_advance_group(participant_code) - def post_connect(self, page_index, participant_code): - # in case message was sent before this web socket connects - result = Participant.objects.filter( - code=participant_code).values_list( - '_index_in_pages', flat=True) + def page_should_be_on(self, participant_code): try: - page_should_be_on = result[0] - except IndexError: - # doesn't get shown because not yet localized - self.send_json({'error': 'Participant not found in database.'}) + return ( + Participant.objects.filter(code=participant_code) + .values_list('_index_in_pages', flat=True) + .get() + ) + except Participant.DoesNotExist: return - if page_should_be_on > page_index: - self.auto_advanced() - def auto_advanced(self, event=None): - self.send_json({'auto_advanced': True}) + async def post_connect(self, page_index, participant_code): + # in case message was sent before this web socket connects + page_should_be_on = await database_sync_to_async(self.page_should_be_on)( + participant_code + ) + if page_should_be_on is None: + await self.send_json({'error': 'Participant not found in database.'}) + elif page_should_be_on > page_index: + await self.auto_advanced() + async def auto_advanced(self, event=None): + await self.send_json({'auto_advanced': True}) -class BaseCreateSession(_OTreeJsonWebsocketConsumer): - def connection_groups(self, **kwargs): - return [] +class BaseCreateSession(_OTreeAsyncJsonWebsocketConsumer): + def group_name(self, **kwargs): + return None + + async def send_response_to_browser(self, event: dict): + raise NotImplemented + + async def create_session_then_send_start_link( + self, use_browser_bots, **session_kwargs + ): - def create_session_then_send_start_link(self, use_browser_bots, **session_kwargs): try: - session = otree.session.create_session(**session_kwargs) + session = await database_sync_to_async(otree.session.create_session)( + **session_kwargs + ) if use_browser_bots: - otree.bots.browser.initialize_session( - session_pk=session.pk, - case_number=None + await database_sync_to_async(otree.bots.browser.initialize_session)( + session_pk=session.pk, case_number=None ) - session.save() except Exception as e: # full error message is printed to console (though sometimes not?) error_message = 'Failed to create session: "{}"'.format(e) traceback_str = traceback.format_exc() - self.send_json(dict( - error=error_message, - traceback=traceback_str, - )) + await self.send_response_to_browser( + dict(error=error_message, traceback=traceback_str) + ) raise - session_home_view = 'MTurkCreateHIT' if session.is_mturk() else 'SessionStartLinks' + session_home_view = ( + 'MTurkCreateHIT' if session.is_mturk() else 'SessionStartLinks' + ) - self.send_json( + await self.send_response_to_browser( {'session_url': reverse(session_home_view, args=[session.code])} ) @@ -256,23 +318,25 @@ class CreateDemoSession(BaseCreateSession): unrestricted_when = UNRESTRICTED_IN_DEMO_MODE - def post_receive_json(self, form_data: dict): + async def send_response_to_browser(self, event: dict): + await self.send_json(event) + + async def post_receive_json(self, form_data: dict): session_config_name = form_data['session_config'] config = SESSION_CONFIGS_DICT.get(session_config_name) if not config: msg = f'Session config "{session_config_name}" does not exist.' - self.send_json( - {'validation_errors': msg}) + await self.send_json({'validation_errors': msg}) return num_participants = config['num_demo_participants'] use_browser_bots = config.get('use_browser_bots', False) - self.create_session_then_send_start_link( + await self.create_session_then_send_start_link( session_config_name=session_config_name, use_browser_bots=use_browser_bots, num_participants=num_participants, - is_demo=True + is_demo=True, ) @@ -280,13 +344,13 @@ class CreateSession(BaseCreateSession): unrestricted_when = None - def connection_groups(self, **kwargs): - return [] + def group_name(self, **kwargs): + return 'create_session' - def post_receive_json(self, form_data: dict): + async def post_receive_json(self, form_data: dict): form = CreateSessionForm(data=form_data) if not form.is_valid(): - self.send_json({'validation_errors': form.errors}) + await self.send_json({'validation_errors': form.errors}) return session_config_name = form.cleaned_data['session_config'] @@ -330,14 +394,13 @@ def post_receive_json(self, form_data: dict): edited_session_config_fields[field] = new_value use_browser_bots = edited_session_config_fields.get( - 'use_browser_bots', - config.get('use_browser_bots', False) + 'use_browser_bots', config.get('use_browser_bots', False) ) # if room_name is missing, it will be empty string room_name = form.cleaned_data['room_name'] or None - self.create_session_then_send_start_link( + await self.create_session_then_send_start_link( session_config_name=session_config_name, num_participants=num_participants, is_demo=False, @@ -348,78 +411,101 @@ def post_receive_json(self, form_data: dict): ) if room_name: - self.group_send_channel( + await channel_utils.group_send_wrapper( type='room_session_ready', group=channel_utils.room_participants_group_name(room_name), - status='session_ready' + event={}, ) + async def send_response_to_browser(self, event: dict): + ''' + Send to a group instead of the channel only, + because if the websocket disconnects during creation of a large session, + (due to temporary network error, etc, or Heroku H15, 55 seconds without ping) + the user could be stuck on "please wait" forever. + the downside is that if two admins create sessions around the same time, + your page could automatically redirect to the other admin's session. + ''' + [group] = self.groups + await channel_utils.group_send_wrapper( + type='session_created', group=group, event=event + ) + + async def session_created(self, event): + await self.send_json(event) + -class RoomAdmin(_OTreeJsonWebsocketConsumer): +class RoomAdmin(_OTreeAsyncJsonWebsocketConsumer): unrestricted_when = None def group_name(self, room): return channel_utils.room_admin_group_name(room) - def post_connect(self, room): + def get_list(self, **kwargs): + + # make it JSON serializable + return list( + ParticipantRoomVisit.objects.filter(**kwargs).values_list( + 'participant_label', flat=True + ) + ) + + async def post_connect(self, room): room_object = ROOM_DICT[room] now = time.time() stale_threshold = now - 15 - present_list = ParticipantRoomVisit.objects.filter( - room_name=room_object.name, - last_updated__gte=stale_threshold, - ).values_list('participant_label', flat=True) - - # make it JSON serializable - present_list = list(present_list) + present_list = await database_sync_to_async(self.get_list)( + room_name=room_object.name, last_updated__gte=stale_threshold + ) - self.send_json({ - 'status': 'load_participant_lists', - 'participants_present': present_list, - }) + await self.send_json( + {'status': 'load_participant_lists', 'participants_present': present_list} + ) # prune very old visits -- don't want a resource leak # because sometimes not getting deleted on WebSocket disconnect very_stale_threshold = now - 10 * 60 - ParticipantRoomVisit.objects.filter( - room_name=room_object.name, - last_updated__lt=very_stale_threshold, - ).delete() + await database_sync_to_async(self.delete_old_visits)( + room_name=room_object.name, last_updated__lt=very_stale_threshold + ) - def roomadmin_update(self, event): + def delete_old_visits(self, **kwargs): + ParticipantRoomVisit.objects.filter(**kwargs).delete() + + async def roomadmin_update(self, event): del event['type'] - self.send_json(event) + await self.send_json(event) -class RoomParticipant(_OTreeJsonWebsocketConsumer): +class RoomParticipant(_OTreeAsyncJsonWebsocketConsumer): unrestricted_when = ALWAYS_UNRESTRICTED - def clean_kwargs(self, params): - room_name, participant_label, tab_unique_id = params.split(',') - return { - 'room_name': room_name, - 'participant_label': participant_label, - 'tab_unique_id': tab_unique_id, - } + def clean_kwargs(self): + d = parse_querystring(self.scope['query_string']) + d.setdefault('participant_label', '') + return d def group_name(self, room_name, participant_label, tab_unique_id): return channel_utils.room_participants_group_name(room_name) - def post_connect(self, room_name, participant_label, tab_unique_id): + def create_participant_room_visit(self, **kwargs): + ParticipantRoomVisit.objects.create(**kwargs) + + async def post_connect(self, room_name, participant_label, tab_unique_id): if room_name in ROOM_DICT: room = ROOM_DICT[room_name] else: # doesn't get shown because not yet localized - self.send_json({'error': 'Invalid room name "{}".'.format(room_name)}) + await self.send_json({'error': 'Invalid room name "{}".'.format(room_name)}) return - if room.has_session(): - self.room_session_ready() + if await database_sync_to_async(room.has_session)(): + await self.room_session_ready() else: try: - ParticipantRoomVisit.objects.create( + await database_sync_to_async(self.create_participant_room_visit)( participant_label=participant_label, room_name=room_name, tab_unique_id=tab_unique_id, @@ -434,52 +520,56 @@ def post_connect(self, room_name, participant_label, tab_unique_id): # 2017-09-17: I saw the integrityerror on macOS. # previously, we logged this, but i see no need to do that. pass - channel_utils.sync_group_send( - channel_utils.room_admin_group_name(room_name), - { - 'type': 'roomadmin.update', - 'status': 'add_participant', - 'participant': participant_label - } + await channel_utils.group_send_wrapper( + type='roomadmin_update', + group=channel_utils.room_admin_group_name(room_name), + event={'status': 'add_participant', 'participant': participant_label}, ) - def pre_disconnect(self, room_name, participant_label, tab_unique_id): + def delete_visit(self, **kwargs): + ParticipantRoomVisit.objects.filter(**kwargs).delete() + + def visit_exists(self, **kwargs): + return ParticipantRoomVisit.objects.filter(**kwargs).exists() + + async def pre_disconnect(self, room_name, participant_label, tab_unique_id): + if room_name in ROOM_DICT: room = ROOM_DICT[room_name] else: # doesn't get shown because not yet localized - self.send_json({'error': 'Invalid room name "{}".'.format(room_name)}) + await self.send_json({'error': 'Invalid room name "{}".'.format(room_name)}) return # should use filter instead of get, # because if the DB is recreated, # the record could already be deleted - ParticipantRoomVisit.objects.filter( + await database_sync_to_async(self.delete_visit)( participant_label=participant_label, room_name=room_name, - tab_unique_id=tab_unique_id).delete() + tab_unique_id=tab_unique_id, + ) - event = { - 'type': 'roomadmin.update', - 'status': 'remove_participant', - } + event = {'status': 'remove_participant'} if room.has_participant_labels(): - if ParticipantRoomVisit.objects.filter( - participant_label=participant_label, - room_name=room_name - ).exists(): + if await database_sync_to_async(self.visit_exists)( + participant_label=participant_label, room_name=room_name + ): return # it's ok if there is a race condition -- # in JS removing a participant is idempotent event['participant'] = participant_label admin_group = channel_utils.room_admin_group_name(room_name) - channel_utils.sync_group_send(admin_group, event) - def room_session_ready(self): - self.send_json({'status': 'session_ready'}) + await channel_utils.group_send_wrapper( + group=admin_group, type='roomadmin_update', event=event + ) + + async def room_session_ready(self, event=None): + await self.send_json({'status': 'session_ready'}) -class BrowserBotsLauncher(_OTreeJsonWebsocketConsumer): +class BrowserBotsLauncher(_OTreeAsyncJsonWebsocketConsumer): # OK to be unrestricted because this websocket doesn't create the session, # or do anything sensitive. @@ -488,27 +578,30 @@ class BrowserBotsLauncher(_OTreeJsonWebsocketConsumer): def group_name(self, session_code): return channel_utils.browser_bots_launcher_group(session_code) - def send_completion_message(self, event): + async def send_completion_message(self, event): # don't need to put in JSON since it's just a participant code - self.send(event['text']) + await self.send(event['text']) -class BrowserBot(_OTreeJsonWebsocketConsumer): +class BrowserBot(_OTreeAsyncJsonWebsocketConsumer): unrestricted_when = ALWAYS_UNRESTRICTED def group_name(self): return 'browser_bot_wait' - def post_connect(self): - launcher_session_info = BrowserBotsLauncherSessionCode.objects.first() - if launcher_session_info: - self.browserbot_sessionready() + def session_exists(self): + return BrowserBotsLauncherSessionCode.objects.exists() + + async def post_connect(self): + if await database_sync_to_async(self.session_exists)(): + await self.browserbot_sessionready() + + async def browserbot_sessionready(self, event=None): + await self.send_json({'status': 'session_ready'}) - def browserbot_sessionready(self): - self.send_json({'status': 'session_ready'}) -class ChatConsumer(_OTreeJsonWebsocketConsumer): +class ChatConsumer(_OTreeAsyncJsonWebsocketConsumer): unrestricted_when = ALWAYS_UNRESTRICTED @@ -522,26 +615,28 @@ def clean_kwargs(self, params): channel, participant_id = original_params.split('/') - return { - 'channel': channel, - 'participant_id': int(participant_id), - } + return {'channel': channel, 'participant_id': int(participant_id)} def group_name(self, channel, participant_id): return get_chat_group(channel) - def post_connect(self, channel, participant_id): - - history = ChatMessage.objects.filter( - channel=channel).order_by('timestamp').values( - 'nickname', 'body', 'participant_id' + def _get_history(self, channel): + return list( + ChatMessage.objects.filter(channel=channel) + .order_by('timestamp') + .values('nickname', 'body', 'participant_id') ) + async def post_connect(self, channel, participant_id): + + history = await database_sync_to_async(self._get_history)(channel=channel) + # Convert ValuesQuerySet to list # but is it ok to send a list (not a dict) as json? - self.send_json(list(history)) + await self.send_json(history) + + async def post_receive_json(self, content, channel, participant_id): - def post_receive_json(self, content, channel, participant_id): # in the Channels docs, the example has a separate msg_consumer # channel, so this can be done asynchronously. # but i think the perf is probably good enough. @@ -550,28 +645,26 @@ def post_receive_json(self, content, channel, participant_id): nickname = Signer().unsign(nickname_signed) body = content['body'] - chat_message = dict( - nickname=nickname, - body=body, - participant_id=participant_id - ) + chat_message = dict(nickname=nickname, body=body, participant_id=participant_id) - self.group_send_channel('chat_sendmessages', chats=[chat_message]) + [group] = self.groups + await channel_utils.group_send_wrapper( + type='chat_sendmessages', group=group, event={'chats': [chat_message]} + ) - ChatMessage.objects.create( - participant_id=participant_id, - channel=channel, - body=body, - nickname=nickname + await database_sync_to_async(self._create_message)( + participant_id=participant_id, channel=channel, body=body, nickname=nickname ) + def _create_message(self, **kwargs): + ChatMessage.objects.create(**kwargs) - def chat_sendmessages(self, event): + async def chat_sendmessages(self, event): chats = event['chats'] - self.send_json(chats) + await self.send_json(chats) -class ExportData(_OTreeJsonWebsocketConsumer): +class ExportData(_OTreeAsyncJsonWebsocketConsumer): ''' I load tested this locally with sqlite/redis and: @@ -581,7 +674,7 @@ class ExportData(_OTreeJsonWebsocketConsumer): unrestricted_when = None - def post_receive_json(self, content: dict): + async def post_receive_json(self, content: dict): ''' if an app name is given, export the app. otherwise, export all the data (wide). @@ -592,7 +685,9 @@ def post_receive_json(self, content: dict): app_name = content.get('app_name') if file_extension == 'xlsx': - mime_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + mime_type = ( + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + ) IOClass = io.BytesIO else: mime_type = 'text/csv' @@ -601,53 +696,54 @@ def post_receive_json(self, content: dict): iso_date = datetime.date.today().isoformat() with IOClass() as fp: if app_name: - export_app(app_name, fp, file_extension=file_extension) + await database_sync_to_async(export_app)( + app_name, fp, file_extension=file_extension + ) file_name_prefix = app_name else: - export_wide(fp, file_extension=file_extension) + await database_sync_to_async(export_wide)( + fp, file_extension=file_extension + ) file_name_prefix = 'all_apps_wide' data = fp.getvalue() - file_name = '{}_{}.{}'.format( - file_name_prefix, iso_date, file_extension) + file_name = f'{file_name_prefix}_{iso_date}.{file_extension}' if file_extension == 'xlsx': data = base64.b64encode(data).decode('utf-8') - content.update( - file_name=file_name, - data=data, - mime_type=mime_type - ) + content.update(file_name=file_name, data=data, mime_type=mime_type) # this doesn't go through channel layer, so it is probably safer # in terms of sending large data + await self.send_json(content) - self.send_json(content) - - def connection_groups(self, **kwargs): - return [] + def group_name(self, **kwargs): + return None class NoOp(WebsocketConsumer): pass -class LifespanApp: - ''' - temporary shim for https://github.com/django/channels/issues/1216 - needed so that hypercorn doesn't display an error. - this uses ASGI 2.0 format, not the newer 3.0 single callable - ''' +def parse_querystring(query_string) -> dict: + '''it seems parse_qs omits keys with empty values''' + return {k: v[0] for k, v in urllib.parse.parse_qs(query_string.decode()).items()} + + +async def create_waitpage_passage(*, participant_id, session_pk, is_enter): + await database_sync_to_async(_create_waitpage_passage)( + participant_id=participant_id, session_pk=session_pk, is_enter=is_enter + ) + - def __init__(self, scope): - self.scope = scope - - async def __call__(self, receive, send): - if self.scope['type'] == 'lifespan': - while True: - message = await receive() - if message['type'] == 'lifespan.startup': - await send({'type': 'lifespan.startup.complete'}) - elif message['type'] == 'lifespan.shutdown': - await send({'type': 'lifespan.shutdown.complete'}) - return +def _create_waitpage_passage(*, participant_id, session_pk, is_enter): + '''if the session was deleted, this would raise''' + try: + WaitPagePassage.objects.create( + participant_id=participant_id, + session_id=session_pk, + is_enter=is_enter, + epoch_time=time.time(), + ) + except: + pass diff --git a/otree/channels/routing.py b/otree/channels/routing.py index b3195d4ec..ee39323fe 100644 --- a/otree/channels/routing.py +++ b/otree/channels/routing.py @@ -9,72 +9,33 @@ websocket_routes = [ # WebSockets - url( - r'^wait_page/(?P[\w,]+)/$', - consumers.WaitPage, - ), - url( - r'^group_by_arrival_time/(?P[\w,\.]+)/$', - consumers.GroupByArrivalTime, - ), - url( - r'^auto_advance/(?P[\w,]+)/$', - consumers.DetectAutoAdvance, - ), - url( - r'^create_session/$', - consumers.CreateSession, - ), - url( - r'^create_demo_session/$', - consumers.CreateDemoSession, - ), - url( - r'^wait_for_session_in_room/(?P[\w,]+)/$', - consumers.RoomParticipant, - ), - url( - r'^room_without_session/(?P\w+)/$', - consumers.RoomAdmin, - ), - url( - r'^browser_bots_client/(?P\w+)/$', - consumers.BrowserBotsLauncher, - ), - url( - r'^browser_bot_wait/$', - consumers.BrowserBot, - ), + url(r'^wait_page/$', consumers.GroupWaitPage), + url(r'^subsession_wait_page/$', consumers.SubsessionWaitPage), + url(r'^group_by_arrival_time/$', consumers.GroupByArrivalTime), + url(r'^auto_advance/$', consumers.DetectAutoAdvance), + url(r'^create_session/$', consumers.CreateSession), + url(r'^create_demo_session/$', consumers.CreateDemoSession), + url(r'^wait_for_session_in_room/$', consumers.RoomParticipant), + url(r'^room_without_session/(?P\w+)/$', consumers.RoomAdmin), + url(r'^browser_bots_client/(?P\w+)/$', consumers.BrowserBotsLauncher), + url(r'^browser_bot_wait/$', consumers.BrowserBot), url( # so it doesn't clash with addon r"^otreechat_core/(?P[a-zA-Z0-9_/-]+)/$", consumers.ChatConsumer, ), - url( - r"^export/$", - consumers.ExportData, - ), + url(r"^export/$", consumers.ExportData), # for django autoreloader # just so client can detect when server has finished restarting - url( - r'^no_op/$', - consumers.NoOp, - ), + url(r'^no_op/$', consumers.NoOp), ] extensions_modules = get_extensions_modules('routing') for extensions_module in extensions_modules: - if hasattr(extensions_module, 'channel_routing'): - raise Exception( - f'The extension {extensions_module} is built for an older version ' - 'of oTree (2.2). You should remove it from your EXTENSION_APPS setting.' - ) websocket_routes += getattr(extensions_module, 'websocket_routes', []) -application = ProtocolTypeRouter({ - # WebSocket chat handler - "websocket": AuthMiddlewareStack(URLRouter(websocket_routes)), - "lifespan": consumers.LifespanApp, -}) +application = ProtocolTypeRouter( + {"websocket": AuthMiddlewareStack(URLRouter(websocket_routes))} +) diff --git a/otree/channels/utils.py b/otree/channels/utils.py index d32c6a838..63d6910e8 100644 --- a/otree/channels/utils.py +++ b/otree/channels/utils.py @@ -1,67 +1,93 @@ from django.core.signing import Signer from asgiref.sync import async_to_sync from channels.layers import get_channel_layer +from urllib.parse import urlencode -group_send = get_channel_layer().group_send -send = get_channel_layer().send -sync_group_send = async_to_sync(group_send) -sync_send = async_to_sync(send) +_group_send = get_channel_layer().group_send +_sync_group_send = async_to_sync(_group_send) -def wait_page_group_name(session_id, page_index, - group_id_in_subsession=''): +def sync_group_send_wrapper(*, type: str, group: str, event: dict): + '''make it a function that takes proper args that are intuitive. + enforces correct use. + ''' + return _sync_group_send(group, {'type': type, **event}) + + +def group_send_wrapper(*, type: str, group: str, event: dict): + '''make it a function that takes proper args that are intuitive. + ''' + return _group_send(group, {'type': type, **event}) + + +def group_wait_page_name(session_id, page_index, group_id_in_subsession): return 'wait-page-{}-page{}-{}'.format( - session_id, page_index, group_id_in_subsession) + session_id, page_index, group_id_in_subsession + ) + + +def subsession_wait_page_name(session_id, page_index): + + return 'wait-page-{}-page{}'.format(session_id, page_index) def gbat_group_name(session_id, page_index): - return 'group_by_arrival_time_session{}_page{}'.format( - session_id, page_index) + return 'group_by_arrival_time_session{}_page{}'.format(session_id, page_index) + + +def gbat_path(**kwargs): + return '/group_by_arrival_time/?' + urlencode(kwargs) -def gbat_path(session_id, index_in_pages, app_name, player_id): - return '/group_by_arrival_time/{},{},{},{}/'.format( - session_id, index_in_pages, app_name, player_id - ) def room_participants_group_name(room_name): return 'room-participants-{}'.format(room_name) -def room_participant_path(room_name, participant_label, tab_unique_id): - return '/wait_for_session_in_room/{},{},{}/'.format( - room_name, participant_label, tab_unique_id - ) +def room_participant_path(**kwargs): + return '/wait_for_session_in_room/?' + urlencode(kwargs) + def room_admin_group_name(room_name): return f'room-admin-{room_name}' + def room_admin_path(room_name): return '/room_without_session/{}/'.format(room_name) + def create_session_path(): return '/create_session/' + def create_demo_session_path(): return '/create_demo_session/' -def wait_page_path(session_pk, index_in_pages, group_id_in_subsession=''): - return '/wait_page/{},{},{}/'.format( - session_pk, index_in_pages, group_id_in_subsession - ) + +def group_wait_page_path(**kwargs): + return '/wait_page/?' + urlencode(kwargs) + + +def subsession_wait_page_path(**kwargs): + return '/subsession_wait_page/?' + urlencode(kwargs) + def browser_bots_launcher_group(session_code): return 'browser-bots-client-{}'.format(session_code) + def browser_bots_launcher_path(session_code): return '/browser_bots_client/{}/'.format(session_code) -def auto_advance_path(participant_code, page_index): - return '/auto_advance/{},{}/'.format(participant_code, page_index) + +def auto_advance_path(**kwargs): + return '/auto_advance/?' + urlencode(kwargs) + def auto_advance_group(participant_code): return f'auto-advance-{participant_code}' + def chat_path(channel, participant_id): channel_and_id = '{}/{}'.format(channel, participant_id) channel_and_id_signed = Signer(sep='/').sign(channel_and_id) @@ -70,4 +96,4 @@ def chat_path(channel, participant_id): def get_chat_group(channel): - return 'otreechat-{}'.format(channel) \ No newline at end of file + return 'otreechat-{}'.format(channel) diff --git a/otree/chat.py b/otree/chat.py index bac955eaa..2b3755d27 100644 --- a/otree/chat.py +++ b/otree/chat.py @@ -3,9 +3,11 @@ from otree.channels import utils as channel_utils from django.utils.translation import ugettext as _ + class ChatTagError(Exception): pass + class UNDEFINED: pass @@ -24,14 +26,15 @@ def chat_template_tag(context, *, channel=UNDEFINED, nickname=UNDEFINED): if not re.match(r'^[a-zA-Z0-9_-]+$', channel): raise ChatTagError( "'channel' can only contain ASCII letters, numbers, underscores, and hyphens. " - "Value given was: {}".format(channel)) + "Value given was: {}".format(channel) + ) # prefix the channel name with session code and app name prefixed_channel = '{}-{}-{}'.format( context['session'].id, Constants.name_in_url, # previously used a hash() here to ensure name_in_url is the same, # but hash() is non-reproducible across processes - channel + channel, ) context['channel'] = prefixed_channel @@ -53,9 +56,9 @@ def chat_template_tag(context, *, channel=UNDEFINED, nickname=UNDEFINED): # Translators: the name someone sees displayed for themselves in a chat. # It's their nickname followed by "(Me)". For example: # "Michael (Me)" or "Player 1 (Me)". - 'nickname_i_see_for_myself': _("{nickname} (Me)").format(nickname=nickname) + 'nickname_i_see_for_myself': _("{nickname} (Me)").format(nickname=nickname), } context['chat_vars_for_js'] = chat_vars_for_js - return context \ No newline at end of file + return context diff --git a/otree/checks.py b/otree/checks.py new file mode 100644 index 000000000..9cfe2918d --- /dev/null +++ b/otree/checks.py @@ -0,0 +1,291 @@ +import inspect +import os +from importlib import import_module +from pathlib import Path + +from django.core.checks import Error, Warning, register + +from otree import common +from otree.api import BasePlayer, BaseGroup, BaseSubsession, Currency, WaitPage, Page +from otree.common import _get_all_configs + + +class AppCheckHelper: + """Basically a wrapper around the AppConfig + """ + + def __init__(self, app_config, errors): + self.app_config = app_config + self.errors = errors + + def add_error(self, title, numeric_id: int, **kwargs): + issue_id = 'otree.E' + str(numeric_id).zfill(3) + kwargs.setdefault('obj', self.app_config.label) + return self.errors.append(Error(title, id=issue_id, **kwargs)) + + def add_warning(self, title, numeric_id: int, **kwargs): + kwargs.setdefault('obj', self.app_config.label) + issue_id = 'otree.W' + str(numeric_id).zfill(3) + return self.errors.append(Warning(title, id=issue_id, **kwargs)) + + # Helper meythods + + def get_path(self, name): + return os.path.join(self.app_config.path, name) + + def get_rel_path(self, name): + basepath = os.getcwd() + return os.path.relpath(name, basepath) + + def get_module(self, name): + return import_module(self.app_config.name + '.' + name) + + def get_template_names(self): + path = self.get_path('templates') + template_names = [] + for root, dirs, files in os.walk(path): + for filename in [f for f in files if f.endswith('.html')]: + template_names.append(os.path.join(root, filename)) + return template_names + + def module_exists(self, module): + try: + self.get_module(module) + return True + except ImportError as e: + return False + + def class_exists(self, module, name): + module = self.get_module(module) + cls = getattr(module, name, None) + return inspect.isclass(cls) + + +def files(helper: AppCheckHelper, **kwargs): + # don't check views.py because it might be pages.py + for fn in ['models.py']: + if not os.path.isfile(helper.get_path(fn)): + helper.add_error('No "%s" file found in app folder' % fn, numeric_id=102) + + templates_dir = Path(helper.get_path('templates')) + app_label = helper.app_config.label + if templates_dir.is_dir(): + # check for files in templates/, but not in templates/ + misplaced_files = list(templates_dir.glob('*.html')) + if misplaced_files: + hint = ( + 'Move template files from "{app}/templates/" ' + 'to "{app}/templates/{app}" subfolder'.format(app=app_label) + ) + + helper.add_error( + "Templates files in wrong folder", hint=hint, numeric_id=103 + ) + + all_subfolders = set(templates_dir.glob('*/')) + correctly_named_subfolders = set(templates_dir.glob('{}/'.format(app_label))) + other_subfolders = all_subfolders - correctly_named_subfolders + if other_subfolders and not correctly_named_subfolders: + msg = ( + "The 'templates' folder has a subfolder called '{}', " + "but it should be renamed '{}' to match the name of the app. " + ).format(other_subfolders.pop().name, app_label) + helper.add_error(msg, numeric_id=104) + + +base_model_attrs = { + 'Player': set(dir(BasePlayer)), + 'Group': set(dir(BaseGroup)), + 'Subsession': set(dir(BaseSubsession)), +} +model_field_substitutes = { + int: 'IntegerField', + float: 'FloatField', + bool: 'BooleanField', + str: 'CharField', + Currency: 'CurrencyField', + type(None): 'IntegerField' + # not always int, but it's a reasonable suggestion +} + + +def model_classes(helper: AppCheckHelper, **kwargs): + for name in ['Subsession', 'Group', 'Player']: + try: + helper.app_config.get_model(name) + except LookupError: + helper.add_error( + 'MissingModel: Model "%s" not defined' % name, numeric_id=110 + ) + + app_config = helper.app_config + Player = app_config.get_model('Player') + Group = app_config.get_model('Group') + Subsession = app_config.get_model('Subsession') + + for Model in [Player, Group, Subsession]: + for attr_name in dir(Model): + if attr_name not in base_model_attrs[Model.__name__]: + try: + attr_value = getattr(Model, attr_name) + _type = type(attr_value) + except AttributeError: + # I got "The 'q_country' attribute can only be accessed + # from Player instances." + # can just filter/ignore these. + pass + else: + if _type in model_field_substitutes.keys(): + msg = ( + 'NonModelFieldAttr: ' + '{} has attribute "{}", which is not a model field, ' + 'and will therefore not be saved ' + 'to the database.'.format(Model.__name__, attr_name) + ) + + helper.add_error( + msg, + numeric_id=111, + hint='Consider changing to "{} = models.{}(initial={})"'.format( + attr_name, + model_field_substitutes[_type], + repr(getattr(Model, attr_name)), + ), + ) + # if people just need an iterable of choices for a model field, + # they should use a tuple, not list or dict + elif _type in {list, dict, set}: + warning = ( + 'MutableModelClassAttr: ' + '{ModelName}.{attr} is a {type_name}. ' + 'Modifying it during a session (e.g. appending or setting values) ' + 'will have unpredictable results; ' + 'you should use ' + 'session.vars or participant.vars instead. ' + 'Or, if this {type_name} is read-only, ' + "then it's recommended to move it outside of this class " + '(e.g. put it in Constants).' + ).format( + ModelName=Model.__name__, + attr=attr_name, + type_name=_type.__name__, + ) + + helper.add_error(warning, numeric_id=112) + + +def constants(helper: AppCheckHelper, **kwargs): + if not helper.module_exists('models'): + return + if not helper.class_exists('models', 'Constants'): + helper.add_error('models.py does not contain Constants class', numeric_id=11) + return + + models = helper.get_module('models') + Constants = getattr(models, 'Constants') + attrs = ['name_in_url', 'players_per_group', 'num_rounds'] + for attr_name in attrs: + if not hasattr(Constants, attr_name): + msg = "models.py: 'Constants' class needs to define '{}'" + helper.add_error(msg.format(attr_name), numeric_id=12) + ppg = Constants.players_per_group + if ppg == 0 or ppg == 1: + helper.add_error( + "models.py: Constants.players_per_group cannot be {}. You " + "should set it to None, which makes the group " + "all players in the subsession.".format(ppg), + numeric_id=13, + ) + if ' ' in Constants.name_in_url: + helper.add_error( + "models.py: Constants.name_in_url must not contain spaces", numeric_id=14 + ) + + +def pages_function(helper: AppCheckHelper, **kwargs): + pages_module = common.get_pages_module(helper.app_config.name) + try: + page_list = pages_module.page_sequence + except: + helper.add_error( + 'pages.py is missing the variable page_sequence.', numeric_id=21 + ) + return + else: + for i, ViewCls in enumerate(page_list): + # there is no good reason to include Page in page_sequence. + # As for WaitPage: even though it works fine currently + # and can save the effort of subclassing, + # we should restrict it, because: + # - one user had "class WaitPage(Page):". + # - if someone makes "class WaitPage(WaitPage):", they might + # not realize why it's inheriting the extra behavior. + # overall, I think the small inconvenience of having to subclass + # once per app + # is outweighed by the unexpected behavior if someone subclasses + # it without understanding inheritance. + # BUT: built-in Trust game had a wait page called WaitPage. + # that was fixed on Aug 24, 2017, need to wait a while... + # see below in ensure_no_misspelled_attributes, + # we can get rid of a check there also + if ViewCls.__name__ == 'Page': + msg = "page_sequence cannot contain a class called 'Page'." + helper.add_error(msg, numeric_id=22) + if ViewCls.__name__ == 'WaitPage' and helper.app_config.name != 'trust': + msg = "page_sequence cannot contain a class called 'WaitPage'." + helper.add_error(msg, numeric_id=221) + + if issubclass(ViewCls, WaitPage): + if ViewCls.group_by_arrival_time: + if i > 0: + helper.add_error( + '"{}" has group_by_arrival_time=True, so ' + 'it must be placed first in page_sequence.'.format( + ViewCls.__name__ + ), + numeric_id=23, + ) + if ViewCls.wait_for_all_groups: + helper.add_error( + 'Page "{}" has group_by_arrival_time=True, so ' + 'it cannot have wait_for_all_groups=True also.'.format( + ViewCls.__name__ + ), + numeric_id=24, + ) + # alternative technique is to not define the method on WaitPage + # and then use hasattr, but I want to keep all complexity + # out of views.abstract + elif ViewCls.get_players_for_group != WaitPage.get_players_for_group: + helper.add_error( + 'Page "{}" defines get_players_for_group, ' + 'but in order to use this method, you must set ' + 'group_by_arrival_time=True'.format(ViewCls.__name__), + numeric_id=25, + ) + elif issubclass(ViewCls, Page): + pass # ok + else: + msg = '"{}" is not a valid page'.format(ViewCls) + helper.add_error(msg, numeric_id=26) + + +def make_check_function(func): + def check_function(app_configs, **kwargs): + # if app_configs list is given (e.g. otree check app1 app2), run on those + # if it's None, run on all apps + # (system check API requires this) + app_configs = app_configs or _get_all_configs() + errors = [] + for app_config in app_configs: + helper = AppCheckHelper(app_config, errors) + func(helper, **kwargs) + return errors + + return check_function + + +def register_system_checks(): + for func in [model_classes, files, constants, pages_function]: + check_function = make_check_function(func) + register(check_function) diff --git a/otree/checks/__init__.py b/otree/checks/__init__.py deleted file mode 100644 index 43812ba55..000000000 --- a/otree/checks/__init__.py +++ /dev/null @@ -1,586 +0,0 @@ -import glob -import inspect -import io -import os - -from otree import common_internal -from importlib import import_module - -from django.apps import apps -from django.conf import settings -from django.core.checks import register, Error, Warning -from django.template import Template -from django.template import TemplateSyntaxError -import django.db.models.fields -from otree.api import ( - BasePlayer, BaseGroup, BaseSubsession, Currency, WaitPage, Page) -from otree.common_internal import _get_all_configs -from pathlib import Path -import re - - -class AppCheckHelper: - """Basically a wrapper around the AppConfig - """ - - def __init__(self, app_config, errors): - self.app_config = app_config - self.errors = errors - - def add_error(self, title, numeric_id: int, **kwargs): - issue_id = 'otree.E' + str(numeric_id).zfill(3) - kwargs.setdefault('obj', self.app_config.label) - return self.errors.append(Error(title, id=issue_id, **kwargs)) - - def add_warning(self, title, numeric_id: int, **kwargs): - kwargs.setdefault('obj', self.app_config.label) - issue_id = 'otree.W' + str(numeric_id).zfill(3) - return self.errors.append(Warning(title, id=issue_id, **kwargs)) - - # Helper meythods - - def get_path(self, name): - return os.path.join(self.app_config.path, name) - - def get_rel_path(self, name): - basepath = os.getcwd() - return os.path.relpath(name, basepath) - - def get_module(self, name): - return import_module(self.app_config.name + '.' + name) - - def get_template_names(self): - path = self.get_path('templates') - template_names = [] - for root, dirs, files in os.walk(path): - for filename in [f for f in files if f.endswith('.html')]: - template_names.append(os.path.join(root, filename)) - return template_names - - def module_exists(self, module): - try: - self.get_module(module) - return True - except ImportError as e: - return False - - def class_exists(self, module, name): - module = self.get_module(module) - cls = getattr(module, name, None) - return inspect.isclass(cls) - - -# CHECKS - -def files(helper: AppCheckHelper, **kwargs): - # don't check views.py because it might be pages.py - for fn in ['models.py']: - if not os.path.isfile(helper.get_path(fn)): - helper.add_error( - 'No "%s" file found in game folder' % fn, - numeric_id=102 - ) - - templates_dir = Path(helper.get_path('templates')) - app_label = helper.app_config.label - if templates_dir.is_dir(): - # check for files in templates/, but not in templates/ - misplaced_files = list(templates_dir.glob('*.html')) - if misplaced_files: - hint = ( - 'Move template files from "{app}/templates/" ' - 'to "{app}/templates/{app}" subfolder'.format( - app=app_label) - ) - - helper.add_error( - "Templates files in wrong folder", - hint=hint, numeric_id=103, - ) - - all_subfolders = set(templates_dir.glob('*/')) - correctly_named_subfolders = set( - templates_dir.glob('{}/'.format(app_label))) - other_subfolders = all_subfolders - correctly_named_subfolders - if other_subfolders and not correctly_named_subfolders: - msg = ( - "The 'templates' folder has a subfolder called '{}', " - "but it should be renamed '{}' to match the name of the app. " - ).format(other_subfolders.pop().name, app_label) - helper.add_error(msg, numeric_id=104) - - -base_model_attrs = { - 'Player': set(dir(BasePlayer)), - 'Group': set(dir(BaseGroup)), - 'Subsession': set(dir(BaseSubsession)), -} - -model_field_substitutes = { - int: 'IntegerField', - float: 'FloatField', - bool: 'BooleanField', - str: 'CharField', - Currency: 'CurrencyField', - type(None): 'IntegerField' - # not always int, but it's a reasonable suggestion -} - - -def model_classes(helper: AppCheckHelper, **kwargs): - for name in ['Subsession', 'Group', 'Player']: - try: - helper.app_config.get_model(name) - except LookupError: - helper.add_error( - 'MissingModel: Model "%s" not defined' % name, numeric_id=110) - - app_config = helper.app_config - Player = app_config.get_model('Player') - Group = app_config.get_model('Group') - Subsession = app_config.get_model('Subsession') - - for Model in [Player, Group, Subsession]: - for attr_name in dir(Model): - if attr_name not in base_model_attrs[Model.__name__]: - try: - attr_value = getattr(Model, attr_name) - _type = type(attr_value) - except AttributeError: - # I got "The 'q_country' attribute can only be accessed - # from Player instances." - # can just filter/ignore these. - pass - else: - if _type in model_field_substitutes.keys(): - msg = ( - 'NonModelFieldAttr: ' - '{} has attribute "{}", which is not a model field, ' - 'and will therefore not be saved ' - 'to the database.'.format(Model.__name__, - attr_name)) - - helper.add_error( - msg, - numeric_id=111, - hint='Consider changing to "{} = models.{}(initial={})"'.format( - attr_name, model_field_substitutes[_type], - repr(getattr(Model, attr_name))) - ) - # if people just need an iterable of choices for a model field, - # they should use a tuple, not list or dict - elif _type in {list, dict, set}: - warning = ( - 'MutableModelClassAttr: ' - '{ModelName}.{attr} is a {type_name}. ' - 'Modifying it during a session (e.g. appending or setting values) ' - 'will have unpredictable results; ' - 'you should use ' - 'session.vars or participant.vars instead. ' - 'Or, if this {type_name} is read-only, ' - "then it's recommended to move it outside of this class " - '(e.g. put it in Constants).' - ).format(ModelName=Model.__name__, - attr=attr_name, - type_name=_type.__name__) - - helper.add_error(warning, numeric_id=112) - # isinstance(X, type) means X is a class, not instance - elif (isinstance(attr_value, type) and - issubclass(attr_value, - django.db.models.fields.Field)): - msg = ( - '{}.{} is missing parentheses.' - ).format(Model.__name__, attr_name) - helper.add_error( - msg, numeric_id=113, - hint=( - 'Consider changing to "{} = models.{}()"' - ).format(attr_name, attr_value.__name__) - ) - - -def constants(helper: AppCheckHelper, **kwargs): - if not helper.module_exists('models'): - return - if not helper.class_exists('models', 'Constants'): - helper.add_error( - 'models.py does not contain Constants class', numeric_id=11 - ) - return - - models = helper.get_module('models') - Constants = getattr(models, 'Constants') - attrs = ['name_in_url', 'players_per_group', 'num_rounds'] - for attr_name in attrs: - if not hasattr(Constants, attr_name): - msg = "models.py: 'Constants' class needs to define '{}'" - helper.add_error(msg.format(attr_name), numeric_id=12) - ppg = Constants.players_per_group - if ppg == 0 or ppg == 1: - helper.add_error( - "models.py: Constants.players_per_group cannot be {}. You " - "should set it to None, which makes the group " - "all players in the subsession.".format(ppg), - numeric_id=13 - ) - if ' ' in Constants.name_in_url: - helper.add_error( - "models.py: Constants.name_in_url must not contain spaces", - numeric_id=14 - ) - - -def orphan_methods(helper: AppCheckHelper, **kwargs): - '''i saw several people making this mistake in the workshop''' - pages_module = common_internal.get_pages_module(helper.app_config.name) - for method_name in ['vars_for_template', 'is_displayed', - 'after_all_players_arrive']: - if hasattr(pages_module, method_name): - helper.add_error( - 'pages.py has a function {} that is not inside a class.'.format( - method_name), - numeric_id=70 - ) - - return - - -def pages_function(helper: AppCheckHelper, **kwargs): - pages_module = common_internal.get_pages_module(helper.app_config.name) - views_or_pages = pages_module.__name__.split('.')[-1] - try: - page_list = pages_module.page_sequence - except: - helper.add_error( - '{}.py is missing the variable page_sequence.'.format( - views_or_pages), - numeric_id=21 - ) - return - else: - for i, ViewCls in enumerate(page_list): - # there is no good reason to include Page in page_sequence. - # As for WaitPage: even though it works fine currently - # and can save the effort of subclassing, - # we should restrict it, because: - # - one user had "class WaitPage(Page):". - # - if someone makes "class WaitPage(WaitPage):", they might - # not realize why it's inheriting the extra behavior. - # overall, I think the small inconvenience of having to subclass - # once per app - # is outweighed by the unexpected behavior if someone subclasses - # it without understanding inheritance. - # BUT: built-in Trust game had a wait page called WaitPage. - # that was fixed on Aug 24, 2017, need to wait a while... - # see below in ensure_no_misspelled_attributes, - # we can get rid of a check there also - if ViewCls.__name__ == 'Page': - msg = ( - "page_sequence cannot contain " - "a class called 'Page'." - ) - helper.add_error(msg, numeric_id=22) - if ViewCls.__name__ == 'WaitPage' and helper.app_config.name != 'trust': - msg = ( - "page_sequence cannot contain " - "a class called 'WaitPage'." - ) - helper.add_error(msg, numeric_id=221) - - if issubclass(ViewCls, WaitPage): - if ViewCls.group_by_arrival_time: - if i > 0: - helper.add_error( - '"{}" has group_by_arrival_time=True, so ' - 'it must be placed first in page_sequence.'.format( - ViewCls.__name__), numeric_id=23) - if ViewCls.wait_for_all_groups: - helper.add_error( - 'Page "{}" has group_by_arrival_time=True, so ' - 'it cannot have wait_for_all_groups=True also.'.format( - ViewCls.__name__), numeric_id=24) - # alternative technique is to not define the method on WaitPage - # and then use hasattr, but I want to keep all complexity - # out of views.abstract - elif ( - ViewCls.get_players_for_group != WaitPage.get_players_for_group): - helper.add_error( - 'Page "{}" defines get_players_for_group, ' - 'but in order to use this method, you must set ' - 'group_by_arrival_time=True'.format( - ViewCls.__name__), numeric_id=25) - elif issubclass(ViewCls, Page): - pass # ok - else: - msg = '"{}" is not a valid page'.format(ViewCls) - helper.add_error(msg, numeric_id=26) - - ensure_no_misspelled_attributes(ViewCls, helper) - - -def ensure_no_misspelled_attributes(ViewCls: type, helper: AppCheckHelper): - '''just a helper function''' - - # this messes with the logic of base classes. - # do this instead of ViewCls == WaitPage, because _builtin already - # subclasses it, so you would get a warning like: - # Page "WaitPage" has the following method that is not recognized by oTree: - # "z_autocomplete". - if ViewCls.__name__ == 'WaitPage' or ViewCls.__name__ == 'Page': - return - - # make sure no misspelled attributes - base_members = set() - for Cls in ViewCls.__bases__: - base_members.update(dir(Cls)) - child_members = set(dir(ViewCls)) - child_only_members = child_members - base_members - - dynamic_form_methods = set() # needs to be a set - for member in child_only_members: - # error_message, not _error_message - for valid_ending in ['error_message', '_min', '_max', '_choices']: - if member.endswith(valid_ending): - dynamic_form_methods.add(member) - invalid_members = child_only_members - dynamic_form_methods - if invalid_members: - ALLOW_CUSTOM_ATTRIBUTES = '_allow_custom_attributes' - if getattr(ViewCls, ALLOW_CUSTOM_ATTRIBUTES, False): - return - - page_attrs = set(dir(Page)) - wait_page_attrs = set(dir(WaitPage)) - ATTRS_ON_PAGE_ONLY = page_attrs - wait_page_attrs - ATTRS_ON_WAITPAGE_ONLY = wait_page_attrs - page_attrs - - for member in invalid_members: - # this assumes that ViewCls is a Page or WaitPage - if member in ATTRS_ON_PAGE_ONLY: - assert issubclass(ViewCls, WaitPage), (ViewCls, member) - msg = ( - 'WaitPage "{ViewClsName}" has the attribute "{member}" that is not ' - 'allowed on a WaitPage. ' - ) - numeric_id = 27 - elif member in ATTRS_ON_WAITPAGE_ONLY: - assert issubclass(ViewCls, Page), (ViewCls, member) - msg = ( - 'Page "{ViewClsName}" has the attribute "{member}" that is ' - 'only allowed on a WaitPage, not a regular Page. ' - ) - numeric_id=271 - elif callable(getattr(ViewCls, member)): - msg = ( - 'Page "{ViewClsName}" has the following method that is not ' - 'recognized by oTree: "{member}". ' - 'Consider moving it into ' - 'the Player class in models.py. ' - ) - - numeric_id=28 - else: - msg = ( - 'Page "{ViewClsName}" has the following attribute that is not ' - 'recognized by oTree: "{member}". ' - ) - numeric_id=29 - - fmt_kwargs = { - 'ViewClsName': ViewCls.__name__, - 'FLAG': ALLOW_CUSTOM_ATTRIBUTES, - 'member': member, - } - # when i make this an error, should add this workaround. - #msg += 'If you want to keep it here, you need to set ' - # '{FLAG}=True on the page class.' - - # at first, just make it a warning. - helper.add_error(msg.format(**fmt_kwargs), numeric_id) - - -def template_content_is_in_blocks(template_name: str, helper: AppCheckHelper): - from otree.checks.templates import get_unreachable_content - from otree.checks.templates import has_valid_encoding - from otree.checks.templates import format_source_snippet - - # Only test files that are valid templates. - if not has_valid_encoding(template_name): - return - - try: - with io.open(template_name, 'r', encoding='utf8') as f: - - # when we upgraded to Django 1.11, we got an error - # if someone used "{% include %}" with a relative - # path (like ../Foo.html): - # File "c:\otree\ve_dj11\lib\site-packages\django\template\loader_tags.py", line 278, in construct_relative_path - # posixpath.dirname(current_template_name.lstrip('/')), - # AttributeError: 'NoneType' object has no attribute 'lstrip' - # can fix this by passing a dummy 'Origin' param. - # i tried also with Engin.get_default().from_string(template_name), - # but got the same error. - class Origin: - name = '' - template_name = '' - - compiled_template = Template(f.read(), origin=Origin) - except (IOError, OSError, TemplateSyntaxError): - # When we used Django 1.8 - # we used to show the line from the source that caused the error, - # but django_template_source was removed at some point, - # so it's better to let the yellow error page show the error nicely - return - - def format_content(text): - text = text.strip() - lines = text.splitlines() - lines = ['> {0}'.format(line) for line in lines] - return '\n'.join(lines) - - contents = get_unreachable_content(compiled_template) - - content_bits = '\n\n'.join( - format_content(bit) - for bit in contents) - # note: this seems to not detect unreachable content - # if the template has a relative include, - # like {% include "../Foo.html" %} - # not sure why, but that's not common usage. - if contents: - helper.add_error( - 'Template contains the following text outside of a ' - '{% block %}. This text will never be displayed.' - '\n\n' + content_bits, - # why do we do this? isn't template_name already the full path? - obj=os.path.join(helper.app_config.label, - helper.get_rel_path(template_name)), - numeric_id=7) - - -def templates_valid(helper: AppCheckHelper, **kwargs): - for template_name in helper.get_template_names(): - template_content_is_in_blocks(template_name, helper) - - -def no_model_class_references(helper: AppCheckHelper, **kwargs): - models_path = helper.get_path('models.py') - - from otree.checks.templates import has_valid_encoding - - # Only test files that are valid templates. - if not has_valid_encoding(models_path): - return - - try: - with io.open(models_path, 'r', encoding='utf8') as f: - content = f.read() - except (IOError, OSError): - # when would these errors occur? - # (this was originally copied from template check) - return - - allowed_attr_names = ['add_to_class', 'objects'] - matches = re.finditer(r'(Player|Group|Subsession)\.(\w+)', content) - - for m in matches: - matched_text = m.group(0) - ModelName = m.group(1) - attr_name = m.group(2) - if attr_name in allowed_attr_names: - continue - position = m.start(0) - num_newlines = content[:position].count('\n') - line_number = num_newlines + 1 - first_letter_uppercase = ModelName[0] - helper.add_error( - f'models.py contains a reference to "{matched_text}" around line {line_number}. ' - f'You should not refer to {ModelName} with an uppercase {ModelName[0]}, ' - f'because that refers to the whole class, ' - f'rather than any individual {ModelName.lower()}. ' - 'Learn about how to access instances using "self".', - numeric_id=120) - - -def unique_sessions_names(helper: AppCheckHelper, **kwargs): - already_seen = set() - for st in settings.SESSION_CONFIGS: - st_name = st["name"] - if st_name in already_seen: - msg = "Duplicate SESSION_CONFIG name '{}'".format(st_name) - helper.add_error(msg, numeric_id=40) - else: - already_seen.add(st_name) - - -def unique_room_names(helper: AppCheckHelper, **kwargs): - already_seen = set() - for room in getattr(settings, 'ROOMS', []): - room_name = room["name"] - if room_name in already_seen: - msg = "Duplicate ROOM name '{}'".format(room_name) - helper.add_error(msg, numeric_id=50) - else: - already_seen.add(room_name) - - -def template_encoding(helper: AppCheckHelper, **kwargs): - from otree.checks.templates import has_valid_encoding - for template_name in helper.get_template_names(): - if not has_valid_encoding(template_name): - helper.add_error( - 'The template {template} is not UTF-8 encoded. ' - 'Please configure your text editor to always save files ' - 'as UTF-8. Then open the file and save it again.' - .format(template=helper.get_rel_path(template_name)), - numeric_id=60, - ) - - -def make_check_function(func): - def check_function(app_configs, **kwargs): - # if app_configs list is given (e.g. otree check app1 app2), run on those - # if it's None, run on all apps - # (system check API requires this) - app_configs = app_configs or _get_all_configs() - errors = [] - for app_config in app_configs: - helper = AppCheckHelper(app_config, errors) - func(helper, **kwargs) - return errors - - return check_function - - -def make_check_function_run_once(func): - def check_function(app_configs, **kwargs): - otree_app_config = apps.get_app_config('otree') - # ignore app_configs list -- just run once - errors = [] - helper = AppCheckHelper(otree_app_config, errors) - func(helper, **kwargs) - return errors - - return check_function - - -def register_system_checks(): - for func in [ - unique_room_names, - unique_sessions_names, - ]: - check_function = make_check_function_run_once(func) - register(check_function) - - for func in [ - model_classes, - files, - constants, - pages_function, - templates_valid, - template_encoding, - orphan_methods, - no_model_class_references, - ]: - check_function = make_check_function(func) - register(check_function) diff --git a/otree/checks/mturk.py b/otree/checks/mturk.py deleted file mode 100644 index 09792736b..000000000 --- a/otree/checks/mturk.py +++ /dev/null @@ -1,64 +0,0 @@ -from django.template.loader import select_template -from django.contrib import messages - -from otree.views import Page, WaitPage -import otree.common_internal -from otree.checks.templates import check_next_button - - -class MTurkValidator(object): - ''' - This validation is based on issue #314 - ''' - def __init__(self, session): - self.session = session - - def get_no_next_buttons_pages(self): - ''' - Check that every page in every app has next_button. - Also including the last page. Next button on last page is - necessary to trigger an externalSubmit to the MTurk server. - ''' - missing_next_button_pages = [] - for app in self.session.config['app_sequence']: - views_module = otree.common_internal.get_pages_module(app) - for page_class in views_module.page_sequence: - page = page_class() - if isinstance(page, Page): - path_template = page.get_template_names() - template = select_template(path_template) - # The returned ``template`` variable is only a wrapper - # around Django's internal ``Template`` object. - template = template.template - if not check_next_button(template): - # can't use template.origin.name because it's not - # available when DEBUG is off. So use path_template - # instead - missing_next_button_pages.append((page, path_template)) - return missing_next_button_pages - - def validation_message(self): - missing_next_button_pages = self.get_no_next_buttons_pages() - if missing_next_button_pages: - page_listing = '; '.join([ - 'Template {} for page {}'.format(template_name, page.__class__.__name__) - for page, template_name in missing_next_button_pages]) - return ( - 'The following templates appear to have no next button. <{}> ' - 'When using oTree on MTurk, even the last page should have a next button. ' - ).format(page_listing) - - return '' - # 2017-05-06: I removed the check for timeouts, because I added - # get_timeout_seconds. - # i could base the warning on whether timeout_seconds is defined, - # but it seems like the warning would generate false positives. - # It's a bit complicated, and doesn't seem worth the code complexity. - - - def app_has_no_wait_pages(self, app): - views_module = otree.common_internal.get_pages_module(app) - return not any(issubclass(page_class, WaitPage) - for page_class in views_module.page_sequence) - - diff --git a/otree/checks/templates.py b/otree/checks/templates.py deleted file mode 100644 index 2cc7d7f90..000000000 --- a/otree/checks/templates.py +++ /dev/null @@ -1,173 +0,0 @@ -from collections import namedtuple -from django.template.base import TextNode -from django.template.library import InclusionNode -from django.template.loader_tags import ExtendsNode, BlockNode -from django.utils.encoding import force_text -from itertools import chain -import unicodedata -import io - -from otree.templatetags.otree import NEXT_BUTTON_TEMPLATE_PATH - -class TemplateCheckContent(object): - def __init__(self, root): - self.root = root - - def node_is_empty(self, node): - if isinstance(node, TextNode): - return node.s.isspace() - return False - - def is_extending(self, root): - return any( - isinstance(node, ExtendsNode) - for node in root.nodelist) - - def is_content_node(self, node): - """ - Returns if the node is an unempty text node. - """ - if isinstance(node, TextNode): - return not self.node_is_empty(node) - return False - - def get_toplevel_content_nodes(self, root): - nodes = [] - for node in root.nodelist: - if isinstance(node, ExtendsNode): - new_child_nodes = self.get_toplevel_content_nodes(node) - nodes.extend(new_child_nodes) - if self.is_content_node(node): - nodes.append(node) - return nodes - - def get_unreachable_content(self): - """ - Return all top-level text nodes when the template is extending another - template. Those text nodes won't be displayed during rendering since - only content inside of blocks is considered in inheritance. - """ - if not self.is_extending(self.root): - return [] - - textnodes = self.get_toplevel_content_nodes(self.root) - return [node.s for node in textnodes] - - -def get_unreachable_content(root): - check = TemplateCheckContent(root) - return check.get_unreachable_content() - - -class TemplateCheckNextButton(object): - def __init__(self, root): - self.root = root - - def get_next_button_nodes(self, root): - nodes = [] - for node in root.nodelist: - if isinstance(node, (ExtendsNode, BlockNode)): - new_child_nodes = self.get_next_button_nodes(node) - nodes.extend(new_child_nodes) - elif isinstance(node, InclusionNode) and node.filename == NEXT_BUTTON_TEMPLATE_PATH: - nodes.append(node) - return nodes - - def check_next_button(self): - next_button_nodes = self.get_next_button_nodes(self.root) - return len(next_button_nodes) > 0 - - -def check_next_button(root): - check = TemplateCheckNextButton(root) - return check.check_next_button() - - -def has_valid_encoding(file_name): - try: - # need to open the file with an explicit encoding='utf8' - # otherwise Windows may use another encoding if. - # io.open provides the encoding= arg and is Py2/Py3 compatible - with io.open(file_name, 'r', encoding='utf8') as f: - template_string = f.read() - force_text(template_string) - except UnicodeDecodeError: - return False - return True - - -Line = namedtuple('Line', ('source', 'lineno', 'start', 'end')) - - -def format_error_line(line): - # We need to make sure here that the output does not contain any unicode - # characters. Django's check framework cannot print errors that contain - # unicode. - source = line.source - source = unicodedata.normalize('NFKD', source).encode('ascii', 'replace') - return '{line.lineno:4d} | {source}'.format(line=line, source=source) - - -def split_source_lines(source): - """ - Split source string into a list of ``Line`` objects. They contain - contextual information like line number, start position, end position. - """ - lines = source.splitlines(True) - start = 0 - annotated_lines = [] - for i, line in enumerate(lines): - # Windows line endings end with '\r\n'. - if line.endswith('\r\n'): - ending_length = 2 - # In case of '\n' or '\r' ending the line. - else: - ending_length = 1 - end = start + len(line) - annotated_lines.append(Line( - # Don't include line endings in snippet source. - source=line[:-ending_length], - lineno=i + 1, - start=start, - end=end)) - start = end - return annotated_lines - - -def format_source_snippet(source, arrow_position, context=5): - """ - Display parts of a source file with an arrow pointing at an exact location. - Will display ``context`` number of lines before and after the arrow - position. - - Example:: - - 15 | - 16 | Please provide your information in the form below. - 17 | - 18 | - 19 | {% formrow form.my_field with label = "foo" %} - -----------^ - 20 | - 21 | {% next_button %} - 22 | {% endblock %} - """ - lines = split_source_lines(source) - error_line = 0 - for line in lines: - if line.start <= arrow_position < line.end: - error_line = line - break - start_context = max(error_line.lineno - 1 - context, 0) - end_context = min(error_line.lineno + context, len(lines)) - before = lines[start_context:error_line.lineno] - after = lines[error_line.lineno:end_context] - - error_prefix = max(len(str(error_line.lineno)), 4) + len(' | ') - error_length = max(arrow_position - error_line.start, 0) - error_arrow = ('-' * (error_prefix + error_length)) + '^' - return '\n'.join(chain( - [format_error_line(line) for line in before], - [error_arrow], - [format_error_line(line) for line in after], - )) diff --git a/otree/common.py b/otree/common.py index 9e20ed22d..6f4a954ed 100644 --- a/otree/common.py +++ b/otree/common.py @@ -1,35 +1,41 @@ -"""oTree Public API utilities""" - +import contextlib +import hashlib +import importlib.util +import itertools +import logging +import random +import re +import string +import threading +from collections import OrderedDict +from importlib import import_module +from typing import Iterable, ItemsView, Tuple +from django.apps import apps +from django.conf import settings +from django.db import connection +from django.db import transaction +from django.utils.safestring import mark_safe +from huey.contrib.djhuey import HUEY +import urllib +import os +import model_utils.tracker import json -from decimal import Decimal from django.conf import settings -from django.utils import formats, numberformat from django.utils.safestring import mark_safe -from otree.currency import Currency, RealWorldCurrency -import six - - -# ============================================================================= -# MONKEY PATCH - fix for https://github.com/oTree-org/otree-core/issues/387 -# ============================================================================= - -# Black Magic: The original number format of django used inside templates don't -# work if the currency code contains non-ascii characters. This ugly hack -# remplace the original number format and when you has a easy_money instance -# simple use the old unicode casting. -_original_number_format = numberformat.format +# until 2016, otree apps imported currency from otree.common. +from otree.currency import Currency, RealWorldCurrency, currency_range -def otree_number_format(number, *args, **kwargs): - if isinstance(number, (Currency, RealWorldCurrency)): - return six.text_type(number) - return _original_number_format(number, *args, **kwargs) - -numberformat.format = otree_number_format +# set to False if using runserver +USE_REDIS = bool(os.environ.get('OTREE_USE_REDIS', '')) +# these locks need to be here rather than views.abstract or views.participant +# because they need to be imported when the main thread runs. +start_link_thread_lock = threading.RLock() +wait_page_thread_lock = threading.RLock() class _CurrencyEncoder(json.JSONEncoder): @@ -46,26 +52,345 @@ def safe_json(obj): return mark_safe(json.dumps(obj, cls=_CurrencyEncoder)) -def currency_range(first, last, increment): - assert last >= first - if Currency(increment) == 0: - if settings.USE_POINTS: - setting_name = 'POINTS_DECIMAL_PLACES' - else: - setting_name = 'REAL_WORLD_CURRENCY_DECIMAL_PLACES' - raise ValueError( - ('currency_range() step argument must not be zero. ' - 'Maybe your {} setting is ' - 'causing it to be rounded to 0.').format(setting_name) +def add_params_to_url(url, params): + url_parts = list(urllib.parse.urlparse(url)) + + # use OrderedDict because sometimes we want certain params at end + # for readability/consistency + query = OrderedDict(urllib.parse.parse_qsl(url_parts[4])) + query.update(params) + url_parts[4] = urllib.parse.urlencode(query) + return urllib.parse.urlunparse(url_parts) + + +SESSION_CODE_CHARSET = string.ascii_lowercase + string.digits + + +def random_chars(num_chars): + return ''.join(random.choice(SESSION_CODE_CHARSET) for _ in range(num_chars)) + + +def random_chars_8(): + return random_chars(8) + + +def random_chars_10(): + return random_chars(10) + + +def get_models_module(app_name): + '''shouldn't rely on app registry because the app might have been removed + from SESSION_CONFIGS, especially if the session was created a long time + ago and you want to export it''' + return import_module(f'{app_name}.models') + + +def get_bots_module(app_name): + return import_module(f'{app_name}.tests') + + +def get_pages_module(app_name): + '''views.py is deprecated, remove it soon''' + for module_name in ['pages', 'views']: + dotted = '{}.{}'.format(app_name, module_name) + if importlib.util.find_spec(dotted): + return import_module(dotted) + msg = 'No pages module found for app {}'.format(app_name) + raise ImportError(msg) + + +def get_app_constants(app_name): + return get_models_module(app_name).Constants + + +def get_dotted_name(Cls): + return '{}.{}'.format(Cls.__module__, Cls.__name__) + + +def get_app_label_from_import_path(import_path): + '''App authors must not override AppConfig.label''' + return import_path.split('.')[-2] + + +def get_app_label_from_name(app_name): + '''App authors must not override AppConfig.label''' + return app_name.split('.')[-1] + + +def expand_choice_tuples(choices): + '''allows the programmer to define choices as a list of values rather + than (value, display_value) + + ''' + if not choices: + return None + if not isinstance(choices[0], (list, tuple)): + choices = [(value, value) for value in choices] + return choices + + +def missing_db_tables(): + """Try to execute a simple select * for every model registered + """ + + # need to normalize to lowercase because MySQL converts DB names to lower + expected_table_names_dict = { + Model._meta.db_table.lower(): '{}.{}'.format( + Model._meta.app_label, Model.__name__ + ) + for Model in apps.get_models() + } + + expected_table_names = set(expected_table_names_dict.keys()) + + # again, normalize to lowercase + actual_table_names = set( + tn.lower() for tn in connection.introspection.table_names() + ) + + missing_table_names = expected_table_names - actual_table_names + + # don't use the SQL table name because it could be uppercase or lowercase, + # depending on whether it's MySQL + return [ + expected_table_names_dict[missing_table] + for missing_table in missing_table_names + ] + + +def make_hash(s): + s += settings.SECRET_KEY + return hashlib.sha224(s.encode()).hexdigest()[:8] + + +def get_admin_secret_code(): + s = settings.SECRET_KEY + return hashlib.sha224(s.encode()).hexdigest()[:8] + + +def validate_alphanumeric(identifier, identifier_description): + if re.match(r'^[a-zA-Z0-9_]+$', identifier): + return identifier + raise ValueError( + '{} "{}" can only contain letters, numbers, ' + 'and underscores (_)'.format(identifier_description, identifier) + ) + + +EMPTY_ADMIN_USERNAME_MSG = 'settings.ADMIN_USERNAME is empty' +EMPTY_ADMIN_PASSWORD_MSG = 'settings.ADMIN_PASSWORD is empty' + + +def ensure_superuser_exists(*args, **kwargs) -> str: + """ + Creates our default superuser. + If it fails, it returns a failure message + """ + username = settings.ADMIN_USERNAME + password = settings.ADMIN_PASSWORD + if not username: + return EMPTY_ADMIN_USERNAME_MSG + if not password: + return EMPTY_ADMIN_PASSWORD_MSG + from django.contrib.auth.models import User + + if User.objects.filter(username=username).exists(): + # msg = 'Default superuser exists.' + # logger.info(msg) + return '' + User.objects.create_superuser(username, email='', password=password) + msg = 'Created superuser "{}"'.format(username) + logging.getLogger('otree').info(msg) + return '' + + +def release_any_stale_locks(): + ''' + Need to release locks in case the server was stopped abruptly, + and the 'finally' block in each lock did not execute + ''' + from otree.models_concrete import ParticipantLockModel + + for LockModel in [ParticipantLockModel]: + try: + LockModel.objects.filter(locked=True).update(locked=False) + except: + # if server is started before DB is synced, + # this will raise + # django.db.utils.OperationalError: no such table: + # otree_globallockmodel + # we can ignore that because we just want to make sure there are no + # active locks + pass + + +def get_redis_conn(): + '''reuse Huey Redis connection''' + return HUEY.storage.conn + + +def has_group_by_arrival_time(app_name): + page_sequence = get_pages_module(app_name).page_sequence + if len(page_sequence) == 0: + return False + # it might not be a waitpage + return getattr(page_sequence[0], 'group_by_arrival_time', False) + + +def is_sqlite(): + return settings.DATABASES['default']['ENGINE'].endswith('sqlite3') + + +@contextlib.contextmanager +def transaction_except_for_sqlite(): + ''' + On SQLite, transactions tend to result in "database locked" errors. + So, skip the transaction on SQLite, to allow local dev. + Should only be used if omitting the transaction rarely causes problems. + ''' + if is_sqlite(): + yield + else: + with transaction.atomic(): + yield + + +class DebugTable: + def __init__(self, title, rows: Iterable[Tuple]): + self.title = title + self.rows = [] + for k, v in rows: + if isinstance(v, str): + v = v.strip().replace("\n", "") + v = mark_safe(v) + self.rows.append((k, v)) + + +class InvalidRoundError(ValueError): + pass + + +def in_round(ModelClass, round_number, **kwargs): + if round_number < 1: + msg = 'Invalid round number: {}'.format(round_number) + raise InvalidRoundError(msg) + try: + return ModelClass.objects.get(round_number=round_number, **kwargs) + except ModelClass.DoesNotExist: + raise InvalidRoundError( + 'No corresponding {} found with round_number={}'.format( + ModelClass.__name__, round_number + ) + ) from None + + +def in_rounds(ModelClass, first, last, **kwargs): + if first < 1: + msg = 'Invalid round number: {}'.format(first) + raise InvalidRoundError(msg) + qs = ModelClass.objects.filter( + round_number__range=(first, last), **kwargs + ).order_by('round_number') + + ret = list(qs) + num_results = len(ret) + expected_num_results = last - first + 1 + if num_results != expected_num_results: + raise InvalidRoundError( + 'Database contains {} records for rounds {}-{}, but expected {}'.format( + num_results, first, last, expected_num_results + ) ) + return ret + + +class BotError(AssertionError): + pass + + +def _get_all_configs(): + return [ + app + for app in apps.get_app_configs() + if app.name in settings.INSTALLED_OTREE_APPS + ] + + +def participant_start_url(code): + return '/InitializeParticipant/{}'.format(code) + + +def patch_migrations_module(): + from django.db.migrations.loader import MigrationLoader + + def migrations_module(*args, **kwargs): + # need to return None so that load_disk() considers it + # unmigrated, and False so that load_disk() considers it + # non-explicit + return None, False + + MigrationLoader.migrations_module = migrations_module + + +class ResponseForException(Exception): + ''' + allows us to show a much simplified traceback without + framework code. + ''' + + pass + + +def add_field_tracker(cls): + # need to do it here because FieldTracker doesnt work on abstract classes + _ft = model_utils.tracker.FieldTracker() + _ft.contribute_to_class(cls, '_ft') + # need to call this, because class_prepared has already been fired + # (it is currently executing) + _ft.finalize_class(sender=cls) + + +class FieldInstanceTrackerWithVarsNumpySupport( + model_utils.tracker.FieldInstanceTracker +): + def has_changed(self, field): + try: + return super().has_changed(field) + except ValueError as exc: + # we just assume it's always changed, so then we always save that field. + # it could be "The truth value of an array..." or "...of a DataFrame" + if 'The truth value of' in str(exc): + return True + raise + + +class FieldTrackerWithVarsSupport(model_utils.tracker.FieldTracker): + tracker_class = FieldInstanceTrackerWithVarsNumpySupport + + +def _group_by_rank(ranked_list, players_per_group): + ppg = players_per_group + players = ranked_list + group_matrix = [] + for i in range(0, len(players), ppg): + group_matrix.append(players[i : i + ppg]) + return group_matrix + - assert increment > 0 # not negative +def _group_randomly(group_matrix, fixed_id_in_group=False): + """Random Uniform distribution of players in every group""" - values = [] - current_value = Currency(first) + players = list(itertools.chain.from_iterable(group_matrix)) + sizes = [len(group) for group in group_matrix] + if sizes and any(size != sizes[0] for size in sizes): + raise ValueError('This algorithm does not work with unevenly sized groups') + players_per_group = sizes[0] - while True: - if current_value > last: - return values - values.append(current_value) - current_value += increment + if fixed_id_in_group: + group_matrix = [list(col) for col in zip(*group_matrix)] + for column in group_matrix: + random.shuffle(column) + return list(zip(*group_matrix)) + else: + random.shuffle(players) + return _group_by_rank(players, players_per_group) diff --git a/otree/common_internal.py b/otree/common_internal.py deleted file mode 100644 index 4f2d3eb20..000000000 --- a/otree/common_internal.py +++ /dev/null @@ -1,315 +0,0 @@ -import contextlib -import hashlib -import importlib.util -import logging -import random -import re -import string -import sys -import threading -import uuid -from collections import OrderedDict -from importlib import import_module -from io import StringIO -from channels.layers import get_channel_layer -import otree.channels.utils as channel_utils -import six -from django.apps import apps -from django.conf import settings -from django.db import connection -from django.db import transaction -from django.http import HttpResponseRedirect -from django.urls import reverse -from django.utils.safestring import mark_safe -from huey.contrib.djhuey import HUEY -from six.moves import urllib -from django.shortcuts import redirect - -# set to False if using runserver -USE_REDIS = True - -# these locks need to be here rather than views.abstract or views.participant -# because they need to be imported when the main thread runs. -start_link_thread_lock = threading.RLock() -wait_page_thread_lock = threading.RLock() - - -def add_params_to_url(url, params): - url_parts = list(urllib.parse.urlparse(url)) - - # use OrderedDict because sometimes we want certain params at end - # for readability/consistency - query = OrderedDict(urllib.parse.parse_qsl(url_parts[4])) - query.update(params) - url_parts[4] = urllib.parse.urlencode(query) - return urllib.parse.urlunparse(url_parts) - - -SESSION_CODE_CHARSET = string.ascii_lowercase + string.digits - - -def random_chars(num_chars): - return ''.join(random.choice(SESSION_CODE_CHARSET) for _ in range(num_chars)) - - -def random_chars_8(): - return random_chars(8) - - -def random_chars_10(): - return random_chars(10) - - -def get_models_module(app_name): - '''shouldn't rely on app registry because the app might have been removed - from SESSION_CONFIGS, especially if the session was created a long time - ago and you want to export it''' - module_name = '{}.models'.format(app_name) - return import_module(module_name) - - -def get_bots_module(app_name): - for module_name in ['tests', 'bots']: - dotted = '{}.{}'.format(app_name, module_name) - if importlib.util.find_spec(dotted): - return import_module(dotted) - raise ImportError('No tests module found for app {}'.format(app_name)) - - -def get_pages_module(app_name): - for module_name in ['pages', 'views']: - dotted = '{}.{}'.format(app_name, module_name) - if importlib.util.find_spec(dotted): - return import_module(dotted) - raise ImportError('No pages module found for app {}'.format(app_name)) - - -def get_app_constants(app_name): - '''Return the ``Constants`` object of a app defined in the models.py file. - - Example:: - - >>> from otree.common_internal import get_app_constants - >>> get_app_constants('demo') - - - ''' - return get_models_module(app_name).Constants - - -def get_dotted_name(Cls): - return '{}.{}'.format(Cls.__module__, Cls.__name__) - - -def get_app_label_from_import_path(import_path): - '''App authors must not override AppConfig.label''' - return import_path.split('.')[-2] - - -def get_app_label_from_name(app_name): - '''App authors must not override AppConfig.label''' - return app_name.split('.')[-1] - - -def expand_choice_tuples(choices): - '''allows the programmer to define choices as a list of values rather - than (value, display_value) - - ''' - if not choices: - return None - if not isinstance(choices[0], (list, tuple)): - choices = [(value, value) for value in choices] - return choices - - -def missing_db_tables(): - """Try to execute a simple select * for every model registered - """ - - # need to normalize to lowercase because MySQL converts DB names to lower - expected_table_names_dict = { - Model._meta.db_table.lower(): '{}.{}'.format(Model._meta.app_label, Model.__name__) - for Model in apps.get_models() - } - - expected_table_names = set(expected_table_names_dict.keys()) - - # again, normalize to lowercase - actual_table_names = set( - tn.lower() for tn in connection.introspection.table_names()) - - missing_table_names = expected_table_names - actual_table_names - - # don't use the SQL table name because it could be uppercase or lowercase, - # depending on whether it's MySQL - return [expected_table_names_dict[missing_table] - for missing_table in missing_table_names] - - -def make_hash(s): - s += settings.SECRET_KEY - return hashlib.sha224(s.encode()).hexdigest()[:8] - - -def get_admin_secret_code(): - s = settings.SECRET_KEY - return hashlib.sha224(s.encode()).hexdigest()[:8] - -def validate_alphanumeric(identifier, identifier_description): - if re.match(r'^[a-zA-Z0-9_]+$', identifier): - return identifier - raise ValueError( - '{} "{}" can only contain letters, numbers, ' - 'and underscores (_)'.format( - identifier_description, - identifier - ) - ) - - -EMPTY_ADMIN_USERNAME_MSG = 'settings.ADMIN_USERNAME is empty' -EMPTY_ADMIN_PASSWORD_MSG = 'settings.ADMIN_PASSWORD is empty' - -def ensure_superuser_exists(*args, **kwargs) -> str: - """ - Creates our default superuser. - If it fails, it returns a failure message - """ - username = settings.ADMIN_USERNAME - password = settings.ADMIN_PASSWORD - if not username: - return EMPTY_ADMIN_USERNAME_MSG - if not password: - return EMPTY_ADMIN_PASSWORD_MSG - from django.contrib.auth.models import User - if User.objects.filter(username=username).exists(): - # msg = 'Default superuser exists.' - # logger.info(msg) - return '' - User.objects.create_superuser(username, email='', password=password) - msg = 'Created superuser "{}"'.format(username) - logging.getLogger('otree').info(msg) - return '' - - -def release_any_stale_locks(): - ''' - Need to release locks in case the server was stopped abruptly, - and the 'finally' block in each lock did not execute - ''' - from otree.models_concrete import ParticipantLockModel - for LockModel in [ParticipantLockModel]: - try: - LockModel.objects.filter(locked=True).update(locked=False) - except: - # if server is started before DB is synced, - # this will raise - # django.db.utils.OperationalError: no such table: - # otree_globallockmodel - # we can ignore that because we just want to make sure there are no - # active locks - pass - - -def get_redis_conn(): - '''reuse Huey Redis connection''' - return HUEY.storage.conn - -def has_group_by_arrival_time(app_name): - page_sequence = get_pages_module(app_name).page_sequence - if len(page_sequence) == 0: - return False - return getattr(page_sequence[0], 'group_by_arrival_time', False) - - -@contextlib.contextmanager -def transaction_except_for_sqlite(): - ''' - On SQLite, transactions tend to result in "database locked" errors. - So, skip the transaction on SQLite, to allow local dev. - Should only be used if omitting the transaction rarely causes problems. - ''' - if settings.DATABASES['default']['ENGINE'].endswith('sqlite3'): - yield - else: - with transaction.atomic(): - yield - - -class DebugTable(object): - def __init__(self, title, rows): - self.title = title - self.rows = [] - for k, v in rows: - if isinstance(v, six.string_types): - v = v.strip().replace("\n", "") - v = mark_safe(v) - self.rows.append((k, v)) - - -class InvalidRoundError(ValueError): - pass - - -def in_round(ModelClass, round_number, **kwargs): - if round_number < 1: - raise InvalidRoundError('Invalid round number: {}'.format(round_number)) - try: - return ModelClass.objects.get(round_number=round_number, **kwargs) - except ModelClass.DoesNotExist: - raise InvalidRoundError( - 'No corresponding {} found with round_number={}'.format( - ModelClass.__name__, round_number)) from None - - -def in_rounds(ModelClass, first, last, **kwargs): - if first < 1: - raise InvalidRoundError('Invalid round number: {}'.format(first)) - qs = ModelClass.objects.filter( - round_number__range=(first, last), - **kwargs - ).order_by('round_number') - - ret = list(qs) - num_results = len(ret) - expected_num_results = last-first+1 - if num_results != expected_num_results: - raise InvalidRoundError( - 'Database contains {} records for rounds {}-{}, but expected {}'.format( - num_results, first, last, expected_num_results)) - return ret - - -class BotError(AssertionError): - pass - - -def _get_all_configs(): - return [ - app - for app in apps.get_app_configs() - if app.name in settings.INSTALLED_OTREE_APPS] - - -def participant_start_url(code): - return '/InitializeParticipant/{}'.format(code) - - -def patch_migrations_module(): - from django.db.migrations.loader import MigrationLoader - def migrations_module(*args, **kwargs): - # need to return None so that load_disk() considers it - # unmigrated, and False so that load_disk() considers it - # non-explicit - return None, False - MigrationLoader.migrations_module = migrations_module - - -class ResponseForException(Exception): - ''' - allows us to show a much simplified traceback without - framework code. - ''' - pass - diff --git a/otree/constants.py b/otree/constants.py index 72bc4c98c..6a3332ad5 100644 --- a/otree/constants.py +++ b/otree/constants.py @@ -1,6 +1,10 @@ +from django.utils.translation import ugettext_lazy + + class MustCopyError(Exception): pass + def _raise_must_copy(*args, **kwargs): raise MustCopyError( "Cannot modify a list that originated in Constants. " @@ -11,6 +15,7 @@ def _raise_must_copy(*args, **kwargs): "This is to prevent accidentally modifying the original list. " ) + class ConstantsList(list): __setitem__ = _raise_must_copy @@ -42,3 +47,13 @@ def __new__(mcs, name, bases, attrs): class BaseConstants(metaclass=BaseConstantsMeta): pass + + +get_param_truth_value = '1' +admin_secret_code = 'admin_secret_code' +timeout_happened = 'timeout_happened' +participant_label = 'participant_label' +wait_page_http_header = 'oTree-Wait-Page' +redisplay_with_errors_http_header = 'oTree-Redisplay-With-Errors' +field_required_msg = ugettext_lazy('This field is required.') +AUTO_NAME_BOTS_EXPORT_FOLDER = 'auto_name' \ No newline at end of file diff --git a/otree/constants_internal.py b/otree/constants_internal.py deleted file mode 100644 index 065f68e7b..000000000 --- a/otree/constants_internal.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - - -# ============================================================================= -# IMPORTS -# ============================================================================= - -from django.utils.translation import ugettext_lazy - -# ============================================================================= -# CONSTANTS -# ============================================================================= - -SubsessionClass = 'SubsessionClass' -GroupClass = 'GroupClass' -PlayerClass = 'PlayerClass' -UserClass = 'UserClass' - -group_id = 'group_id' - -user_code = 'user_code' -subsession_code = 'subsession_code' -subsession_code_obfuscated = 'exp_code' - -nickname = 'nickname' - -completed_views = 'completed_views' - -form_invalid = 'form_invalid' -precondition = 'precondition' -mturk_worker_id = 'mturk_worker_id' -debug_values_built_in = 'debug_values_built_in' -debug_values = 'debug_values' -get_param_truth_value = '1' - -admin_secret_code = 'admin_secret_code' -timeout_seconds = 'timeout_seconds' -timeout_happened = 'timeout_happened' -check_auto_submit = 'check_auto_submit' -page_expiration_times = 'page_timeouts' -participant_label = 'participant_label' -participant_id = 'participant_id' -participant_code = 'participant_code' -session_id = 'session_id' -session_code = 'session_code' -wait_page_http_header = 'oTree-Wait-Page' -redisplay_with_errors_http_header = 'oTree-Redisplay-With-Errors' -user_type = 'user_type' -user_type_participant = 'p' -success = True -failure = False - - -# Translators: for required form fields -field_required_msg = ugettext_lazy('This field is required.') - -AUTO_NAME_BOTS_EXPORT_FOLDER = 'auto_name' \ No newline at end of file diff --git a/otree/currency/__init__.py b/otree/currency/__init__.py index e5b6b3e1c..56566004e 100644 --- a/otree/currency/__init__.py +++ b/otree/currency/__init__.py @@ -1,17 +1,19 @@ from django.utils import numberformat, formats from otree.currency.locale import CURRENCY_SYMBOLS, get_currency_format -from six import __init__ + _original_number_format = numberformat.format + def otree_number_format(number, *args, **kwargs): if isinstance(number, BaseCurrency): - return six.text_type(number) + return str(number) return _original_number_format(number, *args, **kwargs) + from decimal import Decimal, ROUND_HALF_UP -import six + from django.conf import settings from django.utils import formats, numberformat from django.utils.translation import ungettext @@ -57,14 +59,17 @@ def _prepare_operand(self, other): def _make_binary_operator(name): method = getattr(Decimal, name, None) + def binary_function(self, other, context=None): other = _prepare_operand(self, other) return self.__class__(method(self, other)) + return binary_function # Data class + class BaseCurrency(Decimal): # what's this for?? can't money have any # of decimal places? @@ -112,17 +117,17 @@ def _format_currency(cls, number): lc, LO = LANGUAGE_CODE.split('-') else: lc, LO = LANGUAGE_CODE, '' - return format_currency(number, lc=lc, LO=LO, - CUR=settings.REAL_WORLD_CURRENCY_CODE + return format_currency( + number, lc=lc, LO=LO, CUR=settings.REAL_WORLD_CURRENCY_CODE ) def __format__(self, format_spec): if format_spec in {'', 's'}: - formatted = six.text_type(self) + formatted = str(self) else: formatted = format(Decimal(self), format_spec) - if isinstance(format_spec, six.binary_type): + if isinstance(format_spec, bytes): return formatted.encode('utf-8') else: return formatted @@ -133,7 +138,7 @@ def __repr__(self): def __eq__(self, other): if isinstance(other, BaseCurrency): return Decimal.__eq__(self, other) - elif isinstance(other, six.integer_types + (float, Decimal)): + elif isinstance(other, (int, float, Decimal)): return Decimal.__eq__(self, self._sanitize(other)) else: return False @@ -172,8 +177,11 @@ def __pow__(self, other, modulo=None): __rpow__ = _make_binary_operator('__rpow__') def deconstruct(self): - return '{}.{}'.format(self.__module__, self.__class__.__name__), \ - [Decimal.__str__(self)], {} + return ( + '{}.{}'.format(self.__module__, self.__class__.__name__), + [Decimal.__str__(self)], + {}, + ) @classmethod def get_num_decimal_places(cls): @@ -181,7 +189,6 @@ def get_num_decimal_places(cls): class Currency(BaseCurrency): - @classmethod def get_num_decimal_places(cls): if settings.USE_POINTS: @@ -192,8 +199,8 @@ def get_num_decimal_places(cls): def to_real_world_currency(self, session): if settings.USE_POINTS: return RealWorldCurrency( - float(self) * - session.config['real_world_currency_per_point']) + float(self) * session.config['real_world_currency_per_point'] + ) else: return self @@ -203,8 +210,7 @@ def _format_currency(cls, number): formatted_number = formats.number_format(number) if hasattr(settings, 'POINTS_CUSTOM_NAME'): - return '{} {}'.format( - formatted_number, settings.POINTS_CUSTOM_NAME) + return '{} {}'.format(formatted_number, settings.POINTS_CUSTOM_NAME) # Translators: display a number of points, # like "1 point", "2 points", ... @@ -215,8 +221,7 @@ def _format_currency(cls, number): # and msgstr[1] is plural # the {} represents the number; # don't forget to include it in your translation - return ungettext('{} point', '{} points', number).format( - formatted_number) + return ungettext('{} point', '{} points', number).format(formatted_number) else: return super()._format_currency(number) @@ -231,8 +236,10 @@ def to_real_world_currency(self, session): def get_num_decimal_places(cls): return settings.REAL_WORLD_CURRENCY_DECIMAL_PLACES + # Utils + def to_dec(value): return Decimal(value) if isinstance(value, Currency) else value @@ -245,4 +252,31 @@ def format_currency(number, lc, LO, CUR): retval = c_format.replace('¤', symbol).replace('#', formatted_abs) if number < 0: retval = '-{}'.format(retval) - return retval \ No newline at end of file + return retval + + +def currency_range(first, last, increment): + assert last >= first + if Currency(increment) == 0: + if settings.USE_POINTS: + setting_name = 'POINTS_DECIMAL_PLACES' + else: + setting_name = 'REAL_WORLD_CURRENCY_DECIMAL_PLACES' + raise ValueError( + ( + 'currency_range() step argument must not be zero. ' + 'Maybe your {} setting is ' + 'causing it to be rounded to 0.' + ).format(setting_name) + ) + + assert increment > 0 # not negative + + values = [] + current_value = Currency(first) + + while True: + if current_value > last: + return values + values.append(current_value) + current_value += increment diff --git a/otree/currency/locale.py b/otree/currency/locale.py index b61e2a39f..465547517 100644 --- a/otree/currency/locale.py +++ b/otree/currency/locale.py @@ -59,6 +59,15 @@ def get_currency_format(lc: str, LO: str, CUR: str) -> str: return '₹ #' if CUR == 'SGD': return '$#' + # override for CNY/JPY/KRW, otherwise it would be written as 원10 + # need to use the chinese character because that's already what's used in + # form inputs + if CUR == 'CNY': + return '#元' + if CUR == 'JPY': + return '#円' + if CUR == 'KRW': + return '#원' return '¤#' if lc == 'zh': @@ -372,4 +381,4 @@ def get_currency_format(lc: str, LO: str, CUR: str) -> str: TWD en_US NT$1.00 ZAR en_US R1.00 -''' \ No newline at end of file +''' diff --git a/otree/db/idmap.py b/otree/db/idmap.py index 91c8855e4..105a0e1b3 100644 --- a/otree/db/idmap.py +++ b/otree/db/idmap.py @@ -66,13 +66,8 @@ def get_cached_instance(cls, *args, **kwargs): if is_active(): return super().get_cached_instance(*args, **kwargs) -CLASSES_TO_SAVE = { - 'Session', - 'Participant', - 'Subsession', - 'Group', - 'Player' -} + +CLASSES_TO_SAVE = {'Session', 'Participant', 'Subsession', 'Group', 'Player'} def _get_save_objects_model_instances(): diff --git a/otree/db/models.py b/otree/db/models.py index 7174d769c..74504e0fe 100644 --- a/otree/db/models.py +++ b/otree/db/models.py @@ -1,34 +1,21 @@ -from django.db import models -from django.db.models.fields import related +import logging +from decimal import Decimal + from django.core import exceptions +from django.db import models from django.utils.translation import ugettext_lazy -from django.conf import settings -from django.apps import apps -from decimal import Decimal -from otree.currency import ( - Currency, RealWorldCurrency -) -import logging from idmap.models import IdMapModelBase -from .idmap import IdMapModel -import otree.common -from otree.common_internal import ( - expand_choice_tuples, get_app_label_from_import_path) -from otree.constants_internal import field_required_msg -from otree_save_the_change.mixins import SaveTheChange - -# this is imported from other modules -from .serializedfields import _PickleField +from otree.common import expand_choice_tuples, get_app_label_from_import_path +from otree.constants import field_required_msg +from otree.currency import Currency, RealWorldCurrency +from .idmap import IdMapModel +from django.forms import widgets as dj_widgets +from .serializedfields import _PickleField # noqa logger = logging.getLogger(__name__) -class _JSONField(models.TextField): - '''just keeping around so that Migrations don't crash''' - pass - - class OTreeModelBase(IdMapModelBase): def __new__(mcs, name, bases, attrs): meta = attrs.get("Meta") @@ -47,17 +34,10 @@ def __new__(mcs, name, bases, attrs): meta.use_strong_refs = True attrs["Meta"] = meta - - new_class = super().__new__(mcs, name, bases, attrs) if not hasattr(new_class._meta, 'use_strong_refs'): new_class._meta.use_strong_refs = False - - # 2015-12-22: this probably doesn't work anymore, - # since we moved _choices to views.py - # but we can tell users they can define FOO_choices in models.py, - # and then call it in the equivalent method in views.py for f in new_class._meta.fields: if hasattr(new_class, f.name + '_choices'): attr_name = 'get_%s_display' % f.name @@ -69,10 +49,6 @@ def __new__(mcs, name, bases, attrs): return new_class -def get_model(*args, **kwargs): - return apps.get_model(*args, **kwargs) - - def make_get_display(field): def get_FIELD_display(self): choices = getattr(self, field.name + '_choices')() @@ -82,15 +58,13 @@ def get_FIELD_display(self): return get_FIELD_display -class OTreeModel(SaveTheChange, IdMapModel, metaclass=OTreeModelBase): - +class OTreeModel(IdMapModel, metaclass=OTreeModelBase): class Meta: abstract = True def __repr__(self): return '<{} pk={}>'.format(self.__class__.__name__, self.pk) - _is_frozen = False NoneType = type(None) @@ -113,7 +87,6 @@ def __repr__(self): # used by Prefetch. '_ordered_players', '_is_frozen', - # extras on 2018-11-24 'id', '_changed_fields', @@ -141,27 +114,30 @@ def __setattr__(self, field_name: str, value): if field_type_name in self._setattr_datatypes: allowed_types = self._setattr_datatypes[field_type_name] if ( - isinstance(value, allowed_types) - # numpy uses its own datatypes, e.g. numpy._bool, - # which doesn't inherit from python bool. - or 'numpy' in str(type(value)) - # 2018-07-18: - # have an exception for the bug in the 'quiz' sample game - # after a while, we can remove this - or field_name == 'question_id' + isinstance(value, allowed_types) + # numpy uses its own datatypes, e.g. numpy._bool, + # which doesn't inherit from python bool. + or 'numpy' in str(type(value)) + # 2018-07-18: + # have an exception for the bug in the 'quiz' sample game + # after a while, we can remove this + or field_name == 'question_id' ): pass else: - msg = ( - '{} should be set to {}, not {}.' - ).format(field_type_name, allowed_types[0].__name__, type(value).__name__) + msg = ('{} should be set to {}, not {}.').format( + field_type_name, + allowed_types[0].__name__, + type(value).__name__, + ) raise TypeError(msg) elif ( - field_name in self._setattr_attributes or - field_name in self._setattr_whitelist or - # idmap uses _group_cache, _subsession_cache, - # _prefetched_objects_cache, etc - field_name.endswith('_cache') + field_name in self._setattr_attributes + or field_name in self._setattr_whitelist + or + # idmap uses _group_cache, _subsession_cache, + # _prefetched_objects_cache, etc + field_name.endswith('_cache') ): # django sometimes reassigns to non-field attributes that # were set before the class was frozen, such as @@ -169,9 +145,9 @@ def __setattr__(self, field_name: str, value): # or assigning to a property like Player.payoff pass else: - msg = ( - '{} has no field "{}".' - ).format(self.__class__.__name__, field_name) + msg = ('{} has no field "{}".').format( + self.__class__.__name__, field_name + ) raise AttributeError(msg) self._super_setattr(field_name, value) @@ -179,6 +155,12 @@ def __setattr__(self, field_name: str, value): # super() is a bit slower but only gets run during __init__ super().__setattr__(field_name, value) + def save(self, *args, **kwargs): + # Use with FieldTracker + if self.pk and hasattr(self, '_ft') and 'update_fields' not in kwargs: + kwargs['update_fields'] = [k for k in self._ft.changed()] + super().save(*args, **kwargs) + Model = OTreeModel @@ -195,18 +177,18 @@ def fix_choices_arg(kwargs): kwargs['choices'] = choices -class _OtreeModelFieldMixin(object): - +class _OtreeModelFieldMixin: def __init__( - self, - *, - initial=None, - label=None, - min=None, - max=None, - doc='', - widget=None, - **kwargs): + self, + *, + initial=None, + label=None, + min=None, + max=None, + doc='', + widget=None, + **kwargs, + ): self.widget = widget self.doc = doc @@ -238,23 +220,27 @@ def __init__( super().__init__(**kwargs) + def formfield(self, **kwargs): + if self.widget: + kwargs['widget'] = self.widget + return super().formfield(**kwargs) + class _OtreeNumericFieldMixin(_OtreeModelFieldMixin): auto_submit_default = 0 -class BaseCurrencyField( - _OtreeNumericFieldMixin, models.DecimalField): - MONEY_CLASS = None # need to set in subclasses +class BaseCurrencyField(_OtreeNumericFieldMixin, models.DecimalField): + + MONEY_CLASS = None # need to set in subclasses def __init__(self, **kwargs): # i think it's sufficient just to store a high number; # this needs to be higher than decimal_places decimal_places = self.MONEY_CLASS.get_num_decimal_places() # where does this come from? - max_digits=12 - super().__init__( - max_digits=max_digits, decimal_places=decimal_places, **kwargs) + max_digits = 12 + super().__init__(max_digits=max_digits, decimal_places=decimal_places, **kwargs) def deconstruct(self): name, path, args, kwargs = super().deconstruct() @@ -275,7 +261,7 @@ def get_prep_value(self, value): return None return Decimal(self.to_python(value)) - def from_db_value(self, value, expression, connection, context): + def from_db_value(self, value, expression, connection): return self.to_python(value) @@ -285,6 +271,7 @@ class CurrencyField(BaseCurrencyField): def formfield(self, **kwargs): import otree.forms + defaults = { 'form_class': otree.forms.CurrencyField, 'choices_form_class': otree.forms.CurrencyChoiceField, @@ -299,6 +286,7 @@ class RealWorldCurrencyField(BaseCurrencyField): def formfield(self, **kwargs): import otree.forms + defaults = { 'form_class': otree.forms.RealWorldCurrencyField, 'choices_form_class': otree.forms.CurrencyChoiceField, @@ -307,67 +295,22 @@ def formfield(self, **kwargs): return super().formfield(**defaults) -class BooleanField(_OtreeModelFieldMixin, models.NullBooleanField): - # 2014/3/28: i just define the allowable choices on the model field, - # instead of customizing the widget since then it works for any widget - - def __init__(self, - *, - choices=None, - **kwargs): - # 2015-1-19: why is this here? isn't this the default behavior? - # 2013-1-26: ah, because we don't want the "----" (None) choice - if choices is None: - choices = ( - (True, ugettext_lazy('Yes')), - (False, ugettext_lazy('No')) - ) - - # We need to store whether blank is explicitly specified or not. If - # it's not specified explicitly (which will make it default to False) - # we need to special case validation logic in the form field if a - # checkbox input is used. - self._blank_is_explicit = 'blank' in kwargs - - super().__init__( - choices=choices, - **kwargs) - - # you cant override "blank" or you will destroy the migration system - self.allow_blank = bool(kwargs.get("blank")) +class BooleanField(_OtreeModelFieldMixin, models.BooleanField): + def __init__(self, **kwargs): + # usually checkbox is not required, except for consent forms. + widget = kwargs.get('widget') + if isinstance(widget, dj_widgets.CheckboxInput): + kwargs.setdefault('blank', True) + + # we need to set explicitly because otherwise the empty choice will show up as + # "Unknown" in a select widget. This makes it set to '-----------'. + kwargs.setdefault( + 'choices', [(True, ugettext_lazy('Yes')), (False, ugettext_lazy('No'))] + ) + super().__init__(**kwargs) auto_submit_default = False - def clean(self, value, model_instance): - if value is None and not self.allow_blank: - raise exceptions.ValidationError(field_required_msg) - return super().clean(value, model_instance) - - def formfield(self, *args, **kwargs): - from otree import widgets - - is_checkbox_widget = isinstance(self.widget, widgets.CheckboxInput) - if not self._blank_is_explicit and is_checkbox_widget: - kwargs.setdefault('required', False) - else: - # this use the allow_blank for the form fields - kwargs.setdefault('required', not self.allow_blank) - - return super().formfield(*args, **kwargs) - - -class AutoField(_OtreeModelFieldMixin, models.AutoField): - pass - - -class BigIntegerField( - _OtreeNumericFieldMixin, models.BigIntegerField): - auto_submit_default = 0 - - -class BinaryField(_OtreeModelFieldMixin, models.BinaryField): - pass - class StringField(_OtreeModelFieldMixin, models.CharField): ''' @@ -375,105 +318,85 @@ class StringField(_OtreeModelFieldMixin, models.CharField): causing any problems, even though Django recommends against that, but that's for forms on pages that get viewed multiple times ''' + def __init__( - self, - *, - # varchar max length doesn't affect performance or even storage - # size; it's just for validation. so, to be easy to use, - # there is no reason for oTree to set a short default length - # for CharFields. The main consideration is that MySQL cannot index - # varchar longer than 255 chars, but that is not relevant here - # because oTree only uses indexes for fields defined in otree-core, - # which have explicit max_lengths anyway. - max_length=10000, - **kwargs): - - super().__init__( - max_length=max_length, - **kwargs) + self, + *, + # varchar max length doesn't affect performance or even storage + # size; it's just for validation. so, to be easy to use, + # there is no reason for oTree to set a short default length + # for CharFields. The main consideration is that MySQL cannot index + # varchar longer than 255 chars, but that is not relevant here + # because oTree only uses indexes for fields defined in otree-core, + # which have explicit max_lengths anyway. + max_length=10000, + **kwargs, + ): + + super().__init__(max_length=max_length, **kwargs) auto_submit_default = '' -class DateField(_OtreeModelFieldMixin, models.DateField): - pass - - -class DateTimeField(_OtreeModelFieldMixin, models.DateTimeField): +class DecimalField(_OtreeNumericFieldMixin, models.DecimalField): pass -class DecimalField( - _OtreeNumericFieldMixin, - models.DecimalField): +class FloatField(_OtreeNumericFieldMixin, models.FloatField): pass -class EmailField(_OtreeModelFieldMixin, models.EmailField): +class IntegerField(_OtreeNumericFieldMixin, models.IntegerField): pass -class FileField(_OtreeModelFieldMixin, models.FileField): +class PositiveIntegerField(_OtreeNumericFieldMixin, models.PositiveIntegerField): pass -class FilePathField(_OtreeModelFieldMixin, models.FilePathField): - pass - - -class FloatField( - _OtreeNumericFieldMixin, - models.FloatField): - pass - - -class IntegerField( - _OtreeNumericFieldMixin, models.IntegerField): - pass - - -class GenericIPAddressField(_OtreeModelFieldMixin, - models.GenericIPAddressField): - pass - - -class PositiveIntegerField( - _OtreeNumericFieldMixin, - models.PositiveIntegerField): - pass - - -class PositiveSmallIntegerField( - _OtreeNumericFieldMixin, - models.PositiveSmallIntegerField): - pass - - -class SlugField(_OtreeModelFieldMixin, models.SlugField): - pass - +class LongStringField(_OtreeModelFieldMixin, models.TextField): + auto_submit_default = '' -class SmallIntegerField( - _OtreeNumericFieldMixin, models.SmallIntegerField): - pass +MSG_DEPRECATED_FIELD = """ +{FieldName} does not exist in oTree. +You should either replace it with one of oTree's field types, or import it from Django directly. +Note that Django model fields do not accept oTree-specific arguments like label= and widget=. +""".replace( + '\n', ' ' +) -class LongStringField(_OtreeModelFieldMixin, models.TextField): - auto_submit_default = '' +def make_deprecated_field(FieldName): + def DeprecatedField(*args, **kwargs): + # putting the msg on a separate line gives better tracebacks + raise Exception(MSG_DEPRECATED_FIELD.format(FieldName)) -class TimeField(_OtreeModelFieldMixin, models.TimeField): - pass + return DeprecatedField -class URLField(_OtreeModelFieldMixin, models.URLField): - pass +ManyToOneRel = make_deprecated_field("ManyToOneRel") +ManyToManyField = make_deprecated_field("ManyToManyField") +OneToOneField = make_deprecated_field("OneToOneField") +AutoField = make_deprecated_field("AutoField") +BigIntegerField = make_deprecated_field("BigIntegerField") +BinaryField = make_deprecated_field("BinaryField") +EmailField = make_deprecated_field("EmailField") +FileField = make_deprecated_field("FileField") +GenericIPAddressField = make_deprecated_field("GenericIPAddressField") +PositiveSmallIntegerField = make_deprecated_field("PositiveSmallIntegerField") +SlugField = make_deprecated_field("SlugField") +SmallIntegerField = make_deprecated_field("SmallIntegerField") +TimeField = make_deprecated_field("TimeField") +URLField = make_deprecated_field("URLField") +DateField = make_deprecated_field("DateField") +DateTimeField = make_deprecated_field("DateTimeField") CharField = StringField TextField = LongStringField +# keep ForeignKey around ForeignKey = models.ForeignKey -ManyToOneRel = related.ManyToOneRel -ManyToManyField = models.ManyToManyField -OneToOneField = models.OneToOneField -CASCADE = models.CASCADE \ No newline at end of file + + +CASCADE = models.CASCADE diff --git a/otree/db/serializedfields.py b/otree/db/serializedfields.py index 7e638ab9a..401894d65 100644 --- a/otree/db/serializedfields.py +++ b/otree/db/serializedfields.py @@ -32,25 +32,21 @@ def inspect_obj(obj): ) -def scan_for_model_instances(data): +def scan_for_model_instances(vars_dict: dict): ''' I don't know how to entirely block pickle from storing model instances, (I tried overriding __reduce__ but that interferes with deepcopy()) so this simple shallow scan should be good enough. ''' - # vars should always be a dict - if isinstance(data, dict): - for k, v in data.items(): - inspect_obj(k) - inspect_obj(v) - if isinstance(v, dict): - for kk, vv in v.items(): - inspect_obj(kk) - inspect_obj(vv) - elif isinstance(v, list): - for ele in v: - inspect_obj(ele) + for v in vars_dict.values(): + inspect_obj(v) + if isinstance(v, dict): + for vv in v.values(): + inspect_obj(vv) + elif isinstance(v, list): + for ele in v: + inspect_obj(ele) class _PickleField(models.TextField): @@ -80,9 +76,9 @@ def get_prep_value(self, value): value = serialize_to_string(value) return force_text(value) - def from_db_value(self, value, expression, connection, context): + def from_db_value(self, value, expression, connection): return self.to_python(value) def value_to_string(self, obj): value = self.value_from_object(obj) - return pickle.dumps(value) \ No newline at end of file + return pickle.dumps(value) diff --git a/otree/export.py b/otree/export.py index c2a3caaa0..10a40e4f6 100644 --- a/otree/export.py +++ b/otree/export.py @@ -1,35 +1,30 @@ -from otree.common import Currency, RealWorldCurrency -from django.db.models import BinaryField, ForeignKey -from importlib import import_module -import datetime -import inspect -import otree import collections -import six -from django.utils.encoding import force_text +import csv +import logging +import numbers from collections import OrderedDict -from django.conf import settings -from django.db.models import Max, Count, Sum from decimal import Decimal -import otree.constants_internal +from importlib import import_module + +import xlsxwriter +from django.db.models import BinaryField, ForeignKey +from django.db.models import Max +from django.utils.encoding import force_text + +import otree +from otree.currency import Currency, RealWorldCurrency +from otree.common import get_models_module +from otree.models.group import BaseGroup from otree.models.participant import Participant +from otree.models.player import BasePlayer from otree.models.session import Session from otree.models.subsession import BaseSubsession -from otree.models.group import BaseGroup -from otree.models.player import BasePlayer +from otree.models_concrete import PageCompletion from otree.session import SessionConfig -from otree.models_concrete import ( - PageCompletion) -from otree.common_internal import get_models_module -import numbers - -import csv -import xlsxwriter -import logging - logger = logging.getLogger(__name__) + def inspect_field_names(Model): # filter out BinaryField, because it's not useful for CSV export or # live results. could be very big, and causes problems with utf-8 export @@ -83,7 +78,6 @@ def _get_table_fields(Model, for_export=False): # even rows for different rounds. #'_round_number', '_current_page_name', - 'ip_address', 'time_started', 'visited', 'mturk_worker_id', @@ -106,10 +100,11 @@ def _get_table_fields(Model, for_export=False): if issubclass(Model, BasePlayer): subclass_fields = [ - f for f in inspect_field_names(Model) + f + for f in inspect_field_names(Model) if f not in inspect_field_names(BasePlayer) and f not in ['id', 'group_id', 'subsession_id'] - ] + ] if for_export: return ['id_in_group'] + subclass_fields + ['payoff'] @@ -118,19 +113,20 @@ def _get_table_fields(Model, for_export=False): if issubclass(Model, BaseGroup): subclass_fields = [ - f for f in inspect_field_names(Model) + f + for f in inspect_field_names(Model) if f not in inspect_field_names(BaseGroup) and f not in ['id', 'subsession_id'] - ] + ] return ['id_in_subsession'] + subclass_fields if issubclass(Model, BaseSubsession): subclass_fields = [ - f for f in inspect_field_names(Model) - if f not in inspect_field_names(BaseGroup) - and f != 'id' - ] + f + for f in inspect_field_names(Model) + if f not in inspect_field_names(BaseGroup) and f != 'id' + ] return ['round_number'] + subclass_fields @@ -162,6 +158,7 @@ def sanitize_for_live_update(value): return value[:MAX_LENGTH] + '...' return value + def get_payoff_plus_participation_fee(session, participant_values_dict): payoff = Currency(participant_values_dict['payoff']) return session._get_payoff_plus_participation_fee(payoff) @@ -187,18 +184,21 @@ def get_rows_for_wide_csv(): participant_fields = get_field_names_for_csv(Participant) participant_fields.append('payoff_plus_participation_fee') header_row = ['participant.{}'.format(fname) for fname in participant_fields] - header_row += ['session.{}'.format(fname) - for fname in session_fields] - header_row += ['session.config.{}'.format(fname) - for fname in session_config_fields] + header_row += ['session.{}'.format(fname) for fname in session_fields] + header_row += ['session.config.{}'.format(fname) for fname in session_config_fields] rows = [header_row] for participant in participants: session = session_cache[participant['session_id']] - participant['payoff_plus_participation_fee'] = get_payoff_plus_participation_fee(session, participant) + participant[ + 'payoff_plus_participation_fee' + ] = get_payoff_plus_participation_fee(session, participant) row = [sanitize_for_csv(participant[fname]) for fname in participant_fields] row += [sanitize_for_csv(getattr(session, fname)) for fname in session_fields] - row += [sanitize_for_csv(session.config.get(fname)) for fname in session_config_fields] + row += [ + sanitize_for_csv(session.config.get(fname)) + for fname in session_config_fields + ] rows.append(row) # heuristic to get the most relevant order of apps @@ -217,8 +217,8 @@ def get_rows_for_wide_csv(): app_names_with_data.add(app_name) apps_not_in_popular_sequence = [ - app for app in app_names_with_data - if app not in most_common_app_sequence] + app for app in app_names_with_data if app not in most_common_app_sequence + ] order_of_apps = list(most_common_app_sequence) + apps_not_in_popular_sequence @@ -249,7 +249,7 @@ def get_rows_for_wide_csv(): def get_rows_for_wide_csv_round(app_name, round_number, sessions): - models_module = otree.common_internal.get_models_module(app_name) + models_module = otree.common.get_models_module(app_name) Player = models_module.Player Group = models_module.Group Subsession = models_module.Subsession @@ -268,21 +268,27 @@ def get_rows_for_wide_csv_round(app_name, round_number, sessions): header_row = [] for model_name in model_order: for colname in columns_for_models[model_name]: - header_row.append('{}.{}.{}.{}'.format( - app_name, round_number, model_name, colname)) + header_row.append( + '{}.{}.{}.{}'.format(app_name, round_number, model_name, colname) + ) rows.append(header_row) empty_row = ['' for _ in range(len(header_row))] for session in sessions: subsession = Subsession.objects.filter( - session_id=session.id, round_number=round_number).values() + session_id=session.id, round_number=round_number + ).values() if not subsession: subsession_rows = [empty_row for _ in range(session.num_participants)] else: subsession = subsession[0] subsession_id = subsession['id'] - players = Player.objects.filter(subsession_id=subsession_id).order_by('id').values() + players = ( + Player.objects.filter(subsession_id=subsession_id) + .order_by('id') + .values() + ) if len(players) != session.num_participants: msg = ( @@ -290,10 +296,13 @@ def get_rows_for_wide_csv_round(app_name, round_number, sessions): "has {} players. The number of players in the subsession " "should always match the number of players in the session. " "Reset the database and examine your code." - ).format(session.code, session.num_participants, - round_number, - app_name, - len(players)) + ).format( + session.code, + session.num_participants, + round_number, + app_name, + len(players), + ) raise AssertionError(msg) subsession_rows = [] @@ -305,7 +314,8 @@ def get_rows_for_wide_csv_round(app_name, round_number, sessions): all_objects = { 'player': player, 'group': group_cache[player['group_id']], - 'subsession': subsession} + 'subsession': subsession, + } for model_name in model_order: for colname in columns_for_models[model_name]: @@ -319,7 +329,7 @@ def get_rows_for_wide_csv_round(app_name, round_number, sessions): def get_rows_for_csv(app_name): # need to use app_name and not app_label because the app might have been # removed from SESSION_CONFIGS - models_module = otree.common_internal.get_models_module(app_name) + models_module = otree.common.get_models_module(app_name) Player = models_module.Player Group = models_module.Group Subsession = models_module.Subsession @@ -337,19 +347,26 @@ def get_rows_for_csv(app_name): value_dicts = { 'group': {row['id']: row for row in Group.objects.values()}, 'subsession': {row['id']: row for row in Subsession.objects.values()}, - 'participant': {row['id']: row for row in - Participant.objects.filter( - id__in=participant_ids).values()}, - 'session': {row['id']: row for row in - Session.objects.filter(id__in=session_ids).values()} + 'participant': { + row['id']: row + for row in Participant.objects.filter(id__in=participant_ids).values() + }, + 'session': { + row['id']: row + for row in Session.objects.filter(id__in=session_ids).values() + }, } model_order = ['participant', 'player', 'group', 'subsession', 'session'] # header row - rows = [['{}.{}'.format(model_name, colname) - for model_name in model_order - for colname in columns_for_models[model_name]]] + rows = [ + [ + '{}.{}'.format(model_name, colname) + for model_name in model_order + for colname in columns_for_models[model_name] + ] + ] for player in players: # because player.payoff is a property @@ -384,9 +401,11 @@ def get_rows_for_live_update(subsession: BaseSubsession): # we had a strange result on one person's heroku instance # where Meta.ordering on the Player was being ingnored # when you use a filter. So we add one explicitly. - players = Player.objects.filter( - subsession_id=subsession.pk).select_related( - 'group', 'subsession').order_by('pk') + players = ( + Player.objects.filter(subsession_id=subsession.pk) + .select_related('group', 'subsession') + .order_by('pk') + ) model_order = ['player', 'group', 'subsession'] @@ -402,15 +421,11 @@ def get_rows_for_live_update(subsession: BaseSubsession): for colname in columns_for_models[model_name]: attr = getattr(model_instance, colname, '') - if isinstance(attr, collections.Callable): - if model_name == 'player' and colname == 'role' \ - and model_instance.group is None: - attr = '' - else: - try: - attr = attr() - except: - attr = "(error)" + if callable(attr): + try: + attr = attr() + except Exception: + attr = "" row.append(sanitize_for_live_update(attr)) rows.append(row) @@ -463,7 +478,7 @@ def export_time_spent(fp): 'page_index', 'app_name', 'page_name', - 'time_stamp', + 'epoch_time', 'seconds_on_page', 'subsession_pk', 'auto_submitted', @@ -475,112 +490,3 @@ def export_time_spent(fp): writer = csv.writer(fp) writer.writerows([column_names]) writer.writerows(rows) - - -def export_docs(fp, app_name): - """Write the dcos of the given app name as csv into the file-like object - - """ - - # generate doct_dict - models_module = get_models_module(app_name) - - model_names = ["Participant", "Player", "Group", "Subsession", "Session"] - line_break = '\r\n' - - def choices_readable(choices): - lines = [] - for value, name in choices: - # unicode() call is for lazy translation strings - lines.append(u'{}: {}'.format(value, six.text_type(name))) - return lines - - def generate_doc_dict(): - doc_dict = OrderedDict() - - data_types_readable = { - 'PositiveIntegerField': 'positive integer', - 'IntegerField': 'integer', - 'BooleanField': 'boolean', - 'CharField': 'text', - 'TextField': 'text', - 'FloatField': 'decimal', - 'DecimalField': 'decimal', - 'CurrencyField': 'currency'} - - for model_name in model_names: - if model_name == 'Participant': - Model = Participant - elif model_name == 'Session': - Model = Session - else: - Model = getattr(models_module, model_name) - - field_names = set(field.name for field in Model._meta.fields) - - members = get_field_names_for_csv(Model) - doc_dict[model_name] = OrderedDict() - - for member_name in members: - member = getattr(Model, member_name, None) - doc_dict[model_name][member_name] = OrderedDict() - if member_name == 'id': - doc_dict[model_name][member_name]['type'] = [ - 'positive integer'] - doc_dict[model_name][member_name]['doc'] = ['Unique ID'] - elif member_name in field_names: - member = Model._meta.get_field(member_name) - - internal_type = member.get_internal_type() - data_type = data_types_readable.get( - internal_type, internal_type) - - doc_dict[model_name][member_name]['type'] = [data_type] - - # flag error if the model doesn't have a doc attribute, - # which it should unless the field is a 3rd party field - doc = getattr(member, 'doc', '[error]') or '' - doc_dict[model_name][member_name]['doc'] = [ - line.strip() for line in doc.splitlines() - if line.strip()] - - choices = getattr(member, 'choices', None) - if choices: - doc_dict[model_name][member_name]['choices'] = ( - choices_readable(choices)) - elif isinstance(member, collections.Callable): - doc_dict[model_name][member_name]['doc'] = [ - inspect.getdoc(member)] - return doc_dict - - def docs_as_string(doc_dict): - - first_line = '{}: Documentation'.format(app_name) - second_line = '*' * len(first_line) - - lines = [ - first_line, second_line, '', - 'Accessed: {}'.format(datetime.date.today().isoformat()), ''] - - app_doc = getattr(models_module, 'doc', '') - if app_doc: - lines += [app_doc, ''] - - for model_name in doc_dict: - lines.append(model_name) - - for member in doc_dict[model_name]: - lines.append('\t{}'.format(member)) - for info_type in doc_dict[model_name][member]: - lines.append('\t\t{}'.format(info_type)) - for info_line in doc_dict[model_name][member][info_type]: - lines.append(u'{}{}'.format('\t' * 3, info_line)) - - output = u'\n'.join(lines) - return output.replace('\n', line_break).replace('\t', ' ') - - doc_dict = generate_doc_dict() - doc = docs_as_string(doc_dict) - fp.write(doc) - - diff --git a/otree/extensions.py b/otree/extensions.py index c6657cc69..2df3f7a3a 100644 --- a/otree/extensions.py +++ b/otree/extensions.py @@ -1,6 +1,7 @@ from importlib import import_module from django.conf import settings import importlib.util +import sys """ @@ -19,12 +20,8 @@ routing.py ---------- -Should contain a variable ``channel_routing``, -with a list of channel routes, as described in the Django channels documentation: - -https://channels.readthedocs.io/en/stable/getting-started.html#routing - -These routes will be appended to oTree's built-in channel routes. +Should contain a variable ``websocket_routes``, +with a list of channel routes, as described in the Django channels documentation. admin.py -------- @@ -50,26 +47,22 @@ def get(self, request, *args, **kwargs): You don't need to worry about login_required and AUTH_LEVEL; oTree will handle this automatically. +""" -(In the future, admin.py may be used for other admin customizations, -not just data export.) +from logging import getLogger + +logger = getLogger(__name__) -""" def get_extensions_modules(submodule_name): modules = [] - extension_apps = getattr(settings, 'EXTENSION_APPS', []) - # legacy support for otreechat - if 'otreechat' in settings.INSTALLED_APPS: - extension_apps.append('otreechat') find_spec = importlib.util.find_spec - for app_name in extension_apps: - package_dotted = '{}.otree_extensions'.format(app_name) - submodule_dotted = '{}.{}'.format(package_dotted, submodule_name) + for app_name in getattr(settings, 'EXTENSION_APPS', []): + package_dotted = f'{app_name}.otree_extensions' + submodule_dotted = f'{package_dotted}.{submodule_name}' # need to check if base package exists; otherwise we get ImportError if find_spec(package_dotted) and find_spec(submodule_dotted): - module = import_module(submodule_dotted) - modules.append(module) + modules.append(import_module(submodule_dotted)) return modules @@ -77,4 +70,4 @@ def get_extensions_data_export_views(): view_classes = [] for module in get_extensions_modules('admin'): view_classes += getattr(module, 'data_export_views', []) - return view_classes \ No newline at end of file + return view_classes diff --git a/otree/forms/__init__.py b/otree/forms/__init__.py index 62d841cb5..12d6d71d4 100644 --- a/otree/forms/__init__.py +++ b/otree/forms/__init__.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- # flake8: noqa from django.forms import * diff --git a/otree/forms/fields.py b/otree/forms/fields.py index 85fe654cd..584791e24 100644 --- a/otree/forms/fields.py +++ b/otree/forms/fields.py @@ -7,7 +7,6 @@ class BaseCurrencyField(forms.DecimalField): - def __init__(self, *args, **kwargs): kwargs.setdefault('widget', self.widget) super().__init__(*args, **kwargs) diff --git a/otree/forms/forms.py b/otree/forms/forms.py index 2dc4459fc..97eaa1881 100644 --- a/otree/forms/forms.py +++ b/otree/forms/forms.py @@ -1,72 +1,18 @@ import copy -import six from decimal import Decimal -from six.moves import map - from django import forms -from django.forms import models as django_model_forms from django.utils.translation import ugettext as _ -from django.db.models.options import FieldDoesNotExist -import otree.common_internal -from otree.common_internal import ResponseForException +import otree.common +import otree.constants import otree.models -import otree.constants_internal -from otree.db import models +from otree.common import ResponseForException from otree.currency import Currency, RealWorldCurrency - -__all__ = ( - 'formfield_callback', 'modelform_factory', 'ModelForm') - - - - -def formfield_callback(db_field, **kwargs): - # Take the `widget` attribute into account that might be set for a db - # field. - widget = getattr(db_field, 'widget', None) - if widget: - # dynamic methods like FOO_choices, FOO_min, etc - # modify the form field's widget (self.widget) - # Django is not designed for this kind of dynamic modification, - # self.widget can actually be shared across - # all instances of that form field, meaning you are modifying the - # widget globally. However, this doesn't happen if the widget= arg - # is a class, because then it gets instantiated, which - # basically makes a copy. - # i reproduced this for FOO_choices, but not for FOO_min. - # if it's min, it sets the attrs on the widget, which means - # a shallow copy is not enough. but until i can reproduce this, - # leaving as is. - if not isinstance(widget, type): - widget = copy.copy(widget) - kwargs['widget'] = widget - return db_field.formfield(**kwargs) - - -def modelform_factory(*args, **kwargs): - """ - 2018-07-11: now this exists only to make a copy of the widget if necessary. - maybe there is a better way. - """ - kwargs.setdefault('formfield_callback', formfield_callback) - return django_model_forms.modelform_factory(*args, **kwargs) - -import django.forms.models - -class ModelFormMetaclass(django.forms.models.ModelFormMetaclass): - """ - Metaclass for BaseModelForm in order to inject our custom implementation of - `formfield_callback`. - """ - def __new__(mcs, name, bases, attrs): - attrs.setdefault('formfield_callback', formfield_callback) - return super(ModelFormMetaclass, mcs).__new__( - mcs, name, bases, attrs) +from otree.db import models -class ModelForm(forms.ModelForm, metaclass=ModelFormMetaclass): +class ModelForm(forms.ModelForm): def _get_method_from_page_or_model(self, method_name): for obj in [self.view, self.instance]: if hasattr(obj, method_name): @@ -99,22 +45,21 @@ def __init__(self, *args, view=None, **kwargs): for field_name in self.fields: field = self.fields[field_name] - choices_method = self._get_method_from_page_or_model(f'{field_name}_choices') + choices_method = self._get_method_from_page_or_model( + f'{field_name}_choices' + ) if choices_method: choices = choices_method() - choices = otree.common_internal.expand_choice_tuples(choices) + choices = otree.common.expand_choice_tuples(choices) model_field = self.instance._meta.get_field(field_name) + # this is necessary so we don't modify the field for other players model_field_copy = copy.copy(model_field) - - # in Django 1.11, _choices renamed to choices model_field_copy.choices = choices - - field = formfield_callback(model_field_copy) + field = model_field_copy.formfield() self.fields[field_name] = field - if isinstance(field.widget, forms.RadioSelect): # Fields with a RadioSelect should be rendered without the # '---------' option, and with nothing selected by default, to @@ -134,76 +79,26 @@ def __init__(self, *args, view=None, **kwargs): self._set_min_max_on_widgets() - def _get_field_min_max(self, field_name): - """ - Get the field boundaries from a methods defined on the view. - - Example (will get boundaries from `amount_`): - - - class Offer(Page): - ... - form_model = models.Group - form_fields = ['amount'] - - def amount_min(self): - return 1 - - def amount_max(self): - return 5 - - If the method is not found, it will return ``(None, None)``. - """ + def _get_field_bound(self, field_name, min_or_max: str): + model_field = self.instance._meta.get_field(field_name) - # SessionEditProperties is a ModelForm with extra field which is not - # part of the model. In case your ModelForm has an extra field. - try: - model_field = self.instance._meta.get_field(field_name) - except FieldDoesNotExist: - return [None, None] - - min_method = self._get_method_from_page_or_model(f'{field_name}_min') + min_method = self._get_method_from_page_or_model(f'{field_name}_{min_or_max}') if min_method: - min_value = min_method() - else: - min_value = getattr(model_field, 'min', None) - - max_method = self._get_method_from_page_or_model(f'{field_name}_max') - if max_method: - max_value = max_method() + return min_method() else: - max_value = getattr(model_field, 'max', None) - - return [min_value, max_value] + return getattr(model_field, min_or_max, None) def _set_min_max_on_widgets(self): for field_name, field in self.fields.items(): if isinstance(field.widget, forms.NumberInput): - min_bound, max_bound = self._get_field_min_max(field_name) - if isinstance(min_bound, (Currency, RealWorldCurrency)): - min_bound = Decimal(min_bound) - if isinstance(max_bound, (Currency, RealWorldCurrency)): - max_bound = Decimal(max_bound) - if min_bound is not None: - field.widget.attrs['min'] = min_bound - if max_bound is not None: - field.widget.attrs['max'] = max_bound - # is this UI too intrusive? - # if min_bound is not None and max_bound is not None: - # field.widget.attrs['placeholder'] = '({} - {})'.format( - # min_bound, max_bound - # ) - - def boolean_field_names(self): - boolean_fields_in_model = [ - field.name for field in self.Meta.model._meta.fields - if isinstance(field, models.BooleanField) - ] - return [field_name for field_name in self.fields - if field_name in boolean_fields_in_model] + for min_or_max in ['min', 'max']: + bound = self._get_field_bound(field_name, min_or_max) + if isinstance(bound, (Currency, RealWorldCurrency)): + bound = Decimal(bound) + if bound is not None: + field.widget.attrs[min_or_max] = bound def _clean_fields(self): - boolean_field_names = self.boolean_field_names() for name, field in self.fields.items(): # value_from_datadict() gets the data from the data dictionaries. # Each widget type knows how to retrieve its own data, because some @@ -219,31 +114,32 @@ def _clean_fields(self): value = field.clean(value) self.cleaned_data[name] = value - if name in boolean_field_names and value is None: - mfield = self.instance._meta.get_field(name) - if not mfield.allow_blank: - msg = otree.constants_internal.field_required_msg - raise forms.ValidationError(msg) + model_field = self.instance._meta.get_field(name) + if ( + isinstance(model_field, models.BooleanField) + and value is None + and not model_field.blank + ): + msg = otree.constants.field_required_msg + raise forms.ValidationError(msg) - lower, upper = self._get_field_min_max(name) + lower = self._get_field_bound(name, 'min') + upper = self._get_field_bound(name, 'max') # allow blank=True and min/max to be used together # the field is optional, but # if a value is submitted, it must be within [min,max] - if lower is None or value is None: - pass - elif value < lower: - msg = _('Value must be greater than or equal to {}.') - raise forms.ValidationError(msg.format(lower)) - - if upper is None or value is None: - pass - elif value > upper: - msg = _('Value must be less than or equal to {}.') - raise forms.ValidationError(msg.format(upper)) + if value is not None: + if lower is not None and value < lower: + msg = _('Value must be greater than or equal to {}.') + raise forms.ValidationError(msg.format(lower)) + if upper is not None and value > upper: + msg = _('Value must be less than or equal to {}.') + raise forms.ValidationError(msg.format(upper)) error_message_method = self._get_method_from_page_or_model( - f'{name}_error_message') + f'{name}_error_message' + ) if error_message_method: try: error_string = error_message_method(value) @@ -252,10 +148,6 @@ def _clean_fields(self): if error_string: raise forms.ValidationError(error_string) - if hasattr(self, 'clean_%s' % name): - value = getattr(self, 'clean_%s' % name)() - self.cleaned_data[name] = value - except forms.ValidationError as e: self.add_error(name, e) if not self.errors and hasattr(self.view, 'error_message'): diff --git a/otree/forms/widgets.py b/otree/forms/widgets.py index bdcae715f..b9380d7ff 100644 --- a/otree/forms/widgets.py +++ b/otree/forms/widgets.py @@ -4,9 +4,53 @@ from django.utils.encoding import force_text from django.utils.translation import ugettext_lazy from otree.currency import Currency, RealWorldCurrency -from django.forms.widgets import * # noqa from django import forms +# TextInput could be useful if someone wants to set choices= but doesn't +# want a dropdown. Same for NumberInput actually. But they could also +# just use FOO_error_message, so they don't need to know the name of each input. +from django.forms.widgets import ( + CheckboxInput, + HiddenInput, + RadioSelect, + TextInput, + Textarea, +) # noqa + + +def make_deprecated_widget(WidgetName): + def DeprecatedWidget(*args, **kwargs): + # putting the msg on a separate line gives better tracebacks + msg = ( + f'{WidgetName} does not exist in oTree. You should either delete it, ' + f'or import it from Django directly.' + ) + raise Exception(msg) + + return DeprecatedWidget + + +Media = make_deprecated_widget('Media') +MediaDefiningClass = make_deprecated_widget('MediaDefiningClass') +Widget = make_deprecated_widget('Widget') +NumberInput = make_deprecated_widget('NumberInput') +EmailInput = make_deprecated_widget('EmailInput') +URLInput = make_deprecated_widget('URLInput') +PasswordInput = make_deprecated_widget('PasswordInput') +MultipleHiddenInput = make_deprecated_widget('MultipleHiddenInput') +FileInput = make_deprecated_widget('FileInput') +ClearableFileInput = make_deprecated_widget('ClearableFileInput') +DateInput = make_deprecated_widget('DateInput') +DateTimeInput = make_deprecated_widget('DateTimeInput') +TimeInput = make_deprecated_widget('TimeInput') +Select = make_deprecated_widget('Select') +NullBooleanSelect = make_deprecated_widget('NullBooleanSelect') +SelectMultiple = make_deprecated_widget('SelectMultiple') +CheckboxSelectMultiple = make_deprecated_widget('CheckboxSelectMultiple') +MultiWidget = make_deprecated_widget('MultiWidget') +SplitDateTimeWidget = make_deprecated_widget('SplitDateTimeWidget') +SplitHiddenDateTimeWidget = make_deprecated_widget('SplitHiddenDateTimeWidget') +SelectDateWidget = make_deprecated_widget('SelectDateWidget') from otree.currency.locale import CURRENCY_SYMBOLS @@ -20,7 +64,7 @@ def get_context(self, *args, **kwargs): context['currency_symbol'] = self.CURRENCY_SYMBOL return context - def _format_value(self, value): + def format_value(self, value): if isinstance(value, (Currency, RealWorldCurrency)): value = Decimal(value) return force_text(value) @@ -28,14 +72,15 @@ def _format_value(self, value): class _RealWorldCurrencyInput(_BaseMoneyInput): '''it's a class attribute so take care with patching it in tests''' + CURRENCY_SYMBOL = CURRENCY_SYMBOLS.get( - settings.REAL_WORLD_CURRENCY_CODE, - settings.REAL_WORLD_CURRENCY_CODE, + settings.REAL_WORLD_CURRENCY_CODE, settings.REAL_WORLD_CURRENCY_CODE ) class _CurrencyInput(_RealWorldCurrencyInput): '''it's a class attribute so take care with patching it in tests''' + if settings.USE_POINTS: if hasattr(settings, 'POINTS_CUSTOM_NAME'): CURRENCY_SYMBOL = settings.POINTS_CUSTOM_NAME @@ -57,14 +102,14 @@ def __init__(self, *args, show_value=None, **kwargs): try: # fix bug where currency "step" values were ignored. step = kwargs['attrs']['step'] - kwargs['attrs']['step'] = self._format_value(step) + kwargs['attrs']['step'] = self.format_value(step) except KeyError: pass if show_value is not None: self.show_value = show_value super().__init__(*args, **kwargs) - def _format_value(self, value): + def format_value(self, value): if isinstance(value, (Currency, RealWorldCurrency)): value = Decimal(value) return force_text(value) @@ -74,6 +119,8 @@ def get_context(self, *args, **kwargs): context['show_value'] = self.show_value return context + class SliderInput(Slider): '''old name for Slider widget''' + pass diff --git a/otree/locale/zh_CN/LC_MESSAGES/django.mo b/otree/locale/zh_CN/LC_MESSAGES/django.mo deleted file mode 100644 index 340c21787..000000000 Binary files a/otree/locale/zh_CN/LC_MESSAGES/django.mo and /dev/null differ diff --git a/otree/locale/zh_CN/LC_MESSAGES/django.po b/otree/locale/zh_CN/LC_MESSAGES/django.po deleted file mode 100644 index 8ba2638d9..000000000 --- a/otree/locale/zh_CN/LC_MESSAGES/django.po +++ /dev/null @@ -1,142 +0,0 @@ -# SOME DESCRIPTIVE TITLE. -# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER -# This file is distributed under the same license as the PACKAGE package. -# FIRST AUTHOR , YEAR. -# -#, fuzzy -msgid "" -msgstr "" -"Project-Id-Version: PACKAGE VERSION\n" -"Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2018-05-27 20:55-0600\n" -"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" -"Last-Translator: FULL NAME \n" -"Language-Team: LANGUAGE \n" -"Language: \n" -"MIME-Version: 1.0\n" -"Content-Type: text/plain; charset=UTF-8\n" -"Content-Transfer-Encoding: 8bit\n" -"Plural-Forms: nplurals=1; plural=0;\n" - -#. Translators: A player's default chat nickname, -#. which is "Player" + their ID in group. For example: -#. "Player 2". -#: .\chat.py:21 -#, python-brace-format -msgid "Player {id_in_group}" -msgstr "" - -#. Translators: the name someone sees displayed for themselves in a chat. -#. It's their nickname followed by "(Me)". For example: -#. "Michael (Me)" or "Player 1 (Me)". -#: .\chat.py:60 -#, python-brace-format -msgid "{nickname} (Me)" -msgstr "" - -#. Translators: for required form fields -#: .\constants_internal.py:56 -msgid "This field is required." -msgstr "Use LANGUAGE_CODE=zh-hans" - -#. Translators: display a number of points, -#. like "1 point", "2 points", ... -#. See "Plural-Forms" above for pluralization rules -#. in this language. -#. Explanation at http://bit.ly/1IurMu7 -#. In most languages, msgstr[0] is singular, -#. and msgstr[1] is plural -#. the {} represents the number; -#. don't forget to include it in your translation -#: .\currency\__init__.py:225 -msgid "{} point" -msgid_plural "{} points" -msgstr[0] "Use LANGUAGE_CODE=zh-hans" -msgstr[1] "" - -#: .\db\models.py:254 -msgid "Yes" -msgstr "Use LANGUAGE_CODE=zh-hans" - -#: .\db\models.py:255 -msgid "No" -msgstr "Use LANGUAGE_CODE=zh-hans" - -#: .\forms\forms.py:238 -msgid "Value must be greater than or equal to {}." -msgstr "Use LANGUAGE_CODE=zh-hans" - -#: .\forms\forms.py:244 -msgid "Value must be less than or equal to {}." -msgstr "Use LANGUAGE_CODE=zh-hans" - -#. Translators: the label next to a "points" input field -#: .\forms\widgets.py:43 -msgid "points" -msgstr "Use LANGUAGE_CODE=zh-hans" - -#: .\templates\otree\OutOfRangeNotification.html:9 -msgid "No more pages left to show." -msgstr "Use LANGUAGE_CODE=zh-hans" - -#: .\templates\otree\Page.html:17 -#: .\templates\otree\admin\MTurkCreateHIT.html:67 -msgid "Please fix the errors in the form." -msgstr "Use LANGUAGE_CODE=zh-hans" - -#: .\templates\otree\RoomInputLabel.html:7 -msgid "Welcome" -msgstr "" - -#. Translators: If the user enters an invalid participant label -#: .\templates\otree\RoomInputLabel.html:13 -msgid "Invalid entry; try again." -msgstr "" - -#: .\templates\otree\RoomInputLabel.html:15 -msgid "Please enter your participant label." -msgstr "" - -#: .\templates\otree\WaitPage.html:49 -msgid "" -"An error occurred. Please check the logs or ask the administrator for help." -msgstr "Use LANGUAGE_CODE=zh-hans" - -#: .\templates\otree\login.html:47 -msgid "Forgotten your password or username?" -msgstr "Use LANGUAGE_CODE=zh-hans" - -#: .\templates\otree\login.html:52 -msgid "Log in" -msgstr "Use LANGUAGE_CODE=zh-hans" - -#. Translators: The text on the button the user clicks to get to the next page -#: .\templates\otree\tags\NextButton.html:5 -msgid "Next" -msgstr "Use LANGUAGE_CODE=zh-hans" - -#. Translators: Chat widget "send" button text -#: .\templates\otreechat_core\widget.html:7 -msgid "Send" -msgstr "" - -#: .\views\abstract.py:886 -msgid "Time left to complete this page:" -msgstr "Use LANGUAGE_CODE=zh-hans" - -#. Translators: the default title of a wait page -#: .\views\abstract.py:961 .\views\participant.py:291 -msgid "Please wait" -msgstr "Use LANGUAGE_CODE=zh-hans" - -#: .\views\abstract.py:1403 -msgid "Waiting for the other participants." -msgstr "Use LANGUAGE_CODE=zh-hans" - -#: .\views\abstract.py:1405 -msgid "Waiting for the other participant." -msgstr "Use LANGUAGE_CODE=zh-hans" - -#: .\views\participant.py:292 -msgid "Waiting for your session to begin" -msgstr "" diff --git a/otree/management/commands/bots.py b/otree/management/commands/bots.py index ae84865b5..87770fafb 100644 --- a/otree/management/commands/bots.py +++ b/otree/management/commands/bots.py @@ -1,22 +1,24 @@ -import sys import logging - -from pytest import main as pytest_main from django.conf import settings, global_settings from django.core.management.base import BaseCommand +from django.test.utils import ( + setup_databases, + setup_test_environment, + teardown_databases, + teardown_test_environment, +) -import otree.bots.runner +from otree.bots.runner import run_all_bots_for_session_config -import otree.common_internal +import otree.common logger = logging.getLogger('otree') -from otree.constants_internal import AUTO_NAME_BOTS_EXPORT_FOLDER - -from sys import exit as sys_exit +from otree.constants import AUTO_NAME_BOTS_EXPORT_FOLDER MSG_BOTS_HELP = 'Run oTree bots' + class Command(BaseCommand): help = MSG_BOTS_HELP @@ -29,17 +31,17 @@ def _get_action(self, parser, signature): def add_arguments(self, parser): # Positional arguments parser.add_argument( - 'session_config_name', nargs='?', - help='If omitted, all sessions in SESSION_CONFIGS are run' + 'session_config_name', + nargs='?', + help='If omitted, all sessions in SESSION_CONFIGS are run', ) - ahelp = ( - 'Number of participants. ' - 'Defaults to minimum for the session config.' - ) parser.add_argument( - 'num_participants', type=int, nargs='?', - help=ahelp) + 'num_participants', + type=int, + nargs='?', + help='Number of participants (if omitted, use num_demo_participants)', + ) # don't call it --data because then people might think that # that's the *input* data folder @@ -52,22 +54,23 @@ def add_arguments(self, parser): 'Saves the data generated by the tests. ' 'Runs the "export data" command, ' 'outputting the CSV files to the specified directory, ' - 'or an auto-generated one.'), - ) + 'or an auto-generated one.' + ), + ) parser.add_argument( '--save', nargs='?', const=AUTO_NAME_BOTS_EXPORT_FOLDER, dest='export_path', - help=( - 'Alias for --export.'), - ) + help=('Alias for --export.'), + ) v_action = self._get_action(parser, ("-v", "--verbosity")) v_action.default = '1' v_action.help = ( 'Verbosity level; 0=minimal output, 1=normal output,' - '2=verbose output (DEFAULT), 3=very verbose output') + '2=verbose output (DEFAULT), 3=very verbose output' + ) def prepare_global_state(self): ''' @@ -75,12 +78,6 @@ def prepare_global_state(self): these are optimizations that are mostly redundant with what runtests.py does. ''' - # use in-memory. - # this is the simplest way to patch tests to use in-memory, - # while still using Redis in production - settings.CHANNEL_LAYERS['default'] = settings.CHANNEL_LAYERS['inmemory'] - # so we know not to use Huey - otree.common_internal.USE_REDIS = False # To make tests run faster, autorefresh should be set to True # http://whitenoise.evans.io/en/latest/django.html#whitenoise-makes-my-tests-run-slow @@ -89,45 +86,31 @@ def prepare_global_state(self): # same hack as in resetdb code # because pytest.main() uses the serializer # it breaks if the app has migrations but they aren't up to date - otree.common_internal.patch_migrations_module() + otree.common.patch_migrations_module() settings.STATICFILES_STORAGE = global_settings.STATICFILES_STORAGE - def handle( - self, *, verbosity, - **options): + self, + *, + verbosity, + session_config_name, + num_participants, + export_path, + **options + ): self.prepare_global_state() - # '-s' is to see print output - # --tb=short is to show short tracebacks. I think this is - # more expected and less verbose. - # With the default pytest long tracebacks, - # often the code that gets printed is in otree-core, which is not relevant. - # also, this is better than using --tb=native, which loses line breaks - # when a unicode char is contained in the output, and also doesn't get - # color coded with colorama, the way short tracebacks do. - argv = [ - otree.bots.runner.__file__, - '-s', - '--tb', 'short' - ] - if verbosity == 0: - argv.append('--quiet') - if verbosity == 2: - argv.append('--verbose') - - for k in ['session_config_name', 'num_participants', 'export_path']: - v = options[k] - if v: - argv.extend([f'--{k}', v]) - - exit_code = pytest_main(argv) - - if not options['export_path']: - logger.info('Tip: Run this command with the --export flag' - ' to save the data generated by bots.') - - # exit with the exit code, so that CI systems can know if - # the tests succeeded or failed. - sys_exit(exit_code) + setup_test_environment() + old_config = setup_databases( + interactive=False, verbosity=verbosity, aliases={'default'} + ) + try: + run_all_bots_for_session_config( + session_config_name=session_config_name, + num_participants=num_participants, + export_path=export_path, + ) + finally: + teardown_databases(old_config, verbosity=verbosity) + teardown_test_environment() diff --git a/otree/management/commands/botworker.py b/otree/management/commands/botworker.py index b511170ae..144a8dee0 100644 --- a/otree/management/commands/botworker.py +++ b/otree/management/commands/botworker.py @@ -1,7 +1,7 @@ import logging from django.core.management.base import BaseCommand import otree.bots.browser -from otree.common_internal import get_redis_conn +from otree.common import get_redis_conn logger = logging.getLogger('otree.botworker') diff --git a/otree/management/commands/browser_bots.py b/otree/management/commands/browser_bots.py index ee111e5f7..da26feaf1 100644 --- a/otree/management/commands/browser_bots.py +++ b/otree/management/commands/browser_bots.py @@ -7,28 +7,26 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - 'session_config_name', nargs='?', - help='If omitted, all sessions in SESSION_CONFIGS are run' + 'session_config_name', + nargs='?', + help='If omitted, all sessions in SESSION_CONFIGS are run', ) parser.add_argument( - '--server-url', action='store', type=str, dest='server_url', + '--server-url', + action='store', + type=str, + dest='server_url', default='http://127.0.0.1:8000', - help="Server's root URL") - ahelp = ( - 'Number of participants. ' - 'Defaults to minimum for the session config.' + help="Server's root URL", ) - parser.add_argument( - 'num_participants', type=int, nargs='?', - help=ahelp) + ahelp = 'Number of participants. ' 'Defaults to minimum for the session config.' + parser.add_argument('num_participants', type=int, nargs='?', help=ahelp) def handle(self, session_config_name, server_url, num_participants, **options): launcher = Launcher( session_config_name=session_config_name, server_url=server_url, - num_participants=num_participants + num_participants=num_participants, ) launcher.run() - - diff --git a/otree/management/commands/create_session.py b/otree/management/commands/create_session.py index f1ace4e61..b7ec17ce8 100644 --- a/otree/management/commands/create_session.py +++ b/otree/management/commands/create_session.py @@ -1,6 +1,6 @@ import logging -import six + from django.core.management.base import BaseCommand @@ -16,28 +16,35 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - 'session_config_name', type=six.u, help="The session config name") - parser.add_argument( - 'num_participants', type=int, - help="Number of participants for the created session") + 'session_config_name', help="The session config name" + ) parser.add_argument( - "-l", "--label", action="store", type=six.u, - dest="label", default='', help="label for the created session") + 'num_participants', + type=int, + help="Number of participants for the created session", + ) parser.add_argument( - "--room", action="store", type=six.u, - dest="room_name", default=None, - help="Name of room to create the session in") + "--room", + action="store", + dest="room_name", + default=None, + help="Name of room to create the session in", + ) - def handle(self, session_config_name, num_participants, label, room_name, **kwargs): + def handle(self, session_config_name, num_participants, room_name, **kwargs): session = create_session( session_config_name=session_config_name, - num_participants=num_participants, label=label) + num_participants=num_participants, + ) if room_name: room = ROOM_DICT[room_name] room.set_session(session) - logger.info("Created session with code {} in room '{}'\n".format( - session.code, room_name)) + logger.info( + "Created session with code {} in room '{}'\n".format( + session.code, room_name + ) + ) else: logger.info("Created session with code {}\n".format(session.code)) diff --git a/otree/management/commands/devserver.py b/otree/management/commands/devserver.py index a84825c18..470980176 100644 --- a/otree/management/commands/devserver.py +++ b/otree/management/commands/devserver.py @@ -1,26 +1,36 @@ import importlib +import logging import os import os.path -import pathlib +import shutil import sys -import termcolor import time import traceback +from pathlib import Path from unittest.mock import patch + +import termcolor +from channels.management.commands import runserver +from daphne.endpoints import build_endpoint_description_strings +from django.apps import apps from django.conf import settings from django.core.management import call_command -from pathlib import Path -from django.apps import apps -from . import runserver +import otree.bots.browser +import otree.common +import otree_startup +from otree import __version__ as CURRENT_VERSION -TMP_MIGRATIONS_DIR = '__temp_migrations' +TMP_MIGRATIONS_DIR = Path('__temp_migrations') +VERSION_FILE = TMP_MIGRATIONS_DIR.joinpath('otree-version.txt') ADVICE_DELETE_TMP = ( "ADVICE: Try deleting the folder {}. If that doesn't work, " "look for the error in your models.py." ).format(TMP_MIGRATIONS_DIR) +# this happens when I add a non-nullable field to oTree-core +# (includes renaming a non-nullable field) ADVICE_FIX_NOT_NULL_FIELD = ( 'You may have added a non-nullable field without a default. ' 'This typically happens when importing model fields from django instead of otree.' @@ -34,7 +44,7 @@ db_engine = settings.DATABASES['default']['ENGINE'].lower() -if 'sqlite' in db_engine: +if otree.common.is_sqlite(): ADVICE_DELETE_DB = ( 'ADVICE: Stop the server, ' 'then delete the file db.sqlite3 in your project folder, ' @@ -53,47 +63,139 @@ 'command or the other.' ).format(db_engine) +# They should start fresh so that: +# (1) performance refresh +# (2) don't have to worry about old references to things that were removed from otree-core. +MSG_OTREE_UPDATE_DELETE_DB = ( + 'oTree has been updated. Please delete your database (usually "db.sqlite3") ' + f'and the folder "{TMP_MIGRATIONS_DIR}".' +) -class Command(runserver.Command): - - inside_runzip = False +class Command(runserver.Command): def add_arguments(self, parser): super().add_arguments(parser) + # see log_action below; we only show logs of each request + # if verbosity >= 1. + # this still allows logger.info and logger.warning to be shown. + # NOTE: if we change this back to 1, then need to update devserver + # not to show traceback of errors. + parser.set_defaults(verbosity=0) + parser.add_argument( - '--inside-runzip', action='store_true', dest='inside_runzip', default=False, + '--inside-runzip', action='store_true', dest='inside_runzip', default=False ) + + def handle(self, *args, **options): + self.verbosity = options.get("verbosity", 1) + from otree.common import release_any_stale_locks + + release_any_stale_locks() + + # for performance, + # only run checks when the server starts, not when it reloads + # (RUN_MAIN is set by Django autoreloader). + if not os.environ.get('RUN_MAIN'): + + try: + # don't suppress output. it's good to know that check is + # not failing silently or not being run. + # also, intercepting stdout doesn't even seem to work here. + self.check(display_num_errors=True) + + except Exception as exc: + otree_startup.print_colored_traceback_and_exit(exc) + + # better to do this here, because: + # (1) it's redundant to do it on every reload + # (2) we can exit if we run this before the autoreloader is started + if TMP_MIGRATIONS_DIR.exists() and ( + not VERSION_FILE.exists() or VERSION_FILE.read_text() != CURRENT_VERSION + ): + # - Don't delete the DB, because it might have important data + # - Don't delete __temp_migrations, because then we erase the knowledge that + # oTree was updated. If the user starts the server at a later time, we can't remind them + # that they needed to delete the DB. So, the two things must be deleted together. + self.stdout.write(MSG_OTREE_UPDATE_DELETE_DB) + sys.exit(0) + TMP_MIGRATIONS_DIR.mkdir(exist_ok=True) + VERSION_FILE.write_text(CURRENT_VERSION) + TMP_MIGRATIONS_DIR.joinpath('__init__.py').touch(exist_ok=True) + + super().handle(*args, **options) + def inner_run(self, *args, inside_runzip, **options): + ''' + inner_run does not get run twice with runserver, unlike .handle() + ''' self.inside_runzip = inside_runzip - self.handle_migrations() - - super().inner_run(*args, **options) + self.makemigrations_and_migrate() + + # initialize browser bot worker in process memory + otree.bots.browser.browser_bot_worker = otree.bots.browser.Worker() + + # silence the lines like: + # 2018-01-10 18:51:18,092 - INFO - worker - Listening on channels + # http.request, otree.create_session, websocket.connect, + # websocket.disconnect, websocket.receive + daphne_logger = logging.getLogger('django.channels') + original_log_level = daphne_logger.level + daphne_logger.level = logging.WARNING + + endpoints = build_endpoint_description_strings(host=self.addr, port=self.port) + application = self.get_application(options) + + # silence the lines like: + # INFO HTTP/2 support not enabled (install the http2 and tls Twisted extras) + # INFO Configuring endpoint tcp:port=8000:interface=127.0.0.1 + # INFO Listening on TCP address 127.0.0.1:8000 + logging.getLogger('daphne.server').level = logging.WARNING + + # I removed the IPV6 stuff here because its not commonly used yet + addr = self.addr + # 0.0.0.0 is not a regular IP address, so we can't tell the user + # to open their browser to that address + if addr == '127.0.0.1': + addr = 'localhost' + elif addr == '0.0.0.0': + addr = '' + self.stdout.write( + ( + f"Open your browser to http://{addr}:{self.port}/\n" + "To quit the server, press Control+C.\n" + ) + ) - def handle_migrations(self): + try: + self.server_cls( + application=application, + endpoints=endpoints, + signal_handlers=not options["use_reloader"], + action_logger=self.log_action, + http_timeout=self.http_timeout, + root_path=getattr(settings, "FORCE_SCRIPT_NAME", "") or "", + websocket_handshake_timeout=self.websocket_handshake_timeout, + ).run() + daphne_logger.debug("Daphne exited") + except KeyboardInterrupt: + shutdown_message = options.get("shutdown_message", "") + if shutdown_message: + self.stdout.write(shutdown_message) + return + + def makemigrations_and_migrate(self): # only get apps with labels, otherwise migrate will raise an error # when it tries to migrate that app but no migrations dir was created - app_labels = set( - model._meta.app_config.label - for model in apps.get_models() - ) + app_labels = set(model._meta.app_config.label for model in apps.get_models()) migrations_modules = { app_label: '{}.{}'.format(TMP_MIGRATIONS_DIR, app_label) for app_label in app_labels } - settings.MIGRATION_MODULES = migrations_modules - migrations_dir_path = os.path.join(settings.BASE_DIR, TMP_MIGRATIONS_DIR) - pathlib.Path(TMP_MIGRATIONS_DIR).mkdir(exist_ok=True) - - init_file_path = os.path.join(migrations_dir_path, '__init__.py') - pathlib.Path(init_file_path).touch(exist_ok=True) - - self.perf_check() - start = time.time() try: @@ -136,9 +238,10 @@ def handle_migrations(self): # so, simplest to use the string name if type(exc).__name__ in ( - 'OperationalError', - 'ProgrammingError', - 'InconsistentMigrationHistory'): + 'OperationalError', + 'ProgrammingError', + 'InconsistentMigrationHistory', + ): self.print_error_and_exit(ADVICE_DELETE_DB) else: raise @@ -162,25 +265,20 @@ def print_error_and_exit(self, advice): self.stdout.write(ADVICE_PRINT_DETAILS) sys.exit(0) - def perf_check(self): - '''after about 150 migrations, - load time increased from 0.6 to 1.2+ second''' - - MAX_MIGRATIONS = 200 - - # we want to delete migrations files, but keep __init__.py - # and directories, because then we don't need to - # migrations files are named 0001_xxx.py, 0002_xxx.py, etc. - # so, we assume they will all - file_glob = '{}/*/0*.py'.format(TMP_MIGRATIONS_DIR) - python_fns = list(Path('.').glob(file_glob)) - num_files = len(python_fns) - - if num_files > MAX_MIGRATIONS: - advice = ( - 'You have too many migrations files ({}). ' - 'This can slow down performance. ' - 'You should delete the directory {} ' - 'and also delete your database.' - ).format(num_files, TMP_MIGRATIONS_DIR) - termcolor.cprint(advice, 'white', 'on_red') + def log_action(self, protocol, action, details): + ''' + Override log_action method. + Need this until https://github.com/django/channels/issues/612 + is fixed. + maybe for some minimal output use this? + self.stderr.write('.', ending='') + so that you can see that the server is running + (useful if you are accidentally running multiple servers) + + idea: maybe only show details if it's a 4xx or 5xx. + + ''' + if self.verbosity >= 1: + super().log_action(protocol, action, details) + + inside_runzip = False diff --git a/otree/management/commands/django_test.py b/otree/management/commands/django_test.py index be724268f..cd08f4eb7 100644 --- a/otree/management/commands/django_test.py +++ b/otree/management/commands/django_test.py @@ -1,4 +1,4 @@ # make Django's native 'test' command available to those who need it # because oTree overrides it. -from django.core.management.commands.test import Command # noqa \ No newline at end of file +from django.core.management.commands.test import Command # noqa diff --git a/otree/management/commands/prodserver.py b/otree/management/commands/prodserver.py index 74ee4ae82..3168204ec 100644 --- a/otree/management/commands/prodserver.py +++ b/otree/management/commands/prodserver.py @@ -1 +1 @@ -from .runprodserver import Command # noqa \ No newline at end of file +from .runprodserver import Command # noqa diff --git a/otree/management/commands/prodserver1of2.py b/otree/management/commands/prodserver1of2.py index dedd8b24a..a3d10a36d 100644 --- a/otree/management/commands/prodserver1of2.py +++ b/otree/management/commands/prodserver1of2.py @@ -1 +1 @@ -from .runprodserver1of2 import Command # noqa \ No newline at end of file +from .runprodserver1of2 import Command # noqa diff --git a/otree/management/commands/resetdb.py b/otree/management/commands/resetdb.py index 47b0c8234..0ac3764fe 100644 --- a/otree/management/commands/resetdb.py +++ b/otree/management/commands/resetdb.py @@ -1,13 +1,12 @@ import logging -import six from django.conf import settings from django.core.management.base import BaseCommand from django.core.management import call_command from django.db import connection, transaction -import django.apps +from dataclasses import dataclass -from otree import common_internal +from otree import common from typing import Tuple, List logger = logging.getLogger('otree') @@ -15,26 +14,35 @@ MSG_RESETDB_SUCCESS_FOR_HUB = 'Created new tables and columns.' MSG_DB_ENGINE_FOR_HUB = 'Database engine' -def db_label_and_drop_cmd(db_engine: str) -> Tuple[str, str]: + +@dataclass +class DBDeletionInfo: + db_engine: str + table_delete_command: str + + +def db_label_and_drop_cmd(db_engine: str) -> DBDeletionInfo: db_engine_lower = db_engine.lower() if 'oracle' in db_engine_lower: - return ('Oracle', 'DROP TABLE "{table}" CASCADE CONSTRAINTS;') + return DBDeletionInfo('Oracle', 'DROP TABLE "{table}" CASCADE CONSTRAINTS;') if 'postgres' in db_engine_lower: - return ('Postgres', 'DROP TABLE "{table}" CASCADE;') + return DBDeletionInfo('Postgres', 'DROP TABLE "{table}" CASCADE;') if 'mysql' in db_engine_lower: - return ( + return DBDeletionInfo( 'MySQL', ( 'SET FOREIGN_KEY_CHECKS = 0;' 'DROP TABLE {table} CASCADE;' 'SET FOREIGN_KEY_CHECKS = 1;' - ) + ), ) # put this last for test coverage - if 'sqlite3' in db_engine_lower: - return ('SQLite', 'DROP TABLE {table};') + if common.is_sqlite(): + return DBDeletionInfo('SQLite', 'DROP TABLE {table};') raise ValueError( - 'resetdb command does not recognize DB engine "{}"'.format(db_engine)) + 'resetdb command does not recognize DB engine "{}"'.format(db_engine) + ) + def cursor_execute_drop_cmd(cursor, stmt): cursor.execute(stmt) @@ -45,42 +53,44 @@ def migrate_db(options): # it doesn't exist. # Tried setting MIGRATIONS_MODULES but doesn't work # (causes ModuleNotFoundError) - common_internal.patch_migrations_module() + common.patch_migrations_module() - call_command( - 'migrate', interactive=False, run_syncdb=True, **options - ) + call_command('migrate', interactive=False, run_syncdb=True, **options) class Command(BaseCommand): help = ( "Resets your development database to a fresh state. " - "All data will be deleted.") + "All data will be deleted." + ) def add_arguments(self, parser): ahelp = ( - 'Tells the resetdb command to NOT prompt the user for ' - 'input of any kind.') + 'Tells the resetdb command to NOT prompt the user for ' 'input of any kind.' + ) parser.add_argument( - '--noinput', action='store_false', dest='interactive', - default=True, help=ahelp) + '--noinput', + action='store_false', + dest='interactive', + default=True, + help=ahelp, + ) def _confirm(self) -> bool: - self.stdout.write( - "This will delete and recreate your database. ") - answer = six.moves.input("Proceed? (y or n): ") + self.stdout.write("This will delete and recreate your database. ") + answer = input("Proceed? (y or n): ") if answer: return answer[0].lower() == 'y' return False - def _get_tables(self) -> List[str]: with connection.cursor() as cursor: tables = connection.introspection.get_table_list(cursor) # in the old version, juan reversed the list, not sure why, # maybe something about foreign key dependencies? return [ - t.name for t in tables + t.name + for t in tables # do this so it will fail loudly if the "type" doesn't match if {'t': True, 'v': False, 'p': False}[t.type] ] @@ -97,23 +107,32 @@ def handle(self, *, interactive, **options): return dbconf = settings.DATABASES['default'] - db_engine, drop_cmd_template = db_label_and_drop_cmd(dbconf['ENGINE']) + drop_db_info = db_label_and_drop_cmd(dbconf['ENGINE']) + db_engine = drop_db_info.db_engine # hub depends on this string logger.info(f"{MSG_DB_ENGINE_FOR_HUB}: {db_engine}") tables = self._get_tables() - # use a transaction to prevent the DB from getting in an erroneous - # state, which can result in a different error message when resetdb - # is run again, making the original error hard to trace. - with transaction.atomic( - savepoint=connection.features.can_rollback_ddl - ): - logger.info(f"Dropping {len(tables)} tables...") - self._drop_tables(tables, drop_cmd_template) - - migrate_db(options) + logger.info(f"Dropping {len(tables)} tables...") + is_sqlite = common.is_sqlite() + if is_sqlite: + # otherwise, when I drop tables I get: + # django.db.utils.IntegrityError: FOREIGN KEY constraint failed + # see: https://www.sqlitetutorial.net/sqlite-drop-table/ + # potentially i could do 'DELETE FROM {table};' + # for each table before i drop any table, + # but that's more code, and I'm not sure it will work. + connection.disable_constraint_checking() + try: + self._drop_tables(tables, drop_db_info.table_delete_command) + finally: + if is_sqlite: + # maybe overkill but perhaps resetdb could be called via call_command() + connection.enable_constraint_checking() + + migrate_db(options) # mention the word 'columns' here, so people make the connection # between columns and resetdb, so that when they get a 'no such column' diff --git a/otree/management/commands/runprodserver.py b/otree/management/commands/runprodserver.py index f27b58366..0dea5c2d5 100644 --- a/otree/management/commands/runprodserver.py +++ b/otree/management/commands/runprodserver.py @@ -20,23 +20,22 @@ def add_arguments(self, parser): ahelp = ( 'By default we will collect all static files into the directory ' 'configured in your settings. Disable it with this switch if you ' - 'want to do it manually.') + 'want to do it manually.' + ) parser.add_argument( - '--no-collectstatic', action='store_false', dest='collectstatic', - default=True, help=ahelp) + '--no-collectstatic', + action='store_false', + dest='collectstatic', + default=True, + help=ahelp, + ) def setup_honcho(self, **options): super().setup_honcho(**options) honcho = self.honcho - honcho.add_otree_process( - 'botworker', - 'otree botworker', - ) - honcho.add_otree_process( - 'timeoutworkeronly', - 'otree timeoutworkeronly', - ) + honcho.add_otree_process('botworker', 'otree botworker') + honcho.add_otree_process('timeoutworkeronly', 'otree timeoutworkeronly') def handle(self, *args, collectstatic, **options): diff --git a/otree/management/commands/runprodserver1of2.py b/otree/management/commands/runprodserver1of2.py index ded02b56e..e7a211bc2 100644 --- a/otree/management/commands/runprodserver1of2.py +++ b/otree/management/commands/runprodserver1of2.py @@ -4,19 +4,22 @@ import logging import honcho.manager - +from django.conf import settings from django.core.management.base import BaseCommand from django.core.management.base import CommandError import otree logger = logging.getLogger(__name__) -naiveip_re = re.compile(r"""^(?: +naiveip_re = re.compile( + r"""^(?: (?P (?P\d{1,3}(?:\.\d{1,3}){3}) | # IPv4 address (?P\[[a-fA-F0-9:]+\]) | # IPv6 address (?P[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*) # FQDN -):)?(?P\d+)$""", re.X) +):)?(?P\d+)$""", + re.X, +) DEFAULT_PORT = "8000" DEFAULT_ADDR = '0.0.0.0' @@ -31,10 +34,6 @@ else: NUM_WORKERS = 3 -def get_ssl_file_path(filename): - otree_dir = os.path.dirname(otree.__file__) - pth = os.path.join(otree_dir, 'certs', filename) - return pth.replace('\\', '/') # made this simple class to reduce code duplication, # and to make testing easier (I didn't know how to check that it was called @@ -49,30 +48,27 @@ class Command(BaseCommand): def add_arguments(self, parser): - parser.add_argument('addrport', nargs='?', - help='Optional port number, or ipaddr:port') - - ahelp = ( - 'Run an SSL server directly in Daphne with a self-signed cert/key' - ) parser.add_argument( - '--dev-https', action='store_true', dest='dev_https', default=False, - help=ahelp) + 'addrport', nargs='?', help='Optional port number, or ipaddr:port' + ) - def handle(self, *args, addrport=None, verbosity=1, dev_https, **kwargs): + def handle(self, *args, addrport=None, verbosity=1, **kwargs): self.verbosity = verbosity + os.environ['OTREE_USE_REDIS'] = '1' self.honcho = OTreeHonchoManager() - self.setup_honcho(addrport=addrport, dev_https=dev_https) + self.setup_honcho(addrport=addrport) self.honcho.loop() sys.exit(self.honcho.returncode) - def setup_honcho(self, *, addrport, dev_https): + def setup_honcho(self, *, addrport): if addrport: m = re.match(naiveip_re, addrport) if m is None: - raise CommandError('"%s" is not a valid port number ' - 'or address:port pair.' % addrport) + raise CommandError( + '"%s" is not a valid port number ' + 'or address:port pair.' % addrport + ) addr, _, _, _, port = m.groups() else: addr = None @@ -84,25 +80,12 @@ def setup_honcho(self, *, addrport, dev_https): # https://github.com/encode/uvicorn/issues/185 - #asgi_server_cmd = f'uvicorn --host={addr} --port={port} --workers={NUM_WORKERS} otree_startup.asgi:application --log-level=debug' - #asgi_server_cmd += ' --ws=wsproto' - asgi_server_cmd = f'hypercorn -b {addr}:{port} --workers={NUM_WORKERS} otree_startup.asgi:application' - - if dev_https: - # Because of HSTS, Chrome and other browsers will "get stuck" forcing HTTPS, - # which makes it impossible to run regular devserver again on that port - if int(port) == 8000: - self.stderr.write('ERROR: oTree cannot use HTTPS on port 8000. Please specify a different port.') - raise SystemExit(-1) - asgi_server_cmd += ' --keyfile="{}" --certfile="{}"'.format( - get_ssl_file_path('development.key'), - get_ssl_file_path('development.crt'), - ) + # asgi_server_cmd = f'uvicorn --host={addr} --port={port} --workers={NUM_WORKERS} otree_startup.asgi:application --log-level=debug' + # keep-alive is needed, otherwise pages that take more than 5 seconds to load will trigger h13 + # asgi_server_cmd = f'hypercorn -b {addr}:{port} --workers={NUM_WORKERS} --keep-alive=35 otree_startup.asgi:application' + asgi_server_cmd = f'daphne -b {addr} -p {port} otree_startup.asgi:application' logger.info(asgi_server_cmd) honcho = self.honcho - honcho.add_otree_process( - 'asgiserver', - asgi_server_cmd - ) + honcho.add_otree_process('asgiserver', asgi_server_cmd) diff --git a/otree/management/commands/runprodserver2of2.py b/otree/management/commands/runprodserver2of2.py index efd135922..7b58246e7 100644 --- a/otree/management/commands/runprodserver2of2.py +++ b/otree/management/commands/runprodserver2of2.py @@ -1,6 +1,4 @@ -#!/usr/bin/env python import os -import sys from sys import exit as sys_exit from honcho.manager import Manager as HonchoManager @@ -22,23 +20,18 @@ def handle(self, *args, verbosity=1, **options): def get_honcho_manager(self): + # this env var is necessary because if the botworker submits a wait page, + # it needs to broadcast to redis channel layer, not in-memory. + # this caused an obscure bug on 2019-09-21. + os.environ['OTREE_USE_REDIS'] = '1' env_copy = os.environ.copy() manager = HonchoManager() # if I change these, I need to modify the ServerCheck also + manager.add_process('botworker', 'otree botworker', quiet=False, env=env_copy) manager.add_process( - 'botworker', - 'otree botworker', - quiet=False, - env=env_copy, - ) - manager.add_process( - 'timeoutworkeronly', - 'otree timeoutworkeronly', - quiet=False, - env=env_copy, + 'timeoutworkeronly', 'otree timeoutworkeronly', quiet=False, env=env_copy ) return manager - diff --git a/otree/management/commands/runserver.py b/otree/management/commands/runserver.py deleted file mode 100644 index b2fc579ac..000000000 --- a/otree/management/commands/runserver.py +++ /dev/null @@ -1,140 +0,0 @@ -from channels.management.commands import runserver -import otree.bots.browser -from django.conf import settings -import otree.common_internal -import logging -from daphne.endpoints import build_endpoint_description_strings -from daphne.server import Server -import otree_startup -from otree import common_internal -import os -import sys -from channels.worker import Worker -import threading -from channels.layers import get_channel_layer - -class Command(runserver.Command): - - def handle(self, *args, **options): - - self.verbosity = options.get("verbosity", 1) - - # i think this won't work, because channels reads this setting - # during django.setup() - #settings.CHANNEL_LAYERS['default'] = settings.CHANNEL_LAYERS['inmemory'] - - from otree.common_internal import release_any_stale_locks - release_any_stale_locks() - - # don't use cached template loader, so that users can refresh files - # and see the update. - # kind of a hack to patch it here and to refer it as [0], - # but can't think of a better way. - settings.TEMPLATES[0]['OPTIONS']['loaders'] = [ - 'django.template.loaders.filesystem.Loader', - 'django.template.loaders.app_directories.Loader', - ] - - # so we know not to use Huey - otree.common_internal.USE_REDIS = False - - # for performance, - # only run checks when the server starts, not when it reloads - # (RUN_MAIN is set by Django autoreloader). - if not os.environ.get('RUN_MAIN'): - - try: - # don't suppress output. it's good to know that check is - # not failing silently or not being run. - # also, intercepting stdout doesn't even seem to work here. - self.check(display_num_errors=True) - - except Exception as exc: - otree_startup.print_colored_traceback_and_exit(exc) - - super().handle(*args, **options) - - def inner_run(self, *args, **options): - - ''' - inner_run does not get run twice with runserver, unlike .handle() - ''' - - # initialize browser bot worker in process memory - otree.bots.browser.browser_bot_worker = otree.bots.browser.Worker() - - addr = f'[{self.addr}]' if self._raw_ipv6 else self.addr - # 0.0.0.0 is not a regular IP address, so we can't tell the user - # to open their browser to that address - if addr == '127.0.0.1': - addr = 'localhost' - elif addr == '0.0.0.0': - addr = '' - self.stdout.write(( - "Starting server.\n" - "Open your browser to http://%(addr)s:%(port)s/\n" - "To quit the server, press Control+C.\n" - ) % { - "addr": addr, - "port": self.port, - }) - - # silence the lines like: - # 2018-01-10 18:51:18,092 - INFO - worker - Listening on channels - # http.request, otree.create_session, websocket.connect, - # websocket.disconnect, websocket.receive - daphne_logger = logging.getLogger('django.channels') - original_log_level = daphne_logger.level - daphne_logger.level = logging.WARNING - - endpoints = build_endpoint_description_strings(host=self.addr, port=self.port) - application = self.get_application(options) - - # silence the lines like: - # INFO HTTP/2 support not enabled (install the http2 and tls Twisted extras) - # INFO Configuring endpoint tcp:port=8000:interface=127.0.0.1 - # INFO Listening on TCP address 127.0.0.1:8000 - logging.getLogger('daphne.server').level = logging.WARNING - - try: - self.server_cls( - application=application, - endpoints=endpoints, - signal_handlers=not options["use_reloader"], - action_logger=self.log_action, - http_timeout=self.http_timeout, - root_path=getattr(settings, "FORCE_SCRIPT_NAME", "") or "", - websocket_handshake_timeout=self.websocket_handshake_timeout, - ).run() - daphne_logger.debug("Daphne exited") - except KeyboardInterrupt: - shutdown_message = options.get("shutdown_message", "") - if shutdown_message: - self.stdout.write(shutdown_message) - return - - - def add_arguments(self, parser): - super().add_arguments(parser) - # see log_action below; we only show logs of each request - # if verbosity >= 1. - # this still allows logger.info and logger.warning to be shown. - # NOTE: if we change this back to 1, then need to update devserver - # not to show traceback of errors. - parser.set_defaults(verbosity=0) - - def log_action(self, protocol, action, details): - ''' - Override log_action method. - Need this until https://github.com/django/channels/issues/612 - is fixed. - maybe for some minimal output use this? - self.stderr.write('.', ending='') - so that you can see that the server is running - (useful if you are accidentally running multiple servers) - - idea: maybe only show details if it's a 4xx or 5xx. - - ''' - if self.verbosity >= 1: - super().log_action(protocol, action, details) diff --git a/otree/management/commands/startapp.py b/otree/management/commands/startapp.py index ef23e1416..5986a95bd 100644 --- a/otree/management/commands/startapp.py +++ b/otree/management/commands/startapp.py @@ -1,24 +1,15 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - import os from django.core.management.commands import startapp import otree -from otree_startup import pypi_updates_cli class Command(startapp.Command): def get_default_template(self): - return os.path.join( - os.path.dirname(otree.__file__), 'app_template') + return os.path.join(os.path.dirname(otree.__file__), 'app_template') def handle(self, *args, **options): options['template'] = self.get_default_template() super().handle(*args, **options) - try: - pypi_updates_cli() - except: # noqa - pass # noqa self.stdout.write('Created app folder.') diff --git a/otree/management/commands/startproject.py b/otree/management/commands/startproject.py index d31f09b4f..892aed7d4 100644 --- a/otree/management/commands/startproject.py +++ b/otree/management/commands/startproject.py @@ -3,18 +3,17 @@ from django.core.management.base import CommandError import sys import otree -from otree_startup import pypi_updates_cli class Command(startproject.Command): - help = ("Creates a new oTree project.") + help = "Creates a new oTree project." def add_arguments(self, parser): super().add_arguments(parser) '''need this so we can test startproject automatically''' parser.add_argument( - '--noinput', action='store_false', dest='interactive', - default=True) + '--noinput', action='store_false', dest='interactive', default=True + ) def handle(self, *args, **options): project_name = options['name'] @@ -38,10 +37,12 @@ def handle(self, *args, **options): answer = 'n' if answer and answer[0].lower() == "y": project_template_path = ( - "https://github.com/oTree-org/oTree/archive/master.zip") + "https://github.com/oTree-org/oTree/archive/master.zip" + ) else: project_template_path = os.path.join( - os.path.dirname(otree.__file__), 'project_template') + os.path.dirname(otree.__file__), 'project_template' + ) options['template'] = project_template_path @@ -56,28 +57,12 @@ def handle(self, *args, **options): if os.path.exists(project_name): os.rmdir(project_name) - is_macos = sys.platform.startswith('darwin') - if is_macos and 'CERTIFICATE_VERIFY_FAILED' in str(exc): - py_major, py_minor = sys.version_info[:2] - msg = ( - 'CERTIFICATE_VERIFY_FAILED: ' - 'Before downloading the sample games, ' - 'you need to install SSL certificates. ' - 'Usually this can be resolved by entering this command:\n' - '/Applications/Python\\ {}.{}/Install\\ Certificates.command' - ).format(py_major, py_minor) - self.stdout.write(msg) - sys.exit(-1) raise - try: - pypi_updates_cli() - except: - pass # this assumes the 'directory' arg was unused, which will be true # for 99% of oTree users. msg = ( 'Created project folder.\n' 'Enter "cd {}" to move inside the project folder, ' - 'then start the server with "otree devserver".' # + 'then start the server with "otree devserver".' # ).format(project_name) self.stdout.write(msg) diff --git a/otree/management/commands/test.py b/otree/management/commands/test.py index bb27aa781..74ba64129 100644 --- a/otree/management/commands/test.py +++ b/otree/management/commands/test.py @@ -1,4 +1 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - from .bots import Command # noqa diff --git a/otree/management/commands/timeoutworker.py b/otree/management/commands/timeoutworker.py index f2cfddbc1..5c341ced1 100644 --- a/otree/management/commands/timeoutworker.py +++ b/otree/management/commands/timeoutworker.py @@ -1 +1 @@ -from .runprodserver2of2 import Command # noqa \ No newline at end of file +from .runprodserver2of2 import Command # noqa diff --git a/otree/management/commands/timeoutworkeronly.py b/otree/management/commands/timeoutworkeronly.py index f22ef5176..865ec7bc7 100644 --- a/otree/management/commands/timeoutworkeronly.py +++ b/otree/management/commands/timeoutworkeronly.py @@ -1,10 +1,6 @@ -#!/usr/bin/env python - # run the worker to enforce page timeouts # even if the user closes their browser -from huey.contrib.djhuey.management.commands.run_huey import ( - Command as HueyCommand -) +from huey.contrib.djhuey.management.commands.run_huey import Command as HueyCommand class Command(HueyCommand): @@ -15,5 +11,12 @@ def handle(self, *args, **options): # this code is also in asgi.py. it should be in both places, # to ensure the database is flushed in all circumstances. from huey.contrib.djhuey import HUEY + HUEY.flush() + # need to set USE_REDIS = True, because it uses the test client + # to submit pages, and if the next page has a timeout as well, + # its timeout task should be queued. + import otree.common + + otree.common.USE_REDIS = True super().handle(*args, **options) diff --git a/otree/management/commands/unzip.py b/otree/management/commands/unzip.py index 9fe0d3823..2e38a70fd 100644 --- a/otree/management/commands/unzip.py +++ b/otree/management/commands/unzip.py @@ -13,26 +13,24 @@ class Command(BaseCommand): help = "Unzip a zipped oTree project" def add_arguments(self, parser): - parser.add_argument( - 'zip_file', type=str, help="The .otreezip file") + parser.add_argument('zip_file', type=str, help="The .otreezip file") # it's good to require this arg because then it's obvious that the files # will be put in that subfolder, and not dumped in the current dir parser.add_argument( - 'output_folder', type=str, nargs='?', - help="What to call the new project folder") + 'output_folder', + type=str, + nargs='?', + help="What to call the new project folder", + ) def handle(self, **options): zip_file = options['zip_file'] output_folder = options['output_folder'] or auto_named_output_folder(zip_file) unzip(zip_file, output_folder) - msg = ( - f'Unzipped file. Enter this:\n' - f'cd {esc_fn(output_folder)}\n' - ) + msg = f'Unzipped file. Enter this:\n' f'cd {esc_fn(output_folder)}\n' logger.info(msg) - def run_from_argv(self, argv): ''' override this because the built-in django one executes system checks, @@ -59,6 +57,7 @@ def esc_fn(fn): return f'\"{fn}\"' return fn + def auto_named_output_folder(zip_file_name) -> str: default_folder_name = Path(zip_file_name).stem @@ -89,4 +88,3 @@ def unzip(zip_file: str, output_folder): with tarfile.open(zip_file) as tar: tar.extractall(output_folder) - diff --git a/otree/management/commands/webandworkers.py b/otree/management/commands/webandworkers.py index dedd8b24a..a3d10a36d 100644 --- a/otree/management/commands/webandworkers.py +++ b/otree/management/commands/webandworkers.py @@ -1 +1 @@ -from .runprodserver1of2 import Command # noqa \ No newline at end of file +from .runprodserver1of2 import Command # noqa diff --git a/otree/management/commands/zip.py b/otree/management/commands/zip.py index 0365bec7c..52d6f7cc0 100644 --- a/otree/management/commands/zip.py +++ b/otree/management/commands/zip.py @@ -26,6 +26,7 @@ # TODO: make sure we recognize and exclude virtualenvs, even if not called venv + def filter_func(tar_info: tarfile.TarInfo): path = tar_info.path @@ -94,12 +95,22 @@ def zip_project(project_path: Path): logger.error(str(exc)) sys.exit(1) - # w:gz - with tarfile.open(archive_name, 'w:gz') as tar: - # if i omit arcname, it nests the project 2 levels deep. - # if i say arcname=proj, it puts the whole project in a folder. - # if i say arcname='', it has 0 levels of nesting. - tar.add(project_path, arcname='', filter=filter_func) + # once Heroku uses py 3.7 by default, we can remove this runtime stuff. + runtime_txt = project_path / 'runtime.txt' + runtime_existed = runtime_txt.exists() + if not runtime_existed: + # don't use sys.version_info because it might be newer than what + # heroku supports + runtime_txt.write_text(f'python-3.7.3') + try: + with tarfile.open(archive_name, 'w:gz') as tar: + # if i omit arcname, it nests the project 2 levels deep. + # if i say arcname=proj, it puts the whole project in a folder. + # if i say arcname='', it has 0 levels of nesting. + tar.add(project_path, arcname='', filter=filter_func) + finally: + if not runtime_existed: + runtime_txt.unlink() logger.info(f'Saved your code into file "{archive_name}"') @@ -112,7 +123,8 @@ def get_non_comment_lines(f): return lines -class RequirementsError(Exception): pass +class RequirementsError(Exception): + pass def check_requirements_files(project_path: Path): @@ -127,9 +139,7 @@ def check_requirements_files(project_path: Path): reqs_base_exists = reqs_base_path.exists() if not reqs_path.exists(): - raise RequirementsError( - "You need a requirements.txt in your project folder" - ) + raise RequirementsError("You need a requirements.txt in your project folder") with reqs_path.open() as f: all_req_lines = get_non_comment_lines(f) @@ -172,9 +182,7 @@ def check_requirements_files(project_path: Path): already_seen = set() for ln in all_req_lines: - m = re.match( - '(^[\w-]+).*?', - ln) + m = re.match('(^[\w-]+).*?', ln) if m: package = m.group(1) if package in already_seen: @@ -188,4 +196,4 @@ def check_requirements_files(project_path: Path): f'"{package}" is listed more than once ' 'in your requirements.txt. ' ) - already_seen.add(package) \ No newline at end of file + already_seen.add(package) diff --git a/otree/matching.py b/otree/matching.py deleted file mode 100644 index 6f3565499..000000000 --- a/otree/matching.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -"""Multiple algorithms for sorting oTree players - -""" - -# ============================================================================= -# IMPORTS -# ============================================================================= - -import itertools -import random - -from six.moves import range - -# ============================================================================= -# MATCH -# ============================================================================= - - -def by_rank(ranked_list, players_per_group): - ppg = players_per_group - players = ranked_list - group_matrix = [] - for i in range(0, len(players), ppg): - group_matrix.append(players[i:i + ppg]) - return group_matrix - - -def randomly(group_matrix, fixed_id_in_group=False): - """Random Uniform distribution of players in every group""" - - players = list(itertools.chain.from_iterable(group_matrix)) - sizes = [len(group) for group in group_matrix] - if sizes and any(size != sizes[0] for size in sizes): - raise ValueError( - 'This algorithm does not work with unevenly sized groups') - players_per_group = sizes[0] - - if fixed_id_in_group: - group_matrix = [list(col) for col in zip(*group_matrix)] - for column in group_matrix: - random.shuffle(column) - return list(zip(*group_matrix)) - else: - random.shuffle(players) - return by_rank(players, players_per_group) diff --git a/otree/middleware.py b/otree/middleware.py index 88ae5d1a7..fa0b28b86 100644 --- a/otree/middleware.py +++ b/otree/middleware.py @@ -1,10 +1,11 @@ from django.http import HttpResponseServerError import time -from otree.common_internal import missing_db_tables +from otree.common import missing_db_tables import logging logger = logging.getLogger('otree.perf') + def perf_middleware(get_response): # One-time configuration and initialization. @@ -47,8 +48,7 @@ def __call__(self, request): msg = ( "Your database is not ready. Try resetting the database " "(Missing tables for {}, and {} other models). " - ).format( - ', '.join(listed_tables), len(unlisted_tables)) + ).format(', '.join(listed_tables), len(unlisted_tables)) return HttpResponseServerError(msg) else: CheckDBMiddleware.synced = True diff --git a/otree/models/__init__.py b/otree/models/__init__.py index 2d1541ef3..0dda32a22 100644 --- a/otree/models/__init__.py +++ b/otree/models/__init__.py @@ -1,31 +1,14 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from importlib import import_module +# it's OK to do this because we have .pyi files from django.db.models.signals import class_prepared from otree.db.models import * # noqa - -# NOTE: this imports the following submodules and then subclasses several -# classes importing is done via import_module rather than an ordinary import. -# The only reason for this is to hide the base classes from IDEs like PyCharm, -# so that those members/attributes don't show up in autocomplete, -# including all the built-in django model fields that an ordinary oTree -# programmer will never need or want. if this was a conventional Django -# project I wouldn't do it this way, but because oTree is aimed at newcomers -# who may need more assistance from their IDE, I want to try this approach out. -# this module is also a form of documentation of the public API. - -subsession_module = import_module('otree.models.subsession') -group_module = import_module('otree.models.group') -player_module = import_module('otree.models.player') - - -# so that oTree users don't see internal details -session_module = import_module('otree.models.session') -participant_module = import_module('otree.models.participant') +from otree.models.subsession import BaseSubsession +from otree.models.group import BaseGroup +from otree.models.player import BasePlayer +from otree.models.session import Session +from otree.models.participant import Participant def ensure_required_fields(sender, **kwargs): @@ -41,10 +24,3 @@ def ensure_required_fields(sender, **kwargs): class_prepared.connect(ensure_required_fields) - - -Session = session_module.Session -Participant = participant_module.Participant -BaseSubsession = subsession_module.BaseSubsession -BaseGroup = group_module.BaseGroup -BasePlayer = player_module.BasePlayer diff --git a/otree/models/fieldchecks.py b/otree/models/fieldchecks.py index 48f026530..ee521b0cd 100644 --- a/otree/models/fieldchecks.py +++ b/otree/models/fieldchecks.py @@ -1,7 +1,8 @@ from django.core.exceptions import FieldDoesNotExist +from django.db.models import Field -def ensure_field(model, name, field): +def ensure_field(model, name: str, field: Field): try: existing_field = model._meta.get_field(name) except FieldDoesNotExist: @@ -9,6 +10,8 @@ def ensure_field(model, name, field): else: if not isinstance(existing_field, field.__class__): raise TypeError( - '{model} requires a field with name {name} of type {type}.' - .format(model=model, name=name, type=field.__class__.__name__)) + '{model} requires a field with name {name} of type {type}.'.format( + model=model, name=name, type=field.__class__.__name__ + ) + ) return field diff --git a/otree/models/group.py b/otree/models/group.py index a99759d77..00944c1fe 100644 --- a/otree/models/group.py +++ b/otree/models/group.py @@ -1,7 +1,10 @@ from otree.db import models -from otree.common_internal import ( - get_models_module, in_round, in_rounds, InvalidRoundError, - +from otree.common import ( + get_models_module, + in_round, + in_rounds, + InvalidRoundError, + add_field_tracker, ) from otree.models.fieldchecks import ensure_field import django.core.exceptions @@ -20,8 +23,9 @@ class Meta: id_in_subsession = models.PositiveIntegerField(db_index=True) session = models.ForeignKey( - 'otree.Session', related_name='%(app_label)s_%(class)s', - on_delete=models.CASCADE + 'otree.Session', + related_name='%(app_label)s_%(class)s', + on_delete=models.CASCADE, ) round_number = models.PositiveIntegerField(db_index=True) @@ -37,7 +41,8 @@ def get_player_by_id(self, id_in_group): return self.player_set.get(id_in_group=id_in_group) except django.core.exceptions.ObjectDoesNotExist: raise ValueError( - 'No player with id_in_group {}'.format(id_in_group)) from None + 'No player with id_in_group {}'.format(id_in_group) + ) from None def get_player_by_role(self, role): for p in self.get_players(): @@ -53,30 +58,47 @@ def set_players(self, players_list): def in_round(self, round_number): try: - return in_round(type(self), round_number, session=self.session, - id_in_subsession=self.id_in_subsession) + return in_round( + type(self), + round_number, + session=self.session, + id_in_subsession=self.id_in_subsession, + ) except InvalidRoundError as exc: - msg = str(exc) + '; ' + ( - 'Hint: you should not use this ' - 'method if you are rearranging groups between rounds.' + msg = ( + str(exc) + + '; ' + + ( + 'Hint: you should not use this ' + 'method if you are rearranging groups between rounds.' + ) ) ExceptionClass = type(exc) raise ExceptionClass(msg) from None def in_rounds(self, first, last): try: - return in_rounds(type(self), first, last, session=self.session, - id_in_subsession=self.id_in_subsession) + return in_rounds( + type(self), + first, + last, + session=self.session, + id_in_subsession=self.id_in_subsession, + ) except InvalidRoundError as exc: - msg = str(exc) + '; ' + ( - 'Hint: you should not use this ' - 'method if you are rearranging groups between rounds.' + msg = ( + str(exc) + + '; ' + + ( + 'Hint: you should not use this ' + 'method if you are rearranging groups between rounds.' + ) ) ExceptionClass = type(exc) raise ExceptionClass(msg) from None def in_previous_rounds(self): - return self.in_rounds(1, self.round_number-1) + return self.in_rounds(1, self.round_number - 1) def in_all_rounds(self): return self.in_previous_rounds() + [self] @@ -88,8 +110,11 @@ def _ensure_required_fields(cls): model of the same app. """ subsession_model = '{app_label}.Subsession'.format( - app_label=cls._meta.app_label) + app_label=cls._meta.app_label + ) subsession_field = djmodels.ForeignKey( subsession_model, on_delete=models.CASCADE ) ensure_field(cls, 'subsession', subsession_field) + + add_field_tracker(cls) diff --git a/otree/models/participant.py b/otree/models/participant.py index 12e106219..f47db0d57 100644 --- a/otree/models/participant.py +++ b/otree/models/participant.py @@ -1,41 +1,41 @@ - -from django.db.models import permalink, Sum +from django.db import models as dj_models from django.urls import reverse -import otree.common_internal -from otree import constants_internal -from otree.common_internal import random_chars_8 +import otree.common +from otree.common import random_chars_8, FieldTrackerWithVarsSupport from otree.db import models from otree.models_concrete import ParticipantToPlayerLookup -from .varsmixin import ModelWithVars -class Participant(ModelWithVars): + +class Participant(models.Model): class Meta: ordering = ['pk'] app_label = "otree" index_together = ['session', 'mturk_worker_id', 'mturk_assignment_id'] - session = models.ForeignKey( - 'otree.Session', on_delete=models.CASCADE - ) + _ft = FieldTrackerWithVarsSupport() + vars: dict = models._PickleField(default=dict) + + + session = models.ForeignKey('otree.Session', on_delete=models.CASCADE) label = models.CharField( - max_length=50, null=True, doc=( + max_length=50, + null=True, + doc=( "Label assigned by the experimenter. Can be assigned by passing a " "GET param called 'participant_label' to the participant's start " "URL" - ) + ), ) id_in_session = models.PositiveIntegerField(null=True) payoff = models.CurrencyField(default=0) - time_started = models.DateTimeField(null=True) - user_type_in_url = constants_internal.user_type_participant - mturk_assignment_id = models.CharField( - max_length=50, null=True) + time_started = dj_models.DateTimeField(null=True) + mturk_assignment_id = models.CharField(max_length=50, null=True) mturk_worker_id = models.CharField(max_length=50, null=True) _index_in_subsessions = models.PositiveIntegerField(default=0, null=True) @@ -60,16 +60,22 @@ def _id_in_session(self): "would like to merge this dataset with those from another " "subsession in the same session, you should join on this field, " "which will be the same across subsessions." - ) + ), ) - visited = models.BooleanField( - default=False, db_index=True, - doc="""Whether this user's start URL was opened""" + default=False, db_index=True, doc="""Whether this user's start URL was opened""" ) - ip_address = models.GenericIPAddressField(null=True) + # deprecated on 2019-10-16. eventually get rid of this + @property + def ip_address(self): + return 'deprecated' + + @ip_address.setter + def ip_address(self, value): + if value: + raise ValueError('Do not store anything into participant.ip_address') # stores when the page was first visited _last_page_timestamp = models.PositiveIntegerField(null=True) @@ -80,17 +86,15 @@ def _id_in_session(self): # these are both for the admin # In the changelist, simply call these "page" and "app" - _current_page_name = models.CharField(max_length=200, null=True, - verbose_name='page') - _current_app_name = models.CharField(max_length=200, null=True, - verbose_name='app') + _current_page_name = models.CharField( + max_length=200, null=True, verbose_name='page' + ) + _current_app_name = models.CharField(max_length=200, null=True, verbose_name='app') # only to be displayed in the admin participants changelist - _round_number = models.PositiveIntegerField( - null=True - ) + _round_number = models.PositiveIntegerField(null=True) - _current_form_page_url = models.URLField() + _current_form_page_url = dj_models.URLField() _max_page_index = models.PositiveIntegerField() @@ -115,8 +119,7 @@ def player_lookup(self): # to log2(n). similar to the way arraylists grow. num_extra_lookups = len(self._player_lookups) + 1 qs = ParticipantToPlayerLookup.objects.filter( - participant=self, - page_index__range=(index, index+num_extra_lookups) + participant=self, page_index__range=(index, index + num_extra_lookups) ).values() for player_lookup in qs: self._player_lookups[player_lookup['page_index']] = player_lookup @@ -134,10 +137,10 @@ def get_players(self): lst = [] app_sequence = self.session.config['app_sequence'] for app in app_sequence: - models_module = otree.common_internal.get_models_module(app) - players = models_module.Player.objects.filter( - participant=self - ).order_by('round_number') + models_module = otree.common.get_models_module(app) + players = models_module.Player.objects.filter(participant=self).order_by( + 'round_number' + ) lst.extend(list(players)) return lst @@ -156,33 +159,13 @@ def _url_i_should_be_on(self): return self._start_url() if self._index_in_pages <= self._max_page_index: return self.player_lookup()['url'] - if self.session.mturk_HITId: - assignment_id = self.mturk_assignment_id - if self.session.mturk_use_sandbox: - url = 'https://workersandbox.mturk.com/mturk/externalSubmit' - else: - url = "https://www.mturk.com/mturk/externalSubmit" - url = otree.common_internal.add_params_to_url( - url, - { - 'assignmentId': assignment_id, - 'extra_param': '1' # required extra param? - } - ) - return url return reverse('OutOfRangeNotification') def _start_url(self): - return otree.common_internal.participant_start_url(self.code) + return otree.common.participant_start_url(self.code) def payoff_in_real_world_currency(self): - return self.payoff.to_real_world_currency( - self.session - ) - - def money_to_pay(self): - '''deprecated''' - return self.payoff_plus_participation_fee() + return self.payoff.to_real_world_currency(self.session) def payoff_plus_participation_fee(self): return self.session._get_payoff_plus_participation_fee(self.payoff) diff --git a/otree/models/player.py b/otree/models/player.py index d14ef0303..da75ca5cd 100644 --- a/otree/models/player.py +++ b/otree/models/player.py @@ -1,12 +1,11 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from otree.common_internal import ( - get_models_module, in_round, in_rounds) +from otree.common import ( + add_field_tracker, + in_round, + in_rounds, +) from otree.db import models from otree.models.fieldchecks import ensure_field -from django.db import models as djmodels class BasePlayer(models.Model): @@ -22,25 +21,27 @@ class Meta: id_in_group = models.PositiveIntegerField( null=True, db_index=True, - doc=("Index starting from 1. In multiplayer games, " - "indicates whether this is player 1, player 2, etc.") + doc=( + "Index starting from 1. In multiplayer games, " + "indicates whether this is player 1, player 2, etc." + ), ) # don't modify this directly! Set player.payoff instead _payoff = models.CurrencyField( - null=True, - doc="""The payoff the player made in this subsession""", - default=0 + null=True, doc="""The payoff the player made in this subsession""", default=0 ) participant = models.ForeignKey( - 'otree.Participant', related_name='%(app_label)s_%(class)s', - on_delete=models.CASCADE + 'otree.Participant', + related_name='%(app_label)s_%(class)s', + on_delete=models.CASCADE, ) session = models.ForeignKey( - 'otree.Session', related_name='%(app_label)s_%(class)s', - on_delete=models.CASCADE + 'otree.Session', + related_name='%(app_label)s_%(class)s', + on_delete=models.CASCADE, ) round_number = models.PositiveIntegerField(db_index=True) @@ -106,12 +107,15 @@ def _ensure_required_fields(cls): ``Group`` model of the same app. """ subsession_model = '{app_label}.Subsession'.format( - app_label=cls._meta.app_label) + app_label=cls._meta.app_label + ) subsession_field = models.ForeignKey(subsession_model, on_delete=models.CASCADE) ensure_field(cls, 'subsession', subsession_field) - group_model = '{app_label}.Group'.format( - app_label=cls._meta.app_label) - group_field = models.ForeignKey(group_model, null=True, on_delete=models.CASCADE) + group_model = '{app_label}.Group'.format(app_label=cls._meta.app_label) + group_field = models.ForeignKey( + group_model, null=True, on_delete=models.CASCADE + ) ensure_field(cls, 'group', group_field) + add_field_tracker(cls) diff --git a/otree/models/session.py b/otree/models/session.py index 653d25566..8f35f14a0 100644 --- a/otree/models/session.py +++ b/otree/models/session.py @@ -1,27 +1,21 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - import logging -import channels -import json -from django.urls import reverse -from asgiref.sync import async_to_sync -from channels.layers import get_channel_layer -from otree.channels.utils import auto_advance_group - -from otree import constants_internal -import otree.common_internal -from otree.common_internal import ( - random_chars_8, random_chars_10, get_admin_secret_code, - get_app_label_from_name -) import time -from otree.db import models -from otree.models_concrete import ParticipantToPlayerLookup, RoomToSession -from django.template.loader import select_template from django.template import TemplateDoesNotExist -from .varsmixin import ModelWithVars +from django.template.loader import select_template + +import otree.common +import otree.constants +from otree.channels.utils import auto_advance_group +from otree.common import ( + random_chars_8, + random_chars_10, + get_admin_secret_code, + get_app_label_from_name, + FieldTrackerWithVarsSupport, +) +from otree.db import models +from otree.models_concrete import RoomToSession logger = logging.getLogger('otree') @@ -29,24 +23,25 @@ ADMIN_SECRET_CODE = get_admin_secret_code() -class Session(ModelWithVars): +class Session(models.Model): class Meta: app_label = "otree" # if i don't set this, it could be in an unpredictable order ordering = ['pk'] - _pickle_fields = ['vars', 'config'] - config = models._PickleField(default=dict, null=True) # type: dict + _ft = FieldTrackerWithVarsSupport() + vars: dict = models._PickleField(default=dict) + config: dict = models._PickleField(default=dict, null=True) # label of this session instance label = models.CharField( - max_length=300, null=True, blank=True, - help_text='For internal record-keeping') + max_length=300, null=True, blank=True, help_text='For internal record-keeping' + ) experimenter_name = models.CharField( - max_length=300, null=True, blank=True, - help_text='For internal record-keeping') + max_length=300, null=True, blank=True, help_text='For internal record-keeping' + ) code = models.CharField( default=random_chars_8, @@ -54,46 +49,52 @@ class Meta: # set non-nullable, until we make our CharField non-nullable null=False, unique=True, - doc="Randomly generated unique identifier for the session.") + doc="Randomly generated unique identifier for the session.", + ) mturk_HITId = models.CharField( - max_length=300, null=True, blank=True, - help_text='Hit id for this session on MTurk') + max_length=300, + null=True, + blank=True, + help_text='Hit id for this session on MTurk', + ) mturk_HITGroupId = models.CharField( - max_length=300, null=True, blank=True, - help_text='Hit id for this session on MTurk') + max_length=300, + null=True, + blank=True, + help_text='Hit id for this session on MTurk', + ) # since workers can drop out number of participants on server should be # greater than number of participants on mturk # value -1 indicates that this session it not intended to run on mturk mturk_num_participants = models.IntegerField( - default=-1, - help_text="Number of participants on MTurk") + default=-1, help_text="Number of participants on MTurk" + ) mturk_use_sandbox = models.BooleanField( - default=True, - help_text="Should this session be created in mturk sandbox?") + default=True, help_text="Should this session be created in mturk sandbox?" + ) # use Float instead of DateTime because DateTime # is a pain to work with (e.g. naive vs aware datetime objects) # and there is no need here for DateTime - mturk_expiration = models.FloatField( - null=True - ) + mturk_expiration = models.FloatField(null=True) archived = models.BooleanField( default=False, db_index=True, - doc=("If set to True the session won't be visible on the " - "main ViewList for sessions")) + doc=( + "If set to True the session won't be visible on the " + "main ViewList for sessions" + ), + ) comment = models.TextField(blank=True) _anonymous_code = models.CharField( - default=random_chars_10, max_length=10, null=False, db_index=True) - - def use_browser_bots(self): - return self.participant_set.filter(is_browser_bot=True).exists() + default=random_chars_10, max_length=10, null=False, db_index=True + ) is_demo = models.BooleanField(default=False) @@ -117,6 +118,10 @@ def real_world_currency_per_point(self): but still useful internally (like data export)''' return self.config['real_world_currency_per_point'] + @property + def use_browser_bots(self): + return self.config.get('use_browser_bots', False) + def is_mturk(self): return (not self.is_demo) and (self.mturk_num_participants > 0) @@ -124,7 +129,7 @@ def get_subsessions(self): lst = [] app_sequence = self.config['app_sequence'] for app in app_sequence: - models_module = otree.common_internal.get_models_module(app) + models_module = otree.common.get_models_module(app) subsessions = models_module.Subsession.objects.filter( session=self ).order_by('round_number') @@ -144,7 +149,6 @@ def mturk_worker_url(self): # not work is if the HIT was deleted from the server, but in that case, # the HIT itself should be canceled. - # 2018-06-04: # the format seems to have changed to this: # https://worker.mturk.com/projects/{group_id}/tasks?ref=w_pl_prvw @@ -153,8 +157,7 @@ def mturk_worker_url(self): # because it's more precise. subdomain = "workersandbox" if self.mturk_use_sandbox else 'www' return "https://{}.mturk.com/mturk/preview?groupId={}".format( - subdomain, - self.mturk_HITGroupId + subdomain, self.mturk_HITGroupId ) def mturk_is_expired(self): @@ -172,6 +175,7 @@ def advance_last_place_participants(self): # so best to do it only here # it gets cached import django.test + client = django.test.Client() participants = self.get_participants() @@ -189,8 +193,7 @@ def advance_last_place_participants(self): last_place_page_index = min([p._index_in_pages for p in participants]) last_place_participants = [ - p for p in participants - if p._index_in_pages == last_place_page_index + p for p in participants if p._index_in_pages == last_place_page_index ] for p in last_place_participants: @@ -200,17 +203,19 @@ def advance_last_place_participants(self): resp = client.post( current_form_page_url, data={ - constants_internal.timeout_happened: True, - constants_internal.admin_secret_code: ADMIN_SECRET_CODE + otree.constants.timeout_happened: True, + otree.constants.admin_secret_code: ADMIN_SECRET_CODE, }, - follow=True + follow=True, ) # not sure why, but many users are getting HttpResponseNotFound if resp.status_code >= 400: - msg = ('Submitting page {} failed, ' + msg = ( + 'Submitting page {} failed, ' 'returned HTTP status code {}.'.format( current_form_page_url, resp.status_code - )) + ) + ) content = resp.content if len(content) < 600: msg += ' response content: {}'.format(content) @@ -231,14 +236,13 @@ def advance_last_place_participants(self): # do the auto-advancing here, # rather than in increment_index_in_pages, # because it's only needed here. - otree.channels.utils.sync_group_send( - auto_advance_group(p.code), - {'type': 'auto_advanced'} + otree.channels.utils.sync_group_send_wrapper( + type='auto_advanced', group=auto_advance_group(p.code), event={} ) - def get_room(self): from otree.room import ROOM_DICT + try: room_name = RoomToSession.objects.get(session=self).room_name return ROOM_DICT[room_name] @@ -251,23 +255,19 @@ def _get_payoff_plus_participation_fee(self, payoff): Useful to define it here, for data export ''' - return ( - self.config['participation_fee'] + - payoff.to_real_world_currency(self) - ) + return self.config['participation_fee'] + payoff.to_real_world_currency(self) def _set_admin_report_app_names(self): admin_report_app_names = [] num_rounds_list = [] for app_name in self.config['app_sequence']: - models_module = otree.common_internal.get_models_module(app_name) + models_module = otree.common.get_models_module(app_name) app_label = get_app_label_from_name(app_name) try: - select_template([ - f'{app_label}/admin_report.html', - f'{app_label}/AdminReport.html', - ]) + select_template( + [f'{app_label}/admin_report.html', f'{app_label}/AdminReport.html'] + ) except TemplateDoesNotExist: pass else: diff --git a/otree/models/subsession.py b/otree/models/subsession.py index 8e9d5c10b..cd592c67f 100644 --- a/otree/models/subsession.py +++ b/otree/models/subsession.py @@ -1,18 +1,12 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import six from django.db.models import Prefetch + +import otree.common from otree.db import models -from otree.common_internal import ( - get_models_module, in_round, in_rounds) -from otree import matching +from otree.common import get_models_module, in_round, in_rounds import copy -from collections import defaultdict -from otree.common_internal import has_group_by_arrival_time -from django.template.loader import select_template -from django.template import TemplateDoesNotExist, Template -from typing import Optional +from otree.common import has_group_by_arrival_time, add_field_tracker +from django.apps import apps + class GroupMatrixError(ValueError): pass @@ -21,6 +15,7 @@ class GroupMatrixError(ValueError): class RoundMismatchError(GroupMatrixError): pass + class BaseSubsession(models.Model): """Base class for all Subsessions. """ @@ -31,7 +26,9 @@ class Meta: index_together = ['session', 'round_number'] session = models.ForeignKey( - 'otree.Session', related_name='%(app_label)s_%(class)s', null=True, + 'otree.Session', + related_name='%(app_label)s_%(class)s', + null=True, on_delete=models.CASCADE, ) @@ -40,19 +37,17 @@ class Meta: doc='''If this subsession is repeated (i.e. has multiple rounds), this field stores the position of this subsession, among subsessions in the same app. - ''' + ''', ) def in_round(self, round_number): - return in_round(type(self), round_number, - session=self.session, - ) + return in_round(type(self), round_number, session=self.session) def in_rounds(self, first, last): return in_rounds(type(self), first, last, session=self.session) def in_previous_rounds(self): - return self.in_rounds(1, self.round_number-1) + return self.in_rounds(1, self.round_number - 1) def in_all_rounds(self): return self.in_previous_rounds() + [self] @@ -70,10 +65,14 @@ def get_group_matrix(self): players_prefetch = Prefetch( 'player_set', queryset=self._PlayerClass().objects.order_by('id_in_group'), - to_attr='_ordered_players') - return [group._ordered_players - for group in self.group_set.order_by('id_in_subsession') - .prefetch_related(players_prefetch)] + to_attr='_ordered_players', + ) + return [ + group._ordered_players + for group in self.group_set.order_by('id_in_subsession').prefetch_related( + players_prefetch + ) + ] def set_group_matrix(self, matrix): """ @@ -84,9 +83,7 @@ def set_group_matrix(self, matrix): try: players_flat = [p for g in matrix for p in g] except TypeError: - raise GroupMatrixError( - 'Group matrix must be a list of lists.' - ) from None + raise GroupMatrixError('Group matrix must be a list of lists.') from None try: matrix_pks = sorted(p.pk for p in players_flat) except AttributeError: @@ -113,19 +110,19 @@ def set_group_matrix(self, matrix): ) from None else: existing_pks = list( - self.player_set.values_list( - 'pk', flat=True - ).order_by('pk')) + self.player_set.values_list('pk', flat=True).order_by('pk') + ) if matrix_pks != existing_pks: wrong_round_numbers = [ - p.round_number for p in players_flat - if p.round_number != self.round_number] + p.round_number + for p in players_flat + if p.round_number != self.round_number + ] if wrong_round_numbers: raise GroupMatrixError( 'You are setting the groups for round {}, ' 'but the matrix contains players from round {}.'.format( - self.round_number, - wrong_round_numbers[0] + self.round_number, wrong_round_numbers[0] ) ) raise GroupMatrixError( @@ -143,8 +140,11 @@ def set_group_matrix(self, matrix): GroupClass = self._GroupClass() for i, row in enumerate(matrix, start=1): group = GroupClass.objects.create( - subsession=self, id_in_subsession=i, - session=self.session, round_number=self.round_number) + subsession=self, + id_in_subsession=i, + session=self.session, + round_number=self.round_number, + ) group.set_players(row) @@ -157,9 +157,8 @@ def group_like_round(self, round_number): ).prefetch_related( Prefetch( 'player_set', - queryset=self._PlayerClass().objects.order_by( - 'id_in_group'), - to_attr='_ordered_players' + queryset=self._PlayerClass().objects.order_by('id_in_group'), + to_attr='_ordered_players', ) ) ] @@ -171,72 +170,23 @@ def group_like_round(self, round_number): self.set_group_matrix(group_matrix) - def new_group_like_round(self, round_number): - '''test this, could work''' - matrix = self.in_round(round_number).get_group_matrix() - for row in matrix: - for col in row: - matrix[row][col] = matrix[row][col].id_in_subsession - self.set_group_matrix(matrix) - - ''' - def group_like_round(self, round_number): - PlayerClass = self._PlayerClass() - last_round_info = PlayerClass.objects.filter( - session_id=self.session_id, - round_number=round_number - ).values( - 'id_in_group', 'participant_id', 'group__id_in_subsession' - ).order_by('group__id_in_subsession', 'id_in_group') - - player_lookups = {p.participant_id: p for p in self.get_players()} - - self.player_set.update(group=None) - self.group_set.all().delete() - - # UNFINISHED - GroupClass = self._GroupClass() - for i, row in enumerate(matrix, start=1): - group = GroupClass.objects.create( - subsession=self, id_in_subsession=i, - session=self.session, round_number=self.round_number) - - group.set_players(row) - ''' - - - def set_groups(self, matrix): - '''renamed this to set_group_matrix, but keeping in for compat''' - return self.set_group_matrix(matrix) - @property def _Constants(self): return get_models_module(self._meta.app_config.name).Constants def _GroupClass(self): - return models.get_model(self._meta.app_config.label, 'Group') + return apps.get_model(self._meta.app_config.label, 'Group') def _PlayerClass(self): - return models.get_model(self._meta.app_config.label, 'Player') + return apps.get_model(self._meta.app_config.label, 'Player') @classmethod def _has_group_by_arrival_time(cls): return has_group_by_arrival_time(cls._meta.app_config.name) - def group_randomly(self, *, fixed_id_in_group=False): group_matrix = self.get_group_matrix() - group_matrix = matching.randomly( - group_matrix, - fixed_id_in_group) - self.set_group_matrix(group_matrix) - - def _group_by_rank(self, ranked_list): - # FIXME: delete this - group_matrix = matching.by_rank( - ranked_list, - self._Constants.players_per_group - ) + group_matrix = otree.common._group_randomly(group_matrix, fixed_id_in_group) self.set_group_matrix(group_matrix) def before_session_starts(self): @@ -248,3 +198,7 @@ def creating_session(self): def vars_for_admin_report(self): return {} + + @classmethod + def _ensure_required_fields(cls): + add_field_tracker(cls) diff --git a/otree/models/varsmixin.py b/otree/models/varsmixin.py deleted file mode 100644 index 1c872384b..000000000 --- a/otree/models/varsmixin.py +++ /dev/null @@ -1,46 +0,0 @@ -from otree.db import models - -class _SaveTheChangeWithCustomFieldSupport: - ''' - 2017-08-07: kept around because old migrations files reference it. - after a few months when i squash migrations, - the references to this will be deleted, so i can delete it. - - 2017-09-05: I found a bug with NumPy + SaveTheChange; - https://github.com/karanlyons/django-save-the-change/issues/27 - So I need to use this again. Implementing a simplified version of what - Gregor made a while back. - ''' - - _pickle_fields = ['vars'] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._save_the_change_store_initial_pickle_fields() - - def save(self, *args, **kwargs): - self._save_the_change_check_pickle_field_changes() - return super().save(*args, **kwargs) - - def _save_the_change_store_initial_pickle_fields(self): - self._initial_prep_values = {} - for field_name in self._pickle_fields: - field = self._meta.get_field(field_name) - self._initial_prep_values[field_name] = field.get_prep_value( - getattr(self, field_name)) - - def _save_the_change_check_pickle_field_changes(self): - for field_name in self._pickle_fields: - field = self._meta.get_field(field_name) - new_value = field.get_prep_value(getattr(self, field_name)) - initial_prep_value = self._initial_prep_values[field_name] - if new_value != initial_prep_value: - self._changed_fields[field_name] = field.to_python(initial_prep_value) - - -class ModelWithVars(_SaveTheChangeWithCustomFieldSupport, models.Model): - - class Meta: - abstract = True - - vars = models._PickleField(default=dict) # type: dict diff --git a/otree/models_concrete.py b/otree/models_concrete.py index 6c7dcf9c4..ec2a523a0 100644 --- a/otree/models_concrete.py +++ b/otree/models_concrete.py @@ -1,6 +1,7 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- import time +from collections import defaultdict +from typing import Iterable + from django.db import models @@ -11,7 +12,10 @@ class Meta: app_name = models.CharField(max_length=300) page_index = models.PositiveIntegerField() page_name = models.CharField(max_length=300) - time_stamp = models.PositiveIntegerField() + # it needs a default, otherwise i get "added a non-nullable field without default" + # eventually i can remove the default, if I am sure people did not skip over all + # the intermediate versions. + epoch_time = models.PositiveIntegerField(null=True) seconds_on_page = models.PositiveIntegerField() subsession_pk = models.PositiveIntegerField() participant = models.ForeignKey('otree.Participant', on_delete=models.CASCADE) @@ -19,6 +23,15 @@ class Meta: auto_submitted = models.BooleanField() +class WaitPagePassage(models.Model): + participant = models.ForeignKey('otree.Participant', on_delete=models.CASCADE) + session = models.ForeignKey('otree.Session', on_delete=models.CASCADE) + # don't set default=time.time because that's harder to patch + epoch_time = models.PositiveIntegerField(null=True) + # if False, means they exit the wait page + is_enter = models.BooleanField() + + class PageTimeout(models.Model): class Meta: app_label = "otree" @@ -70,9 +83,7 @@ class ParticipantLockModel(models.Model): class Meta: app_label = "otree" - participant_code = models.CharField( - max_length=16, unique=True - ) + participant_code = models.CharField(max_length=16, unique=True) locked = models.BooleanField(default=False) @@ -98,7 +109,6 @@ class Meta: session = models.ForeignKey('otree.Session', on_delete=models.CASCADE) - class ParticipantRoomVisit(models.Model): class Meta: app_label = "otree" @@ -127,12 +137,46 @@ class Meta: channel = models.CharField(max_length=255) # related_name necessary to disambiguate with otreechat add on participant = models.ForeignKey( - 'otree.Participant', related_name='chat_messages_core', - on_delete=models.CASCADE + 'otree.Participant', related_name='chat_messages_core', on_delete=models.CASCADE ) nickname = models.CharField(max_length=255) # call it 'body' instead of 'message' or 'content' because those terms # are already used by channels body = models.TextField() - timestamp = models.FloatField(default=time.time) \ No newline at end of file + timestamp = models.FloatField(default=time.time) + + +def add_time_spent_waiting(participants): + session_passages_qs = WaitPagePassage.objects.filter( + participant__in=participants + ).order_by('id') + _add_time_spent_waiting_inner( + participants=participants, session_passages_qs=session_passages_qs + ) + + +def _add_time_spent_waiting_inner( + *, participants, session_passages_qs: Iterable[WaitPagePassage] +): + '''adds the attribute to each participant object so it can be shown in the template''' + + session_passages = defaultdict(list) + for passage in session_passages_qs: + session_passages[passage.participant_id].append(passage) + + for participant in participants: + total = 0 + enter_time = None + passages = session_passages.get(participant.id, []) + for p in passages: + if p.is_enter and not enter_time: + enter_time = p.epoch_time + if not p.is_enter and enter_time: + total += p.epoch_time - enter_time + enter_time = None + # means they are still waiting + if enter_time: + total += time.time() - enter_time + participant._is_frozen = False + participant.waiting_seconds = int(total) diff --git a/otree/project_template/manage.py b/otree/project_template/manage.py index a6f04e9cf..d6f4caac3 100644 --- a/otree/project_template/manage.py +++ b/otree/project_template/manage.py @@ -7,4 +7,5 @@ os.environ.setdefault("DJANGO_SETTINGS_MODULE", "settings") from otree.management.cli import execute_from_command_line + execute_from_command_line(sys.argv, script_file=__file__) diff --git a/otree/project_template/settings.py b/otree/project_template/settings.py index 9ae40eb2b..8a8815b14 100644 --- a/otree/project_template/settings.py +++ b/otree/project_template/settings.py @@ -5,19 +5,17 @@ # the session config can be accessed from methods in your apps as self.session.config, # e.g. self.session.config['participation_fee'] -SESSION_CONFIG_DEFAULTS = { - 'real_world_currency_per_point': 1.00, - 'participation_fee': 0.00, - 'doc': "", -} +SESSION_CONFIG_DEFAULTS = dict( + real_world_currency_per_point=1.00, participation_fee=0.00, doc="" +) SESSION_CONFIGS = [ - #{ - # 'name': 'public_goods', - # 'display_name': "Public Goods", - # 'num_demo_participants': 3, - # 'app_sequence': ['public_goods', 'payment_info'], - #}, + # dict( + # name='public_goods', + # display_name="Public Goods", + # num_demo_participants=3, + # app_sequence=['public_goods', 'payment_info'] + # ), ] diff --git a/otree/room.py b/otree/room.py index fdc776a57..20178bb73 100644 --- a/otree/room.py +++ b/otree/room.py @@ -1,21 +1,18 @@ -import codecs -from collections import OrderedDict - -import schema - +from pathlib import Path from otree.models_concrete import RoomToSession -from otree.common_internal import ( - add_params_to_url, make_hash, validate_alphanumeric) +from otree.common import add_params_to_url, make_hash, validate_alphanumeric from django.conf import settings from django.urls import reverse from django.db import transaction -class Room(object): - def __init__(self, name, display_name, use_secure_urls, participant_label_file=None): +class Room: + def __init__( + self, name, display_name, use_secure_urls=False, participant_label_file=None + ): self.name = validate_alphanumeric( - name, - identifier_description='settings.ROOMS room name') + name, identifier_description='settings.ROOMS room name' + ) if use_secure_urls and not participant_label_file: raise ValueError( 'Room "{}": you must either set "participant_label_file", ' @@ -31,8 +28,11 @@ def has_session(self): def get_session(self): try: - return RoomToSession.objects.select_related('session').get( - room_name=self.name).session + return ( + RoomToSession.objects.select_related('session') + .get(room_name=self.name) + .session + ) except RoomToSession.DoesNotExist: return None @@ -40,65 +40,23 @@ def set_session(self, session): with transaction.atomic(): RoomToSession.objects.filter(room_name=self.name).delete() if session: - RoomToSession.objects.create( - room_name=self.name, - session=session - ) + RoomToSession.objects.create(room_name=self.name, session=session) def has_participant_labels(self): return bool(self.participant_label_file) def get_participant_labels(self): - ''' - Decided to just re-read the file on every request, - rather than loading in the DB. Reasons: - - (1) Simplifies the code; we don't need an ExpectedRoomParticipant model, - and don't need to load data into there (which involves considerations - of race conditions) - (2) Don't need any complicated rule deciding when to reload the file, - whether it's upon starting the process or resetting the database, - or both. Should the status be stored in the DB or in the process? - (3) Checking if a given participant label is in the file is actually faster - than looking it up in the DB table, even with .filter() and and index! - (tested on Postgres with 10000 iterations: 17s vs 18s) - ''' - - # if i refactor this, i should use chardet instead - encodings = ['ascii', 'utf-8', 'utf-16'] - for e in encodings: - try: - plabel_path = self.participant_label_file - with codecs.open(plabel_path, "r", encoding=e) as f: - seen = set() - labels = [] - for line in f: - label = line.strip() - if not label: - continue - validate_alphanumeric( - label, - identifier_description='participant label' - ) - if label not in seen: - labels.append(label) - seen.add(label) - except UnicodeDecodeError: - continue - except FileNotFoundError: - msg = ( - 'settings.ROOMS: The room "{}" references ' - ' nonexistent participant_label_file "{}".' - ) - raise FileNotFoundError( - msg.format(self.name, self.participant_label_file) - ) from None - else: - return labels - raise Exception( - 'settings.ROOMS: participant_label_file "{}" ' - 'not encoded correctly.'.format(self.participant_label_file) + lines = ( + Path(self.participant_label_file).read_text(encoding='utf8').splitlines() ) + labels = [] + for line in lines: + label = line.strip() + if label: + validate_alphanumeric(label, identifier_description='participant label') + labels.append(label) + # eliminate duplicates + return list(dict.fromkeys(labels)) def get_room_wide_url(self, request): url = reverse('AssignVisitorToRoom', args=(self.name,)) @@ -120,44 +78,16 @@ def get_participant_urls(self, request): return participant_urls -def augment_room(room, ROOM_DEFAULTS): - new_room = {} - new_room.update(ROOM_DEFAULTS) - new_room.update(room) - return new_room - - def get_room_dict(): - room_defaults_schema = schema.Schema( - { - schema.Optional('use_secure_urls', default=False): bool, - schema.Optional('participant_label_file'): str, - } - ) - - room_schema = schema.Schema( - { - 'name': str, - 'display_name': str, - schema.Optional('use_secure_urls'): bool, - schema.Optional('participant_label_file'): str, - } - ) - - ROOM_DICT = OrderedDict() ROOM_DEFAULTS = getattr(settings, 'ROOM_DEFAULTS', {}) - try: - ROOM_DEFAULTS = room_defaults_schema.validate(ROOM_DEFAULTS) - except schema.SchemaError as e: - raise (ValueError('settings.ROOM_DEFAULTS: {}'.format(e))) from None - for room in getattr(settings, 'ROOMS', []): - room = augment_room(room, ROOM_DEFAULTS) - try: - room = room_schema.validate(room) - except schema.SchemaError as e: - raise(ValueError('settings.ROOMS: {}'.format(e))) from None - room_object = Room(**room) + ROOMS = getattr(settings, 'ROOMS', []) + ROOM_DICT = {} + for room in ROOMS: + # extra layer in case ROOM_DEFAULTS has the same key + # as a room + room_object = Room(**dict(ROOM_DEFAULTS, **room)) ROOM_DICT[room_object.name] = room_object return ROOM_DICT + ROOM_DICT = get_room_dict() diff --git a/otree/session.py b/otree/session.py index f543acebc..526e35a0e 100644 --- a/otree/session.py +++ b/otree/session.py @@ -1,28 +1,24 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import random -import sys -from functools import reduce from collections import OrderedDict +from collections import defaultdict from decimal import Decimal -import warnings -from django.urls import reverse +from functools import reduce +from typing import List, Dict from django.conf import settings from django.db import transaction -from django.db.utils import OperationalError +import otree.bots.browser +import otree.common import otree.db.idmap +from otree import common +from otree.common import ( + get_models_module, + get_app_constants, + validate_alphanumeric, + get_bots_module, +) +from otree.currency import RealWorldCurrency from otree.models import Participant, Session -from otree.common_internal import ( - get_models_module, get_app_constants, validate_alphanumeric, - get_bots_module) -from otree import common_internal -import otree.common_internal -from otree.common import RealWorldCurrency from otree.models_concrete import ParticipantLockModel, ParticipantToPlayerLookup -import otree.bots.browser -from collections import defaultdict def gcd(a, b): @@ -76,57 +72,13 @@ def get_num_bot_cases(self): def clean(self): - required_keys = [ - 'name', - 'app_sequence', - 'num_demo_participants', - 'participation_fee', - 'real_world_currency_per_point', - ] - - for key in required_keys: - if key not in self: - raise SessionConfigError( - 'settings.SESSION_CONFIGS: all configs must have a ' - '"{}"'.format(key) - ) - - datatypes = { - 'app_sequence': list, - 'num_demo_participants': int, - 'name': str, - } - - for key, datatype in datatypes.items(): - if not isinstance(self[key], datatype): - msg = ( - 'SESSION_CONFIGS "{}": ' - 'the entry "{}" must be of type {}' - ).format(self['name'], key, datatype.__name__) + for k in ['name', 'app_sequence', 'num_demo_participants']: + if k not in self: + msg = f'Session config is missing "{k}"' raise SessionConfigError(msg) - # Allow non-ASCII chars in session config keys, because they are - # configurable in the admin, so need to be readable by non-English - # speakers. However, don't allow punctuation, spaces, etc. - # They make it harder to reason about and could cause problems - # later on. also could uglify the user's code. - - INVALID_IDENTIFIER_MSG = ( - 'Key "{}" in settings.SESSION_CONFIGS ' - 'must not contain spaces, punctuation, ' - 'or other special characters. ' - 'It can contain non-English characters, ' - 'but it must be a valid Python variable name ' - 'according to string.isidentifier().' - ) - - for key in self: - if not key.isidentifier(): - raise SessionConfigError(INVALID_IDENTIFIER_MSG.format(key)) - validate_alphanumeric( - self['name'], - identifier_description='settings.SESSION_CONFIGS name' + self['name'], identifier_description='settings.SESSION_CONFIGS name' ) app_sequence = self['app_sequence'] @@ -136,18 +88,19 @@ def clean(self): 'app_sequence of "{}" ' 'must not contain duplicate elements. ' 'If you want multiple rounds, ' - 'you should set Constants.num_rounds.') + 'you should set Constants.num_rounds.' + ) raise SessionConfigError(msg.format(self['name'])) if len(app_sequence) == 0: raise SessionConfigError( - 'settings.SESSION_CONFIGS: app_sequence cannot be empty.') + 'settings.SESSION_CONFIGS: app_sequence cannot be empty.' + ) self.setdefault('display_name', self['name']) self.setdefault('doc', '') - self['participation_fee'] = RealWorldCurrency( - self['participation_fee']) + self['participation_fee'] = RealWorldCurrency(self['participation_fee']) def app_sequence_display(self): app_sequence = [] @@ -155,13 +108,13 @@ def app_sequence_display(self): models_module = get_models_module(app_name) num_rounds = models_module.Constants.num_rounds if num_rounds > 1: - formatted_app_name = '{} ({} rounds)'.format( - app_name, num_rounds) + formatted_app_name = '{} ({} rounds)'.format(app_name, num_rounds) else: formatted_app_name = app_name subsssn = { 'doc': getattr(models_module, 'doc', ''), - 'name': formatted_app_name} + 'name': formatted_app_name, + } app_sequence.append(subsssn) return app_sequence @@ -190,10 +143,12 @@ def custom_editable_fields(self): # so i'll just put a general recommendation in the docs return [ - k for k, v in self.items() + k + for k, v in self.items() if k not in self.non_editable_fields and k not in self.builtin_editable_fields() - and type(v) in [bool, int, float, str]] + and type(v) in [bool, int, float, str] + ] def editable_fields(self): return self.builtin_editable_fields() + self.custom_editable_fields() @@ -234,39 +189,53 @@ def editable_field_html(self, field_name): attrs = [ "type='text'", "value='{}'".format(existing_value), - "class='form-control'" + "class='form-control'", ] html = ''' {} - '''.format(field_name, ' '.join(base_attrs + attrs)) + '''.format( + field_name, ' '.join(base_attrs + attrs) + ) return html def builtin_editable_fields_html(self): - return [self.editable_field_html(k) - for k in self.builtin_editable_fields()] + return [self.editable_field_html(k) for k in self.builtin_editable_fields()] def custom_editable_fields_html(self): - return [self.editable_field_html(k) - for k in self.custom_editable_fields()] + return [self.editable_field_html(k) for k in self.custom_editable_fields()] -def get_session_configs_dict(): - SESSION_CONFIGS_DICT = OrderedDict() - for config_dict in settings.SESSION_CONFIGS: - config_obj = SessionConfig(settings.SESSION_CONFIG_DEFAULTS) +def get_session_configs_dict( + SESSION_CONFIGS: List[Dict], SESSION_CONFIG_DEFAULTS: Dict +): + SESSION_CONFIGS_DICT = {} + for config_dict in SESSION_CONFIGS: + config_obj = SessionConfig(SESSION_CONFIG_DEFAULTS) config_obj.update(config_dict) config_obj.clean() - SESSION_CONFIGS_DICT[config_dict['name']] = config_obj + config_name = config_dict['name'] + if config_name in SESSION_CONFIGS_DICT: + msg = f"Duplicate SESSION_CONFIG name: {config_name}" + raise SessionConfigError(msg) + SESSION_CONFIGS_DICT[config_name] = config_obj return SESSION_CONFIGS_DICT -SESSION_CONFIGS_DICT = get_session_configs_dict() + +SESSION_CONFIGS_DICT = get_session_configs_dict( + settings.SESSION_CONFIGS, settings.SESSION_CONFIG_DEFAULTS +) def create_session( - session_config_name, *, num_participants, label='', - room_name=None, is_mturk=False, - is_demo=False, - edited_session_config_fields=None) -> Session: + session_config_name, + *, + num_participants, + label='', + room_name=None, + is_mturk=False, + is_demo=False, + edited_session_config_fields=None, +) -> Session: num_subsessions = 0 edited_session_config_fields = edited_session_config_fields or {} @@ -293,10 +262,23 @@ def create_session( # to be a bit discouraged: http://goo.gl/dEXZpv # 2014-9-22: preassign to groups for demo mode. + # check that it divides evenly + session_lcm = session_config.get_lcm() + if num_participants is None: + # most games are multiplayer, so if it's under 2, we bump it to 2 + num_participants = max(session_lcm, 2) + else: + if num_participants % session_lcm: + msg = ( + 'Session Config {}: Number of participants ({}) is not a multiple ' + 'of group size ({})' + ).format(session_config['name'], num_participants, session_lcm) + raise ValueError(msg) + if is_mturk: mturk_num_participants = ( - num_participants / - settings.MTURK_NUM_PARTICIPANTS_MULTIPLE) + num_participants / settings.MTURK_NUM_PARTICIPANTS_MULTIPLE + ) else: mturk_num_participants = -1 @@ -305,37 +287,31 @@ def create_session( label=label, is_demo=is_demo, num_participants=num_participants, - mturk_num_participants=mturk_num_participants - ) # type: Session - - # check that it divides evenly - session_lcm = session_config.get_lcm() - if num_participants % session_lcm: - msg = ( - 'Session Config {}: Number of participants ({}) is not a multiple ' - 'of group size ({})' - ).format(session_config['name'], num_participants, session_lcm) - raise ValueError(msg) + mturk_num_participants=mturk_num_participants, + ) # type: Session Participant.objects.bulk_create( [ Participant(id_in_session=id_in_session, session=session) - for id_in_session in list(range(1, num_participants+1)) + for id_in_session in list(range(1, num_participants + 1)) ] ) participant_values = session.participant_set.order_by('id').values('code', 'id') - ParticipantLockModel.objects.bulk_create([ - ParticipantLockModel(participant_code=participant['code']) - for participant in participant_values]) + ParticipantLockModel.objects.bulk_create( + [ + ParticipantLockModel(participant_code=participant['code']) + for participant in participant_values + ] + ) participant_to_player_lookups = [] page_index = 0 for app_name in session_config['app_sequence']: - views_module = common_internal.get_pages_module(app_name) + views_module = common.get_pages_module(app_name) models_module = get_models_module(app_name) Constants = models_module.Constants num_subsessions += Constants.num_rounds @@ -353,19 +329,21 @@ def create_session( ] ) - subsessions = Subsession.objects.filter( - session=session).order_by('round_number').values( - 'id', 'round_number') + subsessions = ( + Subsession.objects.filter(session=session) + .order_by('round_number') + .values('id', 'round_number') + ) ppg = Constants.players_per_group if ppg is None or Subsession._has_group_by_arrival_time(): ppg = num_participants - num_groups_per_round = int(num_participants/ppg) + num_groups_per_round = int(num_participants / ppg) groups_to_create = [] for subsession in subsessions: - for id_in_subsession in range(1, num_groups_per_round+1): + for id_in_subsession in range(1, num_groups_per_round + 1): groups_to_create.append( Group( session=session, @@ -377,9 +355,11 @@ def create_session( Group.objects.bulk_create(groups_to_create) - groups = Group.objects.filter(session=session).values( - 'id_in_subsession', 'subsession_id', 'id' - ).order_by('id_in_subsession') + groups = ( + Group.objects.filter(session=session) + .values('id_in_subsession', 'subsession_id', 'id') + .order_by('id_in_subsession') + ) groups_lookup = defaultdict(list) @@ -390,11 +370,11 @@ def create_session( players_to_create = [] for subsession in subsessions: - subsession_id=subsession['id'] - round_number=subsession['round_number'] + subsession_id = subsession['id'] + round_number = subsession['round_number'] participant_index = 0 for group_id in groups_lookup[subsession_id]: - for id_in_group in range(1, ppg+1): + for id_in_group in range(1, ppg + 1): participant = participant_values[participant_index] players_to_create.append( Player( @@ -403,7 +383,7 @@ def create_session( round_number=round_number, participant_id=participant['id'], group_id=group_id, - id_in_group=id_in_group + id_in_group=id_in_group, ) ) participant_index += 1 @@ -412,13 +392,16 @@ def create_session( Player.objects.bulk_create(players_to_create) players_flat = Player.objects.filter(session=session).values( - 'id', 'participant__code', 'participant__id', 'subsession__id', - 'round_number' + 'id', + 'participant__code', + 'participant__id', + 'subsession__id', + 'round_number', ) players_by_round = [[] for _ in range(Constants.num_rounds)] for p in players_flat: - players_by_round[p['round_number']-1].append(p) + players_by_round[p['round_number'] - 1].append(p) for round_number, round_players in enumerate(players_by_round, start=1): for View in views_module.page_sequence: @@ -430,7 +413,7 @@ def create_session( url = View.get_url( participant_code=participant_code, name_in_url=Constants.name_in_url, - page_index=page_index + page_index=page_index, ) participant_to_player_lookups.append( @@ -442,11 +425,11 @@ def create_session( player_pk=p['id'], subsession_pk=p['subsession__id'], session_pk=session.pk, - url=url)) + url=url, + ) + ) - ParticipantToPlayerLookup.objects.bulk_create( - participant_to_player_lookups - ) + ParticipantToPlayerLookup.objects.bulk_create(participant_to_player_lookups) session.participant_set.update(_max_page_index=page_index) with otree.db.idmap.use_cache(): @@ -469,6 +452,7 @@ def create_session( # this should happen after session.ready = True if room_name is not None: from otree.room import ROOM_DICT + room = ROOM_DICT[room_name] room.set_session(session) diff --git a/otree/static/otree/js/common.js b/otree/static/otree/js/common.js index 30415a0e8..4bf9f3dec 100644 --- a/otree/static/otree/js/common.js +++ b/otree/static/otree/js/common.js @@ -2,11 +2,11 @@ function makeReconnectingWebSocket(path) { // https://github.com/pladaria/reconnecting-websocket/issues/91#issuecomment-431244323 var ws_scheme = window.location.protocol === "https:" ? "wss" : "ws"; var ws_path = `${ws_scheme}://${window.location.host}${path}`; - var socket = new ReconnectingWebSocket(ws_path, '', {minReconnectionDelay: 1}); + var socket = new ReconnectingWebSocket(ws_path); socket.onclose = function (e) { if (e.code === 1011) { // this may or may not exist in child pages. - let serverErrorDiv = document.getElementById("websocket-server-error"); + var serverErrorDiv = document.getElementById("websocket-server-error"); if (serverErrorDiv) { // better to put the message here rather than the div, otherwise it's confusing when // you do "view source" and there's an error message. @@ -40,8 +40,8 @@ function makeReconnectingWebSocket(path) { // submission. $('#form').submit(function () { $('.otree-btn-next').each(function () { - let nextButton = this; - let originalState = nextButton.disabled; + var nextButton = this; + var originalState = nextButton.disabled; nextButton.disabled = true; setTimeout(function () { // restore original state. diff --git a/otree/static/otree/js/reconnecting-websocket-iife.min.js b/otree/static/otree/js/reconnecting-websocket-iife.min.js index 8c7a3167d..81cc9a9f2 100644 --- a/otree/static/otree/js/reconnecting-websocket-iife.min.js +++ b/otree/static/otree/js/reconnecting-websocket-iife.min.js @@ -1 +1 @@ -var ReconnectingWebSocket=function(){"use strict";var e=function(t,n){return(e=Object.setPrototypeOf||{__proto__:[]}instanceof Array&&function(e,t){e.__proto__=t}||function(e,t){for(var n in t)t.hasOwnProperty(n)&&(e[n]=t[n])})(t,n)};function t(t,n){function o(){this.constructor=t}e(t,n),t.prototype=null===n?Object.create(n):(o.prototype=n.prototype,new o)}var n=function(){return function(e,t){this.target=t,this.type=e}}(),o=function(e){function n(t,n){var o=e.call(this,"error",n)||this;return o.message=t.message,o.error=t,o}return t(n,e),n}(n),r=function(e){function n(t,n,o){void 0===t&&(t=1e3),void 0===n&&(n="");var r=e.call(this,"close",o)||this;return r.wasClean=!0,r.code=t,r.reason=n,r}return t(n,e),n}(n),i=function(){if("undefined"!=typeof WebSocket)return WebSocket},s={maxReconnectionDelay:1e4,minReconnectionDelay:1e3+4e3*Math.random(),minUptime:5e3,reconnectionDelayGrowFactor:1.3,connectionTimeout:4e3,maxRetries:1/0,debug:!1};return function(){function e(e,t,n){void 0===n&&(n={});var o=this;this._listeners={error:[],message:[],open:[],close:[]},this._retryCount=-1,this._shouldReconnect=!0,this._connectLock=!1,this._binaryType="blob",this._closeCalled=!1,this._messageQueue=[],this.onclose=void 0,this.onerror=void 0,this.onmessage=void 0,this.onopen=void 0,this._handleOpen=function(e){o._debug("open event");var t=o._options.minUptime,n=void 0===t?s.minUptime:t;clearTimeout(o._connectTimeout),o._uptimeTimeout=setTimeout(function(){return o._acceptOpen()},n),o._ws.binaryType=o._binaryType,o._messageQueue.forEach(function(e){return o._ws.send(e)}),o._messageQueue=[],o.onopen&&o.onopen(e),o._listeners.open.forEach(function(t){return o._callEventListener(e,t)})},this._handleMessage=function(e){o._debug("message event"),o.onmessage&&o.onmessage(e),o._listeners.message.forEach(function(t){return o._callEventListener(e,t)})},this._handleError=function(e){o._debug("error event",e.message),o._disconnect(void 0,"TIMEOUT"===e.message?"timeout":void 0),o.onerror&&o.onerror(e),o._debug("exec error listeners"),o._listeners.error.forEach(function(t){return o._callEventListener(e,t)}),o._connect()},this._handleClose=function(e){o._debug("close event"),o._clearTimeouts(),o._shouldReconnect&&o._connect(),o.onclose&&o.onclose(e),o._listeners.close.forEach(function(t){return o._callEventListener(e,t)})},this._url=e,this._protocols=t,this._options=n,this._connect()}return Object.defineProperty(e,"CONNECTING",{get:function(){return 0},enumerable:!0,configurable:!0}),Object.defineProperty(e,"OPEN",{get:function(){return 1},enumerable:!0,configurable:!0}),Object.defineProperty(e,"CLOSING",{get:function(){return 2},enumerable:!0,configurable:!0}),Object.defineProperty(e,"CLOSED",{get:function(){return 3},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"CONNECTING",{get:function(){return e.CONNECTING},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"OPEN",{get:function(){return e.OPEN},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"CLOSING",{get:function(){return e.CLOSING},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"CLOSED",{get:function(){return e.CLOSED},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"binaryType",{get:function(){return this._ws?this._ws.binaryType:this._binaryType},set:function(e){this._binaryType=e,this._ws&&(this._ws.binaryType=e)},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"retryCount",{get:function(){return Math.max(this._retryCount,0)},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"bufferedAmount",{get:function(){return this._messageQueue.reduce(function(e,t){return"string"==typeof t?e+=t.length:t instanceof Blob?e+=t.size:e+=t.byteLength,e},0)+(this._ws?this._ws.bufferedAmount:0)},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"extensions",{get:function(){return this._ws?this._ws.extensions:""},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"protocol",{get:function(){return this._ws?this._ws.protocol:""},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"readyState",{get:function(){return this._ws?this._ws.readyState:e.CONNECTING},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"url",{get:function(){return this._ws?this._ws.url:""},enumerable:!0,configurable:!0}),e.prototype.close=function(e,t){void 0===e&&(e=1e3),this._closeCalled=!0,this._shouldReconnect=!1,this._clearTimeouts(),this._ws?this._ws.readyState!==this.CLOSED?this._ws.close(e,t):this._debug("close: already closed"):this._debug("close enqueued: no ws instance")},e.prototype.reconnect=function(e,t){this._shouldReconnect=!0,this._closeCalled=!1,this._retryCount=-1,this._ws&&this._ws.readyState!==this.CLOSED?(this._disconnect(e,t),this._connect()):this._connect()},e.prototype.send=function(e){this._ws&&this._ws.readyState===this.OPEN?(this._debug("send",e),this._ws.send(e)):(this._debug("enqueue",e),this._messageQueue.push(e))},e.prototype.addEventListener=function(e,t){this._listeners[e]&&this._listeners[e].push(t)},e.prototype.removeEventListener=function(e,t){this._listeners[e]&&(this._listeners[e]=this._listeners[e].filter(function(e){return e!==t}))},e.prototype._debug=function(){for(var e=[],t=0;t"].concat(e))},e.prototype._getNextDelay=function(){var e=this._options,t=e.reconnectionDelayGrowFactor,n=void 0===t?s.reconnectionDelayGrowFactor:t,o=e.minReconnectionDelay,r=void 0===o?s.minReconnectionDelay:o,i=e.maxReconnectionDelay,c=void 0===i?s.maxReconnectionDelay:i,u=r;return this._retryCount>0&&(u=r*Math.pow(n,this._retryCount-1))>c&&(u=c),this._debug("next delay",u),u},e.prototype._wait=function(){var e=this;return new Promise(function(t){setTimeout(t,e._getNextDelay())})},e.prototype._getNextUrl=function(e){if("string"==typeof e)return Promise.resolve(e);if("function"==typeof e){var t=e();if("string"==typeof t)return Promise.resolve(t);if(t.then)return t}throw Error("Invalid URL")},e.prototype._connect=function(){var e=this;if(!this._connectLock&&this._shouldReconnect){this._connectLock=!0;var t=this._options,n=t.maxRetries,o=void 0===n?s.maxRetries:n,r=t.connectionTimeout,c=void 0===r?s.connectionTimeout:r,u=t.WebSocket,a=void 0===u?i():u;if(this._retryCount>=o)this._debug("max retries reached",this._retryCount,">=",o);else{if(this._retryCount++,this._debug("connect",this._retryCount),this._removeListeners(),"function"!=typeof(h=a)||2!==h.CLOSING)throw Error("No valid WebSocket class provided");var h;this._wait().then(function(){return e._getNextUrl(e._url)}).then(function(t){e._closeCalled?e._connectLock=!1:(e._debug("connect",{url:t,protocols:e._protocols}),e._ws=e._protocols?new a(t,e._protocols):new a(t),e._ws.binaryType=e._binaryType,e._connectLock=!1,e._addListeners(),e._connectTimeout=setTimeout(function(){return e._handleTimeout()},c))})}}},e.prototype._handleTimeout=function(){this._debug("timeout event"),this._handleError(new o(Error("TIMEOUT"),this))},e.prototype._disconnect=function(e,t){if(void 0===e&&(e=1e3),this._clearTimeouts(),this._ws){this._removeListeners();try{this._ws.close(e,t),this._handleClose(new r(e,t,this))}catch(e){}}},e.prototype._acceptOpen=function(){this._debug("accept open"),this._retryCount=0},e.prototype._callEventListener=function(e,t){"handleEvent"in t?t.handleEvent(e):t(e)},e.prototype._removeListeners=function(){this._ws&&(this._debug("removeListeners"),this._ws.removeEventListener("open",this._handleOpen),this._ws.removeEventListener("close",this._handleClose),this._ws.removeEventListener("message",this._handleMessage),this._ws.removeEventListener("error",this._handleError))},e.prototype._addListeners=function(){this._ws&&(this._debug("addListeners"),this._ws.addEventListener("open",this._handleOpen),this._ws.addEventListener("close",this._handleClose),this._ws.addEventListener("message",this._handleMessage),this._ws.addEventListener("error",this._handleError))},e.prototype._clearTimeouts=function(){clearTimeout(this._connectTimeout),clearTimeout(this._uptimeTimeout)},e}()}(); \ No newline at end of file +var ReconnectingWebSocket=function(){"use strict";var e=function(t,n){return(e=Object.setPrototypeOf||{__proto__:[]}instanceof Array&&function(e,t){e.__proto__=t}||function(e,t){for(var n in t)t.hasOwnProperty(n)&&(e[n]=t[n])})(t,n)};function t(t,n){function o(){this.constructor=t}e(t,n),t.prototype=null===n?Object.create(n):(o.prototype=n.prototype,new o)}function n(e,t){var n="function"==typeof Symbol&&e[Symbol.iterator];if(!n)return e;var o,r,i=n.call(e),s=[];try{for(;(void 0===t||t-- >0)&&!(o=i.next()).done;)s.push(o.value)}catch(e){r={error:e}}finally{try{o&&!o.done&&(n=i.return)&&n.call(i)}finally{if(r)throw r.error}}return s}var o=function(){return function(e,t){this.target=t,this.type=e}}(),r=function(e){function n(t,n){var o=e.call(this,"error",n)||this;return o.message=t.message,o.error=t,o}return t(n,e),n}(o),i=function(e){function n(t,n,o){void 0===t&&(t=1e3),void 0===n&&(n="");var r=e.call(this,"close",o)||this;return r.wasClean=!0,r.code=t,r.reason=n,r}return t(n,e),n}(o),s=function(){if("undefined"!=typeof WebSocket)return WebSocket},c={maxReconnectionDelay:1e4,minReconnectionDelay:1e3+4e3*Math.random(),minUptime:5e3,reconnectionDelayGrowFactor:1.3,connectionTimeout:4e3,maxRetries:1/0,maxEnqueuedMessages:1/0,startClosed:!1,debug:!1};return function(){function e(e,t,n){var o=this;void 0===n&&(n={}),this._listeners={error:[],message:[],open:[],close:[]},this._retryCount=-1,this._shouldReconnect=!0,this._connectLock=!1,this._binaryType="blob",this._closeCalled=!1,this._messageQueue=[],this.onclose=void 0,this.onerror=void 0,this.onmessage=void 0,this.onopen=void 0,this._handleOpen=function(e){o._debug("open event");var t=o._options.minUptime,n=void 0===t?c.minUptime:t;clearTimeout(o._connectTimeout),o._uptimeTimeout=setTimeout(function(){return o._acceptOpen()},n),o._ws.binaryType=o._binaryType,o._messageQueue.forEach(function(e){return o._ws.send(e)}),o._messageQueue=[],o.onopen&&o.onopen(e),o._listeners.open.forEach(function(t){return o._callEventListener(e,t)})},this._handleMessage=function(e){o._debug("message event"),o.onmessage&&o.onmessage(e),o._listeners.message.forEach(function(t){return o._callEventListener(e,t)})},this._handleError=function(e){o._debug("error event",e.message),o._disconnect(void 0,"TIMEOUT"===e.message?"timeout":void 0),o.onerror&&o.onerror(e),o._debug("exec error listeners"),o._listeners.error.forEach(function(t){return o._callEventListener(e,t)}),o._connect()},this._handleClose=function(e){o._debug("close event"),o._clearTimeouts(),o._shouldReconnect&&o._connect(),o.onclose&&o.onclose(e),o._listeners.close.forEach(function(t){return o._callEventListener(e,t)})},this._url=e,this._protocols=t,this._options=n,this._options.startClosed&&(this._shouldReconnect=!1),this._connect()}return Object.defineProperty(e,"CONNECTING",{get:function(){return 0},enumerable:!0,configurable:!0}),Object.defineProperty(e,"OPEN",{get:function(){return 1},enumerable:!0,configurable:!0}),Object.defineProperty(e,"CLOSING",{get:function(){return 2},enumerable:!0,configurable:!0}),Object.defineProperty(e,"CLOSED",{get:function(){return 3},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"CONNECTING",{get:function(){return e.CONNECTING},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"OPEN",{get:function(){return e.OPEN},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"CLOSING",{get:function(){return e.CLOSING},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"CLOSED",{get:function(){return e.CLOSED},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"binaryType",{get:function(){return this._ws?this._ws.binaryType:this._binaryType},set:function(e){this._binaryType=e,this._ws&&(this._ws.binaryType=e)},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"retryCount",{get:function(){return Math.max(this._retryCount,0)},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"bufferedAmount",{get:function(){return this._messageQueue.reduce(function(e,t){return"string"==typeof t?e+=t.length:t instanceof Blob?e+=t.size:e+=t.byteLength,e},0)+(this._ws?this._ws.bufferedAmount:0)},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"extensions",{get:function(){return this._ws?this._ws.extensions:""},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"protocol",{get:function(){return this._ws?this._ws.protocol:""},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"readyState",{get:function(){return this._ws?this._ws.readyState:this._options.startClosed?e.CLOSED:e.CONNECTING},enumerable:!0,configurable:!0}),Object.defineProperty(e.prototype,"url",{get:function(){return this._ws?this._ws.url:""},enumerable:!0,configurable:!0}),e.prototype.close=function(e,t){void 0===e&&(e=1e3),this._closeCalled=!0,this._shouldReconnect=!1,this._clearTimeouts(),this._ws?this._ws.readyState!==this.CLOSED?this._ws.close(e,t):this._debug("close: already closed"):this._debug("close enqueued: no ws instance")},e.prototype.reconnect=function(e,t){this._shouldReconnect=!0,this._closeCalled=!1,this._retryCount=-1,this._ws&&this._ws.readyState!==this.CLOSED?(this._disconnect(e,t),this._connect()):this._connect()},e.prototype.send=function(e){if(this._ws&&this._ws.readyState===this.OPEN)this._debug("send",e),this._ws.send(e);else{var t=this._options.maxEnqueuedMessages,n=void 0===t?c.maxEnqueuedMessages:t;this._messageQueue.length"],e))},e.prototype._getNextDelay=function(){var e=this._options,t=e.reconnectionDelayGrowFactor,n=void 0===t?c.reconnectionDelayGrowFactor:t,o=e.minReconnectionDelay,r=void 0===o?c.minReconnectionDelay:o,i=e.maxReconnectionDelay,s=void 0===i?c.maxReconnectionDelay:i,u=0;return this._retryCount>0&&(u=r*Math.pow(n,this._retryCount-1))>s&&(u=s),this._debug("next delay",u),u},e.prototype._wait=function(){var e=this;return new Promise(function(t){setTimeout(t,e._getNextDelay())})},e.prototype._getNextUrl=function(e){if("string"==typeof e)return Promise.resolve(e);if("function"==typeof e){var t=e();if("string"==typeof t)return Promise.resolve(t);if(t.then)return t}throw Error("Invalid URL")},e.prototype._connect=function(){var e=this;if(!this._connectLock&&this._shouldReconnect){this._connectLock=!0;var t=this._options,n=t.maxRetries,o=void 0===n?c.maxRetries:n,r=t.connectionTimeout,i=void 0===r?c.connectionTimeout:r,u=t.WebSocket,a=void 0===u?s():u;if(this._retryCount>=o)this._debug("max retries reached",this._retryCount,">=",o);else{if(this._retryCount++,this._debug("connect",this._retryCount),this._removeListeners(),void 0===(h=a)||!h||2!==h.CLOSING)throw Error("No valid WebSocket class provided");var h;this._wait().then(function(){return e._getNextUrl(e._url)}).then(function(t){e._closeCalled||(e._debug("connect",{url:t,protocols:e._protocols}),e._ws=e._protocols?new a(t,e._protocols):new a(t),e._ws.binaryType=e._binaryType,e._connectLock=!1,e._addListeners(),e._connectTimeout=setTimeout(function(){return e._handleTimeout()},i))})}}},e.prototype._handleTimeout=function(){this._debug("timeout event"),this._handleError(new r(Error("TIMEOUT"),this))},e.prototype._disconnect=function(e,t){if(void 0===e&&(e=1e3),this._clearTimeouts(),this._ws){this._removeListeners();try{this._ws.close(e,t),this._handleClose(new i(e,t,this))}catch(e){}}},e.prototype._acceptOpen=function(){this._debug("accept open"),this._retryCount=0},e.prototype._callEventListener=function(e,t){"handleEvent"in t?t.handleEvent(e):t(e)},e.prototype._removeListeners=function(){this._ws&&(this._debug("removeListeners"),this._ws.removeEventListener("open",this._handleOpen),this._ws.removeEventListener("close",this._handleClose),this._ws.removeEventListener("message",this._handleMessage),this._ws.removeEventListener("error",this._handleError))},e.prototype._addListeners=function(){this._ws&&(this._debug("addListeners"),this._ws.addEventListener("open",this._handleOpen),this._ws.addEventListener("close",this._handleClose),this._ws.addEventListener("message",this._handleMessage),this._ws.addEventListener("error",this._handleError))},e.prototype._clearTimeouts=function(){clearTimeout(this._connectTimeout),clearTimeout(this._uptimeTimeout)},e}()}(); \ No newline at end of file diff --git a/otree/staticfiles.py b/otree/staticfiles.py index e4b2aac07..ad031d3f1 100644 --- a/otree/staticfiles.py +++ b/otree/staticfiles.py @@ -12,7 +12,7 @@ class BackslashError(ValueError): class StaticNode(DjStaticNode): - def __init__(self, varname=None, path:FilterExpression=None): + def __init__(self, varname=None, path: FilterExpression = None): # path.token is the literal string, not the value of the variable # it resolves to, # so there should never be a \ diff --git a/otree/strict_templates.py b/otree/strict_templates.py index 22d81001b..a46e6d9d1 100644 --- a/otree/strict_templates.py +++ b/otree/strict_templates.py @@ -12,6 +12,7 @@ which are unlikely to rely on silent failures. ''' + def patch_filter_expression(): ''' don't allow code like {{ bogus }} or {{ player.bogus }} to fail silently @@ -38,8 +39,10 @@ def resolve(self, context, ignore_failures=False): # such as in django/forms/templates/django/forms/widgets: # {% if widget.attrs.class %} return original_resolve(self, context, ignore_failures=False) + FilterExpression.resolve = resolve + def patch_smartif(): ''' SmartIf is for if-statements with multiple tokens, like {% if bogus == 1 %} @@ -55,14 +58,18 @@ def patch_smartif(): def make_infix_eval(func): '''see infix()'s Operator.eval()''' + def new_eval(self, context): return func(context, self.first, self.second) + return new_eval def make_prefix_eval(func): '''see prefix()'s Operator.eval()''' + def new_eval(self, context): return func(context, self.first) + return new_eval infix_operators = { @@ -81,6 +88,7 @@ def new_eval(self, context): } from django.template.smartif import OPERATORS + for operator, func in infix_operators.items(): OPERATORS[operator].eval = make_infix_eval(func) OPERATORS['not'].eval = make_prefix_eval(lambda context, x: not x.eval(context)) @@ -97,6 +105,7 @@ def render_annotated(self, context): if context.template.engine.debug and not hasattr(e, 'template_debug'): e.template_debug = context.template.get_exception_info(e, self.token) raise + Node.render_annotated = render_annotated @@ -104,6 +113,3 @@ def patch_template_silent_failures(): patch_filter_expression() patch_smartif() patch_28935() - - - diff --git a/otree/templates/django/forms/widgets/input_option.html b/otree/templates/django/forms/widgets/input_option.html index 6e77a22c6..a384f1884 100644 --- a/otree/templates/django/forms/widgets/input_option.html +++ b/otree/templates/django/forms/widgets/input_option.html @@ -3,4 +3,4 @@ this was discovered when someone couldn't use RadioGridField. I have confirmed that this new version works with RadioGridField {% endcomment %} -{% if wrap_label|default:False %}{% endif %}{% include "django/forms/widgets/input.html" %}{% if wrap_label|default:False %} {{ widget.label }}{% endif %} \ No newline at end of file +{% if widget.wrap_label %}{% endif %}{% include "django/forms/widgets/input.html" %}{% if widget.wrap_label %} {{ widget.label }}{% endif %} diff --git a/otree/templates/otree/BaseAdmin.html b/otree/templates/otree/BaseAdmin.html index ab2f4d19f..db92be7c8 100644 --- a/otree/templates/otree/BaseAdmin.html +++ b/otree/templates/otree/BaseAdmin.html @@ -28,19 +28,6 @@ } }); - $.ajax({ - url: '{% url 'OtreeCoreUpdateCheck' %}', - type: 'GET', - success: function(data) { - if (!data.pypi_connection_error) { - if (data.update_needed) { - $('._otree-updates').html( - 'Update available') - } - } - } - }); - // i guess sockets must be global variables, and i should name it // something other than 'socket' because there might be other sockets // in descendant pages @@ -86,7 +73,6 @@ Rooms Data Server Check - {% if request.user.is_authenticated|default:False %} Logout diff --git a/otree/templates/otree/DemoIndex.html b/otree/templates/otree/DemoIndex.html index 2f62ce274..8bf4ef471 100644 --- a/otree/templates/otree/DemoIndex.html +++ b/otree/templates/otree/DemoIndex.html @@ -18,8 +18,7 @@ {% if is_debug %} - You can add entries to this list in - settings.py. + To add to this list, create a new session config. {% endif %} @@ -29,14 +28,6 @@ {% for s in session_info %} - {% if s.num_demo_participants == 1 %} - - {% elif s.num_demo_participants == 2 %} - - {% else %} - - {% endif %} - Participants: {{s.num_demo_participants}} {{ s.display_name }} {% endfor %} diff --git a/otree/templates/otree/MTurkHTMLQuestion.html b/otree/templates/otree/MTurkHTMLQuestion.html new file mode 100644 index 000000000..49127f7cc --- /dev/null +++ b/otree/templates/otree/MTurkHTMLQuestion.html @@ -0,0 +1,35 @@ + + + + {% include user_template %} + + +
- 16 | Please provide your information in the form below. - 17 |
settings.py