44# This source code is licensed under the terms described in the LICENSE file in
55# the root directory of this source tree.
66
7+ import json
78from typing import Optional
89
910import click
@@ -93,12 +94,23 @@ def get_model(ctx, model_id: str):
9394 console .print (table )
9495
9596
97+ class JSONParamType (click .ParamType ):
98+ name = "json"
99+
100+ def convert (self , value , param , ctx ):
101+ try :
102+ return json .loads (value )
103+ except json .JSONDecodeError as e :
104+ self .fail (f"Invalid JSON: { e } " , param , ctx )
105+
106+
96107@click .command (name = "register" , help = "Register a new model at distribution endpoint" )
97108@click .help_option ("-h" , "--help" )
98109@click .argument ("model_id" )
99110@click .option ("--provider-id" , help = "Provider ID for the model" , default = None )
100111@click .option ("--provider-model-id" , help = "Provider's model ID" , default = None )
101- @click .option ("--metadata" , help = "JSON metadata for the model" , default = None )
112+ @click .option ("--metadata" , type = JSONParamType (), help = "JSON metadata for the model" , default = None )
113+ @click .option ("--model-type" , type = click .Choice (["llm" , "embedding" ]), default = "llm" , help = "Model type: llm, embedding" )
102114@click .pass_context
103115@handle_client_errors ("register model" )
104116def register_model (
@@ -107,6 +119,7 @@ def register_model(
107119 provider_id : Optional [str ],
108120 provider_model_id : Optional [str ],
109121 metadata : Optional [str ],
122+ model_type : Optional [str ],
110123):
111124 """Register a new model at distribution endpoint"""
112125 client = ctx .obj ["client" ]
@@ -117,6 +130,7 @@ def register_model(
117130 provider_id = provider_id ,
118131 provider_model_id = provider_model_id ,
119132 metadata = metadata ,
133+ model_type = model_type ,
120134 )
121135 if response :
122136 console .print (f"[green]Successfully registered model { model_id } [/green]" )
0 commit comments