之前clstm中文训练实在太慢了,为可以寻求更好的方法,继续找找别的lstm例子测试。
本文测试了keras里面的image_ocr.py例子(https://github.com/keras-team/keras/blob/master/examples/image_ocr.py)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import tensorflow as tf from keras.backend.tensorflow_backend import set_session config = tf.ConfigProto() config.gpu_options.per_process_gpu_memory_fraction = 0.3 set_session(tf.Session(config=config)) import os import itertools import codecs import re import datetime import cairocffi as cairo import editdistance import numpy as np from scipy import ndimage import pylab from keras import backend as K from keras.layers.convolutional import Conv2D, MaxPooling2D from keras.layers import Input, Dense, Activation from keras.layers import Reshape, Lambda from keras.layers.merge import add, concatenate from keras.models import Model from keras.layers.recurrent import GRU from keras.optimizers import SGD from keras.utils.data_utils import get_file from keras.preprocessing import image import keras.callbacks import matplotlib.pyplot as plt OUTPUT_DIR = 'image_ocr' # character classes and matching regex filter # 用于过滤合法字符的正则表达式 regex = r'^[a-z ]+$' alphabet = u'abcdefghijklmnopqrstuvwxyz ' np.random.seed(55) |
1 2 3 4 5 6 7 8 9 10 11 12 |
# this creates larger "blotches" of noise which look # more realistic than just adding gaussian noise # assumes greyscale with pixels ranging from 0 to 1 # 添加高斯噪音 def speckle(img): severity = np.random.uniform(0, 0.6) blur = ndimage.gaussian_filter(np.random.randn(*img.shape) * severity, 1) img_speck = (img + blur) img_speck[img_speck > 1] = 1 img_speck[img_speck <= 0] = 0 return img_speck |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
# paints the string in a random location the bounding box # also uses a random font, a slight random rotation, # and a random amount of speckle noise # 输出一串字符到一个图片里面 def paint_text(text, w, h, rotate=False, ud=False, multi_fonts=False): surface = cairo.ImageSurface(cairo.FORMAT_RGB24, w, h) with cairo.Context(surface) as context: context.set_source_rgb(1, 1, 1) # White context.paint() # this font list works in CentOS 7 if multi_fonts: fonts = ['Century Schoolbook', 'Courier', 'STIX', 'URW Chancery L', 'FreeMono'] context.select_font_face(np.random.choice(fonts), cairo.FONT_SLANT_NORMAL, np.random.choice([cairo.FONT_WEIGHT_BOLD, cairo.FONT_WEIGHT_NORMAL])) else: context.select_font_face('Courier', cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_BOLD) context.set_font_size(25) box = context.text_extents(text) border_w_h = (4, 4) if box[2] > (w - 2 * border_w_h[1]) or box[3] > (h - 2 * border_w_h[0]): raise IOError('Could not fit string into image. Max char count is too large for given image width.') # teach the RNN translational invariance by # fitting text box randomly on canvas, with some room to rotate max_shift_x = w - box[2] - border_w_h[0] max_shift_y = h - box[3] - border_w_h[1] top_left_x = np.random.randint(0, int(max_shift_x)) if ud: top_left_y = np.random.randint(0, int(max_shift_y)) else: top_left_y = h // 2 context.move_to(top_left_x - int(box[0]), top_left_y - int(box[1])) context.set_source_rgb(0, 0, 0) context.show_text(text) buf = surface.get_data() a = np.frombuffer(buf, np.uint8) a.shape = (h, w, 4) a = a[:, :, 0] # grab single channel a = a.astype(np.float32) / 255 a = np.expand_dims(a, 0) if rotate: a = image.random_rotation(a, 3 * (w - top_left_x) / w + 1) a = speckle(a) return a |
1 2 3 4 |
a = paint_text("hello", w=128, h=128) print a.shape a = a.reshape((128, 128)) plt.imshow(a) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
### 打乱list里面的图片顺序 def shuffle_mats_or_lists(matrix_list, stop_ind=None): ret = [] assert all([len(i) == len(matrix_list[0]) for i in matrix_list]) len_val = len(matrix_list[0]) if stop_ind is None: stop_ind = len_val assert stop_ind <= len_val a = list(range(stop_ind)) np.random.shuffle(a) a += list(range(stop_ind, len_val)) for mat in matrix_list: if isinstance(mat, np.ndarray): ret.append(mat[a]) elif isinstance(mat, list): ret.append([mat[i] for i in a]) else: raise TypeError('`shuffle_mats_or_lists` only supports ' 'numpy.array and list objects.') return ret |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
## 翻译字符到整型,整型到字符的两个函数 ## 因为keras的输出只接受整型 # Translation of characters to unique integer values def text_to_labels(text): ret = [] for char in text: ret.append(alphabet.find(char)) return ret # Reverse translation of numerical classes back to characters def labels_to_text(labels): ret = [] for c in labels: if c == len(alphabet): # CTC Blank ret.append("") else: ret.append(alphabet[c]) return "".join(ret) |
1 2 3 4 5 6 7 |
# only a-z and space..probably not to difficult # to expand to uppercase and symbols # 检查是否有效的字符串输入 def is_valid_str(in_str): search = re.compile(regex, re.UNICODE).search return bool(search(in_str)) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# 用于内存中产生不同的训练图片 # Uses generator functions to supply train/test with # data. Image renderings are text are created on the fly # each time with random perturbations class TextImageGenerator(keras.callbacks.Callback): def __init__(self, monogram_file, bigram_file, minibatch_size, img_w, img_h, downsample_factor, val_split, absolute_max_string_len=16): self.minibatch_size = minibatch_size self.img_w = img_w self.img_h = img_h self.monogram_file = monogram_file self.bigram_file = bigram_file self.downsample_factor = downsample_factor self.val_split = val_split self.blank_label = self.get_output_size() - 1 self.absolute_max_string_len = absolute_max_string_len def get_output_size(self): return len(alphabet) + 1 # num_words can be independent of the epoch size due to the use of generators # as max_string_len grows, num_words can grow def build_word_list(self, num_words, max_string_len=None, mono_fraction=0.5): assert max_string_len <= self.absolute_max_string_len assert num_words % self.minibatch_size == 0 assert (self.val_split * num_words) % self.minibatch_size == 0 self.num_words = num_words self.string_list = [''] * self.num_words tmp_string_list = [] self.max_string_len = max_string_len self.Y_data = np.ones([self.num_words, self.absolute_max_string_len]) * -1 self.X_text = [] self.Y_len = [0] * self.num_words # monogram file is sorted by frequency in english speech with codecs.open(self.monogram_file, mode='rt', encoding='utf-8') as f: for line in f: if len(tmp_string_list) == int(self.num_words * mono_fraction): break word = line.rstrip() if max_string_len == -1 or max_string_len is None or len(word) <= max_string_len: tmp_string_list.append(word) # bigram file contains common word pairings in english speech with codecs.open(self.bigram_file, mode='rt', encoding='utf-8') as f: lines = f.readlines() for line in lines: if len(tmp_string_list) == self.num_words: break columns = line.lower().split() word = columns[0] + ' ' + columns[1] if is_valid_str(word) and \ (max_string_len == -1 or max_string_len is None or len(word) <= max_string_len): tmp_string_list.append(word) if len(tmp_string_list) != self.num_words: raise IOError('Could not pull enough words from supplied monogram and bigram files. ') # interlace to mix up the easy and hard words self.string_list[::2] = tmp_string_list[:self.num_words // 2] self.string_list[1::2] = tmp_string_list[self.num_words // 2:] for i, word in enumerate(self.string_list): self.Y_len[i] = len(word) self.Y_data[i, 0:len(word)] = text_to_labels(word) self.X_text.append(word) self.Y_len = np.expand_dims(np.array(self.Y_len), 1) self.cur_val_index = self.val_split self.cur_train_index = 0 # each time an image is requested from train/val/test, a new random # painting of the text is performed def get_batch(self, index, size, train): # width and height are backwards from typical Keras convention # because width is the time dimension when it gets fed into the RNN if K.image_data_format() == 'channels_first': X_data = np.ones([size, 1, self.img_w, self.img_h]) else: X_data = np.ones([size, self.img_w, self.img_h, 1]) labels = np.ones([size, self.absolute_max_string_len]) input_length = np.zeros([size, 1]) label_length = np.zeros([size, 1]) source_str = [] for i in range(size): # Mix in some blank inputs. This seems to be important for # achieving translational invariance if train and i > size - 4: if K.image_data_format() == 'channels_first': X_data[i, 0, 0:self.img_w, :] = self.paint_func('')[0, :, :].T else: X_data[i, 0:self.img_w, :, 0] = self.paint_func('',)[0, :, :].T labels[i, 0] = self.blank_label input_length[i] = self.img_w // self.downsample_factor - 2 label_length[i] = 1 source_str.append('') else: if K.image_data_format() == 'channels_first': X_data[i, 0, 0:self.img_w, :] = self.paint_func(self.X_text[index + i])[0, :, :].T else: X_data[i, 0:self.img_w, :, 0] = self.paint_func(self.X_text[index + i])[0, :, :].T labels[i, :] = self.Y_data[index + i] input_length[i] = self.img_w // self.downsample_factor - 2 label_length[i] = self.Y_len[index + i] source_str.append(self.X_text[index + i]) inputs = {'the_input': X_data, 'the_labels': labels, 'input_length': input_length, 'label_length': label_length, 'source_str': source_str # used for visualization only } outputs = {'ctc': np.zeros([size])} # dummy data for dummy loss function return (inputs, outputs) def next_train(self): while 1: ret = self.get_batch(self.cur_train_index, self.minibatch_size, train=True) self.cur_train_index += self.minibatch_size if self.cur_train_index >= self.val_split: self.cur_train_index = self.cur_train_index % 32 (self.X_text, self.Y_data, self.Y_len) = shuffle_mats_or_lists( [self.X_text, self.Y_data, self.Y_len], self.val_split) yield ret def next_val(self): while 1: ret = self.get_batch(self.cur_val_index, self.minibatch_size, train=False) self.cur_val_index += self.minibatch_size if self.cur_val_index >= self.num_words: self.cur_val_index = self.val_split + self.cur_val_index % 32 yield ret def on_train_begin(self, logs={}): self.build_word_list(16000, 4, 1) self.paint_func = lambda text: paint_text(text, self.img_w, self.img_h, rotate=False, ud=False, multi_fonts=False) def on_epoch_begin(self, epoch, logs={}): # rebind the paint function to implement curriculum learning if 3 <= epoch < 6: self.paint_func = lambda text: paint_text(text, self.img_w, self.img_h, rotate=False, ud=True, multi_fonts=False) elif 6 <= epoch < 9: self.paint_func = lambda text: paint_text(text, self.img_w, self.img_h, rotate=False, ud=True, multi_fonts=True) elif epoch >= 9: self.paint_func = lambda text: paint_text(text, self.img_w, self.img_h, rotate=True, ud=True, multi_fonts=True) if epoch >= 21 and self.max_string_len < 12: self.build_word_list(32000, 12, 0.5) |
1 2 3 4 5 6 7 8 9 10 |
# 计算预测和真实值之间的“损失率”loss # the actual loss calc occurs here despite it not being # an internal Keras loss function def ctc_lambda_func(args): y_pred, labels, input_length, label_length = args # the 2 is critical here since the first couple outputs of the RNN # tend to be garbage: y_pred = y_pred[:, 2:, :] return K.ctc_batch_cost(labels, y_pred, input_length, label_length) |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
# 一般来说,需要一个beam search找出最优的结果(根据语义模型),这里作为例子,只是找出最大的置信度字符。 # For a real OCR application, this should be beam search with a dictionary # and language model. For this example, best path is sufficient. def decode_batch(test_func, word_batch): out = test_func([word_batch])[0] ret = [] for j in range(out.shape[0]): out_best = list(np.argmax(out[j, 2:], 1)) out_best = [k for k, g in itertools.groupby(out_best)] outstr = labels_to_text(out_best) ret.append(outstr) return ret |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
# 在epoch结束的时候,用于输出可视化结果。 class VizCallback(keras.callbacks.Callback): def __init__(self, run_name, test_func, text_img_gen, num_display_words=6): self.test_func = test_func self.output_dir = os.path.join( OUTPUT_DIR, run_name) self.text_img_gen = text_img_gen self.num_display_words = num_display_words if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) def show_edit_distance(self, num): num_left = num mean_norm_ed = 0.0 mean_ed = 0.0 while num_left > 0: word_batch = next(self.text_img_gen)[0] num_proc = min(word_batch['the_input'].shape[0], num_left) decoded_res = decode_batch(self.test_func, word_batch['the_input'][0:num_proc]) for j in range(num_proc): edit_dist = editdistance.eval(decoded_res[j], word_batch['source_str'][j]) mean_ed += float(edit_dist) mean_norm_ed += float(edit_dist) / len(word_batch['source_str'][j]) num_left -= num_proc mean_norm_ed = mean_norm_ed / num mean_ed = mean_ed / num print('\nOut of %d samples: Mean edit distance: %.3f Mean normalized edit distance: %0.3f' % (num, mean_ed, mean_norm_ed)) def on_epoch_end(self, epoch, logs={}): self.model.save_weights(os.path.join(self.output_dir, 'weights%02d.h5' % (epoch))) self.show_edit_distance(256) word_batch = next(self.text_img_gen)[0] res = decode_batch(self.test_func, word_batch['the_input'][0:self.num_display_words]) if word_batch['the_input'][0].shape[0] < 256: cols = 2 else: cols = 1 for i in range(self.num_display_words): pylab.subplot(self.num_display_words // cols, cols, i + 1) if K.image_data_format() == 'channels_first': the_input = word_batch['the_input'][i, 0, :, :] else: the_input = word_batch['the_input'][i, :, :, 0] pylab.imshow(the_input.T, cmap='Greys_r') pylab.xlabel('Truth = \'%s\'\nDecoded = \'%s\'' % (word_batch['source_str'][i], res[i])) fig = pylab.gcf() fig.set_size_inches(10, 13) pylab.savefig(os.path.join(self.output_dir, 'e%02d.png' % (epoch))) pylab.close() |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
run_name = datetime.datetime.now().strftime('%Y:%m:%d:%H:%M:%S') start_epoch = 0 stop_epoch = 20 img_w = 128 # Input Parameters img_h = 64 words_per_epoch = 16000 val_split = 0.2 val_words = int(words_per_epoch * (val_split)) # Network parameters conv_filters = 16 kernel_size = (3, 3) pool_size = 2 time_dense_size = 32 rnn_size = 512 minibatch_size = 32 if K.image_data_format() == 'channels_first': input_shape = (1, img_w, img_h) else: input_shape = (img_w, img_h, 1) fdir = os.path.dirname( get_file('wordlists.tgz', origin='http://www.mythic-ai.com/datasets/wordlists.tgz', untar=True)) img_gen = TextImageGenerator(monogram_file=os.path.join(fdir, 'wordlist_mono_clean.txt'), bigram_file=os.path.join(fdir, 'wordlist_bi_clean.txt'), minibatch_size=minibatch_size, img_w=img_w, img_h=img_h, downsample_factor=(pool_size ** 2), val_split=words_per_epoch - val_words ) |
1 2 3 4 5 6 7 |
## 测试一下 TextImageGenerator img_gen.on_train_begin() ret = img_gen.get_batch(0, 20, train=True) X_data = ret[0]['the_input'] one_image = X_data[0].reshape((128, 64)) print one_image.shape plt.imshow(one_image) |
建立RNN模型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
act = 'relu' input_data = Input(name='the_input', shape=input_shape, dtype='float32') inner = Conv2D(conv_filters, kernel_size, padding='same', activation=act, kernel_initializer='he_normal', name='conv1')(input_data) inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max1')(inner) inner = Conv2D(conv_filters, kernel_size, padding='same', activation=act, kernel_initializer='he_normal', name='conv2')(inner) inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max2')(inner) conv_to_rnn_dims = (img_w // (pool_size ** 2), (img_h // (pool_size ** 2)) * conv_filters) inner = Reshape(target_shape=conv_to_rnn_dims, name='reshape')(inner) # cuts down input size going into RNN: inner = Dense(time_dense_size, activation=act, name='dense1')(inner) # Two layers of bidirectional GRUs # GRU seems to work as well, if not better than LSTM: gru_1 = GRU(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='gru1')(inner) gru_1b = GRU(rnn_size, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru1_b')(inner) gru1_merged = add([gru_1, gru_1b]) gru_2 = GRU(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='gru2')(gru1_merged) gru_2b = GRU(rnn_size, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru2_b')(gru1_merged) # transforms RNN output to character activations: inner = Dense(img_gen.get_output_size(), kernel_initializer='he_normal', name='dense2')(concatenate([gru_2, gru_2b])) y_pred = Activation('softmax', name='softmax')(inner) Model(inputs=input_data, outputs=y_pred).summary() labels = Input(name='the_labels', shape=[img_gen.absolute_max_string_len], dtype='float32') input_length = Input(name='input_length', shape=[1], dtype='int64') label_length = Input(name='label_length', shape=[1], dtype='int64') # Keras doesn't currently support loss funcs with extra parameters # so CTC loss is implemented in a lambda layer loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length]) # clipnorm seems to speeds up convergence sgd = SGD(lr=0.02, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5) model = Model(inputs=[input_data, labels, input_length, label_length], outputs=loss_out) # the loss calc occurs elsewhere, so use a dummy lambda func for the loss model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=sgd) if start_epoch > 0: weight_file = os.path.join(OUTPUT_DIR, os.path.join(run_name, 'weights%02d.h5' % (start_epoch - 1))) model.load_weights(weight_file) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
__________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== the_input (InputLayer) (None, 128, 64, 1) 0 __________________________________________________________________________________________________ conv1 (Conv2D) (None, 128, 64, 16) 160 the_input[0][0] __________________________________________________________________________________________________ max1 (MaxPooling2D) (None, 64, 32, 16) 0 conv1[0][0] __________________________________________________________________________________________________ conv2 (Conv2D) (None, 64, 32, 16) 2320 max1[0][0] __________________________________________________________________________________________________ max2 (MaxPooling2D) (None, 32, 16, 16) 0 conv2[0][0] __________________________________________________________________________________________________ reshape (Reshape) (None, 32, 256) 0 max2[0][0] __________________________________________________________________________________________________ dense1 (Dense) (None, 32, 32) 8224 reshape[0][0] __________________________________________________________________________________________________ gru1 (GRU) (None, 32, 512) 837120 dense1[0][0] __________________________________________________________________________________________________ gru1_b (GRU) (None, 32, 512) 837120 dense1[0][0] __________________________________________________________________________________________________ add_2 (Add) (None, 32, 512) 0 gru1[0][0] gru1_b[0][0] __________________________________________________________________________________________________ gru2 (GRU) (None, 32, 512) 1574400 add_2[0][0] __________________________________________________________________________________________________ gru2_b (GRU) (None, 32, 512) 1574400 add_2[0][0] __________________________________________________________________________________________________ concatenate_2 (Concatenate) (None, 32, 1024) 0 gru2[0][0] gru2_b[0][0] __________________________________________________________________________________________________ dense2 (Dense) (None, 32, 28) 28700 concatenate_2[0][0] __________________________________________________________________________________________________ softmax (Activation) (None, 32, 28) 0 dense2[0][0] ================================================================================================== Total params: 4,862,444 Trainable params: 4,862,444 Non-trainable params: 0 |
1 2 3 4 5 6 |
# 用于输出字符的函数 # captures output of softmax so we can decode the output during visualization test_func = K.function([input_data], [y_pred]) ## 可是化函数 viz_cb = VizCallback(run_name, test_func, img_gen.next_val()) |
1 2 3 4 5 6 7 8 |
## 开始训练 model.fit_generator(generator=img_gen.next_train(), steps_per_epoch=(words_per_epoch - val_words) // minibatch_size, epochs=stop_epoch, validation_data=img_gen.next_val(), validation_steps=val_words // minibatch_size, callbacks=[viz_cb, img_gen], initial_epoch=start_epoch) |