Importing a torch tensor from a Python pickle file

A collaborator sent me a Python pickle, which contains (among other things) a Torch tensor. For example, say they ran this on their computer:

import torch
import pickle

exple = {'some_number': 1, 'some_tensor': torch.tensor([[1., 2., 3.]])}

file = open('my_tensor.pkl', 'wb')
pickle.dump(exple, file)
file.close()

Now, on my computer, I can load the file like this:

#read_pickle.py
import pickle

def read_pickle(filepath):
  file = open(filepath, 'rb')
  content = pickle.load(file)
  file.close()
  return content

R code:

reticulate::use_virtualenv("torch")
reticulate::source_python("read_pickle.py")

data <- read_pickle("my_tensor.pkl")

data
#> $some_number
#> [1] 1
#> 
#> $some_tensor
#> tensor([[1., 2., 3.]])

I can easily convert the non-Torch part to standard R objects:

as.data.frame(data[1])
#>   some_number
#> 1           1

but I can't seem to do the same for the tensor:

as.data.frame(data)
#> Error in as.data.frame.default(x[[i]], optional = TRUE) : 
#>   cannot coerce class ‘c("torch.Tensor", "torch._C.TensorBase", "python.builtin.object"’ to a data.frame

torch::as_array(data$some_tensor)
#> Error in UseMethod("as_array", x) : 
#>   no applicable method for 'as_array' applied to an object of class "c('torch.Tensor', 'torch._C.TensorBase', 'python.builtin.object')"

For that last one, it seems a tensor created in R or imported from pickle do not have the same class:

class(data$some_tensor)
#> [1] "torch.Tensor"          "torch._C.TensorBase"   "python.builtin.object"

t <- torch::torch_tensor(1:3)
class(t)
#> [1] "torch_tensor" "R7"  
torch::as_array(t)
#> [1] 1 2 3

So, is there an "easy" way to read that Torch tensor as an R object?

A not-very-satisfying solution is to modify the Python reading function to convert Torch tensors to Numpy arrays before returning to R:

# read_pickle.py
import pickle
import torch

def read_pickle(filepath):
  file = open(filepath, 'rb')
  content = pickle.load(file)
  content = {k:(v.numpy() if isinstance(v, torch.Tensor) else v) for (k,v) in content.items()}
  file.close()
  return content

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.