diff mbox series

[FFmpeg-devel,v4,03/11] avfilter/dnn: Don't show backends which are not supported by a filter

Message ID tencent_633E0956AD9C4CE08A5A2835E18066340B09@qq.com
State New
Headers show
Series [FFmpeg-devel,v4,01/11] avfilter/dnn: Refactor DNN parameter configuration system | expand

Checks

Context Check Description
andriy/make_x86 success Make finished
andriy/make_fate_x86 success Make fate finished

Commit Message

Zhao Zhili May 7, 2024, 4:08 p.m. UTC
From: Zhao Zhili <zhilizhao@tencent.com>

---
 libavfilter/dnn/dnn_interface.c | 11 ++++++++---
 libavfilter/dnn_filter_common.h | 11 +++++++++--
 libavfilter/dnn_interface.h     |  8 ++++++--
 libavfilter/vf_derain.c         |  2 +-
 libavfilter/vf_dnn_classify.c   |  2 +-
 libavfilter/vf_dnn_detect.c     |  2 +-
 libavfilter/vf_dnn_processing.c |  2 +-
 libavfilter/vf_sr.c             |  2 +-
 8 files changed, 28 insertions(+), 12 deletions(-)
diff mbox series

Patch

diff --git a/libavfilter/dnn/dnn_interface.c b/libavfilter/dnn/dnn_interface.c
index b56c22e4c7..dc1593821d 100644
--- a/libavfilter/dnn/dnn_interface.c
+++ b/libavfilter/dnn/dnn_interface.c
@@ -120,11 +120,16 @@  void *ff_dnn_child_next(DnnContext *obj, void *prev) {
     return NULL;
 }
 
-const AVClass *ff_dnn_child_class_iterate(void **iter)
+const AVClass *ff_dnn_child_class_iterate_with_mask(void **iter, uint32_t backend_mask)
 {
-    uintptr_t i = (uintptr_t) *iter;
+    for (uintptr_t i = (uintptr_t)*iter; i < FF_ARRAY_ELEMS(dnn_backend_info_list); i++) {
+        if (i > 0) {
+            const DNNModule *module = dnn_backend_info_list[i].module;
+
+            if (!(module->type & backend_mask))
+                continue;
+        }
 
-    if (i < FF_ARRAY_ELEMS(dnn_backend_info_list)) {
         *iter = (void *)(i + 1);
         return dnn_backend_info_list[i].class;
     }
diff --git a/libavfilter/dnn_filter_common.h b/libavfilter/dnn_filter_common.h
index b52b55a90d..42a4719997 100644
--- a/libavfilter/dnn_filter_common.h
+++ b/libavfilter/dnn_filter_common.h
@@ -26,6 +26,12 @@ 
 
 #include "dnn_interface.h"
 
+#define DNN_FILTER_CHILD_CLASS_ITERATE(name, backend_mask)                  \
+    static const AVClass *name##_child_class_iterate(void **iter)           \
+    {                                                                       \
+        return  ff_dnn_child_class_iterate_with_mask(iter, (backend_mask)); \
+    }
+
 #define AVFILTER_DNN_DEFINE_CLASS_EXT(name, desc, options) \
     static const AVClass name##_class = {       \
         .class_name = desc,                     \
@@ -34,10 +40,11 @@ 
         .version    = LIBAVUTIL_VERSION_INT,    \
         .category   = AV_CLASS_CATEGORY_FILTER,            \
         .child_next = ff_dnn_filter_child_next,            \
-        .child_class_iterate = ff_dnn_child_class_iterate, \
+        .child_class_iterate = name##_child_class_iterate, \
     }
 
-#define AVFILTER_DNN_DEFINE_CLASS(fname) \
+#define AVFILTER_DNN_DEFINE_CLASS(fname, backend_mask)      \
+    DNN_FILTER_CHILD_CLASS_ITERATE(fname, backend_mask)     \
     AVFILTER_DNN_DEFINE_CLASS_EXT(fname, #fname, fname##_options)
 
 void *ff_dnn_filter_child_next(void *obj, void *prev);
diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
index dd603534b2..697b9f3318 100644
--- a/libavfilter/dnn_interface.h
+++ b/libavfilter/dnn_interface.h
@@ -32,7 +32,11 @@ 
 
 #define DNN_GENERIC_ERROR FFERRTAG('D','N','N','!')
 
-typedef enum {DNN_TF = 1, DNN_OV, DNN_TH} DNNBackendType;
+typedef enum {
+    DNN_TF = 1,
+    DNN_OV = 1 << 1,
+    DNN_TH = 1 << 2
+} DNNBackendType;
 
 typedef enum {DNN_FLOAT = 1, DNN_UINT8 = 4} DNNDataType;
 
@@ -190,7 +194,7 @@  const DNNModule *ff_get_dnn_module(DNNBackendType backend_type, void *log_ctx);
 
 void ff_dnn_init_child_class(DnnContext *ctx);
 void *ff_dnn_child_next(DnnContext *obj, void *prev);
-const AVClass *ff_dnn_child_class_iterate(void **iter);
+const AVClass *ff_dnn_child_class_iterate_with_mask(void **iter, uint32_t backend_mask);
 
 static inline int dnn_get_width_idx_by_layout(DNNLayout layout)
 {
diff --git a/libavfilter/vf_derain.c b/libavfilter/vf_derain.c
index 7f665b73ab..5cefca6b55 100644
--- a/libavfilter/vf_derain.c
+++ b/libavfilter/vf_derain.c
@@ -49,7 +49,7 @@  static const AVOption derain_options[] = {
     { NULL }
 };
 
-AVFILTER_DNN_DEFINE_CLASS(derain);
+AVFILTER_DNN_DEFINE_CLASS(derain, DNN_TF);
 
 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
 {
diff --git a/libavfilter/vf_dnn_classify.c b/libavfilter/vf_dnn_classify.c
index 965779a8ab..f6d3678796 100644
--- a/libavfilter/vf_dnn_classify.c
+++ b/libavfilter/vf_dnn_classify.c
@@ -56,7 +56,7 @@  static const AVOption dnn_classify_options[] = {
     { NULL }
 };
 
-AVFILTER_DNN_DEFINE_CLASS(dnn_classify);
+AVFILTER_DNN_DEFINE_CLASS(dnn_classify, DNN_OV);
 
 static int dnn_classify_post_proc(AVFrame *frame, DNNData *output, uint32_t bbox_index, AVFilterContext *filter_ctx)
 {
diff --git a/libavfilter/vf_dnn_detect.c b/libavfilter/vf_dnn_detect.c
index 926966368a..b4eee06fe7 100644
--- a/libavfilter/vf_dnn_detect.c
+++ b/libavfilter/vf_dnn_detect.c
@@ -84,7 +84,7 @@  static const AVOption dnn_detect_options[] = {
     { NULL }
 };
 
-AVFILTER_DNN_DEFINE_CLASS(dnn_detect);
+AVFILTER_DNN_DEFINE_CLASS(dnn_detect, DNN_TF | DNN_OV);
 
 static inline float sigmoid(float x) {
     return 1.f / (1.f + exp(-x));
diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c
index 9a1dd2a356..7c0f84ec80 100644
--- a/libavfilter/vf_dnn_processing.c
+++ b/libavfilter/vf_dnn_processing.c
@@ -57,7 +57,7 @@  static const AVOption dnn_processing_options[] = {
     { NULL }
 };
 
-AVFILTER_DNN_DEFINE_CLASS(dnn_processing);
+AVFILTER_DNN_DEFINE_CLASS(dnn_processing, DNN_TF | DNN_OV | DNN_TH);
 
 static av_cold int init(AVFilterContext *context)
 {
diff --git a/libavfilter/vf_sr.c b/libavfilter/vf_sr.c
index f14c0c0cd3..3bfca7f042 100644
--- a/libavfilter/vf_sr.c
+++ b/libavfilter/vf_sr.c
@@ -53,7 +53,7 @@  static const AVOption sr_options[] = {
     { NULL }
 };
 
-AVFILTER_DNN_DEFINE_CLASS(sr);
+AVFILTER_DNN_DEFINE_CLASS(sr, DNN_TF);
 
 static av_cold int init(AVFilterContext *context)
 {