Skip to content

Commit a99dc93

Browse files
[Image generation] Non public API changes for image generation (openvinotoolkit#1146)
Now, LCM, SD and SDXL pipelines for image to image generation work correctly. Public API will be later
1 parent c36fe02 commit a99dc93

File tree

14 files changed

+285
-94
lines changed

14 files changed

+285
-94
lines changed

samples/python/visual_language_chat/visual_language_chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def read_image(path: str) -> Tensor:
3636
3737
'''
3838
pic = Image.open(path).convert("RGB")
39-
image_data = np.array(pic.getdata()).reshape(1, pic.size[1], pic.size[0], 3).astype(np.byte)
39+
image_data = np.array(pic.getdata()).reshape(1, pic.size[1], pic.size[0], 3).astype(np.uint8)
4040
return Tensor(image_data)
4141

4242

src/cpp/include/openvino/genai/image_generation/autoencoder_kl.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "openvino/runtime/properties.hpp"
1414

1515
#include "openvino/genai/visibility.hpp"
16+
#include "openvino/genai/image_generation/generation_config.hpp"
1617

1718
namespace ov {
1819
namespace genai {
@@ -74,7 +75,7 @@ class OPENVINO_GENAI_EXPORTS AutoencoderKL {
7475

7576
ov::Tensor decode(ov::Tensor latent);
7677

77-
ov::Tensor encode(ov::Tensor image);
78+
ov::Tensor encode(ov::Tensor image, std::shared_ptr<Generator> generator);
7879

7980
const Config& get_config() const;
8081

src/cpp/src/debug_utils.hpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include <string>
77
#include <iostream>
8+
#include <fstream>
89

910
#include <openvino/runtime/tensor.hpp>
1011

@@ -31,3 +32,42 @@ inline void print_tensor(std::string name, ov::Tensor tensor) {
3132
print_array(tensor.data<ov::float16>(), tensor.get_size());
3233
}
3334
}
35+
36+
template <typename tensor_T, typename file_T>
37+
void _read_tensor_step(tensor_T* data, size_t i, std::ifstream& file, size_t& printed_elements, bool assign) {
38+
const size_t print_size = 10;
39+
40+
file_T value;
41+
file >> value;
42+
43+
// this mode is used to fallback to reference data to check further execution
44+
if (assign)
45+
data[i] = value;
46+
47+
if (std::abs(value - data[i]) > 1e-7 && printed_elements < print_size) {
48+
std::cout << i << ") ref = " << value << " act = " << static_cast<file_T>(data[i]) << std::endl;
49+
++printed_elements;
50+
}
51+
}
52+
53+
inline void read_tensor(const std::string& file_name, ov::Tensor tensor, bool assign = false) {
54+
std::ifstream file(file_name.c_str());
55+
OPENVINO_ASSERT(file.is_open(), "Failed to open file ", file_name);
56+
57+
std::cout << "Opening " << file_name << std::endl;
58+
std::cout << "tensor shape " << tensor.get_shape() << std::endl;
59+
60+
for (size_t i = 0, printed_elements = 0; i < tensor.get_size(); ++i) {
61+
if (tensor.get_element_type() == ov::element::f32)
62+
_read_tensor_step<float, float>(tensor.data<float>(), i, file, printed_elements, assign);
63+
else if (tensor.get_element_type() == ov::element::f64)
64+
_read_tensor_step<double, double>(tensor.data<double>(), i, file, printed_elements, assign);
65+
else if (tensor.get_element_type() == ov::element::u8)
66+
_read_tensor_step<uint8_t, float>(tensor.data<uint8_t>(), i, file, printed_elements, assign);
67+
else {
68+
OPENVINO_THROW("Unsupported tensor type ", tensor.get_element_type(), " by read_tensor");
69+
}
70+
}
71+
72+
std::cout << "Closing " << file_name << std::endl;
73+
}

src/cpp/src/image_generation/models/autoencoder_kl.cpp

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,48 @@
2222
namespace ov {
2323
namespace genai {
2424

25+
class DiagonalGaussianDistribution {
26+
public:
27+
explicit DiagonalGaussianDistribution(ov::Tensor parameters)
28+
: m_parameters(parameters) {
29+
ov::Shape shape = parameters.get_shape();
30+
OPENVINO_ASSERT(shape[0] == 1, "Batch size must be 1");
31+
shape[1] /= 2;
32+
33+
m_mean = ov::Tensor(parameters.get_element_type(), shape, parameters.data());
34+
m_std = ov::Tensor(m_mean.get_element_type(), shape);
35+
ov::Tensor logvar(parameters.get_element_type(), shape, m_mean.data<float>() + m_mean.get_size());
36+
37+
float * logvar_data = logvar.data<float>();
38+
float * std_data = m_std.data<float>();
39+
40+
for (size_t i = 0; i < logvar.get_size(); ++i) {
41+
logvar_data[i] = std::min(std::max(logvar_data[i], -30.0f), 20.0f);
42+
std_data[i] = std::exp(0.5 * logvar_data[i]);
43+
}
44+
}
45+
46+
ov::Tensor sample(std::shared_ptr<Generator> generator) const {
47+
OPENVINO_ASSERT(generator, "Generator must not be nullptr");
48+
49+
ov::Tensor rand_tensor = generator->randn_tensor(m_mean.get_shape());
50+
51+
float * rand_tensor_data = rand_tensor.data<float>();
52+
const float * mean_data = m_mean.data<float>();
53+
const float * std_data = m_std.data<float>();
54+
55+
for (size_t i = 0; i < rand_tensor.get_size(); ++i) {
56+
rand_tensor_data[i] = mean_data[i] + std_data[i] * rand_tensor_data[i];
57+
}
58+
59+
return rand_tensor;
60+
}
61+
62+
private:
63+
ov::Tensor m_parameters;
64+
ov::Tensor m_mean, m_std;
65+
};
66+
2567
size_t get_vae_scale_factor(const std::filesystem::path& vae_config_path) {
2668
std::ifstream file(vae_config_path);
2769
OPENVINO_ASSERT(file.is_open(), "Failed to open ", vae_config_path);
@@ -141,12 +183,34 @@ ov::Tensor AutoencoderKL::decode(ov::Tensor latent) {
141183
return m_decoder_request.get_output_tensor();
142184
}
143185

144-
ov::Tensor AutoencoderKL::encode(ov::Tensor image) {
186+
ov::Tensor AutoencoderKL::encode(ov::Tensor image, std::shared_ptr<Generator> generator) {
145187
OPENVINO_ASSERT(m_encoder_request, "VAE encoder model must be compiled first. Cannot infer non-compiled model");
146188

147189
m_encoder_request.set_input_tensor(image);
148190
m_encoder_request.infer();
149-
return m_encoder_request.get_output_tensor();
191+
192+
ov::Tensor output = m_encoder_request.get_output_tensor(), latent;
193+
194+
ov::CompiledModel compiled_model = m_encoder_request.get_compiled_model();
195+
auto outputs = compiled_model.outputs();
196+
OPENVINO_ASSERT(outputs.size() == 1, "AutoencoderKL encoder model is expected to have a single output");
197+
198+
const std::string output_name = outputs[0].get_any_name();
199+
if (output_name == "latent_sample") {
200+
latent = output;
201+
} else if (output_name == "latent_parameters") {
202+
latent = DiagonalGaussianDistribution(output).sample(generator);
203+
} else {
204+
OPENVINO_THROW("Unexpected output name for AutoencoderKL encoder '", output_name, "'");
205+
}
206+
207+
// apply shift and scaling factor
208+
float * latent_data = latent.data<float>();
209+
for (size_t i = 0; i < latent.get_size(); ++i) {
210+
latent_data[i] = (latent_data[i] - m_config.shift_factor) * m_config.scaling_factor;
211+
}
212+
213+
return latent;
150214
}
151215

152216
const AutoencoderKL::Config& AutoencoderKL::get_config() const {
@@ -171,25 +235,21 @@ void AutoencoderKL::merge_vae_image_pre_processing() const {
171235
ppp.input().preprocess()
172236
.convert_layout()
173237
.convert_element_type(ov::element::f32)
174-
.scale(255.0f / 2.0f)
238+
// this is less accurate that in VaeImageProcessor::normalize
239+
.scale(255.0 / 2.0)
175240
.mean(1.0f);
176241

177-
// apply m_config.scaling_factor as last step
178-
ppp.output().postprocess().custom([scaling_factor = m_config.scaling_factor](const ov::Output<ov::Node>& port) {
179-
auto c_scaling_factor = std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1}, scaling_factor);
180-
return std::make_shared<ov::op::v1::Multiply>(port, c_scaling_factor);
181-
});
182-
183242
ppp.build();
184243
}
185244

186245
void AutoencoderKL::merge_vae_image_post_processing() const {
187246
ov::preprocess::PrePostProcessor ppp(m_decoder_model);
188247

189248
// scale and shift input before VAE decoder
190-
ppp.input().preprocess()
191-
.scale(m_config.scaling_factor)
192-
.mean(-m_config.shift_factor);
249+
if (m_config.scaling_factor != 1.0f)
250+
ppp.input().preprocess().scale(m_config.scaling_factor);
251+
if (m_config.shift_factor != 0.0f)
252+
ppp.input().preprocess().mean(-m_config.shift_factor);
193253

194254
// apply VaeImageProcessor normalization steps
195255
// https://github.com/huggingface/diffusers/blob/v0.30.1/src/diffusers/image_processor.py#L159

src/cpp/src/image_generation/models/sd3_transformer_2d_model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ SD3Transformer2DModel::Config::Config(const std::filesystem::path& config_path)
2929
SD3Transformer2DModel::SD3Transformer2DModel(const std::filesystem::path& root_dir)
3030
: m_config(root_dir / "config.json") {
3131
m_model = utils::singleton_core().read_model((root_dir / "openvino_model.xml").string());
32-
m_vae_scale_factor = ov::genai::get_vae_scale_factor(root_dir.parent_path() / "vae_decoder" / "config.json");
32+
m_vae_scale_factor = get_vae_scale_factor(root_dir.parent_path() / "vae_decoder" / "config.json");
3333
}
3434

3535
SD3Transformer2DModel::SD3Transformer2DModel(const std::filesystem::path& root_dir,

src/cpp/src/image_generation/schedulers/ddim.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,22 @@ void DDIMScheduler::set_timesteps(size_t num_inference_steps, float strength) {
114114
default:
115115
OPENVINO_THROW("Unsupported value for 'timestep_spacing'");
116116
}
117+
118+
// apply 'strength' used in image generation
119+
// in diffusers, it's https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L711
120+
{
121+
size_t init_timestep = std::min<size_t>(num_inference_steps * strength, num_inference_steps);
122+
size_t t_start = std::max<size_t>(num_inference_steps - init_timestep, 0);
123+
m_timesteps = std::vector<int64_t>(m_timesteps.begin() + t_start, m_timesteps.end());
124+
}
117125
}
118126

119127
std::map<std::string, ov::Tensor> DDIMScheduler::step(ov::Tensor noise_pred, ov::Tensor latents, size_t inference_step, std::shared_ptr<Generator> generator) {
120128
// noise_pred - model_output
121129
// latents - sample
122130
// inference_step
123131

124-
size_t timestep = get_timesteps()[inference_step];
132+
size_t timestep = m_timesteps[inference_step];
125133

126134
// get previous step value (=t-1)
127135
int prev_timestep = timestep - m_config.num_train_timesteps / m_num_inference_steps;
@@ -205,7 +213,7 @@ void DDIMScheduler::add_noise(ov::Tensor init_latent, std::shared_ptr<Generator>
205213
int64_t latent_timestep = m_timesteps.front();
206214

207215
float sqrt_alpha_prod = std::sqrt(m_alphas_cumprod[latent_timestep]);
208-
float sqrt_one_minus_alpha_prod = std::sqrt(1.0f - m_alphas_cumprod[latent_timestep]);
216+
float sqrt_one_minus_alpha_prod = std::sqrt(1.0 - m_alphas_cumprod[latent_timestep]);
209217

210218
ov::Tensor rand_tensor = generator->randn_tensor(init_latent.get_shape());
211219

src/cpp/src/image_generation/schedulers/euler_discrete.cpp

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,14 @@ EulerDiscreteScheduler::EulerDiscreteScheduler(const Config& scheduler_config) :
102102
m_sigmas.push_back(0);
103103

104104
m_step_index = -1;
105+
m_begin_index = -1;
105106
}
106107

107108
void EulerDiscreteScheduler::set_timesteps(size_t num_inference_steps, float strength) {
108109
// TODO: support `timesteps` and `sigmas` inputs
109110
m_timesteps.clear();
110111
m_sigmas.clear();
111-
m_step_index = -1;
112+
m_step_index = m_begin_index = -1;
112113

113114
m_num_inference_steps = num_inference_steps;
114115
std::vector<float> sigmas;
@@ -192,17 +193,29 @@ void EulerDiscreteScheduler::set_timesteps(size_t num_inference_steps, float str
192193
OPENVINO_THROW("Unsupported value for 'final_sigmas_type'");
193194
}
194195
m_sigmas.push_back(sigma_last);
196+
197+
// apply 'strength' used in image generation
198+
// in diffusers, it's https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L650
199+
{
200+
size_t init_timestep = std::min<size_t>(num_inference_steps * strength, num_inference_steps);
201+
size_t t_start = std::max<size_t>(num_inference_steps - init_timestep, 0);
202+
// keep original timesteps
203+
m_schedule_timesteps = m_timesteps;
204+
// while return patched ones by 'strength' parameter
205+
m_timesteps = std::vector<int64_t>(m_timesteps.begin() + t_start, m_timesteps.end());
206+
m_begin_index = t_start;
207+
}
195208
}
196209

197210
std::map<std::string, ov::Tensor> EulerDiscreteScheduler::step(ov::Tensor noise_pred, ov::Tensor latents, size_t inference_step, std::shared_ptr<Generator> generator) {
198211
// noise_pred - model_output
199212
// latents - sample
200213
// inference_step
201214

202-
size_t timestep = get_timesteps()[inference_step];
215+
size_t timestep = m_timesteps[inference_step];
203216

204217
if (m_step_index == -1)
205-
m_step_index = 0;
218+
m_step_index = m_begin_index;
206219

207220
float sigma = m_sigmas[m_step_index];
208221
// TODO: hardcoded gamma
@@ -273,7 +286,7 @@ float EulerDiscreteScheduler::get_init_noise_sigma() const {
273286

274287
void EulerDiscreteScheduler::scale_model_input(ov::Tensor sample, size_t inference_step) {
275288
if (m_step_index == -1)
276-
m_step_index = 0;
289+
m_step_index = m_begin_index;
277290

278291
float sigma = m_sigmas[m_step_index];
279292
float* sample_data = sample.data<float>();
@@ -282,9 +295,28 @@ void EulerDiscreteScheduler::scale_model_input(ov::Tensor sample, size_t inferen
282295
}
283296
}
284297

298+
size_t EulerDiscreteScheduler::_index_for_timestep(int64_t timestep) const {
299+
for (size_t i = 0; i < m_schedule_timesteps.size(); ++i) {
300+
if (timestep == m_schedule_timesteps[i]) {
301+
return i;
302+
}
303+
}
304+
305+
OPENVINO_THROW("Failed to find index for timestep ", timestep);
306+
}
307+
285308
void EulerDiscreteScheduler::add_noise(ov::Tensor init_latent, std::shared_ptr<Generator> generator) const {
286-
// use https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_euler_discrete.py#L686
287-
OPENVINO_THROW("Not implemented");
309+
const int64_t latent_timestep = m_timesteps.front();
310+
const float sigma = m_sigmas[_index_for_timestep(latent_timestep)];
311+
312+
ov::Tensor rand_tensor = generator->randn_tensor(init_latent.get_shape());
313+
314+
float * init_latent_data = init_latent.data<float>();
315+
const float * rand_tensor_data = rand_tensor.data<float>();
316+
317+
for (size_t i = 0; i < init_latent.get_size(); ++i) {
318+
init_latent_data[i] = init_latent_data[i] + sigma * rand_tensor_data[i];
319+
}
288320
}
289321

290322
} // namespace genai

src/cpp/src/image_generation/schedulers/euler_discrete.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,12 @@ class EulerDiscreteScheduler : public IScheduler {
5555
Config m_config;
5656

5757
std::vector<float> m_alphas_cumprod, m_sigmas;
58-
std::vector<int64_t> m_timesteps;
58+
std::vector<int64_t> m_timesteps, m_schedule_timesteps;
5959
size_t m_num_inference_steps;
6060

61-
size_t m_step_index;
61+
int m_step_index, m_begin_index;
62+
63+
size_t _index_for_timestep(int64_t timestep) const;
6264
};
6365

6466
} // namespace genai

src/cpp/src/image_generation/stable_diffusion_3_pipeline.hpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,11 @@ class StableDiffusion3Pipeline : public DiffusionPipeline {
261261
ImageGenerationConfig generation_config = m_generation_config;
262262
generation_config.update_generation_config(properties);
263263

264+
if (!initial_image) {
265+
// in case of typical text to image generation, we need to ignore 'strength'
266+
generation_config.strength = 1.0f;
267+
}
268+
264269
const auto& transformer_config = m_transformer->get_config();
265270
const size_t batch_size_multiplier = do_classifier_free_guidance(generation_config.guidance_scale)
266271
? 2
@@ -558,7 +563,7 @@ class StableDiffusion3Pipeline : public DiffusionPipeline {
558563
// 6. Denoising loop
559564
ov::Tensor noisy_residual_tensor(ov::element::f32, {});
560565

561-
for (size_t inference_step = 0; inference_step < generation_config.num_inference_steps; ++inference_step) {
566+
for (size_t inference_step = 0; inference_step < timesteps.size(); ++inference_step) {
562567
// concat the same latent twice along a batch dimension in case of CFG
563568
if (batch_size_multiplier > 1) {
564569
batch_copy(latent, latent_cfg, 0, 0, generation_config.num_images_per_prompt);
@@ -650,16 +655,14 @@ class StableDiffusion3Pipeline : public DiffusionPipeline {
650655
OPENVINO_ASSERT(is_classifier_free_guidance || generation_config.negative_prompt_3 == std::nullopt,
651656
"Negative prompt 3 is not used when guidance scale < 1.0");
652657

653-
if (m_pipeline_type == PipelineType::IMAGE_2_IMAGE) {
654-
if (initial_image) {
655-
ov::Shape initial_image_shape = initial_image.get_shape();
656-
size_t height = initial_image_shape[1], width = initial_image_shape[2];
658+
if (m_pipeline_type == PipelineType::IMAGE_2_IMAGE && initial_image) {
659+
ov::Shape initial_image_shape = initial_image.get_shape();
660+
size_t height = initial_image_shape[1], width = initial_image_shape[2];
657661

658-
OPENVINO_ASSERT(generation_config.height == height,
659-
"Height for initial (", height, ") and generated (", generation_config.height,") images must be the same");
660-
OPENVINO_ASSERT(generation_config.width == width,
661-
"Width for initial (", width, ") and generated (", generation_config.width,") images must be the same");
662-
}
662+
OPENVINO_ASSERT(generation_config.height == height,
663+
"Height for initial (", height, ") and generated (", generation_config.height,") images must be the same");
664+
OPENVINO_ASSERT(generation_config.width == width,
665+
"Width for initial (", width, ") and generated (", generation_config.width,") images must be the same");
663666

664667
OPENVINO_ASSERT(generation_config.strength >= 0.0f && generation_config.strength <= 1.0f,
665668
"'Strength' generation parameter must be withion [0, 1] range");

0 commit comments

Comments
 (0)