I'm following the TensorFlow Keras tutorial for text generation. The training part works perfectly, but when I try to predict the next token, I get an error. Here's all the important code:
- 制作词汇表和数据集。
vocab = sorted(set(text))
char2index = { c:i for i, c in enumerate(vocab) }
index2char = np.array(vocab)
chars_to_int = np.array([char2index[c] for c in text])
char_dataset = tf.data.Dataset.from_tensor_slices(chars_to_int)
sequences = char_dataset.batch(seq_length + 1, drop_remainder=True)
def split_input_and_target(sequence):
input_ = sequence[:-1]
target_ = sequence[1:]
return input_, target_
dataset = sequences.map(split_input_and_target)
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
- 建立模型
- (这里最重要的部分是BATCH_SIZE = 64):
model = tf.keras.Sequential()
model.add(tf.keras.layers.Embedding(len(vocab), EMBEDDING_DIM,
batch_input_shape=[BATCH_SIZE, None]))
# here are a few more layers
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam")
model.fit(dataset, epochs=EPOCHS)
- 实际上是尝试生成文本(在我开始绝望之后,几乎直接从教程中复制了该文本):
num_tokens = 100
seed = "some text"
input_eval = [char2index[c] for c in seed]
input_eval = tf.expand_dims(input_eval, 0)
text_generated = []
model.reset_states()
for i in range(num_tokens):
predictions = model(input_eval)
predictions = tf.squeeze(predictions, 0)
# more stuff
然后,我首先得到警告:
WARNING:tensorflow:Model was constructed with shape (64, None) for input Tensor("embedding_14_input:0", shape=(64, None), dtype=float32), but it was called on an input with incompatible shape (1, 9).
然后给我一个错误:
---->3 predictions = model(input_eval)
...
ValueError: Tensor's shape (9, 64, 256) is not compatible with supplied shape [9, 1, 256]
第二个数字64是我的批量大小。如果我将BATCH_SIZE更改为1,则一切正常,一切都很好,但这显然不是我希望的解决方案。