From 49267de1814d487b37ef5020340f4eb25546f8df Mon Sep 17 00:00:00 2001 From: alberthli Date: Sat, 2 Dec 2023 14:52:50 -0800 Subject: [PATCH 01/28] extremely generic API for trajectory optimization --- ambersim/trajopt/base.py | 54 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 ambersim/trajopt/base.py diff --git a/ambersim/trajopt/base.py b/ambersim/trajopt/base.py new file mode 100644 index 00000000..0a6188e2 --- /dev/null +++ b/ambersim/trajopt/base.py @@ -0,0 +1,54 @@ +from typing import Tuple + +import jax +from flax import struct + + +@struct.dataclass +class TrajectoryOptimizerParams: + """The parameters for generic trajectory optimization algorithms. + + Parameters we may want to optimize should be included here. + + This is left completely empty to allow for maximum flexibility in the API. Some examples: + - A direct collocation method might have parameters for the number of collocation points, the collocation + scheme, and the number of optimization iterations. + - A shooting method might have parameters for the number of shooting points, the shooting scheme, and the number + of optimization iterations. + The parameters also include initial iterates for each type of algorithm. Some examples: + - A direct collocation method might have initial iterates for the controls and the state trajectory. + - A shooting method might have initial iterates for the controls only. + + Parameters which we want to remain untouched by JAX transformations can be marked by pytree_node=False, e.g., + ``` + @struct.dataclass + class ChildParams: + ... + # example field + example: int = struct.field(pytree_node=False) + ... + ``` + """ + + +@struct.dataclass +class TrajectoryOptimizer: + """The API for generic trajectory optimization algorithms.""" + + def __init__(self) -> None: + """Initialize the trajopt object.""" + + @staticmethod + def optimize(params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array]: + """Optimizes a trajectory. + + Args: + params: The parameters of the trajectory optimizer. + + Returns: + xs (shape=(N + 1, nq + nv)): The optimized trajectory. + us (shape=(N, nu)): The optimized controls. + """ + # abstract dataclasses are weird, so we just make all children implement this - to be useful, they need it + # anyway, so it isn't really a problem if an "abstract" TrajectoryOptimizer is instantiated. + raise NotImplementedError From 1035dd4899df3745b5b117e9cce233b4fdf3ecd0 Mon Sep 17 00:00:00 2001 From: alberthli Date: Sat, 2 Dec 2023 14:53:22 -0800 Subject: [PATCH 02/28] fix spacing --- ambersim/trajopt/base.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/ambersim/trajopt/base.py b/ambersim/trajopt/base.py index 0a6188e2..79c678ce 100644 --- a/ambersim/trajopt/base.py +++ b/ambersim/trajopt/base.py @@ -11,23 +11,23 @@ class TrajectoryOptimizerParams: Parameters we may want to optimize should be included here. This is left completely empty to allow for maximum flexibility in the API. Some examples: - - A direct collocation method might have parameters for the number of collocation points, the collocation - scheme, and the number of optimization iterations. - - A shooting method might have parameters for the number of shooting points, the shooting scheme, and the number - of optimization iterations. + - A direct collocation method might have parameters for the number of collocation points, the collocation + scheme, and the number of optimization iterations. + - A shooting method might have parameters for the number of shooting points, the shooting scheme, and the number + of optimization iterations. The parameters also include initial iterates for each type of algorithm. Some examples: - - A direct collocation method might have initial iterates for the controls and the state trajectory. - - A shooting method might have initial iterates for the controls only. + - A direct collocation method might have initial iterates for the controls and the state trajectory. + - A shooting method might have initial iterates for the controls only. Parameters which we want to remain untouched by JAX transformations can be marked by pytree_node=False, e.g., - ``` - @struct.dataclass - class ChildParams: - ... - # example field - example: int = struct.field(pytree_node=False) - ... - ``` + ``` + @struct.dataclass + class ChildParams: + ... + # example field + example: int = struct.field(pytree_node=False) + ... + ``` """ @@ -43,11 +43,11 @@ def optimize(params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array]: """Optimizes a trajectory. Args: - params: The parameters of the trajectory optimizer. + params: The parameters of the trajectory optimizer. Returns: - xs (shape=(N + 1, nq + nv)): The optimized trajectory. - us (shape=(N, nu)): The optimized controls. + xs (shape=(N + 1, nq + nv)): The optimized trajectory. + us (shape=(N, nu)): The optimized controls. """ # abstract dataclasses are weird, so we just make all children implement this - to be useful, they need it # anyway, so it isn't really a problem if an "abstract" TrajectoryOptimizer is instantiated. From 22f9867caa9131c2f83f821b5eb1680249246c62 Mon Sep 17 00:00:00 2001 From: alberthli Date: Sat, 2 Dec 2023 14:56:36 -0800 Subject: [PATCH 03/28] remove unnecessary __init__ method (brainfart) --- ambersim/trajopt/base.py | 3 --- ambersim/trajopt/shooting.py | 14 ++++++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) create mode 100644 ambersim/trajopt/shooting.py diff --git a/ambersim/trajopt/base.py b/ambersim/trajopt/base.py index 79c678ce..0b1716b5 100644 --- a/ambersim/trajopt/base.py +++ b/ambersim/trajopt/base.py @@ -35,9 +35,6 @@ class ChildParams: class TrajectoryOptimizer: """The API for generic trajectory optimization algorithms.""" - def __init__(self) -> None: - """Initialize the trajopt object.""" - @staticmethod def optimize(params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array]: """Optimizes a trajectory. diff --git a/ambersim/trajopt/shooting.py b/ambersim/trajopt/shooting.py new file mode 100644 index 00000000..d2242768 --- /dev/null +++ b/ambersim/trajopt/shooting.py @@ -0,0 +1,14 @@ +import jax +from flax import struct + +from ambersim.trajopt.base import TrajectoryOptimizer, TrajectoryOptimizerParams + + +@struct.dataclass +class ShootingParams(TrajectoryOptimizerParams): + """Parameters for shooting methods.""" + + +@struct.dataclass +class ShootingAlgorithm(TrajectoryOptimizer): + """A trajectory optimization algorithm based on shooting methods.""" From b8d7e8e2f042bdf3d5c2cace6443d8d5d65a3983 Mon Sep 17 00:00:00 2001 From: alberthli Date: Sat, 2 Dec 2023 15:09:56 -0800 Subject: [PATCH 04/28] some additional massaging for best abstraction + useless scaffolding for shooting methods --- ambersim/trajopt/base.py | 20 ++++++++++++++++---- ambersim/trajopt/shooting.py | 17 +++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/ambersim/trajopt/base.py b/ambersim/trajopt/base.py index 0b1716b5..60ca2607 100644 --- a/ambersim/trajopt/base.py +++ b/ambersim/trajopt/base.py @@ -33,18 +33,30 @@ class ChildParams: @struct.dataclass class TrajectoryOptimizer: - """The API for generic trajectory optimization algorithms.""" + """The API for generic trajectory optimization algorithms on mechanical systems. + + We choose to implement this as a flax dataclass (as opposed to a regular class whose functions operate on pytree + nodes) because: + (1) the OOP formalism allows us to define coherent abstractions through inheritance; + (2) struct.dataclass registers dataclasses a pytree nodes, so we can deal with awkward issues like the `self` + variable when using JAX transformations on methods of the dataclass. + """ @staticmethod - def optimize(params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array]: + def optimize(params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array, jax.Array]: """Optimizes a trajectory. + The shapes of the outputs include (?) because we may choose to return non-zero-order-hold parameterizations of + the optimized trajectories (for example, we could choose to return a cubic spline parameterization of the + control inputs over the trajectory as is done in the gradient-based methods of MJPC). + Args: params: The parameters of the trajectory optimizer. Returns: - xs (shape=(N + 1, nq + nv)): The optimized trajectory. - us (shape=(N, nu)): The optimized controls. + qs (shape=(N + 1, nq) or (?)): The optimized trajectory. + vs (shape=(N + 1, nv) or (?)): The optimized generalized velocities. + us (shape=(N, nu) or (?)): The optimized controls. """ # abstract dataclasses are weird, so we just make all children implement this - to be useful, they need it # anyway, so it isn't really a problem if an "abstract" TrajectoryOptimizer is instantiated. diff --git a/ambersim/trajopt/shooting.py b/ambersim/trajopt/shooting.py index d2242768..24b9e881 100644 --- a/ambersim/trajopt/shooting.py +++ b/ambersim/trajopt/shooting.py @@ -1,5 +1,8 @@ +from typing import Tuple + import jax from flax import struct +from mujoco import mjx from ambersim.trajopt.base import TrajectoryOptimizer, TrajectoryOptimizerParams @@ -8,7 +11,21 @@ class ShootingParams(TrajectoryOptimizerParams): """Parameters for shooting methods.""" + model = mjx.Model + @struct.dataclass class ShootingAlgorithm(TrajectoryOptimizer): """A trajectory optimization algorithm based on shooting methods.""" + + @staticmethod + def optimize(params: ShootingParams) -> Tuple[jax.Array, jax.Array]: + """Optimizes a trajectory. + + Args: + params: The parameters of the trajectory optimizer. + + Returns: + xs (shape=(N + 1, nq + nv)): The optimized trajectory. + us (shape=(N, nu)): The optimized controls. + """ From f5610fc323d18f500dfb70ace21abe2cf4e87c76 Mon Sep 17 00:00:00 2001 From: alberthli Date: Sat, 2 Dec 2023 16:12:26 -0800 Subject: [PATCH 05/28] add some informative comments to base.py --- ambersim/trajopt/base.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/ambersim/trajopt/base.py b/ambersim/trajopt/base.py index 60ca2607..d5043ffd 100644 --- a/ambersim/trajopt/base.py +++ b/ambersim/trajopt/base.py @@ -40,10 +40,18 @@ class TrajectoryOptimizer: (1) the OOP formalism allows us to define coherent abstractions through inheritance; (2) struct.dataclass registers dataclasses a pytree nodes, so we can deal with awkward issues like the `self` variable when using JAX transformations on methods of the dataclass. + + Further, we choose not to specify the mjx.Model as either a field of this dataclass or as a parameter. The reason is + because we want to allow for maximum flexibility in the API. Two motivating scenarios: + (1) we want to domain randomize over the model parameters and potentially optimize for them. In this case, it makes + sense to specify the mjx.Model as a parameter that gets passed as an input into the optimize function. + (2) we want to fix the model and only randomize/optimize over non-model-parameters. For instance, this is the + situation in vanilla predictive sampling. If we don't need to pass the model, we instead initialize it as a + field of this dataclass, which makes the optimize function more performant, since it can just reference the + fixed model attribute of the optimizer instead of applying JAX transformations to the entire large model pytree. """ - @staticmethod - def optimize(params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array, jax.Array]: + def optimize(self, params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array, jax.Array]: """Optimizes a trajectory. The shapes of the outputs include (?) because we may choose to return non-zero-order-hold parameterizations of @@ -54,9 +62,9 @@ def optimize(params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array, j params: The parameters of the trajectory optimizer. Returns: - qs (shape=(N + 1, nq) or (?)): The optimized trajectory. - vs (shape=(N + 1, nv) or (?)): The optimized generalized velocities. - us (shape=(N, nu) or (?)): The optimized controls. + qs_star (shape=(N + 1, nq) or (?)): The optimized trajectory. + vs_star (shape=(N + 1, nv) or (?)): The optimized generalized velocities. + us_star (shape=(N, nu) or (?)): The optimized controls. """ # abstract dataclasses are weird, so we just make all children implement this - to be useful, they need it # anyway, so it isn't really a problem if an "abstract" TrajectoryOptimizer is instantiated. From 17642f6883685179ffccf179a7ee5dbdfa98f5a7 Mon Sep 17 00:00:00 2001 From: alberthli Date: Sat, 2 Dec 2023 17:09:23 -0800 Subject: [PATCH 06/28] [maybe] added cost function to base API --- ambersim/trajopt/base.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/ambersim/trajopt/base.py b/ambersim/trajopt/base.py index d5043ffd..c4c69ea7 100644 --- a/ambersim/trajopt/base.py +++ b/ambersim/trajopt/base.py @@ -49,8 +49,24 @@ class TrajectoryOptimizer: situation in vanilla predictive sampling. If we don't need to pass the model, we instead initialize it as a field of this dataclass, which makes the optimize function more performant, since it can just reference the fixed model attribute of the optimizer instead of applying JAX transformations to the entire large model pytree. + + Finally, abstract dataclasses are weird, so we just make all children implement the below functions by instead + raising a NotImplementedError. """ + def cost(self, qs: jax.Array, vs: jax.Array, us: jax.Array) -> jax.Array: + """Computes the cost of a trajectory. + + Args: + qs: The generalized positions over the trajectory. + vs: The generalized velocities over the trajectory. + us: The controls over the trajectory. + + Returns: + The cost of the trajectory. + """ + raise NotImplementedError + def optimize(self, params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array, jax.Array]: """Optimizes a trajectory. @@ -66,6 +82,4 @@ def optimize(self, params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Ar vs_star (shape=(N + 1, nv) or (?)): The optimized generalized velocities. us_star (shape=(N, nu) or (?)): The optimized controls. """ - # abstract dataclasses are weird, so we just make all children implement this - to be useful, they need it - # anyway, so it isn't really a problem if an "abstract" TrajectoryOptimizer is instantiated. raise NotImplementedError From 4e1ae77a24195f18db01432190ff252e2613f7f4 Mon Sep 17 00:00:00 2001 From: alberthli Date: Sat, 2 Dec 2023 17:10:21 -0800 Subject: [PATCH 07/28] shooting method APIs, no cost function implemented yet --- ambersim/trajopt/shooting.py | 159 +++++++++++++++++++++++++++++++++-- 1 file changed, 154 insertions(+), 5 deletions(-) diff --git a/ambersim/trajopt/shooting.py b/ambersim/trajopt/shooting.py index 24b9e881..14a7825d 100644 --- a/ambersim/trajopt/shooting.py +++ b/ambersim/trajopt/shooting.py @@ -1,31 +1,180 @@ from typing import Tuple import jax +import jax.numpy as jnp from flax import struct +from jax import lax, vmap from mujoco import mjx +from mujoco.mjx import step from ambersim.trajopt.base import TrajectoryOptimizer, TrajectoryOptimizerParams +"""Shooting methods and their derived subclasses.""" + + +# ##### # +# UTILS # +# ##### # + + +def shoot(m: mjx.Model, q0: jax.Array, v0: jax.Array, us: jax.Array) -> Tuple[jax.Array, jax.Array]: + """Utility function that shoots a model forward given a sequence of control inputs. + + Args: + m: The model. + q0: The initial generalized coordinates. + v0: The initial generalized velocities. + us: The control inputs. + + Returns: + qs (shape=(N + 1, nq)): The generalized coordinates. + vs (shape=(N + 1, nv)): The generalized velocities. + """ + d = mjx.make_data(m) + d = d.replace(qpos=q0, qvel=v0) # setting the initial state. + + def scan_fn(d, u): + """Integrates the model forward one step given the control input u.""" + d = d.replace(ctrl=u) + d = step(m, d) + x = jnp.concatenate((d.qpos, d.qvel)) # (nq + nv,) + return d, x + + # scan over the control inputs to get the trajectory. + _, xs = lax.scan(scan_fn, d, us, length=us.shape[0]) + _qs = xs[:, : m.nq] + _vs = xs[:, m.nq : m.nq + m.nv] + qs = jnp.concatenate((q0[None, :], _qs), axis=0) # (N + 1, nq) + vs = jnp.concatenate((v0[None, :], _vs), axis=0) # (N + 1, nv) + return qs, vs + + +# ################ # +# SHOOTING METHODS # +# ################ # + +# vanilla API + @struct.dataclass class ShootingParams(TrajectoryOptimizerParams): """Parameters for shooting methods.""" - model = mjx.Model + # inputs into the algorithm + us_guess: jax.Array # shape=(N, nu) or (?) + + @property + def N(self) -> int: + """The number of time steps. + + By default, we assume us_guess represents the ZOH control inputs. However, in the case that it is actually an + alternate parameterization, we may need to compute N some other way, which requires overwriting this method. + """ + return self.us_guess.shape[0] @struct.dataclass class ShootingAlgorithm(TrajectoryOptimizer): """A trajectory optimization algorithm based on shooting methods.""" - @staticmethod - def optimize(params: ShootingParams) -> Tuple[jax.Array, jax.Array]: - """Optimizes a trajectory. + def cost(self, qs: jax.Array, vs: jax.Array, us: jax.Array) -> jax.Array: + """Computes the cost of a trajectory. + + Args: + qs: The generalized positions over the trajectory. + vs: The generalized velocities over the trajectory. + us: The controls over the trajectory. + + Returns: + The cost of the trajectory. + """ + raise NotImplementedError + + def optimize(self, params: ShootingParams) -> Tuple[jax.Array, jax.Array, jax.Array]: + """Optimizes a trajectory using a shooting method. + + Args: + params: The parameters of the trajectory optimizer. + + Returns: + qs (shape=(N + 1, nq) or (?)): The optimized trajectory. + vs (shape=(N + 1, nv) or (?)): The optimized generalized velocities. + us (shape=(N, nu) or (?)): The optimized controls. + """ + raise NotImplementedError + + +# predictive sampling API + + +@struct.dataclass +class VanillaPredictiveSamplerParams(ShootingParams): + """Parameters for generic predictive sampling methods.""" + + key: jax.Array # random key for sampling + + +@struct.dataclass +class VanillaPredictiveSampler(ShootingAlgorithm): + """A vanilla predictive sampler object. + + The following choices are made: + (1) the model parameters are fixed, and are therefore a field of this dataclass; + (2) the control sequence is parameterized as a ZOH sequence instead of a spline; + (3) the control inputs are sampled from a normal distribution with a uniformly chosen noise scale over all params. + (4) the cost function is quadratic in the states and controls. + """ + + model: mjx.Model = struct.field(pytree_node=False) + nsamples: int = struct.field(pytree_node=False) + stdev: float = struct.field(pytree_node=False) # noise scale, parameters theta_new ~ N(theta, (stdev ** 2) * I) + + def cost(self, qs: jax.Array, vs: jax.Array, us: jax.Array) -> jax.Array: + """Computes the cost of a trajectory using quadratic weights.""" + raise NotImplementedError # TODO(ahl): implement this, maybe make a library of parametric cost functions + + def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, jax.Array, jax.Array]: + """Optimizes a trajectory using a vanilla predictive sampler. Args: params: The parameters of the trajectory optimizer. Returns: - xs (shape=(N + 1, nq + nv)): The optimized trajectory. + qs (shape=(N + 1, nq)): The optimized trajectory. + vs (shape=(N + 1, nv)): The optimized generalized velocities. us (shape=(N, nu)): The optimized controls. """ + # unpack the params + m = self.model + nsamples = self.nsamples + + q0 = params.q0 + v0 = params.v0 + us_guess = params.us_guess + N = params.N + key = params.key + + # sample over the control inputs + # TODO(ahl): write a create classmethod that allows the user to set default_limits optionally with some semi- + # reasonable default value + keys = jax.random.split(key, nsamples) # splitting the random key so we can vmap over random sampling + _us_samples = vmap(jax.random.normal, in_axes=(0, None))(keys, shape=(N, m.nu)) + us_guess # (nsamples, N, nu) + + # clamping the samples to their control limits + # TODO(ahl): check whether joints with no limits have reasonable defaults for m.actuator_ctrlrange + limits = m.actuator_ctrlrange + us_samples = vmap(vmap(jnp.clip, in_axes=(0, None, None)), in_axes=(0, None, None))( + _us_samples, a_min=limits[:, 0], a_max=limits[:, 1] + ) # apply limits only to the last dim, need a nested vmap + # limited = m.actuator_ctrllimited[:, None] # (nu, 1) whether each actuator has limited control authority + # default_limits = jnp.array([[-1000.0, 1000.0]] * m.nu) # (nu, 2) default limits for each actuator + # limits = jnp.where(limited, m.actuator_ctrlrange, default_limits) # (nu, 2) + + # predict many samples, evaluate them, and return the best trajectory tuple + qs_samples, vs_samples = vmap(shoot, in_axes=(None, None, None, 0))(m, q0, v0, us_samples) + costs = vmap(self.cost)(qs_samples, vs_samples, us_samples) # (nsamples,) + best_idx = jnp.argmin(costs) + qs_star = lax.dynamic_slice(qs_samples, (best_idx, 0, 0), (1, N + 1, m.nq))[0] # (N + 1, nq) + vs_star = lax.dynamic_slice(vs_samples, (best_idx, 0, 0), (1, N + 1, m.nv))[0] # (N + 1, nv) + us_star = lax.dynamic_slice(us_samples, (best_idx, 0, 0), (1, N, m.nu))[0] # (N, nu) + return qs_star, vs_star, us_star From 8c9cd6b676bb3f1117c64c3c1e0b9cdc1198e0c0 Mon Sep 17 00:00:00 2001 From: alberthli Date: Sun, 3 Dec 2023 15:21:28 -0800 Subject: [PATCH 08/28] define generic CostFunction API --- ambersim/trajopt/base.py | 108 ++++++++++++++++++++++++++++++++++----- 1 file changed, 95 insertions(+), 13 deletions(-) diff --git a/ambersim/trajopt/base.py b/ambersim/trajopt/base.py index c4c69ea7..385405ca 100644 --- a/ambersim/trajopt/base.py +++ b/ambersim/trajopt/base.py @@ -2,6 +2,11 @@ import jax from flax import struct +from jax import grad, hessian + +# ####################### # +# TRAJECTORY OPTIMIZATION # +# ####################### # @struct.dataclass @@ -49,24 +54,13 @@ class TrajectoryOptimizer: situation in vanilla predictive sampling. If we don't need to pass the model, we instead initialize it as a field of this dataclass, which makes the optimize function more performant, since it can just reference the fixed model attribute of the optimizer instead of applying JAX transformations to the entire large model pytree. + Similar logic applies for not specifying the role of the CostFunction - we trust that the user will either use the + provided API or will ignore it and still end up implementing something custom and reasonable. Finally, abstract dataclasses are weird, so we just make all children implement the below functions by instead raising a NotImplementedError. """ - def cost(self, qs: jax.Array, vs: jax.Array, us: jax.Array) -> jax.Array: - """Computes the cost of a trajectory. - - Args: - qs: The generalized positions over the trajectory. - vs: The generalized velocities over the trajectory. - us: The controls over the trajectory. - - Returns: - The cost of the trajectory. - """ - raise NotImplementedError - def optimize(self, params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array, jax.Array]: """Optimizes a trajectory. @@ -83,3 +77,91 @@ def optimize(self, params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Ar us_star (shape=(N, nu) or (?)): The optimized controls. """ raise NotImplementedError + + +# ############# # +# COST FUNCTION # +# ############# # + + +@struct.dataclass +class CostFunctionParams: + """Generic parameters for cost functions.""" + + +@struct.dataclass +class CostFunction: + """The API for generic cost functions for trajectory optimization problems for mechanical systems. + + Rationale behind CostFunctionParams in this generic API: + (1) computation of higher-order derivatives could depend on results or intermediates from lower-order derivatives. + So, we can flexibly cache the requisite values to avoid repeated computation; + (2) we may want to randomize or optimize the cost function parameters themselves, so specifying a generic pytree + as input generically accounts for all possibilities; + (3) there could simply be parameters that cannot be easily specified in advance that are key for cost evaluation, + like a time-varying reference trajectory that gets updated in real time. + (4) histories of higher-order derivatives can be useful for updating their current estimates, e.g., BFGS. + """ + + def cost( + self, qs: jax.Array, vs: jax.Array, us: jax.Array, params: CostFunctionParams + ) -> Tuple[jax.Array, CostFunctionParams]: + """Computes the cost of a trajectory. + + Args: + qs (shape=(N + 1, nq)): The generalized positions over the trajectory. + vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. + us (shape=(N, nu)): The controls over the trajectory. + + Returns: + val (shape=(,)): The cost of the trajectory. + new_params: The updated parameters of the cost function. + """ + raise NotImplementedError + + def grad( + self, qs: jax.Array, vs: jax.Array, us: jax.Array, params: CostFunctionParams + ) -> Tuple[jax.Array, jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: + """Computes the gradient of the cost of a trajectory. + + The default implementation of this function uses JAX's autodiff. Simply override this function if you would like + to supply an analytical gradient. + + Args: + qs (shape=(N + 1, nq)): The generalized positions over the trajectory. + vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. + us (shape=(N, nu)): The controls over the trajectory. + + Returns: + gcost_qs (shape=(N + 1, nq): The gradient of the cost wrt qs. + gcost_vs (shape=(N + 1, nv): The gradient of the cost wrt vs. + gcost_us (shape=(N, nu)): The gradient of the cost wrt us. + gcost_params: The gradient of the cost wrt params. + new_params: The updated parameters of the cost function. + """ + return grad(self.cost, argnums=(0, 1, 2, 3))(qs, vs, us, params) + (params,) + + def hess( + self, qs: jax.Array, vs: jax.Array, us: jax.Array, params: CostFunctionParams + ) -> Tuple[jax.Array, jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: + """Computes the Hessian of the cost of a trajectory. + + The default implementation of this function uses JAX's autodiff. Simply override this function if you would like + to supply an analytical Hessian. + + Args: + qs (shape=(N + 1, nq)): The generalized positions over the trajectory. + vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. + us (shape=(N, nu)): The controls over the trajectory. + + Returns: + Hcost_qs (shape=(N + 1, nq, N + 1, nq)): The Hessian of the cost wrt qs. + Let t, s be times from 0 to N + 1. Then, d^2/dq_{t,i}dq_{s,j} = Hcost_qs[t, i, s, j]. + Hcost_vs (shape=(N + 1, nv, N + 1, nv)): The Hessian of the cost wrt vs. + Let t, s be times from 0 to N + 1. Then, d^2/dv_{t,i}dv_{s,j} = Hcost_vs[t, i, s, j]. + Hcost_us (shape=(N, nu, N, nu)): The Hessian of the cost wrt us. + Let t, s be times from 0 to N. Then, d^2/du_{t,i}du_{s,j} = Hcost_us[t, i, s, j]. + Hcost_params: The Hessian of the cost wrt params. + new_params: The updated parameters of the cost function. + """ + return hessian(self.cost, argnums=(0, 1, 2, 3))(qs, vs, us, params) + (params,) From 0d175f022ee0bdc7aee31ed6e616313b6a6fd4c3 Mon Sep 17 00:00:00 2001 From: alberthli Date: Sun, 3 Dec 2023 15:22:25 -0800 Subject: [PATCH 09/28] expose a CostFunction field of shooting methods --- ambersim/trajopt/shooting.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/ambersim/trajopt/shooting.py b/ambersim/trajopt/shooting.py index 14a7825d..21d8db2d 100644 --- a/ambersim/trajopt/shooting.py +++ b/ambersim/trajopt/shooting.py @@ -8,6 +8,7 @@ from mujoco.mjx import step from ambersim.trajopt.base import TrajectoryOptimizer, TrajectoryOptimizerParams +from ambersim.trajopt.cost import CostFunction """Shooting methods and their derived subclasses.""" @@ -77,19 +78,6 @@ def N(self) -> int: class ShootingAlgorithm(TrajectoryOptimizer): """A trajectory optimization algorithm based on shooting methods.""" - def cost(self, qs: jax.Array, vs: jax.Array, us: jax.Array) -> jax.Array: - """Computes the cost of a trajectory. - - Args: - qs: The generalized positions over the trajectory. - vs: The generalized velocities over the trajectory. - us: The controls over the trajectory. - - Returns: - The cost of the trajectory. - """ - raise NotImplementedError - def optimize(self, params: ShootingParams) -> Tuple[jax.Array, jax.Array, jax.Array]: """Optimizes a trajectory using a shooting method. @@ -125,14 +113,11 @@ class VanillaPredictiveSampler(ShootingAlgorithm): (4) the cost function is quadratic in the states and controls. """ - model: mjx.Model = struct.field(pytree_node=False) + model: mjx.Model + cost: CostFunction nsamples: int = struct.field(pytree_node=False) stdev: float = struct.field(pytree_node=False) # noise scale, parameters theta_new ~ N(theta, (stdev ** 2) * I) - def cost(self, qs: jax.Array, vs: jax.Array, us: jax.Array) -> jax.Array: - """Computes the cost of a trajectory using quadratic weights.""" - raise NotImplementedError # TODO(ahl): implement this, maybe make a library of parametric cost functions - def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, jax.Array, jax.Array]: """Optimizes a trajectory using a vanilla predictive sampler. @@ -172,7 +157,7 @@ def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, j # predict many samples, evaluate them, and return the best trajectory tuple qs_samples, vs_samples = vmap(shoot, in_axes=(None, None, None, 0))(m, q0, v0, us_samples) - costs = vmap(self.cost)(qs_samples, vs_samples, us_samples) # (nsamples,) + costs, _ = vmap(self.cost, in_axes=(0, 0, 0, None))(qs_samples, vs_samples, us_samples, None) # (nsamples,) best_idx = jnp.argmin(costs) qs_star = lax.dynamic_slice(qs_samples, (best_idx, 0, 0), (1, N + 1, m.nq))[0] # (N + 1, nq) vs_star = lax.dynamic_slice(vs_samples, (best_idx, 0, 0), (1, N + 1, m.nv))[0] # (N + 1, nv) From 987e3348372f49dbd639f0c2e8a38e9ac5e61f64 Mon Sep 17 00:00:00 2001 From: alberthli Date: Sun, 3 Dec 2023 15:26:52 -0800 Subject: [PATCH 10/28] add specific implementation of (non-sparse) quadratic CostFunction with static goal and cost matrices --- ambersim/trajopt/cost.py | 170 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 ambersim/trajopt/cost.py diff --git a/ambersim/trajopt/cost.py b/ambersim/trajopt/cost.py new file mode 100644 index 00000000..b09d1cb7 --- /dev/null +++ b/ambersim/trajopt/cost.py @@ -0,0 +1,170 @@ +from typing import Tuple + +import jax +import jax.numpy as jnp +from flax import struct + +from ambersim.trajopt.base import CostFunction, CostFunctionParams + +"""A collection of common cost functions.""" + + +class StaticGoalQuadraticCost(CostFunction): + """A quadratic cost function that penalizes the distance to a static goal. + + This is the most vanilla possible quadratic cost. The cost matrices are static (defined at init time) and so is the + single, fixed goal. The gradient is as compressed as it can be in general (one matrix multiplication), but the + Hessian can be far more compressed by simplying referencing Q, Qf, and R - this implementation is inefficient and + dense. + """ + + def __init__(self, Q: jax.Array, Qf: jax.Array, R: jax.Array, qg: jax.Array, vg: jax.Array) -> None: + """Initializes a quadratic cost function. + + Args: + Q (shape=(nx, nx)): The state cost matrix. + Qf (shape=(nx, nx)): The final state cost matrix. + R (shape=(nu, nu)): The control cost matrix. + qg (shape=(nq,)): The goal generalized coordinates. + vg (shape=(nv,)): The goal generalized velocities. + """ + self.Q = Q + self.Qf = Qf + self.R = R + self.qg = qg + self.vg = vg + + @staticmethod + def _setup_util( + qs: jax.Array, vs: jax.Array, qg: jax.Array, vg: jax.Array + ) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array, int, int, int]: + """Utility function that sets up the cost function. + + Args: + qs (shape=(N + 1, nq)): The generalized positions over the trajectory. + vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. + qg (shape=(nx,)): The goal generalized position. + vg (shape=(nv,)): The goal generalized velocity. + + Returns: + xs (shape=(N + 1, nx)): The states over the trajectory. + xg (shape=(nx,)): The goal state. + xs_err (shape=(N, nx)): The state errors up to the final state. + xf_err (shape=(nx,)): The state error at the final state. + nq: The number of generalized coordinates. + nv: The number of generalized velocities. + """ + xs = jnp.concatenate((qs, vs), axis=-1) + xg = jnp.concatenate((qg, vg), axis=-1) + xs_err = xs[:-1, :] - xg + xf_err = xs[-1, :] - xg + return xs, xg, xs_err, xf_err, qs.shape[-1], vs.shape[-1] + + @staticmethod + def batch_quadform(bs: jax.Array, A: jax.array) -> jax.Array: + """Computes a batched quadratic form for a single instance of A. + + Args: + bs (shape=(..., n)): The batch of vectors. + A (shape=(n, n)): The matrix. + + Returns: + val (shape=(...,)): The batch of quadratic forms. + """ + return jnp.einsum("...i,ij,...j->...", bs, A, bs) + + @staticmethod + def batch_matmul(bs: jax.Array, A: jax.Array) -> jax.Array: + """Computes a batched matrix multiplication for a single instance of A. + + Args: + bs (shape=(..., n)): The batch of vectors. + A (shape=(n, n)): The matrix. + + Returns: + val (shape=(..., n)): The batch of matrix multiplications. + """ + return jnp.einsum("...i,ij->...j", bs, A) + + def cost( + self, qs: jax.Array, vs: jax.Array, us: jax.Array, params: CostFunctionParams + ) -> Tuple[jax.Array, CostFunctionParams]: + """Computes the cost of a trajectory. + + cost = 0.5 * ([q; v] - [qg; vg])' @ Q @ ([q; v] - [qg; vg]) + 0.5 * u' @ R @ u + + Args: + qs (shape=(N + 1, nq)): The generalized positions over the trajectory. + vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. + us (shape=(N, nu)): The controls over the trajectory. + + Returns: + cost_val: The cost of the trajectory. + new_params: Unused. Included for API compliance. + """ + xs, xg, xs_err, xf_err, _, _ = self._setup_util(qs, vs, self.qg, self.vg) + val = 0.5 * ( + self.batch_quadform(xs_err, self.Q) + self.batch_quadform(xf_err, self.Qf) + self.batch_quadform(us, self.R) + ) + return val, params + + def grad( + self, qs: jax.Array, vs: jax.Array, us: jax.Array, params: CostFunctionParams + ) -> Tuple[jax.Array, jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: + """Computes the gradient of the cost of a trajectory. + + Args: + qs (shape=(N + 1, nq)): The generalized positions over the trajectory. + vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. + us (shape=(N, nu)): The controls over the trajectory. + + Returns: + gcost_qs (shape=(N + 1, nq): The gradient of the cost wrt qs. + gcost_vs (shape=(N + 1, nv): The gradient of the cost wrt vs. + gcost_us (shape=(N, nu)): The gradient of the cost wrt us. + gcost_params: Unused. Included for API compliance. + new_params: Unused. Included for API compliance. + """ + xs, xg, xs_err, xf_err, nq, _, _ = self._setup_util(qs, vs, self.qg, self.vg) + gcost_xs = jnp.concatenate( + ( + self.batch_matmul(xs_err, self.Q), + (self.Qf @ xf_err)[None, :], + ), + axis=-1, + ) + gcost_qs = gcost_xs[:, :nq] + gcost_vs = gcost_xs[:, nq:] + gcost_us = self.batch_matmul(us, self.R) + return gcost_qs, gcost_vs, gcost_us, params, params + + def hess( + self, qs: jax.Array, vs: jax.Array, us: jax.Array, params: CostFunctionParams + ) -> Tuple[jax.Array, jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: + """Computes the Hessian of the cost of a trajectory. + + Args: + qs (shape=(N + 1, nq)): The generalized positions over the trajectory. + vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. + us (shape=(N, nu)): The controls over the trajectory. + + Returns: + Hcost_qs (shape=(N + 1, nq, N + 1, nq)): The Hessian of the cost wrt qs. + Let t, s be times from 0 to N + 1. Then, d^2/dq_{t,i}dq_{s,j} = Hcost_qs[t, i, s, j]. + Hcost_vs (shape=(N + 1, nv, N + 1, nv)): The Hessian of the cost wrt vs. + Let t, s be times from 0 to N + 1. Then, d^2/dv_{t,i}dv_{s,j} = Hcost_vs[t, i, s, j]. + Hcost_us (shape=(N, nu, N, nu)): The Hessian of the cost wrt us. + Let t, s be times from 0 to N. Then, d^2/du_{t,i}du_{s,j} = Hcost_us[t, i, s, j]. + Hcost_params: Unused. Included for API compliance. + new_params: Unused. Included for API compliance. + """ + N = us.shape[0] + xs, xg, xs_err, xf_err, nq, _, _ = self._setup_util(qs, vs, self.qg, self.vg) + Q_tiled = jnp.tile(self.Q[None, :, None, :], (N + 1, 1, N + 1, 1)) + Hcost_xs = Q_tiled.at[-1, :, -1, :].set(self.Qf) + + Hcost_qs = Hcost_xs[:, :nq, :, :nq] + Hcost_vs = Hcost_xs[:, nq:, :, nq:] + Hcost_us = jnp.tile(self.R[None, :, None, :], (N, 1, N, 1)) + + return Hcost_qs, Hcost_vs, Hcost_us, params, params From 2b268f8a0286b106cdb80199d8058aa6d8911b98 Mon Sep 17 00:00:00 2001 From: alberthli Date: Sun, 3 Dec 2023 22:42:17 -0800 Subject: [PATCH 11/28] fixed missing field in docstrings --- ambersim/trajopt/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ambersim/trajopt/base.py b/ambersim/trajopt/base.py index 385405ca..8d705144 100644 --- a/ambersim/trajopt/base.py +++ b/ambersim/trajopt/base.py @@ -112,6 +112,7 @@ def cost( qs (shape=(N + 1, nq)): The generalized positions over the trajectory. vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. us (shape=(N, nu)): The controls over the trajectory. + params: The parameters of the cost function. Returns: val (shape=(,)): The cost of the trajectory. @@ -131,6 +132,7 @@ def grad( qs (shape=(N + 1, nq)): The generalized positions over the trajectory. vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. us (shape=(N, nu)): The controls over the trajectory. + params: The parameters of the cost function. Returns: gcost_qs (shape=(N + 1, nq): The gradient of the cost wrt qs. @@ -153,6 +155,7 @@ def hess( qs (shape=(N + 1, nq)): The generalized positions over the trajectory. vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. us (shape=(N, nu)): The controls over the trajectory. + params: The parameters of the cost function. Returns: Hcost_qs (shape=(N + 1, nq, N + 1, nq)): The Hessian of the cost wrt qs. From 13efcfd4c45cd7fac2935a0df4f94d04c6bf8658 Mon Sep 17 00:00:00 2001 From: alberthli Date: Sun, 3 Dec 2023 22:42:32 -0800 Subject: [PATCH 12/28] fixed missing field in docstrings --- ambersim/trajopt/cost.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ambersim/trajopt/cost.py b/ambersim/trajopt/cost.py index b09d1cb7..454d74e6 100644 --- a/ambersim/trajopt/cost.py +++ b/ambersim/trajopt/cost.py @@ -61,7 +61,7 @@ def _setup_util( return xs, xg, xs_err, xf_err, qs.shape[-1], vs.shape[-1] @staticmethod - def batch_quadform(bs: jax.Array, A: jax.array) -> jax.Array: + def batch_quadform(bs: jax.Array, A: jax.Array) -> jax.Array: """Computes a batched quadratic form for a single instance of A. Args: @@ -97,6 +97,7 @@ def cost( qs (shape=(N + 1, nq)): The generalized positions over the trajectory. vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. us (shape=(N, nu)): The controls over the trajectory. + params: Unused. Included for API compliance. Returns: cost_val: The cost of the trajectory. @@ -117,6 +118,7 @@ def grad( qs (shape=(N + 1, nq)): The generalized positions over the trajectory. vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. us (shape=(N, nu)): The controls over the trajectory. + params: Unused. Included for API compliance. Returns: gcost_qs (shape=(N + 1, nq): The gradient of the cost wrt qs. @@ -147,6 +149,7 @@ def hess( qs (shape=(N + 1, nq)): The generalized positions over the trajectory. vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. us (shape=(N, nu)): The controls over the trajectory. + params: Unused. Included for API compliance. Returns: Hcost_qs (shape=(N + 1, nq, N + 1, nq)): The Hessian of the cost wrt qs. From b6ea473af53844750bbb23d15162afb75f0f334e Mon Sep 17 00:00:00 2001 From: alberthli Date: Sun, 3 Dec 2023 22:43:15 -0800 Subject: [PATCH 13/28] pass non-kwarg functions to vmap since vmap breaks in that case --- ambersim/trajopt/shooting.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/ambersim/trajopt/shooting.py b/ambersim/trajopt/shooting.py index 21d8db2d..b333ca57 100644 --- a/ambersim/trajopt/shooting.py +++ b/ambersim/trajopt/shooting.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Tuple import jax @@ -62,6 +63,8 @@ class ShootingParams(TrajectoryOptimizerParams): """Parameters for shooting methods.""" # inputs into the algorithm + q0: jax.Array # shape=(nq,) or (?) + v0: jax.Array # shape=(nv,) or (?) us_guess: jax.Array # shape=(N, nu) or (?) @property @@ -114,7 +117,7 @@ class VanillaPredictiveSampler(ShootingAlgorithm): """ model: mjx.Model - cost: CostFunction + cost_function: CostFunction nsamples: int = struct.field(pytree_node=False) stdev: float = struct.field(pytree_node=False) # noise scale, parameters theta_new ~ N(theta, (stdev ** 2) * I) @@ -143,21 +146,23 @@ def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, j # TODO(ahl): write a create classmethod that allows the user to set default_limits optionally with some semi- # reasonable default value keys = jax.random.split(key, nsamples) # splitting the random key so we can vmap over random sampling - _us_samples = vmap(jax.random.normal, in_axes=(0, None))(keys, shape=(N, m.nu)) + us_guess # (nsamples, N, nu) + jrn = partial(jax.random.normal, shape=(N, m.nu)) # jax random normal function with shape already set + _us_samples = vmap(jrn)(keys) + us_guess # (nsamples, N, nu) # clamping the samples to their control limits # TODO(ahl): check whether joints with no limits have reasonable defaults for m.actuator_ctrlrange limits = m.actuator_ctrlrange - us_samples = vmap(vmap(jnp.clip, in_axes=(0, None, None)), in_axes=(0, None, None))( - _us_samples, a_min=limits[:, 0], a_max=limits[:, 1] - ) # apply limits only to the last dim, need a nested vmap + clip_fn = partial(jnp.clip, a_min=limits[:, 0], a_max=limits[:, 1]) # clipping function with limits already set + us_samples = vmap(vmap(clip_fn))(_us_samples) # apply limits only to the last dim, need a nested vmap # limited = m.actuator_ctrllimited[:, None] # (nu, 1) whether each actuator has limited control authority # default_limits = jnp.array([[-1000.0, 1000.0]] * m.nu) # (nu, 2) default limits for each actuator # limits = jnp.where(limited, m.actuator_ctrlrange, default_limits) # (nu, 2) # predict many samples, evaluate them, and return the best trajectory tuple qs_samples, vs_samples = vmap(shoot, in_axes=(None, None, None, 0))(m, q0, v0, us_samples) - costs, _ = vmap(self.cost, in_axes=(0, 0, 0, None))(qs_samples, vs_samples, us_samples, None) # (nsamples,) + costs, _ = vmap(self.cost_function.cost, in_axes=(0, 0, 0, None))( + qs_samples, vs_samples, us_samples, None + ) # (nsamples,) best_idx = jnp.argmin(costs) qs_star = lax.dynamic_slice(qs_samples, (best_idx, 0, 0), (1, N + 1, m.nq))[0] # (N + 1, nq) vs_star = lax.dynamic_slice(vs_samples, (best_idx, 0, 0), (1, N + 1, m.nv))[0] # (N + 1, nv) From 6bf291d2e97de56965b30a67161b8fcdfd11f15a Mon Sep 17 00:00:00 2001 From: alberthli Date: Sun, 3 Dec 2023 22:43:30 -0800 Subject: [PATCH 14/28] added preliminary dead simple predictive sampling example --- examples/trajopt/ex_predictive_sampling.py | 72 ++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 examples/trajopt/ex_predictive_sampling.py diff --git a/examples/trajopt/ex_predictive_sampling.py b/examples/trajopt/ex_predictive_sampling.py new file mode 100644 index 00000000..895a81a0 --- /dev/null +++ b/examples/trajopt/ex_predictive_sampling.py @@ -0,0 +1,72 @@ +import timeit + +import jax +import jax.numpy as jnp +from jax import jit +from mujoco.mjx._src.types import DisableBit + +from ambersim.trajopt.cost import StaticGoalQuadraticCost +from ambersim.trajopt.shooting import VanillaPredictiveSampler, VanillaPredictiveSamplerParams +from ambersim.utils.io_utils import load_mjx_model_and_data_from_file + +if __name__ == "__main__": + # initializing the predictive sampler + model, _ = load_mjx_model_and_data_from_file("models/barrett_hand/bh280.xml", force_float=False) + model = model.replace( + opt=model.opt.replace( + timestep=0.002, # dt + iterations=1, # number of Newton steps to take during solve + ls_iterations=4, # number of line search iterations along step direction + integrator=1, # RK4 instead of Euler semi-implicit + solver=2, # Newton solver + disableflags=DisableBit.CONTACT, # [IMPORTANT] disable contact for this example + ) + ) + cost_function = StaticGoalQuadraticCost( + Q=jnp.eye(model.nq + model.nv), + Qf=10.0 * jnp.eye(model.nq + model.nv), + R=0.01 * jnp.eye(model.nu), + # qg=jnp.zeros(model.nq).at[6].set(1.0), # if force_float=True + qg=jnp.zeros(model.nq), + vg=jnp.zeros(model.nv), + ) + nsamples = 100 + stdev = 0.01 + ps = VanillaPredictiveSampler(model=model, cost_function=cost_function, nsamples=nsamples, stdev=stdev) + + # sampler parameters + key = jax.random.PRNGKey(0) # random seed for the predictive sampler + q0 = jnp.zeros(model.nq).at[6].set(1.0) + v0 = jnp.zeros(model.nv) + num_steps = 100 + us_guess = jnp.zeros((num_steps, model.nu)) + params = VanillaPredictiveSamplerParams(key=key, q0=q0, v0=v0, us_guess=us_guess) + + # sampling the best sequence of qs, vs, and us + optimize_fn = jit(ps.optimize) + + def _time_fn(): + qs_star, vs_star, us_star = optimize_fn(params) + qs_star.block_until_ready() + vs_star.block_until_ready() + us_star.block_until_ready() + + compile_time = timeit.timeit(_time_fn, number=1) + print(f"Compile time: {compile_time}") + + # informal timing test + # TODO(ahl): identify bottlenecks and zap them + # [Dec. 3, 2023] on vulcan, I've informally tested the scaling of runtime with the number of steps and the number + # of samples. Here are a few preliminary results: + # * nsamples=100, numsteps=10. avg: 0.01s + # * nsamples=1000, numsteps=10. avg: 0.015s + # * nsamples=10000, numsteps=10. avg: 0.07s + # * nsamples=100, numsteps=100. avg: 0.1s + # we conclude that the runtime scales predictably linearly with numsteps, but we also have some sort of (perhaps + # logarithmic) scaling of runtime with nsamples. this outlook is somewhat grim, and we need to also keep in mind + # that we've completely disabled contact for this example and set the number of solver iterations and line search + # iterations to very runtime-friendly values + num_timing_iters = 100 + time = timeit.timeit(_time_fn, number=num_timing_iters) + print(f"Avg. runtime: {time / num_timing_iters}") # timeit returns TOTAL time, so we compute the average ourselves + breakpoint() From 2cfdbfd579db871f4cdd49c1f5897a8205c94765 Mon Sep 17 00:00:00 2001 From: alberthli Date: Sun, 3 Dec 2023 22:57:47 -0800 Subject: [PATCH 15/28] minor, get from property --- ambersim/trajopt/shooting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ambersim/trajopt/shooting.py b/ambersim/trajopt/shooting.py index b333ca57..6029a784 100644 --- a/ambersim/trajopt/shooting.py +++ b/ambersim/trajopt/shooting.py @@ -135,11 +135,11 @@ def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, j # unpack the params m = self.model nsamples = self.nsamples + N = self.N q0 = params.q0 v0 = params.v0 us_guess = params.us_guess - N = params.N key = params.key # sample over the control inputs From dc5d247b6421fc356057d8fd5c1e39bca4b2baa2 Mon Sep 17 00:00:00 2001 From: alberthli Date: Sun, 3 Dec 2023 23:03:20 -0800 Subject: [PATCH 16/28] fixed some params + simplify sampling of us to one line --- ambersim/trajopt/shooting.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/ambersim/trajopt/shooting.py b/ambersim/trajopt/shooting.py index 6029a784..98a95b3c 100644 --- a/ambersim/trajopt/shooting.py +++ b/ambersim/trajopt/shooting.py @@ -135,21 +135,20 @@ def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, j # unpack the params m = self.model nsamples = self.nsamples - N = self.N + stdev = self.stdev q0 = params.q0 v0 = params.v0 us_guess = params.us_guess + N = params.N key = params.key # sample over the control inputs - # TODO(ahl): write a create classmethod that allows the user to set default_limits optionally with some semi- - # reasonable default value - keys = jax.random.split(key, nsamples) # splitting the random key so we can vmap over random sampling - jrn = partial(jax.random.normal, shape=(N, m.nu)) # jax random normal function with shape already set - _us_samples = vmap(jrn)(keys) + us_guess # (nsamples, N, nu) + _us_samples = us_guess + jax.random.normal(key, shape=(nsamples, N, m.nu)) * stdev # clamping the samples to their control limits + # TODO(ahl): write a create classmethod that allows the user to set default_limits optionally with some semi- + # reasonable default value # TODO(ahl): check whether joints with no limits have reasonable defaults for m.actuator_ctrlrange limits = m.actuator_ctrlrange clip_fn = partial(jnp.clip, a_min=limits[:, 0], a_max=limits[:, 1]) # clipping function with limits already set From ff6dbccad4d475b98ff3c6e7101254b6012c2889 Mon Sep 17 00:00:00 2001 From: alberthli Date: Mon, 4 Dec 2023 10:41:16 -0800 Subject: [PATCH 17/28] [DIRTY] commit that contains commented out code for pre-allocating data in case we need it later --- ambersim/trajopt/shooting.py | 38 +++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/ambersim/trajopt/shooting.py b/ambersim/trajopt/shooting.py index 98a95b3c..2a12b840 100644 --- a/ambersim/trajopt/shooting.py +++ b/ambersim/trajopt/shooting.py @@ -19,11 +19,13 @@ # ##### # +# def shoot(m: mjx.Model, d: mjx.Data, q0: jax.Array, v0: jax.Array, us: jax.Array) -> Tuple[jax.Array, jax.Array]: def shoot(m: mjx.Model, q0: jax.Array, v0: jax.Array, us: jax.Array) -> Tuple[jax.Array, jax.Array]: """Utility function that shoots a model forward given a sequence of control inputs. Args: m: The model. + # d: The data. q0: The initial generalized coordinates. v0: The initial generalized velocities. us: The control inputs. @@ -33,7 +35,8 @@ def shoot(m: mjx.Model, q0: jax.Array, v0: jax.Array, us: jax.Array) -> Tuple[ja vs (shape=(N + 1, nv)): The generalized velocities. """ d = mjx.make_data(m) - d = d.replace(qpos=q0, qvel=v0) # setting the initial state. + # d = d.replace(qpos=q0, qvel=v0) # setting the initial state. + # d = mjx.forward(m, d) # setting other internal states like acceleration without integrating def scan_fn(d, u): """Integrates the model forward one step given the control input u.""" @@ -117,10 +120,40 @@ class VanillaPredictiveSampler(ShootingAlgorithm): """ model: mjx.Model + # data: mjx.Data cost_function: CostFunction nsamples: int = struct.field(pytree_node=False) stdev: float = struct.field(pytree_node=False) # noise scale, parameters theta_new ~ N(theta, (stdev ** 2) * I) + # @classmethod + # def create( + # cls, + # model: mjx.Model, + # cost_function: CostFunction, + # nsamples: int, + # stdev: float, + # ) -> "VanillaPredictiveSampler": + # """Creates a vanilla predictive sampler from the given parameters. + + # You should always initialize the sampler using this class method. + # """ + # # initializing a batched mjx.Data object + # def _init_data(model: mjx.Model) -> mjx.Data: + # data = mjx.make_data(model) + # data = mjx.forward(model, data) + # return data + + # _dummy_fn = lambda _dummy: _init_data(model) # dummy fn to allow vmapping over func with no inputs + # data = vmap(_dummy_fn)(jnp.empty(nsamples)) # (nsamples,) + + # return cls( + # model=model, + # data=data, + # cost_function=cost_function, + # nsamples=nsamples, + # stdev=stdev, + # ) + def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, jax.Array, jax.Array]: """Optimizes a trajectory using a vanilla predictive sampler. @@ -134,6 +167,7 @@ def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, j """ # unpack the params m = self.model + # d = self.data nsamples = self.nsamples stdev = self.stdev @@ -158,6 +192,8 @@ def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, j # limits = jnp.where(limited, m.actuator_ctrlrange, default_limits) # (nu, 2) # predict many samples, evaluate them, and return the best trajectory tuple + # vmap over the input data and the control trajectories + # qs_samples, vs_samples = vmap(shoot, in_axes=(None, 0, None, None, 0))(m, d, q0, v0, us_samples) qs_samples, vs_samples = vmap(shoot, in_axes=(None, None, None, 0))(m, q0, v0, us_samples) costs, _ = vmap(self.cost_function.cost, in_axes=(0, 0, 0, None))( qs_samples, vs_samples, us_samples, None From 9f12d740c3e6ed18e0a82acba18a12ee162cd2d7 Mon Sep 17 00:00:00 2001 From: alberthli Date: Mon, 4 Dec 2023 10:46:08 -0800 Subject: [PATCH 18/28] revert to allocating data at shooting time --- ambersim/trajopt/shooting.py | 39 +++--------------------------------- 1 file changed, 3 insertions(+), 36 deletions(-) diff --git a/ambersim/trajopt/shooting.py b/ambersim/trajopt/shooting.py index 2a12b840..60a80606 100644 --- a/ambersim/trajopt/shooting.py +++ b/ambersim/trajopt/shooting.py @@ -19,13 +19,11 @@ # ##### # -# def shoot(m: mjx.Model, d: mjx.Data, q0: jax.Array, v0: jax.Array, us: jax.Array) -> Tuple[jax.Array, jax.Array]: def shoot(m: mjx.Model, q0: jax.Array, v0: jax.Array, us: jax.Array) -> Tuple[jax.Array, jax.Array]: """Utility function that shoots a model forward given a sequence of control inputs. Args: m: The model. - # d: The data. q0: The initial generalized coordinates. v0: The initial generalized velocities. us: The control inputs. @@ -34,9 +32,10 @@ def shoot(m: mjx.Model, q0: jax.Array, v0: jax.Array, us: jax.Array) -> Tuple[ja qs (shape=(N + 1, nq)): The generalized coordinates. vs (shape=(N + 1, nv)): The generalized velocities. """ + # initializing the data d = mjx.make_data(m) - # d = d.replace(qpos=q0, qvel=v0) # setting the initial state. - # d = mjx.forward(m, d) # setting other internal states like acceleration without integrating + d = d.replace(qpos=q0, qvel=v0) # setting the initial state. + d = mjx.forward(m, d) # setting other internal states like acceleration without integrating def scan_fn(d, u): """Integrates the model forward one step given the control input u.""" @@ -120,40 +119,10 @@ class VanillaPredictiveSampler(ShootingAlgorithm): """ model: mjx.Model - # data: mjx.Data cost_function: CostFunction nsamples: int = struct.field(pytree_node=False) stdev: float = struct.field(pytree_node=False) # noise scale, parameters theta_new ~ N(theta, (stdev ** 2) * I) - # @classmethod - # def create( - # cls, - # model: mjx.Model, - # cost_function: CostFunction, - # nsamples: int, - # stdev: float, - # ) -> "VanillaPredictiveSampler": - # """Creates a vanilla predictive sampler from the given parameters. - - # You should always initialize the sampler using this class method. - # """ - # # initializing a batched mjx.Data object - # def _init_data(model: mjx.Model) -> mjx.Data: - # data = mjx.make_data(model) - # data = mjx.forward(model, data) - # return data - - # _dummy_fn = lambda _dummy: _init_data(model) # dummy fn to allow vmapping over func with no inputs - # data = vmap(_dummy_fn)(jnp.empty(nsamples)) # (nsamples,) - - # return cls( - # model=model, - # data=data, - # cost_function=cost_function, - # nsamples=nsamples, - # stdev=stdev, - # ) - def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, jax.Array, jax.Array]: """Optimizes a trajectory using a vanilla predictive sampler. @@ -167,7 +136,6 @@ def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, j """ # unpack the params m = self.model - # d = self.data nsamples = self.nsamples stdev = self.stdev @@ -193,7 +161,6 @@ def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, j # predict many samples, evaluate them, and return the best trajectory tuple # vmap over the input data and the control trajectories - # qs_samples, vs_samples = vmap(shoot, in_axes=(None, 0, None, None, 0))(m, d, q0, v0, us_samples) qs_samples, vs_samples = vmap(shoot, in_axes=(None, None, None, 0))(m, q0, v0, us_samples) costs, _ = vmap(self.cost_function.cost, in_axes=(0, 0, 0, None))( qs_samples, vs_samples, us_samples, None From 3fda8bad7d3de74789ff0e8b9055ef1082295c57 Mon Sep 17 00:00:00 2001 From: alberthli Date: Mon, 4 Dec 2023 13:18:52 -0800 Subject: [PATCH 19/28] [DIRTY] some additions to the example script for profiling help --- examples/trajopt/ex_predictive_sampling.py | 59 ++++++++++++---------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/examples/trajopt/ex_predictive_sampling.py b/examples/trajopt/ex_predictive_sampling.py index 895a81a0..73cca5c8 100644 --- a/examples/trajopt/ex_predictive_sampling.py +++ b/examples/trajopt/ex_predictive_sampling.py @@ -1,5 +1,3 @@ -import timeit - import jax import jax.numpy as jnp from jax import jit @@ -17,7 +15,7 @@ timestep=0.002, # dt iterations=1, # number of Newton steps to take during solve ls_iterations=4, # number of line search iterations along step direction - integrator=1, # RK4 instead of Euler semi-implicit + integrator=0, # Euler semi-implicit integration solver=2, # Newton solver disableflags=DisableBit.CONTACT, # [IMPORTANT] disable contact for this example ) @@ -30,7 +28,7 @@ qg=jnp.zeros(model.nq), vg=jnp.zeros(model.nv), ) - nsamples = 100 + nsamples = 100000 stdev = 0.01 ps = VanillaPredictiveSampler(model=model, cost_function=cost_function, nsamples=nsamples, stdev=stdev) @@ -38,35 +36,40 @@ key = jax.random.PRNGKey(0) # random seed for the predictive sampler q0 = jnp.zeros(model.nq).at[6].set(1.0) v0 = jnp.zeros(model.nv) - num_steps = 100 + num_steps = 25 us_guess = jnp.zeros((num_steps, model.nu)) params = VanillaPredictiveSamplerParams(key=key, q0=q0, v0=v0, us_guess=us_guess) # sampling the best sequence of qs, vs, and us optimize_fn = jit(ps.optimize) - def _time_fn(): - qs_star, vs_star, us_star = optimize_fn(params) - qs_star.block_until_ready() - vs_star.block_until_ready() - us_star.block_until_ready() + # [DEBUG] profiling with nsight systems + # qs_star, vs_star, us_star = optimize_fn(params) # JIT compiling + # with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): + # qs_star, vs_star, us_star = optimize_fn(params) # after JIT + + # def _time_fn(): + # qs_star, vs_star, us_star = optimize_fn(params) + # qs_star.block_until_ready() + # vs_star.block_until_ready() + # us_star.block_until_ready() - compile_time = timeit.timeit(_time_fn, number=1) - print(f"Compile time: {compile_time}") + # compile_time = timeit.timeit(_time_fn, number=1) + # print(f"Compile time: {compile_time}") - # informal timing test - # TODO(ahl): identify bottlenecks and zap them - # [Dec. 3, 2023] on vulcan, I've informally tested the scaling of runtime with the number of steps and the number - # of samples. Here are a few preliminary results: - # * nsamples=100, numsteps=10. avg: 0.01s - # * nsamples=1000, numsteps=10. avg: 0.015s - # * nsamples=10000, numsteps=10. avg: 0.07s - # * nsamples=100, numsteps=100. avg: 0.1s - # we conclude that the runtime scales predictably linearly with numsteps, but we also have some sort of (perhaps - # logarithmic) scaling of runtime with nsamples. this outlook is somewhat grim, and we need to also keep in mind - # that we've completely disabled contact for this example and set the number of solver iterations and line search - # iterations to very runtime-friendly values - num_timing_iters = 100 - time = timeit.timeit(_time_fn, number=num_timing_iters) - print(f"Avg. runtime: {time / num_timing_iters}") # timeit returns TOTAL time, so we compute the average ourselves - breakpoint() + # # informal timing test + # # TODO(ahl): identify bottlenecks and zap them + # # [Dec. 3, 2023] on vulcan, I've informally tested the scaling of runtime with the number of steps and the number + # # of samples. Here are a few preliminary results: + # # * nsamples=100, numsteps=10. avg: 0.01s + # # * nsamples=1000, numsteps=10. avg: 0.015s + # # * nsamples=10000, numsteps=10. avg: 0.07s + # # * nsamples=100, numsteps=100. avg: 0.1s + # # we conclude that the runtime scales predictably linearly with numsteps, but we also have some sort of (perhaps + # # logarithmic) scaling of runtime with nsamples. this outlook is somewhat grim, and we need to also keep in mind + # # that we've completely disabled contact for this example and set the number of solver iterations and line search + # # iterations to very runtime-friendly values + # num_timing_iters = 100 + # time = timeit.timeit(_time_fn, number=num_timing_iters) + # print(f"Avg. runtime: {time / num_timing_iters}") # timeit returns TOTAL time, so we compute the average ourselves + # breakpoint() From 0f507e3143bca668ca0cefba725a72b9c7caa73b Mon Sep 17 00:00:00 2001 From: alberthli Date: Mon, 4 Dec 2023 21:47:53 -0800 Subject: [PATCH 20/28] refactor API to take xs instead of qs and vs --- ambersim/trajopt/base.py | 34 +++---- ambersim/trajopt/cost.py | 109 +++++++++------------ ambersim/trajopt/shooting.py | 48 ++++----- examples/trajopt/ex_predictive_sampling.py | 66 ++++++------- 4 files changed, 110 insertions(+), 147 deletions(-) diff --git a/ambersim/trajopt/base.py b/ambersim/trajopt/base.py index 8d705144..835378e9 100644 --- a/ambersim/trajopt/base.py +++ b/ambersim/trajopt/base.py @@ -103,14 +103,11 @@ class CostFunction: (4) histories of higher-order derivatives can be useful for updating their current estimates, e.g., BFGS. """ - def cost( - self, qs: jax.Array, vs: jax.Array, us: jax.Array, params: CostFunctionParams - ) -> Tuple[jax.Array, CostFunctionParams]: + def cost(self, xs: jax.Array, us: jax.Array, params: CostFunctionParams) -> Tuple[jax.Array, CostFunctionParams]: """Computes the cost of a trajectory. Args: - qs (shape=(N + 1, nq)): The generalized positions over the trajectory. - vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. + xs (shape=(N + 1, nq + nv)): The state trajectory. us (shape=(N, nu)): The controls over the trajectory. params: The parameters of the cost function. @@ -121,50 +118,45 @@ def cost( raise NotImplementedError def grad( - self, qs: jax.Array, vs: jax.Array, us: jax.Array, params: CostFunctionParams - ) -> Tuple[jax.Array, jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: + self, xs: jax.Array, us: jax.Array, params: CostFunctionParams + ) -> Tuple[jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: """Computes the gradient of the cost of a trajectory. The default implementation of this function uses JAX's autodiff. Simply override this function if you would like to supply an analytical gradient. Args: - qs (shape=(N + 1, nq)): The generalized positions over the trajectory. - vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. + xs (shape=(N + 1, nq + nv)): The state trajectory. us (shape=(N, nu)): The controls over the trajectory. params: The parameters of the cost function. Returns: - gcost_qs (shape=(N + 1, nq): The gradient of the cost wrt qs. - gcost_vs (shape=(N + 1, nv): The gradient of the cost wrt vs. + gcost_xs (shape=(N + 1, nq + nv): The gradient of the cost wrt xs. gcost_us (shape=(N, nu)): The gradient of the cost wrt us. gcost_params: The gradient of the cost wrt params. new_params: The updated parameters of the cost function. """ - return grad(self.cost, argnums=(0, 1, 2, 3))(qs, vs, us, params) + (params,) + return grad(self.cost, argnums=(0, 1, 2))(xs, us, params) + (params,) def hess( - self, qs: jax.Array, vs: jax.Array, us: jax.Array, params: CostFunctionParams - ) -> Tuple[jax.Array, jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: + self, xs: jax.Array, us: jax.Array, params: CostFunctionParams + ) -> Tuple[jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: """Computes the Hessian of the cost of a trajectory. The default implementation of this function uses JAX's autodiff. Simply override this function if you would like to supply an analytical Hessian. Args: - qs (shape=(N + 1, nq)): The generalized positions over the trajectory. - vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. + xs (shape=(N + 1, nq + nv)): The state trajectory. us (shape=(N, nu)): The controls over the trajectory. params: The parameters of the cost function. Returns: - Hcost_qs (shape=(N + 1, nq, N + 1, nq)): The Hessian of the cost wrt qs. - Let t, s be times from 0 to N + 1. Then, d^2/dq_{t,i}dq_{s,j} = Hcost_qs[t, i, s, j]. - Hcost_vs (shape=(N + 1, nv, N + 1, nv)): The Hessian of the cost wrt vs. - Let t, s be times from 0 to N + 1. Then, d^2/dv_{t,i}dv_{s,j} = Hcost_vs[t, i, s, j]. + Hcost_xs (shape=(N + 1, nq + nv, N + 1, nq + nv)): The Hessian of the cost wrt xs. + Let t, s be times from 0 to N + 1. Then, d^2/dx_{t,i}dx_{s,j} = Hcost_xs[t, i, s, j]. Hcost_us (shape=(N, nu, N, nu)): The Hessian of the cost wrt us. Let t, s be times from 0 to N. Then, d^2/du_{t,i}du_{s,j} = Hcost_us[t, i, s, j]. Hcost_params: The Hessian of the cost wrt params. new_params: The updated parameters of the cost function. """ - return hessian(self.cost, argnums=(0, 1, 2, 3))(qs, vs, us, params) + (params,) + return hessian(self.cost, argnums=(0, 1, 2))(xs, us, params) + (params,) diff --git a/ambersim/trajopt/cost.py b/ambersim/trajopt/cost.py index 454d74e6..feb579bb 100644 --- a/ambersim/trajopt/cost.py +++ b/ambersim/trajopt/cost.py @@ -18,47 +18,43 @@ class StaticGoalQuadraticCost(CostFunction): dense. """ - def __init__(self, Q: jax.Array, Qf: jax.Array, R: jax.Array, qg: jax.Array, vg: jax.Array) -> None: + def __init__(self, Q: jax.Array, Qf: jax.Array, R: jax.Array, xg: jax.Array) -> None: """Initializes a quadratic cost function. Args: Q (shape=(nx, nx)): The state cost matrix. Qf (shape=(nx, nx)): The final state cost matrix. R (shape=(nu, nu)): The control cost matrix. - qg (shape=(nq,)): The goal generalized coordinates. - vg (shape=(nv,)): The goal generalized velocities. + xg (shape=(nq,)): The goal state. """ self.Q = Q self.Qf = Qf self.R = R - self.qg = qg - self.vg = vg - - @staticmethod - def _setup_util( - qs: jax.Array, vs: jax.Array, qg: jax.Array, vg: jax.Array - ) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array, int, int, int]: - """Utility function that sets up the cost function. - - Args: - qs (shape=(N + 1, nq)): The generalized positions over the trajectory. - vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. - qg (shape=(nx,)): The goal generalized position. - vg (shape=(nv,)): The goal generalized velocity. - - Returns: - xs (shape=(N + 1, nx)): The states over the trajectory. - xg (shape=(nx,)): The goal state. - xs_err (shape=(N, nx)): The state errors up to the final state. - xf_err (shape=(nx,)): The state error at the final state. - nq: The number of generalized coordinates. - nv: The number of generalized velocities. - """ - xs = jnp.concatenate((qs, vs), axis=-1) - xg = jnp.concatenate((qg, vg), axis=-1) - xs_err = xs[:-1, :] - xg - xf_err = xs[-1, :] - xg - return xs, xg, xs_err, xf_err, qs.shape[-1], vs.shape[-1] + self.xg = xg + + # @staticmethod + # def _setup_util( + # xs: jax.Array, qg: jax.Array, vg: jax.Array + # ) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array, int, int, int]: + # """Utility function that sets up the cost function. + + # Args: + # xs (shape=(N + 1, nq + nv)): The state trajectory + # xg (shape=(nq + nv,)): The goal state. + + # Returns: + # xs (shape=(N + 1, nx)): The states over the trajectory. + # xg (shape=(nx,)): The goal state. + # xs_err (shape=(N, nx)): The state errors up to the final state. + # xf_err (shape=(nx,)): The state error at the final state. + # nq: The number of generalized coordinates. + # nv: The number of generalized velocities. + # """ + # xs = jnp.concatenate((qs, vs), axis=-1) + # xg = jnp.concatenate((qg, vg), axis=-1) + # xs_err = xs[:-1, :] - xg + # xf_err = xs[-1, :] - xg + # return xs, xg, xs_err, xf_err, qs.shape[-1], vs.shape[-1] @staticmethod def batch_quadform(bs: jax.Array, A: jax.Array) -> jax.Array: @@ -86,16 +82,13 @@ def batch_matmul(bs: jax.Array, A: jax.Array) -> jax.Array: """ return jnp.einsum("...i,ij->...j", bs, A) - def cost( - self, qs: jax.Array, vs: jax.Array, us: jax.Array, params: CostFunctionParams - ) -> Tuple[jax.Array, CostFunctionParams]: + def cost(self, xs: jax.Array, us: jax.Array, params: CostFunctionParams) -> Tuple[jax.Array, CostFunctionParams]: """Computes the cost of a trajectory. - cost = 0.5 * ([q; v] - [qg; vg])' @ Q @ ([q; v] - [qg; vg]) + 0.5 * u' @ R @ u + cost = 0.5 * (xs - xg)' @ Q @ (xs - xg) + 0.5 * us' @ R @ us Args: - qs (shape=(N + 1, nq)): The generalized positions over the trajectory. - vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. + xs (shape=(N + 1, nq + nv)): The state trajectory. us (shape=(N, nu)): The controls over the trajectory. params: Unused. Included for API compliance. @@ -103,31 +96,31 @@ def cost( cost_val: The cost of the trajectory. new_params: Unused. Included for API compliance. """ - xs, xg, xs_err, xf_err, _, _ = self._setup_util(qs, vs, self.qg, self.vg) + xs_err = xs[:-1, :] - self.xg # errors before the terminal state + xf_err = xs[-1, :] - self.xg val = 0.5 * ( self.batch_quadform(xs_err, self.Q) + self.batch_quadform(xf_err, self.Qf) + self.batch_quadform(us, self.R) ) return val, params def grad( - self, qs: jax.Array, vs: jax.Array, us: jax.Array, params: CostFunctionParams - ) -> Tuple[jax.Array, jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: + self, xs: jax.Array, us: jax.Array, params: CostFunctionParams + ) -> Tuple[jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: """Computes the gradient of the cost of a trajectory. Args: - qs (shape=(N + 1, nq)): The generalized positions over the trajectory. - vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. + xs (shape=(N + 1, nq + nv)): The state trajectory. us (shape=(N, nu)): The controls over the trajectory. params: Unused. Included for API compliance. Returns: - gcost_qs (shape=(N + 1, nq): The gradient of the cost wrt qs. - gcost_vs (shape=(N + 1, nv): The gradient of the cost wrt vs. + gcost_xs (shape=(N + 1, nq + nv): The gradient of the cost wrt xs. gcost_us (shape=(N, nu)): The gradient of the cost wrt us. gcost_params: Unused. Included for API compliance. new_params: Unused. Included for API compliance. """ - xs, xg, xs_err, xf_err, nq, _, _ = self._setup_util(qs, vs, self.qg, self.vg) + xs_err = xs[:-1, :] - self.xg # errors before the terminal state + xf_err = xs[-1, :] - self.xg gcost_xs = jnp.concatenate( ( self.batch_matmul(xs_err, self.Q), @@ -135,39 +128,29 @@ def grad( ), axis=-1, ) - gcost_qs = gcost_xs[:, :nq] - gcost_vs = gcost_xs[:, nq:] gcost_us = self.batch_matmul(us, self.R) - return gcost_qs, gcost_vs, gcost_us, params, params + return gcost_xs, gcost_us, params, params def hess( - self, qs: jax.Array, vs: jax.Array, us: jax.Array, params: CostFunctionParams - ) -> Tuple[jax.Array, jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: - """Computes the Hessian of the cost of a trajectory. + self, xs: jax.Array, us: jax.Array, params: CostFunctionParams + ) -> Tuple[jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: + """Computes the gradient of the cost of a trajectory. Args: - qs (shape=(N + 1, nq)): The generalized positions over the trajectory. - vs (shape=(N + 1, nv)): The generalized velocities over the trajectory. + xs (shape=(N + 1, nq + nv)): The state trajectory. us (shape=(N, nu)): The controls over the trajectory. params: Unused. Included for API compliance. Returns: - Hcost_qs (shape=(N + 1, nq, N + 1, nq)): The Hessian of the cost wrt qs. - Let t, s be times from 0 to N + 1. Then, d^2/dq_{t,i}dq_{s,j} = Hcost_qs[t, i, s, j]. - Hcost_vs (shape=(N + 1, nv, N + 1, nv)): The Hessian of the cost wrt vs. - Let t, s be times from 0 to N + 1. Then, d^2/dv_{t,i}dv_{s,j} = Hcost_vs[t, i, s, j]. + Hcost_xs (shape=(N + 1, nq + nv, N + 1, nq + nv)): The Hessian of the cost wrt xs. + Let t, s be times from 0 to N + 1. Then, d^2/dx_{t,i}dx_{s,j} = Hcost_xs[t, i, s, j]. Hcost_us (shape=(N, nu, N, nu)): The Hessian of the cost wrt us. Let t, s be times from 0 to N. Then, d^2/du_{t,i}du_{s,j} = Hcost_us[t, i, s, j]. Hcost_params: Unused. Included for API compliance. new_params: Unused. Included for API compliance. """ N = us.shape[0] - xs, xg, xs_err, xf_err, nq, _, _ = self._setup_util(qs, vs, self.qg, self.vg) Q_tiled = jnp.tile(self.Q[None, :, None, :], (N + 1, 1, N + 1, 1)) Hcost_xs = Q_tiled.at[-1, :, -1, :].set(self.Qf) - - Hcost_qs = Hcost_xs[:, :nq, :, :nq] - Hcost_vs = Hcost_xs[:, nq:, :, nq:] Hcost_us = jnp.tile(self.R[None, :, None, :], (N, 1, N, 1)) - - return Hcost_qs, Hcost_vs, Hcost_us, params, params + return Hcost_xs, Hcost_us, params, params diff --git a/ambersim/trajopt/shooting.py b/ambersim/trajopt/shooting.py index 60a80606..baea9d9f 100644 --- a/ambersim/trajopt/shooting.py +++ b/ambersim/trajopt/shooting.py @@ -19,22 +19,20 @@ # ##### # -def shoot(m: mjx.Model, q0: jax.Array, v0: jax.Array, us: jax.Array) -> Tuple[jax.Array, jax.Array]: +def shoot(m: mjx.Model, x0: jax.Array, us: jax.Array) -> jax.Array: """Utility function that shoots a model forward given a sequence of control inputs. Args: m: The model. - q0: The initial generalized coordinates. - v0: The initial generalized velocities. + x0: The initial state. us: The control inputs. Returns: - qs (shape=(N + 1, nq)): The generalized coordinates. - vs (shape=(N + 1, nv)): The generalized velocities. + xs (shape=(N + 1, nq + nv)): The state trajectory. """ # initializing the data d = mjx.make_data(m) - d = d.replace(qpos=q0, qvel=v0) # setting the initial state. + d = d.replace(qpos=x0[: m.nq], qvel=x0[m.nq :]) # setting the initial state. d = mjx.forward(m, d) # setting other internal states like acceleration without integrating def scan_fn(d, u): @@ -45,12 +43,9 @@ def scan_fn(d, u): return d, x # scan over the control inputs to get the trajectory. - _, xs = lax.scan(scan_fn, d, us, length=us.shape[0]) - _qs = xs[:, : m.nq] - _vs = xs[:, m.nq : m.nq + m.nv] - qs = jnp.concatenate((q0[None, :], _qs), axis=0) # (N + 1, nq) - vs = jnp.concatenate((v0[None, :], _vs), axis=0) # (N + 1, nv) - return qs, vs + _, _xs = lax.scan(scan_fn, d, us, length=us.shape[0]) + xs = jnp.concatenate((x0[None, :], _xs), axis=0) # (N + 1, nq + nv) + return xs # ################ # @@ -65,8 +60,7 @@ class ShootingParams(TrajectoryOptimizerParams): """Parameters for shooting methods.""" # inputs into the algorithm - q0: jax.Array # shape=(nq,) or (?) - v0: jax.Array # shape=(nv,) or (?) + x0: jax.Array # shape=(nq + nv,) or (?) us_guess: jax.Array # shape=(N, nu) or (?) @property @@ -83,16 +77,15 @@ def N(self) -> int: class ShootingAlgorithm(TrajectoryOptimizer): """A trajectory optimization algorithm based on shooting methods.""" - def optimize(self, params: ShootingParams) -> Tuple[jax.Array, jax.Array, jax.Array]: + def optimize(self, params: ShootingParams) -> Tuple[jax.Array, jax.Array]: """Optimizes a trajectory using a shooting method. Args: params: The parameters of the trajectory optimizer. Returns: - qs (shape=(N + 1, nq) or (?)): The optimized trajectory. - vs (shape=(N + 1, nv) or (?)): The optimized generalized velocities. - us (shape=(N, nu) or (?)): The optimized controls. + xs_star (shape=(N + 1, nq) or (?)): The optimized trajectory. + us_star (shape=(N, nu) or (?)): The optimized controls. """ raise NotImplementedError @@ -123,15 +116,14 @@ class VanillaPredictiveSampler(ShootingAlgorithm): nsamples: int = struct.field(pytree_node=False) stdev: float = struct.field(pytree_node=False) # noise scale, parameters theta_new ~ N(theta, (stdev ** 2) * I) - def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, jax.Array, jax.Array]: + def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, jax.Array]: """Optimizes a trajectory using a vanilla predictive sampler. Args: params: The parameters of the trajectory optimizer. Returns: - qs (shape=(N + 1, nq)): The optimized trajectory. - vs (shape=(N + 1, nv)): The optimized generalized velocities. + xs (shape=(N + 1, nq + nv)): The optimized trajectory. us (shape=(N, nu)): The optimized controls. """ # unpack the params @@ -139,8 +131,7 @@ def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, j nsamples = self.nsamples stdev = self.stdev - q0 = params.q0 - v0 = params.v0 + x0 = params.x0 us_guess = params.us_guess N = params.N key = params.key @@ -161,12 +152,9 @@ def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, j # predict many samples, evaluate them, and return the best trajectory tuple # vmap over the input data and the control trajectories - qs_samples, vs_samples = vmap(shoot, in_axes=(None, None, None, 0))(m, q0, v0, us_samples) - costs, _ = vmap(self.cost_function.cost, in_axes=(0, 0, 0, None))( - qs_samples, vs_samples, us_samples, None - ) # (nsamples,) + xs_samples = vmap(shoot, in_axes=(None, None, 0))(m, x0, us_samples) + costs, _ = vmap(self.cost_function.cost, in_axes=(0, 0, None))(xs_samples, us_samples, None) # (nsamples,) best_idx = jnp.argmin(costs) - qs_star = lax.dynamic_slice(qs_samples, (best_idx, 0, 0), (1, N + 1, m.nq))[0] # (N + 1, nq) - vs_star = lax.dynamic_slice(vs_samples, (best_idx, 0, 0), (1, N + 1, m.nv))[0] # (N + 1, nv) + xs_star = lax.dynamic_slice(xs_samples, (best_idx, 0, 0), (1, N + 1, m.nq + m.nv))[0] # (N + 1, nq + nv) us_star = lax.dynamic_slice(us_samples, (best_idx, 0, 0), (1, N, m.nu))[0] # (N, nu) - return qs_star, vs_star, us_star + return xs_star, us_star diff --git a/examples/trajopt/ex_predictive_sampling.py b/examples/trajopt/ex_predictive_sampling.py index 73cca5c8..12f9686a 100644 --- a/examples/trajopt/ex_predictive_sampling.py +++ b/examples/trajopt/ex_predictive_sampling.py @@ -1,3 +1,5 @@ +import timeit + import jax import jax.numpy as jnp from jax import jit @@ -24,52 +26,50 @@ Q=jnp.eye(model.nq + model.nv), Qf=10.0 * jnp.eye(model.nq + model.nv), R=0.01 * jnp.eye(model.nu), - # qg=jnp.zeros(model.nq).at[6].set(1.0), # if force_float=True - qg=jnp.zeros(model.nq), - vg=jnp.zeros(model.nv), + # xg=jnp.zeros(model.nq + model.nv).at[6].set(1.0), # if force_float=True + xg=jnp.zeros(model.nq + model.nv), ) - nsamples = 100000 + nsamples = 100 stdev = 0.01 ps = VanillaPredictiveSampler(model=model, cost_function=cost_function, nsamples=nsamples, stdev=stdev) # sampler parameters key = jax.random.PRNGKey(0) # random seed for the predictive sampler - q0 = jnp.zeros(model.nq).at[6].set(1.0) - v0 = jnp.zeros(model.nv) - num_steps = 25 + # x0 = jnp.zeros(model.nq + model.nv).at[6].set(1.0) # if force_float=True + x0 = jnp.zeros(model.nq + model.nv) + num_steps = 10 us_guess = jnp.zeros((num_steps, model.nu)) - params = VanillaPredictiveSamplerParams(key=key, q0=q0, v0=v0, us_guess=us_guess) + params = VanillaPredictiveSamplerParams(key=key, x0=x0, us_guess=us_guess) # sampling the best sequence of qs, vs, and us optimize_fn = jit(ps.optimize) # [DEBUG] profiling with nsight systems - # qs_star, vs_star, us_star = optimize_fn(params) # JIT compiling + # xs_star, us_star = optimize_fn(params) # JIT compiling # with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): - # qs_star, vs_star, us_star = optimize_fn(params) # after JIT + # xs_star us_star = optimize_fn(params) # after JIT - # def _time_fn(): - # qs_star, vs_star, us_star = optimize_fn(params) - # qs_star.block_until_ready() - # vs_star.block_until_ready() - # us_star.block_until_ready() + def _time_fn(): + xs_star, us_star = optimize_fn(params) + xs_star.block_until_ready() + us_star.block_until_ready() - # compile_time = timeit.timeit(_time_fn, number=1) - # print(f"Compile time: {compile_time}") + compile_time = timeit.timeit(_time_fn, number=1) + print(f"Compile time: {compile_time}") - # # informal timing test - # # TODO(ahl): identify bottlenecks and zap them - # # [Dec. 3, 2023] on vulcan, I've informally tested the scaling of runtime with the number of steps and the number - # # of samples. Here are a few preliminary results: - # # * nsamples=100, numsteps=10. avg: 0.01s - # # * nsamples=1000, numsteps=10. avg: 0.015s - # # * nsamples=10000, numsteps=10. avg: 0.07s - # # * nsamples=100, numsteps=100. avg: 0.1s - # # we conclude that the runtime scales predictably linearly with numsteps, but we also have some sort of (perhaps - # # logarithmic) scaling of runtime with nsamples. this outlook is somewhat grim, and we need to also keep in mind - # # that we've completely disabled contact for this example and set the number of solver iterations and line search - # # iterations to very runtime-friendly values - # num_timing_iters = 100 - # time = timeit.timeit(_time_fn, number=num_timing_iters) - # print(f"Avg. runtime: {time / num_timing_iters}") # timeit returns TOTAL time, so we compute the average ourselves - # breakpoint() + # informal timing test + # TODO(ahl): identify bottlenecks and zap them + # [Dec. 3, 2023] on vulcan, I've informally tested the scaling of runtime with the number of steps and the number + # of samples. Here are a few preliminary results: + # * nsamples=100, numsteps=10. avg: 0.01s + # * nsamples=1000, numsteps=10. avg: 0.015s + # * nsamples=10000, numsteps=10. avg: 0.07s + # * nsamples=100, numsteps=100. avg: 0.1s + # we conclude that the runtime scales predictably linearly with numsteps, but we also have some sort of (perhaps + # logarithmic) scaling of runtime with nsamples. this outlook is somewhat grim, and we need to also keep in mind + # that we've completely disabled contact for this example and set the number of solver iterations and line search + # iterations to very runtime-friendly values + num_timing_iters = 100 + time = timeit.timeit(_time_fn, number=num_timing_iters) + print(f"Avg. runtime: {time / num_timing_iters}") # timeit returns TOTAL time, so we compute the average ourselves + breakpoint() From b2d829661945dc5c4006b1013685e0004facb2ca Mon Sep 17 00:00:00 2001 From: alberthli Date: Mon, 4 Dec 2023 23:03:56 -0800 Subject: [PATCH 21/28] tests for cost function and its derivatives --- ambersim/trajopt/base.py | 30 +++++++---- ambersim/trajopt/cost.py | 100 ++++++++++++++++++++++--------------- tests/trajopt/test_cost.py | 55 ++++++++++++++++++++ 3 files changed, 136 insertions(+), 49 deletions(-) create mode 100644 tests/trajopt/test_cost.py diff --git a/ambersim/trajopt/base.py b/ambersim/trajopt/base.py index 835378e9..7e4bb5d4 100644 --- a/ambersim/trajopt/base.py +++ b/ambersim/trajopt/base.py @@ -72,8 +72,7 @@ def optimize(self, params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Ar params: The parameters of the trajectory optimizer. Returns: - qs_star (shape=(N + 1, nq) or (?)): The optimized trajectory. - vs_star (shape=(N + 1, nv) or (?)): The optimized generalized velocities. + xs_star (shape=(N + 1, nq + nv) or (?)): The optimized trajectory. us_star (shape=(N, nu) or (?)): The optimized controls. """ raise NotImplementedError @@ -136,27 +135,38 @@ def grad( gcost_params: The gradient of the cost wrt params. new_params: The updated parameters of the cost function. """ - return grad(self.cost, argnums=(0, 1, 2))(xs, us, params) + (params,) + _fn = lambda xs, us, params: self.cost(xs, us, params)[0] # only differentiate wrt the cost val + return grad(_fn, argnums=(0, 1, 2))(xs, us, params) + (params,) def hess( self, xs: jax.Array, us: jax.Array, params: CostFunctionParams - ) -> Tuple[jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: + ) -> Tuple[ + jax.Array, jax.Array, CostFunctionParams, jax.Array, CostFunctionParams, CostFunctionParams, CostFunctionParams + ]: """Computes the Hessian of the cost of a trajectory. The default implementation of this function uses JAX's autodiff. Simply override this function if you would like to supply an analytical Hessian. + Let t, s be times from 0 to N + 1. Then, d^2H/da_{t,i}db_{s,j} = Hcost_asbs[t, i, s, j]. + Args: xs (shape=(N + 1, nq + nv)): The state trajectory. us (shape=(N, nu)): The controls over the trajectory. params: The parameters of the cost function. Returns: - Hcost_xs (shape=(N + 1, nq + nv, N + 1, nq + nv)): The Hessian of the cost wrt xs. - Let t, s be times from 0 to N + 1. Then, d^2/dx_{t,i}dx_{s,j} = Hcost_xs[t, i, s, j]. - Hcost_us (shape=(N, nu, N, nu)): The Hessian of the cost wrt us. - Let t, s be times from 0 to N. Then, d^2/du_{t,i}du_{s,j} = Hcost_us[t, i, s, j]. - Hcost_params: The Hessian of the cost wrt params. + Hcost_xsxs (shape=(N + 1, nq + nv, N + 1, nq + nv)): The Hessian of the cost wrt xs. + Hcost_xsus (shape=(N + 1, nq + nv, N, nu)): The Hessian of the cost wrt xs and us. + Hcost_xsparams: The Hessian of the cost wrt xs and params. + Hcost_usus (shape=(N, nu, N, nu)): The Hessian of the cost wrt us. + Hcost_usparams: The Hessian of the cost wrt us and params. + Hcost_paramsall: The Hessian of the cost wrt params and everything else. new_params: The updated parameters of the cost function. """ - return hessian(self.cost, argnums=(0, 1, 2))(xs, us, params) + (params,) + _fn = lambda xs, us, params: self.cost(xs, us, params)[0] # only differentiate wrt the cost val + hessians = hessian(_fn, argnums=(0, 1, 2))(xs, us, params) + Hcost_xsxs, Hcost_xsus, Hcost_xsparams = hessians[0] + _, Hcost_usus, Hcost_usparams = hessians[1] + Hcost_paramsall = hessians[2] + return Hcost_xsxs, Hcost_xsus, Hcost_xsparams, Hcost_usus, Hcost_usparams, Hcost_paramsall, params diff --git a/ambersim/trajopt/cost.py b/ambersim/trajopt/cost.py index feb579bb..cd3f502f 100644 --- a/ambersim/trajopt/cost.py +++ b/ambersim/trajopt/cost.py @@ -3,6 +3,7 @@ import jax import jax.numpy as jnp from flax import struct +from jax import lax, vmap from ambersim.trajopt.base import CostFunction, CostFunctionParams @@ -32,30 +33,6 @@ def __init__(self, Q: jax.Array, Qf: jax.Array, R: jax.Array, xg: jax.Array) -> self.R = R self.xg = xg - # @staticmethod - # def _setup_util( - # xs: jax.Array, qg: jax.Array, vg: jax.Array - # ) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array, int, int, int]: - # """Utility function that sets up the cost function. - - # Args: - # xs (shape=(N + 1, nq + nv)): The state trajectory - # xg (shape=(nq + nv,)): The goal state. - - # Returns: - # xs (shape=(N + 1, nx)): The states over the trajectory. - # xg (shape=(nx,)): The goal state. - # xs_err (shape=(N, nx)): The state errors up to the final state. - # xf_err (shape=(nx,)): The state error at the final state. - # nq: The number of generalized coordinates. - # nv: The number of generalized velocities. - # """ - # xs = jnp.concatenate((qs, vs), axis=-1) - # xg = jnp.concatenate((qg, vg), axis=-1) - # xs_err = xs[:-1, :] - xg - # xf_err = xs[-1, :] - xg - # return xs, xg, xs_err, xf_err, qs.shape[-1], vs.shape[-1] - @staticmethod def batch_quadform(bs: jax.Array, A: jax.Array) -> jax.Array: """Computes a batched quadratic form for a single instance of A. @@ -98,8 +75,12 @@ def cost(self, xs: jax.Array, us: jax.Array, params: CostFunctionParams) -> Tupl """ xs_err = xs[:-1, :] - self.xg # errors before the terminal state xf_err = xs[-1, :] - self.xg - val = 0.5 * ( - self.batch_quadform(xs_err, self.Q) + self.batch_quadform(xf_err, self.Qf) + self.batch_quadform(us, self.R) + val = 0.5 * jnp.squeeze( + ( + jnp.sum(self.batch_quadform(xs_err, self.Q)) + + self.batch_quadform(xf_err, self.Qf) + + jnp.sum(self.batch_quadform(us, self.R)) + ) ) return val, params @@ -126,31 +107,72 @@ def grad( self.batch_matmul(xs_err, self.Q), (self.Qf @ xf_err)[None, :], ), - axis=-1, + axis=-2, ) gcost_us = self.batch_matmul(us, self.R) return gcost_xs, gcost_us, params, params def hess( self, xs: jax.Array, us: jax.Array, params: CostFunctionParams - ) -> Tuple[jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: + ) -> Tuple[ + jax.Array, jax.Array, CostFunctionParams, jax.Array, CostFunctionParams, CostFunctionParams, CostFunctionParams + ]: """Computes the gradient of the cost of a trajectory. + Let t, s be times from 0 to N + 1. Then, d^2H/da_{t,i}db_{s,j} = Hcost_asbs[t, i, s, j]. + Args: xs (shape=(N + 1, nq + nv)): The state trajectory. us (shape=(N, nu)): The controls over the trajectory. params: Unused. Included for API compliance. Returns: - Hcost_xs (shape=(N + 1, nq + nv, N + 1, nq + nv)): The Hessian of the cost wrt xs. - Let t, s be times from 0 to N + 1. Then, d^2/dx_{t,i}dx_{s,j} = Hcost_xs[t, i, s, j]. - Hcost_us (shape=(N, nu, N, nu)): The Hessian of the cost wrt us. - Let t, s be times from 0 to N. Then, d^2/du_{t,i}du_{s,j} = Hcost_us[t, i, s, j]. - Hcost_params: Unused. Included for API compliance. - new_params: Unused. Included for API compliance. + Hcost_xsxs (shape=(N + 1, nq + nv, N + 1, nq + nv)): The Hessian of the cost wrt xs. + Hcost_xsus (shape=(N + 1, nq + nv, N, nu)): The Hessian of the cost wrt xs and us. + Hcost_xsparams: The Hessian of the cost wrt xs and params. + Hcost_usus (shape=(N, nu, N, nu)): The Hessian of the cost wrt us. + Hcost_usparams: The Hessian of the cost wrt us and params. + Hcost_paramsall: The Hessian of the cost wrt params and everything else. + new_params: The updated parameters of the cost function. """ - N = us.shape[0] - Q_tiled = jnp.tile(self.Q[None, :, None, :], (N + 1, 1, N + 1, 1)) - Hcost_xs = Q_tiled.at[-1, :, -1, :].set(self.Qf) - Hcost_us = jnp.tile(self.R[None, :, None, :], (N, 1, N, 1)) - return Hcost_xs, Hcost_us, params, params + # setting up + nx = self.Q.shape[0] + N, nu = us.shape + Q = self.Q + Qf = self.Qf + R = self.R + dummy_params = CostFunctionParams() + + # Hessian for state + Hcost_xsxs = jnp.zeros((N + 1, nx, N + 1, nx)) + Hcost_xsxs = vmap( + lambda i: lax.dynamic_update_slice( + jnp.zeros((nx, N + 1, nx)), + Q[:, None, :], + (0, i, 0), + ) + )( + jnp.arange(N + 1) + ) # only the terms [i, :, i, :] are nonzero + Hcost_xsxs = Hcost_xsxs.at[-1, :, -1, :].set(Qf) # last one is different + + # trivial cross-terms of Hessian + Hcost_xsus = jnp.zeros((N + 1, nx, N, nu)) + Hcost_xsparams = dummy_params + + # Hessian for control inputs + Hcost_usus = jnp.zeros((N, nu, N, nu)) + Hcost_usus = vmap( + lambda i: lax.dynamic_update_slice( + jnp.zeros((nu, N, nu)), + R[:, None, :], + (0, i, 0), + ) + )( + jnp.arange(N) + ) # only the terms [i, :, i, :] are nonzero + + # trivial cross-terms and Hessian for params + Hcost_usparams = dummy_params + Hcost_paramsall = dummy_params + return Hcost_xsxs, Hcost_xsus, Hcost_xsparams, Hcost_usus, Hcost_usparams, Hcost_paramsall, params diff --git a/tests/trajopt/test_cost.py b/tests/trajopt/test_cost.py new file mode 100644 index 00000000..3eb0c8de --- /dev/null +++ b/tests/trajopt/test_cost.py @@ -0,0 +1,55 @@ +import jax +import jax.numpy as jnp +from jax import hessian, jacobian + +from ambersim.trajopt.base import CostFunctionParams +from ambersim.trajopt.cost import StaticGoalQuadraticCost +from ambersim.utils.io_utils import load_mjx_model_and_data_from_file + + +def test_sgqc(): + """Tests that the StaticGoalQuadraticCost works correctly.""" + # loading model and cost function + model, _ = load_mjx_model_and_data_from_file("models/barrett_hand/bh280.xml", force_float=False) + cost_function = StaticGoalQuadraticCost( + Q=jnp.eye(model.nq + model.nv), + Qf=10.0 * jnp.eye(model.nq + model.nv), + R=0.01 * jnp.eye(model.nu), + xg=jnp.zeros(model.nq + model.nv), + ) + + # generating dummy data + N = 10 + key = jax.random.PRNGKey(0) + xs = jax.random.normal(key=key, shape=(N + 1, model.nq + model.nv)) + us = jax.random.normal(key=key, shape=(N, model.nu)) + + # comparing cost value vs. ground truth for loop + val_test, _ = cost_function.cost(xs, us, params=CostFunctionParams()) + val_gt = 0.0 + for i in range(N): + xs_err = xs[i, :] - cost_function.xg + val_gt += 0.5 * (xs_err @ cost_function.Q @ xs_err + us[i, :] @ cost_function.R @ us[i, :]) + xs_err = xs[N, :] - cost_function.xg + val_gt += 0.5 * (xs_err @ cost_function.Qf @ xs_err) + val_gt = jnp.squeeze(val_gt) + assert jnp.allclose(val_test, val_gt) + + # comparing cost gradients vs. jax autodiff + gcost_xs_test, gcost_us_test, _, _ = cost_function.grad(xs, us, params=CostFunctionParams()) + gcost_xs_gt, gcost_us_gt, _, _ = super(StaticGoalQuadraticCost, cost_function).grad( + xs, us, params=CostFunctionParams() + ) + assert jnp.allclose(gcost_xs_test, gcost_xs_gt) + assert jnp.allclose(gcost_us_test, gcost_us_gt) + + # comparing cost Hessians vs. jax autodiff + Hcost_xsxs_test, Hcost_xsus_test, _, Hcost_usus_test, _, _, _ = cost_function.hess( + xs, us, params=CostFunctionParams() + ) + Hcost_xsxs_gt, Hcost_xsus_gt, _, Hcost_usus_gt, _, _, _ = super(StaticGoalQuadraticCost, cost_function).hess( + xs, us, params=CostFunctionParams() + ) + assert jnp.allclose(Hcost_xsxs_test, Hcost_xsxs_gt) + assert jnp.allclose(Hcost_xsus_test, Hcost_xsus_gt) + assert jnp.allclose(Hcost_usus_test, Hcost_usus_gt) From 0e24057b9966c85140b9ec6cc04ff4613038c325 Mon Sep 17 00:00:00 2001 From: alberthli Date: Mon, 4 Dec 2023 23:13:07 -0800 Subject: [PATCH 22/28] added smoke test for vanilla predictive sampler --- tests/trajopt/test_predictive_sampler.py | 45 ++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 tests/trajopt/test_predictive_sampler.py diff --git a/tests/trajopt/test_predictive_sampler.py b/tests/trajopt/test_predictive_sampler.py new file mode 100644 index 00000000..f78b735f --- /dev/null +++ b/tests/trajopt/test_predictive_sampler.py @@ -0,0 +1,45 @@ +import jax +import jax.numpy as jnp +from jax import jit +from mujoco.mjx._src.types import DisableBit + +from ambersim.trajopt.cost import StaticGoalQuadraticCost +from ambersim.trajopt.shooting import VanillaPredictiveSampler, VanillaPredictiveSamplerParams +from ambersim.utils.io_utils import load_mjx_model_and_data_from_file + + +def test_smoke_VPS(): + """Simple smoke test to make sure we can run inputs through the vanilla predictive sampler + jit.""" + # initializing the predictive sampler + model, _ = load_mjx_model_and_data_from_file("models/barrett_hand/bh280.xml", force_float=False) + model = model.replace( + opt=model.opt.replace( + timestep=0.002, # dt + iterations=1, # number of Newton steps to take during solve + ls_iterations=4, # number of line search iterations along step direction + integrator=0, # Euler semi-implicit integration + solver=2, # Newton solver + disableflags=DisableBit.CONTACT, # disable contact for this test + ) + ) + cost_function = StaticGoalQuadraticCost( + Q=jnp.eye(model.nq + model.nv), + Qf=10.0 * jnp.eye(model.nq + model.nv), + R=0.01 * jnp.eye(model.nu), + xg=jnp.zeros(model.nq + model.nv), + ) + nsamples = 100 + stdev = 0.01 + ps = VanillaPredictiveSampler(model=model, cost_function=cost_function, nsamples=nsamples, stdev=stdev) + + # sampler parameters + key = jax.random.PRNGKey(0) # random seed for the predictive sampler + # x0 = jnp.zeros(model.nq + model.nv).at[6].set(1.0) # if force_float=True + x0 = jnp.zeros(model.nq + model.nv) + num_steps = 10 + us_guess = jnp.zeros((num_steps, model.nu)) + params = VanillaPredictiveSamplerParams(key=key, x0=x0, us_guess=us_guess) + + # sampling the best sequence of qs, vs, and us + optimize_fn = jit(ps.optimize) + assert optimize_fn(params) From adf4243fb55eaba2bc396f74aaed4a99408f6d69 Mon Sep 17 00:00:00 2001 From: alberthli Date: Mon, 4 Dec 2023 23:16:31 -0800 Subject: [PATCH 23/28] remove the example since it's implemented as a benchmark in an upstream PR --- examples/trajopt/ex_predictive_sampling.py | 75 ---------------------- 1 file changed, 75 deletions(-) delete mode 100644 examples/trajopt/ex_predictive_sampling.py diff --git a/examples/trajopt/ex_predictive_sampling.py b/examples/trajopt/ex_predictive_sampling.py deleted file mode 100644 index 12f9686a..00000000 --- a/examples/trajopt/ex_predictive_sampling.py +++ /dev/null @@ -1,75 +0,0 @@ -import timeit - -import jax -import jax.numpy as jnp -from jax import jit -from mujoco.mjx._src.types import DisableBit - -from ambersim.trajopt.cost import StaticGoalQuadraticCost -from ambersim.trajopt.shooting import VanillaPredictiveSampler, VanillaPredictiveSamplerParams -from ambersim.utils.io_utils import load_mjx_model_and_data_from_file - -if __name__ == "__main__": - # initializing the predictive sampler - model, _ = load_mjx_model_and_data_from_file("models/barrett_hand/bh280.xml", force_float=False) - model = model.replace( - opt=model.opt.replace( - timestep=0.002, # dt - iterations=1, # number of Newton steps to take during solve - ls_iterations=4, # number of line search iterations along step direction - integrator=0, # Euler semi-implicit integration - solver=2, # Newton solver - disableflags=DisableBit.CONTACT, # [IMPORTANT] disable contact for this example - ) - ) - cost_function = StaticGoalQuadraticCost( - Q=jnp.eye(model.nq + model.nv), - Qf=10.0 * jnp.eye(model.nq + model.nv), - R=0.01 * jnp.eye(model.nu), - # xg=jnp.zeros(model.nq + model.nv).at[6].set(1.0), # if force_float=True - xg=jnp.zeros(model.nq + model.nv), - ) - nsamples = 100 - stdev = 0.01 - ps = VanillaPredictiveSampler(model=model, cost_function=cost_function, nsamples=nsamples, stdev=stdev) - - # sampler parameters - key = jax.random.PRNGKey(0) # random seed for the predictive sampler - # x0 = jnp.zeros(model.nq + model.nv).at[6].set(1.0) # if force_float=True - x0 = jnp.zeros(model.nq + model.nv) - num_steps = 10 - us_guess = jnp.zeros((num_steps, model.nu)) - params = VanillaPredictiveSamplerParams(key=key, x0=x0, us_guess=us_guess) - - # sampling the best sequence of qs, vs, and us - optimize_fn = jit(ps.optimize) - - # [DEBUG] profiling with nsight systems - # xs_star, us_star = optimize_fn(params) # JIT compiling - # with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): - # xs_star us_star = optimize_fn(params) # after JIT - - def _time_fn(): - xs_star, us_star = optimize_fn(params) - xs_star.block_until_ready() - us_star.block_until_ready() - - compile_time = timeit.timeit(_time_fn, number=1) - print(f"Compile time: {compile_time}") - - # informal timing test - # TODO(ahl): identify bottlenecks and zap them - # [Dec. 3, 2023] on vulcan, I've informally tested the scaling of runtime with the number of steps and the number - # of samples. Here are a few preliminary results: - # * nsamples=100, numsteps=10. avg: 0.01s - # * nsamples=1000, numsteps=10. avg: 0.015s - # * nsamples=10000, numsteps=10. avg: 0.07s - # * nsamples=100, numsteps=100. avg: 0.1s - # we conclude that the runtime scales predictably linearly with numsteps, but we also have some sort of (perhaps - # logarithmic) scaling of runtime with nsamples. this outlook is somewhat grim, and we need to also keep in mind - # that we've completely disabled contact for this example and set the number of solver iterations and line search - # iterations to very runtime-friendly values - num_timing_iters = 100 - time = timeit.timeit(_time_fn, number=num_timing_iters) - print(f"Avg. runtime: {time / num_timing_iters}") # timeit returns TOTAL time, so we compute the average ourselves - breakpoint() From 773be6a4eb6378f6fe8ea3d93476d30ac2d826c6 Mon Sep 17 00:00:00 2001 From: alberthli Date: Mon, 4 Dec 2023 23:18:37 -0800 Subject: [PATCH 24/28] minor docstring edit --- ambersim/trajopt/base.py | 2 +- ambersim/trajopt/cost.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ambersim/trajopt/base.py b/ambersim/trajopt/base.py index 7e4bb5d4..2e50975e 100644 --- a/ambersim/trajopt/base.py +++ b/ambersim/trajopt/base.py @@ -148,7 +148,7 @@ def hess( The default implementation of this function uses JAX's autodiff. Simply override this function if you would like to supply an analytical Hessian. - Let t, s be times from 0 to N + 1. Then, d^2H/da_{t,i}db_{s,j} = Hcost_asbs[t, i, s, j]. + Let t, s be times 0, 1, 2, etc. Then, d^2H/da_{t,i}db_{s,j} = Hcost_asbs[t, i, s, j]. Args: xs (shape=(N + 1, nq + nv)): The state trajectory. diff --git a/ambersim/trajopt/cost.py b/ambersim/trajopt/cost.py index cd3f502f..fa560128 100644 --- a/ambersim/trajopt/cost.py +++ b/ambersim/trajopt/cost.py @@ -119,7 +119,7 @@ def hess( ]: """Computes the gradient of the cost of a trajectory. - Let t, s be times from 0 to N + 1. Then, d^2H/da_{t,i}db_{s,j} = Hcost_asbs[t, i, s, j]. + Let t, s be times 0, 1, 2, etc. Then, d^2H/da_{t,i}db_{s,j} = Hcost_asbs[t, i, s, j]. Args: xs (shape=(N + 1, nq + nv)): The state trajectory. From 544f9643f1f2e85bca07b8c4fc55afea6567fc34 Mon Sep 17 00:00:00 2001 From: alberthli Date: Mon, 4 Dec 2023 23:25:49 -0800 Subject: [PATCH 25/28] ensure predictive sampler also accounts for cost of guess vs. samples --- ambersim/trajopt/shooting.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/ambersim/trajopt/shooting.py b/ambersim/trajopt/shooting.py index baea9d9f..414917a1 100644 --- a/ambersim/trajopt/shooting.py +++ b/ambersim/trajopt/shooting.py @@ -136,19 +136,16 @@ def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, j N = params.N key = params.key - # sample over the control inputs - _us_samples = us_guess + jax.random.normal(key, shape=(nsamples, N, m.nu)) * stdev + # sample over the control inputs - the first sample is the guess, since it's possible that it's the best one + noise = jnp.concatenate( + (jnp.zeros((1, N, m.nu)), jax.random.normal(key, shape=(nsamples - 1, N, m.nu)) * stdev), axis=0 + ) + _us_samples = us_guess + noise # clamping the samples to their control limits - # TODO(ahl): write a create classmethod that allows the user to set default_limits optionally with some semi- - # reasonable default value - # TODO(ahl): check whether joints with no limits have reasonable defaults for m.actuator_ctrlrange limits = m.actuator_ctrlrange clip_fn = partial(jnp.clip, a_min=limits[:, 0], a_max=limits[:, 1]) # clipping function with limits already set us_samples = vmap(vmap(clip_fn))(_us_samples) # apply limits only to the last dim, need a nested vmap - # limited = m.actuator_ctrllimited[:, None] # (nu, 1) whether each actuator has limited control authority - # default_limits = jnp.array([[-1000.0, 1000.0]] * m.nu) # (nu, 2) default limits for each actuator - # limits = jnp.where(limited, m.actuator_ctrlrange, default_limits) # (nu, 2) # predict many samples, evaluate them, and return the best trajectory tuple # vmap over the input data and the control trajectories From 4e22073b2370858a197741e621e95df47879c5fa Mon Sep 17 00:00:00 2001 From: alberthli Date: Mon, 4 Dec 2023 23:54:33 -0800 Subject: [PATCH 26/28] add sanity check for predictive sampling --- tests/trajopt/test_predictive_sampler.py | 46 +++++++++++++++++++++--- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/tests/trajopt/test_predictive_sampler.py b/tests/trajopt/test_predictive_sampler.py index f78b735f..0d29a1b8 100644 --- a/tests/trajopt/test_predictive_sampler.py +++ b/tests/trajopt/test_predictive_sampler.py @@ -1,15 +1,16 @@ import jax import jax.numpy as jnp -from jax import jit +from jax import jit, vmap from mujoco.mjx._src.types import DisableBit +from ambersim.trajopt.base import CostFunctionParams from ambersim.trajopt.cost import StaticGoalQuadraticCost -from ambersim.trajopt.shooting import VanillaPredictiveSampler, VanillaPredictiveSamplerParams +from ambersim.trajopt.shooting import VanillaPredictiveSampler, VanillaPredictiveSamplerParams, shoot from ambersim.utils.io_utils import load_mjx_model_and_data_from_file -def test_smoke_VPS(): - """Simple smoke test to make sure we can run inputs through the vanilla predictive sampler + jit.""" +def _make_vps_data(): + """Makes data required for testing vanilla predictive sampling.""" # initializing the predictive sampler model, _ = load_mjx_model_and_data_from_file("models/barrett_hand/bh280.xml", force_float=False) model = model.replace( @@ -31,10 +32,15 @@ def test_smoke_VPS(): nsamples = 100 stdev = 0.01 ps = VanillaPredictiveSampler(model=model, cost_function=cost_function, nsamples=nsamples, stdev=stdev) + return ps, model, cost_function + + +def test_smoke_VPS(): + """Simple smoke test to make sure we can run inputs through the vanilla predictive sampler + jit.""" + ps, model, _ = _make_vps_data() # sampler parameters key = jax.random.PRNGKey(0) # random seed for the predictive sampler - # x0 = jnp.zeros(model.nq + model.nv).at[6].set(1.0) # if force_float=True x0 = jnp.zeros(model.nq + model.nv) num_steps = 10 us_guess = jnp.zeros((num_steps, model.nu)) @@ -43,3 +49,33 @@ def test_smoke_VPS(): # sampling the best sequence of qs, vs, and us optimize_fn = jit(ps.optimize) assert optimize_fn(params) + + +def test_VPS_cost_decrease(): + """Tests to make sure vanilla predictive sampling decreases (or maintains) the cost.""" + # set up sampler and cost function + ps, model, cost_function = _make_vps_data() + + # batched sampler parameters + batch_size = 10 + key = jax.random.PRNGKey(0) # random seed for the predictive sampler + x0 = jax.random.normal(key=key, shape=(batch_size, model.nq + model.nv)) + + key, subkey = jax.random.split(key) + num_steps = 10 + us_guess = jax.random.normal(key=subkey, shape=(batch_size, num_steps, model.nu)) + + keys = jax.random.split(key, num=batch_size) + params = VanillaPredictiveSamplerParams(key=keys, x0=x0, us_guess=us_guess) + + # sampling with the vanilla predictive sampler + xs_stars, us_stars = vmap(ps.optimize)(params) + + # "optimal" rollout from predictive sampling + vmap_cost = vmap(lambda xs, us: cost_function.cost(xs, us, CostFunctionParams())[0], in_axes=(0, 0)) + costs_star = vmap_cost(xs_stars, us_stars) + + # simply shooting the random initial guess + xs_guess = vmap(shoot, in_axes=(None, 0, 0))(model, x0, us_guess) + costs_guess = vmap_cost(xs_guess, us_guess) + assert jnp.all(costs_star <= costs_guess) From 0c49505398e49f950e9996fb95b6628860effad9 Mon Sep 17 00:00:00 2001 From: alberthli Date: Tue, 5 Dec 2023 12:21:45 -0800 Subject: [PATCH 27/28] added fixture to tests --- tests/trajopt/test_predictive_sampler.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/trajopt/test_predictive_sampler.py b/tests/trajopt/test_predictive_sampler.py index 0d29a1b8..52fab13a 100644 --- a/tests/trajopt/test_predictive_sampler.py +++ b/tests/trajopt/test_predictive_sampler.py @@ -1,5 +1,8 @@ +import os + import jax import jax.numpy as jnp +import pytest from jax import jit, vmap from mujoco.mjx._src.types import DisableBit @@ -8,8 +11,11 @@ from ambersim.trajopt.shooting import VanillaPredictiveSampler, VanillaPredictiveSamplerParams, shoot from ambersim.utils.io_utils import load_mjx_model_and_data_from_file +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" # fixes OOM error + -def _make_vps_data(): +@pytest.fixture +def vps_data(): """Makes data required for testing vanilla predictive sampling.""" # initializing the predictive sampler model, _ = load_mjx_model_and_data_from_file("models/barrett_hand/bh280.xml", force_float=False) @@ -35,9 +41,9 @@ def _make_vps_data(): return ps, model, cost_function -def test_smoke_VPS(): +def test_smoke_VPS(vps_data): """Simple smoke test to make sure we can run inputs through the vanilla predictive sampler + jit.""" - ps, model, _ = _make_vps_data() + ps, model, _ = vps_data # sampler parameters key = jax.random.PRNGKey(0) # random seed for the predictive sampler @@ -51,10 +57,10 @@ def test_smoke_VPS(): assert optimize_fn(params) -def test_VPS_cost_decrease(): +def test_VPS_cost_decrease(vps_data): """Tests to make sure vanilla predictive sampling decreases (or maintains) the cost.""" # set up sampler and cost function - ps, model, cost_function = _make_vps_data() + ps, model, cost_function = vps_data # batched sampler parameters batch_size = 10 @@ -72,7 +78,7 @@ def test_VPS_cost_decrease(): xs_stars, us_stars = vmap(ps.optimize)(params) # "optimal" rollout from predictive sampling - vmap_cost = vmap(lambda xs, us: cost_function.cost(xs, us, CostFunctionParams())[0], in_axes=(0, 0)) + vmap_cost = jit(vmap(lambda xs, us: cost_function.cost(xs, us, CostFunctionParams()))[0], in_axes=(0, 0)) costs_star = vmap_cost(xs_stars, us_stars) # simply shooting the random initial guess From 90ea8f4ced3eedf90fc7c4ca66ea382e54e7edad Mon Sep 17 00:00:00 2001 From: alberthli Date: Tue, 5 Dec 2023 13:00:40 -0800 Subject: [PATCH 28/28] fix stray parenthetical --- tests/trajopt/test_predictive_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trajopt/test_predictive_sampler.py b/tests/trajopt/test_predictive_sampler.py index 52fab13a..7111fc05 100644 --- a/tests/trajopt/test_predictive_sampler.py +++ b/tests/trajopt/test_predictive_sampler.py @@ -78,7 +78,7 @@ def test_VPS_cost_decrease(vps_data): xs_stars, us_stars = vmap(ps.optimize)(params) # "optimal" rollout from predictive sampling - vmap_cost = jit(vmap(lambda xs, us: cost_function.cost(xs, us, CostFunctionParams()))[0], in_axes=(0, 0)) + vmap_cost = jit(vmap(lambda xs, us: cost_function.cost(xs, us, CostFunctionParams())[0], in_axes=(0, 0))) costs_star = vmap_cost(xs_stars, us_stars) # simply shooting the random initial guess