1
The original authors define generalized advantage as
$$\hat{A}_t^{GAE(\gamma,\lambda)} := \sum_{l=0}^{\infty}(\gamma \lambda)^l \delta^V_{t+l}$$.
Thus, the last advantage $$\hat{A}_T^{GAE(\gamma,\lambda)} = \gamma \lambda \delta_T^V$$, where $$T$$ is the last timestep.
However, the RLax implementation of GAE sets $$A_T=\delta_T^V$$:
def truncated_generalized_advantage_estimation(r_t: Array, discount_t: Array,
lambda_: Union[Array, Scalar], values: Array, stop_target_gradients: bool = False) -> Array:
lambda_ = jnp.ones_like(discount_t) * lambda_ # If scalar, make into vector.
delta_t = r_t + discount_t * values[1:] - values[:-1]
# Iterate backwards to calculate advantages.
def _body(acc, xs):
deltas, discounts, lambda_ = xs
acc = deltas + discounts * lambda_ * acc
return acc, acc
_, advantage_t = jax.lax.scan(
_body, 0.0, (delta_t, discount_t, lambda_), reverse=True)
return jax.lax.select(stop_target_gradients,
jax.lax.stop_gradient(advantage_t),
advantage_t)
To confirm this, we manually trace the first invocation of _body(acc, xs), which returns $$A_T$$.
The _body function calculates advantages according to the formula
$$A_t = \delta_t + \gamma_t \lambda_t A_{t+1}$$, where $$A_{t+1}$$ is the first parameter of _body.
Since $$A_{T+1}$$ is zero during the first invocation of _body, the last advantage reduces to
$$A_T = \delta_T + \gamma_T \lambda_T \cdot 0 = \delta_T$$, which does not adhere to the original GAE definition for $$ \lambda \neq 1 \lor \gamma \neq 1$$.