RLax does not adhere to the original GAE formula during the last advantage calculation.

1

The original authors define generalized advantage as $$\hat{A}_t^{GAE(\gamma,\lambda)} := \sum_{l=0}^{\infty}(\gamma \lambda)^l \delta^V_{t+l}$$. Thus, the last advantage $$\hat{A}_T^{GAE(\gamma,\lambda)} = \gamma \lambda \delta_T^V$$, where $$T$$ is the last timestep.

However, the RLax implementation of GAE sets $$A_T=\delta_T^V$$:

def truncated_generalized_advantage_estimation(r_t: Array, discount_t: Array, lambda_: Union[Array, Scalar], values: Array, stop_target_gradients: bool = False) -> Array: lambda_ = jnp.ones_like(discount_t) * lambda_ # If scalar, make into vector. delta_t = r_t + discount_t * values[1:] - values[:-1] # Iterate backwards to calculate advantages. def _body(acc, xs): deltas, discounts, lambda_ = xs acc = deltas + discounts * lambda_ * acc return acc, acc _, advantage_t = jax.lax.scan( _body, 0.0, (delta_t, discount_t, lambda_), reverse=True) return jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(advantage_t), advantage_t)

To confirm this, we manually trace the first invocation of _body(acc, xs), which returns $$A_T$$. The _body function calculates advantages according to the formula $$A_t = \delta_t + \gamma_t \lambda_t A_{t+1}$$, where $$A_{t+1}$$ is the first parameter of _body. Since $$A_{T+1}$$ is zero during the first invocation of _body, the last advantage reduces to $$A_T = \delta_T + \gamma_T \lambda_T \cdot 0 = \delta_T$$, which does not adhere to the original GAE definition for $$ \lambda \neq 1 \lor \gamma \neq 1$$.

Contributions

FS worked on all aspects of this post, including research, analysis and writing.