Add map_location=torch.device("cpu") to torch.load in the inference stage.
map_location=torch.device("cpu")