diff --git a/src/lib.rs b/src/lib.rs index a3d0435c..19ded8c4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -256,7 +256,8 @@ //!
//! 4. Dialogue Model //! -//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT). +//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT) or +//! [GODEL](https://github.com/microsoft/GODEL). //! This pipeline allows the generation of single or multi-turn conversations between a human and a model. //! The DialoGPT's page states that //! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality diff --git a/src/pipelines/conversation.rs b/src/pipelines/conversation.rs index ed08df7a..8c2bd76e 100644 --- a/src/pipelines/conversation.rs +++ b/src/pipelines/conversation.rs @@ -12,7 +12,8 @@ // limitations under the License. //! # Multi-turn dialogue -//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT). +//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT) or +//! [GODEL](https://github.com/microsoft/GODEL). //! This pipeline allows the generation of single or multi-turn conversations between a human and a model. //! The DialoGPT's page states that //! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality @@ -59,6 +60,7 @@ use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator; use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; use crate::resources::ResourceProvider; +use crate::t5::T5Generator; use std::collections::HashMap; use tch::{Device, Kind, Tensor}; use uuid::Uuid; @@ -695,14 +697,16 @@ impl Default for ConversationManager { pub enum ConversationOption { /// Conversation based on GPT2 model GPT2(GPT2Generator), + T5(T5Generator), } impl ConversationOption { pub fn new(config: ConversationConfig) -> Result { match config.model_type { ModelType::GPT2 => Ok(ConversationOption::GPT2(GPT2Generator::new(config.into())?)), + ModelType::T5 => Ok(ConversationOption::T5(T5Generator::new(config.into())?)), _ => Err(RustBertError::InvalidConfigurationError( - "GPT2 is currently the only supported model for conversation generation" + "GPT-2 and T5 are currently the only supported model for conversation generation" .to_string(), )), } @@ -717,8 +721,12 @@ impl ConversationOption { config.into(), tokenizer, )?)), + ModelType::T5 => Ok(ConversationOption::T5(T5Generator::new_with_tokenizer( + config.into(), + tokenizer, + )?)), _ => Err(RustBertError::InvalidConfigurationError( - "GPT2 is currently the only supported model for conversation generation" + "GPT-2 and T5 are currently the only supported model for conversation generation" .to_string(), )), } @@ -729,6 +737,7 @@ impl ConversationOption { Self::GPT2(model_ref) => { Ok(*model_ref.get_eos_ids().as_ref().unwrap().first().unwrap()) } + Self::T5(model_ref) => Ok(*model_ref.get_eos_ids().as_ref().unwrap().first().unwrap()), } } @@ -736,6 +745,7 @@ impl ConversationOption { pub fn get_tokenizer(&self) -> &TokenizerOption { match self { Self::GPT2(model_ref) => model_ref._get_tokenizer(), + Self::T5(model_ref) => model_ref._get_tokenizer(), } } @@ -743,6 +753,7 @@ impl ConversationOption { pub fn get_tokenizer_mut(&mut self) -> &TokenizerOption { match self { Self::GPT2(model_ref) => model_ref._get_tokenizer_mut(), + Self::T5(model_ref) => model_ref._get_tokenizer_mut(), } } @@ -750,6 +761,7 @@ impl ConversationOption { pub fn model_type(&self) -> ModelType { match *self { Self::GPT2(_) => ModelType::GPT2, + Self::T5(_) => ModelType::T5, } } @@ -765,6 +777,19 @@ impl ConversationOption { .into_iter() .map(|output| output.indices) .collect(), + Self::T5(ref model) => model + .generate_from_ids_and_past(input_ids, attention_mask, None) + .into_iter() + .map(|output| output.indices) + .collect(), + } + } + + /// Interface method to get the model family (encoder-decoder or decoder) + fn is_encoder_decoder(&self) -> bool { + match *self { + Self::GPT2(ref generator) => generator.is_encoder_decoder(), + Self::T5(ref generator) => generator.is_encoder_decoder(), } } } @@ -915,7 +940,11 @@ impl ConversationModel { .zip(active_uuid.into_iter()) .zip(removed_padding_quantities.into_iter()) { - let generated_response = &generated_sequence[input_length - removed_padding.0..]; + let generated_response = if self.model.is_encoder_decoder() { + generated_sequence.as_slice() + } else { + &generated_sequence[input_length - removed_padding.0..] + }; conversation .generated_responses .push( @@ -1023,9 +1052,14 @@ impl ConversationModel { .get(input_idx as i64) .slice(0, 0, (max_len - input.len()) as i64, 1) .fill_(0); - let mut padded_input = vec![pad_token; max_len - input.len()]; - padded_input.extend(input); - padded_input + let padding = vec![pad_token; max_len - input.len()]; + if self.model.is_encoder_decoder() { + // right padding assumed for encoder-decoders + [input, &padding].concat() + } else { + // left padding assumed for decoders + [&padding, input].concat() + } }) .map(|tokens| Tensor::from_slice(&tokens).to(self.device)) .collect::>(); diff --git a/src/t5/t5_model.rs b/src/t5/t5_model.rs index 4eba98ee..12fd6996 100644 --- a/src/t5/t5_model.rs +++ b/src/t5/t5_model.rs @@ -61,6 +61,16 @@ impl T5ModelResources { "sentence-t5-base/model", "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/rust_model.ot", ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const GODEL_V1_1_BASE: (&'static str, &'static str) = ( + "godel-v1-1-base/model", + "https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/resolve/main/rust_model.ot", + ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const GODEL_V1_1_LARGE: (&'static str, &'static str) = ( + "godel-v1-1-large/model", + "https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq/resolve/main/rust_model.ot", + ); } impl T5ConfigResources { @@ -79,6 +89,16 @@ impl T5ConfigResources { "sentence-t5-base/config", "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/config.json", ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const GODEL_V1_1_BASE: (&'static str, &'static str) = ( + "godel-v1-1-base/config", + "https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq/resolve/main/config.json", + ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const GODEL_V1_1_LARGE: (&'static str, &'static str) = ( + "godel-v1-1-large/config", + "https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq/resolve/main/config.json", + ); } impl T5VocabResources { @@ -92,11 +112,31 @@ impl T5VocabResources { "t5-base/spiece", "https://huggingface.co/t5-base/resolve/main/spiece.model", ); + /// Shared under Apache 2.0 license by the Google team at . + pub const T5_LARGE: (&'static str, &'static str) = ( + "t5-large/spiece", + "https://huggingface.co/t5-large/resolve/main/spiece.model", + ); /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. pub const SENTENCE_T5_BASE: (&'static str, &'static str) = ( "sentence-t5-base/spiece", "https://huggingface.co/sentence-transformers/sentence-t5-base/resolve/main/spiece.model", ); + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const SENTENCE_T5_LARGE: (&'static str, &'static str) = ( + "sentence-t5-large/spiece", + "https://huggingface.co/sentence-transformers/sentence-t5-large/resolve/main/spiece.model", + ); + /// Shared under Apache 2.0 license by the Google team at . + pub const GODEL_V1_1_BASE: (&'static str, &'static str) = ( + "godel-v1-1-base/spiece", + "https://huggingface.co/t5-base/resolve/main/spiece.model", + ); + /// Shared under Apache 2.0 license by the Google team at . + pub const GODEL_V1_1_LARGE: (&'static str, &'static str) = ( + "godel-v1-1-large/spiece", + "https://huggingface.co/t5-large/resolve/main/spiece.model", + ); } const T5LANGUAGES: [Language; 3] = [Language::English, Language::French, Language::German]; diff --git a/tests/gpt2.rs b/tests/gpt2.rs index 6c31f699..299cd117 100644 --- a/tests/gpt2.rs +++ b/tests/gpt2.rs @@ -723,6 +723,7 @@ fn gpt2_beam_search_token_scores() -> anyhow::Result<()> { fn dialogpt_single_multi_turn_conversation() -> anyhow::Result<()> { // Set-up conversation model let conversation_config = ConversationConfig { + model_type: ModelType::GPT2, do_sample: false, device: Device::Cpu, ..Default::default() @@ -760,6 +761,7 @@ fn dialogpt_single_multi_turn_conversation() -> anyhow::Result<()> { fn dialogpt_multiple_multi_turn_conversation() -> anyhow::Result<()> { // Set-up conversation model let conversation_config = ConversationConfig { + model_type: ModelType::GPT2, do_sample: false, device: Device::Cpu, ..Default::default() @@ -802,6 +804,7 @@ fn dialogpt_multiple_multi_turn_conversation() -> anyhow::Result<()> { fn dialogpt_multiple_multi_turn_conversation_with_truncation() -> anyhow::Result<()> { // Set-up conversation model let conversation_config = ConversationConfig { + model_type: ModelType::GPT2, max_length: Some(36), min_length_for_response: 24, do_sample: false, @@ -851,6 +854,7 @@ fn dialogpt_multiple_multi_turn_conversation_with_truncation() -> anyhow::Result fn dialogpt_multiple_multi_turn_conversation_with_conversation_deletion() -> anyhow::Result<()> { // Set-up conversation model let conversation_config = ConversationConfig { + model_type: ModelType::GPT2, do_sample: false, device: Device::Cpu, ..Default::default() diff --git a/tests/t5.rs b/tests/t5.rs index 1cebe74d..cd33ae5f 100644 --- a/tests/t5.rs +++ b/tests/t5.rs @@ -1,4 +1,7 @@ use rust_bert::pipelines::common::ModelType; +use rust_bert::pipelines::conversation::{ + ConversationConfig, ConversationManager, ConversationModel, +}; use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel}; use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; use rust_bert::resources::RemoteResource; @@ -111,3 +114,230 @@ about exoplanets like K2-18b."]; Ok(()) } + +#[test] +#[cfg_attr(not(feature = "all-tests"), ignore)] +fn godel_single_multi_turn_conversation() -> anyhow::Result<()> { + // Set-up conversation model + let conversation_config = ConversationConfig { + model_type: ModelType::T5, + do_sample: false, + device: Device::Cpu, + model_resource: Box::new(RemoteResource::from_pretrained( + T5ModelResources::GODEL_V1_1_LARGE, + )), + config_resource: Box::new(RemoteResource::from_pretrained( + T5ConfigResources::GODEL_V1_1_LARGE, + )), + vocab_resource: Box::new(RemoteResource::from_pretrained( + T5VocabResources::GODEL_V1_1_LARGE, + )), + merges_resource: None, + ..Default::default() + }; + let conversation_model = ConversationModel::new(conversation_config)?; + + // Set-up conversation manager and add a conversation + let mut conversation_manager = ConversationManager::new(); + let conversation_id = + conversation_manager.create("Going to the movies tonight - any suggestions?"); + + // Turn 1 + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 1); + assert_eq!(output.get(&conversation_id).unwrap(), &" I'd recommend The Last Airbender. It's a great comedy and a great movie if you like comedy."); + + // Turn 2 + let _ = conversation_manager + .get(&conversation_id) + .unwrap() + .add_user_input("Is it an action movie?"); + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 1); + assert_eq!( + output.get(&conversation_id).unwrap(), + &" I'm not sure, but I've heard it's a great comedy." + ); + + // Turn 3 (no new user input) + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 0); + + Ok(()) +} + +#[test] +#[cfg_attr(not(feature = "all-tests"), ignore)] +fn godel_multiple_multi_turn_conversation() -> anyhow::Result<()> { + // Set-up conversation model + let conversation_config = ConversationConfig { + model_type: ModelType::T5, + do_sample: false, + device: Device::Cpu, + model_resource: Box::new(RemoteResource::from_pretrained( + T5ModelResources::GODEL_V1_1_LARGE, + )), + config_resource: Box::new(RemoteResource::from_pretrained( + T5ConfigResources::GODEL_V1_1_LARGE, + )), + vocab_resource: Box::new(RemoteResource::from_pretrained( + T5VocabResources::GODEL_V1_1_LARGE, + )), + merges_resource: None, + ..Default::default() + }; + let conversation_model = ConversationModel::new(conversation_config)?; + + // Set-up conversation manager and add a conversation + let mut conversation_manager = ConversationManager::new(); + let conversation_1_id = + conversation_manager.create("Going to the movies tonight - any suggestions?"); + let conversation_2_id = conversation_manager.create("What's the last book you have read?"); + + // Turn 1 + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 2); + assert_eq!(output.get(&conversation_1_id).unwrap(), &" I'd recommend The Last Airbender. It's a great comedy and a great movie if you like comedy."); + assert_eq!( + output.get(&conversation_2_id).unwrap(), + &" I read The Last of Us. It was a great book." + ); + + // Turn 2 + let _ = conversation_manager + .get(&conversation_1_id) + .unwrap() + .add_user_input("Is it an action movie?"); + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 1); + assert_eq!( + output.get(&conversation_1_id).unwrap(), + &" I'm not sure, but I've heard it's a great comedy." + ); + + // Turn 3 (no new user input) + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 0); + + Ok(()) +} + +#[test] +#[cfg_attr(not(feature = "all-tests"), ignore)] +fn godel_multiple_multi_turn_conversation_with_truncation() -> anyhow::Result<()> { + // Set-up conversation model + let conversation_config = ConversationConfig { + model_type: ModelType::T5, + max_length: Some(36), + min_length_for_response: 24, + do_sample: false, + device: Device::Cpu, + model_resource: Box::new(RemoteResource::from_pretrained( + T5ModelResources::GODEL_V1_1_LARGE, + )), + config_resource: Box::new(RemoteResource::from_pretrained( + T5ConfigResources::GODEL_V1_1_LARGE, + )), + vocab_resource: Box::new(RemoteResource::from_pretrained( + T5VocabResources::GODEL_V1_1_LARGE, + )), + merges_resource: None, + ..Default::default() + }; + let conversation_model = ConversationModel::new(conversation_config)?; + + // Set-up conversation manager and add a conversation + let mut conversation_manager = ConversationManager::new(); + let conversation_1_id = + conversation_manager.create("Going to the movies tonight - any suggestions?"); + let conversation_2_id = conversation_manager.create("Hello how are you today?"); + + // Turn 1 + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 2); + assert_eq!(output.get(&conversation_1_id).unwrap(), &" I'd recommend The Last Airbender. It's a great comedy and a great movie if you like comedy."); + assert_eq!( + output.get(&conversation_2_id).unwrap(), + &" i am a little tired from work" + ); + + // Turn 2 + let _ = conversation_manager + .get(&conversation_1_id) + .unwrap() + .add_user_input("Is it an action movie?"); + let _ = conversation_manager + .get(&conversation_2_id) + .unwrap() + .add_user_input("Fine."); + + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 2); + assert_eq!( + output.get(&conversation_1_id).unwrap(), + &" No, it's a comedy." + ); + + // Turn 3 (no new user input) + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 0); + + Ok(()) +} + +#[test] +#[cfg_attr(not(feature = "all-tests"), ignore)] +fn godel_multiple_multi_turn_conversation_with_conversation_deletion() -> anyhow::Result<()> { + // Set-up conversation model + let conversation_config = ConversationConfig { + model_type: ModelType::T5, + do_sample: false, + device: Device::Cpu, + model_resource: Box::new(RemoteResource::from_pretrained( + T5ModelResources::GODEL_V1_1_LARGE, + )), + config_resource: Box::new(RemoteResource::from_pretrained( + T5ConfigResources::GODEL_V1_1_LARGE, + )), + vocab_resource: Box::new(RemoteResource::from_pretrained( + T5VocabResources::GODEL_V1_1_LARGE, + )), + merges_resource: None, + ..Default::default() + }; + let conversation_model = ConversationModel::new(conversation_config)?; + + // Set-up conversation manager and add a conversation + let mut conversation_manager = ConversationManager::new(); + let conversation_1_id = + conversation_manager.create("Going to the movies tonight - any suggestions?"); + let conversation_2_id = conversation_manager.create("What's the last book you have read?"); + + // Turn 1 + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 2); + assert_eq!(output.get(&conversation_1_id).unwrap(), &" I'd recommend The Last Airbender. It's a great comedy and a great movie if you like comedy."); + assert_eq!( + output.get(&conversation_2_id).unwrap(), + &" I read The Last of Us. It was a great book." + ); + + // Turn 2 + let _ = conversation_manager.remove(&conversation_1_id); + let _ = conversation_manager + .get(&conversation_2_id) + .unwrap() + .add_user_input("Why do you recommend it?"); + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 1); + assert_eq!( + output.get(&conversation_2_id).unwrap(), + &" I've read it, but I'm not sure if I'd like it again. I'm not a huge fan of the genre." + ); + + // Turn 3 (no new user input) + let output = conversation_model.generate_responses(&mut conversation_manager); + assert_eq!(output.len(), 0); + + Ok(()) +}