链式方法是当前比较流行的一种语法规则。
在过去的几个版本中,我们已经提到了几个支持链式方法的函数:
本文将从一个简单的例子说起:
In [2]:
import numpy as np import pandas as pd import seaborn as sns import matplotlib.pyplot as plt def read(fp): df = (pd.read_csv(fp) .rename(columns=str.lower) .drop('unnamed: 36', axis=1) .pipe(extract_city_name) .pipe(time_to_datetime, ['dep_time', 'arr_time', 'crs_arr_time', 'crs_dep_time']) .assign(fl_date=lambda x: pd.to_datetime(x['fl_date']), dest=lambda x: pd.Categorical(x['dest']), origin=lambda x: pd.Categorical(x['origin']), tail_num=lambda x: pd.Categorical(x['tail_num']), unique_carrier=lambda x: pd.Categorical(x['unique_carrier']), cancellation_code=lambda x: pd.Categorical(x['cancellation_code']))) return df def extract_city_name(df): ''' Chicago, IL -> Chicago for origin_city_name and dest_city_name ''' cols = ['origin_city_name', 'dest_city_name'] city = df[cols].apply(lambda x: x.str.extract("(.*), /w{2}", expand=False)) df = df.copy() df[['origin_city_name', 'dest_city_name']] = city return df def time_to_datetime(df, columns): ''' Combine all time items into datetimes. 2014-01-01,0914 -> 2014-01-01 09:14:00 ''' df = df.copy() def converter(col): timepart = (col.astype(str) .str.replace('/.0$', '') # NaNs force float dtype .str.pad(4, fillchar='0')) return pd.to_datetime(df['fl_date'] + ' ' + timepart.str.slice(0, 2) + ':' + timepart.str.slice(2, 4), errors='coerce') return datetime_part df[columns] = df[columns].apply(converter) return df df = read("878167309_T_ONTIME.csv") df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 471949 entries, 0 to 471948 Data columns (total 36 columns): fl_date 471949 non-null datetime64[ns] unique_carrier 471949 non-null category airline_id 471949 non-null int64 tail_num 467903 non-null category fl_num 471949 non-null int64 origin_airport_id 471949 non-null int64 origin_airport_seq_id 471949 non-null int64 origin_city_market_id 471949 non-null int64 origin 471949 non-null category origin_city_name 471949 non-null object origin_state_nm 471949 non-null object dest_airport_id 471949 non-null int64 dest_airport_seq_id 471949 non-null int64 dest_city_market_id 471949 non-null int64 dest 471949 non-null category dest_city_name 471949 non-null object dest_state_nm 471949 non-null object crs_dep_time 471949 non-null datetime64[ns] dep_time 441586 non-null datetime64[ns] dep_delay 441622 non-null float64 taxi_out 441266 non-null float64 wheels_off 441266 non-null float64 wheels_on 440453 non-null float64 taxi_in 440453 non-null float64 crs_arr_time 471949 non-null datetime64[ns] arr_time 440302 non-null datetime64[ns] arr_delay 439620 non-null float64 cancelled 471949 non-null float64 cancellation_code 30852 non-null category diverted 471949 non-null float64 distance 471949 non-null float64 carrier_delay 119994 non-null float64 weather_delay 119994 non-null float64 nas_delay 119994 non-null float64 security_delay 119994 non-null float64 late_aircraft_delay 119994 non-null float64 dtypes: category(5), datetime64[ns](5), float64(14), int64(8), object(4) memory usage: 115.3+ MB
我觉得链式方法的代码非常易读,但是有些人却并了解它。它并不像重嵌套函数那样循环调用参数,它的所有代码和流程都是自上而下运行的,这大大增强了代码的可读性。
我最喜欢的示例来自 Jeff Allen ,比较以下这两段功能相同但风格迥异的代码:
tumble_after( broke( fell_down( fetch(went_up(jack_jill, "hill"), "water"), jack), "crown"), "jill" )
和
jack_jill %>% went_up("hill") %>% fetch("water") %>% fell_down("jack") %>% broke("crown") %>% tumble_after("jill")
对比上述两种风格的代码,你会发现即使你不知道 R 语言中管道符号 %>%
的功能,你也能很轻易地看懂第二段代码。而对于第一段代码而言,你需要弄清楚代码的执行顺序以及如何处理相应的函数参数。
作为读者,你可能会说你不会写出类似于重嵌套风格的代码,但是大多数情况下你的代码应该是如下所示:
on_hill = went_up(jack_jill, 'hill') with_water = fetch(on_hill, 'water') fallen = fell_down(with_water, 'jack') broken = broke(fallen, 'jack') after = tmple_after(broken, 'jill')
我非常不喜欢这个风格的代码,因为我需要花费很多时间来思考如何对变量进行命名。这是非常令人困扰的事情,因为我们根本不关心 on_hill
这些中间变量。
上述代码的第四种实现方法是可行的,假设你拥有一个 JackAndJill
对象并且你可以自定义一些方法。那么你可以实现类似于 R 语言中的管道功能:
jack_jill = JackAndJill() (jack_jill.went_up('hill') .fetch('water') .fell_down('jack') .broke('crown') .tumble_after('jill') )
但是这种方法的问题在于如果的数据不是 ndarray
或者 DataFrame
或者 DataArray
,那么上述的方法就不存在了。而且我们很难对 DataFrame 的子类进行拓展从而来适应自定义的方法。同时,你所创建的从 DataFrame 中继承的子类可能仅适用于你自己的代码,无法和其他方法进行交互操作,因此你的代码将会非常零散。
或者你可以往 pandas 的项目中提交新的 pull request,从而实现自己的方法。但是你需要说服该项目的维护者,你的新方法值得加入到该项目中并维护之。而且 DataFrame
目前已经拥有超过 250 种的方法,因此我们不愿意增加更多的方法。
jack_jill = pd.DataFrame() (jack_jill.pipe(went_up, 'hill') .pipe(fetch, 'water') .pipe(fell_down, 'jack') .pipe(broke, 'crown') .pipe(tumble_after, 'jill') )
DataFrame.pipe
的第一个参数是 DataFrame,我们只需要指明后续的参数即可。
过长的链式代码的缺点是调试比较麻烦。由于没有生成中间变量值,所以如果代码出问题了,我们无法直接定位出问题在哪。Python 中的生成器也有类似的问题,借助生成器机制我们可以降低计算机内存消耗,但是此时我们比较难调试程序。
就我常用的探索分析过程而言,这并不是一个大问题。我平常处理的都是不会再更新的数据集,而且对原始数据集进行加工的步骤也不多。
对于规模较大的工作流程,你可能需要借助 pandas 的其他功能,比如 Airflow 或者 Luigi。
对于需要重复运行的中等规模 ETL 工作流程,我将借助装饰器来审查 DataFrame 每个工作步骤所产生的属性日志。
from functools import wraps import logging def log_shape(func): @wraps(func) def wrapper(*args, **kwargs): result = func(*args, **kwargs) logging.info("%s,%s" % (func.__name__, result.shape)) return result return wrapper def log_dtypes(func): @wraps(func) def wrapper(*args, **kwargs): result = func(*args, **kwargs) logging.info("%s,%s" % (func.__name__, result.dtypes)) return result return wrapper @log_shape @log_dtypes def load(fp): df = pd.read_csv(fp, index_col=0, parse_dates=True) @log_shape @log_dtypes def update_events(df, new_events): df.loc[new_events.index, 'foo'] = new_events return df
借助我之前制作的一个用于验证管道中数据集有效性的软件库 engarde
,我们可以很好地完成工作。
大多数 pandas 的方法都有一个默认值为 False
的关键词 inplace
。通常来说,你不应该做 inplace 运算。
首先,如果你喜欢用链式规则来写代码的话,你肯定不会用 inplace 运算,因为这会导致最终返回的结果是 None
,并中断相应的管道链。
其次,我怀疑存在一个适合 inplace 运算的构思模型。也就是说,最终结果并不会被分配到额外的存储器中。但实际上这可能是不真实的,pandas 中还存在许多下述用法:
def dataframe_method(self, inplace=False) data = self if inplace else self.copy() # result = ... if inplace: self._update_inplace(result) else: return result
最后,类似于 ibis 或者 dask 这种类型的项目 inplace 运算并没有任何意义,因为此时你需要处理表达式或者建立可执行的 DAG 任务,而不仅仅是处理数据而已。
我觉得到此为止我并没有怎么写代码,更多的是在介绍一些额外的东西,我对此感到非常抱歉。接下来,让我们做一些探索性分析吧。
In [132]:
sns.set(style='white', context='talk')
In [133]:
import statsmodels.api as sm
一架一天执行多趟航班执飞任务的飞机“堵机”了,会导致靠后的航班延误更长时间吗?
In [162]:
flights = (df[['fl_date', 'tail_num', 'dep_time', 'dep_delay', 'distance']] .dropna() .sort_values('dep_time') .assign(turn = lambda x: x.groupby(['fl_date', 'tail_num']) .dep_time .transform('rank').astype(int))) fig, ax = plt.subplots(figsize=(15, 5)) sns.boxplot(x='turn', y='dep_delay', data=flights, ax=ax) sns.despine() plt.savefig('images/mc_turn.svg', transparent=True)
一天中较晚起飞的航班会延误更长时间吗?
In [180]:
plt.figure(figsize=(15, 5)) (df[['fl_date', 'tail_num', 'dep_time', 'dep_delay', 'distance']] .dropna() .assign(hour=lambda x: x.dep_time.dt.hour) .query('5 < dep_delay < 600') .pipe((sns.boxplot, 'data'), 'hour', 'dep_delay')) sns.despine() plt.savefig('images/delay_by_hour.svg', transparent=True)
我们将延误超过十小时的数据视为异常值并将其剔除掉。
In [164]:
fig, ax = plt.subplots(figsize=(15, 5)) sns.boxplot(x='hour', y='dep_delay', data=flights[flights.dep_delay < 600], ax=ax) sns.despine() plt.savefig('images/mc_no_days.svg', transparent=True)
接下来,我们仅考虑确实发生延误的航班数据。
In [166]:
fig, ax = plt.subplots(figsize=(15, 5)) sns.boxplot(x='hour', y='dep_delay', data=flights.query('5 < dep_delay < 600'), ax=ax) sns.despine() plt.savefig('images/mc_delays.svg', transparent=True)
哪个航班的延误情况最严重呢?
In [175]:
# Groupby.agg accepts dict of {column: {ouput_name: agg_func}} air = (df.groupby(['origin', 'dest']) .agg({'dep_delay': {'dep_mean': 'mean', 'dep_count': 'count'}, 'arr_delay': {'arr_mean': 'mean', 'arr_count': 'count'}})) air.columns = air.columns.droplevel()
In [171]:
air[air.arr_count > 50].sort_values('dep_mean', ascending=False).head(10)
Out[171]:
arr_mean | arr_count | dep_count | dep_mean | ||
---|---|---|---|---|---|
origin | dest | ||||
MDW | MSY | 47.740741 | 54 | 54 | 55.111111 |
ORD | HSV | 56.578125 | 64 | 65 | 52.800000 |
IAD | EWR | 47.887500 | 80 | 81 | 52.333333 |
JFK | ATL | 49.647887 | 142 | 142 | 51.464789 |
CMH | MIA | 54.000000 | 61 | 61 | 51.344262 |
FLL | BOS | 40.033784 | 148 | 148 | 51.033784 |
IAD | CLT | 51.111111 | 54 | 54 | 50.888889 |
MDW | BDL | 45.807692 | 52 | 52 | 50.442308 |
SJU | BOS | 32.081081 | 111 | 112 | 49.660714 |
PBI | BWI | 40.655172 | 87 | 87 | 49.643678 |
哪个航空公司的延误情况最严重呢?
In [174]:
airlines = df.groupby('unique_carrier').dep_delay.agg(['mean', 'count']) airlines['mean'].sort_values().plot.barh() sns.despine()
B6 是美国捷蓝航空公司。
I wanted to try out scikit-learn's new Gaussian Process module so here's a pretty picture.
In [192]:
print(delay.head().to_html())
<table border="0" class="dataframe"> <thead> <tr style="text-align: right;"> <th></th> <th>count</th> <th>delay</th> <th>dist</th> </tr> <tr> <th>tail_num</th> <th></th> <th></th> <th></th> </tr> </thead> <tbody> <tr> <th>D942DN</th> <td>120</td> <td>9.232143</td> <td>829.783333</td> </tr> <tr> <th>N001AA</th> <td>139</td> <td>13.818182</td> <td>616.043165</td> </tr> <tr> <th>N002AA</th> <td>135</td> <td>9.570370</td> <td>570.377778</td> </tr> <tr> <th>N003AA</th> <td>125</td> <td>5.722689</td> <td>641.184000</td> </tr> <tr> <th>N004AA</th> <td>138</td> <td>2.037879</td> <td>630.391304</td> </tr> </tbody> </table>
In [190]:
planes = df.assign(year=df.fl_date.dt.year).groupby("tail_num") delay = (planes.agg({"year": "count", "distance": "mean", "arr_delay": "mean"}) .rename(columns={"distance": "dist", "arr_delay": "delay", "year": "count"}) .query("count > 20 & dist < 2000")) delay.head()
Out[190]:
count | delay | dist | |
---|---|---|---|
tail_num | |||
D942DN | 120 | 9.232143 | 829.783333 |
N001AA | 139 | 13.818182 | 616.043165 |
N002AA | 135 | 9.570370 | 570.377778 |
N003AA | 125 | 5.722689 | 641.184000 |
N004AA | 138 | 2.037879 | 630.391304 |
In [253]:
X = delay['dist'] y = delay['delay']
In [254]:
from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import RBF, WhiteKernel kernel = (1.0 * RBF(length_scale=10.0, length_scale_bounds=(1e2, 1e4)) + WhiteKernel(noise_level=.5, noise_level_bounds=(1e-1, 1e+5))) gp = GaussianProcessRegressor(kernel=kernel, alpha=0.0).fit(X.reshape(-1, 1), y) X_ = np.linspace(X.min(), X.max(), 1000) y_mean, y_cov = gp.predict(X_[:, np.newaxis], return_cov=True)
In [255]:
%matplotlib inline sns.set(style='white', context='talk')
In [256]:
ax = delay.plot(kind='scatter', x='dist', y = 'delay', figsize=(12, 6), color='k', alpha=.25, s=delay['count'] / 10) ax.plot(X_, y_mean, lw=2, zorder=9) ax.fill_between(X_, y_mean - np.sqrt(np.diag(y_cov)), y_mean + np.sqrt(np.diag(y_cov)), alpha=0.25) sizes = (delay['count'] / 10).round(0) for area in np.linspace(sizes.min(), sizes.max(), 3).astype(int): plt.scatter([], [], c='k', alpha=0.7, s=area, label=str(area * 10) + ' flights') plt.legend(scatterpoints=1, frameon=False, labelspacing=1) ax.set_xlim(0, 2100) ax.set_ylim(-20, 65) sns.despine() plt.tight_layout() plt.savefig("images/mc_flights.svg", transparent=True) plt.savefig("images/mc_flights.png")
谢谢阅读本文!由于我们更多地讨论了关于代码风格的问题而不是介绍实际案例操作,所以本文所介绍的内容比较抽象。谢谢你们的包容,下次我将介绍一个偏实务的话题!
原文链接: http://tomaugspurger.github.io/method-chaining.html
原文作者:Tom Augspurger
译者:Fibears