Hello! I have compiled the master branch of torchvision and used the pre-built libtorch lib. I manage to run the simple HelloWorld example using the ResNet18 but I get "Unhandled exception at 0x00007FFDB7D2A799" error when using the VGG16 network. It fails both for Release and Debug configurations (while ResNet works for both of them). Any ideas? Thanks.
#include <iostream>
#include <torchvision/models/vgg.h>
//#include <torchvision/models/resnet.h>
int main()
{
auto model = vision::models::VGG16();
//auto model = vision::models::ResNet18();
model->eval();
// Create a random input tensor and run it through the model.
//auto in = torch::rand({ 1, 3, 10, 10 });
auto in = torch::rand({ 10, 3, 224, 224 });
auto out = model->forward(in);
std::cout << out;
system("pause");
}
By the way as a sanity check, the equivalent python code works:
import torch
import torchvision.models as models
vgg16 = models.vgg16(pretrained=False)
vgg16.eval()
in_tensor = torch.rand(size=(10, 3, 224, 224))
out_tensor = vgg16.forward(in_tensor)
print(out_tensor)
Windows 10
Visual Studio 2017
pre-build libtorch
Torchvision build from master repo
CMake 3.16
OK since I couldn't fix this bug, I simply exported the VGG16 directly from python using the official torchscript tutorial https://pytorch.org/tutorials/advanced/cpp_export.html . I followed the "tracing" method and the only thing I added was "model.eval()" in python before the tracing part.
Feel free to close the issue, thanks.
Hi,
Thanks for the report!
I'm not sure what the cause of the issue could be with the C++ API. Maybe @glaringlee has an idea?
@ntatsisk
What libtorch version you were using? Do you have an error log?
It seems the c++ code was trying to access an invalid memory address.
I will try reproduce this under linux first.
@peterjc123 Can you help reproduce this on windows? thx
I get a similar error with VGG19.
VGG19 model;
model->eval();
auto in = torch::rand({ 10, 3, 224, 224 });
auto out = model->forward(in);
in debug mode, this snippet of code crashes in TensorBody.h :
int64_t dim() const {
return impl_->dim();
}
this function is reached from VGGImpl's forward method when calling features->forward(x)
torch::Tensor VGGImpl::forward(torch::Tensor x) {
x = features->forward(x);
x = torch::adaptive_avg_pool2d(x, { 7, 7 });
x = x.view({ x.size(0), -1 });
x = classifier->forward(x);
return x;
}
I am working with libtorch1.5 + cuda 10.1, windows 10 and visual studio 2017. I got VGG19 implementation code from torchvision repository branch winbuild\0.6.0
@AndoniC Could you please post the complete stacktrace?
This is the stacktrace in my case:
KernelBase.dll!00007ffc2cc83e49() Desconocido
vcruntime140d.dll!00007ffc05b47ec7() Desconocido
c10.dll!c10::UndefinedTensorImpl::dim() LĂnea 24 C++
> torch_cpu.dll!at::Tensor::dim() LĂnea 108 C++
torch_cpu.dll!at::Tensor::ndimension() LĂnea 202 C++
torch_cpu.dll!at::native::check_shape_forward(const at::Tensor & input, const c10::ArrayRef<__int64> & weight_sizes, const at::Tensor & bias, const at::native::ConvParams & params, bool input_is_mkldnn) LĂnea 394 C++
torch_cpu.dll!at::native::_convolution(const at::Tensor & input_r, const at::Tensor & weight_r, const at::Tensor & bias_r, c10::ArrayRef<__int64> stride_, c10::ArrayRef<__int64> padding_, c10::ArrayRef<__int64> dilation_, bool transposed_, c10::ArrayRef<__int64> output_padding_, __int64 groups_, bool benchmark, bool deterministic, bool cudnn_enabled) LĂnea 597 C++
torch_cpu.dll!at::TypeDefault::_convolution(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, c10::ArrayRef<__int64> stride, c10::ArrayRef<__int64> padding, c10::ArrayRef<__int64> dilation, bool transposed, c10::ArrayRef<__int64> output_padding, __int64 groups, bool benchmark, bool deterministic, bool cudnn_enabled) LĂnea 774 C++
torch_cpu.dll!torch::autograd::VariableType::_convolution(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, c10::ArrayRef<__int64> stride, c10::ArrayRef<__int64> padding, c10::ArrayRef<__int64> dilation, bool transposed, c10::ArrayRef<__int64> output_padding, __int64 groups, bool benchmark, bool deterministic, bool cudnn_enabled) LĂnea 769 C++
torch_cpu.dll!c10::detail::WrapRuntimeKernelFunctor_<at::Tensor (__cdecl*)(at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64,bool,bool,bool),at::Tensor,c10::guts::typelist::typelist<at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64,bool,bool,bool> >::operator()(const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, bool <args_6>, c10::ArrayRef<__int64> <args_7>, __int64 <args_8>, bool <args_9>, bool <args_10>, bool <args_11>) LĂnea 23 C++
torch_cpu.dll!c10::detail::wrap_kernel_functor_unboxed_<c10::detail::WrapRuntimeKernelFunctor_<at::Tensor (__cdecl*)(at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64,bool,bool,bool),at::Tensor,c10::guts::typelist::typelist<at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64,bool,bool,bool> >,at::Tensor __cdecl(at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64,bool,bool,bool)>::call(c10::OperatorKernel * functor, const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, bool <args_6>, c10::ArrayRef<__int64> <args_7>, __int64 <args_8>, bool <args_9>, bool <args_10>, bool <args_11>) LĂnea 276 C++
torch_cpu.dll!c10::KernelFunction::callUnboxed<at::Tensor,at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64,bool,bool,bool>(const c10::OperatorHandle & opHandle, const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, bool <args_6>, c10::ArrayRef<__int64> <args_7>, __int64 <args_8>, bool <args_9>, bool <args_10>, bool <args_11>) LĂnea 66 C++
torch_cpu.dll!c10::Dispatcher::callUnboxedWithDispatchKey<at::Tensor,at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64,bool,bool,bool>(const c10::OperatorHandle & op, c10::DispatchKey dispatchKey, const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, bool <args_6>, c10::ArrayRef<__int64> <args_7>, __int64 <args_8>, bool <args_9>, bool <args_10>, bool <args_11>) LĂnea 221 C++
torch_cpu.dll!c10::Dispatcher::callUnboxed<at::Tensor,at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64,bool,bool,bool>(const c10::OperatorHandle & op, const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, bool <args_6>, c10::ArrayRef<__int64> <args_7>, __int64 <args_8>, bool <args_9>, bool <args_10>, bool <args_11>) LĂnea 229 C++
torch_cpu.dll!c10::OperatorHandle::callUnboxed<at::Tensor,at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64,bool,bool,bool>(const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, bool <args_6>, c10::ArrayRef<__int64> <args_7>, __int64 <args_8>, bool <args_9>, bool <args_10>, bool <args_11>) LĂnea 192 C++
torch_cpu.dll!at::_convolution(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, c10::ArrayRef<__int64> stride, c10::ArrayRef<__int64> padding, c10::ArrayRef<__int64> dilation, bool transposed, c10::ArrayRef<__int64> output_padding, __int64 groups, bool benchmark, bool deterministic, bool cudnn_enabled) LĂnea 2970 C++
torch_cpu.dll!at::native::convolution(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, c10::ArrayRef<__int64> stride, c10::ArrayRef<__int64> padding, c10::ArrayRef<__int64> dilation, bool transposed, c10::ArrayRef<__int64> output_padding, __int64 groups) LĂnea 544 C++
torch_cpu.dll!at::TypeDefault::convolution(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, c10::ArrayRef<__int64> stride, c10::ArrayRef<__int64> padding, c10::ArrayRef<__int64> dilation, bool transposed, c10::ArrayRef<__int64> output_padding, __int64 groups) LĂnea 753 C++
torch_cpu.dll!torch::autograd::VariableType::convolution(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, c10::ArrayRef<__int64> stride, c10::ArrayRef<__int64> padding, c10::ArrayRef<__int64> dilation, bool transposed, c10::ArrayRef<__int64> output_padding, __int64 groups) LĂnea 3917 C++
torch_cpu.dll!c10::detail::WrapRuntimeKernelFunctor_<at::Tensor (__cdecl*)(at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64),at::Tensor,c10::guts::typelist::typelist<at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64> >::operator()(const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, bool <args_6>, c10::ArrayRef<__int64> <args_7>, __int64 <args_8>) LĂnea 23 C++
torch_cpu.dll!c10::detail::wrap_kernel_functor_unboxed_<c10::detail::WrapRuntimeKernelFunctor_<at::Tensor (__cdecl*)(at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64),at::Tensor,c10::guts::typelist::typelist<at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64> >,at::Tensor __cdecl(at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64)>::call(c10::OperatorKernel * functor, const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, bool <args_6>, c10::ArrayRef<__int64> <args_7>, __int64 <args_8>) LĂnea 276 C++
torch_cpu.dll!c10::KernelFunction::callUnboxed<at::Tensor,at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64>(const c10::OperatorHandle & opHandle, const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, bool <args_6>, c10::ArrayRef<__int64> <args_7>, __int64 <args_8>) LĂnea 66 C++
torch_cpu.dll!c10::Dispatcher::callUnboxedWithDispatchKey<at::Tensor,at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64>(const c10::OperatorHandle & op, c10::DispatchKey dispatchKey, const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, bool <args_6>, c10::ArrayRef<__int64> <args_7>, __int64 <args_8>) LĂnea 221 C++
torch_cpu.dll!c10::Dispatcher::callUnboxed<at::Tensor,at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64>(const c10::OperatorHandle & op, const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, bool <args_6>, c10::ArrayRef<__int64> <args_7>, __int64 <args_8>) LĂnea 229 C++
torch_cpu.dll!c10::OperatorHandle::callUnboxed<at::Tensor,at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,bool,c10::ArrayRef<__int64>,__int64>(const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, bool <args_6>, c10::ArrayRef<__int64> <args_7>, __int64 <args_8>) LĂnea 192 C++
torch_cpu.dll!at::convolution(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, c10::ArrayRef<__int64> stride, c10::ArrayRef<__int64> padding, c10::ArrayRef<__int64> dilation, bool transposed, c10::ArrayRef<__int64> output_padding, __int64 groups) LĂnea 2940 C++
torch_cpu.dll!at::native::conv2d(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, c10::ArrayRef<__int64> stride, c10::ArrayRef<__int64> padding, c10::ArrayRef<__int64> dilation, __int64 groups) LĂnea 507 C++
torch_cpu.dll!at::TypeDefault::conv2d(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, c10::ArrayRef<__int64> stride, c10::ArrayRef<__int64> padding, c10::ArrayRef<__int64> dilation, __int64 groups) LĂnea 802 C++
torch_cpu.dll!torch::autograd::VariableType::conv2d(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, c10::ArrayRef<__int64> stride, c10::ArrayRef<__int64> padding, c10::ArrayRef<__int64> dilation, __int64 groups) LĂnea 2978 C++
torch_cpu.dll!c10::detail::WrapRuntimeKernelFunctor_<at::Tensor (__cdecl*)(at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,__int64),at::Tensor,c10::guts::typelist::typelist<at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,__int64> >::operator()(const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, __int64 <args_6>) LĂnea 23 C++
torch_cpu.dll!c10::detail::wrap_kernel_functor_unboxed_<c10::detail::WrapRuntimeKernelFunctor_<at::Tensor (__cdecl*)(at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,__int64),at::Tensor,c10::guts::typelist::typelist<at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,__int64> >,at::Tensor __cdecl(at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,__int64)>::call(c10::OperatorKernel * functor, const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, __int64 <args_6>) LĂnea 276 C++
torch_cpu.dll!c10::KernelFunction::callUnboxed<at::Tensor,at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,__int64>(const c10::OperatorHandle & opHandle, const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, __int64 <args_6>) LĂnea 66 C++
torch_cpu.dll!c10::Dispatcher::callUnboxedWithDispatchKey<at::Tensor,at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,__int64>(const c10::OperatorHandle & op, c10::DispatchKey dispatchKey, const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, __int64 <args_6>) LĂnea 221 C++
torch_cpu.dll!c10::Dispatcher::callUnboxed<at::Tensor,at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,__int64>(const c10::OperatorHandle & op, const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, __int64 <args_6>) LĂnea 229 C++
torch_cpu.dll!c10::OperatorHandle::callUnboxed<at::Tensor,at::Tensor const &,at::Tensor const &,at::Tensor const &,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,c10::ArrayRef<__int64>,__int64>(const at::Tensor & <args_0>, const at::Tensor & <args_1>, const at::Tensor & <args_2>, c10::ArrayRef<__int64> <args_3>, c10::ArrayRef<__int64> <args_4>, c10::ArrayRef<__int64> <args_5>, __int64 <args_6>) LĂnea 192 C++
torch_cpu.dll!at::conv2d(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, c10::ArrayRef<__int64> stride, c10::ArrayRef<__int64> padding, c10::ArrayRef<__int64> dilation, __int64 groups) LĂnea 3010 C++
torch_cpu.dll!torch::nn::functional::detail::conv2d(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, torch::ExpandingArray<2,__int64> stride, torch::ExpandingArray<2,__int64> padding, torch::ExpandingArray<2,__int64> dilation, __int64 groups) LĂnea 67 C++
torch_cpu.dll!torch::nn::Conv2dImpl::_conv_forward(const at::Tensor & input, const at::Tensor & weight) LĂnea 97 C++
torch_cpu.dll!torch::nn::Conv2dImpl::forward(const at::Tensor & input) LĂnea 108 C++
vgg19.exe!torch::nn::AnyModuleHolder<torch::nn::Conv2dImpl,at::Tensor const &>::InvokeForward::operator()<at::Tensor>(at::Tensor && <ts_0>) LĂnea 63 C++
vgg19.exe!torch::unpack<torch::nn::AnyValue,at::Tensor const &,torch::nn::AnyModuleHolder<torch::nn::Conv2dImpl,at::Tensor const &>::InvokeForward,torch::nn::AnyModuleHolder<torch::nn::Conv2dImpl,at::Tensor const &>::CheckedGetter,0>(torch::nn::AnyModuleHolder<torch::nn::Conv2dImpl,at::Tensor const &>::InvokeForward function, torch::nn::AnyModuleHolder<torch::nn::Conv2dImpl,at::Tensor const &>::CheckedGetter accessor, torch::Indices<0> __formal) LĂnea 136 C++
vgg19.exe!torch::unpack<torch::nn::AnyValue,at::Tensor const &,torch::nn::AnyModuleHolder<torch::nn::Conv2dImpl,at::Tensor const &>::InvokeForward,torch::nn::AnyModuleHolder<torch::nn::Conv2dImpl,at::Tensor const &>::CheckedGetter>(torch::nn::AnyModuleHolder<torch::nn::Conv2dImpl,at::Tensor const &>::InvokeForward function, torch::nn::AnyModuleHolder<torch::nn::Conv2dImpl,at::Tensor const &>::CheckedGetter accessor) LĂnea 128 C++
vgg19.exe!torch::nn::AnyModuleHolder<torch::nn::Conv2dImpl,at::Tensor const &>::forward(std::vector<torch::nn::AnyValue,std::allocator<torch::nn::AnyValue> > && arguments) LĂnea 106 C++
vgg19.exe!torch::nn::AnyModule::any_forward<torch::nn::AnyValue>(torch::nn::AnyValue && <arguments_0>) LĂnea 276 C++
vgg19.exe!torch::nn::SequentialImpl::forward<at::Tensor,at::Tensor &>(at::Tensor & <inputs_0>) LĂnea 179 C++
vgg19.exe!VGGImpl::forward(at::Tensor x) LĂnea 75 C++
vgg19.exe!main(int argc, char * * argv) LĂnea 341 C++
vgg19.exe!invoke_main() LĂnea 79 C++
vgg19.exe!__scrt_common_main_seh() LĂnea 288 C++
vgg19.exe!__scrt_common_main() LĂnea 331 C++
vgg19.exe!mainCRTStartup() LĂnea 17 C++
kernel32.dll!00007ffc2dc16fd4() Desconocido
ntdll.dll!00007ffc2f55cec1() Desconocido
I have implemented VGG19 model from torchvision VGG code to try to find the error by debugging.
#ifndef VGGCUSTOM_H
#define VGGCUSTOM_H
#include <torch/torch.h>
struct VGG19CustomImpl: torch::nn::Module {
// Neural network model consisting of layers proposed by VGG19 model.
// N(int) : number of classes to predict with this model
// input size should be : (b x 3 x 224 x 224)
VGG19CustomImpl(int64_t N=1000)
: conv1(register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 64, 3).padding(1).stride(1)))),
conv2(register_module("conv2", torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 64, 3).padding(1).stride(1)))),
conv3(register_module("conv3", torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 128, 3).padding(1)))),
conv4(register_module("conv4", torch::nn::Conv2d(torch::nn::Conv2dOptions(128, 128, 3).padding(1)))),
conv5(register_module("conv5", torch::nn::Conv2d(torch::nn::Conv2dOptions(128, 256, 3).padding(1)))),
conv6(register_module("conv6", torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 256, 3).padding(1)))),
conv7(register_module("conv7", torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 256, 3).padding(1)))),
conv8(register_module("conv8", torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 256, 3).padding(1)))),
conv9(register_module("conv9", torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 512, 3).padding(1)))),
conv10(register_module("conv10", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)))),
conv11(register_module("conv11", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)))),
conv12(register_module("conv12", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)))),
conv13(register_module("conv13", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)))),
conv14(register_module("conv14", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)))),
conv15(register_module("conv15", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)))),
conv16(register_module("conv16", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)))),
linear1(register_module("linear1", torch::nn::Linear(512 * 7 * 7, 4096))),
linear2(register_module("linear2", torch::nn::Linear(4096, 4096))),
linear3(register_module("linear3", torch::nn::Linear(4096, N))),
dropout1(register_module("dropout1", torch::nn::Dropout(torch::nn::DropoutOptions(0.5)))),
dropout2(register_module("dropout2", torch::nn::Dropout(torch::nn::DropoutOptions(0.5))))
{
_initialize_weights();
}
torch::Tensor forward(const torch::Tensor& input) {
auto x = torch::relu(conv1(input));
x = torch::relu(conv2(x));
x = torch::max_pool2d(x, 2, 2);
x = torch::relu(conv3(x));
x = torch::relu(conv4(x));
x = torch::max_pool2d(x, 2, 2);
x = torch::relu(conv5(x));
x = torch::relu(conv6(x));
x = torch::relu(conv7(x));
x = torch::relu(conv8(x));
x = torch::max_pool2d(x, 2, 2);
x = torch::relu(conv9(x));
x = torch::relu(conv10(x));
x = torch::relu(conv11(x));
x = torch::relu(conv12(x));
x = torch::max_pool2d(x, 2, 2);
x = torch::relu(conv13(x));
x = torch::relu(conv14(x));
x = torch::relu(conv15(x));
x = torch::relu(conv16(x));
x = torch::max_pool2d(x, 2, 2);
// Classifier, 512 * 7 * 7 = 25088
x = x.view({ x.size(0), 25088 });
x = relu(linear1(x));
x = dropout1(x);
x = relu(linear2(x));
x = dropout2(x);
x = linear3(x);
return x;
}
torch::nn::Linear linear1, linear2, linear3;
torch::nn::Dropout dropout1,dropout2;
torch::nn::Conv2d conv1, conv2, conv3, conv4, conv5,conv6, conv7, conv8, conv9, conv10,conv11,conv12,conv13,conv14,conv15,conv16;
void _initialize_weights() {
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
torch::nn::init::kaiming_normal_(
M->weight,
/*a=*/0,
torch::kFanOut,
torch::kReLU);
torch::nn::init::constant_(M->bias, 0);
}
else if (
auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0);
}
else if (auto M = dynamic_cast<torch::nn::LinearImpl*>(module.get())) {
torch::nn::init::normal_(M->weight, 0, 0.01);
torch::nn::init::constant_(M->bias, 0);
}
}
}
};
TORCH_MODULE_IMPL(VGG19Custom, VGG19CustomImpl);
#endif // VGG_H
But there is no problem with this model when executing this code
VGG19Custom modelvggcustom;
auto in = torch::rand({ 10, 3, 224, 224 });
auto out = modelvggcustom->forward(in);
So I still don't know why using torchvision model fails. I would say that the model is the same isn't it?
@AndoniC What is the uncompressed size of the model?
Hi @peterjc123,
I am not sure about what you mean with 'uncompressed size of the model'. Could It be the size of the instance of VGG19Custom class (the object 'modelvggcustom') in memory? and if so, do you know how I could check that in visual studio? Or maybe the object 'model' which is an instance of torchvision VGG19 class?
thanks
Finally, I have used Debug / Windows / Show Diagnostic Tools to see the memory increment when I declare the object.
Memory usages increases in 561 MB with the instance of the class I implemented from torchvision model (modelvggcustom).
VGG19Custom modelvggcustom;
auto in = torch::rand({ 10, 3, 224, 224 });
auto out = modelvggcustom->forward(in);
if I use torchvision VGG model (model_from_torchvision), memory increment is 545 MB.
VGG19 model_from_torchvision;
model_from_torchvision->eval();
auto in = torch::rand({ 10, 3, 224, 224 });
auto out = model_from_torchvision->forward(in);
I can reproduce this issue locally with the latest nightly of pytorch and a fresh clone of torchvision. It throws an error dim() called on undefined Tensor on exit when computing convolutions (in the second block). Any ideas, @fmassa?
Actually, it works after I replace torch::nn::Functional(modelsimpl::relu_) and torch::nn::Functional(modelsimpl::max_pool2d, 2, 2) with torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions({2,2})) and torch::nn::ReLU(), which implies there maybe sth problem with torch::nn::Functional.
I opened an issue in pytorch: https://github.com/pytorch/pytorch/issues/41316.
The function signatures in modelsimpl.h are wrong. I fixed them in https://github.com/pytorch/vision/pull/2463. @fmassa Could you please take a look?
I made the changes you proposed and it worked!
Thankyou very much peterjc123.