// SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
 /* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. */
 
+#include <linux/refcount.h>
+
 #include "en_tc.h"
 #include "en/tc_priv.h"
 #include "en/tc_ct.h"
        netdev_dbg(fs->netdev, "ct_fs_smfs debug: " fmt "\n", ##args)
 #define MLX5_CT_TCP_FLAGS_MASK cpu_to_be16(be32_to_cpu(TCP_FLAG_RST | TCP_FLAG_FIN) >> 16)
 
+struct mlx5_ct_fs_smfs_matcher {
+       struct mlx5dr_matcher *dr_matcher;
+       struct list_head list;
+       int prio;
+       refcount_t ref;
+};
+
 struct mlx5_ct_fs_smfs_matchers {
-       struct mlx5dr_matcher *ipv4_tcp;
-       struct mlx5dr_matcher *ipv4_udp;
-       struct mlx5dr_matcher *ipv6_tcp;
-       struct mlx5dr_matcher *ipv6_udp;
+       struct mlx5_ct_fs_smfs_matcher smfs_matchers[4];
+       struct list_head used;
 };
 
 struct mlx5_ct_fs_smfs {
-       struct mlx5_ct_fs_smfs_matchers ct_matchers;
-       struct mlx5_ct_fs_smfs_matchers ct_matchers_nat;
+       struct mlx5dr_table *ct_tbl, *ct_nat_tbl;
+       struct mlx5_ct_fs_smfs_matchers matchers;
+       struct mlx5_ct_fs_smfs_matchers matchers_nat;
        struct mlx5dr_action *fwd_action;
        struct mlx5_flow_table *ct_nat;
+       struct mutex lock; /* Guards matchers */
 };
 
 struct mlx5_ct_fs_smfs_rule {
        struct mlx5_ct_fs_rule fs_rule;
        struct mlx5dr_rule *rule;
        struct mlx5dr_action *count_action;
+       struct mlx5_ct_fs_smfs_matcher *smfs_matcher;
 };
 
 static inline void
        return dr_matcher;
 }
 
-static int
-mlx5_ct_fs_smfs_matchers_create(struct mlx5_ct_fs *fs, struct mlx5dr_table *tbl,
-                               struct mlx5_ct_fs_smfs_matchers *ct_matchers)
+static struct mlx5_ct_fs_smfs_matcher *
+mlx5_ct_fs_smfs_matcher_get(struct mlx5_ct_fs *fs, bool nat, bool ipv4, bool tcp)
 {
-       const struct net_device *netdev = fs->netdev;
-       u32 prio = 0;
-       int err;
-
-       ct_matchers->ipv4_tcp = mlx5_ct_fs_smfs_matcher_create(fs, tbl, true, true, prio);
-       if (IS_ERR(ct_matchers->ipv4_tcp)) {
-               err = PTR_ERR(ct_matchers->ipv4_tcp);
-               netdev_warn(netdev,
-                           "%s, failed to create ipv4 tcp matcher, err: %d\n",
-                           INIT_ERR_PREFIX, err);
-               return err;
-       }
-
-       ++prio;
-       ct_matchers->ipv4_udp = mlx5_ct_fs_smfs_matcher_create(fs, tbl, true, false, prio);
-       if (IS_ERR(ct_matchers->ipv4_udp)) {
-               err = PTR_ERR(ct_matchers->ipv4_udp);
-               netdev_warn(netdev,
-                           "%s, failed to create ipv4 udp matcher, err: %d\n",
-                           INIT_ERR_PREFIX, err);
-               goto err_matcher_ipv4_udp;
+       struct mlx5_ct_fs_smfs *fs_smfs = mlx5_ct_fs_priv(fs);
+       struct mlx5_ct_fs_smfs_matcher *m, *smfs_matcher;
+       struct mlx5_ct_fs_smfs_matchers *matchers;
+       struct mlx5dr_matcher *dr_matcher;
+       struct mlx5dr_table *tbl;
+       struct list_head *prev;
+       int prio;
+
+       matchers = nat ? &fs_smfs->matchers_nat : &fs_smfs->matchers;
+       smfs_matcher = &matchers->smfs_matchers[ipv4 * 2 + tcp];
+
+       if (refcount_inc_not_zero(&smfs_matcher->ref))
+               return smfs_matcher;
+
+       mutex_lock(&fs_smfs->lock);
+
+       /* Retry with lock, as another thread might have already created the relevant matcher
+        * till we acquired the lock
+        */
+       if (refcount_inc_not_zero(&smfs_matcher->ref))
+               goto out_unlock;
+
+       // Find next available priority in sorted used list
+       prio = 0;
+       prev = &matchers->used;
+       list_for_each_entry(m, &matchers->used, list) {
+               prev = &m->list;
+
+               if (m->prio == prio)
+                       prio = m->prio + 1;
+               else
+                       break;
        }
 
-       ++prio;
-       ct_matchers->ipv6_tcp = mlx5_ct_fs_smfs_matcher_create(fs, tbl, false, true, prio);
-       if (IS_ERR(ct_matchers->ipv6_tcp)) {
-               err = PTR_ERR(ct_matchers->ipv6_tcp);
-               netdev_warn(netdev,
-                           "%s, failed to create ipv6 tcp matcher, err: %d\n",
-                           INIT_ERR_PREFIX, err);
-               goto err_matcher_ipv6_tcp;
-       }
+       tbl = nat ? fs_smfs->ct_nat_tbl : fs_smfs->ct_tbl;
+       dr_matcher = mlx5_ct_fs_smfs_matcher_create(fs, tbl, ipv4, tcp, prio);
+       if (IS_ERR(dr_matcher)) {
+               netdev_warn(fs->netdev,
+                           "ct_fs_smfs: failed to create matcher (nat %d, ipv4 %d, tcp %d), err: %ld\n",
+                           nat, ipv4, tcp, PTR_ERR(dr_matcher));
 
-       ++prio;
-       ct_matchers->ipv6_udp = mlx5_ct_fs_smfs_matcher_create(fs, tbl, false, false, prio);
-       if (IS_ERR(ct_matchers->ipv6_udp)) {
-               err = PTR_ERR(ct_matchers->ipv6_udp);
-               netdev_warn(netdev,
-                           "%s, failed to create ipv6 tcp matcher, err: %d\n",
-                            INIT_ERR_PREFIX, err);
-               goto err_matcher_ipv6_udp;
+               smfs_matcher = ERR_CAST(dr_matcher);
+               goto out_unlock;
        }
 
-       return 0;
+       smfs_matcher->dr_matcher = dr_matcher;
+       smfs_matcher->prio = prio;
+       list_add(&smfs_matcher->list, prev);
+       refcount_set(&smfs_matcher->ref, 1);
 
-err_matcher_ipv6_udp:
-       mlx5_smfs_matcher_destroy(ct_matchers->ipv6_tcp);
-err_matcher_ipv6_tcp:
-       mlx5_smfs_matcher_destroy(ct_matchers->ipv4_udp);
-err_matcher_ipv4_udp:
-       mlx5_smfs_matcher_destroy(ct_matchers->ipv4_tcp);
-       return 0;
+out_unlock:
+       mutex_unlock(&fs_smfs->lock);
+       return smfs_matcher;
 }
 
 static void
-mlx5_ct_fs_smfs_matchers_destroy(struct mlx5_ct_fs_smfs_matchers *ct_matchers)
+mlx5_ct_fs_smfs_matcher_put(struct mlx5_ct_fs *fs, struct mlx5_ct_fs_smfs_matcher *smfs_matcher)
 {
-       mlx5_smfs_matcher_destroy(ct_matchers->ipv6_udp);
-       mlx5_smfs_matcher_destroy(ct_matchers->ipv6_tcp);
-       mlx5_smfs_matcher_destroy(ct_matchers->ipv4_udp);
-       mlx5_smfs_matcher_destroy(ct_matchers->ipv4_tcp);
+       struct mlx5_ct_fs_smfs *fs_smfs = mlx5_ct_fs_priv(fs);
+
+       if (!refcount_dec_and_mutex_lock(&smfs_matcher->ref, &fs_smfs->lock))
+               return;
+
+       mlx5_smfs_matcher_destroy(smfs_matcher->dr_matcher);
+       list_del(&smfs_matcher->list);
+       mutex_unlock(&fs_smfs->lock);
 }
 
 static int
 {
        struct mlx5dr_table *ct_tbl, *ct_nat_tbl, *post_ct_tbl;
        struct mlx5_ct_fs_smfs *fs_smfs = mlx5_ct_fs_priv(fs);
-       int err;
 
        post_ct_tbl = mlx5_smfs_table_get_from_fs_ft(post_ct);
        ct_nat_tbl = mlx5_smfs_table_get_from_fs_ft(ct_nat);
 
        ct_dbg("using smfs steering");
 
-       err = mlx5_ct_fs_smfs_matchers_create(fs, ct_tbl, &fs_smfs->ct_matchers);
-       if (err)
-               goto err_init;
-
-       err = mlx5_ct_fs_smfs_matchers_create(fs, ct_nat_tbl, &fs_smfs->ct_matchers_nat);
-       if (err)
-               goto err_matchers_nat;
-
        fs_smfs->fwd_action = mlx5_smfs_action_create_dest_table(post_ct_tbl);
        if (!fs_smfs->fwd_action) {
-               err = -EINVAL;
-               goto err_action_create;
+               return -EINVAL;
        }
 
-       return 0;
+       fs_smfs->ct_tbl = ct_tbl;
+       fs_smfs->ct_nat_tbl = ct_nat_tbl;
+       mutex_init(&fs_smfs->lock);
+       INIT_LIST_HEAD(&fs_smfs->matchers.used);
+       INIT_LIST_HEAD(&fs_smfs->matchers_nat.used);
 
-err_action_create:
-       mlx5_ct_fs_smfs_matchers_destroy(&fs_smfs->ct_matchers_nat);
-err_matchers_nat:
-       mlx5_ct_fs_smfs_matchers_destroy(&fs_smfs->ct_matchers);
-err_init:
-       return err;
+       return 0;
 }
 
 static void
        struct mlx5_ct_fs_smfs *fs_smfs = mlx5_ct_fs_priv(fs);
 
        mlx5_smfs_action_destroy(fs_smfs->fwd_action);
-       mlx5_ct_fs_smfs_matchers_destroy(&fs_smfs->ct_matchers_nat);
-       mlx5_ct_fs_smfs_matchers_destroy(&fs_smfs->ct_matchers);
 }
 
 static inline bool
                            struct mlx5_flow_attr *attr, struct flow_rule *flow_rule)
 {
        struct mlx5_ct_fs_smfs *fs_smfs = mlx5_ct_fs_priv(fs);
-       struct mlx5_ct_fs_smfs_matchers *matchers;
+       struct mlx5_ct_fs_smfs_matcher *smfs_matcher;
        struct mlx5_ct_fs_smfs_rule *smfs_rule;
        struct mlx5dr_action *actions[5];
-       struct mlx5dr_matcher *matcher;
        struct mlx5dr_rule *rule;
        int num_actions = 0, err;
        bool nat, tcp, ipv4;
        tcp = MLX5_GET(fte_match_param, spec->match_value,
                       outer_headers.ip_protocol) == IPPROTO_TCP;
 
-       matchers = nat ? &fs_smfs->ct_matchers_nat : &fs_smfs->ct_matchers;
-       matcher = ipv4 ? (tcp ? matchers->ipv4_tcp : matchers->ipv4_udp) :
-                        (tcp ? matchers->ipv6_tcp : matchers->ipv6_udp);
+       smfs_matcher = mlx5_ct_fs_smfs_matcher_get(fs, nat, ipv4, tcp);
+       if (IS_ERR(smfs_matcher)) {
+               err = PTR_ERR(smfs_matcher);
+               goto err_matcher;
+       }
 
-       rule = mlx5_smfs_rule_create(matcher, spec, num_actions, actions,
+       rule = mlx5_smfs_rule_create(smfs_matcher->dr_matcher, spec, num_actions, actions,
                                     MLX5_FLOW_CONTEXT_FLOW_SOURCE_ANY_VPORT);
        if (!rule) {
                err = -EINVAL;
-               goto err_rule;
+               goto err_create;
        }
 
        smfs_rule->rule = rule;
+       smfs_rule->smfs_matcher = smfs_matcher;
 
        return &smfs_rule->fs_rule;
 
-err_rule:
+err_create:
+       mlx5_ct_fs_smfs_matcher_put(fs, smfs_matcher);
+err_matcher:
        mlx5_smfs_action_destroy(smfs_rule->count_action);
 err_count:
        kfree(smfs_rule);
                                                              fs_rule);
 
        mlx5_smfs_rule_destroy(smfs_rule->rule);
+       mlx5_ct_fs_smfs_matcher_put(fs, smfs_rule->smfs_matcher);
        mlx5_smfs_action_destroy(smfs_rule->count_action);
        kfree(smfs_rule);
 }