'Broadcast in agg when needed
In Polars, the select
and with_column
methods broadcast any scalars that they get, including literals:
import polars as pl
df.with_column(pl.lit(1).alias("y"))
# shape: (3, 2)
# ┌─────┬─────┐
# │ x ┆ y │
# │ --- ┆ --- │
# │ i64 ┆ i64 │
# ╞═════╪═════╡
# │ 1 ┆ 1 │
# ├╌╌╌╌╌┼╌╌╌╌╌┤
# │ 2 ┆ 1 │
# ├╌╌╌╌╌┼╌╌╌╌╌┤
# │ 3 ┆ 1 │
# └─────┴─────┘
The agg
method does not broadcast literals:
import polars as pl
df = pl.DataFrame(dict(x=[1,1,0,0])).groupby("x")
df.agg(pl.lit(1).alias("y"))
# exceptions.ComputeError: returned aggregation is a different length: 1 than the group lengths: 2
Is there an operation I can apply that will broadcast a scalar and ignore a non-scalar? Something like this:
df.agg(something(pl.lit(1)).alias("y"))
# shape: (2, 2)
# ┌─────┬─────┐
# │ x ┆ y │
# │ --- ┆ --- │
# │ i64 ┆ i64 │
# ╞═════╪═════╡
# │ 0 ┆ 1 │
# ├╌╌╌╌╌┼╌╌╌╌╌┤
# │ 1 ┆ 1 │
# └─────┴─────┘
Solution 1:[1]
You will need to use pl.repeat(1, pl.count())
to expand the literal to the group size.
(answer from the Polars Issue tracker - https://github.com/pola-rs/polars/issues/2987#issuecomment-1079617229)
Solution 2:[2]
df = pl.DataFrame(dict(x=[1,1,0,0])).groupby("x")
df.agg( pl.repeat(1, pl.count() ) ).explode(
pl.col('literal')
)
may be helpful
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | jvz |
Solution 2 |