Skip to content

Commit 7580ff9

Browse files
authored
Add support for predict_proba in classification notebook (#3531)
* modify to support predict_proba * fix code style
1 parent 201edfd commit 7580ff9

File tree

1 file changed

+26
-27
lines changed

1 file changed

+26
-27
lines changed

sdk/python/jobs/automl-standalone-jobs/automl-classification-task-bankmarketing/automl-classification-task-bankmarketing.ipynb

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -655,15 +655,19 @@
655655
"outputs": [],
656656
"source": [
657657
"model_name = \"bankmarketing-model\"\n",
658+
"\n",
658659
"model = Model(\n",
659-
" path=f\"azureml://jobs/{best_run.info.run_id}/outputs/artifacts/outputs/mlflow-model/\",\n",
660+
" path=f\"azureml://jobs/{best_run.info.run_id}/outputs/artifacts/outputs/model.pkl\",\n",
660661
" name=model_name,\n",
661662
" description=\"my sample classification model\",\n",
662-
" type=AssetTypes.MLFLOW_MODEL,\n",
663663
")\n",
664664
"\n",
665-
"# for downloaded file\n",
666-
"# model = Model(path=\"artifact_downloads/outputs/model.pkl\", name=model_name)\n",
665+
"env = Environment(\n",
666+
" name=\"automl-tabular-env\",\n",
667+
" description=\"environment for automl inference\",\n",
668+
" image=\"mcr.microsoft.com/azureml/openmpi4.1.0-ubuntu20.04:latest\",\n",
669+
" conda_file=\"artifact_downloads/outputs/conda_env_v_1_0_0.yml\",\n",
670+
")\n",
667671
"\n",
668672
"registered_model = ml_client.models.create_or_update(model)"
669673
]
@@ -690,26 +694,18 @@
690694
"metadata": {},
691695
"outputs": [],
692696
"source": [
697+
"code_configuration = CodeConfiguration(\n",
698+
" code=\"artifact_downloads/outputs/\", scoring_script=\"scoring_file_v_2_0_0.py\"\n",
699+
")\n",
700+
"\n",
693701
"deployment = ManagedOnlineDeployment(\n",
694702
" name=\"bankmarketing-deploy\",\n",
695703
" endpoint_name=online_endpoint_name,\n",
696704
" model=registered_model.id,\n",
697705
" instance_type=\"Standard_DS2_V2\",\n",
698706
" instance_count=1,\n",
699-
" liveness_probe=ProbeSettings(\n",
700-
" failure_threshold=30,\n",
701-
" success_threshold=1,\n",
702-
" timeout=2,\n",
703-
" period=10,\n",
704-
" initial_delay=2000,\n",
705-
" ),\n",
706-
" readiness_probe=ProbeSettings(\n",
707-
" failure_threshold=10,\n",
708-
" success_threshold=1,\n",
709-
" timeout=10,\n",
710-
" period=10,\n",
711-
" initial_delay=2000,\n",
712-
" ),\n",
707+
" code_configuration=code_configuration,\n",
708+
" environment=env,\n",
713709
")"
714710
]
715711
},
@@ -753,29 +749,32 @@
753749
"source": [
754750
"# test the blue deployment with some sample data\n",
755751
"import pandas as pd\n",
752+
"import json\n",
756753
"\n",
757754
"test_data = pd.read_csv(\"./data/test-mltable-folder/bank_marketing_test_data.csv\")\n",
758755
"\n",
759756
"test_data = test_data.drop(\"y\", axis=1)\n",
760757
"\n",
761758
"test_data_json = test_data.to_json(orient=\"records\", indent=4)\n",
762-
"data = (\n",
763-
" '{ \\\n",
764-
" \"input_data\": {\"data\": '\n",
765-
" + test_data_json\n",
766-
" + \"}}\"\n",
767-
")\n",
759+
"\n",
760+
"data = {\n",
761+
" \"Inputs\": {\"data\": json.loads(test_data_json)},\n",
762+
" \"GlobalParameters\": {\n",
763+
" \"method\": \"predict_proba\" # use method \"predict\" when probability is not needed\n",
764+
" },\n",
765+
"}\n",
768766
"\n",
769767
"request_file_name = \"sample-request-bankmarketing.json\"\n",
770768
"\n",
771769
"with open(request_file_name, \"w\") as request_file:\n",
772-
" request_file.write(data)\n",
770+
" json.dump(data, request_file)\n",
773771
"\n",
774-
"ml_client.online_endpoints.invoke(\n",
772+
"res = ml_client.online_endpoints.invoke(\n",
775773
" endpoint_name=online_endpoint_name,\n",
776774
" deployment_name=\"bankmarketing-deploy\",\n",
777775
" request_file=request_file_name,\n",
778-
")"
776+
")\n",
777+
"res"
779778
]
780779
},
781780
{

0 commit comments

Comments
 (0)