もしも、あなたが現実世界により対応する入力シーケンスを用いることを好むのであれば、1つの方法は、以下のようにトレーニングを行うことです。
a, 101110 aa, 101110 101110 aal, 101110 101110 1011101010 aalii, 101110 101110 1011101010 1010 1010 (many lines deleted) zythia, 111011101010 11101011101110 1110 10101010 1010 101110 zythum, 111011101010 11101011101110 1110 10101010 10101110 11101110 zyzomys, 111011101010 11101011101110 111011101010 111011101110 11101110 11101011101110 101010 zyzzogeton, 111011101010 11101011101110 111011101010 111011101010 111011101110 1110111010 10 1110 111011101110 111010
ここでは、文字間のスペースシングは完璧に検出され、空白文字で表されていると仮定しています。
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= lstm_1 (LSTM) (None, 128) 67584 _________________________________________________________________ repeat_vector_1 (RepeatVecto (None, 4, 128) 0 _________________________________________________________________ lstm_2 (LSTM) (None, 4, 128) 131584 _________________________________________________________________ time_distributed_1 (TimeDist (None, 4, 27) 3483 ================================================================= Total params: 202,651 Trainable params: 202,651 Non-trainable params: 0 _________________________________________________________________ Train on 4894 samples, validate on 100 samples
10111011101110 10101110 111010111010 1110101110 juck juck 1110101010 10101110 1011101010 1011101010 bull bull 101110111010 1010 1011101010 11101011101110 pily pily 1011101010 10101110 10 101010 lues laes 1110111010 111011101110 1011101110 111010 gown gown 10101010 101110 10101110 1011101010 haul haul 1110 111011101110 1010 1011101010 toil toil
これは、hidden_size = 256 の場合です。
これは、hidden_size = 64 の場合です。
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= lstm_1 (LSTM) (None, 256) 266240 _________________________________________________________________ repeat_vector_1 (RepeatVecto (None, 4, 256) 0 _________________________________________________________________ lstm_2 (LSTM) (None, 4, 256) 525312 _________________________________________________________________ time_distributed_1 (TimeDist (None, 4, 27) 6939 ================================================================= Total params: 798,491 Trainable params: 798,491 Non-trainable params: 0 _________________________________________________________________ Train on 4894 samples, validate on 100 samples
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= lstm_1 (LSTM) (None, 64) 17408 _________________________________________________________________ repeat_vector_1 (RepeatVecto (None, 4, 64) 0 _________________________________________________________________ lstm_2 (LSTM) (None, 4, 64) 33024 _________________________________________________________________ time_distributed_1 (TimeDist (None, 4, 27) 1755 ================================================================= Total params: 52,187 Trainable params: 52,187 Non-trainable params: 0 _________________________________________________________________ Train on 4894 samples, validate on 100 samples
from keras.models import Sequential from keras import layers import numpy as np import matplotlib.pyplot as plt class CharTable(object): def __init__(self, chars): self.chars = sorted(set(chars)) self.char_indices = dict((c, i) for i, c in enumerate(self.chars)) self.indices_char = dict((i, c) for i, c in enumerate(self.chars)) def encode(self, token, num_rows): x = np.zeros((num_rows, len(self.chars))) for i, c in enumerate(token): x[i, self.char_indices] = 1 return x def decode(self, x, calc_argmax=True): if calc_argmax: x = [x.argmax(axis=-1)] return ''.join(self.indices_char[int(v)] for v in x) def main(): word_len = 4 max_len_x = 15 * word_len + (word_len - 1) max_len_y = word_len input_list = [] output_list = [] fin = 'words_morse10.txt' with open(fin, 'r') as file: for line in file.read().splitlines(): mylist = line.split(", ") [word, morse] = mylist morse = morse + ' ' * (max_len_x - len(morse)) if len(word) == word_len: input_list.append(morse) output_list.append(word) chars_in = '10 ' chars_out = 'abcdefghijklmnopqrstuvwxyz ' ctable_in = CharTable(chars_in) ctable_out = CharTable(chars_out) x = np.zeros((len(input_list), max_len_x, len(chars_in))) y = np.zeros((len(output_list), max_len_y, len(chars_out))) for i, token in enumerate(input_list): x[i] = ctable_in.encode(token, max_len_x) for i, token in enumerate(output_list): y[i] = ctable_out.encode(token, max_len_y) indices = np.arange(len(y)) np.random.shuffle(indices) x = x[indices] y = y[indices] m = len(x) - 100 (x_train, x_val) = x[:m], x[m:] (y_train, y_val) = y[:m], y[m:] hidden_size = 128 batch_size = 128 nlayers = 1 epochs = 150 model = Sequential() model.add(layers.LSTM(hidden_size, input_shape=(max_len_x, len(chars_in)))) model.add(layers.RepeatVector(word_len)) for _ in range(nlayers): model.add(layers.LSTM(hidden_size, return_sequences=True)) model.add(layers.TimeDistributed(layers.Dense(len(chars_out), activation='softmax'))) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) model.summary() hist = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=2, validation_data=(x_val, y_val)) predict = model.predict_classes(x_val) for i in range(len(x_val)): print("".join([ctable_in.decode(code) for code in x_val[i]]), "".join([ctable_out.decode(code) for code in y_val[i]]), end=" ") for j in range(word_len): print(ctable_out.indices_char[predict[i][j]], end="") print() plt.figure(figsize=(16, 5)) plt.subplot(121) plt.plot(hist.history['acc']) plt.plot(hist.history['val_acc']) plt.title('model accuracy') plt.ylabel('accuracy') plt.xlabel('epoch') plt.legend(['train', 'validation'], loc='upper left') plt.subplot(122) plt.plot(hist.history['loss']) plt.plot(hist.history['val_loss']) plt.title('model loss') plt.ylabel('loss') plt.xlabel('epoch') plt.legend(['train', 'validation'], loc='upper right') plt.show() main()
import numpy as np def morse_encode(word): return " ".join([morse_dict[i]for i in " ".join(word).split()]) def data_gen(): fin = 'words_alpha.txt' with open(fin, 'r') as file: for word in file.read().lower().splitlines(): print(word, morse_encode(word), sep=", ") return alphabet = list("abcdefghijklmnopqrstuvwxyz") # values = ['.-', '-...', '-.-.', '-..', '.', '..-.', '--.', '....', '..', '.---', '-.-', # '.-..', '--', '-.', '---', '.--.', '--.-', # '.-.', '...', '-', '..-', '...-', '.--', '-..-', '-.--', '--..'] values = ['101110', '1110101010', '111010111010', '11101010', '10', '1010111010', '1110111010', '10101010', '1010', '10111011101110', '1110101110', '1011101010', '11101110', '111010', '111011101110', '101110111010', '11101110101110', '10111010', '101010', '1110', '10101110', '1010101110', '1011101110', '111010101110', '11101011101110', '111011101010'] morse_dict = dict(zip(alphabet, values)) data_gen()