Filtering a torch dataset

I'm trying to follow along with the book "Deep Learning with PyTorch". I am using the new R packages torch and torchvision .

On page 173, section 7.2.1 I'm just not sure how to filter this dataset to include only labels 1 and 3 (corresponding to the 0 and 2 in the book).

This is my code, and I'd like to know how to filter transformed_cifar10 as per the code in the book. Meaning filter it so that the transformed_cifar10$y labels only include 1 and 3. and then remap {1,3} to {1,2}.

library(dplyr)
library(torch)
library(torchvision)

data_path <- "./ch7/data" # need to change this?

train_transforms <- function (img) {
  img %>% 
    transform_to_tensor() %>% 
    transform_normalize(mean = c(0.4915, 0.4823, 0.4468),
                        std = c(0.2470, 0.2435, 0.2616))
}

transformed_cifar10 <- cifar10_dataset(data_path, 
                                       train = TRUE, 
                                       download = TRUE, 
                                       transform = train_transforms)

This is the python code in the book:

# In[5]:
label_map = {0: 0, 2: 1}
class_names = ['airplane', 'bird']
cifar2 = [(img, label_map[label])
          for img, label in cifar10
          if label in [0, 2]]

First I thought of trying something like this but clearly it doesn't work... Any ideas?

tensor_cifar10[tensor_cifar10$y == 1]

AFAICT what the code in the book is doing something like:

cifar2 <- list()
for(i in seq_len(length(cifar10))) {
  obs <- cifar10[i]
  if (obs$y == 1) { #airplane
    obs$y <- 0
    cifar2[[length(cifar2) + 1]] <- obs
  } else if (obs$y == 3) { #bird 
    obs$y <- 1
    cifar2[[length(cifar2) + 1]] <- obs
  }
}

It's creating a list where each element is also a list containing an img and it's label:

List of 10000
 $ :List of 2
  ..$ x: int [1:32, 1:32, 1:3] 164 167 140 102 73 69 76 63 64 52 ...
  ..$ y: num 1
 $ :List of 2
  ..$ x: int [1:32, 1:32, 1:3] 17 18 18 16 16 16 14 15 17 17 ...
  ..$ y: num 1

That said, I'd say that both in python and in R it's probably more idiomatic to subclass the dataset with something like:

cifar2_dataset <- torch::dataset(
  inherit = cifar10_dataset,
  initialize = function(...) {
    super$initialize(...) # creates the dataset 
    
    # filter obs
    is_airplane_or_bird <- self$y %in% c(1L, 3L)
    self$x <- self$x[is_airplane_or_bird,,,]
    self$y <- self$y[is_airplane_or_bird]
    
    # switch form 1 and 3 to 0 and 1
    self$y <- ifelse(self$y == 1, 0, 1) 
  }
)
d <- fs::dir_create(tempfile())
cifar2 <- cifar2_dataset(root = d, download = TRUE)
1 Like

Thank you very much, learnt a lot from your example. :slight_smile:

1 Like

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.