Abort keras training on user demand in a shiny app without leaving the app

I would like to create a shiny app where a user can use keras to train his data.

However, such training can be long and sometimes depending on settings it can happen that fitting does not converge. So, I would like that to be able to smoothly interrupt training on user demand (e.g. click a button) without having to close the app.

It is possible to stop the training thanks to callback (Training Callbacks • keras)
However, my problem, is that once training is started, Abort button can not be observed until it training is ended

So how to detect that a user has clicked Abort button ?

One workaround:

Another solution:

  • I also investigated the use of future and promises
    But I did not find a working solution. (got, error "'what' must be a function or character string", tensorflow - Parallelizing keras models in R using doParallel - Stack Overflow)
    Besides, keras is already multi-threaded and I fear that running inside parallel will lead to unexpected behaviour (seed, ...). Maybe with a gpu backend keras, but I don't want to rely on this because not all users may have a gpu.
library(keras)
library(tensorflow)
library(shiny)

batch_size <- 128
num_classes <- 10
epochs <- 5

# adapted from https://keras.rstudio.com/articles/examples/mnist_cnn.html
# Input image dimensions
img_rows <- 28
img_cols <- 28

# The data, shuffled and split between train and test sets
mnist <- dataset_mnist()
x_train <- mnist$train$x
y_train <- mnist$train$y
x_test <- mnist$test$x
y_test <- mnist$test$y

# Redefine  dimension of train/test inputs
x_train <- array_reshape(x_train, c(nrow(x_train), img_rows, img_cols, 1))
x_test <- array_reshape(x_test, c(nrow(x_test), img_rows, img_cols, 1))
input_shape <- c(img_rows, img_cols, 1)

# Transform RGB values into [0,1] range
x_train <- x_train / 255
x_test <- x_test / 255

# Convert class vectors to binary class matrices
y_train <- to_categorical(y_train, num_classes)
y_test <- to_categorical(y_test, num_classes)

model <- keras_model_sequential() %>%
  layer_conv_2d(filters = 8, kernel_size = c(3,3), activation = 'relu',
                input_shape = input_shape) %>% 
  layer_conv_2d(filters = 8, kernel_size = c(3,3), activation = 'relu') %>% 
  layer_max_pooling_2d(pool_size = c(2, 2)) %>% 
  layer_dropout(rate = 0.5) %>% 
  layer_flatten() %>% 
  layer_dense(units = 16, activation = 'relu') %>% 
  layer_dropout(rate = 0.5) %>% 
  layer_dense(units = num_classes, activation = 'softmax')

model %>% compile(
  loss = loss_categorical_crossentropy,
  optimizer = optimizer_adadelta(),
  metrics = c('accuracy')
)

ui <- fluidPage(
  tags$div(actionButton("start", "Start training"),
           actionButton("abort", "Abort training", onclick="this.classList.add('clicked')")),
)

server <- function(input, output, session) {
  observeEvent(input$abort, { # this observer can not be trigger while model is training
    str(input$abort)
  })
  observeEvent(input$start, {
    keras::k_clear_session()
    
    # adapted from https://keras.rstudio.com/articles/training_callbacks.html
    LossHistory <- R6::R6Class("LossHistory",
                               inherit = KerasCallback,
                               public = list(
                                 losses = NULL,
                                 on_epoch_begin = function(epoch, logs = list()) {
                                   cat("callback epoch:", epoch + 1, "\n")
                                   if(epoch >= 2) {
                                     self$model$stop_training <- TRUE # this allows to stop_training before all epochs have been done
                                     cat("training has been interrupted by user\nplease wait till current epoch stops\n")
                                   }
                                 }
                               ))
    mycallback <- LossHistory$new()
    
    model %>% 
      fit(
        x_train, y_train,
        batch_size = batch_size,
        epochs = epochs,
        validation_split = 0.2,
        callbacks = list(mycallback),
        verbose = 1
      )
  })
}

shinyApp(ui = ui, server = server)

This topic was automatically closed 54 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.