競馬AI⑩の記事で穴馬を予測するモデルの構築を記事にしましたが、今回は損失関数というものを使い、さらに精度を上げていこうという内容になります。
前回のモデルを作成していない方は、まずは以下の記事を参考にモデルを作成してみてください。
損失関数とは
損失関数は、機械学習モデルの予測が実際のデータからどれだけ離れているかを数値化する指標です。特に、競馬の穴馬予想では、不均衡なデータに対応できる損失関数が重要になります。例えば、Focal Lossは、予想が難しい穴馬(マイノリティクラス)にモデルがより注意を払うように調整することで、予測精度を高める助けとなります。これにより、予測モデルは少数派の穴馬を見つけ出す能力を改善し、競馬予想の精度向上に貢献します。
※Focal Lossとは、クラス間の不均衡が大きいデータセットにおいて、モデルがマイノリティクラスのサンプルにより焦点を当てるように設計された損失関数
うまく説明できないのでChatGPTに手伝ってもらって書きました。損失関数についてはコメントをもらっても回答が難しいので調べてみてください。
穴馬学習コード
コード
#穴馬の学習
import lightgbm as lgb
import pandas as pd
from sklearn.metrics import roc_curve,roc_auc_score
import numpy as np
from sklearn import metrics
from scipy.misc import derivative
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def focal_loss(x: np.ndarray, dtrain: lgb.Dataset, gamma: float) -> np.ndarray:
"""損失計算"""
x = sigmoid(x)
x[dtrain == 0] = 1 - x[dtrain == 0]
return -1 * (1 - x)**gamma * np.log(x)
def focal_loss_grad_hess(x: np.ndarray, dtrain: lgb.Dataset, gamma: float) -> (np.ndarray, np.ndarray,):
"""Focal Lossのgradientとhessianを返す"""
t = dtrain.label
grad = derivative(lambda _x: focal_loss(_x, t, gamma=gamma), x, n=1, dx=1e-6)
hess = derivative(lambda _x: focal_loss(_x, t, gamma=gamma), x, n=2, dx=1e-6)
return grad, hess
def split_date(df, test_size):
sorted_id_list = df.sort_values('race_id').index.unique()
train_id_list = sorted_id_list[:round(len(sorted_id_list) * (1-test_size))]
test_id_list = sorted_id_list[round(len(sorted_id_list) * (1-test_size)):]
train = df.loc[train_id_list]
test = df.loc[test_id_list]
return train, test
# モデルファイル
model_file = 'model/model_ana.txt'
# データの読み込み
data = pd.read_csv('encoded/encoded_data.csv')
# ターゲット変数の生成
data['着順'] = ((data['着順'] <= 3) & (data['人気'] >= 6) & (data['オッズ'] >= 20)).astype(int)
# 特徴量とターゲットの分割
train, test = split_date(data, 0.2)
drop_arr=['着順','オッズ','人気','上がり','走破時間','通過順']
X_train = train.drop(drop_arr, axis=1)
# ,'馬の平均着順', '馬の3着内率'
y_train = train['着順']
X_test = test.drop(drop_arr, axis=1)
y_test = test['着順']
# LightGBMデータセットの作成
train_data = lgb.Dataset(X_train, label=y_train)
valid_data = lgb.Dataset(X_test, label=y_test)
weights = len(y_train) / (2 * np.bincount(y_train))
class_weights = {0: weights[0], 1: weights[1]}
gamma = 5.0 # 例としての値
fobj_focal_loss = lambda x,y: focal_loss_grad_hess(x, y, gamma)
params = {
'metric': 'auc', # 評価指標
'learning_rate': 0.05,
'objective': fobj_focal_loss
# 他のパラメータ
}
lgb_clf = lgb.train(
params,
train_data,
valid_sets=[valid_data]
)
# 予測の取得
y_pred_train = lgb_clf.predict(X_train)
y_pred = lgb_clf.predict(X_test)
# モデルの評価
print(roc_auc_score(y_test, y_pred))
# モデルの保存
lgb_clf.save_model(model_file)
# 既に訓練したモデル 'model' とテストデータ 'X_test', 'y_test' があると仮定します
# テストデータに対する予測確率を求める
model = lgb.Booster(model_file=model_file)
test_probs = model.predict(X_test)
# AUCを計算する
fpr, tpr, thresholds = metrics.roc_curve(y_test, test_probs, pos_label=1)
test_auc = metrics.auc(fpr, tpr)
print("Test AUC: ", test_auc)
解説(損失関数部分のみ)
コード内のFocal Loss
このコードは、LightGBMを使用してデータをモデル化し、特にFocal Loss関数をカスタム損失関数として使用する方法を示しています。Focal Lossは、クラス間の不均衡が大きい場合に有効な損失関数で、特に一部のクラスが他のクラスよりもはるかに少ない場合(例えば、レアイベントの予測)に役立ちます。
コード内で定義されているfocal_loss
関数とfocal_loss_grad_hess
関数は、Focal Lossを計算し、その勾配(gradient)とヘッシアン(hessian)を返します。これらはLightGBMのカスタム損失関数として使用されます。
focal_loss
関数: この関数は、与えられた予測値x
に対してFocal Lossを計算します。sigmoid
関数を使用して予測値を確率に変換し、その確率を使用してFocal Lossを計算します。focal_loss_grad_hess
関数: この関数は、focal_loss
関数の勾配とヘッシアンを数値微分により計算します。derivative
関数(scipy.misc
から取得)を使用して、勾配(一階微分)とヘッシアン(二階微分)を求めます。これにより、LightGBMが最適化過程で使用する情報が得られます。
モデルのトレーニングと評価
lgb.train
関数にfobj_focal_loss
をobjective
パラメータとして渡すことで、Focal Lossを使用したモデルのトレーニングを行います。- モデルの評価は、予測された確率と実際のラベルを用いてROC AUCスコアを計算することで行われます。このスコアは、モデルがポジティブクラス(ここでは
着順
が条件を満たすケース)とネガティブクラスをどの程度うまく区別できるかを示します。
このコード例では、不均衡なデータセットに対してより良い予測性能を達成するために、Focal Lossをカスタマイズしています。また、gamma
パラメータの選択が重要であり、この値を調整することでモデルのフォーカスを変更できます。デフォルトの交差エントロピー損失に比べて、Focal Lossを使用することで特にマイノリティクラスの予測精度が向上することが期待されます。
まとめ
損失関数については、私の頭では詳細に理解することが出来ず、解説はChatGPTにお願いしてしまっているので間違っていたらすみません。
やはり競馬で利益を出すには穴馬を的中していく必要があるので、穴馬のモデルの精度はこれからも上げていけるように頑張っていきます!
コメント
fobj_focal_lossをobjectiveパラメータとして渡す箇所でエラーが発生します。
何か文字列形式でobjectiveを指定する必要があるのでしょうか?
当該コードとエラーメッセージはどのようなものでしょうか?