diff --git a/modin/pandas/dataframe.py b/modin/pandas/dataframe.py index 856b135387e..2ce83913ebb 100644 --- a/modin/pandas/dataframe.py +++ b/modin/pandas/dataframe.py @@ -2074,20 +2074,17 @@ def squeeze( Squeeze 1 dimensional axis objects into scalars. """ axis = self._get_axis_number(axis) if axis is not None else None - len_columns = len(self.columns) - if axis == 1 and len_columns == 1: + if axis is None and (len(self.columns) == 1 or len(self) == 1): + return Series(query_compiler=self._query_compiler).squeeze() + if axis == 1 and len(self.columns) == 1: self._query_compiler._shape_hint = "column" return Series(query_compiler=self._query_compiler) - if axis in [0, None]: - # Only compute the length of the index if axis is 0 or None. - len_index = len(self) - if axis is None and (len_columns == 1 or len_index == 1): - return Series(query_compiler=self._query_compiler).squeeze() - if axis == 0 and len_index == 1: - qc = self.T._query_compiler - qc._shape_hint = "column" - return Series(query_compiler=qc) - return self.copy() + if axis == 0 and len(self) == 1: + qc = self.T._query_compiler + qc._shape_hint = "column" + return Series(query_compiler=qc) + else: + return self.copy() def stack( self, level=-1, dropna=lib.no_default, sort=lib.no_default, future_stack=False