WSDM-爱奇艺:用户留存预测挑战赛 线上0.865
賽題介紹
http://challenge.ai.iqiyi.com/detail?raceId=61600f6cef1b65639cd5eaa6
https://www.datafountain.cn/competitions/551
運(yùn)行說(shuō)明【非常重要】
賽題描述
愛(ài)奇藝是中國(guó)和世界領(lǐng)先的高品質(zhì)視頻娛樂(lè)流媒體平臺(tái),每個(gè)月有超過(guò)5億的用戶(hù)在愛(ài)奇藝上享受娛樂(lè)服務(wù)。愛(ài)奇藝秉承“悅享品質(zhì)”的品牌口號(hào),打造涵蓋影劇、綜藝、動(dòng)漫在內(nèi)的專(zhuān)業(yè)正版視頻內(nèi)容庫(kù),和“隨刻”等海量的用戶(hù)原創(chuàng)內(nèi)容,為用戶(hù)提供豐富的專(zhuān)業(yè)視頻體驗(yàn)。
愛(ài)奇藝手機(jī)端APP,通過(guò)深度學(xué)習(xí)等最新的AI技術(shù),提升用戶(hù)個(gè)性化的產(chǎn)品體驗(yàn),更好地讓用戶(hù)享受定制化的娛樂(lè)服務(wù)。我們用“N日留存分”這一關(guān)鍵指標(biāo)來(lái)衡量用戶(hù)的滿意程度。例如,如果一個(gè)用戶(hù)10月1日的“7日留存分”等于3,代表這個(gè)用戶(hù)接下來(lái)的7天里(10月2日~8日),有3天會(huì)訪問(wèn)愛(ài)奇藝APP。預(yù)測(cè)用戶(hù)的留存分是個(gè)充滿挑戰(zhàn)的難題:不同用戶(hù)本身的偏好、活躍度差異很大,另外用戶(hù)可支配的娛樂(lè)時(shí)間、熱門(mén)內(nèi)容的流行趨勢(shì)等其他因素,也有很強(qiáng)的周期性特征。
本次大賽基于愛(ài)奇藝APP脫敏和采樣后的數(shù)據(jù)信息,預(yù)測(cè)用戶(hù)的7日留存分。參賽隊(duì)伍需要設(shè)計(jì)相應(yīng)的算法進(jìn)行數(shù)據(jù)分析和預(yù)測(cè)。
數(shù)據(jù)描述
本次比賽提供了豐富的數(shù)據(jù)集,包含視頻數(shù)據(jù)、用戶(hù)畫(huà)像數(shù)據(jù)、用戶(hù)啟動(dòng)日志、用戶(hù)觀影和互動(dòng)行為日志等。針對(duì)測(cè)試集用戶(hù),需要預(yù)測(cè)每一位用戶(hù)某一日的“7日留存分”。7日留存分取值范圍從0到7,預(yù)測(cè)結(jié)果保留小數(shù)點(diǎn)后2位。
評(píng)價(jià)指標(biāo)
本次比賽是一個(gè)數(shù)值預(yù)測(cè)類(lèi)問(wèn)題。評(píng)價(jià)函數(shù)使用: 100 ? ( 1 ? 1 n ∑ i = 1 n ∣ F i ? A i 7 ∣ ) 100-(1-\frac{1}{n}\sum_{i=1}^{n}{|\frac{F_i-A_i}{7}|}) 100?(1?n1?∑i=1n?∣7Fi??Ai??∣)。
n n n是測(cè)試集用戶(hù)數(shù)量, F F F是參賽者對(duì)用戶(hù)的7日留存分預(yù)測(cè)值, A A A是真實(shí)的7日留存分真實(shí)值。
評(píng)審說(shuō)明
選手的提交應(yīng)為UTF-8編碼的csv文件。文件的格式和順序需要和測(cè)試集保持一致。參見(jiàn)競(jìng)賽數(shù)據(jù)集下載部分“sample-a”。所有預(yù)測(cè)數(shù)據(jù)保留小數(shù)點(diǎn)后2位有效數(shù)字。不符合提交格式的文件被視為無(wú)效,并浪費(fèi)一次提交機(jī)會(huì)。
本次比賽分為A、B 2個(gè)階段。2個(gè)階段的訓(xùn)練集是一樣的,但需要選手預(yù)測(cè)的測(cè)試集不同。
- A階段截止2022.01.17。A階段測(cè)試集包含15001個(gè)需要預(yù)測(cè)的用戶(hù),用于A階段比賽和排行榜。每個(gè)用戶(hù)提供用戶(hù)id和end_date日期。選手需要預(yù)測(cè)這個(gè)用戶(hù),對(duì)應(yīng)[end_date+1 ~ end_date+7],這未來(lái)7天里的7日留存分。
- B階段從2022.01.17開(kāi)始,截止2022.01.20。屆時(shí)系統(tǒng)會(huì)重新提供B階段測(cè)試集。B階段測(cè)試集更大,包含35000個(gè)需要預(yù)測(cè)的用戶(hù)。B階段使用單獨(dú)的排行榜,其余細(xì)節(jié)和A階段一致。
最后比賽結(jié)果以B階段成績(jī)?yōu)闇?zhǔn),同時(shí)選手需要提交輔助性材料,證明其成績(jī)合法有效。
特別說(shuō)明
- 愛(ài)奇藝AI競(jìng)賽平臺(tái)作為大賽官網(wǎng),是挑戰(zhàn)賽主戰(zhàn)場(chǎng)。若參與主賽場(chǎng)比賽,選手需登錄大賽官網(wǎng)完成注冊(cè)報(bào)名,并務(wù)必在大賽官網(wǎng)主賽場(chǎng)提交預(yù)測(cè)結(jié)果。
- 每支參賽隊(duì)伍的隊(duì)伍人數(shù)最多5人。
- DataFountain競(jìng)賽平臺(tái)作為2022WSDM用戶(hù)留存預(yù)測(cè)挑戰(zhàn)賽的練習(xí)場(chǎng),在A榜階段為參賽選手提供每天額外2次的成績(jī)測(cè)試提交機(jī)會(huì),助力大家在大賽官網(wǎng)主賽場(chǎng)中取得優(yōu)異成績(jī)。
- A榜階段,DataFountain競(jìng)賽平臺(tái)和大賽官網(wǎng)主賽場(chǎng)均可提交預(yù)測(cè)結(jié)果;B榜階段,請(qǐng)參賽選手前往大賽官網(wǎng)主賽場(chǎng)提交預(yù)測(cè)結(jié)果。該賽題最終排名榜單以大賽官網(wǎng)主賽場(chǎng)發(fā)布的結(jié)果為準(zhǔn)。
數(shù)據(jù)集解釋
1. User portrait data
| user_id | |
| device_type | iOS, Android |
| device_rom | rom of the device |
| device_ram | ram of the device |
| sex | |
| age | |
| education | |
| occupation_status | |
| territory_code |
2. App launch logs
| user_id | |
| date | Desensitization, started from 0 |
| launch_type | spontaneous or launched by other apps & deep-links |
3. Video related data
| item_id | id of the video |
| father_id | album id, if the video is an episode of an album collection |
| cast | a list of actors/actresses |
| duration | video length |
| tag_list | a list of tags |
4. User playback data
| user_id | |
| item_id | |
| playtime | video playback time |
| date | timestamp of the behavior |
5. User interaction data
| user_id | |
| item_id | |
| interact_type | interaction types such as posting comments, etc. |
| date | timestamp of the behavior |
時(shí)間線
- 2021.10.15:賽事啟動(dòng),賽題正式發(fā)布,開(kāi)放賽題數(shù)據(jù)集,開(kāi)放組隊(duì)報(bào)名。
- 2021.11.15:開(kāi)放公開(kāi)排名榜,參賽者可以提交預(yù)測(cè)結(jié)果。2021.12.20: 報(bào)名截止
- 2022.01.17: A階段停止提交結(jié)果,B階段測(cè)試集、排行榜開(kāi)放。
- 2022.01.20: B階段停止提交結(jié)果
- 2022.01.21: B階段TOP5團(tuán)隊(duì)解釋文檔停止提交(提交方式稍后公布)
- 2022.01.25: 公布最終成績(jī)
- 2022.02.17: Top 3隊(duì)伍報(bào)告會(huì)及獎(jiǎng)項(xiàng)頒發(fā)
獎(jiǎng)項(xiàng)設(shè)置
- 冠軍隊(duì)伍: 一支 ($2000)
- 亞軍隊(duì)伍: 一支 ($800)
- 季軍隊(duì)伍: 一支 ($500)
基礎(chǔ)字段分析
user_portrait
| user_id | |
| device_type | iOS, Android |
| device_rom | rom of the device |
| device_ram | ram of the device |
| sex | |
| age | |
| education | |
| occupation_status | |
| territory_code |
| 10209854 | 2.0 | 5731 | 109581 | 1.0 | 2.0 | 0.0 | 1.0 | 865101.0 |
| 10230057 | 2.0 | 1877 | 20888 | 1.0 | 4.0 | 0.0 | 1.0 | 864102.0 |
| 10268855 | 2.0 | NaN | NaN | 1.0 | 3.0 | NaN | NaN | NaN |
| 10268855 | 2.0 | NaN | NaN | 1.0 | 3.0 | NaN | NaN | NaN |
有一個(gè)用戶(hù)記錄存在重復(fù),考慮剔除。
user_portrait = user_portrait.drop_duplicates()device_type
device_type 為類(lèi)別類(lèi)型,根據(jù)手機(jī)系統(tǒng)占比,猜測(cè)2為安卓,1為ios,3為wp,4為未知或其他
user_portrait['device_type'].value_counts() 2.0 480055 1.0 85322 3.0 28909 4.0 2280 Name: device_type, dtype: int64ram 和 rom
在手機(jī)上,ROM用來(lái)存放數(shù)據(jù),如系統(tǒng)程序,應(yīng)用程序,音頻,視頻和文檔的,由于視頻等存儲(chǔ)空間大,所以ROM比RAM大很多,現(xiàn)在主流手機(jī)都是8G的空間
RAM又叫運(yùn)行內(nèi)存,存放臨時(shí)程序的,速度要遠(yuǎn)大于ROM,現(xiàn)在主流手機(jī)都是1G的RAM,RAM越大,手機(jī)運(yùn)行越快,玩大型游戲,也就越流暢
# 提取手機(jī)信息 user_portrait['device_ram'] = user_portrait['device_ram'].apply(lambda x: str(x).split(';')[0]) user_portrait['device_rom'] = user_portrait['device_rom'].apply(lambda x: str(x).split(';')[0]) /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/ipykernel_launcher.py:1: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value insteadSee the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy"""Entry point for launching an IPython kernel. /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/ipykernel_launcher.py:2: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value insteadSee the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy sns.distplot(user_portrait['device_ram']) <matplotlib.axes._subplots.AxesSubplot at 0x7f97602fc650>[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來(lái)直接上傳(img-wzt0EGWL-1646533563012)(output_16_1.png)]
sns.distplot(user_portrait['device_rom']) <matplotlib.axes._subplots.AxesSubplot at 0x7f97602597d0>sex
user_portrait['sex'].value_counts() 1.0 308846 2.0 281612 Name: sex, dtype: int64age
sns.distplot(user_portrait['age']) <matplotlib.axes._subplots.AxesSubplot at 0x7f9760069950>education
sns.distplot(user_portrait['education']) <matplotlib.axes._subplots.AxesSubplot at 0x7f0cc1cd5610>occupation_status
sns.distplot(user_portrait['occupation_status']) <matplotlib.axes._subplots.AxesSubplot at 0x7f0c793e0d90>territory_code
用戶(hù)常駐地域編號(hào)
sns.distplot(user_portrait['territory_code']) <matplotlib.axes._subplots.AxesSubplot at 0x7f0c791543d0>app_launch
| user_id | |
| date | Desensitization, started from 0 |
| launch_type | spontaneous or launched by other apps & deep-links |
| 10157996 | 0 | 129 |
| 10139583 | 0 | 129 |
| 10000000 | 0 | 131 |
| 10000000 | 0 | 132 |
| 10000000 | 0 | 141 |
| 10000000 | 0 | 164 |
| 10000000 | 0 | 179 |
video_related
| item_id | id of the video |
| father_id | album id, if the video is an episode of an album collection |
| cast | a list of actors/actresses |
| duration | video length |
| tag_list | a list of tags |
| 24403453.0 | 6.0 | NaN | 50365080;50338575;50313222;50165986 | NaN |
| 22838795.0 | 7.0 | NaN | 50001708;50323515;50125414 | NaN |
user_playback
user_playback.head()| 10057286 | 20628283.0 | 2208.612 | 145 |
| 10522615 | 23930557.0 | 31.054 | 145 |
| 10494028 | 20173699.0 | 115.952 | 145 |
| 10181987 | 21350426.0 | 1.585 | 145 |
| 10439175 | 22946929.0 | 51.726 | 145 |
user_interaction
| user_id | |
| item_id | |
| interact_type | interaction types such as posting comments, etc. |
| date | timestamp of the behavior |
| 10243056 | 22635954 | 1 | 213 |
| 10203565 | 24723827 | 3 | 213 |
探索性數(shù)據(jù)分析
- app_launch
- 歷史一天、三天、一周、一個(gè)月、三個(gè)月的行為
| 10052988 | 0 | 147 |
| 10052988 | 0 | 149 |
| 10007813 | 205 |
| 10052988 | 210 |
| 10279068 | 200 |
| 10546696 | 216 |
| 10406659 | 183 |
| ... | ... |
| 10355586 | 205 |
| 10589773 | 210 |
| 10181954 | 218 |
| 10544736 | 164 |
| 10354569 | 187 |
15001 rows × 2 columns
特征工程
# del user_interaction, user_portrait, user_playback, app_launch, video_related!mkdir wsdm_model_data !python3 baseline_feature_engineering.py mkdir: cannot create directory ‘wsdm_model_data’: File exists構(gòu)建模型 + 訓(xùn)練
!unzip data.zip Archive: data.zip replace app_launch_logs.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: ^C import pandas as pd import numpy as np import json import mathdata_dir = "./wsdm_model_data/" # 處理訓(xùn)練集數(shù)據(jù) data = pd.read_csv(data_dir + "train_data.txt", sep="\t") data["launch_seq"] = data.launch_seq.apply(lambda x: json.loads(x)) data["playtime_seq"] = data.playtime_seq.apply(lambda x: json.loads(x)) data["duration_prefer"] = data.duration_prefer.apply(lambda x: json.loads(x)) data["interact_prefer"] = data.interact_prefer.apply(lambda x: json.loads(x)) # shuffle data data = data.sample(frac=1).reset_index(drop=True) data.columns Index(['user_id', 'end_date', 'label', 'launch_seq', 'playtime_seq','duration_prefer', 'father_id_score', 'cast_id_score', 'tag_score','device_type', 'device_ram', 'device_rom', 'sex', 'age', 'education','occupation_status', 'territory_score', 'interact_prefer'],dtype='object') data| 10309777 | 165 | 6 | [0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, ... | [0, 0, 0, 0, 0, 0, 0.9414, 0, 0, 0.9998, 0.943... | [0.0, 0.0, 0.0, 0.0, 0.08, 0.0, 0.04, 0.0, 0.0... | 1.209317 | 1.353447 | 0.178947 | 0.194954 | -0.740852 | 1.043355 | -0.955892 | -0.319111 | -0.544818 | 0.746096 | 0.167180 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ... |
| 10117035 | 123 | 0 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] | 0.000000 | 0.000000 | 0.000000 | 0.194954 | -1.195884 | -1.173106 | -0.955892 | -0.319111 | -0.544818 | -1.340308 | 0.000000 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
| 10413843 | 149 | 0 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] | 0.000000 | 0.000000 | 0.000000 | -2.041925 | -0.637283 | -0.701308 | -0.955892 | -0.319111 | 0.755516 | 0.746096 | -1.106625 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
| 10209341 | 165 | 0 | [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ... | [0, 0, 0, 0, 0, 0, 0, 0.0475, 0, 0, 0, 0, 0, 0... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] | 0.000000 | 0.000000 | 0.000000 | 0.194954 | 0.150032 | -0.117076 | -0.955892 | -0.319111 | -0.544818 | 0.746096 | 0.940850 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
| 10430657 | 162 | 0 | [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0492... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] | 0.000000 | 0.000000 | 0.000000 | 0.194954 | 1.012626 | -0.145958 | 1.046141 | 0.000000 | -0.544818 | 0.000000 | -0.743187 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 10070331 | 122 | 1 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] | 0.000000 | 0.000000 | 0.000000 | 0.194954 | 0.191747 | 1.228884 | -0.955892 | -0.319111 | -0.544818 | 0.746096 | -0.480041 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
| 10056030 | 115 | 2 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | [0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, ... | -0.299726 | 0.000000 | 0.388082 | 0.194954 | -1.195884 | -0.834187 | 1.046141 | 0.828011 | -0.544818 | -1.340308 | -1.524485 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
| 10235314 | 137 | 0 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | [0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 1.0, 0.5, 0.0, ... | -0.866054 | 0.000000 | -0.084836 | 0.194954 | 1.020778 | 1.262729 | -0.955892 | -0.319111 | -0.544818 | 0.746096 | 0.838748 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
| 10014483 | 195 | 1 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... | 0.288450 | 0.760564 | 0.511767 | 0.194954 | -0.796952 | -0.111235 | -0.955892 | 1.975134 | -0.544818 | 0.746096 | -1.638692 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
| 10446094 | 157 | 0 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] | 0.000000 | 0.000000 | 0.000000 | 0.194954 | 0.000000 | -0.857147 | -0.955892 | -0.319111 | -0.544818 | -1.340308 | -0.891480 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
600001 rows × 18 columns
import paddle from paddle.io import DataLoader, Dataset# 定義模型數(shù)據(jù)集 class CoggleDataset(Dataset):def __init__(self, df):super(CoggleDataset, self).__init__()self.df = dfself.feat_col = list(set(self.df.columns) - set(['user_id', 'end_date', 'label', 'launch_seq', 'playtime_seq', 'duration_prefer', 'interact_prefer']))self.df_feat = self.df[self.feat_col]# 定義需要參與訓(xùn)練的字段def __getitem__(self, index):launch_seq = self.df['launch_seq'].iloc[index]playtime_seq = self.df['playtime_seq'].iloc[index]duration_prefer = self.df['duration_prefer'].iloc[index]interact_prefer = self.df['interact_prefer'].iloc[index]feat = self.df_feat.iloc[index].values.astype(np.float32)launch_seq = paddle.to_tensor(launch_seq).astype(paddle.float32)playtime_seq = paddle.to_tensor(playtime_seq).astype(paddle.float32)duration_prefer = paddle.to_tensor(duration_prefer).astype(paddle.float32)interact_prefer = paddle.to_tensor(interact_prefer).astype(paddle.float32)feat = paddle.to_tensor(feat).astype(paddle.float32)label = paddle.to_tensor(self.df['label'].iloc[index]).astype(paddle.float32)return launch_seq, playtime_seq, duration_prefer, interact_prefer, feat, labeldef __len__(self):return len(self.df) import paddle# 定義模型,這里是LSTM + FC class CoggleModel(paddle.nn.Layer):def __init__(self):super(CoggleModel, self).__init__()# 序列建模self.launch_seq_gru = paddle.nn.GRU(1, 32)self.playtime_seq_gru = paddle.nn.GRU(1, 32)# 全連接層self.fc1 = paddle.nn.Linear(102, 64)self.fc2 = paddle.nn.Linear(64, 1)def forward(self, launch_seq, playtime_seq, duration_prefer, interact_prefer, feat):launch_seq = launch_seq.reshape((-1, 32, 1))playtime_seq = playtime_seq.reshape((-1, 32, 1))launch_seq_feat = self.launch_seq_gru(launch_seq)[0][:, :, 0]playtime_seq_feat = self.playtime_seq_gru(playtime_seq)[0][:, :, 0]all_feat = paddle.concat([launch_seq_feat, playtime_seq_feat, duration_prefer, interact_prefer, feat], 1)all_feat_fc1 = self.fc1(all_feat)all_feat_fc2 = self.fc2(all_feat_fc1)return all_feat_fc2模型訓(xùn)練
from tqdm import tqdm import warnings warnings.filterwarnings("ignore")# 模型訓(xùn)練函數(shù) def train(model, train_loader, optimizer, criterion):model.train()train_loss = []for launch_seq, playtime_seq, duration_prefer, interact_prefer, feat, label in tqdm(train_loader):pred = model(launch_seq, playtime_seq, duration_prefer, interact_prefer, feat)loss = criterion(pred, label)loss.backward()optimizer.step()optimizer.clear_grad()train_loss.append(loss.item())return np.mean(train_loss)# 模型驗(yàn)證函數(shù) def validate(model, val_loader, optimizer, criterion):model.eval()val_loss = []for launch_seq, playtime_seq, duration_prefer, interact_prefer, feat, label in tqdm(val_loader):pred = model(launch_seq, playtime_seq, duration_prefer, interact_prefer, feat)loss = criterion(pred, label)loss.backward()optimizer.step()optimizer.clear_grad()val_loss.append(loss.item())return np.mean(val_loss)# 模型預(yù)測(cè)函數(shù) def predict(model, test_loader):model.eval()test_pred = []for launch_seq, playtime_seq, duration_prefer, interact_prefer, feat, label in tqdm(test_loader):pred = model(launch_seq, playtime_seq, duration_prefer, interact_prefer, feat)test_pred.append(pred.numpy())return test_pred from sklearn.model_selection import StratifiedKFold# 模型多折訓(xùn)練 skf = StratifiedKFold(n_splits=7) fold = 0 for tr_idx, val_idx in skf.split(data, data['label']):train_dataset = CoggleDataset(data.iloc[tr_idx])val_dataset = CoggleDataset(data.iloc[val_idx])# 定義模型、損失函數(shù)和優(yōu)化器model = CoggleModel()optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.001)criterion = paddle.nn.MSELoss()train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)# 每個(gè)epoch訓(xùn)練for epoch in range(3):train_loss = train(model, train_loader, optimizer, criterion)val_loss = validate(model, val_loader, optimizer, criterion)print(fold, epoch, train_loss, val_loss)paddle.save(model.state_dict(), f"model_{fold}.pdparams")fold += 1 W1128 20:18:14.128268 128 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1 W1128 20:18:14.132313 128 device_context.cc:465] device: 0, cuDNN Version: 7.6.1%| | 131/16072 [00:05<09:05, 29.24it/s]模型預(yù)測(cè)
test_data = pd.read_csv(data_dir + "test_data.txt", sep="\t") test_data["launch_seq"] = test_data.launch_seq.apply(lambda x: json.loads(x)) test_data["playtime_seq"] = test_data.playtime_seq.apply(lambda x: json.loads(x)) test_data["duration_prefer"] = test_data.duration_prefer.apply(lambda x: json.loads(x)) test_data["interact_prefer"] = test_data.interact_prefer.apply(lambda x: json.loads(x)) test_data['label'] = 0test_dataset = CoggleDataset(test_data) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4) test_pred_fold = np.zeros(test_data.shape[0])# 模型多折預(yù)測(cè) for idx in range(7):model = CoggleModel()layer_state_dict = paddle.load(f"model_{idx}.pdparams")model.set_state_dict(layer_state_dict)model.eval()test_pred = predict(model, test_loader)test_pred = np.vstack(test_pred)test_pred_fold += test_pred[:, 0]test_pred_fold /= 7 100%|██████████| 235/235 [00:02<00:00, 98.58it/s] 100%|██████████| 235/235 [00:02<00:00, 79.41it/s] 100%|██████████| 235/235 [00:02<00:00, 78.44it/s] 100%|██████████| 235/235 [00:02<00:00, 78.63it/s] 100%|██████████| 235/235 [00:03<00:00, 77.96it/s] 100%|██████████| 235/235 [00:02<00:00, 78.47it/s] 100%|██████████| 235/235 [00:03<00:00, 77.44it/s] test_data["prediction"] = test_pred[:, 0] test_data = test_data[["user_id", "prediction"]] # can clip outputs to [0, 7] or use other tricks test_data.to_csv("./baseline_submission.csv", index=False, header=False, float_format="%.2f")總結(jié)
改進(jìn)思路
總結(jié)
以上是生活随笔為你收集整理的WSDM-爱奇艺:用户留存预测挑战赛 线上0.865的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: Android11不如,1200万像素的
- 下一篇: 这颗“洋葱”要上市了,低调盈利2亿元能跟