Skip to content

Commit 9036ee5

Browse files
committed
Avoid converting tensor names multiple times
1 parent fe3073b commit 9036ee5

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

model.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,10 +1174,8 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
11741174
n_dims = 1;
11751175
}
11761176

1177-
std::string new_name = prefix + name;
1178-
new_name = prefix == "unet." ? convert_tensor_name(new_name) : new_name;
11791177

1180-
TensorStorage tensor_storage(new_name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
1178+
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
11811179
tensor_storage.reverse_ne();
11821180

11831181
size_t tensor_data_size = end - begin;
@@ -1218,13 +1216,18 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s
12181216
return false;
12191217
}
12201218
for (auto ts : tensor_storages) {
1221-
if (ts.name.find("label_emb") != std::string::npos) {
1219+
if (ts.name.find("add_embedding") != std::string::npos || ts.name.find("label_emb") != std::string::npos) {
12221220
// probably SDXL
12231221
LOG_DEBUG("Fixing name for SDXL output blocks.2.2");
12241222
for (auto& tensor_storage : tensor_storages) {
1225-
auto pos = tensor_storage.name.find("model.diffusion_model.output_blocks.2.1.conv");
1223+
int len = 34;
1224+
auto pos = tensor_storage.name.find("unet.up_blocks.0.upsamplers.0.conv");
1225+
if (pos == std::string::npos) {
1226+
len = 44;
1227+
pos = tensor_storage.name.find("model.diffusion_model.output_blocks.2.1.conv");
1228+
}
12261229
if (pos != std::string::npos) {
1227-
tensor_storage.name = "model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name.substr(44);
1230+
tensor_storage.name = "model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name.substr(len);
12281231
LOG_DEBUG("NEW NAME: %s", tensor_storage.name.c_str());
12291232
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
12301233
}
@@ -1640,7 +1643,7 @@ SDVersion ModelLoader::get_sd_version() {
16401643
break;
16411644
}
16421645
}
1643-
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
1646+
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
16441647
is_unet = true;
16451648
if (has_multiple_encoders) {
16461649
is_xl = true;
@@ -1671,7 +1674,7 @@ SDVersion ModelLoader::get_sd_version() {
16711674
token_embedding_weight = tensor_storage;
16721675
// break;
16731676
}
1674-
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight") {
1677+
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
16751678
input_block_weight = tensor_storage;
16761679
input_block_checked = true;
16771680
if (found_family) {
@@ -1777,7 +1780,7 @@ ggml_type ModelLoader::get_diffusion_model_wtype() {
17771780
continue;
17781781
}
17791782

1780-
if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos) {
1783+
if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos && tensor_storage.name.find("unet.") == std::string::npos) {
17811784
continue;
17821785
}
17831786

0 commit comments

Comments
 (0)