diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index febefff6756e7..58722aa180021 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -93,6 +93,7 @@ option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) option(onnxruntime_USE_COREML "Build with CoreML support" OFF) option(onnxruntime_USE_NNAPI_BUILTIN "Build with builtin NNAPI lib for Android NNAPI support" OFF) option(onnxruntime_USE_QNN "Build with QNN support" OFF) +option(onnxruntime_BUILD_QNN_EP_STATIC_LIB "Build with QNN EP as a static library" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index d72b61a0859b2..78edb4179fafd 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -199,17 +199,12 @@ if(CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_BUILD_JAVA) endforeach() endif() -# This list is a reversed topological ordering of library dependencies. -# Earlier entries may depend on later ones. Later ones should not depend on earlier ones. -set(onnxruntime_INTERNAL_LIBRARIES - onnxruntime_session - ${onnxruntime_libs} +set(onnxruntime_INTERNAL_PROVIDER_LIBRARIES ${PROVIDERS_ACL} ${PROVIDERS_ARMNN} ${PROVIDERS_COREML} ${PROVIDERS_DML} ${PROVIDERS_NNAPI} - ${PROVIDERS_QNN} ${PROVIDERS_SNPE} ${PROVIDERS_RKNPU} ${PROVIDERS_VSINPU} @@ -218,6 +213,18 @@ set(onnxruntime_INTERNAL_LIBRARIES ${PROVIDERS_WEBNN} ${PROVIDERS_AZURE} ${PROVIDERS_INTERNAL_TESTING} +) + +if (onnxruntime_BUILD_QNN_EP_STATIC_LIB) + list(APPEND onnxruntime_INTERNAL_PROVIDER_LIBRARIES onnxruntime_providers_qnn) +endif() + +# This list is a reversed topological ordering of library dependencies. +# Earlier entries may depend on later ones. Later ones should not depend on earlier ones. +set(onnxruntime_INTERNAL_LIBRARIES + onnxruntime_session + ${onnxruntime_libs} + ${onnxruntime_INTERNAL_PROVIDER_LIBRARIES} ${onnxruntime_winml} onnxruntime_optimizer onnxruntime_providers diff --git a/cmake/onnxruntime_java.cmake b/cmake/onnxruntime_java.cmake index b15b9632e9e24..1227264e595ed 100644 --- a/cmake/onnxruntime_java.cmake +++ b/cmake/onnxruntime_java.cmake @@ -148,7 +148,7 @@ if (WIN32) if(NOT onnxruntime_ENABLE_STATIC_ANALYSIS) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_JNI_DIR}/$) - if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT) + if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT OR (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB)) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) endif() if (onnxruntime_USE_CUDA) @@ -163,11 +163,14 @@ if (WIN32) if (onnxruntime_USE_TENSORRT) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) endif() + if (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB) + add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) + endif() endif() else() add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_JNI_DIR}/$) - if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT) + if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT OR (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB)) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) endif() if (onnxruntime_USE_CUDA) @@ -182,6 +185,9 @@ else() if (onnxruntime_USE_TENSORRT) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) endif() + if (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB) + add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) + endif() endif() # run the build process (this copies the results back into CMAKE_CURRENT_BINARY_DIR) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 582491de9503d..67fa48b28278d 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -74,9 +74,6 @@ endif() if(onnxruntime_USE_JSEP) set(PROVIDERS_JS onnxruntime_providers_js) endif() -if(onnxruntime_USE_QNN) - set(PROVIDERS_QNN onnxruntime_providers_qnn) -endif() if(onnxruntime_USE_RKNPU) set(PROVIDERS_RKNPU onnxruntime_providers_rknpu) endif() diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index 91a2b13002ec9..4ae89a392278f 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -239,7 +239,9 @@ if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/shared/exported_symbols.lst") elseif(UNIX) if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX") - set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/shared/version_script.lds -Xlinker --gc-sections") + target_link_options(onnxruntime_providers_shared PRIVATE + "LINKER:--version-script=${ONNXRUNTIME_ROOT}/core/providers/shared/version_script.lds" + "LINKER:--gc-sections") endif() elseif(WIN32) set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/shared/symbols.def") diff --git a/cmake/onnxruntime_providers_qnn.cmake b/cmake/onnxruntime_providers_qnn.cmake index b68d84c23bb32..303020145889b 100644 --- a/cmake/onnxruntime_providers_qnn.cmake +++ b/cmake/onnxruntime_providers_qnn.cmake @@ -3,41 +3,89 @@ add_compile_definitions(USE_QNN=1) - # These are shared utils, - # TODO, move to a separate lib when used by EPs other than QNN, NNAPI and CoreML - file(GLOB onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" - ) + if(onnxruntime_BUILD_QNN_EP_STATIC_LIB) + add_compile_definitions(BUILD_QNN_EP_STATIC_LIB=1) + endif() file(GLOB_RECURSE - onnxruntime_providers_qnn_ep_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/qnn/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/qnn/*.cc" + onnxruntime_providers_qnn_ep_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/qnn/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/qnn/*.cc" ) - file(GLOB_RECURSE - onnxruntime_providers_qnn_builder_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/qnn/builder/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/qnn/builder/*.cc" - ) + if(onnxruntime_BUILD_QNN_EP_STATIC_LIB) + # + # Build QNN EP as a static library + # + set(onnxruntime_providers_qnn_srcs ${onnxruntime_providers_qnn_ep_srcs}) + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_qnn_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_qnn ${onnxruntime_providers_qnn_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_qnn onnxruntime_common onnxruntime_framework onnx + onnx_proto protobuf::libprotobuf-lite + flatbuffers::flatbuffers Boost::mp11) + add_dependencies(onnxruntime_providers_qnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + set_target_properties(onnxruntime_providers_qnn PROPERTIES CXX_STANDARD_REQUIRED ON) + set_target_properties(onnxruntime_providers_qnn PROPERTIES FOLDER "ONNXRuntime") + target_include_directories(onnxruntime_providers_qnn PRIVATE ${ONNXRUNTIME_ROOT} + ${onnxruntime_QNN_HOME}/include/QNN + ${onnxruntime_QNN_HOME}/include) + set_target_properties(onnxruntime_providers_qnn PROPERTIES LINKER_LANGUAGE CXX) - set(onnxruntime_providers_qnn_cc_srcs - ${onnxruntime_providers_shared_utils_cc_srcs} - ${onnxruntime_providers_qnn_ep_cc_srcs} - ${onnxruntime_providers_qnn_builder_cc_srcs} - ) + # ignore the warning unknown-pragmas on "pragma region" + if(NOT MSVC) + target_compile_options(onnxruntime_providers_qnn PRIVATE "-Wno-unknown-pragmas") + endif() + else() + # + # Build QNN EP as a shared library + # + file(GLOB_RECURSE + onnxruntime_providers_qnn_shared_lib_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" + ) + set(onnxruntime_providers_qnn_srcs ${onnxruntime_providers_qnn_ep_srcs} + ${onnxruntime_providers_qnn_shared_lib_srcs}) + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_qnn_srcs}) + onnxruntime_add_shared_library_module(onnxruntime_providers_qnn ${onnxruntime_providers_qnn_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_qnn ${ONNXRUNTIME_PROVIDERS_SHARED} ${GSL_TARGET} onnx + onnxruntime_common Boost::mp11 safeint_interface) + target_link_libraries(onnxruntime_providers_qnn PRIVATE ${ONNXRUNTIME_PROVIDERS_SHARED} ${ABSEIL_LIBS} ${CMAKE_DL_LIBS}) + add_dependencies(onnxruntime_providers_qnn onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_include_directories(onnxruntime_providers_qnn PRIVATE ${ONNXRUNTIME_ROOT} + ${CMAKE_CURRENT_BINARY_DIR} + ${onnxruntime_QNN_HOME}/include/QNN + ${onnxruntime_QNN_HOME}/include) + + # Set linker flags for function(s) exported by EP DLL + if(UNIX) + target_link_options(onnxruntime_providers_qnn PRIVATE + "LINKER:--version-script=${ONNXRUNTIME_ROOT}/core/providers/qnn/version_script.lds" + "LINKER:--gc-sections" + "LINKER:-rpath=\$ORIGIN" + ) + elseif(WIN32) + set_property(TARGET onnxruntime_providers_qnn APPEND_STRING PROPERTY LINK_FLAGS + "-DEF:${ONNXRUNTIME_ROOT}/core/providers/qnn/symbols.def") + else() + message(FATAL_ERROR "onnxruntime_providers_qnn unknown platform, need to specify shared library exports for it") + endif() + + # Set compile options + if(MSVC) + target_compile_options(onnxruntime_providers_qnn PUBLIC /wd4099 /wd4005) + else() + # ignore the warning unknown-pragmas on "pragma region" + target_compile_options(onnxruntime_providers_qnn PRIVATE "-Wno-unknown-pragmas") + endif() + + set_target_properties(onnxruntime_providers_qnn PROPERTIES LINKER_LANGUAGE CXX) + set_target_properties(onnxruntime_providers_qnn PROPERTIES CXX_STANDARD_REQUIRED ON) + set_target_properties(onnxruntime_providers_qnn PROPERTIES FOLDER "ONNXRuntime") - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_qnn_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_qnn ${onnxruntime_providers_qnn_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_qnn onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf-lite flatbuffers::flatbuffers Boost::mp11) - target_link_libraries(onnxruntime_providers_qnn) - add_dependencies(onnxruntime_providers_qnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) - set_target_properties(onnxruntime_providers_qnn PROPERTIES CXX_STANDARD_REQUIRED ON) - set_target_properties(onnxruntime_providers_qnn PROPERTIES FOLDER "ONNXRuntime") - target_include_directories(onnxruntime_providers_qnn PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_QNN_HOME}/include/QNN ${onnxruntime_QNN_HOME}/include) - set_target_properties(onnxruntime_providers_qnn PROPERTIES LINKER_LANGUAGE CXX) - # ignore the warning unknown-pragmas on "pragma region" - if(NOT MSVC) - target_compile_options(onnxruntime_providers_qnn PRIVATE "-Wno-unknown-pragmas") + install(TARGETS onnxruntime_providers_qnn + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 5b29d1093aa5c..15a2862cede0c 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -169,9 +169,7 @@ if (onnxruntime_ENABLE_LAZY_TENSOR) endif() endif() -target_link_libraries(onnxruntime_pybind11_state PRIVATE - onnxruntime_session - ${onnxruntime_libs} +set(onnxruntime_pybind11_state_static_providers ${PROVIDERS_NNAPI} ${PROVIDERS_VSINPU} ${PROVIDERS_XNNPACK} @@ -183,7 +181,16 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE ${PROVIDERS_XNNPACK} ${PROVIDERS_WEBGPU} ${PROVIDERS_AZURE} - ${PROVIDERS_QNN} +) + +if(onnxruntime_BUILD_QNN_EP_STATIC_LIB) + list(APPEND onnxruntime_pybind11_state_static_providers PRIVATE onnxruntime_providers_qnn) +endif() + +target_link_libraries(onnxruntime_pybind11_state PRIVATE + onnxruntime_session + ${onnxruntime_libs} + ${onnxruntime_pybind11_state_static_providers} onnxruntime_optimizer onnxruntime_providers onnxruntime_util @@ -1000,6 +1007,16 @@ if (onnxruntime_USE_COREML) endif() if (onnxruntime_USE_QNN) + if(NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + $ + $ + $/onnxruntime/capi/ + ) + endif() + add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 9e3ab4d41f416..223af25c026a1 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -627,16 +627,13 @@ if(onnxruntime_USE_ARMNN) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_armnn) endif() -set(ONNXRUNTIME_TEST_LIBS - onnxruntime_session - ${ONNXRUNTIME_INTEROP_TEST_LIBS} - ${onnxruntime_libs} - # CUDA, ROCM, TENSORRT, MIGRAPHX, DNNL, and OpenVINO are dynamically loaded at runtime +set(ONNXRUNTIME_TEST_STATIC_PROVIDER_LIBS + # CUDA, ROCM, TENSORRT, MIGRAPHX, DNNL, and OpenVINO are dynamically loaded at runtime. + # QNN EP can be built as either a dynamic and static libs. ${PROVIDERS_NNAPI} ${PROVIDERS_VSINPU} ${PROVIDERS_JS} ${PROVIDERS_WEBGPU} - ${PROVIDERS_QNN} ${PROVIDERS_SNPE} ${PROVIDERS_RKNPU} ${PROVIDERS_DML} @@ -645,6 +642,17 @@ set(ONNXRUNTIME_TEST_LIBS ${PROVIDERS_COREML} ${PROVIDERS_XNNPACK} ${PROVIDERS_AZURE} +) + +if (onnxruntime_BUILD_QNN_EP_STATIC_LIB) + list(APPEND ONNXRUNTIME_TEST_STATIC_PROVIDER_LIBS onnxruntime_providers_qnn) +endif() + +set(ONNXRUNTIME_TEST_LIBS + onnxruntime_session + ${ONNXRUNTIME_INTEROP_TEST_LIBS} + ${onnxruntime_libs} + ${ONNXRUNTIME_TEST_STATIC_PROVIDER_LIBS} onnxruntime_optimizer onnxruntime_providers onnxruntime_util @@ -708,7 +716,9 @@ if(onnxruntime_USE_QNN AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_RED list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/qnn/*) list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_qnn) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_qnn) - list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_qnn) + if(NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_shared) + endif() endif() if(onnxruntime_USE_SNPE) diff --git a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java index b80debdde47c4..c28c79f1e723e 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java +++ b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java @@ -76,6 +76,9 @@ final class OnnxRuntime { /** The short name of the ONNX runtime TensorRT provider library */ static final String ONNXRUNTIME_LIBRARY_TENSORRT_NAME = "onnxruntime_providers_tensorrt"; + /** The short name of the ONNX runtime QNN provider library */ + static final String ONNXRUNTIME_LIBRARY_QNN_NAME = "onnxruntime_providers_qnn"; + /** The OS & CPU architecture string */ private static final String OS_ARCH_STR = initOsArch(); @@ -159,8 +162,11 @@ static synchronized void init() throws IOException { // the ONNX Runtime native library will load it extractProviderLibrary(ONNXRUNTIME_LIBRARY_SHARED_NAME); - load(ONNXRUNTIME_LIBRARY_NAME); + if (!isAndroid()) { + load(ONNXRUNTIME_LIBRARY_NAME); + } load(ONNXRUNTIME_JNI_LIBRARY_NAME); + ortApiHandle = initialiseAPIBase(ORT_API_VERSION_14); if (ortApiHandle == 0L) { throw new IllegalStateException( @@ -252,6 +258,16 @@ static boolean extractTensorRT() { return extractProviderLibrary(ONNXRUNTIME_LIBRARY_TENSORRT_NAME); } + /** + * Extracts the QNN provider library from the classpath resources if present, or checks to see if + * the QNN provider library is in the directory specified by {@link #ONNXRUNTIME_NATIVE_PATH}. + * + * @return True if the QNN provider library is ready for loading, false otherwise. + */ + static boolean extractQNN() { + return extractProviderLibrary(ONNXRUNTIME_LIBRARY_QNN_NAME); + } + /** * Extracts a shared provider library from the classpath resources if present, or checks to see if * that library is in the directory specified by {@link #ONNXRUNTIME_NATIVE_PATH}. @@ -260,7 +276,7 @@ static boolean extractTensorRT() { * @return True if the library is ready for loading by ORT's native code, false otherwise. */ static synchronized boolean extractProviderLibrary(String libraryName) { - // Android does not need to extract library and it has no shared provider library + // Android does not need to extract provider libraries. if (isAndroid()) { return false; } @@ -312,7 +328,7 @@ static boolean isAndroid() { private static void load(String library) throws IOException { // On Android, we simply use System.loadLibrary if (isAndroid()) { - System.loadLibrary("onnxruntime4j_jni"); + System.loadLibrary(library); return; } diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 32dc9d9f84aaa..bd988e2bb7468 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -1320,6 +1320,10 @@ public void addXnnpack(Map providerOptions) throws OrtException */ public void addQnn(Map providerOptions) throws OrtException { String qnnProviderName = "QNN"; + + // QNN can either be built as a shared or static library. extractQNN() will extract the + // (lib)onnxruntime_providers_qnn(.so/.dll) from classpath resources if present. + OnnxRuntime.extractQNN(); addExecutionProvider(qnnProviderName, providerOptions); } diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.cc b/onnxruntime/core/platform/windows/logging/etw_sink.cc index 950ac247a2046..489cd19b11302 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.cc +++ b/onnxruntime/core/platform/windows/logging/etw_sink.cc @@ -64,6 +64,10 @@ EtwRegistrationManager& EtwRegistrationManager::Instance() { return instance; } +bool EtwRegistrationManager::SupportsETW() { + return true; +} + bool EtwRegistrationManager::IsEnabled() const { std::lock_guard lock(provider_change_mutex_); return is_enabled_; @@ -248,5 +252,19 @@ void EtwSink::SendImpl(const Timestamp& timestamp, const std::string& logger_id, } } // namespace logging } // namespace onnxruntime +#else +// ETW is not supported on this platform but should still define a dummy EtwRegistrationManager +// so that it can be used in the EP provider bridge. +namespace onnxruntime { +namespace logging { +EtwRegistrationManager& EtwRegistrationManager::Instance() { + static EtwRegistrationManager instance; + return instance; +} +bool EtwRegistrationManager::SupportsETW() { + return false; +} +} // namespace logging +} // namespace onnxruntime #endif // ETW_TRACE_LOGGING_SUPPORTED diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.h b/onnxruntime/core/platform/windows/logging/etw_sink.h index 2a798a28f13de..62b762886ca82 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.h +++ b/onnxruntime/core/platform/windows/logging/etw_sink.h @@ -60,6 +60,9 @@ class EtwRegistrationManager { // Singleton instance access static EtwRegistrationManager& Instance(); + // Returns true if ETW is supported at all. + static bool SupportsETW(); + // Check if ETW logging is enabled bool IsEnabled() const; @@ -110,5 +113,33 @@ class EtwRegistrationManager { } // namespace logging } // namespace onnxruntime +#else +// ETW is not supported on this platform but should still define a dummy EtwRegistrationManager +// so that it can be used in the EP provider bridge. +#include "core/common/logging/severity.h" +namespace onnxruntime { +namespace logging { +class EtwRegistrationManager { + public: + using EtwInternalCallback = std::function; + + static EtwRegistrationManager& Instance(); + static bool SupportsETW(); + bool IsEnabled() const { return false; } + UCHAR Level() const { return 0; } + Severity MapLevelToSeverity() { return Severity::kFATAL; } + uint64_t Keyword() const { return 0; } + HRESULT Status() const { return 0; } + void RegisterInternalCallback(const EtwInternalCallback& callback) {} + void UnregisterInternalCallback(const EtwInternalCallback& callback) {} + + private: + EtwRegistrationManager() = default; + ~EtwRegistrationManager() = default; +}; +} // namespace logging +} // namespace onnxruntime #endif // ETW_TRACE_LOGGING_SUPPORTED diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 79674fd706151..3df231e53e7c0 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -2,13 +2,15 @@ // Licensed under the MIT License. #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" -#include "core/graph/constants.h" -#include "core/providers/qnn/builder/qnn_model.h" #include #include #include +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/qnn_model.h" + namespace onnxruntime { namespace qnn { @@ -51,9 +53,9 @@ Status GetMainContextNode(const std::vectorOpType(), "Should only filter in the EPContext node."); - NodeAttrHelper node_helper(*ep_context_node); + const Node& ep_context_node = *graph_viewer.Nodes().begin(); + ORT_RETURN_IF_NOT(EPCONTEXT_OP == ep_context_node.OpType(), "Should only filter in the EPContext node."); + NodeAttrHelper node_helper(ep_context_node); int64_t is_main_context = node_helper.Get(MAIN_CONTEXT, static_cast(0)); if (1 == is_main_context) { main_context_pos.push_back(static_cast(i)); @@ -68,17 +70,16 @@ Status CreateNodeArgs(const std::vector& names, const std::unordered_map& tensor_info_table, std::vector& node_args, onnxruntime::Graph& graph) { - using namespace ONNX_NAMESPACE; for (size_t i = 0; i < names.size(); ++i) { std::string name = names[i]; ORT_RETURN_IF(tensor_info_table.find(name) == tensor_info_table.end(), "Tensor name: ", name, " not found in tensor_info_table"); const OnnxTensorInfo& tensor_info = tensor_info_table.at(name); - TypeProto tensor_type; - tensor_type.mutable_tensor_type()->set_elem_type(tensor_info.data_type_); + std::unique_ptr tensor_type = Factory::Create(); + tensor_type->mutable_tensor_type()->set_elem_type(tensor_info.data_type_); for (size_t j = 0; j < tensor_info.shape_.size(); ++j) { - tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(tensor_info.shape_[j]); + tensor_type->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(tensor_info.shape_[j]); } - auto& input_arg = graph.GetOrCreateNodeArg(name, &tensor_type); + auto& input_arg = graph.GetOrCreateNodeArg(name, tensor_type.get()); node_args.push_back(&input_arg); } return Status::OK(); @@ -161,8 +162,8 @@ Status TryGetMaxSpillFillSize(const std::vector(0)); if (max_size > max_spill_fill_size) { max_spill_fill_size = max_size; diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index 92c5391b40f09..3dfa0ae21001b 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -6,12 +6,8 @@ #include #include -#include "qnn_def.h" -#include "core/common/logging/logging.h" -#include "core/graph/graph_viewer.h" -#include "core/providers/shared/utils/utils.h" -#include "core/graph/model.h" -#include "core/framework/execution_provider.h" +#include "core/providers/qnn/builder/qnn_def.h" +#include "core/providers/qnn/ort_api.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/qnn/builder/op_builder.h b/onnxruntime/core/providers/qnn/builder/op_builder.h index 05398c3f22ea2..0846275496ebf 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder.h @@ -3,9 +3,7 @@ #pragma once -#include "core/graph/graph_viewer.h" -#include "core/framework/node_unit.h" -#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/ort_api.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index 6ef17b40d274b..3e337f679056f 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -5,8 +5,6 @@ #include #include -#include - #include "op_builder_factory.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/argmax_min_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/argmax_min_op_builder.cc index c685fa065e2ba..e3a6141c292dd 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/argmax_min_op_builder.cc @@ -1,14 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" -#include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "base_op_builder.h" - namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index ed70111087e19..cd1ee72e00d4f 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -2,15 +2,9 @@ // Licensed under the MIT License. #include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include #include "core/providers/qnn/builder/qnn_utils.h" -#include - -#include "core/providers/shared/utils/utils.h" -#include "core/framework/tensorprotoutils.h" -#include "core/providers/cpu/tensor/transpose.h" -#include "core/common/safeint.h" - namespace onnxruntime { namespace qnn { @@ -271,37 +265,189 @@ Status BaseOpBuilder::SetOutputQParamEqualToInputIfNearlyEqual(QnnModelWrapper& return Status::OK(); } -Status BaseOpBuilder::TransposeInitializer(const QnnModelWrapper& qnn_model_wrapper, - const onnx::TensorProto& initializer, - const std::vector& perm, - std::vector& transposed_data) const { - const DataTypeImpl* tensor_dtype = DataTypeImpl::TensorTypeFromONNXEnum(initializer.data_type())->GetElementType(); - const auto tensor_shape_dims = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer); - TensorShape tensor_shape{tensor_shape_dims}; - AllocatorPtr cpu_allocator = std::make_shared(); - Tensor in_tensor = Tensor(tensor_dtype, tensor_shape, cpu_allocator); - - auto rank = perm.size(); - std::vector new_tensor_shape_dims; - std::vector permutations; - new_tensor_shape_dims.reserve(rank); - permutations.reserve(rank); - for (int64_t p : perm) { - permutations.push_back(p); - new_tensor_shape_dims.push_back(tensor_shape_dims[p]); +static Status GetTransposeStrides(const TensorShape& input_shape, + gsl::span perm, + gsl::span input_strides, + gsl::span output_strides) { + const size_t rank = input_shape.NumDimensions(); + ORT_RETURN_IF_NOT(perm.size() == rank, "Expected perm size of ", rank); + ORT_RETURN_IF_NOT(input_strides.size() == rank, "Expected input_strides size of ", rank); + ORT_RETURN_IF_NOT(output_strides.size() == rank, "Expected output_strides size of ", rank); + std::vector output_shape_dims(rank); + ORT_RETURN_IF_ERROR((qnn::utils::PermuteShape(input_shape.GetDims(), perm, output_shape_dims))); + const TensorShape output_shape = TensorShape::FromExistingBuffer(output_shape_dims); + + for (size_t i = 0; i < rank; ++i) { + int64_t stride = (i < rank - 1) ? input_shape.SizeFromDimension(i + 1) : 1; + ORT_RETURN_IF_NOT(stride > 0, "Expected positive shape dims when computing strides."); + input_strides[i] = static_cast(stride); + } + + for (size_t i = 0; i < rank; ++i) { + int64_t stride = (i < rank - 1) ? output_shape.SizeFromDimension(i + 1) : 1; + ORT_RETURN_IF_NOT(stride > 0, "Expected positive shape dims when computing strides."); + output_strides[i] = static_cast(stride); } - TensorShape new_tensor_shape(new_tensor_shape_dims); - Tensor out_tensor = Tensor(tensor_dtype, new_tensor_shape, cpu_allocator); - ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor( - Env::Default(), qnn_model_wrapper.GetGraphViewer().ModelPath(), initializer, in_tensor)); - ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutations, in_tensor, out_tensor)); - onnx::TensorProto new_tensor_proto = onnxruntime::utils::TensorToTensorProto(out_tensor, "test"); - ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(new_tensor_proto, transposed_data)); + return Status::OK(); +} + +// Internal function to transpose data of rank 5 with the given permutation. +// Example: transpose input from either (N,C,H,W,D) or (C,N,H,W,D) to (H,W,D,C,N). +static Status TransposeDataRank5(const TensorShape& input_shape, + gsl::span perm, + size_t elem_byte_size, + gsl::span input_buffer, + gsl::span output_buffer) { + std::array input_strides = {}; + std::array output_strides = {}; + ORT_RETURN_IF_ERROR(GetTransposeStrides(input_shape, perm, input_strides, output_strides)); + + std::vector perm_inverse(perm.size()); + ORT_RETURN_IF_ERROR(qnn::utils::InvertPerm(perm, perm_inverse)); + + for (int64_t d0 = 0; d0 < input_shape[0]; ++d0) { + for (int64_t d1 = 0; d1 < input_shape[1]; ++d1) { + for (int64_t d2 = 0; d2 < input_shape[2]; ++d2) { + for (int64_t d3 = 0; d3 < input_shape[3]; ++d3) { + for (int64_t d4 = 0; d4 < input_shape[4]; ++d4) { + const size_t src_elem_index = ((d0 * input_strides[0]) + + (d1 * input_strides[1]) + + (d2 * input_strides[2]) + + (d3 * input_strides[3]) + + (d4 * input_strides[4])); + const size_t dst_elem_index = ((d0 * output_strides[perm_inverse[0]]) + + (d1 * output_strides[perm_inverse[1]]) + + (d2 * output_strides[perm_inverse[2]]) + + (d3 * output_strides[perm_inverse[3]]) + + (d4 * output_strides[perm_inverse[4]])); + + const size_t src_byte_index = src_elem_index * elem_byte_size; + const size_t dst_byte_index = dst_elem_index * elem_byte_size; + assert(src_byte_index < input_buffer.size()); + assert(dst_byte_index < output_buffer.size()); + + std::memcpy(&output_buffer[dst_byte_index], &input_buffer[src_byte_index], elem_byte_size); + } + } + } + } + } return Status::OK(); } +Status BaseOpBuilder::TwoDimensionTranspose(const QnnModelWrapper& qnn_model_wrapper, + std::vector& data_shape, + const onnx::TensorProto& initializer, + std::vector& transposed_data) const { + ORT_RETURN_IF_NOT(data_shape.size() == 2, "Expected shape of rank 2"); + + std::array perm = {1, 0}; + std::vector output_shape(data_shape.size()); + ORT_RETURN_IF_ERROR((qnn::utils::PermuteShape(data_shape, perm, output_shape))); + + auto onnx_type = static_cast(initializer.data_type()); + const size_t elem_byte_size = qnn::utils::GetElementSizeByType(onnx_type); + ORT_RETURN_IF_NOT(elem_byte_size != 0, "Can't get element byte size from given ONNX type"); + + std::vector input_buffer; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(initializer, input_buffer)); + transposed_data.resize(input_buffer.size()); + + for (size_t row = 0; row < data_shape[0]; row++) { + for (size_t col = 0; col < data_shape[1]; col++) { + const size_t src_elem_index = (row * data_shape[1] + col); + const size_t dst_elem_index = (col * output_shape[1] + row); + const size_t src_byte_index = src_elem_index * elem_byte_size; + const size_t dst_byte_index = dst_elem_index * elem_byte_size; + assert(src_byte_index < input_buffer.size()); + assert(dst_byte_index < transposed_data.size()); + + std::memcpy(&transposed_data[dst_byte_index], &input_buffer[src_byte_index], elem_byte_size); + } + } + + data_shape = std::move(output_shape); // Update parameter with final transposed shape + return Status::OK(); +} + +Status BaseOpBuilder::TransposeFromNchwToHwcn(const QnnModelWrapper& qnn_model_wrapper, + const onnx::TensorProto& initializer, + std::vector& transposed_data, + bool is_3d) const { + auto onnx_type = static_cast(initializer.data_type()); + const size_t elem_byte_size = qnn::utils::GetElementSizeByType(onnx_type); + std::vector input_shape = qnn::utils::GetInitializerShape(initializer); + std::vector input_buffer; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(initializer, input_buffer)); + transposed_data.resize(input_buffer.size()); + return TransposeFromNchwToHwcn(std::move(input_shape), elem_byte_size, input_buffer, transposed_data, is_3d); +} + +Status BaseOpBuilder::TransposeFromNchwToHwcn(std::vector&& original_input_shape_dims, + size_t elem_byte_size, + gsl::span input_buffer, + gsl::span output_buffer, + bool is_3d) const { + std::vector input_shape_dims = std::move(original_input_shape_dims); + const size_t rank = input_shape_dims.size(); + ORT_RETURN_IF_NOT((is_3d && rank == 5) || (!is_3d && rank == 4), "Only support input of rank 4 or 5 but got rank ", + rank); + ORT_RETURN_IF_NOT(output_buffer.size() == input_buffer.size(), + "Expected output buffer's size to equal the input buffer's size: ", + output_buffer.size(), " != ", input_buffer.size()); + ORT_RETURN_IF_NOT(elem_byte_size != 0, "Invalid element byte size due to potentially unsupported type"); + + if (!is_3d) { + input_shape_dims.push_back(1); // Make it 3D by making shape (N,C,H,W,1) + } + + return TransposeDataRank5(TensorShape::FromExistingBuffer(input_shape_dims), + nchw2hwcn_perm_3d, + elem_byte_size, + input_buffer, + output_buffer); +} + +Status BaseOpBuilder::TransposeFromCnhwToHwcn(const QnnModelWrapper& qnn_model_wrapper, + const onnx::TensorProto& initializer, + std::vector& transposed_data, + bool is_3d) const { + auto onnx_type = static_cast(initializer.data_type()); + const size_t elem_byte_size = qnn::utils::GetElementSizeByType(onnx_type); + std::vector input_shape = qnn::utils::GetInitializerShape(initializer); + std::vector input_buffer; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(initializer, input_buffer)); + transposed_data.resize(input_buffer.size()); + return TransposeFromCnhwToHwcn(std::move(input_shape), elem_byte_size, input_buffer, transposed_data, is_3d); +} + +Status BaseOpBuilder::TransposeFromCnhwToHwcn(std::vector&& original_input_shape_dims, + size_t elem_byte_size, + gsl::span input_buffer, + gsl::span output_buffer, + bool is_3d) const { + std::vector input_shape_dims = std::move(original_input_shape_dims); + const size_t rank = input_shape_dims.size(); + ORT_RETURN_IF_NOT((is_3d && rank == 5) || (!is_3d && rank == 4), "Only support input of rank 4 or 5 but got rank ", + rank); + ORT_RETURN_IF_NOT(output_buffer.size() == input_buffer.size(), + "Expected output buffer's size to equal the input buffer's size: ", + output_buffer.size(), " != ", input_buffer.size()); + ORT_RETURN_IF_NOT(elem_byte_size != 0, "Invalid element byte size due to potentially unsupported type"); + + if (!is_3d) { + input_shape_dims.push_back(1); // Make it 3D by making shape (C,N,H,W,1) + } + + return TransposeDataRank5(TensorShape::FromExistingBuffer(input_shape_dims), + cnhw2hwcn_perm_3d, + elem_byte_size, + input_buffer, + output_buffer); +} + Status BaseOpBuilder::ProcessAxisAttribute(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, Qnn_Scalar_t& axis_qnn_scalar, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index 055c0f6ccf2fa..8e34b5d87cc68 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -3,11 +3,11 @@ #pragma once -#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder.h" #include "core/providers/qnn/builder/qnn_quant_params_wrapper.h" -#include "core/framework/allocator.h" #include "QnnOpDef.h" @@ -215,7 +215,8 @@ class BaseOpBuilder : public IOpBuilder { } // NCHW shape to channel last - Status NchwShapeToNhwc(const std::vector& nchw_shape, std::vector& nhwc_shape) const { + template + Status NchwShapeToNhwc(gsl::span nchw_shape, gsl::span nhwc_shape) const { ORT_RETURN_IF_NOT(nchw_shape.size() == 4, "shape should have 4 dimension NCHW."); nhwc_shape[0] = nchw_shape[0]; nhwc_shape[1] = nchw_shape[2]; @@ -226,7 +227,8 @@ class BaseOpBuilder : public IOpBuilder { } // NCHW shape to HWCN shape, required for Conv weight - Status NchwShapeToHwcn(const std::vector& nchw_shape, std::vector& hwcn_shape) const { + template + Status NchwShapeToHwcn(gsl::span nchw_shape, gsl::span hwcn_shape) const { if (nchw_shape.size() == 4) { hwcn_shape[0] = nchw_shape[2]; hwcn_shape[1] = nchw_shape[3]; @@ -246,7 +248,8 @@ class BaseOpBuilder : public IOpBuilder { } // CNHW shape to HWCN shape, required for Conv weight - Status CnhwShapeToHwcn(const std::vector& cnhw_shape, std::vector& hwcn_shape) const { + template + Status CnhwShapeToHwcn(gsl::span cnhw_shape, gsl::span hwcn_shape) const { if (cnhw_shape.size() == 4) { hwcn_shape[0] = cnhw_shape[2]; hwcn_shape[1] = cnhw_shape[3]; @@ -264,37 +267,31 @@ class BaseOpBuilder : public IOpBuilder { return Status::OK(); } - Status TransposeInitializer(const QnnModelWrapper& qnn_model_wrapper, - const onnx::TensorProto& initializer, - const std::vector& perm, - std::vector& transposed_data) const; Status TransposeFromNchwToHwcn(const QnnModelWrapper& qnn_model_wrapper, const onnx::TensorProto& initializer, std::vector& transposed_data, - bool is_3d = false) const { - auto& perm = is_3d ? nchw2hwcn_perm_3d : nchw2hwcn_perm; - return TransposeInitializer(qnn_model_wrapper, initializer, perm, transposed_data); - } + bool is_3d = false) const; + Status TransposeFromNchwToHwcn(std::vector&& input_shape_dims, + size_t elem_byte_size, + gsl::span input_buffer, + gsl::span output_buffer, + bool is_3d = false) const; Status TransposeFromCnhwToHwcn(const QnnModelWrapper& qnn_model_wrapper, const onnx::TensorProto& initializer, std::vector& transposed_data, - bool is_3d = false) const { - auto& perm = is_3d ? cnhw2hwcn_perm_3d : cnhw2hwcn_perm; - return TransposeInitializer(qnn_model_wrapper, initializer, perm, transposed_data); - } + bool is_3d = false) const; + Status TransposeFromCnhwToHwcn(std::vector&& input_shape_dims, + size_t elem_byte_size, + gsl::span input_buffer, + gsl::span output_buffer, + bool is_3d = false) const; Status TwoDimensionTranspose(const QnnModelWrapper& qnn_model_wrapper, std::vector& data_shape, const onnx::TensorProto& initializer, - std::vector& transposed_data) const { - auto tmp = data_shape[0]; - data_shape[0] = data_shape[1]; - data_shape[1] = tmp; - std::vector two_dim_trans_perm{1, 0}; - return TransposeInitializer(qnn_model_wrapper, initializer, two_dim_trans_perm, transposed_data); - } + std::vector& transposed_data) const; // Onnx Pads is [x1_begin, x2_begin, x1_end, x2_end], QNN requires [x1_begin, x1_end, x2_begin, x2_end] void ReArranagePads(std::vector& pads) const { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc index 07abcf1c7bf84..14f50fa78c1a9 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc @@ -5,16 +5,11 @@ #include #include -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" -#include "core/framework/float16.h" -#include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "base_op_builder.h" - namespace onnxruntime { namespace qnn { class BatchNormOpBuilder : public BaseOpBuilder { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc index d3bdee02437e4..3139c05378171 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc @@ -4,12 +4,11 @@ #include #include +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_utils.h" -#include "base_op_builder.h" - namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc index e5dc4d04afefd..23b3dfb063ba2 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc @@ -4,14 +4,11 @@ #include #include -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_utils.h" -#include "base_op_builder.h" - namespace onnxruntime { namespace qnn { class ClipOpBuilder : public BaseOpBuilder { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc index 12887f0fb72d6..0f92778252d48 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc @@ -1,16 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" -#include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/common/safeint.h" #include "core/providers/qnn/builder/qnn_utils.h" -#include "base_op_builder.h" - namespace onnxruntime { namespace qnn { @@ -211,9 +206,9 @@ Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper, // Change shape to HWCN, it could be initializer or normal input if (conv_type == OnnxConvType::kConv) { - ORT_RETURN_IF_ERROR(NchwShapeToHwcn(input_info.shape, actual_shape)); + ORT_RETURN_IF_ERROR(NchwShapeToHwcn(input_info.shape, actual_shape)); } else if (conv_type == OnnxConvType::kConvTranspose) { - ORT_RETURN_IF_ERROR(CnhwShapeToHwcn(input_info.shape, actual_shape)); + ORT_RETURN_IF_ERROR(CnhwShapeToHwcn(input_info.shape, actual_shape)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str()); } @@ -413,9 +408,9 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, // Create the final shape after the weights are transposed to HWCN. if (conv_type == OnnxConvType::kConv) { - ORT_RETURN_IF_ERROR(NchwShapeToHwcn(shape_2d, final_shape)); + ORT_RETURN_IF_ERROR(NchwShapeToHwcn(shape_2d, final_shape)); } else if (conv_type == OnnxConvType::kConvTranspose) { - ORT_RETURN_IF_ERROR(CnhwShapeToHwcn(shape_2d, final_shape)); + ORT_RETURN_IF_ERROR(CnhwShapeToHwcn(shape_2d, final_shape)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str()); } @@ -434,16 +429,6 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, return static_cast(dim); }); - const TensorShape tensor_shape = TensorShape::FromExistingBuffer(shape_2d_int64); // Does not own shape data. - const DataTypeImpl* tensor_dtype = DataTypeImpl::TensorTypeFromONNXEnum( - input_info.initializer_tensor->data_type()) - ->GetElementType(); - ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_info.initializer_tensor, unpacked_tensor)); - - Tensor tensor_2d(tensor_dtype, tensor_shape, unpacked_tensor.data(), OrtMemoryInfo{}); // Does not own data. - ONNX_NAMESPACE::TensorProto reshaped_initializer = onnxruntime::utils::TensorToTensorProto(tensor_2d, - reshape_output); - // The reshape (unsqueeze) may require us to shift the quant parameter's axis. if (input_info.quant_param.IsPerChannel()) { ORT_RETURN_IF_ERROR(input_info.quant_param.HandleUnsqueeze(input_info.shape, shape_2d)); @@ -452,10 +437,21 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, // // Get transposed initializer bytes. // + std::vector original_tensor_bytes; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_info.initializer_tensor, + original_tensor_bytes)); + unpacked_tensor.resize(original_tensor_bytes.size()); + const size_t elem_byte_size = qnn::utils::GetElementSizeByType( + static_cast(input_info.initializer_tensor->data_type())); + ORT_RETURN_IF(elem_byte_size == 0, "Can't get element byte size from given ONNX type for initializer ", + input1_name.c_str()); + if (conv_type == OnnxConvType::kConv) { - ORT_RETURN_IF_ERROR(TransposeFromNchwToHwcn(qnn_model_wrapper, reshaped_initializer, unpacked_tensor)); + ORT_RETURN_IF_ERROR(TransposeFromNchwToHwcn(std::move(shape_2d_int64), elem_byte_size, original_tensor_bytes, + unpacked_tensor, /*is_3d*/ false)); } else if (conv_type == OnnxConvType::kConvTranspose) { - ORT_RETURN_IF_ERROR(TransposeFromCnhwToHwcn(qnn_model_wrapper, reshaped_initializer, unpacked_tensor)); + ORT_RETURN_IF_ERROR(TransposeFromCnhwToHwcn(std::move(shape_2d_int64), elem_byte_size, original_tensor_bytes, + unpacked_tensor, /*is_3d*/ false)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str()); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc index 64f676aaa9875..2bae3452199a5 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc @@ -1,14 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_utils.h" -#include "core/common/safeint.h" - -#include "base_op_builder.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc index 3737fcb54f4cf..24b8c7fd7490a 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc @@ -2,14 +2,10 @@ // Licensed under the MIT License. #include -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_utils.h" -#include "core/common/safeint.h" - -#include "base_op_builder.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc index eeee26c177281..76bc766d2b04d 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc @@ -1,14 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_utils.h" -#include "core/common/safeint.h" - -#include "base_op_builder.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc index 4b8d079c0062a..d77d9534bf1c4 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc @@ -1,16 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" -#include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/common/safeint.h" -#include "onnx/defs/data_type_utils.h" - -#include "base_op_builder.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc index d1a0e88686f39..fc92f42b376bc 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc @@ -2,16 +2,10 @@ // Licensed under the MIT License. #include -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" -#include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/common/safeint.h" -#include "onnx/defs/data_type_utils.h" - -#include "base_op_builder.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/lrn_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/lrn_op_builder.cc index 2f66069b6609e..3c9bdf0e7f8aa 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/lrn_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/lrn_op_builder.cc @@ -2,11 +2,9 @@ // Licensed under the MIT License. #include "core/providers/qnn/builder/opbuilder/base_op_builder.h" -#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/common/safeint.h" -#include "onnx/defs/data_type_utils.h" #include "QnnOpDef.h" // From QNN SDK: contains QNN constants (e.g., op names, param values). diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc index 5fc6d42a8a179..40e0ccdd4a6dd 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc @@ -1,15 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/providers/cpu/tensor/slice_helper.h" -#include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/common/safeint.h" - -#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_utils.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index ef1990ad8e69a..795886fa255ed 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -1,16 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" -#include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/common/safeint.h" -#include "onnx/defs/data_type_utils.h" - -#include "base_op_builder.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc index 77bc58bd6f833..a98110bc96fb2 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc @@ -2,16 +2,13 @@ // Licensed under the MIT License. #include -#include #include +#include +#include #include -#include "core/common/safeint.h" -#include "onnx/defs/data_type_utils.h" -#include "core/providers/common.h" -#include "core/framework/endian_utils.h" -#include "core/providers/shared/utils/utils.h" #include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_utils.h" @@ -71,7 +68,7 @@ class ReduceOpBuilder : public BaseOpBuilder { using AxesQnnIntType = uint32_t; Status GetAxesSet(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - InlinedHashSet& axes_set) const; + std::set& axes_set) const; // Maps an operator type to the opset in which "axes" became an input instead of an attribute. static const std::array opset_with_axes_as_input; @@ -87,7 +84,7 @@ const std::array ReduceOpBuilder::opset_with_axes_as_ }; Status ReduceOpBuilder::GetAxesSet(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - InlinedHashSet& axes_set) const { + std::set& axes_set) const { ReduceOpType reduce_op_type = GetReduceOpType(node_unit.OpType()); if (reduce_op_type == ReduceOpType::REDUCE_OP_TYPE_UNKNOWN) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unknown reduce operator ", node_unit.OpType()); @@ -146,10 +143,7 @@ Status ReduceOpBuilder::GetAxesSet(QnnModelWrapper& qnn_model_wrapper, const Nod auto src_span = gsl::make_span(axes_bytes.data(), axes_bytes.size()); auto dst_span = gsl::make_span(reduce_axes.data(), reduce_axes.size()); - // Copy initializer bytes (stored in little-endian order) to vector of int64_t. - // ReadLittleEndian returns a status error if the source and destination spans do not have - // matching byte sizes. - ORT_RETURN_IF_ERROR(onnxruntime::utils::ReadLittleEndian(src_span, dst_span)); + std::memcpy(dst_span.data(), src_span.data(), src_span.size_bytes()); } } @@ -218,7 +212,7 @@ Status ReduceOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w // // Handle axes param. // - InlinedHashSet axes_set; + std::set axes_set; ORT_RETURN_IF_ERROR(GetAxesSet(qnn_model_wrapper, node_unit, axes_set)); const size_t num_axes = axes_set.size(); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc index b6f414da950d8..6fd67a72b64e1 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc @@ -1,15 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" -#include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "base_op_builder.h" - namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc index c62fca88b6ec2..5e173b7aff030 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc @@ -5,17 +5,10 @@ #include #include -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" -#include "core/framework/tensorprotoutils.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" -#include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/providers/cpu/tensor/slice_helper.h" -#include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/common/safeint.h" - #include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index a6c4203ad92e4..6f9e4ae6b10d2 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -1,16 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" -#include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_utils.h" -#include "core/common/safeint.h" -#include "core/util/qmath.h" - -#include "base_op_builder.h" namespace onnxruntime { namespace qnn { @@ -260,15 +254,16 @@ Status ProcessAlphaAttributeAsInput(QnnModelWrapper& qnn_model_wrapper, // Check LeakyRelu input 0 to see if it's quantized tensor bool is_quantized_tensor = node_unit.Outputs()[0].quant_param.has_value(); if (is_quantized_tensor) { - float scale; - uint8_t zero_point; - int64_t num_of_elements = 1; - concurrency::ThreadPool* thread_pool = nullptr; - GetQuantizationParameter(&tensor_data.alpha, num_of_elements, scale, zero_point, thread_pool); - unpacked_data.resize(1); - ParQuantizeLinearStd(&tensor_data.alpha, unpacked_data.data(), num_of_elements, scale, zero_point, thread_pool); - quantize_param = QnnQuantParamsWrapper(scale, static_cast(zero_point)); qnn_data_type = QNN_DATATYPE_UFIXED_POINT_8; + std::array scales = {1.0f}; + std::array offsets = {0}; + std::array shape = {1}; + auto float_data = gsl::make_span(&tensor_data.alpha, 1); + ORT_RETURN_IF_ERROR(qnn::utils::GetDataQuantParams(float_data, shape, scales, offsets, qnn_data_type)); + + unpacked_data.resize(1); + ORT_RETURN_IF_ERROR(qnn::utils::QuantizeData(float_data, shape, scales, offsets, unpacked_data, qnn_data_type)); + quantize_param = QnnQuantParamsWrapper(scales[0], static_cast(offsets[0])); } else { const auto& inputs = node_unit.Inputs(); TensorInfo input_info = {}; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc index b033c8723ea86..fcc7d27c3ada4 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc @@ -1,17 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/cpu/tensor/slice_helper.h" -#include "core/framework/tensorprotoutils.h" - -#include "base_op_builder.h" - namespace onnxruntime { namespace qnn { @@ -86,26 +81,22 @@ static Status GetInitializerInputData(const NodeUnitIODef& input, const QnnModel ORT_RETURN_IF_NOT(initializer_proto->has_data_type(), "Expected initializer ", input_name.c_str(), " to have a proto data type."); - // Create empty Tensor. - const auto* dtype = DataTypeImpl::TensorTypeFromONNXEnum(initializer_proto->data_type())->GetElementType(); - TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(*initializer_proto); - Tensor tensor(dtype, shape, std::make_shared()); - - // Deserialize initializer into Tensor. - ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor( - onnxruntime::Env::Default(), qnn_model_wrapper.GetGraphViewer().ModelPath(), *initializer_proto, tensor)); + // Deserialize initializer into byte buffer + std::vector initializer_bytes; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*initializer_proto, initializer_bytes)); Status status; // Copy Tensor of int32_t or int64_t elems into output (int64_ts). - if (tensor.IsDataType()) { - gsl::span tensor_elems = tensor.DataAsSpan(); + auto onnx_type = static_cast(initializer_proto->data_type()); + if (onnx_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) { + gsl::span tensor_elems = ReinterpretAsSpan(initializer_bytes); output.insert(output.end(), tensor_elems.begin(), tensor_elems.end()); - } else if (tensor.IsDataType()) { - gsl::span tensor_elems = tensor.DataAsSpan(); + } else if (onnx_type == ONNX_NAMESPACE::TensorProto_DataType_INT32) { + gsl::span tensor_elems = ReinterpretAsSpan(initializer_bytes); output.insert(output.end(), tensor_elems.begin(), tensor_elems.end()); } else { - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Data type ", DataTypeImpl::ToString(dtype), + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Data type ", onnx_type, " is not supported for Slice initializer input ", input.node_arg.Name().c_str()); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc index b62534bacf426..7326523737383 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc @@ -1,15 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" -#include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/common/safeint.h" - -#include "base_op_builder.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc index ba5ad2cf03cef..1db9a8f1e3e15 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc @@ -1,16 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/cpu/tensor/slice_helper.h" -#include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/common/safeint.h" - -#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc index 851ca84dce075..1d518c3ed5359 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc @@ -1,16 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/cpu/tensor/slice_helper.h" -#include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/common/safeint.h" - -#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc index d22c0811682d0..adaa13912ae50 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "core/providers/qnn/builder/opbuilder/base_op_builder.h" -#include "core/framework/utils.h" #include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_utils.h" namespace onnxruntime { namespace qnn { const int TOPK_MIN_INPUT = 2; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/transpose_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/transpose_op_builder.cc index a42d7312f0203..bcd8a6d0f78f6 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/transpose_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/transpose_op_builder.cc @@ -4,12 +4,11 @@ #include #include +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_utils.h" -#include "core/common/safeint.h" - -#include "base_op_builder.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index fa38fad8eeb59..2ca048a3d71d8 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -17,16 +17,12 @@ #include "HTP/QnnHtpContext.h" #include "Saver/QnnSaver.h" #include -#include "core/framework/endian_utils.h" -#include "core/common/logging/capture.h" + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/qnn_telemetry.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/providers/qnn/builder/qnn_configs_helper.h" -#ifdef _WIN32 -#include -#include "core/platform/tracing.h" -#endif - // Flag to determine if Backend should do node validation for each opNode added #define DO_GRAPH_NODE_VALIDATIONS 1 @@ -246,12 +242,12 @@ void QnnLogging(const char* format, const auto data_type = ::onnxruntime::logging::DataType::SYSTEM; if (logger.OutputIsEnabled(severity, data_type)) { - ::onnxruntime::logging::Capture(logger, - severity, - ::onnxruntime::logging::Category::onnxruntime, - data_type, - ORT_WHERE) - .ProcessPrintf(format, argument_parameter); + auto log_capture = Factory::Create(logger, + severity, + logging::Category::onnxruntime, + data_type, + ORT_WHERE); + log_capture->ProcessPrintf(format, argument_parameter); } } @@ -389,25 +385,25 @@ Status QnnBackendManager::CreateDevice() { // Set SoC Model. The *enum* Qnn_SocModel_t is deprecated and will not be updated in the future. Therefore, // must use the latest SDK documentation to get the SoC model of the latest HW. if (soc_model_ != QNN_SOC_MODEL_UNKNOWN) { - QnnHtpDevice_CustomConfig_t& custom_config = device_configs_builder.PushCustomConfig(); - custom_config.option = QNN_HTP_DEVICE_CONFIG_OPTION_SOC; - custom_config.socModel = soc_model_; + gsl::not_null custom_config = device_configs_builder.PushCustomConfig(); + custom_config->option = QNN_HTP_DEVICE_CONFIG_OPTION_SOC; + custom_config->socModel = soc_model_; - QnnDevice_Config_t& device_config = device_configs_builder.PushConfig(); - device_config.option = QNN_DEVICE_CONFIG_OPTION_CUSTOM; - device_config.customConfig = &custom_config; + gsl::not_null device_config = device_configs_builder.PushConfig(); + device_config->option = QNN_DEVICE_CONFIG_OPTION_CUSTOM; + device_config->customConfig = custom_config; } // Set the minimum HTP architecture. The driver will use ops that are compatible with this minimum architecture. if (htp_arch_ != QNN_HTP_DEVICE_ARCH_NONE) { - QnnHtpDevice_CustomConfig_t& custom_config = device_configs_builder.PushCustomConfig(); - custom_config.option = QNN_HTP_DEVICE_CONFIG_OPTION_ARCH; - custom_config.arch.arch = htp_arch_; - custom_config.arch.deviceId = device_id_; - - QnnDevice_Config_t& device_config = device_configs_builder.PushConfig(); - device_config.option = QNN_DEVICE_CONFIG_OPTION_CUSTOM; - device_config.customConfig = &custom_config; + gsl::not_null custom_config = device_configs_builder.PushCustomConfig(); + custom_config->option = QNN_HTP_DEVICE_CONFIG_OPTION_ARCH; + custom_config->arch.arch = htp_arch_; + custom_config->arch.deviceId = device_id_; + + gsl::not_null device_config = device_configs_builder.PushConfig(); + device_config->option = QNN_DEVICE_CONFIG_OPTION_CUSTOM; + device_config->customConfig = custom_config; } } @@ -1110,15 +1106,16 @@ Status QnnBackendManager::ExtractBackendProfilingInfo() { } bool tracelogging_provider_ep_enabled = false; - const Env& env = Env::Default(); - auto& provider = env.GetTelemetryProvider(); - auto level = provider.Level(); +#ifdef _WIN32 + auto& provider = QnnTelemetry::Instance(); if (provider.IsEnabled()) { + auto level = provider.Level(); auto keyword = provider.Keyword(); if ((keyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0 && level >= 5) { tracelogging_provider_ep_enabled = true; } } +#endif // defined(_WIN32) // ETW disabled previously, but enabled now if (ProfilingLevel::INVALID == profiling_level_etw_ && tracelogging_provider_ep_enabled) { @@ -1336,18 +1333,8 @@ void QnnBackendManager::LogQnnProfileEventAsTraceLogging( const std::string& timingSource, const std::string& eventLevel, const char* eventIdentifier) { - TraceLoggingWrite( - telemetry_provider_handle, - "QNNProfilingEvent", - TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)), - TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), - TraceLoggingValue(timestamp, "Timestamp"), - TraceLoggingString(message.c_str(), "Message"), - TraceLoggingString(qnnScalarValue.c_str(), "Value"), - TraceLoggingString(unit.c_str(), "Unit of Measurement"), - TraceLoggingString(timingSource.c_str(), "Timing Source"), - TraceLoggingString(eventLevel.c_str(), "Event Level"), - TraceLoggingString(eventIdentifier, "Event Identifier")); + QnnTelemetry& qnn_telemetry = QnnTelemetry::Instance(); + qnn_telemetry.LogQnnProfileEvent(timestamp, message, qnnScalarValue, unit, timingSource, eventLevel, eventIdentifier); } #endif @@ -1504,7 +1491,8 @@ void* QnnBackendManager::LoadLib(const char* file_name, int flags, std::string& auto file_path = std::filesystem::path(file_name); if (!file_path.is_absolute()) { // construct an absolute path from ORT runtime path + file_name and check whether it exists. - auto pathstring = Env::Default().GetRuntimePath() + ToPathString(file_name); + const Env& env = GetDefaultEnv(); + auto pathstring = env.GetRuntimePath() + ToPathString(file_name); auto absolute_path = pathstring.c_str(); if (std::filesystem::exists(std::filesystem::path(absolute_path))) { // load library from absolute path and search for dependencies there. diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index beabc9bd71b94..41a2e8777f61e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -21,9 +21,8 @@ #include "QnnLog.h" #include "QnnTypes.h" #include "System/QnnSystemInterface.h" -#include "core/common/status.h" -#include "core/common/logging/logging.h" -#include "core/common/path_string.h" + +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/qnn_def.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h b/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h index 9dd9bbaa08d64..b581cd90537d9 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h @@ -3,7 +3,8 @@ #pragma once -#include +#include +#include namespace onnxruntime { namespace qnn { @@ -49,9 +50,9 @@ class QnnConfigsBuilder { * * \return A reference to a default CustomConfigType object. */ - CustomConfigType& PushCustomConfig() { - custom_configs_.push_back(custom_config_init_); - return custom_configs_.back(); + gsl::not_null PushCustomConfig() { + custom_configs_.push_back(std::make_unique(custom_config_init_)); + return custom_configs_.back().get(); } /** @@ -60,15 +61,15 @@ class QnnConfigsBuilder { * * \return A reference to a default BaseConfigType object. */ - BaseConfigType& PushConfig() { - configs_.push_back(base_config_init_); - BaseConfigType& config = configs_.back(); + gsl::not_null PushConfig() { + configs_.push_back(std::make_unique(base_config_init_)); + BaseConfigType* config = configs_.back().get(); // Add pointer to this new config to the list of config pointers. if (IsNullTerminated()) { - config_ptrs_.back() = &config; // Replace last nullptr entry. + config_ptrs_.back() = config; // Replace last nullptr entry. } else { - config_ptrs_.push_back(&config); + config_ptrs_.push_back(config); } return config; @@ -81,9 +82,14 @@ class QnnConfigsBuilder { BaseConfigType base_config_init_; CustomConfigType custom_config_init_; - InlinedVector custom_configs_; - InlinedVector configs_; - InlinedVector config_ptrs_; + + // Store elements of unique_ptrs instead of by value because std::vector reallocation would change the + // location of elements in memory. BaseConfigType objects may contain pointers to CustomConfigType objects, + // so we need to make sure that pointers to these objects are stable in memory. + std::vector> custom_configs_; + std::vector> configs_; + + std::vector config_ptrs_; }; } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index ffd2dc9b11010..705212ae52c77 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -9,8 +9,7 @@ #include #include #include -#include "core/graph/basic_types.h" -#include "core/common/common.h" +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/qnn_quant_params_wrapper.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 4f73e4c532ed4..4721dc7d8499f 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -6,12 +6,9 @@ #include #include "QnnOpDef.h" +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_node_group.h" -#include "core/providers/shared/utils/utils.h" -#include "core/framework/utils.h" -#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" -#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/providers/qnn/builder/qnn_utils.h" namespace onnxruntime { @@ -104,7 +101,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, // valid throughout the lifetime of the ModelBuilder std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); + std::tie(node_unit_holder, node_unit_map) = GetQDQNodeUnits(graph_viewer, logger); // This name must be same with the EPContext node name const auto& graph_name = fused_node.Name(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index 2e0935391ca78..489acaacde4fe 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -3,16 +3,13 @@ #pragma once +#include #include -#include "core/common/status.h" -#include "core/framework/node_unit.h" -#include "core/graph/graph_viewer.h" -#include +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_backend_manager.h" -#include "core/session/onnxruntime_cxx_api.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 2c7f3c8b22ddd..afc588c67d435 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/qnn/builder/qnn_model_wrapper.h" + #include #include #include @@ -8,10 +10,7 @@ #include #include -#include "qnn_model_wrapper.h" -#include "core/common/safeint.h" -#include "core/framework/tensorprotoutils.h" -#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/qnn_utils.h" namespace onnxruntime { @@ -445,7 +444,7 @@ Status QnnModelWrapper::IsPerChannelQuantized(const onnxruntime::NodeUnitIODef& ORT_RETURN_IF(iter == graph_initializers.end(), "Unable to find initializer for scale(s): ", scale_name.c_str()); gsl::not_null scale_tensor_proto = iter->second; - TensorShape scale_shape = onnxruntime::utils::GetTensorShapeFromTensorProto(*scale_tensor_proto); + TensorShape scale_shape(qnn::utils::GetInitializerShape(*scale_tensor_proto)); // Check the number of scale values to determine if the tensor is per-channel. // This is consistent with CPU EP's Quant/Dequant logic. We can't use the presence of an axis because even a @@ -624,29 +623,13 @@ Status QnnModelWrapper::UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& // If this is an int4, we need to unpack it because QNN treats int4 as a full int8. if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) { - TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer); - const size_t num_elems = shape.Size(); - std::vector packed_int4_bytes = std::move(unpacked_tensor); - unpacked_tensor = std::vector(num_elems); - - auto dst = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), unpacked_tensor.size()); - auto src = gsl::make_span(reinterpret_cast(packed_int4_bytes.data()), packed_int4_bytes.size()); - ORT_RETURN_IF_NOT(Int4x2::Unpack(dst, src), "Failed to unpack Tensor for QNN"); - - // NOTE: Masking off top 4 bits to workaround a QNN INT4 accuracy bug. - // Docs explicitly state that masking off top 4 bits should not be required. - for (size_t i = 0; i < dst.size(); i++) { - dst[i] &= 0x0F; // -3 (0b1111_1101) becomes 13 (0b0000_1101) - } + TensorShape shape(qnn::utils::GetInitializerShape(initializer)); + const size_t num_int4_elems = shape.Size(); + ORT_RETURN_IF_ERROR(qnn::utils::UnpackInt4ToInt8(num_int4_elems, unpacked_tensor)); } else if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { - TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer); - const size_t num_elems = shape.Size(); - std::vector packed_int4_bytes = std::move(unpacked_tensor); - unpacked_tensor = std::vector(num_elems); - - auto dst = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), unpacked_tensor.size()); - auto src = gsl::make_span(reinterpret_cast(packed_int4_bytes.data()), packed_int4_bytes.size()); - ORT_RETURN_IF_NOT(UInt4x2::Unpack(dst, src), "Failed to unpack Tensor for QNN"); + TensorShape shape(qnn::utils::GetInitializerShape(initializer)); + const size_t num_uint4_elems = shape.Size(); + ORT_RETURN_IF_ERROR(qnn::utils::UnpackInt4ToInt8(num_uint4_elems, unpacked_tensor)); } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index f3e52050e79e0..8cd7360606d71 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -7,13 +7,10 @@ #include #include -#include "core/common/status.h" #include "QnnInterface.h" #include "qnn_def.h" -#include "core/common/logging/logging.h" -#include "core/framework/node_unit.h" -#include "core/graph/graph_viewer.h" -#include "core/providers/shared/utils/utils.h" + +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/qnn_quant_params_wrapper.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h index f9ef01411310f..276fbaae3b3c9 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h @@ -8,8 +8,7 @@ #include #include -#include "core/common/logging/logging.h" -#include "core/framework/node_unit.h" +#include "core/providers/qnn/ort_api.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc index 813bba8a5952b..8c32db47c1b47 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc @@ -6,9 +6,10 @@ #include #include #include -#include "core/graph/graph_utils.h" -#include "core/framework/node_unit.h" -#include "core/providers/shared/utils/utils.h" +#include +#include + +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_node_group/utils.h" @@ -110,8 +111,8 @@ static bool CanClipBeRemoved(const QnnModelWrapper& qnn_model_wrapper, float clip_min = std::numeric_limits::lowest(); float clip_max = std::numeric_limits::max(); - if (!onnxruntime::GetClipMinMax(qnn_model_wrapper.GetGraphViewer(), clip_node_unit.GetNode(), - clip_min, clip_max, logger)) { + if (!qnn::utils::GetClipMinMax(qnn_model_wrapper.GetGraphViewer(), clip_node_unit.GetNode(), + clip_min, clip_max, logger)) { return false; } @@ -174,7 +175,7 @@ static bool CanActivationBeRemoved(const QnnModelWrapper& qnn_model_wrapper, static std::vector FindParentDQNodes(const GraphViewer& graph_viewer, const Node& node) { // Get all parent DQ nodes sorted by destination argument index. std::vector parents(node.InputDefs().size(), nullptr); - for (auto it = node.InputEdgesBegin(); it != node.InputEdgesEnd(); it++) { + for (auto it = node.InputEdgesBegin(); it != node.InputEdgesEnd(); ++it) { if (it->GetNode().OpType().compare(DEQUANTIZE_LINEAR) == 0) { parents[it->GetDstArgIndex()] = &(it->GetNode()); } @@ -314,12 +315,9 @@ static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, inputs.reserve(num_dqs); for (const Node* dq_node : dq_nodes) { const auto dq_inputs = dq_node->InputDefs(); - const auto& dq_attrs = dq_node->GetAttributes(); + NodeAttrHelper dq_attrs(*dq_node); - std::optional axis; - if (auto entry = dq_attrs.find("axis"); entry != dq_attrs.end()) { - axis = entry->second.i(); - } + std::optional axis = dq_attrs.GetInt64("axis"); // quantization scale and zp are always the input[1, 2] NodeUnitIODef::QuantParam quant_param{*dq_inputs[1], dq_inputs.size() == 3 ? dq_inputs[2] : nullptr, axis}; @@ -328,16 +326,14 @@ static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, // Populate NodeUnit outputs and output edges std::vector outputs; - Node::EdgeSet output_edges; + std::vector> output_edges_holder; + std::vector output_edges; for (const Node* q_node : q_nodes) { const auto q_inputs = q_node->InputDefs(); - const auto& q_attrs = q_node->GetAttributes(); + NodeAttrHelper q_attrs(*q_node); const auto q_outputs = q_node->OutputDefs(); - std::optional axis; - if (auto entry = q_attrs.find("axis"); entry != q_attrs.end()) { - axis = entry->second.i(); - } + std::optional axis = q_attrs.GetInt64("axis"); // quantization scale and zp are always the input[1, 2] NodeUnitIODef::QuantParam quant_param{*q_inputs[1], q_inputs.size() == 3 ? q_inputs[2] : nullptr, axis}; @@ -347,22 +343,25 @@ static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, auto q_cur_edge = q_node->OutputEdgesBegin(); auto q_end_edge = q_node->OutputEdgesEnd(); for (; q_cur_edge != q_end_edge; ++q_cur_edge) { - output_edges.insert(Node::EdgeEnd{q_cur_edge->GetNode(), 0, q_cur_edge->GetDstArgIndex()}); + auto output_edge = Factory::Create(q_cur_edge->GetNode(), 0, q_cur_edge->GetDstArgIndex()); + output_edges.push_back(output_edge.get()); + output_edges_holder.push_back(std::move(output_edge)); } } - NodeUnit custom_node_unit(dq_nodes, target_node, q_nodes, NodeUnit::Type::QDQGroup, - inputs, outputs, num_dqs, output_edges); - const auto* conv_op_builder = qnn::GetOpBuilder(custom_node_unit.OpType()); + std::unique_ptr custom_node_unit = Factory::Create(dq_nodes, target_node, + q_nodes, NodeUnit::Type::QDQGroup, + inputs, outputs, num_dqs, output_edges); + const auto* conv_op_builder = qnn::GetOpBuilder(custom_node_unit->OpType()); if (conv_op_builder == nullptr) { return Status::OK(); } if (validate) { - return conv_op_builder->IsOpSupported(qnn_model_wrapper, custom_node_unit, logger); + return conv_op_builder->IsOpSupported(qnn_model_wrapper, *custom_node_unit, logger); } - return conv_op_builder->AddToModelBuilder(qnn_model_wrapper, custom_node_unit, logger, validate); + return conv_op_builder->AddToModelBuilder(qnn_model_wrapper, *custom_node_unit, logger, validate); } // Traverses graph to check if the given NodeUnit is part of a valid DQ* -> Conv -> Relu/Clip -> Q sequence. diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h index b604b25e943e6..a211c86c2301e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h @@ -9,7 +9,7 @@ #include #include -#include "core/framework/node_unit.h" +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/qnn_node_group.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc index caf4725626338..3af2fdd1f0276 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc @@ -6,9 +6,8 @@ #include #include #include -#include "core/graph/graph_utils.h" -#include "core/framework/node_unit.h" -#include "core/providers/shared/utils/utils.h" + +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_node_group/utils.h" diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h index 90fe44c3af059..d3d552bc172ec 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h @@ -7,8 +7,7 @@ #include #include -#include "core/common/common.h" -#include "core/framework/node_unit.h" +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/qnn_node_group.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc index 76b1726646486..5094ad96724f5 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc @@ -6,9 +6,8 @@ #include #include #include -#include "core/graph/graph_utils.h" -#include "core/framework/node_unit.h" -#include "core/providers/shared/utils/utils.h" + +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h index 3b67f13492a46..0a1b16d24ffcd 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h @@ -7,8 +7,7 @@ #include #include -#include "core/common/common.h" -#include "core/framework/node_unit.h" +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/qnn_node_group.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 9fb9e815321c0..56413b781b246 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -10,8 +10,7 @@ #include #include #include -#include "core/graph/graph_utils.h" -#include "core/framework/node_unit.h" +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc index 5548d7d37c378..93b2fca296389 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc @@ -4,8 +4,7 @@ #include #include -#include "core/graph/graph_viewer.h" -#include "core/framework/node_unit.h" +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/qnn_node_group.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h index 0d11d21906ccb..c4cf4e8a20a92 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h @@ -7,8 +7,7 @@ #include #include -#include "core/graph/graph_viewer.h" -#include "core/framework/node_unit.h" +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/qnn_node_group.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h index 23330f5616d73..01c15cf4bebe6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h @@ -4,10 +4,10 @@ #pragma once #include #include -#include "QnnTypes.h" -#include "core/common/common.h" #include -#include "core/framework/node_unit.h" + +#include "core/providers/qnn/ort_api.h" +#include "QnnTypes.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index 8d2cb5bdb6da0..fe7d992fe6520 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -1,15 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/qnn/builder/qnn_utils.h" + +#include #include +#include +#include #include #include #include -#include -#include "core/common/common.h" -#include "core/framework/data_types.h" -#include "qnn_utils.h" +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/qnn_def.h" namespace onnxruntime { @@ -63,6 +65,42 @@ size_t GetElementSizeByType(ONNXTensorElementDataType elem_type) { return pos->second; } +size_t GetElementSizeByType(ONNX_NAMESPACE::TensorProto_DataType onnx_type) { + switch (onnx_type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT4: + return sizeof(Int4x2); + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: + return sizeof(UInt4x2); + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + return sizeof(int8_t); + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + return sizeof(uint8_t); + case ONNX_NAMESPACE::TensorProto_DataType_INT16: + return sizeof(int16_t); + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: + return sizeof(uint16_t); + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + return sizeof(int32_t); + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + return sizeof(uint32_t); + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + return sizeof(int64_t); + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + return sizeof(uint64_t); + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + return 2; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return sizeof(float); + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + return sizeof(double); + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + return sizeof(bool); + default: + return 0; + } + // Unreachable +} + std::ostream& operator<<(std::ostream& out, const Qnn_Scalar_t& scalar) { switch (scalar.dataType) { case QNN_DATATYPE_INT_8: @@ -487,39 +525,22 @@ bool OnnxDataTypeToQnnDataType(const int32_t onnx_data_type, Qnn_DataType_t& qnn } std::pair CheckMinMax(float rmin, float rmax) { - // Ensure a minimum range of 0.0001 (required by QNN) - rmax = std::max(rmax, rmin + 0.0001f); - // Both QNN and ORT require the range to include 0.0f rmin = std::min(rmin, 0.0f); rmax = std::max(rmax, 0.0f); + // Ensure a minimum range of 0.0001 (required by QNN) + rmax = std::max(rmax, rmin + 0.0001f); + return std::make_pair(rmin, rmax); } -template -Status GetQminQmax(const Qnn_DataType_t qnn_data_type, - T& qmin, - T& qmax) { - if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { - qmin = static_cast(std::numeric_limits::min()); - qmax = static_cast(std::numeric_limits::max()); - } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) { - qmin = static_cast(std::numeric_limits::min()); - qmax = static_cast(std::numeric_limits::max()); - } else if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_16) { - qmin = static_cast(std::numeric_limits::min()); - qmax = static_cast(std::numeric_limits::max()); - } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { - qmin = static_cast(std::numeric_limits::min()); - qmax = static_cast(std::numeric_limits::max()); - } else if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_32) { - qmin = static_cast(std::numeric_limits::min()); - qmax = static_cast(std::numeric_limits::max()); - } else { - ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); +inline float RoundHalfToEven(float input) { + if (!std::isfinite(input)) { + return input; } - return Status::OK(); + // std::remainder returns x - n, where n is the integral value nearest to x. When |x - n| = 0.5, n is chosen to be even + return input - std::remainderf(input, 1.f); } Status GetQuantParams(float rmin, @@ -535,20 +556,22 @@ Status GetQuantParams(float rmin, rmin = -abs_max; } - float qmin = 0.0f; - float qmax = 255.0f; - ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax)); + double rmin_dbl = static_cast(rmin); + double rmax_dbl = static_cast(rmax); + double qmin = 0.0; + double qmax = 0.0; + ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax, symmetric)); - scale = (rmax - rmin) / (qmax - qmin); - float initial_zero_point = 0.0f; + double scale_dbl = (rmax_dbl - rmin_dbl) / (qmax - qmin); + double initial_zero_point = 0.0; if (symmetric) { - initial_zero_point = std::round(rmin + rmax) / 2; + initial_zero_point = std::round(rmin_dbl + rmax_dbl) / 2; } else { - initial_zero_point = qmin - (rmin / scale); + initial_zero_point = qmin - (rmin_dbl / scale_dbl); } - zero_point = static_cast(RoundHalfToEven(Saturate(qmax, qmin, initial_zero_point))); - // To match QNN quantization definition - zero_point = 0 - zero_point; + zero_point = static_cast(RoundHalfToEven(static_cast(Saturate(qmax, qmin, initial_zero_point)))); + zero_point = -zero_point; // Negate to match QNN quantization definition. + scale = static_cast(scale_dbl); return Status::OK(); } @@ -570,6 +593,220 @@ Status Quantize(const double double_value, return Status::OK(); } +size_t ShapeSizeCalc(gsl::span shape, size_t start, size_t end) { + size_t size = 1; + for (size_t i = start; i < end; i++) { + size *= shape[i]; + } + return size; +} + +Status GetDataQuantParams(gsl::span data, gsl::span shape, + /*out*/ gsl::span scales, /*out*/ gsl::span offsets, + Qnn_DataType_t data_type, bool symmetric, std::optional axis) { + const size_t num_dims = shape.size(); + const size_t num_elems = ShapeSizeCalc(shape, 0, num_dims); + ORT_RETURN_IF_NOT(num_elems == data.size(), "Shape mismatch with data to quantize"); + + size_t block_count = 1; + size_t broadcast_dim = 1; + size_t block_size = num_elems; + + if (axis.has_value()) { + size_t axis_no_neg = *axis < 0 ? static_cast(*axis) + num_dims : static_cast(*axis); + block_count = ShapeSizeCalc(shape, 0, axis_no_neg); + broadcast_dim = shape[axis_no_neg]; + block_size = ShapeSizeCalc(shape, axis_no_neg + 1, num_dims); + } + + ORT_RETURN_IF_NOT(scales.size() == broadcast_dim, "Unexpected size of scales output buffer"); + ORT_RETURN_IF_NOT(offsets.size() == broadcast_dim, "Unexpected size of offsets output buffer"); + + size_t i = 0; + for (size_t n = 0; n < block_count; n++) { + for (size_t bd = 0; bd < broadcast_dim; bd++) { + float rmin = std::numeric_limits::max(); + float rmax = std::numeric_limits::lowest(); + for (size_t j = 0; j < block_size; j++) { + rmin = std::min(rmin, data[i]); + rmax = std::max(rmax, data[i]); + i++; + } + + scales[bd] = 1.0f; + offsets[bd] = 0; + ORT_RETURN_IF_ERROR(GetQuantParams(rmin, rmax, data_type, scales[bd], offsets[bd], symmetric)); + } + } + + assert(i == data.size()); + return Status::OK(); +} + +Status QuantizeData(gsl::span data, gsl::span shape, + gsl::span scales, gsl::span offsets, + /*out*/ gsl::span quant_bytes, Qnn_DataType_t data_type, + std::optional axis) { + const size_t num_dims = shape.size(); + const size_t num_elems = ShapeSizeCalc(shape, 0, num_dims); + ORT_RETURN_IF_NOT(num_elems == data.size(), "Shape mismatch with data to quantize"); + size_t expected_num_quant_bytes = GetElementSizeByType(data_type) * data.size(); + ORT_RETURN_IF_NOT(quant_bytes.size() == expected_num_quant_bytes, + "Cannot quantize data because output buffer is not the correct size"); + + size_t block_count = 1; + size_t broadcast_dim = 1; + size_t block_size = num_elems; + + if (axis.has_value()) { + size_t axis_no_neg = *axis < 0 ? static_cast(*axis) + num_dims : static_cast(*axis); + block_count = ShapeSizeCalc(shape, 0, axis_no_neg); + broadcast_dim = shape[axis_no_neg]; + block_size = ShapeSizeCalc(shape, axis_no_neg + 1, num_dims); + } + + ORT_RETURN_IF_NOT(scales.size() == broadcast_dim, "Unexpected size of scales output buffer"); + ORT_RETURN_IF_NOT(offsets.size() == broadcast_dim, "Unexpected size of offsets output buffer"); + + size_t i = 0; + for (size_t n = 0; n < block_count; n++) { + for (size_t bd = 0; bd < broadcast_dim; bd++) { + switch (data_type) { + case QNN_DATATYPE_SFIXED_POINT_8: { + auto input_span = gsl::make_span(&data[i], block_size); + auto output_span = gsl::make_span(&quant_bytes[i * sizeof(int8_t)], sizeof(int8_t) * block_size); + ORT_RETURN_IF_ERROR(QuantizeData(input_span, scales[bd], offsets[bd], output_span)); + break; + } + case QNN_DATATYPE_UFIXED_POINT_8: { + auto input_span = gsl::make_span(&data[i], block_size); + auto output_span = gsl::make_span(&quant_bytes[i * sizeof(uint8_t)], sizeof(uint8_t) * block_size); + ORT_RETURN_IF_ERROR(QuantizeData(input_span, scales[bd], offsets[bd], output_span)); + break; + } + case QNN_DATATYPE_SFIXED_POINT_16: { + auto input_span = gsl::make_span(&data[i], block_size); + auto output_span = gsl::make_span(&quant_bytes[i * sizeof(int16_t)], sizeof(int16_t) * block_size); + ORT_RETURN_IF_ERROR(QuantizeData(input_span, scales[bd], offsets[bd], output_span)); + break; + } + case QNN_DATATYPE_UFIXED_POINT_16: { + auto input_span = gsl::make_span(&data[i], block_size); + auto output_span = gsl::make_span(&quant_bytes[i * sizeof(uint16_t)], sizeof(uint16_t) * block_size); + ORT_RETURN_IF_ERROR(QuantizeData(input_span, scales[bd], offsets[bd], output_span)); + break; + } + case QNN_DATATYPE_SFIXED_POINT_32: { + auto input_span = gsl::make_span(&data[i], block_size); + auto output_span = gsl::make_span(&quant_bytes[i * sizeof(int32_t)], sizeof(int32_t) * block_size); + ORT_RETURN_IF_ERROR(QuantizeData(input_span, scales[bd], offsets[bd], output_span)); + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported quantization data type for QuantizeData"); + } + i += block_size; + } + } + assert(i == data.size()); + + return Status::OK(); +} + +static bool GetType(const NodeArg& node_arg, int32_t& type, const logging::Logger& logger) { + type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + const auto* type_proto = node_arg.TypeAsProto(); + if (!type_proto || !type_proto->has_tensor_type() || !type_proto->tensor_type().has_elem_type()) { + LOGS(logger, WARNING) << "NodeArg [" << node_arg.Name() << "] has no input type"; + return false; + } + + type = type_proto->tensor_type().elem_type(); + return true; +} + +static bool GetClipMinMaxImpl(const GraphViewer& graph_viewer, const Node& node, float& min, float& max, + const logging::Logger& logger) { + const auto& node_name = node.Name(); + int32_t input_type; + if (!GetType(*node.InputDefs()[0], input_type, logger)) { + return false; + } + + min = std::numeric_limits::lowest(); + max = std::numeric_limits::max(); + + if (node.SinceVersion() < 11) { // Clip opset 1, 6 is using attributes for min/max + NodeAttrHelper helper(node); + // attributes will be always float + min = helper.Get("min", std::numeric_limits::lowest()); + max = helper.Get("max", std::numeric_limits::max()); + } else { + auto get_value = + [&](const ONNX_NAMESPACE::TensorProto* initializer, std::string_view type, float& value) -> bool { + if (!initializer) { + LOGS(logger, VERBOSE) << type << " input of Clip must be a constant initializer"; + return false; + } + + switch (input_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + std::vector bytes(sizeof(float)); + auto status = onnxruntime::utils::UnpackInitializerData(*initializer, graph_viewer.ModelPath(), bytes); + if (!status.IsOK()) { + LOGS(logger, ERROR) << "GetClipMinMax() failed to unpack float initializer: " << status.ErrorMessage(); + return false; + } + value = *reinterpret_cast(bytes.data()); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + std::vector bytes(sizeof(MLFloat16)); + auto status = onnxruntime::utils::UnpackInitializerData(*initializer, graph_viewer.ModelPath(), bytes); + if (!status.IsOK()) { + LOGS(logger, ERROR) << "GetClipMinMax() failed to unpack float16 initializer: " << status.ErrorMessage(); + return false; + } + value = reinterpret_cast(bytes.data())->ToFloat(); + break; + } + default: + LOGS(logger, VERBOSE) << "GetClipMinMax() only supports float and float16 as min and max inputs for now." + << " The node [" << node_name << "] has input type: " << input_type; + return false; + } + + return true; + }; + + // min and max are both optional. could have neither, one or both. + if (node.InputDefs().size() > 1 && node.InputDefs()[1]->Exists()) { + // we have input min + const auto& min_name = node.InputDefs()[1]->Name(); + const auto* min_value = graph_viewer.GetConstantInitializer(min_name); + if (!get_value(min_value, "Min", min)) { + return false; + } + } + + if (node.InputDefs().size() > 2 && node.InputDefs()[2]->Exists()) { + // we have input max + const auto& max_name = node.InputDefs()[2]->Name(); + const auto* max_value = graph_viewer.GetConstantInitializer(max_name); + if (!get_value(max_value, "Max", max)) { + return false; + } + } + } + + return true; +} + +bool GetClipMinMax(const GraphViewer& graph_viewer, const Node& node, float& min, float& max, + const logging::Logger& logger) { + return GetClipMinMaxImpl(graph_viewer, node, min, max, logger); +} + } // namespace utils } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index aa4a27460563f..b7c9a238c3ff6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -2,16 +2,16 @@ // Licensed under the MIT License. #pragma once +#include #include #include #include #include +#include #include #include "QnnTypes.h" -#include "core/session/onnxruntime_cxx_api.h" -#include "core/framework/node_unit.h" -#include "core/util/qmath.h" +#include "core/providers/qnn/ort_api.h" namespace onnxruntime { namespace qnn { @@ -22,6 +22,8 @@ size_t GetElementSizeByType(const Qnn_DataType_t& data_type); size_t GetElementSizeByType(ONNXTensorElementDataType elem_type); +size_t GetElementSizeByType(ONNX_NAMESPACE::TensorProto_DataType onnx_type); + // TODO: make these work with Wrappers? std::ostream& operator<<(std::ostream& out, const Qnn_Param_t& qnn_param); std::ostream& operator<<(std::ostream& out, const Qnn_Tensor_t& tensor); @@ -74,7 +76,30 @@ static bool ArrayHasString(const std::array& strings, std:: std::pair CheckMinMax(float rmin, float rmax); template -Status GetQminQmax(const Qnn_DataType_t qnn_data_type, T& qmin, T& qmax); +Status GetQminQmax(const Qnn_DataType_t qnn_data_type, + T& qmin, + T& qmax, + bool symmetric = false) { + if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { + qmin = static_cast(std::numeric_limits::min() + static_cast(symmetric)); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_16) { + qmin = static_cast(std::numeric_limits::min() + static_cast(symmetric)); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_32) { + qmin = static_cast(std::numeric_limits::min() + static_cast(symmetric)); + qmax = static_cast(std::numeric_limits::max()); + } else { + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); +} template inline T Saturate(const T qmax, @@ -104,6 +129,109 @@ Status Quantize(const double double_value, const Qnn_DataType_t qnn_data_type, int& quant_value); +size_t ShapeSizeCalc(gsl::span shape, size_t start, size_t end); + +// Computes the quantization parameters (scales and offsets) for the given data. +// Supports both per-tensor and per-channel quantization. Must provide an axis argument +// for per-channel quantization. +// The offsets use the QNN convention where offset = -zero_point. +Status GetDataQuantParams(gsl::span data, gsl::span shape, + /*out*/ gsl::span scales, /*out*/ gsl::span offsets, + Qnn_DataType_t data_type, bool symmetric = false, + std::optional axis = std::nullopt); + +// Quantizes the given float data using the provided quantization parameters (scales and offsets). +// Supports both per-tensor and per-channel quantization. Must provide an axis argument +// for per-channel quantization. +// The provided offsets must use the QNN convention where offset = -zero_point. +Status QuantizeData(gsl::span data, gsl::span shape, + gsl::span scales, gsl::span offsets, + /*out*/ gsl::span quant_bytes, Qnn_DataType_t data_type, + std::optional axis = std::nullopt); + +// Quantizes (per-tensor) the given float data using the provided scale and offset. +// The provided offset must use the QNN convention where offset = -zero_point. +template +inline Status QuantizeData(gsl::span data, float scale, int32_t offset, + /*out*/ gsl::span quant_bytes) { + const size_t num_elems = data.size(); + const size_t expected_output_bytes = sizeof(QuantType) * num_elems; + ORT_RETURN_IF_NOT(expected_output_bytes == quant_bytes.size(), + "Output buffer is not large enough to hold quantized bytes."); + const double clip_min = static_cast(std::numeric_limits::lowest()); + const double clip_max = static_cast(std::numeric_limits::max()); + + QuantType* output = reinterpret_cast(quant_bytes.data()); + for (size_t i = 0; i < num_elems; ++i) { + const double scale_dbl = static_cast(scale); + const double offset_dbl = static_cast(offset); + double float_val = std::nearbyint(static_cast(data[i]) / scale_dbl) - offset_dbl; + float_val = std::max(float_val, clip_min); + float_val = std::min(float_val, clip_max); + output[i] = static_cast(float_val); + } + return Status::OK(); +} + +// Re-writes a buffer of packed 4-bit elements to a buffer of unpacked 8-bit elements. +// QNN requires that 4-bit weights are unpacked to 8-bit. +template +Status UnpackInt4ToInt8(size_t num_int4_elems, std::vector& data_bytes) { + if constexpr (Signed) { // INT4 + std::vector packed_int4_bytes = std::move(data_bytes); + data_bytes = std::vector(num_int4_elems); + + auto dst = gsl::make_span(reinterpret_cast(data_bytes.data()), data_bytes.size()); + auto src = gsl::make_span(reinterpret_cast(packed_int4_bytes.data()), packed_int4_bytes.size()); + ORT_RETURN_IF_NOT(Int4x2::Unpack(dst, src), "Failed to unpack Tensor for QNN"); + + // NOTE: Masking off top 4 bits to workaround a QNN INT4 accuracy bug. + // Docs explicitly state that masking off top 4 bits should not be required, but we have to do it. + for (size_t i = 0; i < dst.size(); i++) { + dst[i] &= 0x0F; // -3 (0b1111_1101) becomes 13 (0b0000_1101) + } + } else { // UINT4 + std::vector packed_uint4_bytes = std::move(data_bytes); + data_bytes = std::vector(num_int4_elems); + + auto dst = gsl::make_span(reinterpret_cast(data_bytes.data()), data_bytes.size()); + auto src = gsl::make_span(reinterpret_cast(packed_uint4_bytes.data()), packed_uint4_bytes.size()); + ORT_RETURN_IF_NOT(UInt4x2::Unpack(dst, src), "Failed to unpack Tensor for QNN"); + } + + return Status::OK(); +} + +template +std::vector GetInitializerShape(const ONNX_NAMESPACE::TensorProto& tensor_proto) { + const auto& dims = tensor_proto.dims(); + std::vector tensor_shape_vec(static_cast(dims.size())); + for (int i = 0; i < dims.size(); ++i) { + tensor_shape_vec[i] = static_cast(dims[i]); + } + + return tensor_shape_vec; +} + +template +Status PermuteShape(gsl::span input_shape, gsl::span perm, gsl::span output_shape) { + const size_t rank = input_shape.size(); + ORT_RETURN_IF_NOT(rank == perm.size() && rank == output_shape.size(), + "PermuteShape(): expect all arguments to have the same rank."); + + for (size_t i = 0; i < rank; ++i) { + size_t p = static_cast(perm[i]); + output_shape[i] = input_shape[p]; + } + + return Status::OK(); +} + +// Get the min/max of a Clip operator. Reads values from attributes for opset < 11 and inputs after that. +// For opset 11+, if min/max are not constant initializers, will return false. +// For now we only support getting float min/max. +bool GetClipMinMax(const GraphViewer& graph_viewer, const Node& node, + float& min, float& max, const logging::Logger& logger); } // namespace utils } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/ort_api.cc b/onnxruntime/core/providers/qnn/ort_api.cc new file mode 100644 index 0000000000000..893708cc7e2e0 --- /dev/null +++ b/onnxruntime/core/providers/qnn/ort_api.cc @@ -0,0 +1,232 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/ort_api.h" + +#include +#include +#include + +namespace onnxruntime { + +#if BUILD_QNN_EP_STATIC_LIB +static std::unique_ptr>> s_run_on_unload_; + +void RunOnUnload(std::function function) { + static std::mutex mutex; + std::lock_guard guard(mutex); + if (!s_run_on_unload_) { + s_run_on_unload_ = std::make_unique>>(); + } + s_run_on_unload_->push_back(std::move(function)); +} + +struct OnUnload { + ~OnUnload() { + if (!s_run_on_unload_) + return; + + for (auto& function : *s_run_on_unload_) + function(); + + s_run_on_unload_.reset(); + } + +} g_on_unload; +#endif // BUILD_QNN_EP_STATIC_LIB + +std::vector Graph__Nodes(const Graph& graph) { +#if BUILD_QNN_EP_STATIC_LIB + std::vector nodes; + nodes.reserve(graph.NumberOfNodes()); + + for (const Node& node : graph.Nodes()) { + nodes.push_back(&node); + } + + return nodes; +#else + return graph.Nodes(); +#endif +} + +std::unique_ptr Factory::Create(gsl::span dq_nodes, + const Node& target_node, + gsl::span q_nodes, + NodeUnit::Type unit_type, + gsl::span inputs, + gsl::span outputs, + size_t input_edge_count, + gsl::span output_edges) { +#if BUILD_QNN_EP_STATIC_LIB + Node::EdgeSet output_edge_set; + for (const Node_EdgeEnd* edge_end : output_edges) { + output_edge_set.insert(*edge_end); + } + + return std::make_unique(dq_nodes, target_node, q_nodes, unit_type, + inputs, outputs, input_edge_count, output_edge_set); +#else + return NodeUnit::Create(dq_nodes, target_node, q_nodes, unit_type, inputs, outputs, input_edge_count, output_edges); +#endif +} + +#if BUILD_QNN_EP_STATIC_LIB +#define NODE_ATTR_ITER_VAL(iter) (iter)->second +#else +#define NODE_ATTR_ITER_VAL(iter) (iter)->second() +#endif + +NodeAttrHelper::NodeAttrHelper(const onnxruntime::Node& node) + : node_attributes_(node.GetAttributes()) {} + +NodeAttrHelper::NodeAttrHelper(const NodeUnit& node_unit) + : node_attributes_(node_unit.GetNode().GetAttributes()) {} + +float NodeAttrHelper::Get(const std::string& key, float def_val) const { + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return NODE_ATTR_ITER_VAL(entry).f(); + } + + return def_val; +} + +int32_t NodeAttrHelper::Get(const std::string& key, int32_t def_val) const { + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return narrow(NODE_ATTR_ITER_VAL(entry).i()); + } + + return def_val; +} + +uint32_t NodeAttrHelper::Get(const std::string& key, uint32_t def_val) const { + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return narrow(NODE_ATTR_ITER_VAL(entry).i()); + } + + return def_val; +} + +int64_t NodeAttrHelper::Get(const std::string& key, int64_t def_val) const { + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return NODE_ATTR_ITER_VAL(entry).i(); + } + + return def_val; +} + +const std::string& NodeAttrHelper::Get(const std::string& key, const std::string& def_val) const { + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return NODE_ATTR_ITER_VAL(entry).s(); + } + + return def_val; +} + +std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = NODE_ATTR_ITER_VAL(entry).ints(); + const int64_t* cbegin = values.data(); + const int64_t* cend = values.data() + values.size(); + std::vector v; + v.reserve(static_cast(values.size())); + std::transform(cbegin, cend, std::back_inserter(v), + [](int64_t val) -> int32_t { return narrow(val); }); + return v; + } + + return def_val; +} + +std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = NODE_ATTR_ITER_VAL(entry).ints(); + const int64_t* cbegin = values.data(); + const int64_t* cend = values.data() + values.size(); + std::vector v; + v.reserve(static_cast(values.size())); + std::transform(cbegin, cend, std::back_inserter(v), + [](int64_t val) -> uint32_t { return narrow(val); }); + return v; + } + + return def_val; +} + +std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = NODE_ATTR_ITER_VAL(entry).ints(); + const int64_t* cbegin = values.data(); + const int64_t* cend = values.data() + values.size(); + return std::vector{cbegin, cend}; + } + + return def_val; +} + +std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = NODE_ATTR_ITER_VAL(entry).floats(); + const float* cbegin = values.data(); + const float* cend = values.data() + values.size(); + return std::vector{cbegin, cend}; + } + + return def_val; +} + +std::optional NodeAttrHelper::GetFloat(const std::string& key) const { + std::optional result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + result = NODE_ATTR_ITER_VAL(entry).f(); + } + + return result; +} + +std::optional NodeAttrHelper::GetInt64(const std::string& key) const { + std::optional result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + result = NODE_ATTR_ITER_VAL(entry).i(); + } + + return result; +} + +std::optional> NodeAttrHelper::GetFloats(const std::string& key) const { + std::optional> result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = NODE_ATTR_ITER_VAL(entry).floats(); + const float* cbegin = values.data(); + const float* cend = values.data() + values.size(); + result = std::vector(cbegin, cend); + } + + return result; +} + +std::optional> NodeAttrHelper::GetInt64s(const std::string& key) const { + std::optional> result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = NODE_ATTR_ITER_VAL(entry).ints(); + const int64_t* cbegin = values.data(); + const int64_t* cend = values.data() + values.size(); + result = std::vector(cbegin, cend); + } + + return result; +} + +std::optional NodeAttrHelper::GetString(const std::string& key) const { + std::optional result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + result = NODE_ATTR_ITER_VAL(entry).s(); + } + + return result; +} + +bool NodeAttrHelper::HasAttr(const std::string& key) const { + return node_attributes_.find(key) != node_attributes_.end(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/ort_api.h b/onnxruntime/core/providers/qnn/ort_api.h new file mode 100644 index 0000000000000..c8f64a8975fd4 --- /dev/null +++ b/onnxruntime/core/providers/qnn/ort_api.h @@ -0,0 +1,201 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License + +#pragma once + +// This compilation unit (ort_api.h/.cc) encapsulates the interface between the EP and ORT in a manner +// that allows QNN EP to built either as a static library or a dynamic shared library. +// The preprocessor macro `BUILD_QNN_EP_STATIC_LIB` is defined and set to 1 if QNN EP +// is built as a static library. + +#if BUILD_QNN_EP_STATIC_LIB +// Includes when building QNN EP statically +#ifdef _WIN32 +#include +#include +#include "core/platform/tracing.h" +#include "core/platform/windows/logging/etw_sink.h" +#endif + +#include "onnx/defs/data_type_utils.h" +#include "core/common/common.h" +#include "core/common/status.h" +#include "core/common/safeint.h" +#include "core/common/logging/logging.h" +#include "core/common/logging/capture.h" +#include "core/common/path_string.h" +#include "core/platform/env.h" +#include "core/framework/data_types.h" +#include "core/framework/float16.h" +#include "core/framework/run_options.h" +#include "core/framework/execution_provider.h" +#include "core/framework/model_metadef_id_generator.h" +#include "core/framework/compute_capability.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/node_unit.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/utils.h" +#include "core/graph/constants.h" +#include "core/graph/basic_types.h" +#include "core/graph/model.h" +#include "core/graph/graph_viewer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/providers/common.h" +#include "core/providers/partitioning_utils.h" +#include "core/session/onnxruntime_cxx_api.h" +#else +// Includes when building QNN EP as a shared library +#include "core/providers/shared_library/provider_api.h" +#define ORT_API_MANUAL_INIT +#include "core/session/onnxruntime_cxx_api.h" +#endif + +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_run_options_config_keys.h" + +#include +#include + +namespace onnxruntime { +#if BUILD_QNN_EP_STATIC_LIB +using Node_EdgeEnd = Node::EdgeEnd; +#endif + +#if BUILD_QNN_EP_STATIC_LIB +void RunOnUnload(std::function function); +inline const Env& GetDefaultEnv() { return Env::Default(); } +#endif + +inline void InitOrtCppApi() { +#if BUILD_QNN_EP_STATIC_LIB + // Do nothing. Including "onnxruntime_cxx_api.h" normally initializes the global api_ object. +#else + // Call util function in provider bridge that initializes the global api_ object. + InitProviderOrtApi(); +#endif +} + +/// +/// Creates an onnxruntime or onnx object. Works for both static and shared library builds of QNN EP. +/// +/// Example: auto model = Factory<Model>::Create(/* args ... */); +/// +/// Type of the object to create +template +struct Factory { + template + static inline std::unique_ptr Create(Params&&... params) { +#if BUILD_QNN_EP_STATIC_LIB + return std::make_unique(std::forward(params)...); +#else + return T::Create(std::forward(params)...); +#endif + } +}; + +// Specialization of Factory for creating NodeUnit objects. +template <> +struct Factory { + static std::unique_ptr Create(gsl::span dq_nodes, + const Node& target_node, + gsl::span q_nodes, + NodeUnit::Type unit_type, + gsl::span inputs, + gsl::span outputs, + size_t input_edge_count, + gsl::span output_edges); +}; + +inline const ConfigOptions& RunOptions__GetConfigOptions(const RunOptions& run_options) { +#if BUILD_QNN_EP_STATIC_LIB + return run_options.config_options; +#else + return run_options.GetConfigOptions(); +#endif +} + +inline std::unique_ptr& ComputeCapability__SubGraph(ComputeCapability& compute_cability) { +#if BUILD_QNN_EP_STATIC_LIB + return compute_cability.sub_graph; +#else + return compute_cability.SubGraph(); +#endif +} + +inline std::vector& IndexedSubGraph__Nodes(IndexedSubGraph& indexed_sub_graph) { +#if BUILD_QNN_EP_STATIC_LIB + return indexed_sub_graph.nodes; +#else + return indexed_sub_graph.Nodes(); +#endif +} + +std::vector Graph__Nodes(const Graph& graph); + +inline std::pair>, std::unordered_map> +GetQDQNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger) { +#if BUILD_QNN_EP_STATIC_LIB + return QDQ::GetAllNodeUnits(graph_viewer, logger); +#else + return QDQ::GetAllNodeUnits(&graph_viewer, logger); +#endif +} + +/** + * Wrapping onnxruntime::Node for retrieving attribute values + */ +class NodeAttrHelper { + public: + explicit NodeAttrHelper(const Node& node); + + // Get the attributes from the target node of the node_unit + explicit NodeAttrHelper(const NodeUnit& node_unit); + + /* + * Get with default + */ + float Get(const std::string& key, float def_val) const; + std::vector Get(const std::string& key, const std::vector& def_val) const; + + int64_t Get(const std::string& key, int64_t def_val) const; + std::vector Get(const std::string& key, const std::vector& def_val) const; + + const std::string& Get(const std::string& key, const std::string& def_val) const; + + // Convert the i() or ints() of the attribute from int64_t to int32_t + int32_t Get(const std::string& key, int32_t def_val) const; + std::vector Get(const std::string& key, const std::vector& def_val) const; + + // Convert the i() or ints() of the attribute from int64_t to uint32_t + uint32_t Get(const std::string& key, uint32_t def_val) const; + std::vector Get(const std::string& key, const std::vector& def_val) const; + + /* + * Get without default. + */ + std::optional GetFloat(const std::string& key) const; + std::optional> GetFloats(const std::string& key) const; + + std::optional GetInt64(const std::string& key) const; + std::optional> GetInt64s(const std::string& key) const; + + std::optional GetString(const std::string& key) const; + + bool HasAttr(const std::string& key) const; + + private: + const NodeAttributes& node_attributes_; +}; + +namespace logging { +inline std::unique_ptr Capture__Create(const Logger& logger, logging::Severity severity, const char* category, + logging::DataType data_type, const CodeLocation& location) { +#if BUILD_QNN_EP_STATIC_LIB + return std::make_unique(logger, severity, category, data_type, location); +#else + return Capture::Create(logger, severity, category, data_type, location); +#endif +} +} // namespace logging +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 1d9242f8a5939..e3ae0b0d68aa7 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -5,58 +5,19 @@ #include #include -#include "core/framework/compute_capability.h" -#include "core/graph/graph_viewer.h" -#include "core/session/onnxruntime_session_options_config_keys.h" -#include "core/session/onnxruntime_run_options_config_keys.h" -#include "core/session/onnxruntime_cxx_api.h" -#include "core/framework/kernel_registry.h" -#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" -#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" -#include "core/platform/env.h" -#include "core/providers/common.h" -#include "core/providers/partitioning_utils.h" -#include "core/providers/partitioning_utils.h" +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/qnn_telemetry.h" +#include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_node_group.h" #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" -#include "core/framework/run_options.h" - -#ifdef _WIN32 -#include -#include "core/platform/windows/logging/etw_sink.h" -#endif namespace onnxruntime { constexpr const char* QNN = "QNN"; -static std::unique_ptr>> s_run_on_unload_; - -void RunOnUnload(std::function function) { - static std::mutex mutex; - std::lock_guard guard(mutex); - if (!s_run_on_unload_) { - s_run_on_unload_ = std::make_unique>>(); - } - s_run_on_unload_->push_back(std::move(function)); -} - -struct OnUnload { - ~OnUnload() { - if (!s_run_on_unload_) - return; - - for (auto& function : *s_run_on_unload_) - function(); - - s_run_on_unload_.reset(); - } - -} g_on_unload; - static void ParseProfilingLevel(std::string profiling_level_string, qnn::ProfilingLevel& profiling_level) { std::transform(profiling_level_string.begin(), @@ -193,17 +154,20 @@ qnn::ProfilingLevel QNNExecutionProvider::GetProfilingLevelFromETWLevel(unsigned } QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_options_map, - const SessionOptions* session_options) + const ConfigOptions* config_options) : IExecutionProvider{onnxruntime::kQnnExecutionProvider} { - if (session_options) { - disable_cpu_ep_fallback_ = session_options->config_options.GetConfigOrDefault( + InitOrtCppApi(); + metadef_id_generator_ = Factory::Create(); + + if (config_options) { + disable_cpu_ep_fallback_ = config_options->GetConfigOrDefault( kOrtSessionOptionsDisableCPUEPFallback, "0") == "1"; - context_cache_enabled_ = session_options->config_options.GetConfigOrDefault( + context_cache_enabled_ = config_options->GetConfigOrDefault( kOrtSessionOptionEpContextEnable, "0") == "1"; LOGS_DEFAULT(VERBOSE) << "Context cache enable: " << context_cache_enabled_; - std::string embed_mode = session_options->config_options.GetConfigOrDefault( + std::string embed_mode = config_options->GetConfigOrDefault( kOrtSessionOptionEpContextEmbedMode, "0"); if ("1" == embed_mode) { qnn_context_embed_mode_ = true; @@ -214,18 +178,18 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << qnn_context_embed_mode_; - context_cache_path_cfg_ = session_options->config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + context_cache_path_cfg_ = config_options->GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_cfg_; // For the case that workaround QNN context PD memory limit, user need split the model into pieces and // generate the QNN context model separately. // It could happen that the generated EPContext node in separate graph has same node name. // User can set this context_node_name_prefix for each split pieces to avoid that happens. - context_node_name_prefix_ = session_options->config_options.GetConfigOrDefault(kOrtSessionOptionEpContextNodeNamePrefix, ""); + context_node_name_prefix_ = config_options->GetConfigOrDefault(kOrtSessionOptionEpContextNodeNamePrefix, ""); LOGS_DEFAULT(VERBOSE) << "User specified QNN context node name prefix: " << context_node_name_prefix_; share_ep_contexts_ = - session_options->config_options.GetConfigOrDefault(kOrtSessionOptionShareEpContexts, "0") == "1"; + config_options->GetConfigOrDefault(kOrtSessionOptionShareEpContexts, "0") == "1"; LOGS_DEFAULT(VERBOSE) << "User specified option - share EP contexts across sessions: " << share_ep_contexts_; } @@ -246,8 +210,9 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio // separate out the profiling level for ETW in case it gets disabled later when we extract the events // set to invalid to indicate that ETW is no enabled when we setup QNN qnn::ProfilingLevel profiling_level_etw = qnn::ProfilingLevel::INVALID; - const Env& env = Env::Default(); - auto& provider = env.GetTelemetryProvider(); + +#ifdef _WIN32 + auto& provider = qnn::QnnTelemetry::Instance(); if (provider.IsEnabled()) { auto level = provider.Level(); auto keyword = provider.Keyword(); @@ -257,6 +222,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } } +#endif // defined(_WIN32) // In case ETW gets disabled later auto profiling_level_pos = provider_options_map.find(PROFILING_LEVEL); @@ -402,47 +368,53 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio soc_model, enable_htp_weight_sharing); -#ifdef _WIN32 - auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); - // Register callback for ETW capture state (rundown) - callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( - [&etwRegistrationManager, this]( - LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext) { - ORT_UNUSED_PARAMETER(SourceId); - ORT_UNUSED_PARAMETER(MatchAnyKeyword); - ORT_UNUSED_PARAMETER(MatchAllKeyword); - ORT_UNUSED_PARAMETER(FilterData); - ORT_UNUSED_PARAMETER(CallbackContext); - - if (IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { - if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) { - auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); - (void)qnn_backend_manager_->ResetQnnLogLevel(ortETWSeverity); - } - if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) { - if (Level != 0) { - // Commenting out Dynamic QNN Profiling for now - // There seems to be a crash in 3rd party QC QnnHtp.dll with this. - // Repro Scenario - start ETW tracing prior to session creation. - // Then disable/enable ETW Tracing with the code below uncommented a few times - // auto profiling_level_etw = GetProfilingLevelFromETWLevel(Level); - // (void)qnn_backend_manager_->SetProfilingLevelETW(profiling_level_etw); +#if defined(_WIN32) + if (onnxruntime::logging::EtwRegistrationManager::SupportsETW()) { + auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); + // Register callback for ETW capture state (rundown) + callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( + [&etwRegistrationManager, this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + ORT_UNUSED_PARAMETER(SourceId); + ORT_UNUSED_PARAMETER(MatchAnyKeyword); + ORT_UNUSED_PARAMETER(MatchAllKeyword); + ORT_UNUSED_PARAMETER(FilterData); + ORT_UNUSED_PARAMETER(CallbackContext); + + if (IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) { + auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); + (void)qnn_backend_manager_->ResetQnnLogLevel(ortETWSeverity); + } + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) { + if (Level != 0) { + // Commenting out Dynamic QNN Profiling for now + // There seems to be a crash in 3rd party QC QnnHtp.dll with this. + // Repro Scenario - start ETW tracing prior to session creation. + // Then disable/enable ETW Tracing with the code below uncommented a few times + // auto profiling_level_etw = GetProfilingLevelFromETWLevel(Level); + // (void)qnn_backend_manager_->SetProfilingLevelETW(profiling_level_etw); + // + // NOTE(1/2/2025): It is possible that the above was not working in part because it is using the + // *logging ETW* subsystem to modify profiling, which should use an entirely different + // ETW provider (see QnnTelemetry). Should add callbacks for profiling to the QnnTelemetry ETW provider. + } } } - } - if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { - // (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID); - (void)qnn_backend_manager_->ResetQnnLogLevel(std::nullopt); - } - }); - etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); + if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { + // (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID); + (void)qnn_backend_manager_->ResetQnnLogLevel(std::nullopt); + } + }); + etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); + } #endif } @@ -456,7 +428,7 @@ QNNExecutionProvider::~QNNExecutionProvider() { } // Unregister the ETW callback -#ifdef _WIN32 +#if defined(_WIN32) if (callback_ETWSink_provider_ != nullptr) { logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); } @@ -488,9 +460,10 @@ static void LogNodeSupport(const logging::Logger& logger, oss << "\tREASON : " << support_status.ErrorMessage() << std::endl; } - logging::Capture(logger, log_severity, logging::Category::onnxruntime, - log_data_type, call_site) - .Stream() + auto log_capture = Factory::Create(logger, log_severity, + logging::Category::onnxruntime, + log_data_type, call_site); + log_capture->Stream() << (support_status.IsOK() ? "Validation PASSED " : "Validation FAILED ") << "for " << num_nodes << " nodes in " << qnn_node_group.Type() << " (" << qnn_node_group.GetTargetNodeUnit()->OpType() << ") :" << std::endl @@ -594,11 +567,11 @@ static bool EpSharedContextsHasAllGraphs(const std::vectorName(); + const std::string& graph_name = ep_context_node.Name(); bool has_shared_qnn_model = SharedContext::GetInstance().HasQnnModel(graph_name); if (!has_shared_qnn_model) { LOGS(logger, VERBOSE) << "Graph: " << graph_name << " from EpContext node not found from shared EP contexts."; @@ -613,7 +586,7 @@ static bool EpSharedContextsHasAllGraphs(const std::vector>& result, - const utils::GenerateMetadefNameFn& gen_metadef_name, + const std::function& gen_metadef_name, const logging::Logger& logger) { std::unordered_set supported_nodes{}; std::vector> supported_groups{}; @@ -673,7 +646,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer const auto gen_metadef_name = [&]() { uint64_t model_hash; - int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); + int metadef_id = metadef_id_generator_->GenerateId(graph_viewer, model_hash); return MakeString(QNN, context_node_name_prefix_, "_", model_hash, "_", metadef_id); }; @@ -718,7 +691,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); + std::tie(node_unit_holder, node_unit_map) = GetQDQNodeUnits(graph_viewer, logger); // remove is_qnn_ctx_model related code const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, @@ -761,11 +734,14 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer bool is_valid_partition = true; size_t nodes_in_partition = 0; - if (partition && partition->sub_graph) { - nodes_in_partition = partition->sub_graph->nodes.size(); + if (partition && ComputeCapability__SubGraph(*partition)) { + const auto& subgraph = ComputeCapability__SubGraph(*partition); + const auto& subgraph_nodes = IndexedSubGraph__Nodes(*subgraph); + + nodes_in_partition = subgraph_nodes.size(); if (nodes_in_partition == 1 && !is_qnn_ctx_model) { - const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]); + const Node* node = graph_viewer.GetNode(subgraph_nodes[0]); if (!node) { LOGS(logger, ERROR) << "QNN EP: Invalid node in partition of one node."; @@ -834,34 +810,34 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector& nod void QNNExecutionProvider::InitQnnGraphConfigs(qnn::QnnConfigsBuilder& configs_builder) const { if (qnn_backend_manager_->GetQnnBackendType() == qnn::QnnBackendType::HTP) { if (htp_graph_finalization_opt_mode_ != qnn::HtpGraphFinalizationOptimizationMode::kDefault) { - QnnHtpGraph_CustomConfig_t& htp_graph_opt_config = configs_builder.PushCustomConfig(); - htp_graph_opt_config.option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION; - htp_graph_opt_config.optimizationOption.type = QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG; - htp_graph_opt_config.optimizationOption.floatValue = static_cast(htp_graph_finalization_opt_mode_); - - QnnGraph_Config_t& graph_opt_config = configs_builder.PushConfig(); - graph_opt_config.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; - graph_opt_config.customConfig = &htp_graph_opt_config; + gsl::not_null htp_graph_opt_config = configs_builder.PushCustomConfig(); + htp_graph_opt_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION; + htp_graph_opt_config->optimizationOption.type = QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG; + htp_graph_opt_config->optimizationOption.floatValue = static_cast(htp_graph_finalization_opt_mode_); + + gsl::not_null graph_opt_config = configs_builder.PushConfig(); + graph_opt_config->option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; + graph_opt_config->customConfig = htp_graph_opt_config; } if (vtcm_size_in_mb_ > 0) { - QnnHtpGraph_CustomConfig_t& htp_graph_opt_config_vtcm = configs_builder.PushCustomConfig(); - htp_graph_opt_config_vtcm.option = QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE; - htp_graph_opt_config_vtcm.vtcmSizeInMB = static_cast(vtcm_size_in_mb_); + gsl::not_null htp_graph_opt_config_vtcm = configs_builder.PushCustomConfig(); + htp_graph_opt_config_vtcm->option = QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE; + htp_graph_opt_config_vtcm->vtcmSizeInMB = static_cast(vtcm_size_in_mb_); - QnnGraph_Config_t& graph_opt_config_vtcm = configs_builder.PushConfig(); - graph_opt_config_vtcm.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; - graph_opt_config_vtcm.customConfig = &htp_graph_opt_config_vtcm; + gsl::not_null graph_opt_config_vtcm = configs_builder.PushConfig(); + graph_opt_config_vtcm->option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; + graph_opt_config_vtcm->customConfig = htp_graph_opt_config_vtcm; } if (enable_HTP_FP16_precision_) { - QnnHtpGraph_CustomConfig_t& htp_graph_precision_config = configs_builder.PushCustomConfig(); - htp_graph_precision_config.option = QNN_HTP_GRAPH_CONFIG_OPTION_PRECISION; - htp_graph_precision_config.precision = QNN_PRECISION_FLOAT16; + gsl::not_null htp_graph_precision_config = configs_builder.PushCustomConfig(); + htp_graph_precision_config->option = QNN_HTP_GRAPH_CONFIG_OPTION_PRECISION; + htp_graph_precision_config->precision = QNN_PRECISION_FLOAT16; - QnnGraph_Config_t& graph_precision_config = configs_builder.PushConfig(); - graph_precision_config.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; - graph_precision_config.customConfig = &htp_graph_precision_config; + gsl::not_null graph_precision_config = configs_builder.PushConfig(); + graph_precision_config->option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; + graph_precision_config->customConfig = htp_graph_precision_config; } } } @@ -918,10 +894,10 @@ Status QNNExecutionProvider::Compile(const std::vector& fused if (EpSharedContextsHasAllGraphs(fused_nodes_and_graphs, logger)) { for (auto fused_node_and_graph : fused_nodes_and_graphs) { const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); - const auto& ep_context_node = graph_viewer.Nodes().begin(); + const Node& ep_context_node = *graph_viewer.Nodes().begin(); const Node& fused_node = fused_node_and_graph.fused_node; const std::string& graph_meta_id = fused_node.Name(); - std::string key = ep_context_node->Name(); + std::string key = ep_context_node.Name(); auto qnn_model_shared = SharedContext::GetInstance().GetSharedQnnModel(key); ORT_RETURN_IF(nullptr == qnn_model_shared, "Graph: " + key + " not found from shared EP contexts."); ORT_RETURN_IF_ERROR(qnn_model_shared->SetGraphInputOutputInfo(graph_viewer, fused_node, logger)); @@ -963,10 +939,10 @@ Status QNNExecutionProvider::Compile(const std::vector& fused for (auto fused_node_and_graph : fused_nodes_and_graphs) { const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); - const auto& ep_context_node = graph_viewer.Nodes().begin(); + const Node& ep_context_node = *graph_viewer.Nodes().begin(); const Node& fused_node = fused_node_and_graph.fused_node; const std::string& graph_meta_id = fused_node.Name(); - std::string key = ep_context_node->Name(); + std::string key = ep_context_node.Name(); ORT_RETURN_IF(qnn_models.find(key) == qnn_models.end(), key + " key name not exist in table qnn_models."); auto qnn_model = std::move(qnn_models[key]); ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node, logger)); @@ -1007,7 +983,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused buffer_size, max_spill_fill_buffer_size)); } - qnn_ep_context_model_ = std::make_unique("qnn_ep_context_model", false, logger); + qnn_ep_context_model_ = Factory::Create(std::string{"qnn_ep_context_model"}, false, logger); ORT_RETURN_IF_ERROR(qnn::CreateEPContextNodes(qnn_ep_context_model_.get(), context_buffer.get(), buffer_size, @@ -1026,8 +1002,8 @@ const InlinedVector QNNExecutionProvider::GetEpContextNodes() const InlinedVector ep_context_nodes; if (qnn_ep_context_model_) { const auto& graph = qnn_ep_context_model_->MainGraph(); - for (const auto& node : graph.Nodes()) { - ep_context_nodes.push_back(graph.GetNode(node.Index())); + for (gsl::not_null node : Graph__Nodes(graph)) { + ep_context_nodes.push_back(graph.GetNode(node->Index())); } } @@ -1118,22 +1094,34 @@ void QNNExecutionProvider::ReleasePerThreadContext() const { per_thread_context_cache->erase(cached_context_it); } +static bool TryGetConfigEntry(const ConfigOptions& config_options, const std::string& key, std::string& value) { + std::optional new_value = config_options.GetConfigEntry(key); + if (!new_value.has_value()) { + return false; + } + + value = *new_value; + return true; +} + Status QNNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { auto backend_type = qnn_backend_manager_->GetQnnBackendType(); if (qnn::QnnBackendType::HTP != backend_type && qnn::QnnBackendType::DSP != backend_type) { return Status::OK(); } + const ConfigOptions& config_options = RunOptions__GetConfigOptions(run_options); + std::string htp_perf_mode = ""; qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; - if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnPerfMode, htp_perf_mode)) { + if (TryGetConfigEntry(config_options, kOrtRunOptionsConfigQnnPerfMode, htp_perf_mode)) { // set power mode ParseHtpPerformanceMode(htp_perf_mode, htp_performance_mode); } std::string rpc_latency = ""; uint32_t rpc_control_latency = 0; - if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnRpcControlLatency, rpc_latency)) { + if (TryGetConfigEntry(config_options, kOrtRunOptionsConfigQnnRpcControlLatency, rpc_latency)) { rpc_control_latency = static_cast(std::stoul(rpc_latency)); LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency; } @@ -1159,9 +1147,11 @@ Status QNNExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::R return Status::OK(); } + const ConfigOptions& config_options = RunOptions__GetConfigOptions(run_options); + std::string htp_perf_mode = ""; qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; - if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnPerfModePostRun, htp_perf_mode)) { + if (TryGetConfigEntry(config_options, kOrtRunOptionsConfigQnnPerfModePostRun, htp_perf_mode)) { // set power mode ParseHtpPerformanceMode(htp_perf_mode, htp_performance_mode); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index a0577e8fd87f2..026f699b2d27c 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -3,26 +3,19 @@ #pragma once -#include "core/framework/execution_provider.h" -#include "core/framework/session_options.h" -#include "core/framework/model_metadef_id_generator.h" -#include "core/graph/model.h" +#include #include +#include +#include + +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/providers/qnn/builder/qnn_model.h" #include "core/providers/qnn/builder/qnn_configs_helper.h" #include "HTP/QnnHtpGraph.h" -#include -#include -#include -#ifdef _WIN32 -#include "core/platform/windows/logging/etw_sink.h" -#endif namespace onnxruntime { -void RunOnUnload(std::function function); - class SharedContext { public: static SharedContext& GetInstance() { @@ -87,7 +80,7 @@ class SharedContext { // Logical device representation. class QNNExecutionProvider : public IExecutionProvider { public: - explicit QNNExecutionProvider(const ProviderOptions& provider_options_map, const SessionOptions* session_options); + explicit QNNExecutionProvider(const ProviderOptions& provider_options_map, const ConfigOptions* config_options); virtual ~QNNExecutionProvider(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QNNExecutionProvider); @@ -143,14 +136,14 @@ class QNNExecutionProvider : public IExecutionProvider { bool qnn_context_embed_mode_ = true; int32_t vtcm_size_in_mb_ = 0; std::unique_ptr qnn_ep_context_model_; - ModelMetadefIdGenerator metadef_id_generator_; + std::unique_ptr metadef_id_generator_; uint32_t device_id_ = 0; qnn::HtpPerformanceMode default_htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; uint32_t default_rpc_control_latency_ = 0; bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; bool enable_spill_fill_buffer_ = false; -#ifdef _WIN32 +#if defined(_WIN32) onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr; #endif qnn::ModelSettings model_settings_ = {}; diff --git a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc index 4095d7ff02a33..d4dd446751359 100644 --- a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc +++ b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc @@ -2,32 +2,68 @@ // Licensed under the MIT License #include "core/providers/qnn/qnn_provider_factory_creator.h" - -#include "core/session/abi_session_options_impl.h" #include "core/providers/qnn/qnn_execution_provider.h" -#include "core/session/ort_apis.h" namespace onnxruntime { struct QNNProviderFactory : IExecutionProviderFactory { - QNNProviderFactory(const ProviderOptions& provider_options_map, const SessionOptions* session_options) - : provider_options_map_(provider_options_map), session_options_(session_options) { + QNNProviderFactory(const ProviderOptions& provider_options_map, const ConfigOptions* config_options) + : provider_options_map_(provider_options_map), config_options_(config_options) { } ~QNNProviderFactory() override { } std::unique_ptr CreateProvider() override { - return std::make_unique(provider_options_map_, session_options_); + return std::make_unique(provider_options_map_, config_options_); } private: ProviderOptions provider_options_map_; - const SessionOptions* session_options_; + const ConfigOptions* config_options_; }; +#if BUILD_QNN_EP_STATIC_LIB std::shared_ptr QNNProviderFactoryCreator::Create(const ProviderOptions& provider_options_map, const SessionOptions* session_options) { - return std::make_shared(provider_options_map, session_options); + const ConfigOptions* config_options = nullptr; + if (session_options != nullptr) { + config_options = &session_options->config_options; + } + + return std::make_shared(provider_options_map, config_options); } +#else +struct QNN_Provider : Provider { + std::shared_ptr CreateExecutionProviderFactory(const void* param) override { + if (param == nullptr) { + LOGS_DEFAULT(ERROR) << "[QNN EP] Passed NULL options to CreateExecutionProviderFactory()"; + return nullptr; + } + + std::array pointers_array = *reinterpret_cast*>(param); + const ProviderOptions* provider_options = reinterpret_cast(pointers_array[0]); + const ConfigOptions* config_options = reinterpret_cast(pointers_array[1]); + + if (provider_options == nullptr) { + LOGS_DEFAULT(ERROR) << "[QNN EP] Passed NULL ProviderOptions to CreateExecutionProviderFactory()"; + return nullptr; + } + + return std::make_shared(*provider_options, config_options); + } + + void Initialize() override {} + void Shutdown() override {} +} g_provider; +#endif // BUILD_QNN_EP_STATIC_LIB } // namespace onnxruntime + +#if !BUILD_QNN_EP_STATIC_LIB +extern "C" { + +ORT_API(onnxruntime::Provider*, GetProvider) { + return &onnxruntime::g_provider; +} +} +#endif // !BUILD_QNN_EP_STATIC_LIB diff --git a/onnxruntime/core/providers/qnn/qnn_provider_factory_creator.h b/onnxruntime/core/providers/qnn/qnn_provider_factory_creator.h index 80f9d99b804e7..46b6c15b40553 100644 --- a/onnxruntime/core/providers/qnn/qnn_provider_factory_creator.h +++ b/onnxruntime/core/providers/qnn/qnn_provider_factory_creator.h @@ -11,6 +11,9 @@ namespace onnxruntime { struct SessionOptions; +// Defined in core/session/provider_bridge_ort.cc if built as a shared library (default build config). +// Defined in core/providers/qnn/qnn_provider_factory.cc if built as a static library. +// The preprocessor macro `BUILD_QNN_EP_STATIC_LIB` is defined and set to 1 if QNN is built as a static library. struct QNNProviderFactoryCreator { static std::shared_ptr Create(const ProviderOptions& provider_options_map, const SessionOptions* session_options); diff --git a/onnxruntime/core/providers/qnn/qnn_telemetry.cc b/onnxruntime/core/providers/qnn/qnn_telemetry.cc new file mode 100644 index 0000000000000..b2c8350bfe8ca --- /dev/null +++ b/onnxruntime/core/providers/qnn/qnn_telemetry.cc @@ -0,0 +1,211 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/qnn_telemetry.h" + +#ifdef _WIN32 +#if !BUILD_QNN_EP_STATIC_LIB +// ETW includes +// need space after Windows.h to prevent clang-format re-ordering breaking the build. +// TraceLoggingProvider.h must follow Windows.h +#include + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 26440) // Warning C26440 from TRACELOGGING_DEFINE_PROVIDER +#endif + +#include +#include +#include +#include "core/platform/windows/TraceLoggingConfig.h" + +// Seems this workaround can be dropped when we drop support for VS2017 toolchains +// https://developercommunity.visualstudio.com/content/problem/85934/traceloggingproviderh-is-incompatible-with-utf-8.html +#ifdef _TlgPragmaUtf8Begin +#undef _TlgPragmaUtf8Begin +#define _TlgPragmaUtf8Begin +#endif + +#ifdef _TlgPragmaUtf8End +#undef _TlgPragmaUtf8End +#define _TlgPragmaUtf8End +#endif + +// Different versions of TraceLoggingProvider.h contain different macro variable names for the utf8 begin and end, +// and we need to cover the lower case version as well. +#ifdef _tlgPragmaUtf8Begin +#undef _tlgPragmaUtf8Begin +#define _tlgPragmaUtf8Begin +#endif + +#ifdef _tlgPragmaUtf8End +#undef _tlgPragmaUtf8End +#define _tlgPragmaUtf8End +#endif + +TRACELOGGING_DEFINE_PROVIDER(telemetry_provider_handle, "Microsoft.ML.ONNXRuntime", + // {3a26b1ff-7484-7484-7484-15261f42614d} + (0x3a26b1ff, 0x7484, 0x7484, 0x74, 0x84, 0x15, 0x26, 0x1f, 0x42, 0x61, 0x4d), + TraceLoggingOptionMicrosoftTelemetry()); + +#ifdef _MSC_VER +#pragma warning(pop) +#endif +#endif // !BUILD_QNN_EP_STATIC_LIB + +#include "core/providers/qnn/ort_api.h" + +namespace onnxruntime { +namespace qnn { + +#if !BUILD_QNN_EP_STATIC_LIB +std::mutex QnnTelemetry::mutex_; +std::mutex QnnTelemetry::provider_change_mutex_; +uint32_t QnnTelemetry::global_register_count_ = 0; +bool QnnTelemetry::enabled_ = true; +UCHAR QnnTelemetry::level_ = 0; +UINT64 QnnTelemetry::keyword_ = 0; +std::vector QnnTelemetry::callbacks_; +std::mutex QnnTelemetry::callbacks_mutex_; +#endif // !BUILD_QNN_EP_STATIC_LIB + +QnnTelemetry::QnnTelemetry() { +#if !BUILD_QNN_EP_STATIC_LIB + std::lock_guard lock(mutex_); + if (global_register_count_ == 0) { + // TraceLoggingRegister is fancy in that you can only register once GLOBALLY for the whole process + HRESULT hr = TraceLoggingRegisterEx(telemetry_provider_handle, ORT_TL_EtwEnableCallback, nullptr); + if (SUCCEEDED(hr)) { + global_register_count_ += 1; + } + } +#endif // !BUILD_QNN_EP_STATIC_LIB +} + +QnnTelemetry::~QnnTelemetry() { +#if !BUILD_QNN_EP_STATIC_LIB + std::lock_guard lock(mutex_); + if (global_register_count_ > 0) { + global_register_count_ -= 1; + if (global_register_count_ == 0) { + TraceLoggingUnregister(telemetry_provider_handle); + } + } + + std::lock_guard lock_callbacks(callbacks_mutex_); + callbacks_.clear(); +#endif // !BUILD_QNN_EP_STATIC_LIB +} + +QnnTelemetry& QnnTelemetry::Instance() { + static QnnTelemetry instance; + return instance; +} + +bool QnnTelemetry::IsEnabled() const { +#if BUILD_QNN_EP_STATIC_LIB + const Env& env = GetDefaultEnv(); + auto& provider = env.GetTelemetryProvider(); + return provider.IsEnabled(); +#else + std::lock_guard lock(provider_change_mutex_); + return enabled_; +#endif +} + +UCHAR QnnTelemetry::Level() const { +#if BUILD_QNN_EP_STATIC_LIB + const Env& env = GetDefaultEnv(); + auto& provider = env.GetTelemetryProvider(); + return provider.Level(); +#else + std::lock_guard lock(provider_change_mutex_); + return level_; +#endif +} + +UINT64 QnnTelemetry::Keyword() const { +#if BUILD_QNN_EP_STATIC_LIB + const Env& env = GetDefaultEnv(); + auto& provider = env.GetTelemetryProvider(); + return provider.Keyword(); +#else + std::lock_guard lock(provider_change_mutex_); + return keyword_; +#endif +} + +void QnnTelemetry::LogQnnProfileEvent(uint64_t timestamp, + const std::string& message, + const std::string& qnnScalarValue, + const std::string& unit, + const std::string& timingSource, + const std::string& eventLevel, + const char* eventIdentifier) const { + TraceLoggingWrite( + telemetry_provider_handle, + "QNNProfilingEvent", + TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)), + TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE), + TraceLoggingValue(timestamp, "Timestamp"), + TraceLoggingString(message.c_str(), "Message"), + TraceLoggingString(qnnScalarValue.c_str(), "Value"), + TraceLoggingString(unit.c_str(), "Unit of Measurement"), + TraceLoggingString(timingSource.c_str(), "Timing Source"), + TraceLoggingString(eventLevel.c_str(), "Event Level"), + TraceLoggingString(eventIdentifier, "Event Identifier")); +} + +void QnnTelemetry::RegisterInternalCallback(const EtwInternalCallback& callback) { +#if BUILD_QNN_EP_STATIC_LIB + WindowsTelemetry::RegisterInternalCallback(callback); +#else + std::lock_guard lock_callbacks(callbacks_mutex_); + callbacks_.push_back(&callback); +#endif +} + +void QnnTelemetry::UnregisterInternalCallback(const EtwInternalCallback& callback) { +#if BUILD_QNN_EP_STATIC_LIB + WindowsTelemetry::UnregisterInternalCallback(callback); +#else + std::lock_guard lock_callbacks(callbacks_mutex_); + auto new_end = std::remove_if(callbacks_.begin(), callbacks_.end(), + [&callback](const EtwInternalCallback* ptr) { + return ptr == &callback; + }); + callbacks_.erase(new_end, callbacks_.end()); +#endif +} + +#if !BUILD_QNN_EP_STATIC_LIB +void NTAPI QnnTelemetry::ORT_TL_EtwEnableCallback( + _In_ LPCGUID SourceId, + _In_ ULONG IsEnabled, + _In_ UCHAR Level, + _In_ ULONGLONG MatchAnyKeyword, + _In_ ULONGLONG MatchAllKeyword, + _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData, + _In_opt_ PVOID CallbackContext) { + std::lock_guard lock(provider_change_mutex_); + enabled_ = (IsEnabled != 0); + level_ = Level; + keyword_ = MatchAnyKeyword; + + InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); +} + +void QnnTelemetry::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + std::lock_guard lock_callbacks(callbacks_mutex_); + for (const auto& callback : callbacks_) { + (*callback)(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + } +} +#endif // !BUILD_QNN_EP_STATIC_LIB + +} // namespace qnn +} // namespace onnxruntime +#endif // defined(_WIN32) diff --git a/onnxruntime/core/providers/qnn/qnn_telemetry.h b/onnxruntime/core/providers/qnn/qnn_telemetry.h new file mode 100644 index 0000000000000..a2d42c518c1ac --- /dev/null +++ b/onnxruntime/core/providers/qnn/qnn_telemetry.h @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#ifdef _WIN32 +#include + +#if !BUILD_QNN_EP_STATIC_LIB +#include +#endif + +#include +#include +#include +#include + +#include "core/providers/qnn/ort_api.h" + +#if !BUILD_QNN_EP_STATIC_LIB +TRACELOGGING_DECLARE_PROVIDER(telemetry_provider_handle); +#endif + +namespace onnxruntime { +namespace qnn { + +/// +/// Singleton class used to log QNN profiling events to the ONNX Runtime telemetry tracelogging provider. +/// +/// When QNN EP is a DLL, we must define our own tracelogging provider handle via TRACELOGGING_DEFINE_PROVIDER. +/// TraceLogging documentation states that separate DLLs cannot share the same tracelogging provider handle. See: +/// https://learn.microsoft.com/en-us/windows/win32/api/traceloggingprovider/nf-traceloggingprovider-tracelogging_define_provider#remarks +/// +/// When QNN EP is a static library, we use the tracelogging provider handle already defined +/// in core/platform/windows/telemetry.h/.cc. In this case, we forward method calls to the +/// ORT Env's telemetry provider. +/// +class QnnTelemetry { + public: + static QnnTelemetry& Instance(); + bool IsEnabled() const; + + // Get the current logging level + unsigned char Level() const; + + // Get the current keyword + UINT64 Keyword() const; + + // Logs QNN profiling event as trace logging event. + void LogQnnProfileEvent(uint64_t timestamp, + const std::string& message, + const std::string& qnnScalarValue, + const std::string& unit, + const std::string& timingSource, + const std::string& eventLevel, + const char* eventIdentifier) const; + + using EtwInternalCallback = std::function; + + static void RegisterInternalCallback(const EtwInternalCallback& callback); + + static void UnregisterInternalCallback(const EtwInternalCallback& callback); + + private: + QnnTelemetry(); + ~QnnTelemetry(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnTelemetry); + +#if !BUILD_QNN_EP_STATIC_LIB + static std::mutex mutex_; + static uint32_t global_register_count_; + static bool enabled_; + + static std::vector callbacks_; + static std::mutex callbacks_mutex_; + static std::mutex provider_change_mutex_; + static UCHAR level_; + static ULONGLONG keyword_; + + static void InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext); + + static void NTAPI ORT_TL_EtwEnableCallback( + _In_ LPCGUID SourceId, + _In_ ULONG IsEnabled, + _In_ UCHAR Level, + _In_ ULONGLONG MatchAnyKeyword, + _In_ ULONGLONG MatchAllKeyword, + _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData, + _In_opt_ PVOID CallbackContext); +#endif +}; + +} // namespace qnn +} // namespace onnxruntime + +#endif // defined(_WIN32) diff --git a/onnxruntime/core/providers/qnn/symbols.def b/onnxruntime/core/providers/qnn/symbols.def new file mode 100644 index 0000000000000..4ec2f7914c208 --- /dev/null +++ b/onnxruntime/core/providers/qnn/symbols.def @@ -0,0 +1,2 @@ +EXPORTS + GetProvider diff --git a/onnxruntime/core/providers/qnn/version_script.lds b/onnxruntime/core/providers/qnn/version_script.lds new file mode 100644 index 0000000000000..094abb3329781 --- /dev/null +++ b/onnxruntime/core/providers/qnn/version_script.lds @@ -0,0 +1,9 @@ +#_init and _fini should be local +VERS_1.0 { + global: + GetProvider; + + # Hide everything else. + local: + *; +}; diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 45f81ed22b7f7..6ff2572e5e668 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -9,6 +9,11 @@ #pragma once #define SHARED_PROVIDER 1 +#ifdef _WIN32 +#include +#include +#endif // defined(_WIN32) + #include #include #include @@ -136,6 +141,17 @@ enum class DataType { USER = 1 ///< Contains potentially sensitive user data. }; +enum class ORTTraceLoggingKeyword : uint64_t { + Session = 0x1, // ORT Session TraceLoggingWrite + Logs = 0x2, // LOGS() Macro ORT logs. Pair with an appropriate level depending on detail required + Reserved1 = 0x4, // Reserved if we want to add some specific sub-categories instead of just LOGS() or other uses + Reserved2 = 0x8, + Reserved3 = 0x10, + Reserved4 = 0x20, + Reserved5 = 0x40, + Reserved6 = 0x80, + Profiling = 0x100 // Enables profiling. At higher levels >5 can impact inference performance +}; } // namespace logging // OnnxRuntime Types (these are the internal types) @@ -143,6 +159,13 @@ struct CPUIDInfo; namespace logging { struct Logger; struct Capture; +#ifdef _WIN32 +struct EtwRegistrationManager; +using EtwRegistrationManager_EtwInternalCallback = std::function; +#endif } // namespace logging struct ComputeCapability; struct ConfigOptions; @@ -157,10 +180,12 @@ struct KernelRegistry; struct Function; struct Graph; class GraphViewer; +struct ConstGraphNodes; enum class DataLayout; struct Model; struct Path; struct Node; +struct Node_EdgeEnd; struct NodeArg; struct NodeAttributes; struct NodeUnitIODef; @@ -215,6 +240,7 @@ using DeleteFunc = void (*)(void*); using NodeArgInfo = ONNX_NAMESPACE::ValueInfoProto; using NameMLValMap = std::unordered_map; + } // namespace onnxruntime #include "core/platform/threadpool.h" @@ -368,6 +394,28 @@ template <> constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; } + +inline std::vector> +CreateSupportedPartitions(const GraphViewer& graph_viewer, + const std::unordered_set& supported_nodes, + const std::unordered_set& stop_ops, + const std::function& generate_metadef_name, + const std::string& execution_provider_name, + const std::string& execution_provider_type, + const std::unordered_map* node_unit_map, + bool drop_constant_initializers = false) { + return g_host->Utils__CreateSupportedPartitions(graph_viewer, supported_nodes, stop_ops, generate_metadef_name, + execution_provider_name, execution_provider_type, node_unit_map, + drop_constant_initializers); +} +inline std::unique_ptr MakeComputeCapability(const GraphViewer& graph_viewer, + const std::vector& group, + const std::function& generate_metadef_name, + const std::string& execution_provider_name, + bool drop_constant_initializers) { + return g_host->Utils__MakeComputeCapability(graph_viewer, group, generate_metadef_name, + execution_provider_name, drop_constant_initializers); +} } // namespace utils namespace QDQ { @@ -381,6 +429,10 @@ GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) // So the C API (and C++) becomes available when ORT_API_MANUAL_INIT is used. void InitProviderOrtApi(); +// This is a replacement for Env::Default(). Returns a reference to the default ORT Environment. +inline Env& GetDefaultEnv() { + return g_host->Env__Default(); +} } // namespace onnxruntime #define CREATE_MESSAGE(logger, severity, category, datatype) \ diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index aa8c367d25d51..4c050534456da 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -505,6 +505,9 @@ Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const st /*out*/ std::vector& unpacked_tensor) { return g_host->UnpackInitializerData(tensor, model_path, unpacked_tensor); } +Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, /*out*/ std::vector& unpacked_tensor) { + return g_host->UnpackInitializerData(tensor, std::filesystem::path(), unpacked_tensor); +} } // namespace utils @@ -788,5 +791,5 @@ std::string ToUTF8String(const std::wstring& s) { std::wstring ToWideString(const std::string& s) { return g_host->ToWideString(s); } -#endif +#endif // _WIN32 } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 5a179ec622f8c..54b4564be4887 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -120,11 +120,20 @@ struct Node__EdgeIterator { virtual bool operator!=(const Node__EdgeIterator& p) const = 0; virtual void operator++() = 0; + virtual const Node_EdgeEnd& operator*() const = 0; virtual const Node& GetNode() const = 0; virtual int GetSrcArgIndex() const = 0; virtual int GetDstArgIndex() const = 0; }; +struct ConstGraphNodes_Iterator { + virtual ~ConstGraphNodes_Iterator() {} + + virtual bool operator!=(const ConstGraphNodes_Iterator& other) const = 0; + virtual void operator++() = 0; + virtual const Node& operator*() = 0; +}; + // There are two ways to route a function, one is a virtual method and the other is a function pointer (or pointer to // member function). // The function pointers are nicer in that they directly call the target function, but they cannot be used in cases @@ -273,20 +282,40 @@ struct ProviderHost { // logging::Logger virtual bool logging__Logger__OutputIsEnabled(const logging::Logger* p, logging::Severity severity, logging::DataType data_type) = 0; + virtual logging::Severity logging__Logger__GetSeverity(const logging::Logger* p) = 0; // logging::LoggingManager virtual const logging::Logger& logging__LoggingManager__DefaultLogger() = 0; // logging::Capture - virtual std::unique_ptr logging__Capture__construct(const logging::Logger& logger, logging::Severity severity, const char* category, logging::DataType dataType, const CodeLocation& location) = 0; + virtual std::unique_ptr logging__Capture__construct(const logging::Logger& logger, + logging::Severity severity, + const char* category, + logging::DataType data_type, + const CodeLocation& location) = 0; virtual void logging__Capture__operator_delete(logging::Capture* p) noexcept = 0; virtual std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept = 0; + virtual void logging__Capture__ProcessPrintf(logging::Capture* p, const char* format, va_list args) = 0; + +#if defined(_WIN32) + // logging::EtwRegistrationManager + virtual logging::EtwRegistrationManager& logging__EtwRegistrationManager__Instance() = 0; + virtual bool logging__EtwRegistrationManager__SupportsETW() = 0; + virtual logging::Severity logging__EtwRegistrationManager__MapLevelToSeverity(logging::EtwRegistrationManager* p) = 0; + virtual void logging__EtwRegistrationManager__RegisterInternalCallback( + logging::EtwRegistrationManager* p, + const logging::EtwRegistrationManager_EtwInternalCallback& callback) = 0; + virtual void logging__EtwRegistrationManager__UnregisterInternalCallback( + logging::EtwRegistrationManager* p, + const logging::EtwRegistrationManager_EtwInternalCallback& callback) = 0; +#endif // defined(_WIN32) // Env virtual Env& Env__Default() = 0; // Utils::DataTypeUtils virtual const std::string* Utils__DataTypeUtils__ToType(const ONNX_NAMESPACE::TypeProto& type_proto) = 0; + virtual const std::string* Utils__DataTypeUtils__ToType(const std::string& type_str) = 0; // int64s virtual int int64s__size(const ONNX_NAMESPACE::int64s* p) = 0; @@ -328,6 +357,7 @@ struct ProviderHost { virtual bool TypeProto_Tensor__has_shape(const ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; virtual const ONNX_NAMESPACE::TensorShapeProto& TypeProto_Tensor__shape(const ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; virtual ONNX_NAMESPACE::TensorShapeProto* TypeProto_Tensor__mutable_shape(ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; + virtual bool TypeProto_Tensor__has_elem_type(const ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; virtual int32_t TypeProto_Tensor__elem_type(const ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; virtual void TypeProto_Tensor__set_elem_type(ONNX_NAMESPACE::TypeProto_Tensor* p, int32_t value) = 0; @@ -342,6 +372,7 @@ struct ProviderHost { // TypeProto virtual std::unique_ptr TypeProto__construct() = 0; virtual void TypeProto__CopyFrom(ONNX_NAMESPACE::TypeProto* p, const ONNX_NAMESPACE::TypeProto* other) = 0; + virtual bool TypeProto__has_tensor_type(const ONNX_NAMESPACE::TypeProto* p) = 0; virtual const ONNX_NAMESPACE::TypeProto_Tensor& TypeProto__tensor_type(const ONNX_NAMESPACE::TypeProto* p) = 0; virtual ONNX_NAMESPACE::TypeProto_Tensor* TypeProto__mutable_tensor_type(ONNX_NAMESPACE::TypeProto* p) = 0; @@ -462,6 +493,7 @@ struct ProviderHost { virtual bool TensorProto__has_raw_data(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual const std::string& TensorProto__raw_data(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual std::string* TensorProto__mutable_raw_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual bool TensorProto__has_data_type(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual int32_t TensorProto__data_type(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual void TensorProto__set_data_type(ONNX_NAMESPACE::TensorProto* p, int32_t type) = 0; virtual void TensorProto__CopyFrom(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto* other) = 0; @@ -495,6 +527,7 @@ struct ProviderHost { // TensorShapeProto_Dimensions virtual std::unique_ptr TensorShapeProto_Dimensions__begin(const ONNX_NAMESPACE::TensorShapeProto_Dimensions* p) = 0; virtual std::unique_ptr TensorShapeProto_Dimensions__end(const ONNX_NAMESPACE::TensorShapeProto_Dimensions* p) = 0; + virtual size_t TensorShapeProto_Dimensions__size(const ONNX_NAMESPACE::TensorShapeProto_Dimensions* p) = 0; // TensorShapeProto virtual int TensorShapeProto__dim_size(const ONNX_NAMESPACE::TensorShapeProto* p) = 0; @@ -823,6 +856,8 @@ struct ProviderHost { virtual const NodeAttributes& Node__GetAttributes(const Node* p) noexcept = 0; virtual void Node__AddAttribute(Node* p, const ::std::string& attr_name, const ONNX_NAMESPACE::GraphProto& value) = 0; + virtual void Node__AddAttribute(Node* p, const ::std::string& attr_name, const std::string& value) = 0; + virtual void Node__AddAttribute(Node* p, const ::std::string& attr_name, int64_t value) = 0; virtual size_t Node__GetInputEdgesCount(const Node* p) noexcept = 0; virtual size_t Node__GetOutputEdgesCount(const Node* p) noexcept = 0; @@ -842,6 +877,14 @@ struct ProviderHost { virtual const std::unordered_map>& Node__GetAttributeNameToMutableSubgraphMap(Node* p) = 0; virtual std::unordered_map> Node__GetAttributeNameToSubgraphMap(const Node* p) const = 0; + // Node_EdgeEnd + virtual std::unique_ptr Node_EdgeEnd__construct(const Node& node, int src_arg_index, int dst_arg_index) = 0; + virtual void Node_EdgeEnd__operator_delete(Node_EdgeEnd* p) noexcept = 0; + + virtual const Node& Node_EdgeEnd__GetNode(const Node_EdgeEnd* p) = 0; + virtual int Node_EdgeEnd__GetSrcArgIndex(const Node_EdgeEnd* p) = 0; + virtual int Node_EdgeEnd__GetDstArgIndex(const Node_EdgeEnd* p) = 0; + // NodeArg virtual const std::string& NodeArg__Name(const NodeArg* p) noexcept = 0; virtual const ONNX_NAMESPACE::TensorShapeProto* NodeArg__Shape(const NodeArg* p) = 0; @@ -872,6 +915,12 @@ struct ProviderHost { virtual void NodeAttributes__reserve(NodeAttributes* p, size_t size) = 0; // NodeUnit + virtual std::unique_ptr NodeUnit__construct(gsl::span dq_nodes, const Node& target_node, + gsl::span q_nodes, uint8_t unit_type, + gsl::span inputs, gsl::span outputs, + size_t input_edge_count, gsl::span output_edges) = 0; + virtual void NodeUnit__operator_delete(NodeUnit* p) noexcept = 0; + virtual int NodeUnit__UnitType(const NodeUnit* p) noexcept = 0; virtual const std::vector& NodeUnit__Inputs(const NodeUnit* p) noexcept = 0; @@ -897,10 +946,29 @@ struct ProviderHost { virtual std::pair>, std::unordered_map> QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) = 0; + // Partitioning utils + virtual std::vector> + Utils__CreateSupportedPartitions(const GraphViewer& graph_viewer, + const std::unordered_set& supported_nodes, + const std::unordered_set& stop_ops, + const std::function& generate_metadef_name, + const std::string& execution_provider_name, + const std::string& execution_provider_type, + const std::unordered_map* node_unit_map, + bool drop_constant_initializers) = 0; + + virtual std::unique_ptr + Utils__MakeComputeCapability(const GraphViewer& graph_viewer, + const std::vector& group, + const std::function& generate_metadef_name, + const std::string& execution_provider_name, + bool drop_constant_initializers) = 0; // Model virtual std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger) = 0; + virtual std::unique_ptr Model__construct(const std::string& graph_name, bool is_onnx_domain_only, + const logging::Logger& logger) = 0; virtual void Model__operator_delete(Model* p) = 0; virtual Graph& Model__MainGraph(Model* p) = 0; virtual std::unique_ptr Model__ToProto(Model* p) = 0; @@ -974,6 +1042,7 @@ struct ProviderHost { virtual const std::string& GraphViewer__Name(const GraphViewer* p) noexcept = 0; virtual const std::filesystem::path& GraphViewer__ModelPath(const GraphViewer* p) noexcept = 0; + virtual const ConstGraphNodes& GraphViewer__Nodes(const GraphViewer* p) noexcept = 0; virtual const Node* GraphViewer__GetNode(const GraphViewer* p, NodeIndex node_index) = 0; virtual const NodeArg* GraphViewer__GetNodeArg(const GraphViewer* p, const std::string& name) = 0; @@ -989,6 +1058,7 @@ struct ProviderHost { virtual const std::vector& GraphViewer__GetInputs(const GraphViewer* p) noexcept = 0; virtual const std::vector& GraphViewer__GetOutputs(const GraphViewer* p) noexcept = 0; + virtual bool GraphViewer__NodeProducesGraphOutput(const GraphViewer* p, const Node& node) = 0; virtual const std::unordered_set& GraphViewer__GetValueInfo(const GraphViewer* p) noexcept = 0; virtual const InitializedTensorSet& GraphViewer__GetAllInitializedTensors(const GraphViewer* p) = 0; @@ -1007,6 +1077,13 @@ struct ProviderHost { virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0; virtual IOnnxRuntimeOpSchemaCollectionPtr GraphViewer__GetSchemaRegistry(const GraphViewer* p) const = 0; + // ConstGraphNodes + virtual std::unique_ptr ConstGraphNodes__begin(const ConstGraphNodes* p) = 0; + virtual std::unique_ptr ConstGraphNodes__end(const ConstGraphNodes* p) = 0; + virtual std::unique_ptr ConstGraphNodes__cbegin(const ConstGraphNodes* p) = 0; + virtual std::unique_ptr ConstGraphNodes__cend(const ConstGraphNodes* p) = 0; + virtual bool ConstGraphNodes__empty(const ConstGraphNodes* p) noexcept = 0; + // OpKernel virtual const Node& OpKernel__Node(const OpKernel* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 76b6d8063fd66..309432592d97c 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -23,6 +23,9 @@ namespace logging { struct Logger final { bool OutputIsEnabled(Severity severity, DataType data_type) const noexcept { return g_host->logging__Logger__OutputIsEnabled(this, severity, data_type); } + Severity GetSeverity() const noexcept { + return g_host->logging__Logger__GetSeverity(this); + } PROVIDER_DISALLOW_ALL(Logger) }; @@ -35,15 +38,34 @@ struct LoggingManager final { struct Capture final { static std::unique_ptr Create(const Logger& logger, logging::Severity severity, const char* category, - logging::DataType dataType, const CodeLocation& location) { return g_host->logging__Capture__construct(logger, severity, category, dataType, location); } + logging::DataType data_type, const CodeLocation& location) { + return g_host->logging__Capture__construct(logger, severity, category, data_type, location); + } static void operator delete(void* p) { g_host->logging__Capture__operator_delete(reinterpret_cast(p)); } std::ostream& Stream() noexcept { return g_host->logging__Capture__Stream(this); } + void ProcessPrintf(const char* format, va_list args) { g_host->logging__Capture__ProcessPrintf(this, format, args); } Capture() = delete; Capture(const Capture&) = delete; void operator=(const Capture&) = delete; }; + +#if defined(_WIN32) +struct EtwRegistrationManager final { + using EtwInternalCallback = EtwRegistrationManager_EtwInternalCallback; + static EtwRegistrationManager& Instance() { return g_host->logging__EtwRegistrationManager__Instance(); } + static bool SupportsETW() { return g_host->logging__EtwRegistrationManager__SupportsETW(); } + Severity MapLevelToSeverity() { return g_host->logging__EtwRegistrationManager__MapLevelToSeverity(this); } + void RegisterInternalCallback(const EtwInternalCallback& callback) { + g_host->logging__EtwRegistrationManager__RegisterInternalCallback(this, callback); + } + void UnregisterInternalCallback(const EtwInternalCallback& callback) { + g_host->logging__EtwRegistrationManager__UnregisterInternalCallback(this, callback); + } +}; +#endif // defined(_WIN32) + } // namespace logging } // namespace onnxruntime @@ -234,6 +256,7 @@ struct TensorProto final { const std::string& raw_data() const { return g_host->TensorProto__raw_data(this); } std::string* mutable_raw_data() { return g_host->TensorProto__mutable_raw_data(this); } + bool has_data_type() const { return g_host->TensorProto__has_data_type(this); } int32_t data_type() const { return g_host->TensorProto__data_type(this); } void set_data_type(int32_t type) { return g_host->TensorProto__set_data_type(this, type); } @@ -286,6 +309,7 @@ struct TensorShapeProto_Dimension final { struct TensorShapeProto_Dimensions final { IteratorHolder begin() const { return g_host->TensorShapeProto_Dimensions__begin(this); } IteratorHolder end() const { return g_host->TensorShapeProto_Dimensions__end(this); } + size_t size() const { return g_host->TensorShapeProto_Dimensions__size(this); } PROVIDER_DISALLOW_ALL(TensorShapeProto_Dimensions) }; @@ -305,6 +329,7 @@ struct TypeProto_Tensor final { bool has_shape() const { return g_host->TypeProto_Tensor__has_shape(this); } const TensorShapeProto& shape() const { return g_host->TypeProto_Tensor__shape(this); } TensorShapeProto* mutable_shape() { return g_host->TypeProto_Tensor__mutable_shape(this); } + bool has_elem_type() const { return g_host->TypeProto_Tensor__has_elem_type(this); } int32_t elem_type() const { return g_host->TypeProto_Tensor__elem_type(this); } void set_elem_type(int32_t value) { g_host->TypeProto_Tensor__set_elem_type(this, value); } @@ -339,6 +364,7 @@ struct TypeProto_Sequence final { struct TypeProto final { static std::unique_ptr Create() { return g_host->TypeProto__construct(); } + bool has_tensor_type() const { return g_host->TypeProto__has_tensor_type(this); } const TypeProto_Tensor& tensor_type() const { return g_host->TypeProto__tensor_type(this); } TypeProto_Tensor* mutable_tensor_type() { return g_host->TypeProto__mutable_tensor_type(this); } @@ -475,6 +501,7 @@ namespace Utils { struct DataTypeUtils final { static const std::string* ToType(const ONNX_NAMESPACE::TypeProto& type_proto) { return g_host->Utils__DataTypeUtils__ToType(type_proto); } + static const std::string* ToType(const std::string& type_str) { return g_host->Utils__DataTypeUtils__ToType(type_str); } PROVIDER_DISALLOW_ALL(DataTypeUtils) }; @@ -770,6 +797,21 @@ struct Function final { PROVIDER_DISALLOW_ALL(Function) }; +struct Node_EdgeEnd final { + static std::unique_ptr Create(const Node& node, int src_arg_index, int dst_arg_index) { + return g_host->Node_EdgeEnd__construct(node, src_arg_index, dst_arg_index); + } + static void operator delete(void* p) { g_host->Node_EdgeEnd__operator_delete(reinterpret_cast(p)); } + + const Node& GetNode() const { return g_host->Node_EdgeEnd__GetNode(this); } + int GetSrcArgIndex() const { return g_host->Node_EdgeEnd__GetSrcArgIndex(this); } + int GetDstArgIndex() const { return g_host->Node_EdgeEnd__GetDstArgIndex(this); } + + Node_EdgeEnd() = delete; + Node_EdgeEnd(const Node_EdgeEnd&) = delete; + void operator=(const Node_EdgeEnd&) = delete; +}; + struct Node final { enum class Type { Primitive = 0, @@ -801,6 +843,12 @@ struct Node final { void AddAttribute(const ::std::string& attr_name, const ONNX_NAMESPACE::GraphProto& value) { g_host->Node__AddAttribute(this, attr_name, value); } + void AddAttribute(const std::string& attr_name, const std::string& value) { + g_host->Node__AddAttribute(this, attr_name, value); + } + void AddAttribute(const std::string& attr_name, int64_t value) { + g_host->Node__AddAttribute(this, attr_name, value); + } size_t GetInputEdgesCount() const noexcept { return g_host->Node__GetInputEdgesCount(this); } size_t GetOutputEdgesCount() const noexcept { return g_host->Node__GetOutputEdgesCount(this); } @@ -832,6 +880,7 @@ struct Node final { } void operator++() { impl_->operator++(); } + const Node_EdgeEnd& operator*() { return impl_->operator*(); } const Node__EdgeIterator* operator->() const { return impl_.get(); } std::unique_ptr impl_; @@ -906,6 +955,15 @@ struct NodeUnit final { QDQGroup, // The NodeUnit contain a QDQ group of nodes, such as "DQ->Sigmoid->Q" }; + static std::unique_ptr Create(gsl::span dq_nodes, const Node& target_node, + gsl::span q_nodes, Type unit_type, + gsl::span inputs, gsl::span outputs, + size_t input_edge_count, gsl::span output_edges) { + return g_host->NodeUnit__construct(dq_nodes, target_node, q_nodes, static_cast(unit_type), + inputs, outputs, input_edge_count, output_edges); + } + static void operator delete(void* p) { g_host->NodeUnit__operator_delete(reinterpret_cast(p)); } + Type UnitType() const noexcept { return static_cast(g_host->NodeUnit__UnitType(this)); } const std::vector& Inputs() const noexcept { return g_host->NodeUnit__Inputs(this); } @@ -932,6 +990,10 @@ struct NodeUnit final { // output. any Q nodes are hidden. Node::EdgeConstIterator OutputEdgesBegin() const { return g_host->NodeUnit__OutputEdgesBegin(this); } Node::EdgeConstIterator OutputEdgesEnd() const { return g_host->NodeUnit__OutputEdgesEnd(this); } + + NodeUnit() = delete; + NodeUnit(const NodeUnit&) = delete; + void operator=(const NodeUnit& v) = delete; }; struct ModelSavingOptions; @@ -941,6 +1003,9 @@ struct Model final { const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger) { return g_host->Model__construct(std::move(model_proto), model_path, local_registries, logger); } + static std::unique_ptr Create(const std::string& graph_name, bool is_onnx_domain_only, const logging::Logger& logger) { + return g_host->Model__construct(graph_name, is_onnx_domain_only, logger); + } static void operator delete(void* p) { g_host->Model__operator_delete(reinterpret_cast(p)); } static Status Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { return g_host->Model__Load(file_path, model_proto); } @@ -1041,6 +1106,7 @@ class GraphViewer final { const std::string& Name() const noexcept { return g_host->GraphViewer__Name(this); } const std::filesystem::path& ModelPath() const noexcept { return g_host->GraphViewer__ModelPath(this); } + const ConstGraphNodes& Nodes() const noexcept { return g_host->GraphViewer__Nodes(this); } const Node* GetNode(NodeIndex node_index) const { return g_host->GraphViewer__GetNode(this, node_index); } const NodeArg* GetNodeArg(const std::string& name) const { return g_host->GraphViewer__GetNodeArg(this, name); } @@ -1058,6 +1124,9 @@ class GraphViewer final { const std::vector& GetInputs() const noexcept { return g_host->GraphViewer__GetInputs(this); } const std::vector& GetOutputs() const noexcept { return g_host->GraphViewer__GetOutputs(this); } + bool NodeProducesGraphOutput(const Node& node) const { + return g_host->GraphViewer__NodeProducesGraphOutput(this, node); + } const std::unordered_set& GetValueInfo() const noexcept { return g_host->GraphViewer__GetValueInfo(this); } const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return g_host->GraphViewer__GetAllInitializedTensors(this); } @@ -1085,6 +1154,25 @@ class GraphViewer final { void operator=(const GraphViewer&) = delete; }; +struct ConstGraphNodes final { + IteratorHolder begin() const { + return g_host->ConstGraphNodes__begin(this); + } + IteratorHolder end() const { + return g_host->ConstGraphNodes__end(this); + } + IteratorHolder cbegin() const { + return g_host->ConstGraphNodes__cbegin(this); + } + IteratorHolder cend() const { + return g_host->ConstGraphNodes__cend(this); + } + + bool empty() const noexcept { return g_host->ConstGraphNodes__empty(this); } + + PROVIDER_DISALLOW_ALL(ConstGraphNodes) +}; + struct OpKernelContext final { template const T& RequiredInput(int index) const; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index af39edae2074d..36ce529998836 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -37,7 +37,6 @@ #include "core/framework/model_metadef_id_generator.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" -#include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_c_api.h" #include "core/common/string_helper.h" @@ -62,6 +61,10 @@ #include "orttraining/core/framework/distributed_run_context.h" #endif +#ifdef _WIN32 +#include "core/platform/windows/logging/etw_sink.h" +#endif + namespace ONNX_NAMESPACE { // We use these names in the provider API because we don't have the protobuf definitions of the RepeatedField* types using int64s = google::protobuf::RepeatedField; @@ -76,11 +79,18 @@ using FunctionProtos = google::protobuf::RepeatedPtrField; namespace onnxruntime { using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef; using IndexedSubGraph_SourceOfSchema = IndexedSubGraph::SourceOfSchema; +using Node_EdgeEnd = Node::EdgeEnd; +#ifdef _WIN32 +namespace logging { +using EtwRegistrationManager_EtwInternalCallback = EtwRegistrationManager::EtwInternalCallback; +} +#endif } // namespace onnxruntime #include "core/common/cpuid_info.h" #include "core/common/logging/logging.h" #include "core/providers/shared_library/provider_interfaces.h" +#include "core/providers/partitioning_utils.h" #include "core/providers/cuda/cuda_provider_factory_creator.h" #include "core/providers/cann/cann_provider_factory_creator.h" @@ -90,6 +100,7 @@ using IndexedSubGraph_SourceOfSchema = IndexedSubGraph::SourceOfSchema; #include "core/providers/openvino/openvino_provider_factory_creator.h" #include "core/providers/tensorrt/tensorrt_provider_factory_creator.h" #include "core/providers/vitisai/vitisai_provider_factory_creator.h" +#include "core/providers/qnn/qnn_provider_factory_creator.h" #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/cann/cann_provider_factory.h" @@ -181,6 +192,7 @@ struct Node__EdgeIterator_Impl : Node__EdgeIterator { bool operator!=(const Node__EdgeIterator& p) const override { return v_ != static_cast(&p)->v_; } void operator++() override { v_.operator++(); } + const Node_EdgeEnd& operator*() const override { return v_.operator*(); } const Node& GetNode() const override { return v_->GetNode(); } int GetSrcArgIndex() const override { return v_->GetSrcArgIndex(); } int GetDstArgIndex() const override { return v_->GetDstArgIndex(); } @@ -188,6 +200,18 @@ struct Node__EdgeIterator_Impl : Node__EdgeIterator { Node::EdgeConstIterator v_; }; +struct ConstGraphNodes_Iterator_Impl : ConstGraphNodes_Iterator { + ConstGraphNodes_Iterator_Impl(ConstGraphNodes::ConstNodeIterator&& v) : v_{std::move(v)} {} + + bool operator!=(const ConstGraphNodes_Iterator& other) const override { + return v_ != static_cast(&other)->v_; + } + void operator++() override { v_.operator++(); } + const Node& operator*() override { return *v_; } + + ConstGraphNodes::ConstNodeIterator v_; +}; + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) common::Status LoadDynamicLibraryFromProvider(onnxruntime::PathString library_name) { const auto& platform_env = onnxruntime::Env::Default(); @@ -367,22 +391,57 @@ struct ProviderHostImpl : ProviderHost { // logging::Logger (wrapped) bool logging__Logger__OutputIsEnabled(const logging::Logger* p, logging::Severity severity, logging::DataType data_type) override { return p->OutputIsEnabled(severity, data_type); } + logging::Severity logging__Logger__GetSeverity(const logging::Logger* p) override { + return p->GetSeverity(); + } // logging::LoggingManager (wrapped) const logging::Logger& logging__LoggingManager__DefaultLogger() override { return logging::LoggingManager::DefaultLogger(); } // logging::Capture (wrapped) - std::unique_ptr logging__Capture__construct(const logging::Logger& logger, logging::Severity severity, const char* category, logging::DataType dataType, const CodeLocation& location) override { - return std::make_unique(logger, severity, category, dataType, location); + std::unique_ptr logging__Capture__construct(const logging::Logger& logger, + logging::Severity severity, const char* category, + logging::DataType data_type, + const CodeLocation& location) override { + return std::make_unique(logger, severity, category, data_type, location); } void logging__Capture__operator_delete(logging::Capture* p) noexcept override { delete p; } std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept override { return p->Stream(); } + void logging__Capture__ProcessPrintf(logging::Capture* p, const char* format, va_list args) override { + p->ProcessPrintf(format, args); + } + +#if defined(_WIN32) + // logging::EtwRegistrationManager + logging::EtwRegistrationManager& logging__EtwRegistrationManager__Instance() override { + return logging::EtwRegistrationManager::Instance(); + } + bool logging__EtwRegistrationManager__SupportsETW() override { + return logging::EtwRegistrationManager::SupportsETW(); + } + logging::Severity logging__EtwRegistrationManager__MapLevelToSeverity(logging::EtwRegistrationManager* p) override { + return p->MapLevelToSeverity(); + } + void logging__EtwRegistrationManager__RegisterInternalCallback( + logging::EtwRegistrationManager* p, + const logging::EtwRegistrationManager_EtwInternalCallback& callback) override { + p->RegisterInternalCallback(callback); + } + void logging__EtwRegistrationManager__UnregisterInternalCallback( + logging::EtwRegistrationManager* p, + const logging::EtwRegistrationManager_EtwInternalCallback& callback) override { + p->UnregisterInternalCallback(callback); + } +#endif // defined(_WIN32) // Env Env& Env__Default() override { return Env::Default(); } // Utils::DataTypeUtils (wrapped) const std::string* Utils__DataTypeUtils__ToType(const ONNX_NAMESPACE::TypeProto& type_proto) override { return ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(type_proto); } + const std::string* Utils__DataTypeUtils__ToType(const std::string& type_str) override { + return ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(type_str); + } // int64s (wrapped) int int64s__size(const ONNX_NAMESPACE::int64s* p) override { return p->size(); } @@ -424,6 +483,7 @@ struct ProviderHostImpl : ProviderHost { bool TypeProto_Tensor__has_shape(const ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->has_shape(); } const ONNX_NAMESPACE::TensorShapeProto& TypeProto_Tensor__shape(const ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->shape(); } ONNX_NAMESPACE::TensorShapeProto* TypeProto_Tensor__mutable_shape(ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->mutable_shape(); } + bool TypeProto_Tensor__has_elem_type(const ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->has_elem_type(); } int32_t TypeProto_Tensor__elem_type(const ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->elem_type(); } void TypeProto_Tensor__set_elem_type(ONNX_NAMESPACE::TypeProto_Tensor* p, int32_t value) override { p->set_elem_type(value); }; @@ -444,6 +504,7 @@ struct ProviderHostImpl : ProviderHost { // TypeProto (wrapped) std::unique_ptr TypeProto__construct() override { return std::make_unique(); } void TypeProto__CopyFrom(ONNX_NAMESPACE::TypeProto* p, const ONNX_NAMESPACE::TypeProto* other) override { p->CopyFrom(*other); } + bool TypeProto__has_tensor_type(const ONNX_NAMESPACE::TypeProto* p) override { return p->has_tensor_type(); } const ONNX_NAMESPACE::TypeProto_Tensor& TypeProto__tensor_type(const ONNX_NAMESPACE::TypeProto* p) override { return p->tensor_type(); } ONNX_NAMESPACE::TypeProto_Tensor* TypeProto__mutable_tensor_type(ONNX_NAMESPACE::TypeProto* p) override { return p->mutable_tensor_type(); } int TypeProto__value_case(const ONNX_NAMESPACE::TypeProto* p) override { return p->value_case(); } @@ -572,6 +633,7 @@ struct ProviderHostImpl : ProviderHost { const std::string& TensorProto__raw_data(const ONNX_NAMESPACE::TensorProto* p) override { return p->raw_data(); } std::string* TensorProto__mutable_raw_data(ONNX_NAMESPACE::TensorProto* p) override { return p->mutable_raw_data(); } + bool TensorProto__has_data_type(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_data_type(); } int32_t TensorProto__data_type(const ONNX_NAMESPACE::TensorProto* p) override { return p->data_type(); } void TensorProto__set_data_type(ONNX_NAMESPACE::TensorProto* p, int32_t type) override { p->set_data_type(type); } @@ -610,6 +672,10 @@ struct ProviderHostImpl : ProviderHost { return std::make_unique(p->end()); } + size_t TensorShapeProto_Dimensions__size(const ONNX_NAMESPACE::TensorShapeProto_Dimensions* p) override { + return p->size(); + } + // TensorShapeProto (wrapped) int TensorShapeProto__dim_size(const ONNX_NAMESPACE::TensorShapeProto* p) override { return p->dim_size(); } const ONNX_NAMESPACE::TensorShapeProto_Dimensions& TensorShapeProto__dim(const ONNX_NAMESPACE::TensorShapeProto* p) override { return p->dim(); } @@ -960,6 +1026,12 @@ struct ProviderHostImpl : ProviderHost { void Node__AddAttribute(Node* p, const ::std::string& attr_name, const ONNX_NAMESPACE::GraphProto& value) override { p->AddAttribute(attr_name, value); } + void Node__AddAttribute(Node* p, const ::std::string& attr_name, const std::string& value) override { + p->AddAttribute(attr_name, value); + } + void Node__AddAttribute(Node* p, const ::std::string& attr_name, int64_t value) override { + p->AddAttribute(attr_name, value); + } size_t Node__GetInputEdgesCount(const Node* p) noexcept override { return p->GetInputEdgesCount(); } size_t Node__GetOutputEdgesCount(const Node* p) noexcept override { return p->GetOutputEdgesCount(); } @@ -982,6 +1054,16 @@ struct ProviderHostImpl : ProviderHost { std::unordered_map> Node__GetAttributeNameToSubgraphMap(const Node* p) const override { return p->GetAttributeNameToSubgraphMap(); } int Node__NodeType(const Node* p) const noexcept override { return int(p->NodeType()); } + // Node_EdgeEnd (wrapped). Maps to Node::EdgeEnd struct. + std::unique_ptr Node_EdgeEnd__construct(const Node& node, int src_arg_index, int dst_arg_index) override { + return std::make_unique(node, src_arg_index, dst_arg_index); + } + void Node_EdgeEnd__operator_delete(Node_EdgeEnd* p) noexcept override { delete p; } + + const Node& Node_EdgeEnd__GetNode(const Node_EdgeEnd* p) override { return p->GetNode(); } + int Node_EdgeEnd__GetSrcArgIndex(const Node_EdgeEnd* p) override { return p->GetSrcArgIndex(); } + int Node_EdgeEnd__GetDstArgIndex(const Node_EdgeEnd* p) override { return p->GetDstArgIndex(); } + // NodeArg (wrapped) const std::string& NodeArg__Name(const NodeArg* p) noexcept override { return p->Name(); } const ONNX_NAMESPACE::TensorShapeProto* NodeArg__Shape(const NodeArg* p) override { return p->Shape(); } @@ -1017,6 +1099,24 @@ struct ProviderHostImpl : ProviderHost { void NodeAttributes__reserve(NodeAttributes* p, size_t size) override { p->reserve(size); } // NodeUnit (wrapped) + std::unique_ptr NodeUnit__construct(gsl::span dq_nodes, + const Node& target_node, + gsl::span q_nodes, + uint8_t unit_type, + gsl::span inputs, + gsl::span outputs, + size_t input_edge_count, + gsl::span output_edges) override { + Node::EdgeSet output_edge_set; + for (const Node_EdgeEnd* edge_end : output_edges) { + output_edge_set.insert(*edge_end); + } + + return std::make_unique(dq_nodes, target_node, q_nodes, static_cast(unit_type), + inputs, outputs, input_edge_count, output_edge_set); + } + void NodeUnit__operator_delete(NodeUnit* p) noexcept override { delete p; } + int NodeUnit__UnitType(const NodeUnit* p) noexcept override { return static_cast(p->UnitType()); } const std::vector& NodeUnit__Inputs(const NodeUnit* p) noexcept override { @@ -1064,12 +1164,46 @@ struct ProviderHostImpl : ProviderHost { return QDQ::GetAllNodeUnits(*graph_viewer, logger); } + // Partitioning utils + std::vector> + Utils__CreateSupportedPartitions(const GraphViewer& graph_viewer, + const std::unordered_set& supported_nodes, + const std::unordered_set& stop_ops, + const utils::GenerateMetadefNameFn& generate_metadef_name, + const std::string& execution_provider_name, + const std::string& execution_provider_type, + const std::unordered_map* node_unit_map, + bool drop_constant_initializers) override { + return onnxruntime::utils::CreateSupportedPartitions(graph_viewer, + supported_nodes, + stop_ops, + generate_metadef_name, + execution_provider_name, + execution_provider_type, + node_unit_map, + drop_constant_initializers); + } + + std::unique_ptr + Utils__MakeComputeCapability(const GraphViewer& graph_viewer, + const std::vector& group, + const std::function& generate_metadef_name, + const std::string& execution_provider_name, + bool drop_constant_initializers) override { + return onnxruntime::utils::MakeComputeCapability(graph_viewer, group, generate_metadef_name, + execution_provider_name, drop_constant_initializers); + } + // Model (wrapped) std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger) override { return std::make_unique(model_proto, model_path, local_registries, logger); } + std::unique_ptr Model__construct(const std::string& graph_name, bool is_onnx_domain_only, + const logging::Logger& logger) override { + return std::make_unique(graph_name, is_onnx_domain_only, logger); + } void Model__operator_delete(Model* p) override { delete p; } Graph& Model__MainGraph(Model* p) override { return p->MainGraph(); } std::unique_ptr Model__ToProto(Model* p) override { return std::make_unique(p->ToProto()); } @@ -1179,6 +1313,7 @@ struct ProviderHostImpl : ProviderHost { const std::string& GraphViewer__Name(const GraphViewer* p) noexcept override { return p->Name(); } const std::filesystem::path& GraphViewer__ModelPath(const GraphViewer* p) noexcept override { return p->ModelPath(); } + const ConstGraphNodes& GraphViewer__Nodes(const GraphViewer* p) noexcept override { return p->Nodes(); } const Node* GraphViewer__GetNode(const GraphViewer* p, NodeIndex node_index) override { return p->GetNode(node_index); } const NodeArg* GraphViewer__GetNodeArg(const GraphViewer* p, const std::string& name) override { return p->GetNodeArg(name); } @@ -1196,6 +1331,9 @@ struct ProviderHostImpl : ProviderHost { const std::vector& GraphViewer__GetInputs(const GraphViewer* p) noexcept override { return p->GetInputs(); } const std::vector& GraphViewer__GetOutputs(const GraphViewer* p) noexcept override { return p->GetOutputs(); } + bool GraphViewer__NodeProducesGraphOutput(const GraphViewer* p, const Node& node) override { + return p->NodeProducesGraphOutput(node); + } const std::unordered_set& GraphViewer__GetValueInfo(const GraphViewer* p) noexcept override { return p->GetValueInfo(); } const InitializedTensorSet& GraphViewer__GetAllInitializedTensors(const GraphViewer* p) override { return p->GetAllInitializedTensors(); } @@ -1224,6 +1362,21 @@ struct ProviderHostImpl : ProviderHost { const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } IOnnxRuntimeOpSchemaCollectionPtr GraphViewer__GetSchemaRegistry(const GraphViewer* p) const override { return p->GetSchemaRegistry(); } + // ConstGraphNodes + std::unique_ptr ConstGraphNodes__begin(const ConstGraphNodes* p) override { + return std::make_unique(p->begin()); + } + std::unique_ptr ConstGraphNodes__end(const ConstGraphNodes* p) override { + return std::make_unique(p->end()); + } + std::unique_ptr ConstGraphNodes__cbegin(const ConstGraphNodes* p) override { + return std::make_unique(p->cbegin()); + } + std::unique_ptr ConstGraphNodes__cend(const ConstGraphNodes* p) override { + return std::make_unique(p->cend()); + } + bool ConstGraphNodes__empty(const ConstGraphNodes* p) noexcept override { return p->empty(); } + // OpKernel (direct) const Node& OpKernel__Node(const OpKernel* p) override { return p->OpKernel::Node(); } @@ -1651,6 +1804,14 @@ static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX ORT_TSTR("onnxruntime_p ); static ProviderLibrary s_library_migraphx(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_migraphx") LIBRARY_EXTENSION); +// QNN EP can be built either as a static library or a shared library. Can safely define s_library_qnn even if static. +static ProviderLibrary s_library_qnn(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_qnn") LIBRARY_EXTENSION +#ifndef _WIN32 + , + false /* unload - On Linux if we unload the qnn shared provider we crash */ +#endif +); + void UnloadSharedProviders() { s_library_dnnl.Unload(); s_library_vitisai.Unload(); @@ -1662,6 +1823,7 @@ void UnloadSharedProviders() { s_library_rocm.Unload(); s_library_shared.Unload(); s_library_migraphx.Unload(); + s_library_qnn.Unload(); } // Used by test code @@ -1832,6 +1994,20 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O return ov_options_converted_map; } +#if !BUILD_QNN_EP_STATIC_LIB +std::shared_ptr QNNProviderFactoryCreator::Create(const ProviderOptions& provider_options_map, + const SessionOptions* session_options) { + const ConfigOptions* config_options = nullptr; + if (session_options != nullptr) { + config_options = &session_options->config_options; + } + + std::array configs_array = {&provider_options_map, config_options}; + const void* arg = reinterpret_cast(&configs_array); + return s_library_qnn.Get().CreateExecutionProviderFactory(arg); +} +#endif // !BUILD_QNN_EP_STATIC_LIB + std::shared_ptr OpenVINOProviderFactoryCreator::Create( const ProviderOptions* provider_options_map, const SessionOptions* session_options) { // Append session options applicable for EP to EP Provider options. diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index a3f0ed55b83f2..38fde332ca992 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -7,7 +7,6 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/inference_session.h" -#include "core/providers/shared/utils/utils.h" #include "test/providers/qnn/qnn_test_utils.h" @@ -25,6 +24,24 @@ namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +static int64_t GetNodeAttr(const Node& node, const std::string& attr_name, int64_t default_val) { + const auto& attributes = node.GetAttributes(); + if (auto entry = attributes.find(attr_name); entry != attributes.end()) { + return entry->second.i(); + } + + return default_val; +} + +static const std::string& GetNodeAttr(const Node& node, const std::string& attr_name, const std::string& default_val) { + const auto& attributes = node.GetAttributes(); + if (auto entry = attributes.find(attr_name); entry != attributes.end()) { + return entry->second.s(); + } + + return default_val; +} + // Create a model with FusedMatMul + Add (quantized) // input1 -> Add -> Q -> DQ \ // FusedMatMul -> Q -> DQ -> output @@ -873,10 +890,9 @@ static void GetLastContextBinaryFileName(const std::string last_onnx_ctx_file, auto& ctx_graph = ctx_model->MainGraph(); for (auto& node : ctx_graph.Nodes()) { if (node.OpType() == "EPContext") { - NodeAttrHelper node_helper(node); - int64_t is_main_context = node_helper.Get("main_context", static_cast(0)); + int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); if (1 == is_main_context) { - last_ctx_bin_file = node_helper.Get("ep_cache_context", ""); + last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); return; } } @@ -899,10 +915,9 @@ static void UpdateEpContextModel(const std::vector& ep_ctx_files, for (auto& node : ctx_graph.Nodes()) { if (node.OpType() == "EPContext") { - NodeAttrHelper node_helper(node); - int64_t is_main_context = node_helper.Get("main_context", static_cast(0)); + int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); if (1 == is_main_context) { - std::string old_qnn_ctx_binary_file_name = node_helper.Get("ep_cache_context", ""); + std::string old_qnn_ctx_binary_file_name = GetNodeAttr(node, "ep_cache_context", ""); auto file_path = path.replace_filename(old_qnn_ctx_binary_file_name); std::remove(file_path.string().c_str()); node.ClearAttribute("ep_cache_context"); diff --git a/onnxruntime/test/qnn_ctx_gen/main.cc b/onnxruntime/test/qnn_ctx_gen/main.cc index 3be0bd253c8a4..bb5007b40b072 100644 --- a/onnxruntime/test/qnn_ctx_gen/main.cc +++ b/onnxruntime/test/qnn_ctx_gen/main.cc @@ -16,7 +16,6 @@ #include "core/common/logging/sinks/clog_sink.h" #include "core/graph/model.h" -#include "core/providers/shared/utils/utils.h" #include "core/session/environment.h" #include "core/common/logging/logging.h" @@ -31,6 +30,24 @@ static void CheckStatus(const Status& status) { } } +static int64_t GetNodeAttr(const Node& node, const std::string& attr_name, int64_t default_val) { + const auto& attributes = node.GetAttributes(); + if (auto entry = attributes.find(attr_name); entry != attributes.end()) { + return entry->second.i(); + } + + return default_val; +} + +static const std::string& GetNodeAttr(const Node& node, const std::string& attr_name, const std::string& default_val) { + const auto& attributes = node.GetAttributes(); + if (auto entry = attributes.find(attr_name); entry != attributes.end()) { + return entry->second.s(); + } + + return default_val; +} + // from the last context cache Onnx model, find the EPContext node with main_context=1, // and get the QNN context binary file name, this context binary contains all graphs from all Onnx models // get the max spill fill buffer size @@ -44,11 +61,10 @@ static void GetLastContextBinaryFileName(const std::basic_string last auto& ctx_graph = ctx_model->MainGraph(); for (auto& node : ctx_graph.Nodes()) { if (node.OpType() == "EPContext") { - NodeAttrHelper node_helper(node); - int64_t is_main_context = node_helper.Get("main_context", static_cast(0)); - max_size = node_helper.Get("max_size", static_cast(0)); + int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); + max_size = GetNodeAttr(node, "max_size", static_cast(0)); if (1 == is_main_context) { - last_ctx_bin_file = node_helper.Get("ep_cache_context", ""); + last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); return; } } @@ -72,10 +88,9 @@ static void UpdateEpContextModel(const std::vector> for (auto& node : ctx_graph.Nodes()) { if (node.OpType() == "EPContext") { - NodeAttrHelper node_helper(node); - int64_t is_main_context = node_helper.Get("main_context", static_cast(0)); + int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); if (1 == is_main_context) { - std::string old_qnn_ctx_binary_file_name = node_helper.Get("ep_cache_context", ""); + std::string old_qnn_ctx_binary_file_name = GetNodeAttr(node, "ep_cache_context", ""); auto file_path = path.replace_filename(old_qnn_ctx_binary_file_name); std::remove(file_path.string().c_str()); node.ClearAttribute("ep_cache_context"); diff --git a/setup.py b/setup.py index c1580eeb9e8f9..5c464eec537ec 100644 --- a/setup.py +++ b/setup.py @@ -311,17 +311,20 @@ def finalize_options(self): providers_tensorrt_or_migraphx = "onnxruntime_providers_" + ("migraphx" if is_migraphx else "tensorrt") providers_openvino = "onnxruntime_providers_openvino" providers_cann = "onnxruntime_providers_cann" +providers_qnn = "onnxruntime_providers_qnn" if platform.system() == "Linux": providers_cuda_or_rocm = "lib" + providers_cuda_or_rocm + ".so" providers_tensorrt_or_migraphx = "lib" + providers_tensorrt_or_migraphx + ".so" providers_openvino = "lib" + providers_openvino + ".so" providers_cann = "lib" + providers_cann + ".so" + providers_qnn = "lib" + providers_qnn + ".so" elif platform.system() == "Windows": providers_cuda_or_rocm = providers_cuda_or_rocm + ".dll" providers_tensorrt_or_migraphx = providers_tensorrt_or_migraphx + ".dll" providers_openvino = providers_openvino + ".dll" providers_cann = providers_cann + ".dll" + providers_qnn = providers_qnn + ".dll" # Additional binaries dl_libs = [] @@ -341,8 +344,9 @@ def finalize_options(self): dl_libs.append(providers_cuda_or_rocm) dl_libs.append(providers_tensorrt_or_migraphx) dl_libs.append(providers_cann) + dl_libs.append(providers_qnn) dl_libs.append("libonnxruntime.so*") - # DNNL, TensorRT & OpenVINO EPs are built as shared libs + # DNNL, TensorRT, OpenVINO, and QNN EPs are built as shared libs libs.extend(["libonnxruntime_providers_shared.so"]) libs.extend(["libonnxruntime_providers_dnnl.so"]) libs.extend(["libonnxruntime_providers_openvino.so"]) @@ -350,6 +354,7 @@ def finalize_options(self): libs.append(providers_cuda_or_rocm) libs.append(providers_tensorrt_or_migraphx) libs.append(providers_cann) + libs.append(providers_qnn) # QNN qnn_deps = [ "libQnnCpu.so", @@ -388,13 +393,14 @@ def finalize_options(self): providers_cann, "onnxruntime.dll", ] - # DNNL, TensorRT & OpenVINO EPs are built as shared libs + # DNNL, TensorRT, OpenVINO, and QNN EPs are built as shared libs libs.extend(["onnxruntime_providers_shared.dll"]) libs.extend(["onnxruntime_providers_dnnl.dll"]) libs.extend(["onnxruntime_providers_tensorrt.dll"]) libs.extend(["onnxruntime_providers_openvino.dll"]) libs.extend(["onnxruntime_providers_cuda.dll"]) libs.extend(["onnxruntime_providers_vitisai.dll"]) + libs.extend(["onnxruntime_providers_qnn.dll"]) # DirectML Libs libs.extend(["DirectML.dll"]) # QNN V68/V73 dependencies diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 53dcdc6e0c6fa..fc84436e9a879 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -128,6 +128,17 @@ def invalid_hetero_build(): return device_read +def _qnn_verify_library_kind(library_kind): + choices = ["shared_lib", "static_lib"] + if library_kind not in choices: + print("\nYou have specified an invalid library kind for QNN EP.") + print(f"The invalid library kind was: {library_kind}") + print("Provide a library kind from the following options: ", choices) + print(f"Example: --use_qnn {choices[0]}") + sys.exit("Incorrect build configuration") + return library_kind + + def parse_arguments(): class Parser(argparse.ArgumentParser): # override argument file line parsing behavior - allow multiple arguments per line and handle quotes @@ -579,7 +590,14 @@ def convert_arg_line_to_args(self, arg_line): parser.add_argument("--use_jsep", action="store_true", help="Build with JavaScript kernels.") parser.add_argument("--use_webgpu", action="store_true", help="Build with WebGPU support.") parser.add_argument("--use_external_dawn", action="store_true", help="Treat Dawn as an external dependency.") - parser.add_argument("--use_qnn", action="store_true", help="Build with QNN support.") + parser.add_argument( + "--use_qnn", + nargs="?", + const="shared_lib", # If provide --use_qnn without an arg, defaults to a shared library. + type=_qnn_verify_library_kind, + help="Build with QNN support. Specify 'shared_lib' or 'static_lib' to build QNN EP " + "as a shared or static library, respectively.", + ) parser.add_argument("--qnn_home", help="Path to QNN SDK dir.") parser.add_argument("--use_rknpu", action="store_true", help="Build with RKNPU.") parser.add_argument("--use_preinstalled_eigen", action="store_true", help="Use pre-installed Eigen.") @@ -1303,6 +1321,11 @@ def generate_build_tree( raise BuildError("qnn_home=" + qnn_home + " not valid." + " qnn_home paths must be specified and valid.") cmake_args += ["-Donnxruntime_USE_QNN=ON"] + if args.use_qnn == "static_lib": + cmake_args += ["-Donnxruntime_BUILD_QNN_EP_STATIC_LIB=ON"] + if args.android and args.use_qnn != "static_lib": + raise BuildError("Only support Android + QNN builds with QNN EP built as a static library.") + if args.use_coreml: cmake_args += ["-Donnxruntime_USE_COREML=ON"] @@ -2346,6 +2369,8 @@ def build_nuget_package( elif use_rocm: package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm" elif use_qnn: + if use_qnn != "shared_lib": + raise BuildError("Currently NuGet packages with QNN require QNN EP to be built as a shared library.") execution_provider = "/p:ExecutionProvider=qnn" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.QNN" elif any(map(lambda x: "OrtPackageId=" in x, msbuild_extra_options)): diff --git a/tools/ci_build/github/android/build_aar_package.py b/tools/ci_build/github/android/build_aar_package.py index 1b34b3d302e57..a01a9265575c2 100644 --- a/tools/ci_build/github/android/build_aar_package.py +++ b/tools/ci_build/github/android/build_aar_package.py @@ -75,11 +75,18 @@ def _parse_build_settings(args): return build_settings +def _is_qnn_android_build(build_settings): + for build_arg in build_settings["build_params"]: + if build_arg.startswith("--use_qnn"): + return True + return False + + def _build_aar(args): build_settings = _parse_build_settings(args) build_dir = os.path.abspath(args.build_dir) ops_config_path = os.path.abspath(args.include_ops_by_config) if args.include_ops_by_config else None - qnn_android_build = "--use_qnn" in build_settings["build_params"] + qnn_android_build = _is_qnn_android_build(build_settings) # Setup temp environment for building temp_env = os.environ.copy() diff --git a/tools/ci_build/github/android/default_qnn_aar_build_settings.json b/tools/ci_build/github/android/default_qnn_aar_build_settings.json index 599c108f830e7..588a914ebf680 100644 --- a/tools/ci_build/github/android/default_qnn_aar_build_settings.json +++ b/tools/ci_build/github/android/default_qnn_aar_build_settings.json @@ -11,7 +11,7 @@ "--cmake_generator=Ninja", "--build_java", "--build_shared_lib", - "--use_qnn", + "--use_qnn=static_lib", "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF", "--skip_tests" diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index c3dbee336b69d..11b9177e96787 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -72,7 +72,8 @@ jobs: --android_abi=x86_64 \ --android_api=31 \ --parallel \ - --use_qnn \ + --build_shared_lib \ + --use_qnn static_lib \ --qnn_home $(QnnSDKRootDir) \ --cmake_generator=Ninja \ --skip_tests diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index d3826d90f9073..b2362a377728b 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -41,7 +41,12 @@ jobs: timeoutInMinutes: 60 workspace: clean: all - + strategy: + matrix: + SHARED_LIB: + QnnLibKind: 'shared_lib' + STATIC_LIB: + QnnLibKind: 'static_lib' steps: - script: | ls -R /data/qnn_test_data @@ -65,7 +70,8 @@ jobs: --config Release \ --use_binskim_compliant_compile_flags \ --build_java \ - --use_qnn \ + --build_shared_lib \ + --use_qnn $(QnnLibKind) \ --qnn_home $(QnnSDKRootDir) \ --cmake_generator=Ninja \ --update --build --parallel @@ -77,7 +83,8 @@ jobs: --config Release \ --use_binskim_compliant_compile_flags \ --build_java \ - --use_qnn \ + --build_shared_lib \ + --use_qnn $(QnnLibKind) \ --qnn_home $(QnnSDKRootDir) \ --cmake_generator=Ninja \ --test diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index e07f0afa6109c..da58b70be7f83 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -94,6 +94,7 @@ jobs: --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --cmake_generator "$(VSGenerator)" + --build_shared_lib --use_qnn --qnn_home $(QnnSDKRootDir) --enable_pybind diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 8cc647c2464f3..e64a184d8ebeb 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -92,6 +92,7 @@ jobs: --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --cmake_generator "$(VSGenerator)" + --build_shared_lib --use_qnn --qnn_home $(QnnSDKRootDir) --enable_pybind diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 466fee92d0d5e..a61bfc7706818 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -92,6 +92,7 @@ jobs: --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --cmake_generator "$(VSGenerator)" + --build_shared_lib --use_qnn --qnn_home $(QnnSDKRootDir) --enable_pybind diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index aa0b6bf6d391e..4a437be325e7a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -93,12 +93,18 @@ stages: workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' createLogFile: true + - task: CmdLine@2 + displayName: 'Print contents of binaries directory' + inputs: + script: | + dir $(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }} + - template: win-esrp-dll.yml parameters: FolderPath: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\${{ parameters.build_config }}' DisplayName: 'ESRP - Sign dlls' DoEsrp: ${{ parameters.DoEsrp }} - Pattern: 'onnxruntime.dll' + Pattern: 'onnxruntime*.dll' - task: MSBuild@1 displayName: 'Restore NuGet Packages and create project.assets.json' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 5c013fae6be0b..18d4a615766a4 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -36,7 +36,7 @@ parameters: default: 2.28.2.241116 jobs: -- job: 'build' +- job: 'BUILD_QNN_EP' pool: 'onnxruntime-qnn-windows-vs-2022-arm64' variables: DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true @@ -46,6 +46,12 @@ jobs: timeoutInMinutes: 240 workspace: clean: all + strategy: + matrix: + SHARED_LIB: + QnnLibKind: 'shared_lib' + STATIC_LIB: + QnnLibKind: 'static_lib' steps: - script: | @@ -79,7 +85,8 @@ jobs: --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --cmake_generator "Visual Studio 17 2022" - --use_qnn + --build_shared_lib + --use_qnn $(QnnLibKind) --qnn_home $(QnnSDKRootDir) --update --build --parallel @@ -88,7 +95,8 @@ jobs: --config $(BuildConfig) ^ --build_dir $(Build.BinariesDirectory) ^ --cmake_generator "Visual Studio 17 2022" ^ - --use_qnn ^ + --build_shared_lib ^ + --use_qnn $(QnnLibKind) ^ --qnn_home $(QnnSDKRootDir) ^ --test --enable_onnx_tests displayName: 'Run unit tests' @@ -121,7 +129,7 @@ jobs: TargetFolder: '$(Build.ArtifactStagingDirectory)' CleanTargetFolder: true OverWrite: true - condition: and(succeeded(), ne(variables['Build.Reason'], 'PullRequest')) + condition: and(succeeded(), and(ne(variables['Build.Reason'], 'PullRequest'), eq(variables['QnnLibKind'], 'shared_lib'))) - task: PublishBuildArtifacts@1 displayName: 'Publish Artifact' @@ -129,4 +137,4 @@ jobs: PathtoPublish: '$(Build.ArtifactStagingDirectory)' ArtifactName: 'internal_release' publishLocation: 'Container' - condition: and(succeeded(), ne(variables['Build.Reason'], 'PullRequest')) + condition: and(succeeded(), and(ne(variables['Build.Reason'], 'PullRequest'), eq(variables['QnnLibKind'], 'shared_lib'))) diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 53700c58c7e7d..31897c592b0da 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -36,7 +36,7 @@ parameters: default: 2.28.2.241116 jobs: -- job: 'build' +- job: 'BUILD_QNN_EP' pool: 'Onnxruntime-QNNEP-Windows-2022-CPU' variables: MsbuildArguments: '-detailedsummary -maxcpucount -consoleloggerparameters:PerformanceSummary' @@ -50,6 +50,12 @@ jobs: timeoutInMinutes: 120 workspace: clean: all + strategy: + matrix: + SHARED_LIB: + QnnLibKind: 'shared_lib' + STATIC_LIB: + QnnLibKind: 'static_lib' steps: - task: UsePythonVersion@0 @@ -72,7 +78,8 @@ jobs: --build_dir $(Build.BinariesDirectory) --cmake_generator "Visual Studio 17 2022" --build_java - --use_qnn + --build_shared_lib + --use_qnn $(QnnLibKind) --qnn_home $(QnnSDKRootDir) --use_binskim_compliant_compile_flags --update --parallel @@ -87,7 +94,8 @@ jobs: --build_dir $(Build.BinariesDirectory) ^ --cmake_generator "Visual Studio 17 2022" ^ --build_java ^ - --use_qnn ^ + --build_shared_lib ^ + --use_qnn $(QnnLibKind) ^ --qnn_home $(QnnSDKRootDir) ^ --use_binskim_compliant_compile_flags ^ --test --enable_onnx_tests diff --git a/tools/ci_build/github/linux/build_linux_python_package.sh b/tools/ci_build/github/linux/build_linux_python_package.sh index e2e0cea69efb5..11997382d119c 100755 --- a/tools/ci_build/github/linux/build_linux_python_package.sh +++ b/tools/ci_build/github/linux/build_linux_python_package.sh @@ -94,7 +94,7 @@ fi if [ "$BUILD_DEVICE" == "NPU" ]; then #Enable QNN EP - BUILD_ARGS+=("--use_qnn" "--qnn_home=/qnn_sdk") + BUILD_ARGS+=("--build_shared_lib" "--use_qnn" "--qnn_home=/qnn_sdk") fi export ONNX_ML=1 diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 11842f34ce45b..980455ccddb0e 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -382,6 +382,7 @@ def generate_files(line_list, args): "tensorrt_ep_shared_lib": "onnxruntime_providers_tensorrt.dll", "openvino_ep_shared_lib": "onnxruntime_providers_openvino.dll", "cuda_ep_shared_lib": "onnxruntime_providers_cuda.dll", + "qnn_ep_shared_lib": "onnxruntime_providers_qnn.dll", "onnxruntime_perf_test": "onnxruntime_perf_test.exe", "onnx_test_runner": "onnx_test_runner.exe", } @@ -777,6 +778,24 @@ def generate_files(line_list, args): + '\\native" />' ) + if args.execution_provider == "qnn" or is_qnn_package and not is_ado_packaging_build: + files_list.append( + "' + ) + files_list.append( + "' + ) + # process all other library dependencies if is_cpu_package or is_cuda_gpu_package or is_dml_package or is_mklml_package: # Process dnnl dependency