diff --git a/docs/source/user-guide/common-operations/functions.md b/docs/source/user-guide/common-operations/functions.md index 50411825d..9939e2f00 100644 --- a/docs/source/user-guide/common-operations/functions.md +++ b/docs/source/user-guide/common-operations/functions.md @@ -52,7 +52,7 @@ df = ctx.table("pokemon") DataFusion offers mathematical functions such as {py:func}`~datafusion.functions.pow` or {py:func}`~datafusion.functions.log` ```{code-cell} ipython3 -from datafusion import col, literal, string_literal, str_lit +from datafusion import col, literal from datafusion import functions as f df.select( @@ -122,8 +122,8 @@ Casting expressions to different data types using {py:func}`~datafusion.function ```{code-cell} ipython3 df.select( - f.arrow_cast(col('"Total"'), string_literal("Float64")).alias("total_as_float"), - f.arrow_cast(col('"Total"'), str_lit("Int32")).alias("total_as_int") + f.arrow_cast(col('"Total"'), "Float64").alias("total_as_float"), + f.arrow_cast(col('"Total"'), "Int32").alias("total_as_int") ) ``` diff --git a/python/datafusion/functions/__init__.py b/python/datafusion/functions/__init__.py index 54783f086..9fb20dc15 100644 --- a/python/datafusion/functions/__init__.py +++ b/python/datafusion/functions/__init__.py @@ -64,13 +64,16 @@ from datafusion.functions import spark -def _warn_expr_for_literal_arg(function_name: str, arg_name: str) -> None: - warnings.warn( - f"Passing Expr for {function_name}() argument {arg_name!r} is deprecated; " - "pass a Python literal instead.", - DeprecationWarning, - stacklevel=4, - ) +def _warn_if_expr_for_literal_arg( + value: Any, function_name: str, arg_name: str +) -> None: + if isinstance(value, Expr): + warnings.warn( + f"Passing Expr for {function_name}() argument {arg_name!r} is deprecated; " + "pass a Python literal instead.", + DeprecationWarning, + stacklevel=3, + ) __all__ = [ @@ -437,6 +440,7 @@ def encode(expr: Expr, encoding: Expr | str) -> Expr: >>> result.collect_column("enc")[0].as_py() 'aGVsbG8' """ + _warn_if_expr_for_literal_arg(encoding, "encode", "encoding") encoding = coerce_to_expr(encoding) return Expr(f.encode(expr.expr, encoding.expr)) @@ -452,6 +456,7 @@ def decode(expr: Expr, encoding: Expr | str) -> Expr: >>> result.collect_column("dec")[0].as_py() b'hello' """ + _warn_if_expr_for_literal_arg(encoding, "decode", "encoding") encoding = coerce_to_expr(encoding) return Expr(f.decode(expr.expr, encoding.expr)) @@ -742,6 +747,7 @@ def digest(value: Expr, method: Expr | str) -> Expr: >>> len(result.collect_column("d")[0].as_py()) > 0 True """ + _warn_if_expr_for_literal_arg(method, "digest", "method") method = coerce_to_expr(method) return Expr(f.digest(value.expr, method.expr)) @@ -2723,8 +2729,7 @@ def date_part(part: Expr | str, date: Expr) -> Expr: def _date_part(part: Expr | str, date: Expr, function_name: str) -> Expr: - if isinstance(part, Expr): - _warn_expr_for_literal_arg(function_name, "part") + _warn_if_expr_for_literal_arg(part, function_name, "part") part = coerce_to_expr(part) return Expr(f.date_part(part.expr, date.expr)) @@ -2760,8 +2765,7 @@ def date_trunc(part: Expr | str, date: Expr) -> Expr: def _date_trunc(part: Expr | str, date: Expr, function_name: str) -> Expr: - if isinstance(part, Expr): - _warn_expr_for_literal_arg(function_name, "part") + _warn_if_expr_for_literal_arg(part, function_name, "part") part = coerce_to_expr(part) return Expr(f.date_trunc(part.expr, date.expr)) @@ -3096,6 +3100,7 @@ def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr: >>> result.collect_column("c")[0].as_py() 1.0 """ + _warn_if_expr_for_literal_arg(data_type, "arrow_cast", "data_type") if isinstance(data_type, pa.DataType): return expr.cast(data_type) if isinstance(data_type, str): @@ -3128,6 +3133,7 @@ def arrow_try_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr: >>> result.collect_column("c")[0].as_py() is None True """ + _warn_if_expr_for_literal_arg(data_type, "arrow_try_cast", "data_type") if isinstance(data_type, pa.DataType): return expr.try_cast(data_type) if isinstance(data_type, str): @@ -3235,6 +3241,7 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr: """ if key is None: return Expr(f.arrow_metadata(expr.expr)) + _warn_if_expr_for_literal_arg(key, "arrow_metadata", "key") if isinstance(key, str): key = Expr.string_literal(key) return Expr(f.arrow_metadata(expr.expr, key.expr)) diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 43dd70660..894d81fcf 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -1014,7 +1014,7 @@ def test_string_functions(df, function, expected_result): def test_hash_functions(df): exprs = [ - f.digest(column("a"), literal(m)) + f.digest(column("a"), m) for m in ( "md5", "sha224", @@ -1602,12 +1602,10 @@ def test_regr_funcs_df(func, expected): def test_binary_string_functions(df): df = df.select( - f.encode(column("a").cast(pa.string()), literal("base64").cast(pa.string())), + f.encode(column("a").cast(pa.string()), "base64"), f.decode( - f.encode( - column("a").cast(pa.string()), literal("base64").cast(pa.string()) - ), - literal("base64").cast(pa.string()), + f.encode(column("a").cast(pa.string()), "base64"), + "base64", ), ) result = df.collect() @@ -2355,6 +2353,68 @@ def test_regexp_replace_native(self): ).collect() assert result[0].column(0)[0].as_py() == "aX bX cX" + @pytest.mark.parametrize( + ("func", "arg_name", "expr"), + [ + pytest.param( + f.encode, + "encoding", + lambda: f.encode(column("a"), literal("base64")), + id="encode-encoding", + ), + pytest.param( + f.decode, + "encoding", + lambda: f.decode(column("a"), literal("base64")), + id="decode-encoding", + ), + pytest.param( + f.digest, + "method", + lambda: f.digest(column("a"), literal("sha256")), + id="digest-method", + ), + pytest.param( + f.arrow_cast, + "data_type", + lambda: f.arrow_cast(column("a"), literal("Float64")), + id="arrow-cast-data-type", + ), + pytest.param( + f.arrow_try_cast, + "data_type", + lambda: f.arrow_try_cast(column("a"), literal("Float64")), + id="arrow-try-cast-data-type", + ), + pytest.param( + f.arrow_metadata, + "key", + lambda: f.arrow_metadata(column("a"), literal("k")), + id="arrow-metadata-key", + ), + ], + ) + def test_literal_only_expr_args_warn_deprecated(self, func, arg_name, expr): + with pytest.warns( + DeprecationWarning, + match=( + rf"Passing Expr for {func.__name__}\(\) argument " + rf"'{arg_name}' is deprecated" + ), + ): + result = expr() + assert result is not None + + def test_literal_only_native_args_do_not_warn(self): + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + assert f.encode(column("a"), "base64") is not None + assert f.decode(column("a"), "base64") is not None + assert f.digest(column("a"), "sha256") is not None + assert f.arrow_cast(column("a"), "Float64") is not None + assert f.arrow_try_cast(column("a"), pa.float64()) is not None + assert f.arrow_metadata(column("a"), "k") is not None + def test_backward_compat_with_lit(self): """Verify that existing code using lit() still works.""" ctx = SessionContext()