diff --git a/deepsegment/train.py b/deepsegment/train.py index 0d01ef0..d8b1d86 100644 --- a/deepsegment/train.py +++ b/deepsegment/train.py @@ -106,7 +106,7 @@ def generate_data(lines, max_sents_per_example=6, n_examples=1000): return x, y -def train(x, y, vx, vy, epochs, batch_size, save_folder, glove_path): +def train(x, y, vx, vy, epochs, batch_size, save_folder, glove_path,embedding_dim): """ Trains a deepsegment model. @@ -126,7 +126,9 @@ def train(x, y, vx, vy, epochs, batch_size, save_folder, glove_path): save_folder (str): path for the directory where checkpoints should be saved. - glove_path (str): path to 100d word vectors. + glove_path (str): path to word vectors. + + dim_vectors (str): dimension of word vectors. (50, 100, 200, 300) """ @@ -141,7 +143,7 @@ def train(x, y, vx, vy, epochs, batch_size, save_folder, glove_path): checkpoint = ModelCheckpoint(checkpoint_path, verbose=1, save_best_only=True, mode='max', monitor='f1') earlystop = EarlyStopping(patience=3, monitor='f1', mode='max') - model = seqtag_keras.Sequence(embeddings=embeddings) + model = seqtag_keras.Sequence(embeddings=embeddings,word_embedding_dim=embedding_dim,word_lstm_size=embedding_dim) model.fit(x, y, x_valid=vx, y_valid=vy, epochs=epochs, batch_size=batch_size, callbacks=[checkpoint, earlystop]) @@ -154,7 +156,7 @@ def train(x, y, vx, vy, epochs, batch_size, save_folder, glove_path): 'italian': 'it' } -def finetune(lang_code, x, y, vx, vy, name=None, epochs=5, batch_size=16, lr=0.0001): +def finetune(lang_code, x, y, vx, vy, name=None, epochs=5, batch_size=16, lr=0.0001,embedding_dim=100): """ Finetunes an existing deepsegment model. @@ -197,9 +199,9 @@ def finetune(lang_code, x, y, vx, vy, name=None, epochs=5, batch_size=16, lr=0.0 model = BiLSTMCRF(char_vocab_size=p.char_vocab_size, word_vocab_size=p.word_vocab_size, num_labels=p.label_size, - word_embedding_dim=100, + word_embedding_dim=embedding_dim, char_embedding_dim=25, - word_lstm_size=100, + word_lstm_size=embedding_dim, char_lstm_size=25, fc_dim=100, dropout=0.2,