Ok, so solved this myself, with a little help from my friends. The trick is to use a custom
generator function, e.g. like so:
balanced_generator = function(X_data, Y_data, batch_size){
function(){
i_0 = sample(x = which(Y_data == 0), size = batch_size / 2, replace = TRUE)
i_1 = sample(x = which(Y_data == 1), size = batch_size / 2, replace = TRUE)
i = c(rbind(i_0, i_1))
list(X_data[i,], Y_data[i])
}
}
and then train the network, using the fit_generator() function