Sequential Model Exmaple
seq_model = keras_model_sequential() %>%
layer_dense(units = 32, activation = "relu", input_shape = c(64)) %>%
layer_dense(units = 32, activation = "relu") %>%
layer_dense(units = 10, activation = "softmax")
Its functional equivalent
input_tensor = layer_input(shape = c(64))
output_tensor <- input_tensor %>%
layer_dense(units = 32, activation = "relu") %>%
layer_dense(units = 32, activation = "relu") %>%
layer_dense(units = 10, activation = "softmax")
model = keras_model(input_tensor, output_tensor)
E.g. : A question answering model (Inputs: 1) question, 2) reference; output: answer)
text_vocabulary_size = 10000
ques_vocabulary_size = 10000
answer_vocabulary_size = 500
text_input = layer_input(shape = list(NULL), dtype = "int32", name = "text")
encoded_text = text_input %>%
layer_embedding(input_dim = 64, output_dim = text_vocabulary_size) %>%
layer_lstm(units = 32)
question_input = layer_input(shape = list(NULL),dtype = "int32", name = "question")
encoded_question = question_input %>%
layer_embedding(input_dim = 32, output_dim = ques_vocabulary_size) %>%
layer_lstm(units = 16)
concatenated = layer_concatenate(list(encoded_text, encoded_question))
answer = concatenated %>%
layer_dense(units = answer_vocabulary_size, activation = "softmax")
model = keras_model(list(text_input, question_input), answer)
model %>% compile(
optimizer = "rmsprop",
loss = "categorical_crossentropy",
metrics = c("acc")
num_samples = 1000
max_length = 100
random_matrix = function(range, nrow, ncol) {
matrix(sample(range, size = nrow * ncol, replace = TRUE),nrow = nrow, ncol = ncol)
text = random_matrix(1:text_vocabulary_size, num_samples, max_length)
question = random_matrix(1:ques_vocabulary_size, num_samples, max_length)
answers = random_matrix(0:1, num_samples, answer_vocabulary_size)
model %>% fit(
list(text = text, question = question), answers,
epochs = 10, batch_size = 128
E.g. takes as input a series of social media posts from a single anonymous person and tries to predict attributes of that person, such as age, gender, and income level
vocabulary_size <- 50000 num_income_groups <- 10
posts_input <- layer_input(shape = list(NULL), dtype = "int32", name = "posts")
embedded_posts <- posts_input %>% layer_embedding(input_dim = 256, output_dim = vocabulary_size)
base_model <- embedded_posts %>%
layer_conv_1d(filters = 128, kernel_size = 5, activation = "relu") %>% layer_max_pooling_1d(pool_size = 5) %>%
layer_conv_1d(filters = 256, kernel_size = 5, activation = "relu") %>% layer_conv_1d(filters = 256, kernel_size = 5, activation = "relu") %>% layer_max_pooling_1d(pool_size = 5) %>%
layer_conv_1d(filters = 256, kernel_size = 5, activation = "relu") %>% layer_conv_1d(filters = 256, kernel_size = 5, activation = "relu") %>% layer_global_max_pooling_1d() %>%
layer_dense(units = 128, activation = "relu")
age_prediction <- base_model %>% layer_dense(units = 1, name = "age")
income_prediction <- base_model %>% layer_dense(num_income_groups, activation = "softmax", name = "income")
gender_prediction <- base_model %>% layer_dense(units = 1, activation = "sigmoid", name = "gender")
model <- keras_model(posts_input,list(age_prediction, income_prediction, gender_prediction))
model %>% compile(
optimizer = "rmsprop",
loss = list(
age = "mse",
income = "categorical_crossentropy",
gender = "binary_crossentropy"
loss_weights = list(
age = 0.25, income = 1, gender = 10
model %>% fit(
posts, list(
age = age_targets,
income = income_targets,
gender = gender_targets),
epochs = 10, batch_size = 64
To stop training when you measure that the validation loss in no longer improving can be achieved using a Keras callback.
Model checkpointing: saving the current weights of the model at different points during training. callback_model_checkpoint()
Early stopping: Interrupting training when the validation loss is no longer improving and saving the best model obtained during training. callback_early_stopping()
Dynamically adjusting the value of certain parameters during the training. callback_learning_rate_scheduler()
, callback_reduce_lr_on_plateau()
Logging training and validation metrics during training or visualizing the representation learned by the model as they are updated. callback_csv_logger()