Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# SMIPG-NLPCC2017
Emotional Conversation Generation Task in NLPCC2017
GRU + Attention + Beam Search (+ Sample)的Seq2Seq模型
35 changes: 35 additions & 0 deletions legacy_models/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# -*- coding:utf8 -*-
import tensorflow as tf

from legacy_models.model.lstm.model import Config, Model

# 对训练数据进行切割
# splitData()


config = Config()
#config.is_pretrained = False
model = Model(config)
sess = tf.Session()
model.variables_init(sess)
model.restore(sess, 24000)
model.train(sess)
model.loss_tracker.savefig(config.save_path)

resonse = model.generate(sess, "我 对此 感到 非常 开心")
print(resonse)

'''
vocab_to_idx, idx_to_vocab, vocab_embed = loadPretrainedVector(30, 50, "./dict/vector/wiki.zh.text200.vector")

for k in vocab_to_idx.keys():
if u"他"==k:
print(k, vocab_to_idx[k])

'''

#for k in idx_to_vocab.keys():
# print(k, idx_to_vocab[k])

#for i in vocab_embed:
# print(i)
5 changes: 5 additions & 0 deletions tf_chatbot/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@
tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.99, 'Learning rate decays by this much.')
tf.app.flags.DEFINE_float('max_gradient_norm', 5.0, 'Clip gradients to this norm')
tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size to use during training')
tf.app.flags.DEFINE_integer('epoch_size', 20, 'Size of epoch')

tf.app.flags.DEFINE_integer('vocab_size', 20000, 'Dialog vocabulary size')
tf.app.flags.DEFINE_integer('size', 128, 'size of each model layer')
tf.app.flags.DEFINE_integer('num_layers', 1, 'Numbers of layers in the model')
tf.app.flags.DEFINE_integer('beam_search_size', 3, 'Size of beam search op')

tf.app.flags.DEFINE_integer('max_train_data_size', 0, 'Limit on the size of training data (0: no limit)')
tf.app.flags.DEFINE_integer('steps_per_checkpoint', 100, 'How many training steps to do per checkpoint')

tf.app.flags.DEFINE_boolean('use_sample', True, 'use sample while generating')
tf.app.flags.DEFINE_boolean('use_beam_search', True, 'use beam search while generating')

FLAGS = tf.app.flags.FLAGS

BUCKETS = [(5,10), (10, 15), (20, 25), (40, 50)]
415 changes: 0 additions & 415 deletions tf_chatbot/lib/basic/advanced_seq2seq.py

This file was deleted.

11 changes: 8 additions & 3 deletions tf_chatbot/lib/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,27 @@
from tf_chatbot.lib import data_utils
from tf_chatbot.lib.seq2seq_model_utils import create_model, get_predicted_sentence


def chat():
with tf.Session() as sess:

model = create_model(sess, forward_only=True)
model.batch_size = 1

vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.in" % FLAGS.vocab_size)
vocab_path = os.path.join(
FLAGS.data_dir,
"vocab%d.in" %
FLAGS.vocab_size)
vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)

sys.stdout.write("> ")
sys.stdout.flush()
sentence = sys.stdin.readline()

while sentence:
predicted_sentence = get_predicted_sentence(sentence, vocab, rev_vocab, model, sess)
predicted_sentence = get_predicted_sentence(
sentence, vocab, rev_vocab, model, sess)
print(predicted_sentence)
print("> ")
sys.stdout.flush()
sentence = sys.stdin.readline()
sentence = sys.stdin.readline()
79 changes: 54 additions & 25 deletions tf_chatbot/lib/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import re
import sys
import platform
import json

from tensorflow.python.platform import gfile
Expand All @@ -28,51 +29,62 @@

_ENCODING = "utf8"


def get_dialog_train_set_path(path):
return os.path.join(path, 'train_data')


def get_dialog_dev_set_path(path):
return os.path.join(path, 'dev_data')


def basic_tokenizer(sentence):
words = []
for space_separated_fragment in sentence.strip().split():
words.extend(re.split(_WORD_SPLIT, space_separated_fragment))
return [w.lower() for w in words if w]


def create_vocabulary_bak(vocabulary_path, data_path, max_vocabulary_size,
tokenizer=None, normalize_digits=True):
tokenizer=None, normalize_digits=True):
if not gfile.Exists(vocabulary_path):
print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path))
print(
"Creating vocabulary %s from data %s" %
(vocabulary_path, data_path))
vocab = {}
with gfile.GFile(data_path, mode='r') as f:
counter = 0
for line in f:
counter += 1
if counter % 100000 == 0:
print(" processing line %d" % counter)
tokens = tokenizer(line) if tokenizer else basic_tokenizer(line)
tokens = tokenizer(
line) if tokenizer else basic_tokenizer(line)
for w in tokens:
word = re.sub(_DIGIT_RE, "0", w) if normalize_digits else w
if word in vocab:
vocab[word] += 1
else:
vocab[word] = 1
vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True)
vocab_list = _START_VOCAB + \
sorted(vocab, key=vocab.get, reverse=True)
if len(vocab_list) > max_vocabulary_size:
vocab_list = vocab_list[:max_vocabulary_size]
with gfile.GFile(vocabulary_path, mode='w') as vocab_file:
for w in vocab_list:
vocab_file.write(w + '\n')


def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size,
tokenizer=None, normalize_digits=True):
tokenizer=None, normalize_digits=True):
if not gfile.Exists(vocabulary_path):
print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path))
print(
"Creating vocabulary %s from data %s" %
(vocabulary_path, data_path))
vocab = {}
data = json.load(open(data_path, encoding=_ENCODING))
data = json.load(open(data_path), encoding=_ENCODING)
counter = 0
for ((q,qe),(a,ae)) in data:
for ((q, qe), (a, ae)) in data:
counter += 1
if counter % 50000 == 0:
print(" Create_vocabulary: processing line %d" % counter)
Expand Down Expand Up @@ -110,23 +122,32 @@ def initialize_vocabulary(vocabulary_path):
rev_vocab.extend(f.readlines())

rev_vocab = [line.strip() for line in rev_vocab]
vocab = dict([(x,y) for (y,x) in enumerate(rev_vocab)]) # {'word':index}
vocab = dict([(x, y)
for (y, x) in enumerate(rev_vocab)]) # {'word':index}
return vocab, rev_vocab
else:
raise ValueError("Vocabulary file %s not found" % vocabulary_path)


def sentence_to_token_ids(sentence, vocabulary,
tokenizer=None, normalize_digits=True):
if tokenizer:
words = tokenizer(sentence)
else:
words = basic_tokenizer(sentence)
if not normalize_digits:
return [vocabulary.get(w, UNK_ID) for w in words]
return [vocabulary.get(re.sub(_DIGIT_RE, "0", w), UNK_ID) for w in words]
if platform.system() == "Windows":
if not normalize_digits:
return [vocabulary.get(w, UNK_ID) for w in words]
return [vocabulary.get(re.sub(_DIGIT_RE, "0", w), UNK_ID) for w in words]

else:
if not normalize_digits:
return [vocabulary.get(w.encode('utf8'), UNK_ID) for w in words]
return [vocabulary.get(re.sub(_DIGIT_RE, "0", w.encode('utf8')), UNK_ID) for w in words]


def data_to_token_ids_bak(data_path, target_path, vocabulary_path,
tokenizer=None, normalize_digits=True):
tokenizer=None, normalize_digits=True):
if not gfile.Exists(target_path):
print("Tokenizing data in %s" % data_path)
vocab, _ = initialize_vocabulary(vocabulary_path)
Expand All @@ -139,31 +160,38 @@ def data_to_token_ids_bak(data_path, target_path, vocabulary_path,
print(" tokenizing line %d" % counter)
token_ids = sentence_to_token_ids(line, vocab, tokenizer,
normalize_digits)
tokens_file.write(" ".join([str(tok) for tok in token_ids]) + '\n')
tokens_file.write(
" ".join([str(tok) for tok in token_ids]) + '\n')


def data_to_token_ids(data_path, target_path, vocabulary_path,
tokenizer=None, normalize_digits=True):
tokenizer=None, normalize_digits=True):
if not gfile.Exists(target_path):
print("Tokenizing data in %s" % data_path)
vocab, _ = initialize_vocabulary(vocabulary_path)
with gfile.GFile(target_path, mode='w') as tokens_file:
data = json.load(open(data_path, encoding=_ENCODING))
data = json.load(open(data_path), encoding=_ENCODING)
counter = 0
for ((q,qe),(a,ae)) in data:
for ((q, qe), (a, ae)) in data:
counter += 1
if counter % 50000 == 0:
print(" Data_to_token_ids: tokenizing line %d" % counter)
token_ids_q = sentence_to_token_ids(q, vocab, tokenizer, normalize_digits)
tokens_file.write(" ".join([str(tok) for tok in token_ids_q]) + '\n')
token_ids_a = sentence_to_token_ids(a, vocab, tokenizer, normalize_digits)
tokens_file.write(" ".join([str(tok) for tok in token_ids_a]) + '\n')
token_ids_q = sentence_to_token_ids(
q, vocab, tokenizer, normalize_digits)
tokens_file.write(" ".join([str(tok)
for tok in token_ids_q]) + '\n')
token_ids_a = sentence_to_token_ids(
a, vocab, tokenizer, normalize_digits)
tokens_file.write(" ".join([str(tok)
for tok in token_ids_a]) + '\n')


def prepare_dialog_data(data_dir, vocabulary_size):
train_path = get_dialog_train_set_path(data_dir)
dev_path = get_dialog_dev_set_path(data_dir)

vocab_path = os.path.join(data_dir, "vocab%d.in" % vocabulary_size)
create_vocabulary(vocab_path, train_path+".json", vocabulary_size)
create_vocabulary(vocab_path, train_path + ".json", vocabulary_size)

train_ids_path = train_path + (".ids%d.in" % vocabulary_size)
data_to_token_ids(train_path + ".json", train_ids_path, vocab_path)
Expand All @@ -173,6 +201,7 @@ def prepare_dialog_data(data_dir, vocabulary_size):

return (train_ids_path, dev_ids_path, vocab_path)


def read_data(tokenized_dialog_path, max_size=None):

data_set = [[] for _ in BUCKETS]
Expand All @@ -184,16 +213,16 @@ def read_data(tokenized_dialog_path, max_size=None):
counter += 1
if counter % 100000 == 0:
print(" reading data line %d" % counter)
#sys.stdout.flush()
# sys.stdout.flush()

source_ids = [int(x) for x in source.split()]
target_ids = [int(x) for x in target.split()]
target_ids.append(EOS_ID)

for bucket_id, (source_size, target_size) in enumerate(BUCKETS):
if len(source_ids) < source_size and len(target_ids) < target_size:
if len(source_ids) < source_size and len(
target_ids) < target_size:
data_set[bucket_id].append([source_ids, target_ids])
break
source, target = fh.readline(), fh.readline()
return data_set

21 changes: 14 additions & 7 deletions tf_chatbot/lib/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,34 @@
from tf_chatbot.lib.seq2seq_model_utils import create_model, get_predicted_sentence
import json


def predict():
def _get_test_dataset():
data = json.load(open(TEST_DATASET_PATH, encoding=data_utils._ENCODING))
data = json.load(open(TEST_DATASET_PATH))
test_sentences = [q for ((q, qe), _) in data]
return test_sentences

results_filename = '_'.join(['results', str(FLAGS.num_layers), str(FLAGS.size), str(FLAGS.vocab_size)])
results_filename = '_'.join(
['results', str(FLAGS.num_layers), str(FLAGS.size), str(FLAGS.vocab_size)])
results_path = os.path.join(FLAGS.results_dir, results_filename)

with tf.Session() as sess, open(results_path, 'w') as results_fh:

model = create_model(sess, forward_only=True)
model = create_model(sess, forward_only=True, use_sample=FLAGS.use_sample)
model.batch_size = 1

vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.in" % FLAGS.vocab_size)
vocab_path = os.path.join(
FLAGS.data_dir,
"vocab%d.in" %
FLAGS.vocab_size)
vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)

test_dataset = _get_test_dataset()

for sentence in test_dataset:
predicted_sentence = get_predicted_sentence(sentence, vocab, rev_vocab, model, sess)
print(sentence, '->', predicted_sentence)
predicted_sentence = get_predicted_sentence(
sentence, vocab, rev_vocab, model, sess, use_beam_search=FLAGS.use_beam_search)
print(sentence.strip(), '->')
print(predicted_sentence)

results_fh.write(predicted_sentence + '\n')
results_fh.write(predicted_sentence + '\n')
Loading