Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/user-guide/common-operations/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ df = ctx.table("pokemon")
DataFusion offers mathematical functions such as {py:func}`~datafusion.functions.pow` or {py:func}`~datafusion.functions.log`

```{code-cell} ipython3
from datafusion import col, literal, string_literal, str_lit
from datafusion import col, literal
from datafusion import functions as f

df.select(
Expand Down Expand Up @@ -122,8 +122,8 @@ Casting expressions to different data types using {py:func}`~datafusion.function

```{code-cell} ipython3
df.select(
f.arrow_cast(col('"Total"'), string_literal("Float64")).alias("total_as_float"),
f.arrow_cast(col('"Total"'), str_lit("Int32")).alias("total_as_int")
f.arrow_cast(col('"Total"'), "Float64").alias("total_as_float"),
f.arrow_cast(col('"Total"'), "Int32").alias("total_as_int")
)
```

Expand Down
29 changes: 18 additions & 11 deletions python/datafusion/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,16 @@
from datafusion.functions import spark


def _warn_expr_for_literal_arg(function_name: str, arg_name: str) -> None:
warnings.warn(
f"Passing Expr for {function_name}() argument {arg_name!r} is deprecated; "
"pass a Python literal instead.",
DeprecationWarning,
stacklevel=4,
)
def _warn_if_expr_for_literal_arg(
value: Any, function_name: str, arg_name: str
) -> None:
if isinstance(value, Expr):
warnings.warn(
f"Passing Expr for {function_name}() argument {arg_name!r} is deprecated; "
"pass a Python literal instead.",
DeprecationWarning,
stacklevel=3,
)


__all__ = [
Expand Down Expand Up @@ -437,6 +440,7 @@ def encode(expr: Expr, encoding: Expr | str) -> Expr:
>>> result.collect_column("enc")[0].as_py()
'aGVsbG8'
"""
_warn_if_expr_for_literal_arg(encoding, "encode", "encoding")
encoding = coerce_to_expr(encoding)
return Expr(f.encode(expr.expr, encoding.expr))

Expand All @@ -452,6 +456,7 @@ def decode(expr: Expr, encoding: Expr | str) -> Expr:
>>> result.collect_column("dec")[0].as_py()
b'hello'
"""
_warn_if_expr_for_literal_arg(encoding, "decode", "encoding")
encoding = coerce_to_expr(encoding)
return Expr(f.decode(expr.expr, encoding.expr))

Expand Down Expand Up @@ -742,6 +747,7 @@ def digest(value: Expr, method: Expr | str) -> Expr:
>>> len(result.collect_column("d")[0].as_py()) > 0
True
"""
_warn_if_expr_for_literal_arg(method, "digest", "method")
method = coerce_to_expr(method)
return Expr(f.digest(value.expr, method.expr))

Expand Down Expand Up @@ -2723,8 +2729,7 @@ def date_part(part: Expr | str, date: Expr) -> Expr:


def _date_part(part: Expr | str, date: Expr, function_name: str) -> Expr:
if isinstance(part, Expr):
_warn_expr_for_literal_arg(function_name, "part")
_warn_if_expr_for_literal_arg(part, function_name, "part")
part = coerce_to_expr(part)
return Expr(f.date_part(part.expr, date.expr))

Expand Down Expand Up @@ -2760,8 +2765,7 @@ def date_trunc(part: Expr | str, date: Expr) -> Expr:


def _date_trunc(part: Expr | str, date: Expr, function_name: str) -> Expr:
if isinstance(part, Expr):
_warn_expr_for_literal_arg(function_name, "part")
_warn_if_expr_for_literal_arg(part, function_name, "part")
part = coerce_to_expr(part)
return Expr(f.date_trunc(part.expr, date.expr))

Expand Down Expand Up @@ -3096,6 +3100,7 @@ def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
>>> result.collect_column("c")[0].as_py()
1.0
"""
_warn_if_expr_for_literal_arg(data_type, "arrow_cast", "data_type")
if isinstance(data_type, pa.DataType):
return expr.cast(data_type)
if isinstance(data_type, str):
Expand Down Expand Up @@ -3128,6 +3133,7 @@ def arrow_try_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
>>> result.collect_column("c")[0].as_py() is None
True
"""
_warn_if_expr_for_literal_arg(data_type, "arrow_try_cast", "data_type")
if isinstance(data_type, pa.DataType):
return expr.try_cast(data_type)
if isinstance(data_type, str):
Expand Down Expand Up @@ -3235,6 +3241,7 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
"""
if key is None:
return Expr(f.arrow_metadata(expr.expr))
_warn_if_expr_for_literal_arg(key, "arrow_metadata", "key")
if isinstance(key, str):
key = Expr.string_literal(key)
return Expr(f.arrow_metadata(expr.expr, key.expr))
Expand Down
72 changes: 66 additions & 6 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ def test_string_functions(df, function, expected_result):

def test_hash_functions(df):
exprs = [
f.digest(column("a"), literal(m))
f.digest(column("a"), m)
for m in (
"md5",
"sha224",
Expand Down Expand Up @@ -1602,12 +1602,10 @@ def test_regr_funcs_df(func, expected):

def test_binary_string_functions(df):
df = df.select(
f.encode(column("a").cast(pa.string()), literal("base64").cast(pa.string())),
f.encode(column("a").cast(pa.string()), "base64"),
f.decode(
f.encode(
column("a").cast(pa.string()), literal("base64").cast(pa.string())
),
literal("base64").cast(pa.string()),
f.encode(column("a").cast(pa.string()), "base64"),
"base64",
),
)
result = df.collect()
Expand Down Expand Up @@ -2355,6 +2353,68 @@ def test_regexp_replace_native(self):
).collect()
assert result[0].column(0)[0].as_py() == "aX bX cX"

@pytest.mark.parametrize(
("func", "arg_name", "expr"),
[
pytest.param(
f.encode,
"encoding",
lambda: f.encode(column("a"), literal("base64")),
id="encode-encoding",
),
pytest.param(
f.decode,
"encoding",
lambda: f.decode(column("a"), literal("base64")),
id="decode-encoding",
),
pytest.param(
f.digest,
"method",
lambda: f.digest(column("a"), literal("sha256")),
id="digest-method",
),
pytest.param(
f.arrow_cast,
"data_type",
lambda: f.arrow_cast(column("a"), literal("Float64")),
id="arrow-cast-data-type",
),
pytest.param(
f.arrow_try_cast,
"data_type",
lambda: f.arrow_try_cast(column("a"), literal("Float64")),
id="arrow-try-cast-data-type",
),
pytest.param(
f.arrow_metadata,
"key",
lambda: f.arrow_metadata(column("a"), literal("k")),
id="arrow-metadata-key",
),
],
)
def test_literal_only_expr_args_warn_deprecated(self, func, arg_name, expr):
with pytest.warns(
DeprecationWarning,
match=(
rf"Passing Expr for {func.__name__}\(\) argument "
rf"'{arg_name}' is deprecated"
),
):
result = expr()
assert result is not None

def test_literal_only_native_args_do_not_warn(self):
with warnings.catch_warnings():
warnings.simplefilter("error", DeprecationWarning)
assert f.encode(column("a"), "base64") is not None
assert f.decode(column("a"), "base64") is not None
assert f.digest(column("a"), "sha256") is not None
assert f.arrow_cast(column("a"), "Float64") is not None
assert f.arrow_try_cast(column("a"), pa.float64()) is not None
assert f.arrow_metadata(column("a"), "k") is not None

def test_backward_compat_with_lit(self):
"""Verify that existing code using lit() still works."""
ctx = SessionContext()
Expand Down
Loading