-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathmodel_inference.py
More file actions
54 lines (32 loc) · 1.35 KB
/
model_inference.py
File metadata and controls
54 lines (32 loc) · 1.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Databricks notebook source
# MAGIC %md ### Apply model to new records
# COMMAND ----------
import mlflow.spark
from mlflow.tracking import MlflowClient
from databricks.feature_store import FeatureStoreClient
client = MlflowClient()
fs = FeatureStoreClient()
# COMMAND ----------
# MAGIC %md Simulate new records; Notice that only the record IDs need to be passes. The MLflow model has recorded the feature looking logic and will join the necessary features to the record Ids.
# COMMAND ----------
new_passenger_records = (spark.table('default.passenger_labels')
.select('PassengerId')
.limit(20))
display(new_passenger_records)
# COMMAND ----------
# MAGIC %md Get model's unique identifier
# COMMAND ----------
def get_run_id(model_name, stage='Production'):
"""Get production model id from Model Registry"""
prod_run = [run for run in client.search_model_versions(f"name='{model_name}'")
if run.current_stage == stage][0]
return prod_run.run_id
# Replace the first parameter with your model's name
run_id = get_run_id('feature_store_models', stage='Production')
run_id
# COMMAND ----------
# MAGIC %md Score records
# COMMAND ----------
model_uri = f'runs:/{run_id}/model'
with_predictions = fs.score_batch(model_uri, new_passenger_records)
display(with_predictions)