Skip to content

Commit

Permalink
[SPARK-47669][SQL][CONNECT][PYTHON] Add Column.try_cast
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Add `try_cast` function in Column APIs

### Why are the changes needed?
for functionality parity

### Does this PR introduce _any_ user-facing change?
yes

```
        >>> from pyspark.sql import functions as sf
        >>> df = spark.createDataFrame(
        ...      [(2, "123"), (5, "Bob"), (3, None)], ["age", "name"])
        >>> df.select(df.name.try_cast("double")).show()
        +-----+
        | name|
        +-----+
        |123.0|
        | NULL|
        | NULL|
        +-----+

```

### How was this patch tested?
new tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#45796 from zhengruifeng/connect_try_cast.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed Apr 3, 2024
1 parent 27fdf96 commit 55b5ff6
Show file tree
Hide file tree
Showing 14 changed files with 301 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,41 @@ class Column private[sql] (@DeveloperApi val expr: proto.Expression) extends Log
*/
def cast(to: String): Column = cast(DataTypeParser.parseDataType(to))

/**
* Casts the column to a different data type and the result is null on failure.
* {{{
* // Casts colA to IntegerType.
* import org.apache.spark.sql.types.IntegerType
* df.select(df("colA").try_cast(IntegerType))
*
* // equivalent to
* df.select(df("colA").try_cast("int"))
* }}}
*
* @group expr_ops
* @since 4.0.0
*/
def try_cast(to: DataType): Column = Column { builder =>
builder.getCastBuilder
.setExpr(expr)
.setType(DataTypeProtoConverter.toConnectProtoType(to))
.setEvalMode(proto.Expression.Cast.EvalMode.EVAL_MODE_TRY)
}

/**
* Casts the column to a different data type and the result is null on failure.
* {{{
* // Casts colA to integer.
* df.select(df("colA").try_cast("int"))
* }}}
*
* @group expr_ops
* @since 4.0.0
*/
def try_cast(to: String): Column = {
try_cast(DataTypeParser.parseDataType(to))
}

/**
* Returns a sort expression based on the descending order of the column.
* {{{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,10 @@ class PlanGenerationTestSuite
fn.col("a").cast("long")
}

columnTest("try_cast") {
fn.col("a").try_cast("long")
}

orderColumnTest("desc") {
fn.col("b").desc
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@ message Expression {
// If this is set, Server will use Catalyst parser to parse this string to DataType.
string type_str = 3;
}

// (Optional) The expression evaluation mode.
EvalMode eval_mode = 4;

enum EvalMode {
EVAL_MODE_UNSPECIFIED = 0;
EVAL_MODE_LEGACY = 1;
EVAL_MODE_ANSI = 2;
EVAL_MODE_TRY = 3;
}
}

message Literal {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [try_cast(a#0 as bigint) AS a#0L]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"cast": {
"expr": {
"unresolvedAttribute": {
"unparsedIdentifier": "a"
}
},
"type": {
"long": {
}
},
"evalMode": "EVAL_MODE_TRY"
}
}]
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -2097,11 +2097,19 @@ class SparkConnectPlanner(
}

private def transformCast(cast: proto.Expression.Cast): Expression = {
cast.getCastToTypeCase match {
case proto.Expression.Cast.CastToTypeCase.TYPE =>
Cast(transformExpression(cast.getExpr), transformDataType(cast.getType))
case _ =>
Cast(transformExpression(cast.getExpr), parser.parseDataType(cast.getTypeStr))
val dataType = cast.getCastToTypeCase match {
case proto.Expression.Cast.CastToTypeCase.TYPE => transformDataType(cast.getType)
case _ => parser.parseDataType(cast.getTypeStr)
}
val mode = cast.getEvalMode match {
case proto.Expression.Cast.EvalMode.EVAL_MODE_LEGACY => Some(EvalMode.LEGACY)
case proto.Expression.Cast.EvalMode.EVAL_MODE_ANSI => Some(EvalMode.ANSI)
case proto.Expression.Cast.EvalMode.EVAL_MODE_TRY => Some(EvalMode.TRY)
case _ => None
}
mode match {
case Some(m) => Cast(transformExpression(cast.getExpr), dataType, None, m)
case _ => Cast(transformExpression(cast.getExpr), dataType)
}
}

Expand Down
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.sql/column.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,6 @@ Column
Column.rlike
Column.startswith
Column.substr
Column.try_cast
Column.when
Column.withField
62 changes: 62 additions & 0 deletions python/pyspark/sql/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,68 @@ def cast(self, dataType: Union[DataType, str]) -> "Column":
)
return Column(jc)

def try_cast(self, dataType: Union[DataType, str]) -> "Column":
"""
This is a special version of `cast` that performs the same operation, but returns a NULL
value instead of raising an error if the invoke method throws exception.
.. versionadded:: 4.0.0
Parameters
----------
dataType : :class:`DataType` or str
a DataType or Python string literal with a DDL-formatted string
to use when parsing the column to the same type.
Returns
-------
:class:`Column`
Column representing whether each element of Column is cast into new type.
Examples
--------
Example 1: Cast with a Datatype
>>> from pyspark.sql.types import LongType
>>> df = spark.createDataFrame(
... [(2, "123"), (5, "Bob"), (3, None)], ["age", "name"])
>>> df.select(df.name.try_cast(LongType())).show()
+----+
|name|
+----+
| 123|
|NULL|
|NULL|
+----+
Example 2: Cast with a DDL string
>>> df = spark.createDataFrame(
... [(2, "123"), (5, "Bob"), (3, None)], ["age", "name"])
>>> df.select(df.name.try_cast("double")).show()
+-----+
| name|
+-----+
|123.0|
| NULL|
| NULL|
+-----+
"""
if isinstance(dataType, str):
jc = self._jc.try_cast(dataType)
elif isinstance(dataType, DataType):
from pyspark.sql import SparkSession

spark = SparkSession._getActiveSessionOrCreate()
jdt = spark._jsparkSession.parseDataType(dataType.json())
jc = self._jc.try_cast(jdt)
else:
raise PySparkTypeError(
error_class="NOT_DATATYPE_OR_STR",
message_parameters={"arg_name": "dataType", "arg_type": type(dataType).__name__},
)
return Column(jc)

def astype(self, dataType: Union[DataType, str]) -> "Column":
"""
:func:`astype` is an alias for :func:`cast`.
Expand Down
17 changes: 17 additions & 0 deletions python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,23 @@ def cast(self, dataType: Union[DataType, str]) -> "Column":

astype = cast

def try_cast(self, dataType: Union[DataType, str]) -> "Column":
if isinstance(dataType, (DataType, str)):
return Column(
CastExpression(
expr=self._expr,
data_type=dataType,
eval_mode="try",
)
)
else:
raise PySparkTypeError(
error_class="NOT_DATATYPE_OR_STR",
message_parameters={"arg_name": "dataType", "arg_type": type(dataType).__name__},
)

try_cast.__doc__ = PySparkColumn.try_cast.__doc__

def __repr__(self) -> str:
return "Column<'%s'>" % self._expr.__repr__()

Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,10 +837,15 @@ def __init__(
self,
expr: Expression,
data_type: Union[DataType, str],
eval_mode: Optional[str] = None,
) -> None:
super().__init__()
self._expr = expr
self._data_type = data_type
if eval_mode is not None:
assert isinstance(eval_mode, str)
assert eval_mode in ["legacy", "ansi", "try"]
self._eval_mode = eval_mode

def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
fun = proto.Expression()
Expand All @@ -849,6 +854,15 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
fun.cast.type_str = self._data_type
else:
fun.cast.type.CopyFrom(pyspark_types_to_proto_types(self._data_type))

if self._eval_mode is not None:
if self._eval_mode == "legacy":
fun.cast.eval_mode = proto.Expression.Cast.EvalMode.EVAL_MODE_LEGACY
elif self._eval_mode == "ansi":
fun.cast.eval_mode = proto.Expression.Cast.EvalMode.EVAL_MODE_ANSI
elif self._eval_mode == "try":
fun.cast.eval_mode = proto.Expression.Cast.EvalMode.EVAL_MODE_TRY

return fun

def __repr__(self) -> str:
Expand Down
96 changes: 49 additions & 47 deletions python/pyspark/sql/connect/proto/expressions_pb2.py

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions python/pyspark/sql/connect/proto/expressions_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -309,22 +309,48 @@ class Expression(google.protobuf.message.Message):
class Cast(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

class _EvalMode:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType

class _EvalModeEnumTypeWrapper(
google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[
Expression.Cast._EvalMode.ValueType
],
builtins.type,
): # noqa: F821
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
EVAL_MODE_UNSPECIFIED: Expression.Cast._EvalMode.ValueType # 0
EVAL_MODE_LEGACY: Expression.Cast._EvalMode.ValueType # 1
EVAL_MODE_ANSI: Expression.Cast._EvalMode.ValueType # 2
EVAL_MODE_TRY: Expression.Cast._EvalMode.ValueType # 3

class EvalMode(_EvalMode, metaclass=_EvalModeEnumTypeWrapper): ...
EVAL_MODE_UNSPECIFIED: Expression.Cast.EvalMode.ValueType # 0
EVAL_MODE_LEGACY: Expression.Cast.EvalMode.ValueType # 1
EVAL_MODE_ANSI: Expression.Cast.EvalMode.ValueType # 2
EVAL_MODE_TRY: Expression.Cast.EvalMode.ValueType # 3

EXPR_FIELD_NUMBER: builtins.int
TYPE_FIELD_NUMBER: builtins.int
TYPE_STR_FIELD_NUMBER: builtins.int
EVAL_MODE_FIELD_NUMBER: builtins.int
@property
def expr(self) -> global___Expression:
"""(Required) the expression to be casted."""
@property
def type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
type_str: builtins.str
"""If this is set, Server will use Catalyst parser to parse this string to DataType."""
eval_mode: global___Expression.Cast.EvalMode.ValueType
"""(Optional) The expression evaluation mode."""
def __init__(
self,
*,
expr: global___Expression | None = ...,
type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
type_str: builtins.str = ...,
eval_mode: global___Expression.Cast.EvalMode.ValueType = ...,
) -> None: ...
def HasField(
self,
Expand All @@ -344,6 +370,8 @@ class Expression(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"cast_to_type",
b"cast_to_type",
"eval_mode",
b"eval_mode",
"expr",
b"expr",
"type",
Expand Down
37 changes: 37 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,43 @@ class Column(val expr: Expression) extends Logging {
*/
def cast(to: String): Column = cast(CatalystSqlParser.parseDataType(to))

/**
* Casts the column to a different data type and the result is null on failure.
* {{{
* // Casts colA to IntegerType.
* import org.apache.spark.sql.types.IntegerType
* df.select(df("colA").try_cast(IntegerType))
*
* // equivalent to
* df.select(df("colA").try_cast("int"))
* }}}
*
* @group expr_ops
* @since 4.0.0
*/
def try_cast(to: DataType): Column = withExpr {
val cast = Cast(
child = expr,
dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(to),
evalMode = EvalMode.TRY)
cast.setTagValue(Cast.USER_SPECIFIED_CAST, ())
cast
}

/**
* Casts the column to a different data type and the result is null on failure.
* {{{
* // Casts colA to integer.
* df.select(df("colA").try_cast("int"))
* }}}
*
* @group expr_ops
* @since 4.0.0
*/
def try_cast(to: String): Column = {
try_cast(CatalystSqlParser.parseDataType(to))
}

/**
* Returns a sort expression based on the descending order of the column.
* {{{
Expand Down

0 comments on commit 55b5ff6

Please sign in to comment.