Warnings while saving model
I am using the following model for machine translation using transformer with KerasNLP: `
# Encoder
encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs")
x = keras_nlp.layers.TokenAndPositionEmbedding(
vocabulary_size=ENG_VOCAB_SIZE,
sequence_length=MAX_SEQUENCE_LENGTH,
embedding_dim=EMBED_DIM,
mask_zero=True,
)(encoder_inputs)
encoder_outputs = keras_nlp.layers.TransformerEncoder(
intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
)(inputs=x)
encoder = keras.Model(encoder_inputs, encoder_outputs)
# Decoder
decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")
encoded_seq_inputs = keras.Input(shape=(None, EMBED_DIM), name="decoder_state_inputs")
x = keras_nlp.layers.TokenAndPositionEmbedding(
vocabulary_size=SND_VOCAB_SIZE,
sequence_length=MAX_SEQUENCE_LENGTH,
embedding_dim=EMBED_DIM,
mask_zero=True,
)(decoder_inputs)
x = keras_nlp.layers.TransformerDecoder(
intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
)(decoder_sequence=x, encoder_sequence=encoded_seq_inputs)
x = keras.layers.Dropout(0.5)(x)
decoder_outputs = keras.layers.Dense(SND_VOCAB_SIZE, activation="softmax")(x)
decoder = keras.Model(
[
decoder_inputs,
encoded_seq_inputs,
],
decoder_outputs,
)
decoder_outputs = decoder([decoder_inputs, encoder_outputs])
transformer = keras.Model(
[encoder_inputs, decoder_inputs],
decoder_outputs,
name="transformer",
)
`
`
transformer.summary()
transformer.compile(
"rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
hist = transformer.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)
`
When I am trying to save the model it is giving me warnings and it is not predicting correctly:
`
with open('model.pkl', 'wb') as file:
pickle.dump(transformer, file)
`
WARNING:absl:Found untraced functions such as token_embedding1_layer_call_fn, token_embedd开发者_开发知识库ing1_layer_call_and_return_conditional_losses, position_embedding1_layer_call_fn, position_embedding1_layer_call_and_return_conditional_losses, multi_head_attention_layer_call_fn while saving (showing 5 of 78). These functions will not be directly callable after loading.
精彩评论