diff mbox series

[FFmpeg-devel,WIP,v2,9/9] avfilter/dnn: Use dnn_backend_info_list to search for dnn module

Message ID tencent_C5568B9DEAE1165AB6349981F234913FA909@qq.com
State New
Headers show
Series [FFmpeg-devel,WIP,v2,1/9] avfilter/dnn: Refactor DNN parameter configuration system | expand

Checks

Context Check Description
yinshiyou/make_loongarch64 success Make finished
yinshiyou/make_fate_loongarch64 success Make fate finished

Commit Message

Zhao Zhili April 28, 2024, 6:46 a.m. UTC
From: Zhao Zhili <zhilizhao@tencent.com>

---
 libavfilter/dnn/dnn_backend_openvino.c |  1 +
 libavfilter/dnn/dnn_backend_tf.c       |  1 +
 libavfilter/dnn/dnn_backend_torch.cpp  |  1 +
 libavfilter/dnn/dnn_interface.c        | 26 ++++++++------------------
 libavfilter/dnn_interface.h            |  1 +
 5 files changed, 12 insertions(+), 18 deletions(-)
diff mbox series

Patch

diff --git a/libavfilter/dnn/dnn_backend_openvino.c b/libavfilter/dnn/dnn_backend_openvino.c
index d8a6820dc2..9c699cdc8c 100644
--- a/libavfilter/dnn/dnn_backend_openvino.c
+++ b/libavfilter/dnn/dnn_backend_openvino.c
@@ -1613,6 +1613,7 @@  static int dnn_flush_ov(const DNNModel *model)
 
 const DNNModule ff_dnn_backend_openvino = {
     .clazz          = DNN_DEFINE_CLASS(dnn_openvino),
+    .type           = DNN_OV,
     .load_model     = dnn_load_model_ov,
     .execute_model  = dnn_execute_model_ov,
     .get_result     = dnn_get_result_ov,
diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c
index 06ea6cbb8c..6afefe8115 100644
--- a/libavfilter/dnn/dnn_backend_tf.c
+++ b/libavfilter/dnn/dnn_backend_tf.c
@@ -886,6 +886,7 @@  static int dnn_flush_tf(const DNNModel *model)
 
 const DNNModule ff_dnn_backend_tf = {
     .clazz          = DNN_DEFINE_CLASS(dnn_tensorflow),
+    .type           = DNN_TF,
     .load_model     = dnn_load_model_tf,
     .execute_model  = dnn_execute_model_tf,
     .get_result     = dnn_get_result_tf,
diff --git a/libavfilter/dnn/dnn_backend_torch.cpp b/libavfilter/dnn/dnn_backend_torch.cpp
index 24e9f2c8e2..2557264713 100644
--- a/libavfilter/dnn/dnn_backend_torch.cpp
+++ b/libavfilter/dnn/dnn_backend_torch.cpp
@@ -561,6 +561,7 @@  static int dnn_flush_th(const DNNModel *model)
 
 extern const DNNModule ff_dnn_backend_torch = {
     .clazz          = DNN_DEFINE_CLASS(dnn_th),
+    .type           = DNN_TH,
     .load_model     = dnn_load_model_th,
     .execute_model  = dnn_execute_model_th,
     .get_result     = dnn_get_result_th,
diff --git a/libavfilter/dnn/dnn_interface.c b/libavfilter/dnn/dnn_interface.c
index ebd308cd84..cce3c45856 100644
--- a/libavfilter/dnn/dnn_interface.c
+++ b/libavfilter/dnn/dnn_interface.c
@@ -80,25 +80,15 @@  static const DnnBackendInfo dnn_backend_info_list[] = {
 
 const DNNModule *ff_get_dnn_module(DNNBackendType backend_type, void *log_ctx)
 {
-    switch(backend_type){
-    #if (CONFIG_LIBTENSORFLOW == 1)
-    case DNN_TF:
-        return &ff_dnn_backend_tf;
-    #endif
-    #if (CONFIG_LIBOPENVINO == 1)
-    case DNN_OV:
-        return &ff_dnn_backend_openvino;
-    #endif
-    #if (CONFIG_LIBTORCH == 1)
-    case DNN_TH:
-        return &ff_dnn_backend_torch;
-    #endif
-    default:
-        av_log(log_ctx, AV_LOG_ERROR,
-                "Module backend_type %d is not supported or enabled.\n",
-                backend_type);
-        return NULL;
+    for (int i = 1; i < FF_ARRAY_ELEMS(dnn_backend_info_list); i++) {
+        if (dnn_backend_info_list[i].module->type == backend_type)
+            return dnn_backend_info_list[i].module;
     }
+
+    av_log(log_ctx, AV_LOG_ERROR,
+            "Module backend_type %d is not supported or enabled.\n",
+            backend_type);
+    return NULL;
 }
 
 void *ff_dnn_child_next(DnnContext *obj, void *prev) {
diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
index 4e14a42d00..4b25ac2b84 100644
--- a/libavfilter/dnn_interface.h
+++ b/libavfilter/dnn_interface.h
@@ -170,6 +170,7 @@  typedef struct DnnContext {
 // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.
 struct DNNModule {
     const AVClass clazz;
+    DNNBackendType type;
     // Loads model and parameters from given file. Returns NULL if it is not possible.
     DNNModel *(*load_model)(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx);
     // Executes model with specified input and output. Returns the error code otherwise.