Custom Metric using Full Validation Set

I have created a model using Keras and would like to have R2 on the validation set reported after each epoch. To do so, I create a "custom_metric" and pass it to the "metric" argument in the "keras::fit" function. However, when I compare the reported metric on the final epoch to an R2 calculated manually, I do not get identical results

Below is an example of what I am trying to do.

library(keras)

use_session_with_seed(seed = 1) 

# Get Boston Housing Data
boston_housing <- dataset_boston_housing()

c(train_x, train_y) %<-% boston_housing$train
c(val_x, val_y) %<-% boston_housing$test

# Custom R2 Metric
metric_r2 <- custom_metric("r2", function(y_true, y_pred) {
  ss_res <- k_sum(k_square(y_true - y_pred))
  ss_tot <- k_sum(k_square(y_true - k_mean(y_true))) 
  
  1 -  ss_res/ss_tot       
})

# Build Model
model <- keras_model_sequential()

model %>%
  layer_dense(units = 64, activation = "relu",
              input_shape = dim(train_x)[2]) %>%
  layer_dense(units = 64, activation = "relu") %>%
  layer_dense(units = 1)
  
model %>% compile(
    loss = "mse",
    optimizer = optimizer_rmsprop(),  
    metrics = list(metric_r2)                           # <- Notice custom metric
  )

# Fit model
history <- model %>% fit(
  train_x,
  train_y,
  epochs = 100,
  validation_data = list(val_x, val_y),
  verbose = 0
)

# Predict on Full Validation Data
pred <- model %>% predict(val_x)

ss_res <- sum((val_y - pred[, 1])^2)
ss_tot <- sum((val_y - mean(val_y))^2)

paste("Full Data R2 =", round(1 - ss_res / ss_tot, 3))  # Will print "Full Data R2 = 0.202"
paste("Custom Metric R2 =", round(history$metrics$val_r2[100], 3)) # Will print "Custom Metric R2 = 0.085"

According to the accepted answer in this link, Keras calculates the metric on the validation set in batches, and then report the average across all batches. With my metric, this is not appropriate and must be what is causing the difference. In the last answer in the same link, "GuangshengZuo" offers a way to circumvent this behavior in Python.

My question is how to do this in R?

Any help is much appreciated!