xref: /xnu-12377.61.12/bsd/netinet/tcp_prague.c (revision 4d495c6e23c53686cf65f45067f79024cf5dcee8)
1 /*
2  * Copyright (c) 2024 Apple Inc. All rights reserved.
3  *
4  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
5  *
6  * This file contains Original Code and/or Modifications of Original Code
7  * as defined in and that are subject to the Apple Public Source License
8  * Version 2.0 (the 'License'). You may not use this file except in
9  * compliance with the License. The rights granted to you under the License
10  * may not be used to create, or enable the creation or redistribution of,
11  * unlawful or unlicensed copies of an Apple operating system, or to
12  * circumvent, violate, or enable the circumvention or violation of, any
13  * terms of an Apple operating system software license agreement.
14  *
15  * Please obtain a copy of the License at
16  * http://www.opensource.apple.com/apsl/ and read it before using this file.
17  *
18  * The Original Code and all software distributed under the License are
19  * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
20  * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
21  * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
22  * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
23  * Please see the License for the specific language governing rights and
24  * limitations under the License.
25  *
26  * @APPLE_OSREFERENCE_LICENSE_HEADER_END@
27  */
28 
29  #include "tcp_includes.h"
30 
31 static int tcp_prague_init(struct tcpcb *tp);
32 static int tcp_prague_cleanup(struct tcpcb *tp);
33 static void tcp_prague_cwnd_init_or_reset(struct tcpcb *tp);
34 static void tcp_prague_ack_rcvd(struct tcpcb *tp, struct tcphdr *th);
35 static void tcp_prague_pre_fr(struct tcpcb *tp);
36 static void tcp_prague_post_fr(struct tcpcb *tp, struct tcphdr *th);
37 static void tcp_prague_after_timeout(struct tcpcb *tp);
38 static int tcp_prague_delay_ack(struct tcpcb *tp, struct tcphdr *th);
39 static void tcp_prague_switch_cc(struct tcpcb *tp);
40 static void tcp_prague_clear_state(struct tcpcb *tp);
41 static void tcp_prague_process_ecn(struct tcpcb *tp, struct tcphdr *th, uint32_t new_bytes_marked,
42     uint32_t packets_marked, uint32_t packets_acked);
43 static void tcp_prague_set_bytes_acked(struct tcpcb *tp, uint32_t acked);
44 
45 static void prague_ca_after_ce(struct tcpcb *tp, uint32_t acked);
46 
47 extern float cbrtf(float x);
48 
49 struct tcp_cc_algo tcp_cc_prague = {
50 	.name = "prague",
51 	.init = tcp_prague_init,
52 	.cleanup = tcp_prague_cleanup,
53 	.cwnd_init = tcp_prague_cwnd_init_or_reset,
54 	.congestion_avd = tcp_prague_ack_rcvd,
55 	.ack_rcvd = tcp_prague_ack_rcvd,
56 	.pre_fr = tcp_prague_pre_fr,
57 	.post_fr = tcp_prague_post_fr,
58 	.after_idle = tcp_prague_cwnd_init_or_reset,
59 	.after_timeout = tcp_prague_after_timeout,
60 	.delay_ack = tcp_prague_delay_ack,
61 	.process_ecn = tcp_prague_process_ecn,
62 	.set_bytes_acked = tcp_prague_set_bytes_acked,
63 	.switch_to = tcp_prague_switch_cc
64 };
65 
66 /*
67  * Prague state includes:
68  * 1. Exponentially Weighted Moving Average (EWMA) -> alpha, of fraction of CE marks [0,1]
69  * 2. g is the estimation gain, a real number between 0 and 1, we use 1/2^4
70  * 3. scaled_alpha is alpha / g or alpha << g_shift
71  */
72 #define ALPHA_SHIFT                 (20)
73 #define G_SHIFT                     (4)
74 #define CWND_SHIFT                  (20)
75 #define MAX_ALPHA                   (1ULL << ALPHA_SHIFT)
76 #define REF_RTT_RATE                (25)   /* 25 ms */
77 
78 static float cubic_beta = 0.7f;
79 static float cubic_one_sub_beta = 0.3f;
80 static float cubic_one_add_beta = 1.7f;
81 static float cubic_fast_convergence_factor = 0.85f;
82 static float cubic_c_factor = 0.4f;
83 
84 /*
85  * Compute the target congestion window for the next RTT according to
86  * cubic equation when an ack is received.
87  *
88  * W(t) = C(t-K)^3 + Wmax
89  */
90 static uint32_t
cubic_target(struct tcpcb * tp,uint32_t srtt)91 cubic_target(struct tcpcb *tp, uint32_t srtt)
92 {
93 	struct tcp_globals *globals = tcp_get_globals(tp);
94 	float K = 0;
95 
96 	if (tp->t_ccstate->cubic_epoch_start == 0) {
97 		/*
98 		 * This is the beginning of a new epoch after recovery, initialize
99 		 * some of the variables that we need to use for computing the
100 		 * congestion window later.
101 		 */
102 		tp->t_ccstate->cubic_epoch_start = tcp_globals_now(globals);
103 		if (tp->t_ccstate->cubic_epoch_start == 0) {
104 			tp->t_ccstate->cubic_epoch_start = 1;
105 		}
106 		if (tp->snd_cwnd < tp->t_ccstate->cubic_W_max) {
107 			/*
108 			 * Compute cubic epoch period, this is the time
109 			 * period that the window will take to increase to
110 			 * Wmax again after reduction due to loss.
111 			 */
112 
113 			K = ((float)(tp->t_ccstate->cubic_W_max - tp->snd_cwnd)) / cubic_c_factor / tp->t_maxseg;
114 			K = cbrtf(K);
115 			tp->t_ccstate->cubic_K = K * TCP_RETRANSHZ; /* in milliseconds */
116 			tp->t_ccstate->cubic_origin_point = tp->t_ccstate->cubic_W_max;
117 		} else {
118 			tp->t_ccstate->cubic_K = 0;
119 			tp->t_ccstate->cubic_origin_point = tp->t_ccstate->cubic_W_max = tp->snd_cwnd;
120 		}
121 	}
122 
123 	if (tp->t_ccstate->cubic_origin_point == 0) {
124 		os_log_error(OS_LOG_DEFAULT, "Prague CC: Cubic origin point should be greater than 0");
125 	}
126 	/*
127 	 * Compute the target window W(t+RTT) for the next RTT using,
128 	 * W(t) = C(t-K)^3 + W_max
129 	 */
130 	uint32_t elapsed_time = timer_diff(tcp_globals_now(globals), 0, tp->t_ccstate->cubic_epoch_start, 0);
131 	elapsed_time += srtt;
132 
133 	float var = (elapsed_time - tp->t_ccstate->cubic_K) / TCP_RETRANSHZ;
134 	var = var * var * var * cubic_c_factor * tp->t_maxseg;
135 
136 	int32_t cwnd = (int32_t)((float)tp->t_ccstate->cubic_origin_point + var);
137 	uint32_t w_cubic_next = cwnd > 0 ? (uint32_t)cwnd : 0;
138 
139 	/* Apply the lower and upper bound */
140 	if (w_cubic_next < tp->snd_cwnd) {
141 		/* Make sure that cwnd is non-decreasing */
142 		w_cubic_next = tp->snd_cwnd;
143 	} else if (w_cubic_next > (uint32_t)(1.5 * tp->snd_cwnd)) {
144 		w_cubic_next = (uint32_t)(1.5 * tp->snd_cwnd);
145 	}
146 
147 	return w_cubic_next;
148 }
149 
150 static uint32_t
reno_friendly_win(struct tcpcb * tp,struct tcphdr * th)151 reno_friendly_win(struct tcpcb *tp, struct tcphdr *th)
152 {
153 	if (tp->t_ccstate->reno_cwnd == 0) {
154 		/*
155 		 * Start of the epoch, we set the reno_cwnd to cwnd at
156 		 * the start of the congestion avoidance stage
157 		 */
158 		tp->t_ccstate->reno_cwnd = tp->snd_cwnd;
159 		tp->t_ccstate->reno_acked = BYTES_ACKED(th, tp);
160 	}
161 	tp->t_ccstate->reno_acked += BYTES_ACKED(th, tp);
162 
163 	/*
164 	 * Increase by ai_factor * MSS, once per RTT. Counting bytes_acked
165 	 * against the snd_cwnd represents exactly one RTT at full rate.
166 	 */
167 	while (tp->t_ccstate->reno_acked >= tp->snd_cwnd) {
168 		/* Enough bytes have been ACK'd for TCP to do AIMD */
169 		tp->t_ccstate->reno_acked -= tp->snd_cwnd;
170 
171 		/*
172 		 * Compute Reno Friendly window using,
173 		 *  W_est(t) = W_max*ß + [3*(1-ß)/(1+ß)] * (bytes_acked/reno_cwnd)
174 		 */
175 		float alpha_aimd = 0;
176 		if (tp->t_ccstate->reno_cwnd < tp->t_ccstate->cubic_W_max) {
177 			alpha_aimd = (float)3 * (cubic_one_sub_beta) / (cubic_one_add_beta);
178 		} else {
179 			alpha_aimd = 1;
180 		}
181 
182 		tp->t_ccstate->reno_cwnd += (uint32_t)(tp->t_maxseg * alpha_aimd);
183 	}
184 
185 	return tp->t_ccstate->reno_cwnd;
186 }
187 
188 static void
tcp_cubic_ack_ca(struct tcpcb * tp,struct tcphdr * th,uint32_t acked)189 tcp_cubic_ack_ca(struct tcpcb *tp, struct tcphdr *th, uint32_t acked)
190 {
191 	uint32_t cubic_target_win;
192 	uint64_t incr_bytes_acked = 0;
193 
194 	/* Do not increase congestion window in non-validated phase */
195 	if (tcp_cc_is_cwnd_nonvalidated(tp) != 0) {
196 		return;
197 	}
198 
199 	tp->t_bytes_acked += acked;
200 	uint32_t srtt = tp->t_srtt >> TCP_RTT_SHIFT;
201 	/*
202 	 * First compute cubic target window as given in RFC 9438 Section 4.2
203 	 */
204 	cubic_target_win = cubic_target(tp, srtt);
205 
206 	/* Compute Reno-friendly window */
207 	uint32_t reno_win = reno_friendly_win(tp, th);
208 	reno_win = tcp_round_to(reno_win, tp->t_maxseg);
209 
210 	if (tp->snd_cwnd < cubic_target_win) {
211 		/*
212 		 * Compute how many bytes need to be acknowledged
213 		 * before we can increase the cwnd by one segment.
214 		 * It is given by cwnd / (target - cwnd)
215 		 */
216 		incr_bytes_acked = (uint64_t)tp->snd_cwnd * tp->t_maxseg;
217 		incr_bytes_acked /= (cubic_target_win - tp->snd_cwnd);
218 	}
219 
220 	if (tp->snd_cwnd < reno_win) {
221 		uint64_t incr_reno_bytes_acked = (uint64_t)tp->snd_cwnd * tp->t_maxseg;
222 		incr_reno_bytes_acked /= (reno_win - tp->snd_cwnd);
223 
224 		if (incr_reno_bytes_acked < incr_bytes_acked) {
225 			/*
226 			 * Cubic is in Reno-friendly region as less bytes
227 			 * are needed to increase in Reno way.
228 			 */
229 			incr_bytes_acked = incr_reno_bytes_acked;
230 		}
231 	}
232 
233 	if (incr_bytes_acked > 0 && tp->t_bytes_acked >= incr_bytes_acked) {
234 		tp->t_bytes_acked -= incr_bytes_acked;
235 		tp->snd_cwnd = min(tp->snd_cwnd + tp->t_maxseg, TCP_MAXWIN << tp->snd_scale);
236 	}
237 }
238 
239 static void
tcp_prague_ack_rcvd(struct tcpcb * tp,struct tcphdr * th)240 tcp_prague_ack_rcvd(struct tcpcb *tp, struct tcphdr *th)
241 {
242 	/* Do not increase the congestion window in non-validated phase */
243 	if (tcp_cc_is_cwnd_nonvalidated(tp) != 0) {
244 		return;
245 	}
246 
247 	uint32_t acked = tp->t_ccstate->prague_bytes_acked;
248 
249 	if (acked == 0) {
250 		return;
251 	}
252 
253 	if (tp->snd_cwnd >= tp->snd_ssthresh) {
254 		/* Congestion avoidance phase */
255 		if (tp->t_ccstate->reduced_due_to_ce) {
256 			prague_ca_after_ce(tp, acked);
257 		} else {
258 			tcp_cubic_ack_ca(tp, th, acked);
259 		}
260 	} else {
261 		/*
262 		 * Maximum burst-size is limited to the initial congestion-window.
263 		 * We know that the network can survive this kind of burst.
264 		 */
265 		uint32_t abc_lim = tcp_initial_cwnd(tp);
266 		uint32_t incr = min(acked, abc_lim);
267 		tp->snd_cwnd += incr;
268 	}
269 
270 	tp->snd_cwnd = tcp_round_to(tp->snd_cwnd, tp->t_maxseg);
271 	if (tp->snd_cwnd < 2 * tp->t_maxseg) {
272 		tp->snd_cwnd =  2 * tp->t_maxseg;
273 	}
274 
275 	tcp_update_pacer_state(tp);
276 }
277 
278 static void
tcp_prague_pre_fr(struct tcpcb * tp)279 tcp_prague_pre_fr(struct tcpcb *tp)
280 {
281 	tp->t_ccstate->cubic_epoch_start = 0;
282 
283 	uint32_t win = tp->snd_cwnd;
284 	if (tp->t_flagsext & TF_CWND_NONVALIDATED) {
285 		tp->t_lossflightsize = tcp_flight_size(tp);
286 		win = max(tp->t_pipeack, tp->t_lossflightsize);
287 	} else {
288 		tp->t_lossflightsize = 0;
289 	}
290 	/*
291 	 * Note the congestion window at which packet loss occurred as
292 	 * cub_W_max.
293 	 *
294 	 * If the current congestion window is less than the last Wmax when
295 	 * loss occurred, it indicates that capacity available in the
296 	 * network has gone down. This can happen if a new flow has started
297 	 * and it is capturing some of the bandwidth. To reach convergence
298 	 * quickly, backoff a little more.
299 	 */
300 	if (win < tp->t_ccstate->cubic_W_max) {
301 		tp->t_ccstate->cubic_W_max = (uint32_t)((float)win * cubic_fast_convergence_factor);
302 	} else {
303 		tp->t_ccstate->cubic_W_max = win;
304 	}
305 
306 	/* Calculate congestion window by multiplying with beta factor */
307 	win = (uint32_t)(win * cubic_beta);
308 	win = tcp_round_to(win, tp->t_maxseg);
309 	if (win < 2 * tp->t_maxseg) {
310 		win =  2 * tp->t_maxseg;
311 	}
312 	tp->snd_ssthresh = win;
313 	tcp_cc_resize_sndbuf(tp);
314 
315 	tp->t_ccstate->num_cong_events_loss++;
316 	tp->t_ccstate->in_loss = 1;
317 	tp->t_ccstate->reduced_due_to_ce = false;
318 }
319 
320 static void
tcp_prague_post_fr(struct tcpcb * tp,__unused struct tcphdr * th)321 tcp_prague_post_fr(struct tcpcb *tp, __unused struct tcphdr *th)
322 {
323 	/*
324 	 * Since we do cwnd validation during pre_fr,
325 	 * we can safely set cwnd to ssthresh.
326 	 */
327 	tp->snd_cwnd = tp->snd_ssthresh;
328 
329 	tcp_update_pacer_state(tp);
330 
331 	tp->t_ccstate->reno_cwnd = 0;
332 	tp->t_ccstate->reno_acked = 0;
333 
334 	tp->t_ccstate->in_loss = 0;
335 }
336 
337 static bool
rtt_elapsed(uint32_t largest_snd_nxt,uint32_t ack)338 rtt_elapsed(uint32_t largest_snd_nxt, uint32_t ack)
339 {
340 	/*
341 	 * A packet with sequence higher than largest sent sequence at
342 	 * the start of the round has been acknowledged
343 	 * Packet must have been sent after the processing of this ACK
344 	 */
345 	return largest_snd_nxt == 0 || SEQ_GT(ack, largest_snd_nxt);
346 }
347 
348 /*
349  * RTT independence using square of RTT ratio to acheive rate fairness.
350  * For additive increase, alpha = (RTT / REF_RTT) ^ 2
351  * 1. Grow by 1MSS per target RTT; i.e. by a ratio of actual / target RTT and
352  * 2. Growing by 1MSS translates to a sending rate increase proportional to the
353  * same RTT ratio.
354  *
355  * Update infrequently whenever a change in srtt is expected.
356  */
357 static void
prague_ai_alpha_rate(struct tcpcb * tp)358 prague_ai_alpha_rate(struct tcpcb *tp)
359 {
360 	uint32_t srtt = tp->t_srtt >> TCP_RTT_SHIFT;
361 	if (srtt == 0 || srtt > REF_RTT_RATE) {
362 		tp->t_ccstate->prague_alpha_ai = (1 << CWND_SHIFT);
363 		return;
364 	}
365 
366 	uint64_t numer = srtt << CWND_SHIFT;
367 	numer *= srtt;
368 	uint64_t divisor = REF_RTT_RATE * REF_RTT_RATE;
369 
370 	tp->t_ccstate->prague_alpha_ai = (uint64_t)((numer + (divisor >> 1)) / divisor);
371 }
372 
373 /*
374  * Handle an ACK in congestion avoidance phase
375  * after the decrease happened due to CE
376  */
377 static void
prague_ca_after_ce(struct tcpcb * tp,uint32_t acked)378 prague_ca_after_ce(struct tcpcb *tp, uint32_t acked)
379 {
380 	tp->t_bytes_acked += acked;
381 	/*
382 	 * To increase cwnd by 1MSS, we need cwnd / alpha_ai bytes
383 	 * to be acknowledged. Scale cwnd by CWND_SHIFT as alpha_ai
384 	 * is already scaled to avoid floating point arithmetic
385 	 */
386 	uint64_t bytes_needed_for_increase = (uint64_t)tp->snd_cwnd << CWND_SHIFT;
387 	bytes_needed_for_increase /= tp->t_ccstate->prague_alpha_ai;
388 
389 	if (tp->t_bytes_acked >= bytes_needed_for_increase) {
390 		tp->t_bytes_acked -= bytes_needed_for_increase;
391 		tp->snd_cwnd += tp->t_maxseg;
392 	}
393 }
394 
395 static void
prague_update_alpha(struct tcpcb * tp,uint32_t ack,uint32_t packets_marked,uint32_t packets_acked)396 prague_update_alpha(struct tcpcb *tp, uint32_t ack, uint32_t packets_marked,
397     uint32_t packets_acked)
398 {
399 	if (!rtt_elapsed(tp->t_ccstate->snd_nxt_alpha, ack)) {
400 		/* One RTT hasn't elapsed yet, don't update alpha */
401 		os_log(OS_LOG_DEFAULT, "one RTT hasn't elapsed, not updating alpha");
402 		return;
403 	}
404 
405 	if (!tp->t_ccstate->ever_saw_ce) {
406 		return;
407 	}
408 
409 	uint32_t newly_marked = 0, newly_acked = 0;
410 
411 	if (packets_marked > tp->t_ccstate->prague_packets_marked) {
412 		newly_marked = packets_marked - tp->t_ccstate->prague_packets_marked;
413 	}
414 
415 	if (packets_acked > tp->t_ccstate->prague_packets_acked) {
416 		newly_acked = packets_acked - tp->t_ccstate->prague_packets_acked;
417 	} else {
418 		os_log_error(OS_LOG_DEFAULT,
419 		    "No new packets were ACK'ed, we shouldn't be called");
420 		return;
421 	}
422 
423 	uint64_t scaled_alpha = tp->t_ccstate->prague_scaled_alpha;
424 
425 	/*
426 	 * We currently don't react to local AQM for TCP Prague
427 	 */
428 	uint64_t p = (newly_marked << ALPHA_SHIFT) / newly_acked;
429 	/*
430 	 * Equation for alpha,
431 	 * alpha = (1 - g) * alpha + g * F (fraction of marked / acked)
432 	 * alpha = alpha - (alpha >> g_shift) + (marked << (alpha_shift -
433 	 * g_shift)) / acked, OR
434 	 * scaled_alpha = scaled_alpha - (scaled_alpha >> g_shift) +
435 	 * (marked << alpha_shift) / acked
436 	 */
437 	scaled_alpha = scaled_alpha - (scaled_alpha >> G_SHIFT) + p;
438 	tp->t_ccstate->prague_scaled_alpha = MIN(MAX_ALPHA << G_SHIFT, scaled_alpha);
439 
440 	/* New round for alpha */
441 	tp->t_ccstate->snd_nxt_alpha = tp->snd_nxt;
442 	tp->t_ccstate->prague_packets_marked = packets_marked;
443 	tp->t_ccstate->prague_packets_acked = packets_acked;
444 }
445 
446 static bool
prague_cwr(struct tcpcb * tp)447 prague_cwr(struct tcpcb *tp)
448 {
449 	// If we are currently in loss recovery, then do nothing
450 	if (tp->t_ccstate->in_loss) {
451 		os_log(OS_LOG_DEFAULT, "currently in loss recovery, no need to do CWR");
452 		return false;
453 	}
454 
455 	tp->t_ccstate->num_cong_events_ce++;
456 	const uint64_t alpha = tp->t_ccstate->prague_scaled_alpha >> G_SHIFT;
457 
458 	/*
459 	 * For Prague, the recovery time is only set during packet
460 	 * loss and we allow any ACKs that don't have CE marks to
461 	 * increase cwnd during ack_end, even in CWR state.
462 	 *
463 	 * On entering CWR, cwnd = cwnd * (1 - DCTCP.alpha) / 2
464 	 */
465 	uint64_t reduction =
466 	    (tp->snd_cwnd * alpha) >> (ALPHA_SHIFT + 1);
467 	tp->snd_cwnd -= reduction;
468 
469 	/* If no more increase due to non-CE acked bytes, then round it */
470 	if (tp->t_ccstate->prague_bytes_acked == 0) {
471 		tp->snd_cwnd = tcp_round_to(tp->snd_cwnd, tp->t_maxseg);
472 	}
473 	/* Should be at least 2 MSS */
474 	if (tp->snd_cwnd < 2 * tp->t_maxseg) {
475 		tp->snd_cwnd =  2 * tp->t_maxseg;
476 	}
477 
478 	tp->snd_ssthresh = tp->snd_cwnd;
479 
480 	tp->t_ccstate->reduced_due_to_ce = true;
481 
482 	return true;
483 }
484 
485 static void
tcp_prague_process_ecn(struct tcpcb * tp,struct tcphdr * th,uint32_t new_bytes_marked,uint32_t packets_marked,uint32_t packets_acked)486 tcp_prague_process_ecn(struct tcpcb *tp, struct tcphdr *th, uint32_t new_bytes_marked,
487     uint32_t packets_marked, uint32_t packets_acked)
488 {
489 	if (__improbable(packets_marked < tp->t_ccstate->prague_ce_counter ||
490 	    packets_acked < tp->t_ccstate->prague_packets_acked)) {
491 		os_log_error(OS_LOG_DEFAULT, "new CE count (%u) can't be less than current CE count (%u)"
492 		    "OR newly ACKed (%u) can't be less that current ACKed (%u)",
493 		    packets_marked, tp->t_ccstate->prague_ce_counter,
494 		    packets_acked, tp->t_ccstate->prague_packets_acked);
495 	}
496 
497 	if (packets_marked > tp->t_ccstate->prague_ce_counter) {
498 		tp->t_ccstate->ever_saw_ce = true;
499 	}
500 	/*
501 	 * update alpha of fraction of marked packets,
502 	 * even when there are no new CE counts
503 	 */
504 	if (packets_acked > tp->t_ccstate->prague_packets_acked) {
505 		prague_update_alpha(tp, th->th_ack, packets_marked, packets_acked);
506 	}
507 
508 	if (packets_marked == tp->t_ccstate->prague_ce_counter) {
509 		/* No change in CE */
510 		return;
511 	}
512 
513 	os_log(OS_LOG_DEFAULT, "%u packets were newly CE marked",
514 	    packets_marked - tp->t_ccstate->prague_ce_counter);
515 	/*
516 	 * Received an ACK with new CE counts, subtract CE marked bytes
517 	 * from bytes_acked, so that we use only unmarked bytes to
518 	 * increase cwnd during ACK processing
519 	 */
520 	if (tp->t_ccstate->prague_bytes_acked > new_bytes_marked) {
521 		tp->t_ccstate->prague_bytes_acked -= new_bytes_marked;
522 	} else {
523 		tp->t_ccstate->prague_bytes_acked = 0;
524 	}
525 
526 	/* Update CE count even if we are already in CWR */
527 	tp->t_ccstate->prague_ce_counter = packets_marked;
528 
529 	/* Update AIMD alpha as SRTT might have changed */
530 	prague_ai_alpha_rate(tp);
531 
532 	if (!rtt_elapsed(tp->t_ccstate->snd_nxt_cwr, th->th_ack)) {
533 		/* One RTT hasn't elapsed yet, don't doing CWR */
534 		os_log(OS_LOG_DEFAULT, "one RTT hasn't elapsed, not doing CWR");
535 		return;
536 	}
537 
538 	/* CWR reduction if new counts are received */
539 	bool cwnd_changed = prague_cwr(tp);
540 
541 	/* Update pacer state if cwnd has changed */
542 	if (cwnd_changed) {
543 		tcp_update_pacer_state(tp);
544 	}
545 	/* New round for CWR */
546 	tp->t_ccstate->snd_nxt_cwr = tp->snd_nxt;
547 }
548 
549 static void
tcp_prague_set_bytes_acked(struct tcpcb * tp,uint32_t acked)550 tcp_prague_set_bytes_acked(struct tcpcb *tp, uint32_t acked)
551 {
552 	/* Set bytes_acked which will be used later during ack_rcvd() */
553 	tp->t_ccstate->prague_bytes_acked = acked;
554 }
555 
556 static void
tcp_prague_clear_state(struct tcpcb * tp)557 tcp_prague_clear_state(struct tcpcb *tp)
558 {
559 	tp->snd_cwnd_prev = 0;
560 	tp->t_ccstate->num_cong_events_loss = 0;
561 	tp->t_ccstate->num_cong_events_ce = 0;
562 	tp->t_ccstate->prague_alpha_ai = (1 << CWND_SHIFT);
563 
564 	/* CUBIC state */
565 	tp->t_ccstate->cubic_K = 0;
566 	//prague->cubic_acked = 0;
567 	tp->t_ccstate->cubic_epoch_start = 0;
568 	tp->t_ccstate->cubic_origin_point = 0;
569 	tp->t_ccstate->cubic_W_max = 0;
570 }
571 
572 int
tcp_prague_init(struct tcpcb * tp)573 tcp_prague_init(struct tcpcb *tp)
574 {
575 	os_atomic_inc(&tcp_cc_prague.num_sockets, relaxed);
576 
577 	VERIFY(tp->t_ccstate != NULL);
578 
579 	tp->t_ccstate->prague_scaled_alpha = (MAX_ALPHA << G_SHIFT);
580 	tcp_prague_clear_state(tp);
581 	return 0;
582 }
583 
584 int
tcp_prague_cleanup(struct tcpcb * tp)585 tcp_prague_cleanup(struct tcpcb *tp)
586 {
587 #pragma unused(tp)
588 	os_atomic_dec(&tcp_cc_prague.num_sockets, relaxed);
589 	return 0;
590 }
591 
592 /*
593  * Initialize the congestion window for a connection
594  */
595 void
tcp_prague_cwnd_init_or_reset(struct tcpcb * tp)596 tcp_prague_cwnd_init_or_reset(struct tcpcb *tp)
597 {
598 	VERIFY(tp->t_ccstate != NULL);
599 
600 	tcp_prague_clear_state(tp);
601 	tcp_cc_cwnd_init_or_reset(tp);
602 	tp->t_pipeack = 0;
603 	tcp_clear_pipeack_state(tp);
604 
605 	/* Start counting bytes for RFC 3465 again */
606 	tp->t_bytes_acked = 0;
607 
608 	/*
609 	 * slow start threshold could get initialized to a lower value
610 	 * when there is a cached value in the route metrics. In this case,
611 	 * the connection can enter congestion avoidance without any packet
612 	 * loss and Cubic will enter steady-state too early. It is better
613 	 * to always probe to find the initial slow-start threshold.
614 	 */
615 	if (tp->t_inpcb->inp_mstat.ms_total.ts_txbytes <= tcp_initial_cwnd(tp) &&
616 	    tp->snd_ssthresh < (TCP_MAXWIN << TCP_MAX_WINSHIFT)) {
617 		tp->snd_ssthresh = TCP_MAXWIN << TCP_MAX_WINSHIFT;
618 	}
619 
620 	/* Initialize cubic last max to be same as ssthresh */
621 	tp->t_ccstate->cubic_W_max = tp->snd_ssthresh;
622 
623 	/* Set initial pacer state */
624 	tcp_update_pacer_state(tp);
625 }
626 
627 static void
tcp_prague_after_timeout(struct tcpcb * tp)628 tcp_prague_after_timeout(struct tcpcb *tp)
629 {
630 	VERIFY(tp->t_ccstate != NULL);
631 
632 	/*
633 	 * Avoid adjusting congestion window due to SYN retransmissions.
634 	 * If more than one byte (SYN) is outstanding then it is still
635 	 * needed to adjust the window.
636 	 */
637 	if (tp->t_state < TCPS_ESTABLISHED &&
638 	    ((int)(tp->snd_max - tp->snd_una) <= 1)) {
639 		return;
640 	}
641 
642 	if (!IN_FASTRECOVERY(tp)) {
643 		tcp_prague_clear_state(tp);
644 		tcp_prague_pre_fr(tp);
645 	}
646 
647 	/*
648 	 * Close the congestion window down to one segment as a retransmit
649 	 * timeout might indicate severe congestion.
650 	 */
651 	tp->snd_cwnd = tp->t_maxseg;
652 
653 	tcp_update_pacer_state(tp);
654 }
655 
656 static int
tcp_prague_delay_ack(struct tcpcb * tp,struct tcphdr * th)657 tcp_prague_delay_ack(struct tcpcb *tp, struct tcphdr *th)
658 {
659 	return tcp_cc_delay_ack(tp, th);
660 }
661 
662 /*
663  * When switching from a different CC it is better for Cubic to start
664  * fresh. The state required for Cubic calculation might be stale and it
665  * might not represent the current state of the network. If it starts as
666  * a new connection it will probe and learn the existing network conditions.
667  */
668 static void
tcp_prague_switch_cc(struct tcpcb * tp)669 tcp_prague_switch_cc(struct tcpcb *tp)
670 {
671 	tcp_prague_cwnd_init_or_reset(tp);
672 
673 	os_atomic_inc(&tcp_cc_prague.num_sockets, relaxed);
674 }
675