CWデコードとディープラーニング(3)

もしも、あなたが現実世界により対応する入力シーケンスを用いることを好むのであれば、1つの方法は、以下のようにトレーニングを行うことです。

a, 101110
aa, 101110 101110
aal, 101110 101110 1011101010
aalii, 101110 101110 1011101010 1010 1010
    (many lines deleted)
zythia, 111011101010 11101011101110 1110 10101010 1010 101110
zythum, 111011101010 11101011101110 1110 10101010 10101110 11101110
zyzomys, 111011101010 11101011101110 111011101010 111011101110 11101110 11101011101110 101010
zyzzogeton, 111011101010 11101011101110 111011101010 111011101010 111011101110 1110111010 10 1110 111011101110 111010

ここでは、文字間のスペースシングは完璧に検出され、空白文字で表されていると仮定しています。

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm_1 (LSTM)                (None, 128)               67584     
_________________________________________________________________
repeat_vector_1 (RepeatVecto (None, 4, 128)            0         
_________________________________________________________________
lstm_2 (LSTM)                (None, 4, 128)            131584    
_________________________________________________________________
time_distributed_1 (TimeDist (None, 4, 27)             3483      
=================================================================
Total params: 202,651
Trainable params: 202,651
Non-trainable params: 0
_________________________________________________________________
Train on 4894 samples, validate on 100 samples
10111011101110 10101110 111010111010 1110101110                 juck     juck
1110101010 10101110 1011101010 1011101010                       bull     bull
101110111010 1010 1011101010 11101011101110                     pily     pily
1011101010 10101110 10 101010                                   lues     laes
1110111010 111011101110 1011101110 111010                       gown     gown
10101010 101110 10101110 1011101010                             haul     haul
1110 111011101110 1010 1011101010                               toil     toil

これは、hidden_size = 256 の場合です。


これは、hidden_size = 64 の場合です。

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm_1 (LSTM)                (None, 256)               266240    
_________________________________________________________________
repeat_vector_1 (RepeatVecto (None, 4, 256)            0         
_________________________________________________________________
lstm_2 (LSTM)                (None, 4, 256)            525312    
_________________________________________________________________
time_distributed_1 (TimeDist (None, 4, 27)             6939      
=================================================================
Total params: 798,491
Trainable params: 798,491
Non-trainable params: 0
_________________________________________________________________
Train on 4894 samples, validate on 100 samples
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm_1 (LSTM)                (None, 64)                17408     
_________________________________________________________________
repeat_vector_1 (RepeatVecto (None, 4, 64)             0         
_________________________________________________________________
lstm_2 (LSTM)                (None, 4, 64)             33024     
_________________________________________________________________
time_distributed_1 (TimeDist (None, 4, 27)             1755      
=================================================================
Total params: 52,187
Trainable params: 52,187
Non-trainable params: 0
_________________________________________________________________
Train on 4894 samples, validate on 100 samples
from keras.models import Sequential
from keras import layers
import numpy as np
import matplotlib.pyplot as plt


class CharTable(object):
    def __init__(self, chars):
        self.chars = sorted(set(chars))
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))

    def encode(self, token, num_rows):
        x = np.zeros((num_rows, len(self.chars)))
        for i, c in enumerate(token):
            x[i, self.char_indices] = 1
        return x

    def decode(self, x, calc_argmax=True):
        if calc_argmax:
            x = [x.argmax(axis=-1)]
        return ''.join(self.indices_char[int(v)] for v in x)


def main():
    word_len = 4
    max_len_x = 15 * word_len + (word_len - 1)
    max_len_y = word_len

    input_list = []
    output_list = []
    fin = 'words_morse10.txt'
    with open(fin, 'r') as file:
        for line in file.read().splitlines():
            mylist = line.split(", ")
            [word, morse] = mylist
            morse = morse + ' ' * (max_len_x - len(morse))
            if len(word) == word_len:
                input_list.append(morse)
                output_list.append(word)

    chars_in = '10 '
    chars_out = 'abcdefghijklmnopqrstuvwxyz '
    ctable_in = CharTable(chars_in)
    ctable_out = CharTable(chars_out)

    x = np.zeros((len(input_list), max_len_x, len(chars_in)))
    y = np.zeros((len(output_list), max_len_y, len(chars_out)))
    for i, token in enumerate(input_list):
        x[i] = ctable_in.encode(token, max_len_x)
    for i, token in enumerate(output_list):
        y[i] = ctable_out.encode(token, max_len_y)

    indices = np.arange(len(y))
    np.random.shuffle(indices)
    x = x[indices]
    y = y[indices]

    m = len(x) - 100
    (x_train, x_val) = x[:m], x[m:]
    (y_train, y_val) = y[:m], y[m:]

    hidden_size = 128
    batch_size = 128
    nlayers = 1
    epochs = 150

    model = Sequential()
    model.add(layers.LSTM(hidden_size, input_shape=(max_len_x, len(chars_in))))
    model.add(layers.RepeatVector(word_len))

    for _ in range(nlayers):
        model.add(layers.LSTM(hidden_size, return_sequences=True))

    model.add(layers.TimeDistributed(layers.Dense(len(chars_out), activation='softmax')))
    model.compile(loss='categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])
    model.summary()

    hist = model.fit(x_train, y_train, batch_size=batch_size,
                     epochs=epochs, verbose=2, validation_data=(x_val, y_val))

    predict = model.predict_classes(x_val)

    for i in range(len(x_val)):
        print("".join([ctable_in.decode(code) for code in x_val[i]]),
              "".join([ctable_out.decode(code) for code in y_val[i]]), end="     ")
        for j in range(word_len):
            print(ctable_out.indices_char[predict[i][j]], end="")
        print()

    plt.figure(figsize=(16, 5))
    plt.subplot(121)
    plt.plot(hist.history['acc'])
    plt.plot(hist.history['val_acc'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper left')
    plt.subplot(122)
    plt.plot(hist.history['loss'])
    plt.plot(hist.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper right')
    plt.show()


main()
import numpy as np

def morse_encode(word):
    return " ".join([morse_dict[i]for i in " ".join(word).split()])


def data_gen():
    fin = 'words_alpha.txt'

    with open(fin, 'r') as file:
        for word in file.read().lower().splitlines():
            print(word, morse_encode(word), sep=", ")

    return


alphabet = list("abcdefghijklmnopqrstuvwxyz")

# values = ['.-', '-...', '-.-.', '-..', '.', '..-.', '--.', '....', '..', '.---', '-.-',
#           '.-..', '--', '-.', '---', '.--.', '--.-',
#           '.-.', '...', '-', '..-', '...-', '.--', '-..-', '-.--', '--..']

values = ['101110', '1110101010', '111010111010', '11101010', '10', '1010111010',
		  '1110111010', '10101010', '1010', '10111011101110', '1110101110',
          '1011101010', '11101110', '111010', '111011101110', '101110111010',
          '11101110101110', '10111010', '101010', '1110', '10101110', '1010101110',
          '1011101110', '111010101110', '11101011101110', '111011101010']

morse_dict = dict(zip(alphabet, values))

data_gen()

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.