@@ -1073,9 +1073,15 @@ static int get_input_ov(void *model, DNNData *input, const char *input_name)
return AVERROR(ENOSYS);
}
- input->channels = dims[1];
- input->height = input_resizable ? -1 : dims[2];
- input->width = input_resizable ? -1 : dims[3];
+ if (dims[1] <= 3) { // NCHW
+ input->channels = dims[1];
+ input->height = input_resizable ? -1 : dims[2];
+ input->width = input_resizable ? -1 : dims[3];
+ } else { // NHWC
+ input->height = input_resizable ? -1 : dims[1];
+ input->width = input_resizable ? -1 : dims[2];
+ input->channels = dims[3];
+ }
input->dt = precision_to_datatype(precision);
return 0;
@@ -1105,9 +1111,15 @@ static int get_input_ov(void *model, DNNData *input, const char *input_name)
return DNN_GENERIC_ERROR;
}
- input->channels = dims.dims[1];
- input->height = input_resizable ? -1 : dims.dims[2];
- input->width = input_resizable ? -1 : dims.dims[3];
+ if (dims[1] <= 3) { // NCHW
+ input->channels = dims[1];
+ input->height = input_resizable ? -1 : dims[2];
+ input->width = input_resizable ? -1 : dims[3];
+ } else { // NHWC
+ input->height = input_resizable ? -1 : dims[1];
+ input->width = input_resizable ? -1 : dims[2];
+ input->channels = dims[3];
+ }
input->dt = precision_to_datatype(precision);
return 0;
}
@@ -699,13 +699,39 @@ static av_cold void dnn_detect_uninit(AVFilterContext *context)
free_detect_labels(ctx);
}
+static int config_input(AVFilterLink *inlink)
+{
+ AVFilterContext *context = inlink->dst;
+ DnnDetectContext *ctx = context->priv;
+ DNNData model_input;
+ int ret;
+
+ ret = ff_dnn_get_input(&ctx->dnnctx, &model_input);
+ if (ret != 0) {
+ av_log(ctx, AV_LOG_ERROR, "could not get input from the model\n");
+ return ret;
+ }
+ ctx->scale_width = model_input.width == -1 ? inlink->w : model_input.width;
+ ctx->scale_height = model_input.height == -1 ? inlink->h : model_input.height;
+
+ return 0;
+}
+
+static const AVFilterPad dnn_detect_inputs[] = {
+ {
+ .name = "default",
+ .type = AVMEDIA_TYPE_VIDEO,
+ .config_props = config_input,
+ },
+};
+
const AVFilter ff_vf_dnn_detect = {
.name = "dnn_detect",
.description = NULL_IF_CONFIG_SMALL("Apply DNN detect filter to the input."),
.priv_size = sizeof(DnnDetectContext),
.init = dnn_detect_init,
.uninit = dnn_detect_uninit,
- FILTER_INPUTS(ff_video_default_filterpad),
+ FILTER_INPUTS(dnn_detect_inputs),
FILTER_OUTPUTS(ff_video_default_filterpad),
FILTER_PIXFMTS_ARRAY(pix_fmts),
.priv_class = &dnn_detect_class,