[Cake] [PATCH net v2] sched: sch_cake: add bounds checks to host bulk flow fairness counts

Dave Taht dave.taht at gmail.com
Wed Jan 8 11:10:36 EST 2025


On Tue, Jan 7, 2025 at 4:01 AM Toke Høiland-Jørgensen via Cake
<cake at lists.bufferbloat.net> wrote:
>
> Even though we fixed a logic error in the commit cited below, syzbot
> still managed to trigger an underflow of the per-host bulk flow
> counters, leading to an out of bounds memory access.
>
> To avoid any such logic errors causing out of bounds memory accesses,
> this commit factors out all accesses to the per-host bulk flow counters
> to a series of helpers that perform bounds-checking before any
> increments and decrements. This also has the benefit of improving
> readability by moving the conditional checks for the flow mode into
> these helpers, instead of having them spread out throughout the
> code (which was the cause of the original logic error).
>
> v2:
> - Remove now-unused srchost and dsthost local variables in cake_dequeue()
>
> Fixes: 546ea84d07e3 ("sched: sch_cake: fix bulk flow accounting logic for host fairness")
> Reported-by: syzbot+f63600d288bfb7057424 at syzkaller.appspotmail.com
> Signed-off-by: Toke Høiland-Jørgensen <toke at redhat.com>
> ---
>  net/sched/sch_cake.c | 140 +++++++++++++++++++++++--------------------
>  1 file changed, 75 insertions(+), 65 deletions(-)
>
> diff --git a/net/sched/sch_cake.c b/net/sched/sch_cake.c
> index 8d8b2db4653c..2c2e2a67f3b2 100644
> --- a/net/sched/sch_cake.c
> +++ b/net/sched/sch_cake.c
> @@ -627,6 +627,63 @@ static bool cake_ddst(int flow_mode)
>         return (flow_mode & CAKE_FLOW_DUAL_DST) == CAKE_FLOW_DUAL_DST;
>  }
>
> +static void cake_dec_srchost_bulk_flow_count(struct cake_tin_data *q,
> +                                            struct cake_flow *flow,
> +                                            int flow_mode)
> +{
> +       if (likely(cake_dsrc(flow_mode) &&
> +                  q->hosts[flow->srchost].srchost_bulk_flow_count))
> +               q->hosts[flow->srchost].srchost_bulk_flow_count--;
> +}
> +
> +static void cake_inc_srchost_bulk_flow_count(struct cake_tin_data *q,
> +                                            struct cake_flow *flow,
> +                                            int flow_mode)
> +{
> +       if (likely(cake_dsrc(flow_mode) &&
> +                  q->hosts[flow->srchost].srchost_bulk_flow_count < CAKE_QUEUES))
> +               q->hosts[flow->srchost].srchost_bulk_flow_count++;
> +}
> +
> +static void cake_dec_dsthost_bulk_flow_count(struct cake_tin_data *q,
> +                                            struct cake_flow *flow,
> +                                            int flow_mode)
> +{
> +       if (likely(cake_ddst(flow_mode) &&
> +                  q->hosts[flow->dsthost].dsthost_bulk_flow_count))
> +               q->hosts[flow->dsthost].dsthost_bulk_flow_count--;
> +}
> +
> +static void cake_inc_dsthost_bulk_flow_count(struct cake_tin_data *q,
> +                                            struct cake_flow *flow,
> +                                            int flow_mode)
> +{
> +       if (likely(cake_ddst(flow_mode) &&
> +                  q->hosts[flow->dsthost].dsthost_bulk_flow_count < CAKE_QUEUES))
> +               q->hosts[flow->dsthost].dsthost_bulk_flow_count++;
> +}
> +
> +static u16 cake_get_flow_quantum(struct cake_tin_data *q,
> +                                struct cake_flow *flow,
> +                                int flow_mode)
> +{
> +       u16 host_load = 1;
> +
> +       if (cake_dsrc(flow_mode))
> +               host_load = max(host_load,
> +                               q->hosts[flow->srchost].srchost_bulk_flow_count);
> +
> +       if (cake_ddst(flow_mode))
> +               host_load = max(host_load,
> +                               q->hosts[flow->dsthost].dsthost_bulk_flow_count);
> +
> +       /* The get_random_u16() is a way to apply dithering to avoid
> +        * accumulating roundoff errors
> +        */
> +       return (q->flow_quantum * quantum_div[host_load] +
> +               get_random_u16()) >> 16;
> +}
> +
>  static u32 cake_hash(struct cake_tin_data *q, const struct sk_buff *skb,
>                      int flow_mode, u16 flow_override, u16 host_override)
>  {
> @@ -773,10 +830,8 @@ static u32 cake_hash(struct cake_tin_data *q, const struct sk_buff *skb,
>                 allocate_dst = cake_ddst(flow_mode);
>
>                 if (q->flows[outer_hash + k].set == CAKE_SET_BULK) {
> -                       if (allocate_src)
> -                               q->hosts[q->flows[reduced_hash].srchost].srchost_bulk_flow_count--;
> -                       if (allocate_dst)
> -                               q->hosts[q->flows[reduced_hash].dsthost].dsthost_bulk_flow_count--;
> +                       cake_dec_srchost_bulk_flow_count(q, &q->flows[outer_hash + k], flow_mode);
> +                       cake_dec_dsthost_bulk_flow_count(q, &q->flows[outer_hash + k], flow_mode);
>                 }
>  found:
>                 /* reserve queue for future packets in same flow */
> @@ -801,9 +856,10 @@ static u32 cake_hash(struct cake_tin_data *q, const struct sk_buff *skb,
>                         q->hosts[outer_hash + k].srchost_tag = srchost_hash;
>  found_src:
>                         srchost_idx = outer_hash + k;
> -                       if (q->flows[reduced_hash].set == CAKE_SET_BULK)
> -                               q->hosts[srchost_idx].srchost_bulk_flow_count++;
>                         q->flows[reduced_hash].srchost = srchost_idx;
> +
> +                       if (q->flows[reduced_hash].set == CAKE_SET_BULK)
> +                               cake_inc_srchost_bulk_flow_count(q, &q->flows[reduced_hash], flow_mode);
>                 }
>
>                 if (allocate_dst) {
> @@ -824,9 +880,10 @@ static u32 cake_hash(struct cake_tin_data *q, const struct sk_buff *skb,
>                         q->hosts[outer_hash + k].dsthost_tag = dsthost_hash;
>  found_dst:
>                         dsthost_idx = outer_hash + k;
> -                       if (q->flows[reduced_hash].set == CAKE_SET_BULK)
> -                               q->hosts[dsthost_idx].dsthost_bulk_flow_count++;
>                         q->flows[reduced_hash].dsthost = dsthost_idx;
> +
> +                       if (q->flows[reduced_hash].set == CAKE_SET_BULK)
> +                               cake_inc_dsthost_bulk_flow_count(q, &q->flows[reduced_hash], flow_mode);
>                 }
>         }
>
> @@ -1839,10 +1896,6 @@ static s32 cake_enqueue(struct sk_buff *skb, struct Qdisc *sch,
>
>         /* flowchain */
>         if (!flow->set || flow->set == CAKE_SET_DECAYING) {
> -               struct cake_host *srchost = &b->hosts[flow->srchost];
> -               struct cake_host *dsthost = &b->hosts[flow->dsthost];
> -               u16 host_load = 1;
> -
>                 if (!flow->set) {
>                         list_add_tail(&flow->flowchain, &b->new_flows);
>                 } else {
> @@ -1852,18 +1905,8 @@ static s32 cake_enqueue(struct sk_buff *skb, struct Qdisc *sch,
>                 flow->set = CAKE_SET_SPARSE;
>                 b->sparse_flow_count++;
>
> -               if (cake_dsrc(q->flow_mode))
> -                       host_load = max(host_load, srchost->srchost_bulk_flow_count);
> -
> -               if (cake_ddst(q->flow_mode))
> -                       host_load = max(host_load, dsthost->dsthost_bulk_flow_count);
> -
> -               flow->deficit = (b->flow_quantum *
> -                                quantum_div[host_load]) >> 16;
> +               flow->deficit = cake_get_flow_quantum(b, flow, q->flow_mode);
>         } else if (flow->set == CAKE_SET_SPARSE_WAIT) {
> -               struct cake_host *srchost = &b->hosts[flow->srchost];
> -               struct cake_host *dsthost = &b->hosts[flow->dsthost];
> -
>                 /* this flow was empty, accounted as a sparse flow, but actually
>                  * in the bulk rotation.
>                  */
> @@ -1871,12 +1914,8 @@ static s32 cake_enqueue(struct sk_buff *skb, struct Qdisc *sch,
>                 b->sparse_flow_count--;
>                 b->bulk_flow_count++;
>
> -               if (cake_dsrc(q->flow_mode))
> -                       srchost->srchost_bulk_flow_count++;
> -
> -               if (cake_ddst(q->flow_mode))
> -                       dsthost->dsthost_bulk_flow_count++;
> -
> +               cake_inc_srchost_bulk_flow_count(b, flow, q->flow_mode);
> +               cake_inc_dsthost_bulk_flow_count(b, flow, q->flow_mode);
>         }
>
>         if (q->buffer_used > q->buffer_max_used)
> @@ -1933,13 +1972,11 @@ static struct sk_buff *cake_dequeue(struct Qdisc *sch)
>  {
>         struct cake_sched_data *q = qdisc_priv(sch);
>         struct cake_tin_data *b = &q->tins[q->cur_tin];
> -       struct cake_host *srchost, *dsthost;
>         ktime_t now = ktime_get();
>         struct cake_flow *flow;
>         struct list_head *head;
>         bool first_flow = true;
>         struct sk_buff *skb;
> -       u16 host_load;
>         u64 delay;
>         u32 len;
>
> @@ -2039,11 +2076,6 @@ static struct sk_buff *cake_dequeue(struct Qdisc *sch)
>         q->cur_flow = flow - b->flows;
>         first_flow = false;
>
> -       /* triple isolation (modified DRR++) */
> -       srchost = &b->hosts[flow->srchost];
> -       dsthost = &b->hosts[flow->dsthost];
> -       host_load = 1;
> -
>         /* flow isolation (DRR++) */
>         if (flow->deficit <= 0) {
>                 /* Keep all flows with deficits out of the sparse and decaying
> @@ -2055,11 +2087,8 @@ static struct sk_buff *cake_dequeue(struct Qdisc *sch)
>                                 b->sparse_flow_count--;
>                                 b->bulk_flow_count++;
>
> -                               if (cake_dsrc(q->flow_mode))
> -                                       srchost->srchost_bulk_flow_count++;
> -
> -                               if (cake_ddst(q->flow_mode))
> -                                       dsthost->dsthost_bulk_flow_count++;
> +                               cake_inc_srchost_bulk_flow_count(b, flow, q->flow_mode);
> +                               cake_inc_dsthost_bulk_flow_count(b, flow, q->flow_mode);
>
>                                 flow->set = CAKE_SET_BULK;
>                         } else {
> @@ -2071,19 +2100,7 @@ static struct sk_buff *cake_dequeue(struct Qdisc *sch)
>                         }
>                 }
>
> -               if (cake_dsrc(q->flow_mode))
> -                       host_load = max(host_load, srchost->srchost_bulk_flow_count);
> -
> -               if (cake_ddst(q->flow_mode))
> -                       host_load = max(host_load, dsthost->dsthost_bulk_flow_count);
> -
> -               WARN_ON(host_load > CAKE_QUEUES);
> -
> -               /* The get_random_u16() is a way to apply dithering to avoid
> -                * accumulating roundoff errors
> -                */
> -               flow->deficit += (b->flow_quantum * quantum_div[host_load] +
> -                                 get_random_u16()) >> 16;
> +               flow->deficit += cake_get_flow_quantum(b, flow, q->flow_mode);
>                 list_move_tail(&flow->flowchain, &b->old_flows);
>
>                 goto retry;
> @@ -2107,11 +2124,8 @@ static struct sk_buff *cake_dequeue(struct Qdisc *sch)
>                                 if (flow->set == CAKE_SET_BULK) {
>                                         b->bulk_flow_count--;
>
> -                                       if (cake_dsrc(q->flow_mode))
> -                                               srchost->srchost_bulk_flow_count--;
> -
> -                                       if (cake_ddst(q->flow_mode))
> -                                               dsthost->dsthost_bulk_flow_count--;
> +                                       cake_dec_srchost_bulk_flow_count(b, flow, q->flow_mode);
> +                                       cake_dec_dsthost_bulk_flow_count(b, flow, q->flow_mode);
>
>                                         b->decaying_flow_count++;
>                                 } else if (flow->set == CAKE_SET_SPARSE ||
> @@ -2129,12 +2143,8 @@ static struct sk_buff *cake_dequeue(struct Qdisc *sch)
>                                 else if (flow->set == CAKE_SET_BULK) {
>                                         b->bulk_flow_count--;
>
> -                                       if (cake_dsrc(q->flow_mode))
> -                                               srchost->srchost_bulk_flow_count--;
> -
> -                                       if (cake_ddst(q->flow_mode))
> -                                               dsthost->dsthost_bulk_flow_count--;
> -
> +                                       cake_dec_srchost_bulk_flow_count(b, flow, q->flow_mode);
> +                                       cake_dec_dsthost_bulk_flow_count(b, flow, q->flow_mode);
>                                 } else
>                                         b->decaying_flow_count--;
>
> --
> 2.47.1
>

Acked-By: Dave Taht <dave.taht at gmail.com>



-- 
Dave Täht CSO, LibreQos


More information about the Cake mailing list