python - Filtering from index and comparing row value with all values in column - Stack Overflow

时间: 2025-01-06 admin 业界

Starting with this DataFrame:

df_1 = pl.DataFrame({
    'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
    'index': [0, 3, 4, 7, 9],
    'limit': [12, 18, 11, 5, 9],
    'price': [10, 15, 12, 8, 11]
})

┌───────┬───────┬───────┬───────┐
│ name  ┆ index ┆ limit ┆ price │
│ ---   ┆   --- ┆   --- ┆   --- │
│ str   ┆   i64 ┆   i64 ┆   i64 │
╞═══════╪═══════╪═══════╪═══════╡
│ Alpha ┆     0 ┆    12 ┆    10 │
│ Alpha ┆     3 ┆    18 ┆    15 │
│ Alpha ┆     4 ┆    11 ┆    12 │
│ Alpha ┆     7 ┆     5 ┆     8 │
│ Alpha ┆     9 ┆     9 ┆    11 │
└───────┴───────┴───────┴───────┘

I need to add a new column to tell me at which index (greater than the current one) the price is equal or higher than the current limit.

With this example above, the expected output is:

┌───────┬───────┬───────┬───────┬───────────┐
│ name  ┆ index ┆ limit ┆ price ┆ min_index │
│ ---   ┆   --- ┆   --- ┆   --- ┆       --- │
│ str   ┆   i64 ┆   i64 ┆   i64 ┆       i64 │
╞═══════╪═══════╪═══════╪═══════╪═══════════╡
│ Alpha ┆     0 ┆    12 ┆    10 ┆         3 │
│ Alpha ┆     3 ┆    18 ┆    15 ┆      null │
│ Alpha ┆     4 ┆    11 ┆    12 ┆         9 │
│ Alpha ┆     7 ┆     5 ┆     8 ┆         9 │
│ Alpha ┆     9 ┆     9 ┆    11 ┆      null │
└───────┴───────┴───────┴───────┴───────────┘

Explaining the "min_index" column results:

  • 1st row, where the limit is 12: from the 2nd row onwards, the minimum index whose price is equal or greater than 12 is 3.
  • 2nd row, where the limit is 18: from the 3rd row onwards, there is no index whose price is equal or greater than 18.
  • 3rd row, where the limit is 11: from the 4th row onwards, the minimum index whose price is equal or greater than 11 is 9.
  • 4th row, where the limit is 5: from the 5th row onwards, the minimum index whose price is equal or greater than 5 is 9.
  • 5th row, where the limit is 9: as this is the last row, there is no further index whose price is equal or greater than 9.

My solution is shown below - but what would be a neat Polars way of doing it? I was able to solve it in 8 steps, but I'm sure there is a more effective way of doing it.

# Import Polars.
import polars as pl

# Create a sample DataFrame.
df_1 = pl.DataFrame({
    'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
    'index': [0, 3, 4, 7, 9],
    'limit': [12, 18, 11, 5, 9],
    'price': [10, 15, 12, 8, 11]
})

# Group by name, so that we can vertically stack all row's values into a single list.
df_2 = df_1.group_by('name').agg(pl.all())

# Put the lists with the original DataFrame.
df_3 = df_1.join(
    other=df_2,
    on='name',
    suffix='_list'
)

# Explode the dataframe to long format by exploding the given columns.
df_3 = df_3.explode([
    'index_list',
    'limit_list',
    'price_list',
])

# Filter the DataFrame for the condition we want.
df_3 = df_3.filter(
    (pl.col('index_list') > pl.col('index')) &
    (pl.col('price_list') >= pl.col('limit'))
)

# Get the minimum index over the index column.
df_3 = df_3.with_columns(
    pl.col('index_list').min().over('index').alias('min_index')
)

# Select only the relevant columns and drop duplicates.
df_3 = df_3.select(
    pl.col(['index', 'min_index'])
).unique()

# Finally join the result.
df_final = df_1.join(
    other=df_3,
    on='index',
    how='left'
)

print(df_final)

Starting with this DataFrame:

df_1 = pl.DataFrame({
    'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
    'index': [0, 3, 4, 7, 9],
    'limit': [12, 18, 11, 5, 9],
    'price': [10, 15, 12, 8, 11]
})

┌───────┬───────┬───────┬───────┐
│ name  ┆ index ┆ limit ┆ price │
│ ---   ┆   --- ┆   --- ┆   --- │
│ str   ┆   i64 ┆   i64 ┆   i64 │
╞═══════╪═══════╪═══════╪═══════╡
│ Alpha ┆     0 ┆    12 ┆    10 │
│ Alpha ┆     3 ┆    18 ┆    15 │
│ Alpha ┆     4 ┆    11 ┆    12 │
│ Alpha ┆     7 ┆     5 ┆     8 │
│ Alpha ┆     9 ┆     9 ┆    11 │
└───────┴───────┴───────┴───────┘

I need to add a new column to tell me at which index (greater than the current one) the price is equal or higher than the current limit.

With this example above, the expected output is:

┌───────┬───────┬───────┬───────┬───────────┐
│ name  ┆ index ┆ limit ┆ price ┆ min_index │
│ ---   ┆   --- ┆   --- ┆   --- ┆       --- │
│ str   ┆   i64 ┆   i64 ┆   i64 ┆       i64 │
╞═══════╪═══════╪═══════╪═══════╪═══════════╡
│ Alpha ┆     0 ┆    12 ┆    10 ┆         3 │
│ Alpha ┆     3 ┆    18 ┆    15 ┆      null │
│ Alpha ┆     4 ┆    11 ┆    12 ┆         9 │
│ Alpha ┆     7 ┆     5 ┆     8 ┆         9 │
│ Alpha ┆     9 ┆     9 ┆    11 ┆      null │
└───────┴───────┴───────┴───────┴───────────┘

Explaining the "min_index" column results:

  • 1st row, where the limit is 12: from the 2nd row onwards, the minimum index whose price is equal or greater than 12 is 3.
  • 2nd row, where the limit is 18: from the 3rd row onwards, there is no index whose price is equal or greater than 18.
  • 3rd row, where the limit is 11: from the 4th row onwards, the minimum index whose price is equal or greater than 11 is 9.
  • 4th row, where the limit is 5: from the 5th row onwards, the minimum index whose price is equal or greater than 5 is 9.
  • 5th row, where the limit is 9: as this is the last row, there is no further index whose price is equal or greater than 9.

My solution is shown below - but what would be a neat Polars way of doing it? I was able to solve it in 8 steps, but I'm sure there is a more effective way of doing it.

# Import Polars.
import polars as pl

# Create a sample DataFrame.
df_1 = pl.DataFrame({
    'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
    'index': [0, 3, 4, 7, 9],
    'limit': [12, 18, 11, 5, 9],
    'price': [10, 15, 12, 8, 11]
})

# Group by name, so that we can vertically stack all row's values into a single list.
df_2 = df_1.group_by('name').agg(pl.all())

# Put the lists with the original DataFrame.
df_3 = df_1.join(
    other=df_2,
    on='name',
    suffix='_list'
)

# Explode the dataframe to long format by exploding the given columns.
df_3 = df_3.explode([
    'index_list',
    'limit_list',
    'price_list',
])

# Filter the DataFrame for the condition we want.
df_3 = df_3.filter(
    (pl.col('index_list') > pl.col('index')) &
    (pl.col('price_list') >= pl.col('limit'))
)

# Get the minimum index over the index column.
df_3 = df_3.with_columns(
    pl.col('index_list').min().over('index').alias('min_index')
)

# Select only the relevant columns and drop duplicates.
df_3 = df_3.select(
    pl.col(['index', 'min_index'])
).unique()

# Finally join the result.
df_final = df_1.join(
    other=df_3,
    on='index',
    how='left'
)

print(df_final)
Share Improve this question edited 21 hours ago jonrsharpe 122k30 gold badges263 silver badges469 bronze badges asked 21 hours ago Danilo SettonDanilo Setton 70511 silver badges21 bronze badges
Add a comment  | 

1 Answer 1

Reset to default 2

Option 1: df.join_where (experimental)

out = (
    df_1.join(
        df_1
        .join_where(
            df_1.select('index', 'price'),
            pl.col('index_right') > pl.col('index'),
            pl.col('price_right') >= pl.col('limit')
        )
        .group_by('index')
        .agg(
            pl.col('index_right').min().alias('min_index')
            ),
        on='index',
        how='left'
    )
)

Output:

shape: (5, 5)
┌───────┬───────┬───────┬───────┬───────────┐
│ name  ┆ index ┆ limit ┆ price ┆ min_index │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---       │
│ str   ┆ i64   ┆ i64   ┆ i64   ┆ i64       │
╞═══════╪═══════╪═══════╪═══════╪═══════════╡
│ Alpha ┆ 0     ┆ 12    ┆ 10    ┆ 3         │
│ Alpha ┆ 3     ┆ 18    ┆ 15    ┆ null      │
│ Alpha ┆ 4     ┆ 11    ┆ 12    ┆ 9         │
│ Alpha ┆ 7     ┆ 5     ┆ 8     ┆ 9         │
│ Alpha ┆ 9     ┆ 9     ┆ 11    ┆ null      │
└───────┴───────┴───────┴───────┴───────────┘

Explanation / Intermediates

  • Use df.join_where and for other use df.select (note that you don't need 'limit'), adding the filter predicates.
# df_1.join_where(...)

shape: (4, 6)
┌───────┬───────┬───────┬───────┬─────────────┬─────────────┐
│ name  ┆ index ┆ limit ┆ price ┆ index_right ┆ price_right │
│ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---         ┆ ---         │
│ str   ┆ i64   ┆ i64   ┆ i64   ┆ i64         ┆ i64         │
╞═══════╪═══════╪═══════╪═══════╪═════════════╪═════════════╡
│ Alpha ┆ 0     ┆ 12    ┆ 10    ┆ 3           ┆ 15          │
│ Alpha ┆ 0     ┆ 12    ┆ 10    ┆ 4           ┆ 12          │
│ Alpha ┆ 4     ┆ 11    ┆ 12    ┆ 9           ┆ 11          │
│ Alpha ┆ 7     ┆ 5     ┆ 8     ┆ 9           ┆ 11          │
└───────┴───────┴───────┴───────┴─────────────┴─────────────┘
  • Since order is not maintained, use df.group_by to retrieve pl.Expr.min per 'index'.
# df_1.join_where(...).group_by('index').agg(...)

shape: (3, 2)
┌───────┬───────────┐
│ index ┆ min_index │
│ ---   ┆ ---       │
│ i64   ┆ i64       │
╞═══════╪═══════════╡
│ 0     ┆ 3         │
│ 7     ┆ 9         │
│ 4     ┆ 9         │
└───────┴───────────┘
  • The result we add to df_1 with a left join.

Option 2: df.join with "cross" + df.filter

(Adding this option, since df.join_where is experimental. This will be more expensive though.)

out2 = (
    df_1.join(
        df_1
        .join(df_1.select('index', 'price'), how='cross')
        .filter(
            pl.col('index_right') > pl.col('index'),
            pl.col('price_right') >= pl.col('limit')
        )
        .group_by('index')
        .agg(
            pl.col('index_right').min().alias('min_index')
        ),
        on='index',
        how='left'
    )
)

out2.equals(out)
# True