diff --git a/ambersim/trajopt/base.py b/ambersim/trajopt/base.py new file mode 100644 index 00000000..2e50975e --- /dev/null +++ b/ambersim/trajopt/base.py @@ -0,0 +1,172 @@ +from typing import Tuple + +import jax +from flax import struct +from jax import grad, hessian + +# ####################### # +# TRAJECTORY OPTIMIZATION # +# ####################### # + + +@struct.dataclass +class TrajectoryOptimizerParams: + """The parameters for generic trajectory optimization algorithms. + + Parameters we may want to optimize should be included here. + + This is left completely empty to allow for maximum flexibility in the API. Some examples: + - A direct collocation method might have parameters for the number of collocation points, the collocation + scheme, and the number of optimization iterations. + - A shooting method might have parameters for the number of shooting points, the shooting scheme, and the number + of optimization iterations. + The parameters also include initial iterates for each type of algorithm. Some examples: + - A direct collocation method might have initial iterates for the controls and the state trajectory. + - A shooting method might have initial iterates for the controls only. + + Parameters which we want to remain untouched by JAX transformations can be marked by pytree_node=False, e.g., + ``` + @struct.dataclass + class ChildParams: + ... + # example field + example: int = struct.field(pytree_node=False) + ... + ``` + """ + + +@struct.dataclass +class TrajectoryOptimizer: + """The API for generic trajectory optimization algorithms on mechanical systems. + + We choose to implement this as a flax dataclass (as opposed to a regular class whose functions operate on pytree + nodes) because: + (1) the OOP formalism allows us to define coherent abstractions through inheritance; + (2) struct.dataclass registers dataclasses a pytree nodes, so we can deal with awkward issues like the `self` + variable when using JAX transformations on methods of the dataclass. + + Further, we choose not to specify the mjx.Model as either a field of this dataclass or as a parameter. The reason is + because we want to allow for maximum flexibility in the API. Two motivating scenarios: + (1) we want to domain randomize over the model parameters and potentially optimize for them. In this case, it makes + sense to specify the mjx.Model as a parameter that gets passed as an input into the optimize function. + (2) we want to fix the model and only randomize/optimize over non-model-parameters. For instance, this is the + situation in vanilla predictive sampling. If we don't need to pass the model, we instead initialize it as a + field of this dataclass, which makes the optimize function more performant, since it can just reference the + fixed model attribute of the optimizer instead of applying JAX transformations to the entire large model pytree. + Similar logic applies for not specifying the role of the CostFunction - we trust that the user will either use the + provided API or will ignore it and still end up implementing something custom and reasonable. + + Finally, abstract dataclasses are weird, so we just make all children implement the below functions by instead + raising a NotImplementedError. + """ + + def optimize(self, params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array, jax.Array]: + """Optimizes a trajectory. + + The shapes of the outputs include (?) because we may choose to return non-zero-order-hold parameterizations of + the optimized trajectories (for example, we could choose to return a cubic spline parameterization of the + control inputs over the trajectory as is done in the gradient-based methods of MJPC). + + Args: + params: The parameters of the trajectory optimizer. + + Returns: + xs_star (shape=(N + 1, nq + nv) or (?)): The optimized trajectory. + us_star (shape=(N, nu) or (?)): The optimized controls. + """ + raise NotImplementedError + + +# ############# # +# COST FUNCTION # +# ############# # + + +@struct.dataclass +class CostFunctionParams: + """Generic parameters for cost functions.""" + + +@struct.dataclass +class CostFunction: + """The API for generic cost functions for trajectory optimization problems for mechanical systems. + + Rationale behind CostFunctionParams in this generic API: + (1) computation of higher-order derivatives could depend on results or intermediates from lower-order derivatives. + So, we can flexibly cache the requisite values to avoid repeated computation; + (2) we may want to randomize or optimize the cost function parameters themselves, so specifying a generic pytree + as input generically accounts for all possibilities; + (3) there could simply be parameters that cannot be easily specified in advance that are key for cost evaluation, + like a time-varying reference trajectory that gets updated in real time. + (4) histories of higher-order derivatives can be useful for updating their current estimates, e.g., BFGS. + """ + + def cost(self, xs: jax.Array, us: jax.Array, params: CostFunctionParams) -> Tuple[jax.Array, CostFunctionParams]: + """Computes the cost of a trajectory. + + Args: + xs (shape=(N + 1, nq + nv)): The state trajectory. + us (shape=(N, nu)): The controls over the trajectory. + params: The parameters of the cost function. + + Returns: + val (shape=(,)): The cost of the trajectory. + new_params: The updated parameters of the cost function. + """ + raise NotImplementedError + + def grad( + self, xs: jax.Array, us: jax.Array, params: CostFunctionParams + ) -> Tuple[jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: + """Computes the gradient of the cost of a trajectory. + + The default implementation of this function uses JAX's autodiff. Simply override this function if you would like + to supply an analytical gradient. + + Args: + xs (shape=(N + 1, nq + nv)): The state trajectory. + us (shape=(N, nu)): The controls over the trajectory. + params: The parameters of the cost function. + + Returns: + gcost_xs (shape=(N + 1, nq + nv): The gradient of the cost wrt xs. + gcost_us (shape=(N, nu)): The gradient of the cost wrt us. + gcost_params: The gradient of the cost wrt params. + new_params: The updated parameters of the cost function. + """ + _fn = lambda xs, us, params: self.cost(xs, us, params)[0] # only differentiate wrt the cost val + return grad(_fn, argnums=(0, 1, 2))(xs, us, params) + (params,) + + def hess( + self, xs: jax.Array, us: jax.Array, params: CostFunctionParams + ) -> Tuple[ + jax.Array, jax.Array, CostFunctionParams, jax.Array, CostFunctionParams, CostFunctionParams, CostFunctionParams + ]: + """Computes the Hessian of the cost of a trajectory. + + The default implementation of this function uses JAX's autodiff. Simply override this function if you would like + to supply an analytical Hessian. + + Let t, s be times 0, 1, 2, etc. Then, d^2H/da_{t,i}db_{s,j} = Hcost_asbs[t, i, s, j]. + + Args: + xs (shape=(N + 1, nq + nv)): The state trajectory. + us (shape=(N, nu)): The controls over the trajectory. + params: The parameters of the cost function. + + Returns: + Hcost_xsxs (shape=(N + 1, nq + nv, N + 1, nq + nv)): The Hessian of the cost wrt xs. + Hcost_xsus (shape=(N + 1, nq + nv, N, nu)): The Hessian of the cost wrt xs and us. + Hcost_xsparams: The Hessian of the cost wrt xs and params. + Hcost_usus (shape=(N, nu, N, nu)): The Hessian of the cost wrt us. + Hcost_usparams: The Hessian of the cost wrt us and params. + Hcost_paramsall: The Hessian of the cost wrt params and everything else. + new_params: The updated parameters of the cost function. + """ + _fn = lambda xs, us, params: self.cost(xs, us, params)[0] # only differentiate wrt the cost val + hessians = hessian(_fn, argnums=(0, 1, 2))(xs, us, params) + Hcost_xsxs, Hcost_xsus, Hcost_xsparams = hessians[0] + _, Hcost_usus, Hcost_usparams = hessians[1] + Hcost_paramsall = hessians[2] + return Hcost_xsxs, Hcost_xsus, Hcost_xsparams, Hcost_usus, Hcost_usparams, Hcost_paramsall, params diff --git a/ambersim/trajopt/cost.py b/ambersim/trajopt/cost.py new file mode 100644 index 00000000..fa560128 --- /dev/null +++ b/ambersim/trajopt/cost.py @@ -0,0 +1,178 @@ +from typing import Tuple + +import jax +import jax.numpy as jnp +from flax import struct +from jax import lax, vmap + +from ambersim.trajopt.base import CostFunction, CostFunctionParams + +"""A collection of common cost functions.""" + + +class StaticGoalQuadraticCost(CostFunction): + """A quadratic cost function that penalizes the distance to a static goal. + + This is the most vanilla possible quadratic cost. The cost matrices are static (defined at init time) and so is the + single, fixed goal. The gradient is as compressed as it can be in general (one matrix multiplication), but the + Hessian can be far more compressed by simplying referencing Q, Qf, and R - this implementation is inefficient and + dense. + """ + + def __init__(self, Q: jax.Array, Qf: jax.Array, R: jax.Array, xg: jax.Array) -> None: + """Initializes a quadratic cost function. + + Args: + Q (shape=(nx, nx)): The state cost matrix. + Qf (shape=(nx, nx)): The final state cost matrix. + R (shape=(nu, nu)): The control cost matrix. + xg (shape=(nq,)): The goal state. + """ + self.Q = Q + self.Qf = Qf + self.R = R + self.xg = xg + + @staticmethod + def batch_quadform(bs: jax.Array, A: jax.Array) -> jax.Array: + """Computes a batched quadratic form for a single instance of A. + + Args: + bs (shape=(..., n)): The batch of vectors. + A (shape=(n, n)): The matrix. + + Returns: + val (shape=(...,)): The batch of quadratic forms. + """ + return jnp.einsum("...i,ij,...j->...", bs, A, bs) + + @staticmethod + def batch_matmul(bs: jax.Array, A: jax.Array) -> jax.Array: + """Computes a batched matrix multiplication for a single instance of A. + + Args: + bs (shape=(..., n)): The batch of vectors. + A (shape=(n, n)): The matrix. + + Returns: + val (shape=(..., n)): The batch of matrix multiplications. + """ + return jnp.einsum("...i,ij->...j", bs, A) + + def cost(self, xs: jax.Array, us: jax.Array, params: CostFunctionParams) -> Tuple[jax.Array, CostFunctionParams]: + """Computes the cost of a trajectory. + + cost = 0.5 * (xs - xg)' @ Q @ (xs - xg) + 0.5 * us' @ R @ us + + Args: + xs (shape=(N + 1, nq + nv)): The state trajectory. + us (shape=(N, nu)): The controls over the trajectory. + params: Unused. Included for API compliance. + + Returns: + cost_val: The cost of the trajectory. + new_params: Unused. Included for API compliance. + """ + xs_err = xs[:-1, :] - self.xg # errors before the terminal state + xf_err = xs[-1, :] - self.xg + val = 0.5 * jnp.squeeze( + ( + jnp.sum(self.batch_quadform(xs_err, self.Q)) + + self.batch_quadform(xf_err, self.Qf) + + jnp.sum(self.batch_quadform(us, self.R)) + ) + ) + return val, params + + def grad( + self, xs: jax.Array, us: jax.Array, params: CostFunctionParams + ) -> Tuple[jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: + """Computes the gradient of the cost of a trajectory. + + Args: + xs (shape=(N + 1, nq + nv)): The state trajectory. + us (shape=(N, nu)): The controls over the trajectory. + params: Unused. Included for API compliance. + + Returns: + gcost_xs (shape=(N + 1, nq + nv): The gradient of the cost wrt xs. + gcost_us (shape=(N, nu)): The gradient of the cost wrt us. + gcost_params: Unused. Included for API compliance. + new_params: Unused. Included for API compliance. + """ + xs_err = xs[:-1, :] - self.xg # errors before the terminal state + xf_err = xs[-1, :] - self.xg + gcost_xs = jnp.concatenate( + ( + self.batch_matmul(xs_err, self.Q), + (self.Qf @ xf_err)[None, :], + ), + axis=-2, + ) + gcost_us = self.batch_matmul(us, self.R) + return gcost_xs, gcost_us, params, params + + def hess( + self, xs: jax.Array, us: jax.Array, params: CostFunctionParams + ) -> Tuple[ + jax.Array, jax.Array, CostFunctionParams, jax.Array, CostFunctionParams, CostFunctionParams, CostFunctionParams + ]: + """Computes the gradient of the cost of a trajectory. + + Let t, s be times 0, 1, 2, etc. Then, d^2H/da_{t,i}db_{s,j} = Hcost_asbs[t, i, s, j]. + + Args: + xs (shape=(N + 1, nq + nv)): The state trajectory. + us (shape=(N, nu)): The controls over the trajectory. + params: Unused. Included for API compliance. + + Returns: + Hcost_xsxs (shape=(N + 1, nq + nv, N + 1, nq + nv)): The Hessian of the cost wrt xs. + Hcost_xsus (shape=(N + 1, nq + nv, N, nu)): The Hessian of the cost wrt xs and us. + Hcost_xsparams: The Hessian of the cost wrt xs and params. + Hcost_usus (shape=(N, nu, N, nu)): The Hessian of the cost wrt us. + Hcost_usparams: The Hessian of the cost wrt us and params. + Hcost_paramsall: The Hessian of the cost wrt params and everything else. + new_params: The updated parameters of the cost function. + """ + # setting up + nx = self.Q.shape[0] + N, nu = us.shape + Q = self.Q + Qf = self.Qf + R = self.R + dummy_params = CostFunctionParams() + + # Hessian for state + Hcost_xsxs = jnp.zeros((N + 1, nx, N + 1, nx)) + Hcost_xsxs = vmap( + lambda i: lax.dynamic_update_slice( + jnp.zeros((nx, N + 1, nx)), + Q[:, None, :], + (0, i, 0), + ) + )( + jnp.arange(N + 1) + ) # only the terms [i, :, i, :] are nonzero + Hcost_xsxs = Hcost_xsxs.at[-1, :, -1, :].set(Qf) # last one is different + + # trivial cross-terms of Hessian + Hcost_xsus = jnp.zeros((N + 1, nx, N, nu)) + Hcost_xsparams = dummy_params + + # Hessian for control inputs + Hcost_usus = jnp.zeros((N, nu, N, nu)) + Hcost_usus = vmap( + lambda i: lax.dynamic_update_slice( + jnp.zeros((nu, N, nu)), + R[:, None, :], + (0, i, 0), + ) + )( + jnp.arange(N) + ) # only the terms [i, :, i, :] are nonzero + + # trivial cross-terms and Hessian for params + Hcost_usparams = dummy_params + Hcost_paramsall = dummy_params + return Hcost_xsxs, Hcost_xsus, Hcost_xsparams, Hcost_usus, Hcost_usparams, Hcost_paramsall, params diff --git a/ambersim/trajopt/shooting.py b/ambersim/trajopt/shooting.py new file mode 100644 index 00000000..414917a1 --- /dev/null +++ b/ambersim/trajopt/shooting.py @@ -0,0 +1,157 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +from flax import struct +from jax import lax, vmap +from mujoco import mjx +from mujoco.mjx import step + +from ambersim.trajopt.base import TrajectoryOptimizer, TrajectoryOptimizerParams +from ambersim.trajopt.cost import CostFunction + +"""Shooting methods and their derived subclasses.""" + + +# ##### # +# UTILS # +# ##### # + + +def shoot(m: mjx.Model, x0: jax.Array, us: jax.Array) -> jax.Array: + """Utility function that shoots a model forward given a sequence of control inputs. + + Args: + m: The model. + x0: The initial state. + us: The control inputs. + + Returns: + xs (shape=(N + 1, nq + nv)): The state trajectory. + """ + # initializing the data + d = mjx.make_data(m) + d = d.replace(qpos=x0[: m.nq], qvel=x0[m.nq :]) # setting the initial state. + d = mjx.forward(m, d) # setting other internal states like acceleration without integrating + + def scan_fn(d, u): + """Integrates the model forward one step given the control input u.""" + d = d.replace(ctrl=u) + d = step(m, d) + x = jnp.concatenate((d.qpos, d.qvel)) # (nq + nv,) + return d, x + + # scan over the control inputs to get the trajectory. + _, _xs = lax.scan(scan_fn, d, us, length=us.shape[0]) + xs = jnp.concatenate((x0[None, :], _xs), axis=0) # (N + 1, nq + nv) + return xs + + +# ################ # +# SHOOTING METHODS # +# ################ # + +# vanilla API + + +@struct.dataclass +class ShootingParams(TrajectoryOptimizerParams): + """Parameters for shooting methods.""" + + # inputs into the algorithm + x0: jax.Array # shape=(nq + nv,) or (?) + us_guess: jax.Array # shape=(N, nu) or (?) + + @property + def N(self) -> int: + """The number of time steps. + + By default, we assume us_guess represents the ZOH control inputs. However, in the case that it is actually an + alternate parameterization, we may need to compute N some other way, which requires overwriting this method. + """ + return self.us_guess.shape[0] + + +@struct.dataclass +class ShootingAlgorithm(TrajectoryOptimizer): + """A trajectory optimization algorithm based on shooting methods.""" + + def optimize(self, params: ShootingParams) -> Tuple[jax.Array, jax.Array]: + """Optimizes a trajectory using a shooting method. + + Args: + params: The parameters of the trajectory optimizer. + + Returns: + xs_star (shape=(N + 1, nq) or (?)): The optimized trajectory. + us_star (shape=(N, nu) or (?)): The optimized controls. + """ + raise NotImplementedError + + +# predictive sampling API + + +@struct.dataclass +class VanillaPredictiveSamplerParams(ShootingParams): + """Parameters for generic predictive sampling methods.""" + + key: jax.Array # random key for sampling + + +@struct.dataclass +class VanillaPredictiveSampler(ShootingAlgorithm): + """A vanilla predictive sampler object. + + The following choices are made: + (1) the model parameters are fixed, and are therefore a field of this dataclass; + (2) the control sequence is parameterized as a ZOH sequence instead of a spline; + (3) the control inputs are sampled from a normal distribution with a uniformly chosen noise scale over all params. + (4) the cost function is quadratic in the states and controls. + """ + + model: mjx.Model + cost_function: CostFunction + nsamples: int = struct.field(pytree_node=False) + stdev: float = struct.field(pytree_node=False) # noise scale, parameters theta_new ~ N(theta, (stdev ** 2) * I) + + def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, jax.Array]: + """Optimizes a trajectory using a vanilla predictive sampler. + + Args: + params: The parameters of the trajectory optimizer. + + Returns: + xs (shape=(N + 1, nq + nv)): The optimized trajectory. + us (shape=(N, nu)): The optimized controls. + """ + # unpack the params + m = self.model + nsamples = self.nsamples + stdev = self.stdev + + x0 = params.x0 + us_guess = params.us_guess + N = params.N + key = params.key + + # sample over the control inputs - the first sample is the guess, since it's possible that it's the best one + noise = jnp.concatenate( + (jnp.zeros((1, N, m.nu)), jax.random.normal(key, shape=(nsamples - 1, N, m.nu)) * stdev), axis=0 + ) + _us_samples = us_guess + noise + + # clamping the samples to their control limits + limits = m.actuator_ctrlrange + clip_fn = partial(jnp.clip, a_min=limits[:, 0], a_max=limits[:, 1]) # clipping function with limits already set + us_samples = vmap(vmap(clip_fn))(_us_samples) # apply limits only to the last dim, need a nested vmap + + # predict many samples, evaluate them, and return the best trajectory tuple + # vmap over the input data and the control trajectories + xs_samples = vmap(shoot, in_axes=(None, None, 0))(m, x0, us_samples) + costs, _ = vmap(self.cost_function.cost, in_axes=(0, 0, None))(xs_samples, us_samples, None) # (nsamples,) + best_idx = jnp.argmin(costs) + xs_star = lax.dynamic_slice(xs_samples, (best_idx, 0, 0), (1, N + 1, m.nq + m.nv))[0] # (N + 1, nq + nv) + us_star = lax.dynamic_slice(us_samples, (best_idx, 0, 0), (1, N, m.nu))[0] # (N, nu) + return xs_star, us_star diff --git a/tests/trajopt/test_cost.py b/tests/trajopt/test_cost.py new file mode 100644 index 00000000..3eb0c8de --- /dev/null +++ b/tests/trajopt/test_cost.py @@ -0,0 +1,55 @@ +import jax +import jax.numpy as jnp +from jax import hessian, jacobian + +from ambersim.trajopt.base import CostFunctionParams +from ambersim.trajopt.cost import StaticGoalQuadraticCost +from ambersim.utils.io_utils import load_mjx_model_and_data_from_file + + +def test_sgqc(): + """Tests that the StaticGoalQuadraticCost works correctly.""" + # loading model and cost function + model, _ = load_mjx_model_and_data_from_file("models/barrett_hand/bh280.xml", force_float=False) + cost_function = StaticGoalQuadraticCost( + Q=jnp.eye(model.nq + model.nv), + Qf=10.0 * jnp.eye(model.nq + model.nv), + R=0.01 * jnp.eye(model.nu), + xg=jnp.zeros(model.nq + model.nv), + ) + + # generating dummy data + N = 10 + key = jax.random.PRNGKey(0) + xs = jax.random.normal(key=key, shape=(N + 1, model.nq + model.nv)) + us = jax.random.normal(key=key, shape=(N, model.nu)) + + # comparing cost value vs. ground truth for loop + val_test, _ = cost_function.cost(xs, us, params=CostFunctionParams()) + val_gt = 0.0 + for i in range(N): + xs_err = xs[i, :] - cost_function.xg + val_gt += 0.5 * (xs_err @ cost_function.Q @ xs_err + us[i, :] @ cost_function.R @ us[i, :]) + xs_err = xs[N, :] - cost_function.xg + val_gt += 0.5 * (xs_err @ cost_function.Qf @ xs_err) + val_gt = jnp.squeeze(val_gt) + assert jnp.allclose(val_test, val_gt) + + # comparing cost gradients vs. jax autodiff + gcost_xs_test, gcost_us_test, _, _ = cost_function.grad(xs, us, params=CostFunctionParams()) + gcost_xs_gt, gcost_us_gt, _, _ = super(StaticGoalQuadraticCost, cost_function).grad( + xs, us, params=CostFunctionParams() + ) + assert jnp.allclose(gcost_xs_test, gcost_xs_gt) + assert jnp.allclose(gcost_us_test, gcost_us_gt) + + # comparing cost Hessians vs. jax autodiff + Hcost_xsxs_test, Hcost_xsus_test, _, Hcost_usus_test, _, _, _ = cost_function.hess( + xs, us, params=CostFunctionParams() + ) + Hcost_xsxs_gt, Hcost_xsus_gt, _, Hcost_usus_gt, _, _, _ = super(StaticGoalQuadraticCost, cost_function).hess( + xs, us, params=CostFunctionParams() + ) + assert jnp.allclose(Hcost_xsxs_test, Hcost_xsxs_gt) + assert jnp.allclose(Hcost_xsus_test, Hcost_xsus_gt) + assert jnp.allclose(Hcost_usus_test, Hcost_usus_gt) diff --git a/tests/trajopt/test_predictive_sampler.py b/tests/trajopt/test_predictive_sampler.py new file mode 100644 index 00000000..7111fc05 --- /dev/null +++ b/tests/trajopt/test_predictive_sampler.py @@ -0,0 +1,87 @@ +import os + +import jax +import jax.numpy as jnp +import pytest +from jax import jit, vmap +from mujoco.mjx._src.types import DisableBit + +from ambersim.trajopt.base import CostFunctionParams +from ambersim.trajopt.cost import StaticGoalQuadraticCost +from ambersim.trajopt.shooting import VanillaPredictiveSampler, VanillaPredictiveSamplerParams, shoot +from ambersim.utils.io_utils import load_mjx_model_and_data_from_file + +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" # fixes OOM error + + +@pytest.fixture +def vps_data(): + """Makes data required for testing vanilla predictive sampling.""" + # initializing the predictive sampler + model, _ = load_mjx_model_and_data_from_file("models/barrett_hand/bh280.xml", force_float=False) + model = model.replace( + opt=model.opt.replace( + timestep=0.002, # dt + iterations=1, # number of Newton steps to take during solve + ls_iterations=4, # number of line search iterations along step direction + integrator=0, # Euler semi-implicit integration + solver=2, # Newton solver + disableflags=DisableBit.CONTACT, # disable contact for this test + ) + ) + cost_function = StaticGoalQuadraticCost( + Q=jnp.eye(model.nq + model.nv), + Qf=10.0 * jnp.eye(model.nq + model.nv), + R=0.01 * jnp.eye(model.nu), + xg=jnp.zeros(model.nq + model.nv), + ) + nsamples = 100 + stdev = 0.01 + ps = VanillaPredictiveSampler(model=model, cost_function=cost_function, nsamples=nsamples, stdev=stdev) + return ps, model, cost_function + + +def test_smoke_VPS(vps_data): + """Simple smoke test to make sure we can run inputs through the vanilla predictive sampler + jit.""" + ps, model, _ = vps_data + + # sampler parameters + key = jax.random.PRNGKey(0) # random seed for the predictive sampler + x0 = jnp.zeros(model.nq + model.nv) + num_steps = 10 + us_guess = jnp.zeros((num_steps, model.nu)) + params = VanillaPredictiveSamplerParams(key=key, x0=x0, us_guess=us_guess) + + # sampling the best sequence of qs, vs, and us + optimize_fn = jit(ps.optimize) + assert optimize_fn(params) + + +def test_VPS_cost_decrease(vps_data): + """Tests to make sure vanilla predictive sampling decreases (or maintains) the cost.""" + # set up sampler and cost function + ps, model, cost_function = vps_data + + # batched sampler parameters + batch_size = 10 + key = jax.random.PRNGKey(0) # random seed for the predictive sampler + x0 = jax.random.normal(key=key, shape=(batch_size, model.nq + model.nv)) + + key, subkey = jax.random.split(key) + num_steps = 10 + us_guess = jax.random.normal(key=subkey, shape=(batch_size, num_steps, model.nu)) + + keys = jax.random.split(key, num=batch_size) + params = VanillaPredictiveSamplerParams(key=keys, x0=x0, us_guess=us_guess) + + # sampling with the vanilla predictive sampler + xs_stars, us_stars = vmap(ps.optimize)(params) + + # "optimal" rollout from predictive sampling + vmap_cost = jit(vmap(lambda xs, us: cost_function.cost(xs, us, CostFunctionParams())[0], in_axes=(0, 0))) + costs_star = vmap_cost(xs_stars, us_stars) + + # simply shooting the random initial guess + xs_guess = vmap(shoot, in_axes=(None, 0, 0))(model, x0, us_guess) + costs_guess = vmap_cost(xs_guess, us_guess) + assert jnp.all(costs_star <= costs_guess)