1 /*
2 * Copyright (c) 2024 Apple Inc. All rights reserved.
3 *
4 * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5 *
6 * This file contains Original Code and/or Modifications of Original Code
7 * as defined in and that are subject to the Apple Public Source License
8 * Version 2.0 (the 'License'). You may not use this file except in
9 * compliance with the License. The rights granted to you under the License
10 * may not be used to create, or enable the creation or redistribution of,
11 * unlawful or unlicensed copies of an Apple operating system, or to
12 * circumvent, violate, or enable the circumvention or violation of, any
13 * terms of an Apple operating system software license agreement.
14 *
15 * Please obtain a copy of the License at
16 * http://www.opensource.apple.com/apsl/ and read it before using this file.
17 *
18 * The Original Code and all software distributed under the License are
19 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23 * Please see the License for the specific language governing rights and
24 * limitations under the License.
25 *
26 * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27 */
28
29 #include "tcp_includes.h"
30
31 static int tcp_prague_init(struct tcpcb *tp);
32 static int tcp_prague_cleanup(struct tcpcb *tp);
33 static void tcp_prague_cwnd_init_or_reset(struct tcpcb *tp);
34 static void tcp_prague_ack_rcvd(struct tcpcb *tp, struct tcphdr *th);
35 static void tcp_prague_pre_fr(struct tcpcb *tp);
36 static void tcp_prague_post_fr(struct tcpcb *tp, struct tcphdr *th);
37 static void tcp_prague_after_timeout(struct tcpcb *tp);
38 static int tcp_prague_delay_ack(struct tcpcb *tp, struct tcphdr *th);
39 static void tcp_prague_switch_cc(struct tcpcb *tp);
40 static void tcp_prague_clear_state(struct tcpcb *tp);
41 static void tcp_prague_process_ecn(struct tcpcb *tp, struct tcphdr *th, uint32_t new_bytes_marked,
42 uint32_t packets_marked, uint32_t packets_acked);
43 static void tcp_prague_set_bytes_acked(struct tcpcb *tp, uint32_t acked);
44
45 static void prague_update_pacer_state(struct tcpcb *tp);
46 static void prague_ca_after_ce(struct tcpcb *tp, uint32_t acked);
47
48 extern float cbrtf(float x);
49
50 struct tcp_cc_algo tcp_cc_prague = {
51 .name = "prague",
52 .init = tcp_prague_init,
53 .cleanup = tcp_prague_cleanup,
54 .cwnd_init = tcp_prague_cwnd_init_or_reset,
55 .congestion_avd = tcp_prague_ack_rcvd,
56 .ack_rcvd = tcp_prague_ack_rcvd,
57 .pre_fr = tcp_prague_pre_fr,
58 .post_fr = tcp_prague_post_fr,
59 .after_idle = tcp_prague_cwnd_init_or_reset,
60 .after_timeout = tcp_prague_after_timeout,
61 .delay_ack = tcp_prague_delay_ack,
62 .process_ecn = tcp_prague_process_ecn,
63 .set_bytes_acked = tcp_prague_set_bytes_acked,
64 .switch_to = tcp_prague_switch_cc
65 };
66
67 /*
68 * Prague state includes:
69 * 1. Exponentially Weighted Moving Average (EWMA) -> alpha, of fraction of CE marks [0,1]
70 * 2. g is the estimation gain, a real number between 0 and 1, we use 1/2^4
71 * 3. scaled_alpha is alpha / g or alpha << g_shift
72 */
73 #define ALPHA_SHIFT (20)
74 #define G_SHIFT (4)
75 #define CWND_SHIFT (20)
76 #define MAX_ALPHA (1ULL << ALPHA_SHIFT)
77 #define REF_RTT_RATE (25) /* 25 ms */
78
79 #define BURST_SHIFT (12) /* 1/(2^12) = 0.000244s, we allow a burst queue of at least 250us */
80
81 #define PACING_INITIAL_RTT (100) /* 100ms, Only used to calculate startup pacer rate */
82 #define MSEC_PER_SEC (1000) /* milliseconds per second */
83
84 static float cubic_beta = 0.7f;
85 static float cubic_one_sub_beta = 0.3f;
86 static float cubic_one_add_beta = 1.7f;
87 static float cubic_fast_convergence_factor = 0.85f;
88 static float cubic_c_factor = 0.4f;
89
90 /*
91 * Compute the target congestion window for the next RTT according to
92 * cubic equation when an ack is received.
93 *
94 * W(t) = C(t-K)^3 + Wmax
95 */
96 static uint32_t
cubic_target(struct tcpcb * tp,uint32_t srtt)97 cubic_target(struct tcpcb *tp, uint32_t srtt)
98 {
99 struct tcp_globals *globals = tcp_get_globals(tp);
100 float K = 0;
101
102 if (tp->t_ccstate->cubic_epoch_start == 0) {
103 /*
104 * This is the beginning of a new epoch after recovery, initialize
105 * some of the variables that we need to use for computing the
106 * congestion window later.
107 */
108 tp->t_ccstate->cubic_epoch_start = tcp_globals_now(globals);
109 if (tp->t_ccstate->cubic_epoch_start == 0) {
110 tp->t_ccstate->cubic_epoch_start = 1;
111 }
112 if (tp->snd_cwnd < tp->t_ccstate->cubic_W_max) {
113 /*
114 * Compute cubic epoch period, this is the time
115 * period that the window will take to increase to
116 * Wmax again after reduction due to loss.
117 */
118
119 K = ((float)(tp->t_ccstate->cubic_W_max - tp->snd_cwnd)) / cubic_c_factor / tp->t_maxseg;
120 K = cbrtf(K);
121 tp->t_ccstate->cubic_K = K * TCP_RETRANSHZ; /* in milliseconds */
122 tp->t_ccstate->cubic_origin_point = tp->t_ccstate->cubic_W_max;
123 } else {
124 tp->t_ccstate->cubic_K = 0;
125 tp->t_ccstate->cubic_origin_point = tp->t_ccstate->cubic_W_max = tp->snd_cwnd;
126 }
127 }
128
129 if (tp->t_ccstate->cubic_origin_point == 0) {
130 os_log_error(OS_LOG_DEFAULT, "Prague CC: Cubic origin point should be greater than 0");
131 }
132 /*
133 * Compute the target window W(t+RTT) for the next RTT using,
134 * W(t) = C(t-K)^3 + W_max
135 */
136 uint32_t elapsed_time = timer_diff(tcp_globals_now(globals), 0, tp->t_ccstate->cubic_epoch_start, 0);
137 elapsed_time += srtt;
138
139 float var = (elapsed_time - tp->t_ccstate->cubic_K) / TCP_RETRANSHZ;
140 var = var * var * var * cubic_c_factor * tp->t_maxseg;
141
142 int32_t cwnd = (int32_t)((float)tp->t_ccstate->cubic_origin_point + var);
143 uint32_t w_cubic_next = cwnd > 0 ? (uint32_t)cwnd : 0;
144
145 /* Apply the lower and upper bound */
146 if (w_cubic_next < tp->snd_cwnd) {
147 /* Make sure that cwnd is non-decreasing */
148 w_cubic_next = tp->snd_cwnd;
149 } else if (w_cubic_next > (uint32_t)(1.5 * tp->snd_cwnd)) {
150 w_cubic_next = (uint32_t)(1.5 * tp->snd_cwnd);
151 }
152
153 return w_cubic_next;
154 }
155
156 static uint32_t
reno_friendly_win(struct tcpcb * tp,struct tcphdr * th)157 reno_friendly_win(struct tcpcb *tp, struct tcphdr *th)
158 {
159 if (tp->t_ccstate->reno_cwnd == 0) {
160 /*
161 * Start of the epoch, we set the reno_cwnd to cwnd at
162 * the start of the congestion avoidance stage
163 */
164 tp->t_ccstate->reno_cwnd = tp->snd_cwnd;
165 tp->t_ccstate->reno_acked = BYTES_ACKED(th, tp);
166 }
167 tp->t_ccstate->reno_acked += BYTES_ACKED(th, tp);
168
169 /*
170 * Increase by ai_factor * MSS, once per RTT. Counting bytes_acked
171 * against the snd_cwnd represents exactly one RTT at full rate.
172 */
173 while (tp->t_ccstate->reno_acked >= tp->snd_cwnd) {
174 /* Enough bytes have been ACK'd for TCP to do AIMD */
175 tp->t_ccstate->reno_acked -= tp->snd_cwnd;
176
177 /*
178 * Compute Reno Friendly window using,
179 * W_est(t) = W_max*ß + [3*(1-ß)/(1+ß)] * (bytes_acked/reno_cwnd)
180 */
181 float alpha_aimd = 0;
182 if (tp->t_ccstate->reno_cwnd < tp->t_ccstate->cubic_W_max) {
183 alpha_aimd = (float)3 * (cubic_one_sub_beta) / (cubic_one_add_beta);
184 } else {
185 alpha_aimd = 1;
186 }
187
188 tp->t_ccstate->reno_cwnd += (uint32_t)(tp->t_maxseg * alpha_aimd);
189 }
190
191 return tp->t_ccstate->reno_cwnd;
192 }
193
194 static void
tcp_cubic_ack_ca(struct tcpcb * tp,struct tcphdr * th,uint32_t acked)195 tcp_cubic_ack_ca(struct tcpcb *tp, struct tcphdr *th, uint32_t acked)
196 {
197 uint32_t cubic_target_win;
198 uint64_t incr_bytes_acked = 0;
199
200 /* Do not increase congestion window in non-validated phase */
201 if (tcp_cc_is_cwnd_nonvalidated(tp) != 0) {
202 return;
203 }
204
205 tp->t_bytes_acked += acked;
206 uint32_t srtt = tp->t_srtt >> TCP_RTT_SHIFT;
207 /*
208 * First compute cubic target window as given in RFC 9438 Section 4.2
209 */
210 cubic_target_win = cubic_target(tp, srtt);
211
212 /* Compute Reno-friendly window */
213 uint32_t reno_win = reno_friendly_win(tp, th);
214 reno_win = tcp_round_to(reno_win, tp->t_maxseg);
215
216 if (tp->snd_cwnd < cubic_target_win) {
217 /*
218 * Compute how many bytes need to be acknowledged
219 * before we can increase the cwnd by one segment.
220 * It is given by cwnd / (target - cwnd)
221 */
222 incr_bytes_acked = (uint64_t)tp->snd_cwnd * tp->t_maxseg;
223 incr_bytes_acked /= (cubic_target_win - tp->snd_cwnd);
224 }
225
226 if (tp->snd_cwnd < reno_win) {
227 uint64_t incr_reno_bytes_acked = (uint64_t)tp->snd_cwnd * tp->t_maxseg;
228 incr_reno_bytes_acked /= (reno_win - tp->snd_cwnd);
229
230 if (incr_reno_bytes_acked < incr_bytes_acked) {
231 /*
232 * Cubic is in Reno-friendly region as less bytes
233 * are needed to increase in Reno way.
234 */
235 incr_bytes_acked = incr_reno_bytes_acked;
236 }
237 }
238
239 if (incr_bytes_acked > 0 && tp->t_bytes_acked >= incr_bytes_acked) {
240 tp->t_bytes_acked -= incr_bytes_acked;
241 tp->snd_cwnd = min(tp->snd_cwnd + tp->t_maxseg, TCP_MAXWIN << tp->snd_scale);
242 }
243 }
244
245 static void
tcp_prague_ack_rcvd(struct tcpcb * tp,struct tcphdr * th)246 tcp_prague_ack_rcvd(struct tcpcb *tp, struct tcphdr *th)
247 {
248 /* Do not increase the congestion window in non-validated phase */
249 if (tcp_cc_is_cwnd_nonvalidated(tp) != 0) {
250 return;
251 }
252
253 uint32_t acked = tp->t_ccstate->prague_bytes_acked;
254
255 if (acked == 0) {
256 return;
257 }
258
259 if (tp->snd_cwnd >= tp->snd_ssthresh) {
260 /* Congestion avoidance phase */
261 if (tp->t_ccstate->reduced_due_to_ce) {
262 prague_ca_after_ce(tp, acked);
263 } else {
264 tcp_cubic_ack_ca(tp, th, acked);
265 }
266 } else {
267 /*
268 * Maximum burst-size is limited to the initial congestion-window.
269 * We know that the network can survive this kind of burst.
270 */
271 uint32_t abc_lim = tcp_initial_cwnd(tp);
272 uint32_t incr = min(acked, abc_lim);
273 tp->snd_cwnd += incr;
274 }
275
276 tp->snd_cwnd = tcp_round_to(tp->snd_cwnd, tp->t_maxseg);
277 if (tp->snd_cwnd < 2 * tp->t_maxseg) {
278 tp->snd_cwnd = 2 * tp->t_maxseg;
279 }
280
281 prague_update_pacer_state(tp);
282 }
283
284 static void
tcp_prague_pre_fr(struct tcpcb * tp)285 tcp_prague_pre_fr(struct tcpcb *tp)
286 {
287 tp->t_ccstate->cubic_epoch_start = 0;
288
289 uint32_t win = tp->snd_cwnd;
290 if (tp->t_flagsext & TF_CWND_NONVALIDATED) {
291 tp->t_lossflightsize = tcp_flight_size(tp);
292 win = max(tp->t_pipeack, tp->t_lossflightsize);
293 } else {
294 tp->t_lossflightsize = 0;
295 }
296 /*
297 * Note the congestion window at which packet loss occurred as
298 * cub_W_max.
299 *
300 * If the current congestion window is less than the last Wmax when
301 * loss occurred, it indicates that capacity available in the
302 * network has gone down. This can happen if a new flow has started
303 * and it is capturing some of the bandwidth. To reach convergence
304 * quickly, backoff a little more.
305 */
306 if (win < tp->t_ccstate->cubic_W_max) {
307 tp->t_ccstate->cubic_W_max = (uint32_t)((float)win * cubic_fast_convergence_factor);
308 } else {
309 tp->t_ccstate->cubic_W_max = win;
310 }
311
312 /* Calculate congestion window by multiplying with beta factor */
313 win = (uint32_t)(win * cubic_beta);
314 win = tcp_round_to(win, tp->t_maxseg);
315 if (win < 2 * tp->t_maxseg) {
316 win = 2 * tp->t_maxseg;
317 }
318 tp->snd_ssthresh = win;
319 tcp_cc_resize_sndbuf(tp);
320
321 tp->t_ccstate->num_cong_events_loss++;
322 tp->t_ccstate->in_loss = 1;
323 tp->t_ccstate->reduced_due_to_ce = false;
324 }
325
326 static void
tcp_prague_post_fr(struct tcpcb * tp,__unused struct tcphdr * th)327 tcp_prague_post_fr(struct tcpcb *tp, __unused struct tcphdr *th)
328 {
329 /*
330 * Since we do cwnd validation during pre_fr,
331 * we can safely set cwnd to ssthresh.
332 */
333 tp->snd_cwnd = tp->snd_ssthresh;
334
335 prague_update_pacer_state(tp);
336
337 tp->t_ccstate->reno_cwnd = 0;
338 tp->t_ccstate->reno_acked = 0;
339
340 tp->t_ccstate->in_loss = 0;
341 }
342
343 static bool
rtt_elapsed(uint32_t largest_snd_nxt,uint32_t ack)344 rtt_elapsed(uint32_t largest_snd_nxt, uint32_t ack)
345 {
346 /*
347 * A packet with sequence higher than largest sent sequence at
348 * the start of the round has been acknowledged
349 * Packet must have been sent after the processing of this ACK
350 */
351 return largest_snd_nxt == 0 || SEQ_GT(ack, largest_snd_nxt);
352 }
353
354 static void
prague_update_pacer_state(struct tcpcb * tp)355 prague_update_pacer_state(struct tcpcb *tp)
356 {
357 uint32_t srtt = tp->t_srtt >> TCP_RTT_SHIFT;
358 if (srtt == 0) {
359 srtt = PACING_INITIAL_RTT;
360 }
361
362 uint64_t rate = tp->snd_cwnd;
363
364 /* Use 200% rate when in slow start */
365 if (tp->snd_cwnd < tp->snd_ssthresh) {
366 rate *= 2;
367 }
368
369 /* Multiply by MSEC_PER_SEC as srtt is in milliseconds */
370 rate *= MSEC_PER_SEC;
371 rate = rate / srtt;
372
373 uint32_t burst = (uint32_t)(rate >> BURST_SHIFT);
374
375 tp->t_pacer.rate = rate;
376 tp->t_pacer.tso_burst_size = burst;
377 }
378
379 /*
380 * RTT independence using square of RTT ratio to acheive rate fairness.
381 * For additive increase, alpha = (RTT / REF_RTT) ^ 2
382 * 1. Grow by 1MSS per target RTT; i.e. by a ratio of actual / target RTT and
383 * 2. Growing by 1MSS translates to a sending rate increase proportional to the
384 * same RTT ratio.
385 *
386 * Update infrequently whenever a change in srtt is expected.
387 */
388 static void
prague_ai_alpha_rate(struct tcpcb * tp)389 prague_ai_alpha_rate(struct tcpcb *tp)
390 {
391 uint32_t srtt = tp->t_srtt >> TCP_RTT_SHIFT;
392 if (srtt > REF_RTT_RATE) {
393 tp->t_ccstate->prague_alpha_ai = (1 << CWND_SHIFT);
394 return;
395 }
396
397 uint64_t numer = srtt << CWND_SHIFT;
398 numer *= srtt;
399 uint64_t divisor = REF_RTT_RATE * REF_RTT_RATE;
400
401 tp->t_ccstate->prague_alpha_ai = (uint64_t)((numer + (divisor >> 1)) / divisor);
402 }
403
404 /*
405 * Handle an ACK in congestion avoidance phase
406 * after the decrease happened due to CE
407 */
408 static void
prague_ca_after_ce(struct tcpcb * tp,uint32_t acked)409 prague_ca_after_ce(struct tcpcb *tp, uint32_t acked)
410 {
411 tp->t_bytes_acked += acked;
412 /*
413 * To increase cwnd by 1MSS, we need cwnd / alpha_ai bytes
414 * to be acknowledged. Scale cwnd by CWND_SHIFT as alpha_ai
415 * is already scaled to avoid floating point arithmetic
416 */
417 uint64_t bytes_needed_for_increase = (uint64_t)tp->snd_cwnd << CWND_SHIFT;
418 bytes_needed_for_increase /= tp->t_ccstate->prague_alpha_ai;
419
420 if (tp->t_bytes_acked >= bytes_needed_for_increase) {
421 tp->t_bytes_acked -= bytes_needed_for_increase;
422 tp->snd_cwnd += tp->t_maxseg;
423 }
424 }
425
426 static void
prague_update_alpha(struct tcpcb * tp,uint32_t ack,uint32_t packets_marked,uint32_t packets_acked)427 prague_update_alpha(struct tcpcb *tp, uint32_t ack, uint32_t packets_marked,
428 uint32_t packets_acked)
429 {
430 if (!rtt_elapsed(tp->t_ccstate->snd_nxt_alpha, ack)) {
431 /* One RTT hasn't elapsed yet, don't update alpha */
432 os_log(OS_LOG_DEFAULT, "one RTT hasn't elapsed, not updating alpha");
433 return;
434 }
435
436 if (!tp->t_ccstate->ever_saw_ce) {
437 return;
438 }
439
440 uint32_t newly_marked = 0, newly_acked = 0;
441
442 if (packets_marked > tp->t_ccstate->prague_packets_marked) {
443 newly_marked = packets_marked - tp->t_ccstate->prague_packets_marked;
444 }
445
446 if (packets_acked > tp->t_ccstate->prague_packets_acked) {
447 newly_acked = packets_acked - tp->t_ccstate->prague_packets_acked;
448 } else {
449 os_log_error(OS_LOG_DEFAULT,
450 "No new packets were ACK'ed, we shouldn't be called");
451 return;
452 }
453
454 uint64_t scaled_alpha = tp->t_ccstate->prague_scaled_alpha;
455
456 /*
457 * We currently don't react to local AQM for TCP Prague
458 */
459 uint64_t p = (newly_marked << ALPHA_SHIFT) / newly_acked;
460 /*
461 * Equation for alpha,
462 * alpha = (1 - g) * alpha + g * F (fraction of marked / acked)
463 * alpha = alpha - (alpha >> g_shift) + (marked << (alpha_shift -
464 * g_shift)) / acked, OR
465 * scaled_alpha = scaled_alpha - (scaled_alpha >> g_shift) +
466 * (marked << alpha_shift) / acked
467 */
468 scaled_alpha = scaled_alpha - (scaled_alpha >> G_SHIFT) + p;
469 tp->t_ccstate->prague_scaled_alpha = MIN(MAX_ALPHA << G_SHIFT, scaled_alpha);
470
471 /* New round for alpha */
472 tp->t_ccstate->snd_nxt_alpha = tp->snd_nxt;
473 tp->t_ccstate->prague_packets_marked = packets_marked;
474 tp->t_ccstate->prague_packets_acked = packets_acked;
475 }
476
477 static bool
prague_cwr(struct tcpcb * tp)478 prague_cwr(struct tcpcb *tp)
479 {
480 // If we are currently in loss recovery, then do nothing
481 if (tp->t_ccstate->in_loss) {
482 os_log(OS_LOG_DEFAULT, "currently in loss recovery, no need to do CWR");
483 return false;
484 }
485
486 tp->t_ccstate->num_cong_events_ce++;
487 const uint64_t alpha = tp->t_ccstate->prague_scaled_alpha >> G_SHIFT;
488
489 /*
490 * For Prague, the recovery time is only set during packet
491 * loss and we allow any ACKs that don't have CE marks to
492 * increase cwnd during ack_end, even in CWR state.
493 *
494 * On entering CWR, cwnd = cwnd * (1 - DCTCP.alpha) / 2
495 */
496 uint64_t reduction =
497 (tp->snd_cwnd * alpha) >> (ALPHA_SHIFT + 1);
498 tp->snd_cwnd -= reduction;
499
500 /* If no more increase due to non-CE acked bytes, then round it */
501 if (tp->t_ccstate->prague_bytes_acked == 0) {
502 tp->snd_cwnd = tcp_round_to(tp->snd_cwnd, tp->t_maxseg);
503 }
504 /* Should be at least 2 MSS */
505 if (tp->snd_cwnd < 2 * tp->t_maxseg) {
506 tp->snd_cwnd = 2 * tp->t_maxseg;
507 }
508
509 tp->snd_ssthresh = tp->snd_cwnd;
510
511 tp->t_ccstate->reduced_due_to_ce = true;
512
513 return true;
514 }
515
516 static void
tcp_prague_process_ecn(struct tcpcb * tp,struct tcphdr * th,uint32_t new_bytes_marked,uint32_t packets_marked,uint32_t packets_acked)517 tcp_prague_process_ecn(struct tcpcb *tp, struct tcphdr *th, uint32_t new_bytes_marked,
518 uint32_t packets_marked, uint32_t packets_acked)
519 {
520 if (__improbable(packets_marked < tp->t_ccstate->prague_ce_counter ||
521 packets_acked < tp->t_ccstate->prague_packets_acked)) {
522 os_log_error(OS_LOG_DEFAULT, "new CE count (%u) can't be less than current CE count (%u)"
523 "OR newly ACKed (%u) can't be less that current ACKed (%u)",
524 packets_marked, tp->t_ccstate->prague_ce_counter,
525 packets_acked, tp->t_ccstate->prague_packets_acked);
526 }
527
528 if (packets_marked > tp->t_ccstate->prague_ce_counter) {
529 tp->t_ccstate->ever_saw_ce = true;
530 }
531 /*
532 * update alpha of fraction of marked packets,
533 * even when there are no new CE counts
534 */
535 if (packets_acked > tp->t_ccstate->prague_packets_acked) {
536 prague_update_alpha(tp, th->th_ack, packets_marked, packets_acked);
537 }
538
539 if (packets_marked == tp->t_ccstate->prague_ce_counter) {
540 /* No change in CE */
541 return;
542 }
543
544 os_log(OS_LOG_DEFAULT, "%u packets were newly CE marked",
545 packets_marked - tp->t_ccstate->prague_ce_counter);
546 /*
547 * Received an ACK with new CE counts, subtract CE marked bytes
548 * from bytes_acked, so that we use only unmarked bytes to
549 * increase cwnd during ACK processing
550 */
551 if (tp->t_ccstate->prague_bytes_acked > new_bytes_marked) {
552 tp->t_ccstate->prague_bytes_acked -= new_bytes_marked;
553 } else {
554 tp->t_ccstate->prague_bytes_acked = 0;
555 }
556
557 /* Update CE count even if we are already in CWR */
558 tp->t_ccstate->prague_ce_counter = packets_marked;
559
560 /* Update AIMD alpha as SRTT might have changed */
561 prague_ai_alpha_rate(tp);
562
563 if (!rtt_elapsed(tp->t_ccstate->snd_nxt_cwr, th->th_ack)) {
564 /* One RTT hasn't elapsed yet, don't doing CWR */
565 os_log(OS_LOG_DEFAULT, "one RTT hasn't elapsed, not doing CWR");
566 return;
567 }
568
569 /* CWR reduction if new counts are received */
570 bool cwnd_changed = prague_cwr(tp);
571
572 /* Update pacer state if cwnd has changed */
573 if (cwnd_changed) {
574 prague_update_pacer_state(tp);
575 }
576 /* New round for CWR */
577 tp->t_ccstate->snd_nxt_cwr = tp->snd_nxt;
578 }
579
580 static void
tcp_prague_set_bytes_acked(struct tcpcb * tp,uint32_t acked)581 tcp_prague_set_bytes_acked(struct tcpcb *tp, uint32_t acked)
582 {
583 /* Set bytes_acked which will be used later during ack_rcvd() */
584 tp->t_ccstate->prague_bytes_acked = acked;
585 }
586
587 static void
tcp_prague_clear_state(struct tcpcb * tp)588 tcp_prague_clear_state(struct tcpcb *tp)
589 {
590 tp->snd_cwnd_prev = 0;
591 tp->t_ccstate->num_cong_events_loss = 0;
592 tp->t_ccstate->num_cong_events_ce = 0;
593 tp->t_ccstate->prague_alpha_ai = (1 << CWND_SHIFT);
594
595 /* CUBIC state */
596 tp->t_ccstate->cubic_K = 0;
597 //prague->cubic_acked = 0;
598 tp->t_ccstate->cubic_epoch_start = 0;
599 tp->t_ccstate->cubic_origin_point = 0;
600 tp->t_ccstate->cubic_W_max = 0;
601 }
602
603 int
tcp_prague_init(struct tcpcb * tp)604 tcp_prague_init(struct tcpcb *tp)
605 {
606 os_atomic_inc(&tcp_cc_prague.num_sockets, relaxed);
607
608 VERIFY(tp->t_ccstate != NULL);
609
610 tp->t_ccstate->prague_scaled_alpha = (MAX_ALPHA << G_SHIFT);
611 tcp_prague_clear_state(tp);
612 return 0;
613 }
614
615 int
tcp_prague_cleanup(struct tcpcb * tp)616 tcp_prague_cleanup(struct tcpcb *tp)
617 {
618 #pragma unused(tp)
619 os_atomic_dec(&tcp_cc_prague.num_sockets, relaxed);
620 return 0;
621 }
622
623 /*
624 * Initialize the congestion window for a connection
625 */
626 void
tcp_prague_cwnd_init_or_reset(struct tcpcb * tp)627 tcp_prague_cwnd_init_or_reset(struct tcpcb *tp)
628 {
629 VERIFY(tp->t_ccstate != NULL);
630
631 tcp_prague_clear_state(tp);
632 tcp_cc_cwnd_init_or_reset(tp);
633 tp->t_pipeack = 0;
634 tcp_clear_pipeack_state(tp);
635
636 /* Start counting bytes for RFC 3465 again */
637 tp->t_bytes_acked = 0;
638
639 /*
640 * slow start threshold could get initialized to a lower value
641 * when there is a cached value in the route metrics. In this case,
642 * the connection can enter congestion avoidance without any packet
643 * loss and Cubic will enter steady-state too early. It is better
644 * to always probe to find the initial slow-start threshold.
645 */
646 if (tp->t_inpcb->inp_stat->txbytes <= tcp_initial_cwnd(tp) &&
647 tp->snd_ssthresh < (TCP_MAXWIN << TCP_MAX_WINSHIFT)) {
648 tp->snd_ssthresh = TCP_MAXWIN << TCP_MAX_WINSHIFT;
649 }
650
651 /* Initialize cubic last max to be same as ssthresh */
652 tp->t_ccstate->cubic_W_max = tp->snd_ssthresh;
653
654 /* Set initial pacer state */
655 uint64_t startup_rate =
656 tp->snd_cwnd * MSEC_PER_SEC / PACING_INITIAL_RTT;
657 uint32_t startup_burst_size = tp->t_maxseg;
658 tp->t_pacer.rate = startup_rate;
659 tp->t_pacer.tso_burst_size = startup_burst_size;
660 }
661
662 static void
tcp_prague_after_timeout(struct tcpcb * tp)663 tcp_prague_after_timeout(struct tcpcb *tp)
664 {
665 VERIFY(tp->t_ccstate != NULL);
666
667 /*
668 * Avoid adjusting congestion window due to SYN retransmissions.
669 * If more than one byte (SYN) is outstanding then it is still
670 * needed to adjust the window.
671 */
672 if (tp->t_state < TCPS_ESTABLISHED &&
673 ((int)(tp->snd_max - tp->snd_una) <= 1)) {
674 return;
675 }
676
677 if (!IN_FASTRECOVERY(tp)) {
678 tcp_prague_clear_state(tp);
679 tcp_prague_pre_fr(tp);
680 }
681
682 /*
683 * Close the congestion window down to one segment as a retransmit
684 * timeout might indicate severe congestion.
685 */
686 tp->snd_cwnd = tp->t_maxseg;
687 }
688
689 static int
tcp_prague_delay_ack(struct tcpcb * tp,struct tcphdr * th)690 tcp_prague_delay_ack(struct tcpcb *tp, struct tcphdr *th)
691 {
692 return tcp_cc_delay_ack(tp, th);
693 }
694
695 /*
696 * When switching from a different CC it is better for Cubic to start
697 * fresh. The state required for Cubic calculation might be stale and it
698 * might not represent the current state of the network. If it starts as
699 * a new connection it will probe and learn the existing network conditions.
700 */
701 static void
tcp_prague_switch_cc(struct tcpcb * tp)702 tcp_prague_switch_cc(struct tcpcb *tp)
703 {
704 tcp_prague_cwnd_init_or_reset(tp);
705
706 os_atomic_inc(&tcp_cc_prague.num_sockets, relaxed);
707 }
708