From 60f951201cab49227991f9a4678d2b4a88bd22e8 Mon Sep 17 00:00:00 2001 From: arunjose696 Date: Wed, 14 Aug 2024 15:10:35 +0200 Subject: [PATCH] PR comments and tests --- .../pandas/query_compiler_validator.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/modin/core/storage_formats/pandas/query_compiler_validator.py b/modin/core/storage_formats/pandas/query_compiler_validator.py index 3b2001f09d9..57925c62b2f 100644 --- a/modin/core/storage_formats/pandas/query_compiler_validator.py +++ b/modin/core/storage_formats/pandas/query_compiler_validator.py @@ -23,6 +23,8 @@ from types import FunctionType, MethodType from typing import Any, Dict, Tuple, TypeVar +from pandas.core.indexes.frozen import FrozenList + from modin.core.storage_formats.base.query_compiler import BaseQueryCompiler Fn = TypeVar("Fn", bound=Any) @@ -56,10 +58,13 @@ def cast_arg_to_current_qc(arg): else: return arg - if isinstance(arguments, tuple): + imutable_types = (FrozenList, tuple) + if isinstance(arguments, imutable_types): + args_type = type(arguments) arguments = list(arguments) arguments = cast_nested_args_to_current_qc_type(arguments, current_qc) - return tuple(arguments) + + return args_type(arguments) if isinstance(arguments, list): for i in range(len(arguments)): if isinstance(arguments[i], (list, dict)): @@ -88,8 +93,6 @@ def apply_argument_cast(): def decorator(obj: Fn) -> Fn: """Decorate function or class to cast all arguments that are query compilers to the current query compiler""" if isinstance(obj, type): - seen: Dict[Any, Any] = {} - all_attrs = dict(inspect.getmembers(obj)) all_attrs.pop("__abstractmethods__") @@ -103,10 +106,7 @@ def decorator(obj: Fn) -> Fn: if isinstance( attr_value, (FunctionType, MethodType, classmethod, staticmethod) ): - try: - wrapped = seen[attr_value] - except KeyError: - wrapped = seen[attr_value] = apply_argument_cast()(attr_value) + wrapped = apply_argument_cast()(attr_value) setattr(obj, attr_name, wrapped) return obj # type: ignore [return-value] elif isinstance(obj, classmethod):