diff --git a/README.md b/README.md
index c45ccf4..126f10c 100644
--- a/README.md
+++ b/README.md
@@ -282,6 +282,7 @@ ImportError: /home//lib/python3.6/site-packages/mgeconvert/backend/ir_to_tflite/
| matrix mul | ✓
✓ | ✓
✓ | ✓
✓ |
| max pool2d | ✓
✓ | ✓
✓ | ✓
✓ |
| mul | ✓
✓ | ✓
✓ | ✓
✓ |
+| pad | ✓
× | ×
× | ×
× |
| pow | ✓
✓ | ✓
✓ | ✓
✓ |
| reduce max | ✓
✓ | ✓
✓ | ✓
✓ |
| reduce min | ✓
✓ | ✓
✓ | ✓
✓ |
@@ -314,6 +315,7 @@ ImportError: /home//lib/python3.6/site-packages/mgeconvert/backend/ir_to_tflite/
| Concat | ✓
✓ |
| Dropout | ✓
✓ |
| Flatten | ✓
✓ |
+| Gather | ✓
✓ |
| Gemm | ✓
✓ |
| GlobalAveragePool | ✓
✓ |
| GlobalMaxPool | ✓
✓ |
diff --git a/ci/pytest_caffe_and_onnx.sh b/ci/pytest_caffe_and_onnx.sh
index 31d90a4..f91e39a 100755
--- a/ci/pytest_caffe_and_onnx.sh
+++ b/ci/pytest_caffe_and_onnx.sh
@@ -8,7 +8,7 @@ python3 -m pip install --no-binary=protobuf protobuf==3.11.1
apt install -y protobuf-compiler
-./mgeconvert/backend/ir_to_caffe/init.sh
+./mgeconvert/backend/ir_to_caffe/init.sh False /usr/bin/python3
pip3 install scikit-image==0.17.2 onnx onnxruntime
diff --git a/ci/pytest_caffe_and_onnx_complete.sh b/ci/pytest_caffe_and_onnx_complete.sh
index b5ce618..7dd8e53 100755
--- a/ci/pytest_caffe_and_onnx_complete.sh
+++ b/ci/pytest_caffe_and_onnx_complete.sh
@@ -8,7 +8,7 @@ python3 -m pip install --no-binary=protobuf protobuf==3.11.1
apt install -y protobuf-compiler
-./mgeconvert/backend/ir_to_caffe/init.sh
+./mgeconvert/backend/ir_to_caffe/init.sh False /usr/bin/python3
pip3 install scikit-image==0.17.2 onnx onnxruntime
diff --git a/ci/pytest_tflite.sh b/ci/pytest_tflite.sh
index 5240703..4d8352f 100755
--- a/ci/pytest_tflite.sh
+++ b/ci/pytest_tflite.sh
@@ -3,7 +3,7 @@
set -e
./mgeconvert/backend/ir_to_tflite/build_flatbuffer.sh
-./mgeconvert/backend/ir_to_tflite/init.sh
+./mgeconvert/backend/ir_to_tflite/init.sh False /usr/bin/python3
sudo python3 -m pip uninstall flatbuffers -y
sudo python3 -m pip install tensorflow==2.5.0
diff --git a/ci/pytest_tflite_complete.sh b/ci/pytest_tflite_complete.sh
index bdae149..b0e54e2 100755
--- a/ci/pytest_tflite_complete.sh
+++ b/ci/pytest_tflite_complete.sh
@@ -3,7 +3,7 @@
set -e
./mgeconvert/backend/ir_to_tflite/build_flatbuffer.sh
-./mgeconvert/backend/ir_to_tflite/init.sh
+./mgeconvert/backend/ir_to_tflite/init.sh False /usr/bin/python3
# try to find libflatbuffers.so
export LD_LIBRARY_PATH=$HOME/.local/lib:$LD_LIBRARY_PATH:
diff --git a/mgeconvert/backend/ir_to_caffe/init.sh b/mgeconvert/backend/ir_to_caffe/init.sh
index 62f9709..72fdc40 100755
--- a/mgeconvert/backend/ir_to_caffe/init.sh
+++ b/mgeconvert/backend/ir_to_caffe/init.sh
@@ -1,6 +1,13 @@
#!/bin/bash -e
-python3 -m pip install --no-binary=protobuf "protobuf>=3.11.1" --user
+ADD_USER=""
+if [[ $1 == "False" ]]; then
+ ADD_USER="--user"
+fi
+
+PYTHON3=$2
+
+$PYTHON3 -m pip install --no-binary=protobuf "protobuf>=3.11.1" $ADD_USER
hash wget || (echo "please install wget package" && exit -1)
hash protoc || (echo "please install protobuf-compiler package" && exit -1)
diff --git a/mgeconvert/backend/ir_to_onnx/init.sh b/mgeconvert/backend/ir_to_onnx/init.sh
index 38ff8e8..67f0da3 100755
--- a/mgeconvert/backend/ir_to_onnx/init.sh
+++ b/mgeconvert/backend/ir_to_onnx/init.sh
@@ -1,5 +1,12 @@
#!/bin/bash -e
-python3 -m pip install "onnx>=1.7.0" --user
-python3 -m pip install onnx-simplifier --user
-python3 -m pip install protobuf --user
+ADD_USER=""
+if [[ $1 == "False" ]]; then
+ ADD_USER="--user"
+fi
+
+PYTHON3=$2
+
+$PYTHON3 -m pip install "onnx>=1.7.0,<1.12.0" $ADD_USER
+$PYTHON3 -m pip install onnx-simplifier $ADD_USER
+$PYTHON3 -m pip install protobuf $ADD_USER
diff --git a/mgeconvert/backend/ir_to_tflite/init.sh b/mgeconvert/backend/ir_to_tflite/init.sh
index c1448d5..0047ac4 100755
--- a/mgeconvert/backend/ir_to_tflite/init.sh
+++ b/mgeconvert/backend/ir_to_tflite/init.sh
@@ -1,9 +1,15 @@
#!/bin/bash -e
basepath=$(cd `dirname $0`; pwd)
+ADD_USER=""
+if [[ $1 == "False" ]]; then
+ ADD_USER="--user"
+fi
+
+PYTHON3=$2
-if python3 -c "import flatbuffers">/dev/null 2>&1; then
- FLAT_BUFFER_VERSION="$(python3 -m pip show flatbuffers | grep Version)"
+if $PYTHON3 -c "import flatbuffers">/dev/null 2>&1; then
+ FLAT_BUFFER_VERSION="$($PYTHON3 -m pip show flatbuffers | grep Version)"
else
FLAT_BUFFER_VERSION=""
fi
@@ -11,9 +17,9 @@ echo ${FLAT_BUFFER_VERSION}
# install flatbuffers
if [[ "$FLAT_BUFFER_VERSION" != "Version: 1.12" ]]; then
- python3 -m pip uninstall flatbuffers -y
+ $PYTHON3 -m pip uninstall flatbuffers -y
echo "install flatbuffers..."
- python3 -m pip install flatbuffers==1.12 --user
+ $PYTHON3 -m pip install flatbuffers==1.12 $ADD_USER
fi
if [ ! -d /tmp/mgeconvert ]; then
@@ -26,8 +32,8 @@ export PATH=$PATH:$HOME/.local/bin
echo "building tflite schema..."
cd /tmp/mgeconvert
rm -f schema.fbs
-tf_version=$1
-if [ ! -n "$1" ] ;then
+tf_version=$3
+if [ ! -n "$3" ] ;then
tf_version="r2.3"
fi
echo "Use TFLite $tf_version!"
@@ -37,13 +43,13 @@ chmod 777 /tmp/mgeconvert/tflite
cp -r /tmp/mgeconvert/tflite $basepath
-python3 -m pip install pybind11==2.6.2 --user
+$PYTHON3 -m pip install pybind11==2.6.2 $ADD_USER
# using pyflexbuffers
cd $basepath/pyflexbuffers
-PYBIND11_HEADER=$(python3 -c "import pybind11; print(pybind11.get_include())")
-PYTHON_INCLUDE=$(python3 -c "import sysconfig; print(sysconfig.get_paths()['include'])")
-PYTHON_STDLIB=$(python3 -c "import sysconfig; print(sysconfig.get_paths()['stdlib'])")
+PYBIND11_HEADER=$($PYTHON3 -c "import pybind11; print(pybind11.get_include())")
+PYTHON_INCLUDE=$($PYTHON3 -c "import sysconfig; print(sysconfig.get_paths()['include'])")
+PYTHON_STDLIB=$($PYTHON3 -c "import sysconfig; print(sysconfig.get_paths()['stdlib'])")
g++ fbconverter.cc --std=c++14 -fPIC --shared -I$basepath/pyflexbuffers/include -I${PYBIND11_HEADER} -I${PYTHON_INCLUDE} -L${PYTHON_STDLIB} -L$basepath/pyflexbuffers/lib -lflatbuffers -o fbconverter.so
diff --git a/setup.py b/setup.py
index 5ac267c..fa42781 100644
--- a/setup.py
+++ b/setup.py
@@ -15,6 +15,9 @@
targets = []
tfversion = None
+IS_VENV = hasattr(sys, "real_prefix") or (
+ hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix
+)
def write_init(targets, tflite_schema_version=None):
@@ -97,9 +100,11 @@ def find_extension(self, name):
def build_all(self, ext):
if ext.script:
if ext.name == "tflite" and tfversion is not None:
- subprocess.check_call([ext.script, tfversion])
+ subprocess.check_call(
+ [ext.script, str(IS_VENV), sys.executable, tfversion]
+ )
else:
- subprocess.check_call(ext.script)
+ subprocess.check_call([ext.script, str(IS_VENV), sys.executable])
if ext.artifacts is not None:
self.copy_tree(ext.artifacts, os.path.join(self.build_lib, ext.artifacts))
diff --git a/test/mge/test_tflite.py b/test/mge/test_tflite.py
index 0997eaf..19b95a4 100644
--- a/test/mge/test_tflite.py
+++ b/test/mge/test_tflite.py
@@ -32,7 +32,7 @@
from mgeconvert.converters.mge_to_tflite import mge_to_tflite
from tensorflow.lite.python import interpreter # pylint: disable=import-error
-max_error = 1e-6
+max_error = 1e-4
tmp_file = "test_model"
diff --git a/test/traced_module/test_qat_tflite.py b/test/traced_module/test_qat_tflite.py
index 7fc1bbf..35a2bcb 100644
--- a/test/traced_module/test_qat_tflite.py
+++ b/test/traced_module/test_qat_tflite.py
@@ -23,6 +23,11 @@
from .tm_utils import get_traced_module
+if mge.__version__ > "1.7.0":
+ from megengine.traced_module.tm_config import disable_default_checker
+
+ disable_default_checker()
+
max_error = 1e-4
tmp_file = "test_model"
diff --git a/test/traced_module/test_tflite.py b/test/traced_module/test_tflite.py
index 7b5875e..b87670d 100644
--- a/test/traced_module/test_tflite.py
+++ b/test/traced_module/test_tflite.py
@@ -40,7 +40,7 @@
from mgeconvert.converters.tm_to_tflite import tracedmodule_to_tflite
from tensorflow.lite.python import interpreter # pylint: disable=import-error
-max_error = 1e-6
+max_error = 1e-4
tmp_file = "test_model"