airflow.providers.google.cloud.utils.mlengine_prediction_summary
¶
DataFlowPythonOperator 呼叫的範本,用於總結 BatchPrediction。
它接受使用者函數來計算預測結果中每個實例的指標,然後匯總以輸出為摘要。
它接受以下引數
--prediction_path
:包含 BatchPrediction 結果的 GCS 資料夾,其中包含 json 格式的prediction.results-NNNNN-of-NNNNN
檔案。輸出也將儲存在此資料夾中,作為 'prediction.summary.json'。--metric_fn_encoded
:編碼函數,用於計算並傳回給定實例(作為字典)的指標元組。它應該透過base64.b64encode(dill.dumps(fn, recurse=True))
進行編碼。--metric_keys
:摘要輸出中聚合指標的逗號分隔鍵。鍵的順序和大小必須與 metric_fn 的輸出相符。摘要將有一個額外的鍵 'count',表示實例總數,因此鍵不應包含 'count'。
使用範例
當輸入檔案如下所示時
{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]}
{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]}
{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]}
{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]}
輸出檔案將會是
{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25}
在 dag 外部進行測試
subprocess.check_call(
[
"python",
"-m",
"airflow.providers.google.cloud.utils.mlengine_prediction_summary",
"--prediction_path=gs://...",
"--metric_fn_encoded=" + metric_fn_encoded,
"--metric_keys=log_loss,mse",
"--runner=DataflowRunner",
"--staging_location=gs://...",
"--temp_location=gs://...",
]
)
模組內容¶
函數¶
|
Dataflow 中使用的摘要 PTransform。 |
|
取得預測摘要。 |
- class airflow.providers.google.cloud.utils.mlengine_prediction_summary.JsonCoder[source]¶
基底:
apache_beam.coders.coders.Coder
JSON 編碼器/解碼器。