@@ -135,8 +135,8 @@ def export(
135135 "TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default."
136136 )
137137
138- if input_ids is not None and inputs_embeds is not None :
139- raise ValueError ("Can't specify both input_ids and inputs_embeds." )
138+ if not input_ids ^ inputs_embeds :
139+ raise ValueError ("Need to specify either input_ids or inputs_embeds." )
140140
141141 example_cache_position = (
142142 cache_position if cache_position is not None else torch .tensor ([0 ], dtype = torch .long , device = model_device )
@@ -150,24 +150,14 @@ def export(
150150 dynamic_shapes = dynamic_shapes ,
151151 strict = strict if strict is not None else True ,
152152 )
153- elif inputs_embeds :
153+ else : # inputs_embeds
154154 exported_program = torch .export .export (
155155 self .model ,
156156 args = (),
157157 kwargs = {"inputs_embeds" : inputs_embeds , "cache_position" : example_cache_position },
158158 dynamic_shapes = dynamic_shapes ,
159159 strict = strict if strict is not None else True ,
160160 )
161- else :
162- # No inputs specified, assume we are exporting with input_ids for legacy reasons.
163- example_input_ids = torch .tensor ([[1 ]], dtype = torch .long , device = model_device )
164- exported_program = torch .export .export (
165- self .model ,
166- args = (),
167- kwargs = {"input_ids" : example_input_ids , "cache_position" : example_cache_position },
168- dynamic_shapes = dynamic_shapes ,
169- strict = strict if strict is not None else True ,
170- )
171161
172162 return exported_program
173163
0 commit comments