I'm trying to deploy a PyTorch model to Connect as an API with vetiver and getting stuck. I think the issue is that my deployed API doesn't 'have' the class from which the model was created. This is all being done with the Bank Marketing Dataset from Kaggle.
The model itself is created from a class that inherits from
torch.nn.Module, which seems to be a common way to do this.
from torch import nn class BankModel(nn.Module): def __init__(self, input_size): super().__init__() self.linear_stack = nn.Sequential( nn.Linear(input_size, 16), nn.ReLU(), nn.Linear(16, 1), nn.Sigmoid() ) def forward(self, x): x = self.linear_stack(x) return x
I then save the trained model to disk with
torch.save. In the deployment script I read the model from disk, pin it to Connect, and then deploy an API as per the vetiver docs. This is the relevant section of the code:
bank_model = BankModel(input_size=input_size).to("cpu") bank_model.load_state_dict(torch.load(model_filepath)) v = vetiver.VetiverModel( bank_model, model_name=model_name, versioned=True ) vetiver.vetiver_pin_write(board, v) latest_version = sorted( board.pin_versions(model_name)["version"].to_list()) app_id = os.getenv("APP_ID") app_id = None if app_id is None else int(app_id) vetiver.deploy_rsconnect( connect_server=connect_server, board=board, pin_name=model_name, version=latest_version, extra_files=[req_filepath, cert_filepath], new=False, app_id=app_id, title=api_title )
Everything deploys but when I try to access the API I get an error:
Unexpected error while running Python API: Can't get attribute 'BankModel' on <module '__main__' from '/opt/rstudio-connect/python/connect_fastapi_runtime.py'>.
I'm much less familiar with Python than R so it's entirely possible that I'm making an obvious mistake, but as far as I can tell I'm sending the model up to Connect without the custom class that tells it how to use the model object. I can't figure out how to provide that class though.