Error with Custom Layer

Hello,

I am trying to write a custom normalization layer that works like batch normalization, except that it normalizes input by the running mean and variance rather than by the batch mean and variance (see Revisit Batch Normalization). Unfortunately, I have been getting one error after another trying to get it to work. Here is my code:

library(keras)

# Custom layer class.
FullNormalizationLayer <- R6::R6Class("KerasLayer",
  inherit = KerasLayer,
  public = list(
    # Normalization parameters.
    off_axes = NULL,
    momentum = NULL,
    epsilon  = NULL,
    gamma = NULL,
    beta  = NULL,
    mean_x1 = NULL,
    mean_x2 = NULL,
    initialize = function(momentum, epsilon = 1e-6) {
      self$momentum <- momentum
      self$epsilon <- epsilon
    },
    build = function(input_shape) {
      param_shape <- input_shape
      self$off_axes <- 1:(length(input_shape)-1)
      param_shape[self$off_axes] <- 1
      
      # Output standard deviation across all samples.
      self$gamma <- self$add_weight(
        name = "gamma", shape = param_shape,
        initializer = Ones(), trainable = TRUE
      )
      
      # Output bias.
      self$beta <- self$add_weight(
        name = "beta", shape = param_shape,
        initializer = Zeros(), trainable = TRUE
      )
      
      # First moment of input across all samples.
      self$mean_x1 <- self$add_weight(
        name = "mean_x1", shape = param_shape,
        initializer = Zeros(), trainable = FALSE
      )
      
      # Second moment of input across all samples.
      self$mean_x2 <- self$add_weight(
        name = "mean_x2", shape = param_shape,
        initializer = Zeros(), trainable = FALSE
      )
    },
    call = function(x, training = NULL) {
      # Subtract off the mean across feature channels (lateral competition).
      x <- x - k_mean(x, axis = -1, keepdims = TRUE)
      
      # Deal with first and second moments if in the training phase.
      if (is.null(training)) training <- k_get_value(k_learning_phase())
      if (training | is.null(training)) {
        # Find the first and second moments across all non-feature axes.
        m1 <- k_mean(  x, axis = self$off_axis, keepdims = TRUE)
        m2 <- k_mean(x*x, axis = self$off_axis, keepdims = TRUE)
        
        # Combine moments from this batch with the running average across all samples.
        if (self$first_norm) {
          self$mean_x1 <- m1
          self$mean_x2 <- m2
          self$first_norm <- FALSE
        } else {
          self$mean_x1 <- self$momentum * self$mean_x1 + (1 - self$momentum) * m1
          self$mean_x2 <- self$momentum * self$mean_x2 + (1 - self$momentum) * m2
        }
      }
      
      # Perform final normalization.
      k_batch_normalization(x,
                            mean = self$mean_x1,
                            var = self$mean_x2 - k_square(self$mean_x1),
                            beta = self$beta,
                            gamma = self$gamma,
                            axis = -1,
                            epsilon = self$epsilon)
    }
  )
)

# Define layer wrapper function.
layer_full_normalization <- function(object, momentum = 0.9, epsilon = 1e-6, name = NULL, trainable = TRUE) {
  create_layer(FullNormalizationLayer, object, list(
    momentum = momentum,
    epsilon = epsilon,
    name = name,
    trainable = trainable
  ))
}

# Build sequential model for CIFAR-10 classification. 
full_norm_model <- keras_model_sequential() %>%
  layer_conv_2d(64, c(3, 3), use_bias = FALSE, input_shape = dim(test_x)[2:4],
                padding = "same", name = "conv3_64_A") %>%
  layer_full_normalization(name = "norm3_64_A") %>%
  layer_activation("relu", name = "relu3_64_A") %>%
  layer_conv_2d(64, c(1, 1), use_bias = FALSE,
                padding = "same", name = "conv1_64_A") %>%
  layer_full_normalization(name = "norm1_64_A") %>%
  layer_activation("relu", name = "relu1_64_A") %>%
  layer_conv_2d(64, c(3, 3), use_bias = FALSE,
                padding = "same", name = "conv3_64_B") %>%
  layer_full_normalization(name = "norm3_64_B") %>%
  layer_activation("relu", name = "relu3_64_B") %>%
  layer_conv_2d(64, c(1, 1), use_bias = FALSE,
                padding = "same", name = "conv1_64_B") %>%
  layer_full_normalization(name = "norm1_64_B") %>%
  layer_activation("relu", name = "relu1_64_B") %>%
  layer_max_pooling_2d(name = "pool_64") %>%
  layer_conv_2d(128, c(3, 3), use_bias = FALSE,
                padding = "same", name = "conv3_128_A") %>%
  layer_full_normalization(name = "norm3_128_A") %>%
  layer_activation("relu", name = "relu3_128_A") %>%
  layer_conv_2d(128, c(1, 1), use_bias = FALSE,
                padding = "same", name = "conv1_128_A") %>%
  layer_full_normalization(name = "norm1_128_A") %>%
  layer_activation("relu", name = "relu1_128_A") %>%
  layer_conv_2d(128, c(3, 3), use_bias = FALSE,
                padding = "same", name = "conv3_128_B") %>%
  layer_full_normalization(name = "norm3_128_B") %>%
  layer_activation("relu", name = "relu3_128_B") %>%
  layer_conv_2d(128, c(1, 1), use_bias = FALSE,
                padding = "same", name = "conv1_128_B") %>%
  layer_full_normalization(name = "norm1_128_B") %>%
  layer_activation("relu", name = "relu1_128_B") %>%
  layer_max_pooling_2d(name = "pool_128") %>%
  layer_conv_2d(10, c(1, 1), name = "conv_10") %>%
  layer_global_average_pooling_2d(name = "pool_10") %>%
  layer_activation("softmax", name = "class_10")

and here is the error message I get:

Error in py_call_impl(callable, dots$args, dots$keywords) : 
  RuntimeError: Evaluation error: ValueError: Shape must be rank 1 but is rank 0 for 'norm3_64_A_14/Reshape' (op: 'Reshape') with input shapes: [1,1,1,64], [].

Detailed traceback: 
  File "/home/rstudio-user/.virtualenvs/r-tensorflow/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 1908, in batch_normalization
    mean = tf.reshape(mean, (-1))
  File "/home/rstudio-user/.virtualenvs/r-tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 7715, in reshape
    "Reshape", tensor=tensor, shape=shape, name=name)
  File "/home/rstudio-user/.virtualenvs/r-tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/home/rstudio-user/.virtualenvs/r-tensorflow/lib/python2.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/rstudio-user/.virtualenvs/r-tensorflow/lib/python2.7/site
Called from: py_call_impl(callable, dots$args, dots$keywords)

For your reference, here is the code for defining the BatchNormalizationLayer for Keras in Python. It's a mess, but I have been trying to make sense of the essential parts and apply it to my code. The R Keras documentation for defining custom layers is also very sparse, and I haven't found anything on the internet that relates to my particular issue.

If you have any experience with custom layers (especially those that function differently between training phase and testing phase), please let me know if you see any basic errors in my code.

Thanks!

I figured out a solution that seems to work. It turns out that the dimensionality issue was cropping up in the k_batch_normalization() function. (Maybe it expects the mean/variance/gamma/beta not to have the same rank as x?) Anyway, for anyone with similar issues, here is the code that ended up working:

# Custom layer class.
FullNormalizationLayer <- R6::R6Class("KerasLayer",
  inherit = KerasLayer,
  public = list(
    # Normalization parameters.
    first_norm = NULL,
    off_axes = NULL,
    momentum = NULL,
    epsilon  = NULL,
    gamma = NULL,
    beta  = NULL,
    mean_x1 = NULL,
    mean_x2 = NULL,
    initialize = function(momentum, epsilon = 1e-6) {
      first_norm <- TRUE
      self$momentum <- momentum
      self$epsilon <- epsilon
    },
    build = function(input_shape) {
      param_shape <- input_shape
      self$off_axes <- 1:(length(input_shape)-1)
      param_shape[self$off_axes] <- 1
      
      # Output standard deviation across all samples.
      self$gamma <- self$add_weight(
        name = "gamma", shape = param_shape,
        initializer = Ones(), trainable = TRUE
      )
      
      # Output bias.
      self$beta <- self$add_weight(
        name = "beta", shape = param_shape,
        initializer = Zeros(), trainable = TRUE
      )
      
      # First moment of input across all samples.
      self$mean_x1 <- self$add_weight(
        name = "mean_x1", shape = param_shape,
        initializer = Zeros(), trainable = FALSE
      )
      
      # Second moment of input across all samples.
      self$mean_x2 <- self$add_weight(
        name = "mean_x2", shape = param_shape,
        initializer = Ones(), trainable = FALSE
      )
    },
    call = function(x, training = NULL) {
      # Subtract off the mean across feature channels (lateral competition).
      x <- x - k_mean(x, axis = -1, keepdims = TRUE)
      
      # Deal with first and second moments if in the training phase.
      if (is.null(training)) training <- k_get_value(k_learning_phase())
      if (training | is.null(training)) {
        # Find the first and second moments across all non-feature axes.
        m1 <- k_mean(  x, axis = self$off_axis, keepdims = TRUE)
        m2 <- k_mean(x*x, axis = self$off_axis, keepdims = TRUE)
        
        # Combine moments from this batch with the running average across all samples.
        if (self$first_norm) {
          self$mean_x1 <- m1
          self$mean_x2 <- m2
          self$first_norm <- FALSE
        } else {
          self$mean_x1 <- self$momentum * self$mean_x1 + (1 - self$momentum) * m1
          self$mean_x2 <- self$momentum * self$mean_x2 + (1 - self$momentum) * m2
        }
      }
      
      # Perform final normalization.
      self$gamma * (x - self$mean_x1) / k_sqrt(self$mean_x2 - k_square(self$mean_x1) + self$epsilon) + self$beta
      # k_batch_normalization(x,
      #                       mean = self$mean_x1,
      #                       var = self$mean_x2 - k_square(self$mean_x1),
      #                       beta = self$beta,
      #                       gamma = self$gamma,
      #                       axis = -1,
      #                       epsilon = self$epsilon)
    }
  )
)