Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): support rollup, cube and grouping_sets APIs #9945

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .codespellrc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[codespell]
# local codespell matches `./docs`, pre-commit codespell matches `docs`
skip = *.lock,.direnv,.git,./docs/_freeze,./docs/_output/**,./docs/_inv/**,docs/_freeze/**,*.svg,*.css,*.html,*.js,ibis/backends/tests/tpc/queries/duckdb/ds/*.sql
ignore-regex = \b(i[if]f|I[IF]F|AFE|alls)\b
ignore-regex = \b(i[if]f|I[IF]F|AFE|alls|ND)\b
builtin = clear,rare,names
ignore-words-list = tim,notin,ang
42 changes: 37 additions & 5 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,13 +1400,42 @@ def visit_JoinLink(self, op, *, how, table, predicates):
def _generate_groups(groups):
return map(sge.convert, range(1, len(groups) + 1))

def visit_Aggregate(self, op, *, parent, groups, metrics):
sel = sg.select(
*self._cleanup_names(groups), *self._cleanup_names(metrics), copy=False
def _compile_agg_select(self, op, *, parent, keys, metrics):
return sg.select(
*self._cleanup_names(keys), *self._cleanup_names(metrics), copy=False
).from_(parent, copy=False)

if groups:
sel = sel.group_by(*self._generate_groups(groups.values()), copy=False)
def _compile_group_by(self, sel, *, groups, grouping_sets, rollups, cubes):
expressions = list(self._generate_groups(groups.values()))
group = sge.Group(
expressions=expressions,
grouping_sets=[
sge.GroupingSets(
expressions=[
sge.Tuple(expressions=expressions)
for expressions in grouping_set
]
)
for grouping_set in grouping_sets
],
rollup=[sge.Rollup(expressions=rollup) for rollup in rollups],
cube=[sge.Cube(expressions=cube) for cube in cubes],
)
return sel.group_by(group, copy=False)

def visit_Aggregate(
self, op, *, parent, keys, groups, metrics, grouping_sets, rollups, cubes
):
sel = self._compile_agg_select(op, parent=parent, keys=keys, metrics=metrics)

if groups or grouping_sets or rollups or cubes:
sel = self._compile_group_by(
sel,
groups=groups,
grouping_sets=grouping_sets,
rollups=rollups,
cubes=cubes,
)

return sel

Expand Down Expand Up @@ -1609,6 +1638,9 @@ def _make_sample_backwards_compatible(self, *, sample, parent):
parent.args["sample"] = sample
return sg.select(STAR).from_(parent)

def visit_GroupID(self, op, *, arg):
return self.f.grouping(*arg)


# `__init_subclass__` is uncalled for subclasses - we manually call it here to
# autogenerate the base class implementations as well.
Expand Down
17 changes: 6 additions & 11 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,12 +464,12 @@ def visit_ArgMax(self, op, *, arg, key, where):
arg, where=where, order_by=[sge.Ordered(this=key, desc=True)]
)

def visit_Aggregate(self, op, *, parent, groups, metrics):
def _compile_agg_select(self, op, *, parent, keys, metrics):
"""Support `GROUP BY` expressions in `SELECT` since DataFusion does not."""
quoted = self.quoted
metrics = tuple(self._cleanup_names(metrics))

if groups:
if keys:
# datafusion doesn't support count distinct aggregations alongside
# computed grouping keys so create a projection of the key and all
# existing columns first, followed by the usual group by
Expand All @@ -484,11 +484,11 @@ def visit_Aggregate(self, op, *, parent, groups, metrics):
),
# can't use set subtraction here since the schema keys'
# order matters and set subtraction doesn't preserve order
(k for k in op.parent.schema.keys() if k not in groups),
(k for k in op.parent.schema.keys() if k not in keys),
)
)
table = (
sg.select(*cols, *self._cleanup_names(groups))
sg.select(*cols, *self._cleanup_names(keys))
.from_(parent)
.subquery(parent.alias)
)
Expand All @@ -497,19 +497,14 @@ def visit_Aggregate(self, op, *, parent, groups, metrics):
# quoted=True is required here for correctness
by_names_quoted = tuple(
sg.column(key, table=getattr(value, "table", None), quoted=quoted)
for key, value in groups.items()
for key, value in keys.items()
)
selections = by_names_quoted + metrics
else:
selections = metrics or (STAR,)
table = parent

sel = sg.select(*selections).from_(table)

if groups:
sel = sel.group_by(*by_names_quoted)

return sel
return sg.select(*selections).from_(table)

def visit_StructColumn(self, op, *, names, values):
args = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ FROM (
FROM "countries" AS "t0"
) AS t0
GROUP BY
"cont"
1
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ WITH "t5" AS (
) AS "t4"
) AS t4
GROUP BY
"t4"."field_of_study"
1
)
SELECT
*
Expand Down
91 changes: 91 additions & 0 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import itertools
import sqlite3
from datetime import date
from operator import methodcaller

Expand All @@ -10,6 +11,7 @@
import ibis
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.selectors as s
from ibis import _
from ibis import literal as L
from ibis.backends.tests.errors import (
Expand All @@ -19,7 +21,9 @@
GoogleBadRequest,
ImpalaHiveServer2Error,
MySQLNotSupportedError,
MySQLOperationalError,
OracleDatabaseError,
PolarsColumnNotFoundError,
PolarsInvalidOperationError,
PsycoPg2InternalError,
Py4JError,
Expand Down Expand Up @@ -1735,3 +1739,90 @@ def test_group_by_scalar(alltypes, df, value):
result = expr.execute()
n = result["n"].values[0].item()
assert n == len(df)


@pytest.fixture(scope="session")
def grouping_set_table():
return ibis.memtable(
{
"a": [1, 1, 2, 2, 3, 5],
"b": ["a", "a", "b", "a", "a", "c"],
"c": [12, 10, 5, 7, 5, 2],
}
)


@pytest.mark.notyet(["sqlite"], raises=sqlite3.OperationalError)
@pytest.mark.notyet(["mysql"], raises=MySQLOperationalError)
@pytest.mark.notyet(["polars"], raises=PolarsColumnNotFoundError)
@pytest.mark.notyet(["druid"])
def test_cube(con, backend, grouping_set_table):
expr = (
grouping_set_table.group_by(ibis.cube("b"))
.agg(sum_a=lambda t: t.a.sum())
.order_by(s.all())
)

result = con.to_pandas(expr)
expected = pd.DataFrame({"b": ["a", "b", "c", None], "sum_a": [7, 2, 5, 14]})

backend.assert_frame_equal(result, expected)


@pytest.mark.notyet(["sqlite"], raises=sqlite3.OperationalError)
@pytest.mark.notyet(["mysql"], raises=MySQLOperationalError)
@pytest.mark.notyet(["polars"], raises=PolarsColumnNotFoundError)
@pytest.mark.notyet(["druid"])
def test_rollup(con, backend, grouping_set_table):
expr = (
grouping_set_table.group_by(ibis.rollup("b", "c"))
.agg(sum_a=lambda t: t.a.sum())
.order_by(s.all())
)

result = con.to_pandas(expr)
expected = pd.DataFrame(
{
"b": ["a"] * 5 + ["b"] * 2 + ["c"] * 2 + [None],
"c": [5, 7, 10, 12, None, 5, None, 2, None, None],
"sum_a": [3, 2, 1, 1, 7, 2, 2, 5, 5, 14],
}
)

backend.assert_frame_equal(result, expected)


@pytest.mark.notyet(["sqlite"], raises=sqlite3.OperationalError)
@pytest.mark.notyet(["mysql"], raises=MySQLOperationalError)
@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError)
@pytest.mark.notyet(["druid"], raises=Exception)
@pytest.mark.notyet(
[
"exasol",
"oracle",
"clickhouse",
"datafusion",
"risingwave",
"mssql",
"pyspark",
"impala",
"snowflake",
"bigquery",
],
raises=AssertionError,
reason="returns empty for a grouping set on an empty table, which disagrees with half the backends",
)
def test_grouping_empty_table(con, backend):
# TODO(cpcloud): group_id doesn't allow string columns
t = ibis.memtable({"c1": []}, schema={"c1": "int"})
expr = (
t.group_by(ibis.cube("c1"))
.agg(gid=lambda t: ibis.group_id(t.c1))
.order_by(s.first())
)
result = con.to_pandas(expr)

assert len(result) == 1
assert len(result.columns) == 2
assert pd.isnull(result.at[0, "c1"])
assert result.at[0, "gid"] == 1
Loading
Loading