TLDR: The default datatype of a numpy array translates to double/float64. If a Tensor is created from that array using torch.as_tensor
it will adopt that datatype, which is not compatible with the default datatype of a neural network model which is float32. Using that tensor as input to the nn model will result in the error expected scalar type Float but found Double
.
Observation
import torch
from torch.nn import Linear
import numpy as np
model = Linear(3,1)
input = np.array([[3.14,3,3],[1,2,3]])
t_input = torch.as_tensor(input)
model(t_input)
After having created a Neural network model and a appropriate tensor, the following error pops up, when using the tensor as input:
RuntimeError: expected scalar type Float but found Double
Resolution
One of the first google entries suggest to convert the model to the specified datatype and it indeed works:
model.double()
model(t_input)
But what is the root problem?
Upon further inspection you will notice, that the tensor created from the numpy array has the dtype torch.float64. This is because the function torch.as_tensor
() infers the dtype from the data source (pytorch doc). And the default dtype of a numpy array is indeed float64 (numpy doc).
t_input.dtype
# torch.float64
input.dtype
# dtype('float64')
Further Remarks
Let’s switch perspectives! In the torch module, by default the dtype is float32 (pytorch doc) and can be changed with torch.set_default_dtype
(). This does influence the utilized dtypes when creating a neural network or tensor. Thus creating a tensor from a list should prove to be a compatible input for a neural network created in the same manner. However, if you create a tensor from existing data such as a numpy array using the torch.as_Tensor()
function, it infers the dtype from the previous datatype.
Practically spoken, you could change the dtype of your numpay array to float32, however you would lose precision this way.
Additionally, calling torch.Tensor(input)
will result in no problems, however that function implicitely converts the data into float32 / the default dtype, thus losing precision.
Further Readings
Further googling actually also reveals exactly my findings in this SO post.