You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This branch contains our modified model code specifically for ONNX export. We made a few modifications because directly exporting with PyTorch 1.9 has several limitations.
PyTorch only allows F.interpolate(x, scale_factor) to accept scale_factor as float but not tensor. This makes the value hardcoded into the ONNX graph. We modify downsampling to take scale factor as user provided tensor, such that the downsample_ratio hyperparameter can configured at runtime.
PyTorch does not trace Tensor.Shape very well. It creates a messy graph. We customize it so that the graph is the cleanest.
Our custom export logis are implemented in model/onnx_helper.py
Export Yourself
The following procedures were used to generate our ONNX models.
Install dependencies
pip install -r requirements.txt
(Only for PyTorch <= 1.9) A few modifications to the PyTorch source. This is needed before pull request #60080 is merged into later version of PyTorch. If you are exporting MobileNetV3 variant, go to your local PyTorch install and override the following method to file site-packages/torch/onnx/symbolic_opset9.py. This allows export of hardswish as native ops.
Also note, if your inference backend does not support hardswish or hardsigmoid. You can also use this hack to replace them with primitive ops.
Use the export script. The device argument is only for export tracing. Float16 must be exported using a cuda device. Our export script only support opset 11 and up. If you need older opset support. You must adapt the code yourself.
Our model is tested to work on ONNX Runtime's CPU and CUDA backends. If your inference backend has compatibility issue to certain ops, you can file an issue on GitHub, but we don't guarantee solutions. Feel free to write your own export code that fits your need.
About
Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!