在服務(wù)一些客戶做商業(yè)問題的機器學習建模時,我們會碰到不少擁有非常大量數(shù)據(jù)且對模型 pipeline 運行有一定要求的情況。相比直接的單機 Python 建模,這類項目有一些難點:
從前兩點來看,我們之前習慣的單機 Pandas + lgb/xgb 建模思路已經(jīng)難以適用(除非搞臺神威·太湖之光之類的機器)。所以我們需要引入目前大數(shù)據(jù)界的當紅炸子雞 -- Spark 來協(xié)助完成此類項目。
要玩 Spark 的第一步,首先是部署。如果是本機測試運行,一般跑一個 pip install pyspark
就能把一個 local 節(jié)點跑起來了,非常的方便。如果想部署一個相對完整一點的 standalone 集群,可以參考以下步驟:
cp spark-env.sh.template spark-env.sh
,然后進行編輯。我改的一些配置具體如下:spark-env.sh
# 配置 master 和 worker
SPARK_MASTER_HOST=0.0.0.0
SPARK_DAEMON_MEMORY=4g
SPARK_WORKER_CORES=6
SPARK_WORKER_MEMORY=36g
slaves
# 指定 slaves 機器的列表,這里就選了本機
localhost
spark-defaults.conf
# 這個文件很多教程都會讓你改,說是 spark-submit 命令會默認從這里讀取相關(guān)配置
# 但要注意我們寫的 PySpark 程序很多時候并不是通過 spark-submit 命令提交的,所以這里改了可能沒用
spark.driver.memory 4g
sbin/start-all.sh
即可?;蛘咭部梢苑謩e起 master 和 slave,運行 ./sbin/start-master.sh
和 ./sbin/start-slave.sh spark://127.0.0.1:7077 -c 6 -m 36G
即可。sbin/stop-all.sh
,或者分別停 slave 和 master 都行。這樣就算部署完了!其中 spark master 會有一個監(jiān)聽 8080 端口的 web-ui,worker 會監(jiān)聽 8081,后面提交 application 就會有監(jiān)聽 4040 端口的管理界面,功能強大,用戶友好度++。
直接上代碼:
from pyspark.sql import SparkSession
spark = (SparkSession.builder
.master('spark://127.0.0.1:7077')
.appName('zijie')
.getOrCreate())
df = spark.read.parquet('data/the_only_data_i_ever_wanted.parquet')
df.show()
我們的大數(shù)據(jù)平臺就這么跑起來了!
這個在網(wǎng)上看一些 Spark 相關(guān)的介紹應(yīng)該很快會有一些認識。有幾個比較明顯的區(qū)別點我大致列一下:
從實際操作來看,在 PySpark 中其實有很多操作長得跟 Pandas 非常類似,比如我們常用的 df[df['date'] > '2020-01-01']
之類的寫法。當然區(qū)別也有不少,所以后來 Databricks 干脆推出了一個 Koalas
的庫來支持更平滑的切換。
這里主要記錄幾個在項目過程中寫的感覺比較好玩的,并對比 pandas 的版本方便大家理解。
pandas version:
# 對每家店每個 SKU 歷史無銷售情況進行填零處理
def fill_dates(df):
new_df = []
for store_id in df.store_id.unique():
for sku in df.query('store_id == @store_id').sku.unique():
tmp = pd.DataFrame()
cond = (df.store_id == store_id) & (df.sku == sku)
min_date = df.loc[cond, 'date'].min()
max_date = df.loc[cond, 'date'].max()
dates_in_between = daterange(min_date, max_date)
tmp['date'] = dates_in_between
tmp['sku'] = sku
tmp['store_id'] = store_id
new_df.append(tmp)
new_df = pd.concat(new_df)
new_df = new_df.merge(df, on=['date', 'sku', 'store_id'], how='left').fillna(0)
return new_df
可以看到整體邏輯就是取所有 store, sku 的組合,然后找到每個組合最小最大的售賣日期,把中間的日期都填上。
PS: 這段代碼應(yīng)該效率不高,后續(xù)我們又迭代了幾個版本。
Spark version:
from pyspark.sql import functions as F
def fill_dates_spark(df):
tmp = df.groupby(['store_id', 'sku']).agg(F.min('date').cast('date').alias('min_date'),
F.max('date').cast('date').alias('max_date'))
tmp = tmp.withColumn('date', F.explode(F.sequence('min_date', 'max_date'))).select(
['date', 'store_id', 'sku'])
new_df = tmp.join(df, ['date', 'store_id', 'sku'], 'left').fillna(0, subset=['y'])
return new_df
用了 sequence+explode 操作,代碼簡潔很多。其中 sequence 會自動生成從 start 到 end 的序列(時間,數(shù)字都支持),explode 操作直接把一行“炸開”成多行,省去了 join 操作,性能也更好。我不太確定 Pandas 里是不是也能這么玩?
這個是我們最常用的一種特征了,在 pandas 里主要就是做循環(huán) join:
def shift_daily_data(df, delay, shift_by='date', shift_value='y'):
groupby_df = [x for x in df.columns if (x != shift_by) and (x != shift_value)]
shift_df = df.copy()
shift_df[shift_by] = shift_df[shift_by].apply(lambda x: x + relativedelta(days=delay))
shift_df = shift_df.rename(columns={shift_value: '%s_%s_day_lag_%d' % ('_'.join(groupby_df), shift_value, delay)})
return shift_df
def add_daily_shifts(df, days, categories, shift_by='date', shift_value='y'):
merge_df = df.copy()
for base_categories in categories:
feat_cols = base_categories + [shift_by]
base_df = df.groupby(feat_cols, as_index=False).agg({shift_value: sum})
for i in days:
delay_df = shift_daily_data(base_df, i, shift_by, shift_value)
merge_df = pd.merge(left=merge_df, right=delay_df, how='left', on=feat_cols, sort=False).reset_index(
drop=True).fillna(0)
gc.collect()
return merge_df
# 按照不同維度生成 lag 自回歸時序特征
def add_lag_features(all_data_df, fcst_type):
lag_days = list(range(1, 11)) + [14, 21, 28, 29, 30, 31]
lag_days = [x for x in lag_days if x >= fcst_type]
groupby_cats = [['sku'], ['store_id'], ['sku', 'store_id']]
all_data_df = add_daily_shifts(all_data_df, lag_days, groupby_cats)
return all_data_df
在遷移到 Spark 時第一版我也采用了類似的寫法,不過發(fā)現(xiàn)性能比較差,而且隨著 lag 數(shù)的增多,join 次數(shù)也增多了,數(shù)據(jù)血緣關(guān)系會拉的非常長……
第二版我們采用了 window function 的寫法:
from pyspark.sql import functions as F
from pyspark.sql import Window
def add_date_index(df, date_col, start_day='2016-01-01'):
df = df.withColumn(f'{date_col}_index', F.datediff(date_col, F.lit(start_day)))
return df
def add_shifts_by_window(df, days, group_by, order_by='date_index', shift_value='y'):
# 取 lag 操作,其實就是要取一個時間點往前一個時間窗口中的值
# 然后這個窗口要考慮時間順序,我們就加上 orderBy,需要分門店分 sku,我們就加上 partitionBy
w = Window.orderBy(order_by).partitionBy(*group_by)
new_col_prefix = f'{'_'.join(group_by)}_{shift_value}_day_lag'
# 再用lag函數(shù)取之前的值即可
new_cols = [F.coalesce(F.lag(shift_value, i).over(w), F.lit(0)).alias(f'{new_col_prefix}_{i}') for i in days]
df = df.select('*', *new_cols)
return df
# 接下來主要就是調(diào)用了
def add_daily_shifts_by_categories(df, days, categories, shift_by='date', shift_value='y'):
df = add_date_index(df, shift_by)
shift_by = f'{shift_by}_index'
cat_cols = ['store_id', 'sku']
merge_df = add_shifts_by_window(df, max(days), cat_cols, shift_by, shift_value)
for base_categories in categories:
if len(base_categories) < len(cat_cols):
# 先聚合,再添加 lag 特征
feat_cols = base_categories + [shift_by]
base_df = df.groupby(feat_cols).agg(F.sum(shift_value).alias(shift_value))
join_df = add_shifts_by_window(base_df, max(days), base_categories, shift_by, shift_value)
join_df = join_df.drop(shift_value)
merge_df = merge_df.join(join_df, feat_cols, 'left').fillna(0)
return merge_df
用這個方法的前提是,先要把日期填充做了,否則 window 中的數(shù)值可能是不連續(xù)的。當時也考慮過不做填充可不可以?比如用 F.create_map 的方法創(chuàng)建出時間點與值的 map:df = df.withColumn('m', F.create_map('date_index', 'y'))
,然后用類似的 collect_list
手法獲取 window 中的多個 map,合并 map,然后按 lag 順序取 key,取不到的就填 0 即可。其中合并 map 需要用 udf,大致如下:combineMap = udf(lambda maps: dict(ChainMap(*maps)), MapType(IntegerType(), DoubleType()))
。
在實驗中發(fā)現(xiàn),這個 udf 使用過程中會報錯,說 pandas udf 目前不支持在 window function 中使用,需要用 Spark 3.0 才行……所以暫時用了以上的方案。實測下來發(fā)現(xiàn),用上了 window function,建 lag 特征的時間從 20 多分鐘降到了 200 秒左右,而且不管建多少個 lag,時間基本都是一樣的,可擴展性棒棒的!
從這個例子中也可以看到,window function 結(jié)合 Spark SQL 中帶的各種方法非常強大靈活。而到了 PySpark 這里,還有更加神奇的 pandas udf,光看 官方示例[2] 就有種騷操作飛起的感覺,感興趣的同學可以去看看。
這個項目中我們用的是類似 frequency encoding 的手法,Pandas 代碼如下:
def y_rank_transform(df, col_name, orderby, ascending=True):
sorted_df = df.groupby(col_name).agg({orderby: np.sum}).reset_index().sort_values(orderby, ascending=ascending)
rank_map = {v: i for i, v in enumerate(sorted_df[col_name].values)}
df[col_name] = df[col_name].map(rank_map)
return df, rank_map
def convert_category_feats(full_df, category_features, orderby):
# 根據(jù) orderby 值的大小對 category_features 進行排序編碼
rank_maps = {}
for c in category_features:
if c in full_df:
full_df, rank_map = y_rank_transform(full_df, c, orderby)
rank_maps[c] = rank_map
gc.collect()
return full_df, rank_maps
還是比較好理解的。然后 Spark 里可以直接用 pyspark.ml.feature 里自帶的一些實現(xiàn)來幫助我們做類似的事情:
from pyspark.ml.feature import StringIndexer
from pyspark.sql import functions as F
def convert_category_feats(full_df):
cat_cols = get_category_cols()
cat_cols = [x for x in cat_cols if x in full_df.columns]
# 根據(jù) orderby 值的大小對 category_features 進行排序編碼
for c in cat_cols:
if c in full_df.columns:
target_col = f'{c}_index'
indexer = StringIndexer(inputCol=c, outputCol=target_col)
model = indexer.fit(full_df)
full_df = model.transform(full_df).withColumn(target_col, F.col(target_col).cast('int'))
return full_df
所以有時候也可以沒事瀏覽下標準庫里的東西,說不定你想要的功能都已經(jīng)有現(xiàn)成實現(xiàn)啦。
構(gòu)建完特征,就到了令人激動的模型訓練環(huán)節(jié)!特征構(gòu)建之類的,總體來說還是盡在掌握的感覺,但十億級數(shù)據(jù)量的訓練,就感覺有點心里發(fā)虛了…… 這部分一開始的工作主要由我們年輕帥氣的實力派選手娜可露露來負責。學弟經(jīng)過一番調(diào)研,最終鎖定了一個名為 mmlspark 的庫:
之前我們在不少場景用了 lgb,而這個 mmlspark 同是微軟出品的框架,感覺應(yīng)該穩(wěn)了!
要嘗試這個庫,第一步肯定就是安裝了!這個庫的安裝比較詭異,沒有提供 pypi/conda 安裝包,官網(wǎng)上給出的用法是這樣的:
import pyspark
spark = pyspark.sql.SparkSession.builder.appName('MyApp') \
.config('spark.jars.packages', 'com.microsoft.ml.spark:mmlspark_2.11:1.0.0-rc1') \
.config('spark.jars.repositories', 'https://mmlspark.azureedge.net/maven') \
.getOrCreate()
import mmlspark
但真正跑起來的時候,碰到了一系列網(wǎng)絡(luò)問題,中間嘗試了好久的更換 maven/ivy2 源等,都沒有很好的解決。最后……我們在 GCP 的服務(wù)器上跑了一下代碼,順利下好了所有的依賴。然后把 ~/.ivy2/cache
下的所有文件打包回來到本地緩存文件夾解壓開……Work like a charm!
代碼能跑起來了,接下來應(yīng)該一帆風順了吧!沒想到?jīng)]過多久,學弟就碰到了第二個困難:
上 官方文檔[3] 瞄了一眼,突然感覺不妙!一般文檔寫成這樣的庫,八成是沒啥人用的了……不過看支持的參數(shù),比起 Spark ML 的 GBDT 還是豐富很多的,基本上應(yīng)該是繼承了原生 lgb 的接口。在這里我們的主要目標是通過 early_stopping 來做最基本的調(diào)參,這樣可以保證模型運行時有比較可靠的表現(xiàn)。但是文檔里根本沒有提這個 early_stopping 應(yīng)該怎么用,該怎么辦呢?
遇到這種情況,一般就只能找 1)有沒有別人的代碼用了這個功能,2)源碼里是怎么實現(xiàn)這個功能的。具體到這個問題,我就直接選擇在 github 的 mmlspark repo 里搜 earlyStoppingRound
這個參數(shù):
然后通過調(diào)用路徑做幾層追蹤,就會看到相關(guān)的實現(xiàn):
所以要用 early_stopping,需要幾個條件:
validationIndicatorCol
參數(shù)true
/false
來指定訓練集和驗證集earlyStoppingRound
參數(shù),需要大于 0通過類似的手段,我們解決了一系列因為文檔和示例缺乏導致的使用困難,包括傳入類別變量等。
之前用原生 lgb 時,訓練數(shù)據(jù)處理基本比較簡單,直接用 lgb.dataset
從 pandas/numpy 數(shù)據(jù)集進行構(gòu)建即可。不過 mmlspark 就很不一樣了,竟然要求傳入 2 個 columns,一個叫 featuresCol
,另一個叫 labelCol
。唔,總不至于只支持 1 個特征吧?
轉(zhuǎn)念一想,lightgbm 的 Python API 分了 native 和 sklearn 兩套,那 mmlspark 這個 API 應(yīng)該同理是為了符合 Spark ml 的標準!順著這個思路,果然發(fā)現(xiàn) Spark ml 都是這個套路,然后 Spark ml 庫里也自帶了一個類, 叫pyspark.ml.feature.VectorAssembler
,直接用上就能把需要的 feature columns 轉(zhuǎn)換成 Vector 類型的單個 column 了!代碼類似:
def vectorize(df, feat_cols):
assembler = VectorAssembler(inputCols=feat_cols, outputCol='features')
df = assembler.transform(df)
return df
終于一切代碼就緒,就開始跑訓練了!沒想到剛開始就出現(xiàn)了問題,訓練啟動后就一直沒反應(yīng),看 Spark 的任務(wù)也完全沒有進度,非常詭異。這個問題的排查繞了一些彎路,看了不少 mmlspark 的源碼,嘗試 callstack 的收集,strace 等,都沒有理想的結(jié)果。最后還是從日志中發(fā)現(xiàn)了問題。
先截取一個酷炫的圖,給大家看下在哪里看日志:
具體日志:
上面這個是正常的日志,當時出錯時發(fā)現(xiàn) Spark executor 在啟動 lgb 時報了一堆的錯誤:
NetworkInit failed with exception on local port...
Retrying NetworkInit with local port...
然后可以用類似的方法,在 git repo 里搜這些錯誤信息,看到底是從哪里報出來的:
再接著往上排查幾層,看了下 lgb 分布式訓練的一些文檔說明,就大致明白問題出在哪了。總體的 mmlspark 訓練過程其實就是把分布在各個機器上的數(shù)據(jù)轉(zhuǎn)化為lgb.dataset
形式,然后再各自起原生的 lightgbm 來訓練。多節(jié)點訓練時各個節(jié)點需要通過網(wǎng)絡(luò)端口來進行同步,因此需要在啟動時設(shè)定好大家各自的端口。而且 mmlspark 里用的是 mapPartitions 方法來做具體的訓練,我是在單機上跑(當然只要一臺機器上有多個 parition 都會有這個問題),所以就出現(xiàn)多個 partition 啟動 lightgbm 時監(jiān)聽的端口沖突問題。要解決的話也比較簡單,只需要 repartition 數(shù)據(jù)到每臺服務(wù)器啟動一個 lgb task 即可!
此外對于類似此類 MPI 的計算 load,官方還提供了一個新的 barrier execution mode 來解決一系列相關(guān)問題:
剛解決完訓練卡死的問題,立刻又來了下一個問題。前面剛提到我們用 early_stopping 來尋找一個合適的樹的數(shù)量參數(shù),不過在 mmlspark 中用完 early_stopping 后,發(fā)現(xiàn)沒有方法可以獲取到這個 best_iteration 到底是多少?!
搜了半天文檔和代碼,都沒發(fā)現(xiàn)隱藏功能,只好在 git 上提了一個 issue[4],至今沒人理……(一年多后終于支持了)。如果要自力更生,怎么解決?
def getFeatureImportances(self, importance_type='split'):
'''
Get the feature importances as a list. The importance_type can be 'split' or 'gain'.
'''
return list(self._java_obj.getFeatureImportances(importance_type))
這么簡單嗎?當然不是,還需要在 Scala 里轉(zhuǎn)一層:
/**
* Calls into LightGBM to retrieve the feature importances.
* @param importanceType Can be 'split' or 'gain'
* @return The feature importance values as an array.
*/
def getFeatureImportances(importanceType: String): Array[Double] = {
val importanceTypeNum = if (importanceType.toLowerCase.trim == 'gain') 1 else 0
if (boosterPtr == null) {
LightGBMUtils.initializeNativeLibrary()
boosterPtr = getBoosterPtrFromModelString(model)
}
val numFeaturesOut = lightgbmlib.new_intp()
LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterGetNumFeature(boosterPtr, numFeaturesOut),
'Booster NumFeature')
val numFeatures = lightgbmlib.intp_value(numFeaturesOut)
val featureImportances = lightgbmlib.new_doubleArray(numFeatures)
LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterFeatureImportance(boosterPtr, -1, importanceTypeNum, featureImportances),
'Booster FeatureImportance')
(0 until numFeatures).map(lightgbmlib.doubleArray_getitem(featureImportances, _)).toArray
}
所以要實現(xiàn)比如原生的 get_current_iteration
方法,也得按照上面這個流程走一遍。
首先看到 mmlspark 實現(xiàn)了 saveNativeModel
方法,一看這個名字,應(yīng)該會把模型存成 lgb native model??戳讼麓a應(yīng)該沒問題,就嘗試存了一個。
接下來拿出我們的原生 lightgbm,來 load 這個存好的模型。因為是 Spark 存模型,還需要考慮分布式文件系統(tǒng)等問題,不過 mmlspark 也比較暴力,直接用了 coalesce(1)
加 write text 的方法來存模型,所以最終肯定就是一個文件啦!
讀到原生模型后,取 current iteration 就易如反掌了!最后實現(xiàn)代碼如下:
def get_native_lgb_model(file_path):
txt_files = list(Path(file_path).glob('*.txt'))
if len(txt_files) != 1:
raise Exception('Aww...cannot read model file!')
native_model = lgb.Booster(model_file=txt_files[0].as_posix())
return native_model
def get_best_iteration(model, path_prefix='/share'):
file_path = f'{path_prefix}/lgb_model'
model.saveNativeModel(file_path)
native_model = get_native_lgb_model(file_path)
best_iteration = int(native_model.current_iteration() * 1.02)
return best_iteration
還是比較輕松的嘛!
接下來經(jīng)過了一陣風平浪靜的開發(fā)日子,我們每天日出而作,日落而息,逐漸實現(xiàn)了一些參數(shù)搜索緩存,自動回測,與數(shù)據(jù)開發(fā)平臺實現(xiàn)對接等功能,并逐漸把訓練數(shù)據(jù)量提升到了一億行。在這個階段我們主要的目標是評估當數(shù)據(jù)量增長時,整體的性能變化,機器資源占用變化如何,進而產(chǎn)出對機器資源需求的規(guī)劃方案來。如果所需的機器數(shù)量過多,就需要做一系列的優(yōu)化控制整體成本。
前面有提到我們的整個數(shù)據(jù)獲取,清洗,構(gòu)建特征,模型訓練預(yù)測,業(yè)務(wù)系統(tǒng)對接產(chǎn)出,必須在 5 個小時以內(nèi)完成,這其中的每一個階段要花的時間都要做好優(yōu)化工作,確保沒有明顯的瓶頸點。
首先,優(yōu)化的前提是監(jiān)控,在本地集群和開發(fā)平臺,我們都設(shè)計了相應(yīng)的日志,用于抓取 pipeline 中每個階段所需要花費的時間。另外開發(fā)平臺部署了 Prometheus 和 Grafana,本地集群我們配備了 dstat, jstat, top 等腳本,主要用于監(jiān)控 Spark, Python 相關(guān)進程的 cpu,內(nèi)存使用情況,為整體的 capacity planning 做準備。
接著在監(jiān)控的基礎(chǔ)上,我們對各個 stage 做了相應(yīng)的優(yōu)化:
describe <table_name>
語句進行查詢。learningRate
: 越大訓練所需的輪次越少numLeaves
: 越小則每棵樹越簡單,但實際可能需要的數(shù)量越多maxBin
: 越小則訓練速度越快,但會損失精度baggingFraction
和 baggingFreq
: 訓練采樣率,采樣之后訓練數(shù)據(jù)少了,自然速度就快了,可以控制每多少輪重新采樣一次featureFraction
: 特征采樣率,原理類似上一條。注意有些情況下這個參數(shù)設(shè)置為 1 效果才比較理想,一個典型的例子是 one hot encoding 后的數(shù)據(jù),采樣后可能導致類別信息的缺失我們采取了一個比較簡單的做法來做訓練速度的優(yōu)化,在原先隨機搜索的基礎(chǔ)上,除了記錄模型的精度指標,我們還會一并記錄訓練所花費的時間。最后在做參數(shù)選擇時,可以靈活選擇可以接受的時間耗費,在訓練時間小于這個要求的前提下,選取效果最優(yōu)的參數(shù)。通過這一步優(yōu)化,整體訓練時間縮短了一半左右,而且訓練精度并沒有下降。
bigint
, double
等類型來存儲數(shù)據(jù),但在我們這個應(yīng)用場景中, int
, float
類型就已經(jīng)足夠。因此可以做一些類型轉(zhuǎn)換,節(jié)約內(nèi)存占用和保存文件的大小。經(jīng)過一系列的優(yōu)化工作,基本上可以達到使用 5 臺 16c/64g 機器完成十億級模型訓練預(yù)測按時產(chǎn)出的需求。
當訓練數(shù)據(jù)擴充到一億規(guī)模時,我們的 mmlspark 又出現(xiàn)了一個奇怪的卡頓問題。在訓練過程中,這一億數(shù)據(jù)并不是進入一個統(tǒng)一的大模型來訓練,而是會根據(jù)策略引擎的規(guī)則,分發(fā)到不同的模型做訓練。前面有提到模型的數(shù)量大約有 40+個,這其中有些模型分到的數(shù)據(jù)量會比較大,因而本身訓練時間就比較慢。但隨著訓練流程的逐步進行,這個訓練時間變得越來長,直至 task 失去響應(yīng)。所以我們又啟動了新一輪的排查流程。
遇到卡頓,首先觀察系統(tǒng)資源情況,例如 cpu, 內(nèi)存,磁盤 io/空間,網(wǎng)絡(luò)等。但運行過程中發(fā)現(xiàn)沒有一個資源吃緊的情況,其中特別奇怪的是 cpu 使用率在 100%(機器是 8 核,正好用滿一個 core),沒有發(fā)揮所有的性能。
觀察模型正常訓練時的情況,Spark 啟動的 lgb 會基本把所有 cpu 資源打滿,因此懷疑是在進入訓練之前的某些環(huán)節(jié)無法并行計算導致的問題。
為了更好的追蹤 jvm 內(nèi)部情況,請出了老朋友 visualvm
。這貨是我 N 年前工作的時候用的主力排查工具,不知道現(xiàn)在還是不是流行。為了用上這個工具,需要對 Spark 的配置做一些修改:
${SPARK_HOME}/conf/metrics.properties
文件,加上 jmx 相關(guān)的一些 sink-Dcom.sun.management.jmxremote -Dcom.sun.management.jmxremote.authenticate=false -Dcom.sun.management.jmxremote.ssl=false -Dcom.sun.management.jmxremote.port=22990
,注意很多文章都說要加在 spark-defaults.conf
里,但是我們直接運行 Python 程序并不會調(diào)用 spark-submit
命令。所以這些參數(shù)需要在程序內(nèi)的 spark session 中指定:metrics_conf = f'{spark_home}/conf/metrics.properties'
jmx_conf = '-Dcom.sun.management.jmxremote -Dcom.sun.management.jmxremote.authenticate=false ' \
'-Dcom.sun.management.jmxremote.ssl=false -Dcom.sun.management.jmxremote.port=22990'
spark = (SparkSession.builder
.master('spark://127.0.0.1:7077')
.appName('zijie')
.config('spark.executor.memory', '36g')
.config('spark.driver.memory', '6g')
.config('spark.jars.repositories', 'https://mmlspark.azureedge.net/maven')
.config('spark.jars.packages', 'com.microsoft.ml.spark:mmlspark_2.11:1.0.0-rc1')
.config('spark.metrics.conf', metrics_conf)
.config('spark.executor.extraJavaOptions', jmx_conf)
.getOrCreate())
配置好之后,啟動應(yīng)用,就能在 visualvm
里添加 jmx 連接做監(jiān)控了。我們獲取到了卡頓時候的 cpu 使用情況截圖如下:
可以看出模型訓練階段,cpu 使用率都是在 80%上下波動,但模型訓練的中間,總有一些只占用了 1 個 cpu 資源的時間段。而且這些 cpu 資源使用是黃色的正常工作線程,而不是垃圾回收。
接下來一個比較自然的思路就是在這些 cpu 使用低谷去獲取 thread dump,看系統(tǒng)到底在忙什么。用 jstack
或者 visualvm
等工具都可以獲取到。一個典型的 thread dump 如下所示:(截取了前面 10%的內(nèi)容)
2020-02-20 16:46:33
Full thread dump OpenJDK 64-Bit Server VM (25.242-b08 mixed mode):
'Barrier task timer for barrier() calls.' - Thread t@2867
java.lang.Thread.State: WAITING
at java.lang.Object.wait(Native Method)
- waiting on <51997c78> (a java.util.TaskQueue)
at java.lang.Object.wait(Object.java:502)
at java.util.TimerThread.mainLoop(Timer.java:526)
at java.util.TimerThread.run(Timer.java:505)
Locked ownable synchronizers:
- None
'JMX server connection timeout 2854' - Thread t@2854
java.lang.Thread.State: TIMED_WAITING
at java.lang.Object.wait(Native Method)
- waiting on <3be7c557> (a [I)
at com.sun.jmx.remote.internal.ServerCommunicatorAdmin$Timeout.run(ServerCommunicatorAdmin.java:168)
at java.lang.Thread.run(Thread.java:748)
Locked ownable synchronizers:
- None
'RMI TCP Connection(6)-10.0.50.59' - Thread t@2853
java.lang.Thread.State: RUNNABLE
at java.net.SocketInputStream.socketRead0(Native Method)
at java.net.SocketInputStream.socketRead(SocketInputStream.java:116)
at java.net.SocketInputStream.read(SocketInputStream.java:171)
at java.net.SocketInputStream.read(SocketInputStream.java:141)
at java.io.BufferedInputStream.fill(BufferedInputStream.java:246)
at java.io.BufferedInputStream.read(BufferedInputStream.java:265)
- locked <3e35bfb4> (a java.io.BufferedInputStream)
at java.io.FilterInputStream.read(FilterInputStream.java:83)
at sun.rmi.transport.tcp.TCPTransport.handleMessages(TCPTransport.java:555)
at sun.rmi.transport.tcp.TCPTransport$ConnectionHandler.run0(TCPTransport.java:834)
at sun.rmi.transport.tcp.TCPTransport$ConnectionHandler.lambda$run$0(TCPTransport.java:688)
at sun.rmi.transport.tcp.TCPTransport$ConnectionHandler$$Lambda$37/13510931.run(Unknown Source)
at java.security.AccessController.doPrivileged(Native Method)
at sun.rmi.transport.tcp.TCPTransport$ConnectionHandler.run(TCPTransport.java:687)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Locked ownable synchronizers:
- locked <37829806> (a java.util.concurrent.ThreadPoolExecutor$Worker)
想起我當年作為新人來看 thread dump 時,心情是多么的激動!哇靠,這里有個 WAITING,這里有個 locked,是不是找到問題了!但后來發(fā)現(xiàn),其實都不是問題……如果你一開始看 thread dump 沒有頭緒,非常正常,一方面可以去搜索一些這些 thread state 代表什么含義,另一方面也可以在程序正常運行時跑個 thread dump 看看,會發(fā)現(xiàn)其實也有很多 WAITING 和 lock。
在這個具體的問題里,出問題的 thread 主要是以下這個:
'Executor task launch worker for task 8327' - Thread t@2767
java.lang.Thread.State: RUNNABLE
at java.io.FileInputStream.readBytes(Native Method)
at java.io.FileInputStream.read(FileInputStream.java:255)
at org.apache.spark.network.util.LimitedInputStream.read(LimitedInputStream.java:99)
at net.jpountz.lz4.LZ4BlockInputStream.readFully(LZ4BlockInputStream.java:269)
at net.jpountz.lz4.LZ4BlockInputStream.refill(LZ4BlockInputStream.java:245)
at net.jpountz.lz4.LZ4BlockInputStream.read(LZ4BlockInputStream.java:157)
at org.apache.spark.storage.BufferReleasingInputStream.read(ShuffleBlockFetcherIterator.scala:591)
at java.io.BufferedInputStream.fill(BufferedInputStream.java:246)
at java.io.BufferedInputStream.read1(BufferedInputStream.java:286)
at java.io.BufferedInputStream.read(BufferedInputStream.java:345)
- locked <34545dd4> (a java.io.BufferedInputStream)
at java.io.DataInputStream.read(DataInputStream.java:149)
at org.spark_project.guava.io.ByteStreams.read(ByteStreams.java:899)
at org.spark_project.guava.io.ByteStreams.readFully(ByteStreams.java:733)
at org.apache.spark.sql.execution.UnsafeRowSerializerInstance$$anon$2$$anon$3.next(UnsafeRowSerializer.scala:127)
at org.apache.spark.sql.execution.UnsafeRowSerializerInstance$$anon$2$$anon$3.next(UnsafeRowSerializer.scala:110)
at scala.collection.Iterator$$anon$12.next(Iterator.scala:445)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:410)
at org.apache.spark.util.CompletionIterator.next(CompletionIterator.scala:29)
at org.apache.spark.InterruptibleIterator.next(InterruptibleIterator.scala:40)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:410)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage4.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:462)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
at scala.collection.Iterator$class.foreach(Iterator.scala:891)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:59)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:104)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:48)
at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:310)
at scala.collection.AbstractIterator.to(Iterator.scala:1334)
at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:302)
at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1334)
at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:289)
at scala.collection.AbstractIterator.toArray(Iterator.scala:1334)
at com.microsoft.ml.spark.lightgbm.TrainUtils$.translate(TrainUtils.scala:229)
at com.microsoft.ml.spark.lightgbm.TrainUtils$.trainLightGBM(TrainUtils.scala:385)
at com.microsoft.ml.spark.lightgbm.LightGBMBase$$anonfun$6.apply(LightGBMBase.scala:145)
at org.apache.spark.rdd.RDDBarrier$$anonfun$mapPartitions$1$$anonfun$apply$1.apply(RDDBarrier.scala:51)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:123)
at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Locked ownable synchronizers:
- locked <3ddc0126> (a java.util.concurrent.ThreadPoolExecutor$Worker)
調(diào)用棧里面有很多 spark, mmlspark 的關(guān)鍵字,一看就是“自己人”??吹阶钌蠈樱@個調(diào)用主要是在做文件 IO,那么問題來了,為什么這里 IO 不能并行利用多 CPU 呢?一個比較可疑的點是 lz4 那塊的調(diào)用。搜索一番發(fā)現(xiàn)果然 lz4 壓縮算法不是 splittable 的,這導致了在處理壓縮文件時,必須把所有數(shù)據(jù)放在一起來運行(比如 MR 里就是 single mapper 了)。另外這里也跟我們前面用了 1 個 partition 有關(guān),Spark 在做 shuffle, broadcast 時都會用到 lz4 壓縮[5],然后在解壓縮階段只有一個 partition 參與來做,就自然出現(xiàn)了單 CPU 被打滿的現(xiàn)象。
細心的同學可能會發(fā)現(xiàn),上面我們在找 mmlspark 中怎么做 early_stopping 的驗證集時貼了一段 Spark 代碼,其中有一個詭異的 broadcast 調(diào)用。這個 broadcast 被用來把 validation data 發(fā)送到各個數(shù)據(jù)分區(qū)去做驗證集。我們之前代碼中選取了 30 天的數(shù)據(jù)來做 validation set,這就導致有些 broadcast 的數(shù)據(jù)會非常巨大。Spark 的 broadcast 數(shù)據(jù)默認也會觸發(fā) compression,進一步加劇了這個問題。所以我們有幾個改進點:
除了流程優(yōu)化,還可以借鑒一些參數(shù)優(yōu)化的經(jīng)驗(玄學調(diào)參處處有)。這里主要參考了幾篇文章:
通過一系列調(diào)整和實驗,最終確定了一組設(shè)置:
spark.io.compression.lz4.blockSize='512k'
spark.serializer='org.apache.spark.serializer.KryoSerializer'
spark.kryoserializer.buffer.max='512m'
spark.shuffle.file.buffer='1m'
在解決卡頓問題的基礎(chǔ)上,進一步把整體訓練時間從之前的 55 分鐘縮短到了 25 分鐘左右。改完參數(shù)后看到的 CPU 使用曲線就正常多了:
看前面提到了這么多 mmlspark 的問題,我們還一直堅持使用,感覺一定是真愛了!其實在整個使用過程中也調(diào)研過一些其它的框架和方案。
Spark 自己就帶了機器學習相關(guān)的庫,其中就有 GBDT 的實現(xiàn)。在項目推進過程中,我們也嘗試了 Spark ML 中的 GBDT 模型來進行訓練。需要注意的是,應(yīng)該使用 pyspark.ml.regression.GBTRegressor
這個類,而不是之前的 pyspark.mllib.tree.GradientBoostedTrees
。實現(xiàn)起來還是非常順利的,但實測下來發(fā)現(xiàn)性能非常的差,感覺用 Spark 來構(gòu)建整個迭代式的算法流程,整體的效率不高。所以這個方案看起來不可行。
Lightgbm 庫自己也帶了分布式訓練的方案,具體可以參考 官方文檔[8]
從支持的功能上和官方提供的性能報告上感覺效果非常優(yōu)秀。例如可以根據(jù)數(shù)據(jù)與特征的大小,選擇 feature/data/voting 三種不同的并行方案。官方給的例子里,15 億行數(shù)據(jù),60+特征,在多機上做 data parallel 訓練,整體性能可以達到線性擴展的效果。
但有一個問題,lightgbm 本身并沒有帶數(shù)據(jù)分發(fā)的能力。官網(wǎng)上的例子可以看出用戶需要自行做數(shù)據(jù),配置,可執(zhí)行文件的轉(zhuǎn)換和分發(fā),然后自行在多節(jié)點上啟動訓練任務(wù)。其它幾點都還好說,可以用 pssh
, pscp
之類的命令。但數(shù)據(jù)分發(fā)和轉(zhuǎn)換就是一個比較大的問題了。如果仔細往下想,就會發(fā)現(xiàn)整體實現(xiàn)思路可能跟 mmlspark 目前的實現(xiàn)非常類似了。
所以總體看下來,如果要自行集成 native lightgbm 做分布式訓練,可能會需要寫一個類似 mmlspark 的庫,工作量大,也沒有太大必要。
順帶考察了 lgb 的兩個老競爭對手,看看他們的分布式方案如何。Catboost 完全沒有對分布式的支持,率先出局。Xgboost 里提到如果用 Spark 做數(shù)據(jù)處理,建議使用 Xgboost4j-Spark。粗略看了下還挺不錯的,起碼文檔比 mmlspark 好多了!不過美中不足的是這個庫叫 4j,所以只有 Java/Scala 接口,木有 Python 支持,集成起來會有一些難度。
從 mmlspark 的思路出發(fā),自然會想到其實也可以結(jié)合別的并行計算框架,例如 Dask。Xgboost, lightgbm 都有相關(guān)的庫,用 Dask 來支持分布式訓練。這個方案看起來有幾個問題:
不過幾個 Dask 庫里的實現(xiàn)方式還是值得一看的,提供了一些并發(fā)框架集成 Python 算法包的思路。
騰訊家的一個較為知名 分布式機器學習庫[9],基于 parameter server 架構(gòu)實現(xiàn)了一系列算法,支持分布式大規(guī)模的訓練。大致看了一下這個庫,有幾個 concern:
所以結(jié)論還是不傾向使用,或許后面可以了解下 Angel 的實現(xiàn)方式,看看有沒有借鑒意義。
TF 里面也有 GBDT[10],可能很多人都不知道吧……這個我還沒試過,另外真的要用的話還得考察下 TensorFlow on Spark。當然好處是說不定還能試一些網(wǎng)絡(luò)模型看看效果如何。
在 Spark 完成特征構(gòu)建后,就可以通過不同的策略,把需要分模型訓練的數(shù)據(jù)分別存儲到分布式文件系統(tǒng)中,然后利用一些多機任務(wù)管理的框架(例如 Ray,我們的數(shù)據(jù)開發(fā)平臺等),在不同的節(jié)點上分別取對應(yīng)的數(shù)據(jù)進行訓練。這個方案的好處是靈活性非常高,不再局限于 Spark 平臺能支持的算法,可以跑任意我們熟悉的算法模塊。但缺點就是任務(wù)管理,高可用,failover,可運維性等等方面都會有些 concern。
另外一個問題就是數(shù)據(jù)交換的額外開銷。我們在項目中也嘗試了一下在單機做 lgb 訓練,也就是在 Spark 特征構(gòu)建完之后,通過 toPandas
調(diào)用把數(shù)據(jù)集轉(zhuǎn)化為 pandas dataframe,然后再調(diào)用原生 Python lightgbm 庫來做模型訓練與預(yù)測。這個操作相比寫入文件系統(tǒng)還少了磁盤 IO 的開銷,但是整體測試下來用原生 lgb 訓練整體時間需要 41 分鐘,而直接用 mmlspark 用相同的配置只需要 27 分鐘。假設(shè)我們有 10 億行數(shù)據(jù),70 個特征,那么每次訓練的數(shù)據(jù)量達到了 500GB 左右,這部分的開銷還是非??捎^的。H2O 有個產(chǎn)品叫 Sparkling Water,就實現(xiàn)了 internal/external 兩種 backend,其中也提到了內(nèi)外部處理的優(yōu)劣和適用場景等。
總結(jié)來看,目前還是mmlspark方案更加合適。后續(xù)我們也會持續(xù)關(guān)注類似框架,并比較評估大規(guī)模的GBDT模型與深度學習模型的表現(xiàn)差異。
Spark 官方網(wǎng)站: https://spark.apache.org/downloads.html
[2]官方示例: https://databricks.com/blog/2017/10/30/introducing-vectorized-udfs-for-pyspark.html
[3]官方文檔: https://mmlspark.blob.core.windows.net/docs/1.0.0-rc1/pyspark/mmlspark.lightgbm.html
[4]issue: https://github.com/Azure/mmlspark/issues/775
[5]lz4 壓縮: https://spark.apache.org/docs/latest/configuration.html#compression-and-serialization
[6]Facebook 關(guān)于 Spark 性能調(diào)優(yōu)的分享: https://www.slideshare.net/databricks/tuning-apache-spark-for-largescale-workloads-gaoxiang-liu-and-sital-kedia
[7]Intel 關(guān)于壓縮算法的分享: https://www.slideshare.net/databricks/best-practice-of-compressiondecompression-codes-in-apache-spark-with-sophia-sun-and-qi-xie
[8]官方文檔: https://lightgbm.readthedocs.io/en/latest/Parallel-Learning-Guide.html
[9]分布式機器學習庫: https://github.com/Angel-ML/sona
[10]TF 里面也有 GBDT: https://www.tensorflow.org/api_docs/python/tf/estimator/BoostedTreesRegressor
聯(lián)系客服