python - Filtering from index and comparing row value with all values in column - Stack Overflow
Starting with this DataFrame:
df_1 = pl.DataFrame({
'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
'index': [0, 3, 4, 7, 9],
'limit': [12, 18, 11, 5, 9],
'price': [10, 15, 12, 8, 11]
})
┌───────┬───────┬───────┬───────┐
│ name ┆ index ┆ limit ┆ price │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════╪═══════╪═══════╡
│ Alpha ┆ 0 ┆ 12 ┆ 10 │
│ Alpha ┆ 3 ┆ 18 ┆ 15 │
│ Alpha ┆ 4 ┆ 11 ┆ 12 │
│ Alpha ┆ 7 ┆ 5 ┆ 8 │
│ Alpha ┆ 9 ┆ 9 ┆ 11 │
└───────┴───────┴───────┴───────┘
I need to add a new column to tell me at which index (greater than the current one) the price is equal or higher than the current limit.
With this example above, the expected output is:
┌───────┬───────┬───────┬───────┬───────────┐
│ name ┆ index ┆ limit ┆ price ┆ min_index │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════╪═══════╪═══════╪═══════════╡
│ Alpha ┆ 0 ┆ 12 ┆ 10 ┆ 3 │
│ Alpha ┆ 3 ┆ 18 ┆ 15 ┆ null │
│ Alpha ┆ 4 ┆ 11 ┆ 12 ┆ 9 │
│ Alpha ┆ 7 ┆ 5 ┆ 8 ┆ 9 │
│ Alpha ┆ 9 ┆ 9 ┆ 11 ┆ null │
└───────┴───────┴───────┴───────┴───────────┘
Explaining the "min_index" column results:
- 1st row, where the limit is 12: from the 2nd row onwards, the minimum index whose price is equal or greater than 12 is 3.
- 2nd row, where the limit is 18: from the 3rd row onwards, there is no index whose price is equal or greater than 18.
- 3rd row, where the limit is 11: from the 4th row onwards, the minimum index whose price is equal or greater than 11 is 9.
- 4th row, where the limit is 5: from the 5th row onwards, the minimum index whose price is equal or greater than 5 is 9.
- 5th row, where the limit is 9: as this is the last row, there is no further index whose price is equal or greater than 9.
My solution is shown below - but what would be a neat Polars way of doing it? I was able to solve it in 8 steps, but I'm sure there is a more effective way of doing it.
# Import Polars.
import polars as pl
# Create a sample DataFrame.
df_1 = pl.DataFrame({
'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
'index': [0, 3, 4, 7, 9],
'limit': [12, 18, 11, 5, 9],
'price': [10, 15, 12, 8, 11]
})
# Group by name, so that we can vertically stack all row's values into a single list.
df_2 = df_1.group_by('name').agg(pl.all())
# Put the lists with the original DataFrame.
df_3 = df_1.join(
other=df_2,
on='name',
suffix='_list'
)
# Explode the dataframe to long format by exploding the given columns.
df_3 = df_3.explode([
'index_list',
'limit_list',
'price_list',
])
# Filter the DataFrame for the condition we want.
df_3 = df_3.filter(
(pl.col('index_list') > pl.col('index')) &
(pl.col('price_list') >= pl.col('limit'))
)
# Get the minimum index over the index column.
df_3 = df_3.with_columns(
pl.col('index_list').min().over('index').alias('min_index')
)
# Select only the relevant columns and drop duplicates.
df_3 = df_3.select(
pl.col(['index', 'min_index'])
).unique()
# Finally join the result.
df_final = df_1.join(
other=df_3,
on='index',
how='left'
)
print(df_final)
Starting with this DataFrame:
df_1 = pl.DataFrame({
'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
'index': [0, 3, 4, 7, 9],
'limit': [12, 18, 11, 5, 9],
'price': [10, 15, 12, 8, 11]
})
┌───────┬───────┬───────┬───────┐
│ name ┆ index ┆ limit ┆ price │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════╪═══════╪═══════╡
│ Alpha ┆ 0 ┆ 12 ┆ 10 │
│ Alpha ┆ 3 ┆ 18 ┆ 15 │
│ Alpha ┆ 4 ┆ 11 ┆ 12 │
│ Alpha ┆ 7 ┆ 5 ┆ 8 │
│ Alpha ┆ 9 ┆ 9 ┆ 11 │
└───────┴───────┴───────┴───────┘
I need to add a new column to tell me at which index (greater than the current one) the price is equal or higher than the current limit.
With this example above, the expected output is:
┌───────┬───────┬───────┬───────┬───────────┐
│ name ┆ index ┆ limit ┆ price ┆ min_index │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════╪═══════╪═══════╪═══════════╡
│ Alpha ┆ 0 ┆ 12 ┆ 10 ┆ 3 │
│ Alpha ┆ 3 ┆ 18 ┆ 15 ┆ null │
│ Alpha ┆ 4 ┆ 11 ┆ 12 ┆ 9 │
│ Alpha ┆ 7 ┆ 5 ┆ 8 ┆ 9 │
│ Alpha ┆ 9 ┆ 9 ┆ 11 ┆ null │
└───────┴───────┴───────┴───────┴───────────┘
Explaining the "min_index" column results:
- 1st row, where the limit is 12: from the 2nd row onwards, the minimum index whose price is equal or greater than 12 is 3.
- 2nd row, where the limit is 18: from the 3rd row onwards, there is no index whose price is equal or greater than 18.
- 3rd row, where the limit is 11: from the 4th row onwards, the minimum index whose price is equal or greater than 11 is 9.
- 4th row, where the limit is 5: from the 5th row onwards, the minimum index whose price is equal or greater than 5 is 9.
- 5th row, where the limit is 9: as this is the last row, there is no further index whose price is equal or greater than 9.
My solution is shown below - but what would be a neat Polars way of doing it? I was able to solve it in 8 steps, but I'm sure there is a more effective way of doing it.
# Import Polars.
import polars as pl
# Create a sample DataFrame.
df_1 = pl.DataFrame({
'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
'index': [0, 3, 4, 7, 9],
'limit': [12, 18, 11, 5, 9],
'price': [10, 15, 12, 8, 11]
})
# Group by name, so that we can vertically stack all row's values into a single list.
df_2 = df_1.group_by('name').agg(pl.all())
# Put the lists with the original DataFrame.
df_3 = df_1.join(
other=df_2,
on='name',
suffix='_list'
)
# Explode the dataframe to long format by exploding the given columns.
df_3 = df_3.explode([
'index_list',
'limit_list',
'price_list',
])
# Filter the DataFrame for the condition we want.
df_3 = df_3.filter(
(pl.col('index_list') > pl.col('index')) &
(pl.col('price_list') >= pl.col('limit'))
)
# Get the minimum index over the index column.
df_3 = df_3.with_columns(
pl.col('index_list').min().over('index').alias('min_index')
)
# Select only the relevant columns and drop duplicates.
df_3 = df_3.select(
pl.col(['index', 'min_index'])
).unique()
# Finally join the result.
df_final = df_1.join(
other=df_3,
on='index',
how='left'
)
print(df_final)
Share
Improve this question
edited 21 hours ago
jonrsharpe
122k30 gold badges263 silver badges469 bronze badges
asked 21 hours ago
Danilo SettonDanilo Setton
70511 silver badges21 bronze badges
1 Answer
Reset to default 2Option 1: df.join_where
(experimental)
out = (
df_1.join(
df_1
.join_where(
df_1.select('index', 'price'),
pl.col('index_right') > pl.col('index'),
pl.col('price_right') >= pl.col('limit')
)
.group_by('index')
.agg(
pl.col('index_right').min().alias('min_index')
),
on='index',
how='left'
)
)
Output:
shape: (5, 5)
┌───────┬───────┬───────┬───────┬───────────┐
│ name ┆ index ┆ limit ┆ price ┆ min_index │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════╪═══════╪═══════╪═══════════╡
│ Alpha ┆ 0 ┆ 12 ┆ 10 ┆ 3 │
│ Alpha ┆ 3 ┆ 18 ┆ 15 ┆ null │
│ Alpha ┆ 4 ┆ 11 ┆ 12 ┆ 9 │
│ Alpha ┆ 7 ┆ 5 ┆ 8 ┆ 9 │
│ Alpha ┆ 9 ┆ 9 ┆ 11 ┆ null │
└───────┴───────┴───────┴───────┴───────────┘
Explanation / Intermediates
- Use
df.join_where
and forother
usedf.select
(note that you don't need 'limit'), adding the filter predicates.
# df_1.join_where(...)
shape: (4, 6)
┌───────┬───────┬───────┬───────┬─────────────┬─────────────┐
│ name ┆ index ┆ limit ┆ price ┆ index_right ┆ price_right │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════╪═══════╪═══════╪═════════════╪═════════════╡
│ Alpha ┆ 0 ┆ 12 ┆ 10 ┆ 3 ┆ 15 │
│ Alpha ┆ 0 ┆ 12 ┆ 10 ┆ 4 ┆ 12 │
│ Alpha ┆ 4 ┆ 11 ┆ 12 ┆ 9 ┆ 11 │
│ Alpha ┆ 7 ┆ 5 ┆ 8 ┆ 9 ┆ 11 │
└───────┴───────┴───────┴───────┴─────────────┴─────────────┘
- Since order is not maintained, use
df.group_by
to retrievepl.Expr.min
per 'index'.
# df_1.join_where(...).group_by('index').agg(...)
shape: (3, 2)
┌───────┬───────────┐
│ index ┆ min_index │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═══════╪═══════════╡
│ 0 ┆ 3 │
│ 7 ┆ 9 │
│ 4 ┆ 9 │
└───────┴───────────┘
- The result we add to
df_1
with a left join.
Option 2: df.join
with "cross" + df.filter
(Adding this option, since df.join_where
is experimental. This will be more expensive though.)
out2 = (
df_1.join(
df_1
.join(df_1.select('index', 'price'), how='cross')
.filter(
pl.col('index_right') > pl.col('index'),
pl.col('price_right') >= pl.col('limit')
)
.group_by('index')
.agg(
pl.col('index_right').min().alias('min_index')
),
on='index',
how='left'
)
)
out2.equals(out)
# True
- 虚拟现实技术2015年将迎来爆发
- IBM联合创新策略伙伴合作启动
- 消息称PC供应商面临来自Windows 8的挑战
- windows 10 - Gamemaker Mobile device Inconsistent Drag Speed Across Different Operating Systems (Win7 vs. Win10) - Stack Overflo
- github - How do if fix theses syntax errors with this bash script - Stack Overflow
- machine learning - Cannot import name 'T2TViTModel' from 'transformers' - Stack Overflow
- sql - PostgreSQL ERROR: cannot accumulate arrays of different dimensionality - Stack Overflow
- The blender. How to combine multiple textures into one - Stack Overflow
- azure - .Net Core C# IMAP integration for outlook - Stack Overflow
- smartcontracts - FailedCall() with OpenZeppelin meta transactions (ERC2771Forwarder and ERC2771Context) - Stack Overflow
- javascript - How to play HLS live-stream from the end with Bitmovin player - Stack Overflow
- web - Framer Motion Scrollable - Stack Overflow
- ace - Does JSONEditor have a default function to get the JSON block from the cursor position? - Stack Overflow
- Make property of python class dependent on external variable? - Stack Overflow
- How to Use GraalVM to Execute JavaScript Code Making REST Calls Using Java? - Stack Overflow
- r - How to change the collapse breakpoints of a bslib navbar? - Stack Overflow
- Conflicting dependencies while installing torch==1.10.0, torchaudio==0.10.0, and torchvision==0.11.0 in my Python environment -