Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(datafusion): add map methods to datafusion compiler #10510

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
17 changes: 13 additions & 4 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
{
"build": { "dockerfile": "Dockerfile", "context": ".." },
"build": {
"dockerfile": "Dockerfile",
"context": ".."
},
"containerUser": "vscode",
"remoteUser": "vscode",
"postStartCommand": "git config --global --add safe.directory ${containerWorkspaceFolder}",
"workspaceFolder": "/app",
"customizations": {
"codespaces": {
"openFiles": ["docs/tutorials/getting_started.qmd"]
"openFiles": [
"docs/tutorials/getting_started.qmd"
]
},
"vscode": {
"extensions": ["ms-toolsai.jupyter", "ms-python.python", "quarto.quarto"]
"extensions": [
"ms-toolsai.jupyter",
"ms-python.python",
"quarto.quarto"
]
}
},
"features": {
Expand All @@ -24,4 +33,4 @@
"yqVersion": "latest"
}
}
}
}
2 changes: 1 addition & 1 deletion .devcontainer/postCreate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ python3 -m pip install ipython

# avoid using dynamic versioning by grabbing the version from pyproject.toml
POETRY_DYNAMIC_VERSIONING_BYPASS="$(yq '.tool.poetry.version' pyproject.toml)" \
python3 -m pip install -e '.[duckdb,clickhouse,examples,geospatial]'
python3 -m pip install -e '.[duckdb,clickhouse,examples,geospatial,datafusion]'
34 changes: 33 additions & 1 deletion ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
ops.RowID,
ops.Strftime,
ops.TimeDelta,
ops.TimestampBucket,
# ops.TimestampBucket,
ops.TimestampDelta,
ops.TypeOf,
ops.StringToDate,
Expand All @@ -67,6 +67,10 @@
ops.EndsWith: "ends_with",
ops.ArrayIntersect: "array_intersect",
ops.ArrayUnion: "array_union",
ops.MapKeys: "map_keys",
ops.MapValues: "map_values",
ops.MapLength: "cardinality",
ops.IsNull: "ifnull",
}

def _to_timestamp(self, value, target_dtype, literal=False):
Expand Down Expand Up @@ -541,5 +545,33 @@
map(partial(self.cast, to=op.dtype), arg),
)

def visit_MapGet(self, op, *, arg, key, default):
return self.if_(

Check warning on line 549 in ibis/backends/sql/compilers/datafusion.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/datafusion.py#L549

Added line #L549 was not covered by tests
sg.or_(arg.is_(NULL), key.is_(NULL)),
NULL,
self.f.ifnull(
self.f.list_extract(self.f.map_extract(arg, key), 1),
default,
),
)

def visit_MapContains(self, op, *, arg, key):
return self.if_(

Check warning on line 559 in ibis/backends/sql/compilers/datafusion.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/datafusion.py#L559

Added line #L559 was not covered by tests
sg.or_(arg.is_(NULL), key.is_(NULL)),
NULL,
self.f.list_contains(self.f.map_keys(arg), key),
)

# ops.MapMerge: "mapUpdate", ## need to implement this as a visitor node

def visit_TimestampBucket(self, op, *, arg, interval, offset):
# https://datafusion.apache.org/user-guide/sql/scalar_functions.html#date-bin
origin = self.f.cast(

Check warning on line 569 in ibis/backends/sql/compilers/datafusion.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/datafusion.py#L569

Added line #L569 was not covered by tests
"1970-01-01T00:00:00Z", self.type_mapper.from_ibis(dt.timestamp)
)
if offset is not None:
origin += offset
return self.f.date_bin(interval, arg, origin)

Check warning on line 574 in ibis/backends/sql/compilers/datafusion.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/datafusion.py#L573-L574

Added lines #L573 - L574 were not covered by tests


compiler = DataFusionCompiler()
27 changes: 0 additions & 27 deletions ibis/backends/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
param(None, None, id="null_both"),
],
)
@mark_notyet_datafusion
def test_map_nulls(con, k, v):
k = ibis.literal(k, type="array<string>")
v = ibis.literal(v, type="array<string>")
Expand All @@ -79,7 +78,6 @@ def test_map_nulls(con, k, v):
param(None, None, id="null_both"),
],
)
@mark_notyet_datafusion
def test_map_keys_nulls(con, k, v):
k = ibis.literal(k, type="array<string>")
v = ibis.literal(v, type="array<string>")
Expand Down Expand Up @@ -112,7 +110,6 @@ def test_map_keys_nulls(con, k, v):
param(ibis.literal(None, type="map<string, string>"), id="null_map"),
],
)
@mark_notyet_datafusion
def test_map_values_nulls(con, map):
assert con.execute(map.values()) is None

Expand Down Expand Up @@ -181,7 +178,6 @@ def test_map_values_nulls(con, map):
],
)
@pytest.mark.parametrize("method", ["get", "contains"])
@mark_notyet_datafusion
def test_map_get_contains_nulls(con, map, key, method):
expr = getattr(map, method)
assert con.execute(expr(key)) is None
Expand Down Expand Up @@ -219,15 +215,13 @@ def test_map_merge_nulls(con, m1, m2):
assert con.execute(concatted) is None


@mark_notyet_datafusion
def test_map_table(backend):
table = backend.map
assert table.kv.type().is_map()
assert not table.limit(1).execute().empty


@mark_notimpl_risingwave_hstore
@mark_notyet_datafusion
def test_column_map_values(backend):
table = backend.map
expr = table.select("idx", vals=table.kv.values()).order_by("idx")
Expand All @@ -254,7 +248,6 @@ def test_column_map_merge(backend):


@mark_notimpl_risingwave_hstore
@mark_notyet_datafusion
def test_literal_map_keys(con):
mapping = ibis.literal({"1": "a", "2": "b"})
expr = mapping.keys().name("tmp")
Expand All @@ -266,7 +259,6 @@ def test_literal_map_keys(con):


@mark_notimpl_risingwave_hstore
@mark_notyet_datafusion
def test_literal_map_values(con):
mapping = ibis.literal({"1": "a", "2": "b"})
expr = mapping.values().name("tmp")
Expand All @@ -277,7 +269,6 @@ def test_literal_map_values(con):

@mark_notimpl_risingwave_hstore
@mark_notyet_postgres
@mark_notyet_datafusion
def test_scalar_isin_literal_map_keys(con):
mapping = ibis.literal({"a": 1, "b": 2})
a = ibis.literal("a")
Expand All @@ -290,7 +281,6 @@ def test_scalar_isin_literal_map_keys(con):

@mark_notimpl_risingwave_hstore
@mark_notyet_postgres
@mark_notyet_datafusion
def test_map_scalar_contains_key_scalar(con):
mapping = ibis.literal({"a": 1, "b": 2})
a = ibis.literal("a")
Expand All @@ -302,7 +292,6 @@ def test_map_scalar_contains_key_scalar(con):


@mark_notimpl_risingwave_hstore
@mark_notyet_datafusion
def test_map_scalar_contains_key_column(backend, alltypes, df):
value = {"1": "a", "3": "c"}
mapping = ibis.literal(value)
Expand All @@ -314,7 +303,6 @@ def test_map_scalar_contains_key_column(backend, alltypes, df):

@mark_notimpl_risingwave_hstore
@mark_notyet_postgres
@mark_notyet_datafusion
def test_map_column_contains_key_scalar(backend, alltypes, df):
expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]))
series = df.apply(lambda row: {row["string_col"]: row["int_col"]}, axis=1)
Expand All @@ -327,7 +315,6 @@ def test_map_column_contains_key_scalar(backend, alltypes, df):

@mark_notimpl_risingwave_hstore
@mark_notyet_postgres
@mark_notyet_datafusion
def test_map_column_contains_key_column(alltypes):
map_expr = ibis.map(
ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col])
Expand All @@ -352,7 +339,6 @@ def test_literal_map_merge(con):


@mark_notimpl_risingwave_hstore
@mark_notyet_datafusion
def test_literal_map_getitem_broadcast(backend, alltypes, df):
value = {"1": "a", "2": "b"}

Expand Down Expand Up @@ -499,7 +485,6 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df):
@values
@keys
@mark_notimpl_risingwave_hstore
@mark_notyet_datafusion
def test_map_get_all_types(con, keys, values):
m = ibis.map(ibis.array(keys), ibis.array(values))
for key, val in zip(keys, values):
Expand All @@ -510,7 +495,6 @@ def test_map_get_all_types(con, keys, values):

@keys
@mark_notimpl_risingwave_hstore
@mark_notyet_datafusion
def test_map_contains_all_types(con, keys):
a = ibis.array(keys)
m = ibis.map(a, a)
Expand All @@ -519,7 +503,6 @@ def test_map_contains_all_types(con, keys):


@mark_notimpl_risingwave_hstore
@mark_notyet_datafusion
def test_literal_map_get_broadcast(backend, alltypes, df):
value = {"1": "a", "2": "b"}

Expand Down Expand Up @@ -571,7 +554,6 @@ def test_map_construct_array_column(con, alltypes, df):

@mark_notimpl_risingwave_hstore
@mark_notyet_postgres
@mark_notyet_datafusion
def test_map_get_with_compatible_value_smaller(con):
value = ibis.literal({"A": 1000, "B": 2000})
expr = value.get("C", 3)
Expand All @@ -580,7 +562,6 @@ def test_map_get_with_compatible_value_smaller(con):

@mark_notimpl_risingwave_hstore
@mark_notyet_postgres
@mark_notyet_datafusion
def test_map_get_with_compatible_value_bigger(con):
value = ibis.literal({"A": 1, "B": 2})
expr = value.get("C", 3000)
Expand All @@ -589,7 +570,6 @@ def test_map_get_with_compatible_value_bigger(con):

@mark_notimpl_risingwave_hstore
@mark_notyet_postgres
@mark_notyet_datafusion
def test_map_get_with_incompatible_value_different_kind(con):
value = ibis.literal({"A": 1000, "B": 2000})
expr = value.get("C", 3.0)
Expand All @@ -598,7 +578,6 @@ def test_map_get_with_incompatible_value_different_kind(con):

@mark_notimpl_risingwave_hstore
@mark_notyet_postgres
@mark_notyet_datafusion
@pytest.mark.parametrize("null_value", [None, ibis.null()])
def test_map_get_with_null_on_not_nullable(con, null_value):
map_type = dt.Map(dt.string, dt.Int16(nullable=False))
Expand All @@ -613,7 +592,6 @@ def test_map_get_with_null_on_not_nullable(con, null_value):
["flink"], raises=Py4JJavaError, reason="Flink cannot handle typeless nulls"
)
@mark_notimpl_risingwave_hstore
@mark_notyet_datafusion
def test_map_get_with_null_on_null_type_with_null(con, null_value):
value = ibis.literal({"A": None, "B": None})
expr = value.get("C", null_value)
Expand All @@ -626,7 +604,6 @@ def test_map_get_with_null_on_null_type_with_null(con, null_value):
)
@mark_notimpl_risingwave_hstore
@mark_notyet_postgres
@mark_notyet_datafusion
def test_map_get_with_null_on_null_type_with_non_null(con):
value = ibis.literal({"A": None, "B": None})
expr = value.get("C", 1)
Expand All @@ -639,7 +616,6 @@ def test_map_get_with_null_on_null_type_with_non_null(con):
reason="`tbl_properties` is required when creating table with schema",
)
@mark_notimpl_risingwave_hstore
@mark_notyet_datafusion
def test_map_create_table(con, temp_table):
t = con.create_table(
temp_table,
Expand All @@ -654,21 +630,18 @@ def test_map_create_table(con, temp_table):
reason="No translation rule for <class 'ibis.expr.operations.maps.MapLength'>",
)
@mark_notimpl_risingwave_hstore
@mark_notyet_datafusion
def test_map_length(con):
expr = ibis.literal(dict(a="A", b="B")).length()
assert con.execute(expr) == 2


@mark_notyet_datafusion
def test_map_keys_unnest(backend):
expr = backend.map.kv.keys().unnest()
result = expr.to_pandas()
assert frozenset(result) == frozenset("abcdef")


@mark_notimpl_risingwave_hstore
@mark_notyet_datafusion
def test_map_contains_null(con):
expr = ibis.map(["a"], ibis.literal([None], type="array<string>"))
assert con.execute(expr.contains("a"))
Expand Down
Loading