@@ -1174,10 +1174,8 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
11741174 n_dims = 1 ;
11751175 }
11761176
1177- std::string new_name = prefix + name;
1178- new_name = prefix == " unet." ? convert_tensor_name (new_name) : new_name;
11791177
1180- TensorStorage tensor_storage (new_name , type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
1178+ TensorStorage tensor_storage (prefix + name , type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
11811179 tensor_storage.reverse_ne ();
11821180
11831181 size_t tensor_data_size = end - begin;
@@ -1218,13 +1216,18 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s
12181216 return false ;
12191217 }
12201218 for (auto ts : tensor_storages) {
1221- if (ts.name .find (" label_emb" ) != std::string::npos) {
1219+ if (ts.name .find (" add_embedding " ) != std::string::npos || ts. name . find ( " label_emb" ) != std::string::npos) {
12221220 // probably SDXL
12231221 LOG_DEBUG (" Fixing name for SDXL output blocks.2.2" );
12241222 for (auto & tensor_storage : tensor_storages) {
1225- auto pos = tensor_storage.name .find (" model.diffusion_model.output_blocks.2.1.conv" );
1223+ int len = 34 ;
1224+ auto pos = tensor_storage.name .find (" unet.up_blocks.0.upsamplers.0.conv" );
1225+ if (pos == std::string::npos) {
1226+ len = 44 ;
1227+ pos = tensor_storage.name .find (" model.diffusion_model.output_blocks.2.1.conv" );
1228+ }
12261229 if (pos != std::string::npos) {
1227- tensor_storage.name = " model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name .substr (44 );
1230+ tensor_storage.name = " model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name .substr (len );
12281231 LOG_DEBUG (" NEW NAME: %s" , tensor_storage.name .c_str ());
12291232 add_preprocess_tensor_storage_types (tensor_storages_types, tensor_storage.name , tensor_storage.type );
12301233 }
@@ -1640,7 +1643,7 @@ SDVersion ModelLoader::get_sd_version() {
16401643 break ;
16411644 }
16421645 }
1643- if (tensor_storage.name .find (" model.diffusion_model.input_blocks." ) != std::string::npos) {
1646+ if (tensor_storage.name .find (" model.diffusion_model.input_blocks." ) != std::string::npos || tensor_storage. name . find ( " unet.down_blocks. " ) != std::string::npos ) {
16441647 is_unet = true ;
16451648 if (has_multiple_encoders) {
16461649 is_xl = true ;
@@ -1671,7 +1674,7 @@ SDVersion ModelLoader::get_sd_version() {
16711674 token_embedding_weight = tensor_storage;
16721675 // break;
16731676 }
1674- if (tensor_storage.name == " model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == " model.diffusion_model.img_in.weight" ) {
1677+ if (tensor_storage.name == " model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == " model.diffusion_model.img_in.weight" || tensor_storage. name == " unet.conv_in.weight " ) {
16751678 input_block_weight = tensor_storage;
16761679 input_block_checked = true ;
16771680 if (found_family) {
@@ -1777,7 +1780,7 @@ ggml_type ModelLoader::get_diffusion_model_wtype() {
17771780 continue ;
17781781 }
17791782
1780- if (tensor_storage.name .find (" model.diffusion_model." ) == std::string::npos) {
1783+ if (tensor_storage.name .find (" model.diffusion_model." ) == std::string::npos && tensor_storage. name . find ( " unet. " ) == std::string::npos ) {
17811784 continue ;
17821785 }
17831786
0 commit comments