LightGBM, XGBoostであれば、途中の学習メトリクスをMlflowで記録するとき、mlflow.lightgbm.autolog()
や mlflow.xgboost.autolog()
を呼び出すだけで記録してくれるので非常に楽である。一方、CatBoost にはオートログの機能はない。
class CatBoostMlFlowCallback:
def after_iteration(self, info):
step = len(info.metrics["learn"]["Logloss"])
mlflow.log_metric("train-logloss", info.metrics["learn"]["Logloss"][-1], step)
mlflow.log_metric("valid-logloss", info.metrics["validation"]["Logloss"][-1], step)
mlflow.log_metric("valid-auc", info.metrics["validation"]["AUC"][-1], step)
return True
上のようなコールバック関数を作って、モデルを fit
するときにcallbacks=[CatBoostMlFlowCallback()]
という形で設定すると、CatBoostの場合もメトリックが記録できる。