diff --git a/ambersim/control/base.py b/ambersim/control/base.py new file mode 100644 index 00000000..1a2c7a4a --- /dev/null +++ b/ambersim/control/base.py @@ -0,0 +1,32 @@ +import jax +from flax import struct + + +@struct.dataclass +class ControllerParams: + """The parameters for generic controllers. + + This is left completely empty for maximum flexibility in the API. Some examples: + - "Regular" inputs into feedback controllers (e.g., the state) belong here. + - Non-Markovian controllers can pass histories in this params object. + - Parameters of the controller that you may randomize/optimize go here. + """ + + +@struct.dataclass +class Controller: + """The API for a generic controller. + + See the notes in TrajectoryOptimizer on the generality of this class - much of the same applies. + """ + + def compute(self, ctrl_params: ControllerParams) -> jax.Array: + """Computes a control input. + + Args: + ctrl_params: ControllerParams + + Returns: + u (shape=(nu,)): The control input. + """ + raise NotImplementedError diff --git a/ambersim/control/predictive_control.py b/ambersim/control/predictive_control.py new file mode 100644 index 00000000..4946d6d9 --- /dev/null +++ b/ambersim/control/predictive_control.py @@ -0,0 +1,149 @@ +from typing import Tuple + +import jax +import jax.numpy as jnp +from flax import struct +from mujoco import mjx + +from ambersim.control.base import Controller, ControllerParams +from ambersim.trajopt.base import TrajectoryOptimizer +from ambersim.trajopt.shooting import PDPredictiveSamplerParams, PredictiveSampler, VanillaPredictiveSamplerParams + +# ########### # +# GENERIC API # +# ########### # + + +@struct.dataclass +class PredictiveControllerParams(ControllerParams): + """The generic API for predictive controller params.""" + + +@struct.dataclass +class PredictiveController(Controller): + """The generic API for a predictive controller.""" + + trajectory_optimizer: TrajectoryOptimizer + model: mjx.Model + + def compute(self, ctrl_params: PredictiveControllerParams) -> jax.Array: + """Computes a control input using forward prediction.""" + raise NotImplementedError + + +# ################### # +# PREDICTIVE SAMPLING # +# ################### # + + +@struct.dataclass +class VanillaPredictiveSamplingControllerParams(PredictiveControllerParams): + """Vanilla predictive sampling controller params.""" + + key: jax.Array # random key for sampling + x: jax.Array # shape=(nq + nv,) current state + guess: jax.Array # shape=(N, nu) current guess + + +@struct.dataclass +class VanillaPredictiveSamplingController(PredictiveController): + """Vanilla predictive sampling controller.""" + + def __post_init__(self) -> None: + """Post-initialization check.""" + assert isinstance( + self.trajectory_optimizer, PredictiveSampler + ), "trajectory_optimizer must be a PredictiveSampler!" + + def compute(self, ctrl_params: VanillaPredictiveSamplingControllerParams) -> jax.Array: + """Computes a control input using forward prediction. + + Args: + ctrl_params: Inputs into the controller. + + Returns: + u (shape=(nu,)): The control input. + """ + return self.compute_with_us_star(ctrl_params)[0] + + def compute_with_us_star( + self, ctrl_params: VanillaPredictiveSamplingControllerParams + ) -> Tuple[jax.Array, jax.Array]: + """Computes a control input using forward prediction + the optimal sequence of guesses. + + This is needed in practice because the current optimal sequence is used to warm start the sampling distribution + for the next call of the controller. + + Args: + ctrl_params: Inputs into the controller. + + Returns: + u (shape=(nu,)): The control input. + us_star (shape=(N, nu)): The optimal control sequence. + """ + to_params = VanillaPredictiveSamplerParams( + key=ctrl_params.key, + x0=ctrl_params.x, + guess=ctrl_params.guess, + ) + xs_star, us_star = self.trajectory_optimizer.optimize(to_params) + u = us_star[0, :] + return u, us_star + + +@struct.dataclass +class PDPredictiveSamplingControllerParams(PredictiveControllerParams): + """PD predictive sampling controller params.""" + + key: jax.Array # random key for sampling + x: jax.Array # shape=(nq + nv,) current state + guess: jax.Array # shape=(N, nq) current guess + kp: float # proportional gain + kd: float # derivative gain + + +@struct.dataclass +class PDPredictiveSamplingController(PredictiveController): + """PD predictive sampling controller.""" + + def __post_init__(self) -> None: + """Post-initialization check.""" + assert isinstance( + self.trajectory_optimizer, PredictiveSampler + ), "trajectory_optimizer must be a PredictiveSampler!" + + def compute(self, ctrl_params: PDPredictiveSamplingControllerParams) -> jax.Array: + """Computes a control input using forward prediction. + + Args: + ctrl_params: Inputs into the controller. + + Returns: + u (shape=(nu,)): The control input. + """ + return self.compute_with_qs_star(ctrl_params)[0] + + def compute_with_qs_star(self, ctrl_params: PDPredictiveSamplingControllerParams) -> Tuple[jax.Array, jax.Array]: + """Computes a control input using forward prediction + the optimal sequence of guesses. + + This is needed in practice because the current optimal sequence is used to warm start the sampling distribution + for the next call of the controller. + + Args: + ctrl_params: Inputs into the controller. + + Returns: + u (shape=(nu,)): The control input. + us_star (shape=(N, nu)): The optimal control sequence. + """ + to_params = PDPredictiveSamplerParams( + key=ctrl_params.key, + x0=ctrl_params.x, + guess=ctrl_params.guess, + kp=ctrl_params.kp, + kd=ctrl_params.kd, + ) + xs_star, us_star = self.trajectory_optimizer.optimize(to_params) + # u = us_star[0, :] + qs_star = xs_star[:, : self.model.nq] # the 0th index is the current state, so return the 1st index + return qs_star[1, :], qs_star diff --git a/ambersim/models/allegro_hand/assets/base_link.stl b/ambersim/models/allegro_hand/assets/base_link.stl new file mode 100644 index 00000000..5901473f Binary files /dev/null and b/ambersim/models/allegro_hand/assets/base_link.stl differ diff --git a/ambersim/models/allegro_hand/assets/base_link_left.stl b/ambersim/models/allegro_hand/assets/base_link_left.stl new file mode 100644 index 00000000..365d7e7f Binary files /dev/null and b/ambersim/models/allegro_hand/assets/base_link_left.stl differ diff --git a/ambersim/models/allegro_hand/assets/fileback.png b/ambersim/models/allegro_hand/assets/fileback.png new file mode 100644 index 00000000..e03322fc Binary files /dev/null and b/ambersim/models/allegro_hand/assets/fileback.png differ diff --git a/ambersim/models/allegro_hand/assets/filedown.png b/ambersim/models/allegro_hand/assets/filedown.png new file mode 100644 index 00000000..ee48edf6 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/filedown.png differ diff --git a/ambersim/models/allegro_hand/assets/filefront.png b/ambersim/models/allegro_hand/assets/filefront.png new file mode 100644 index 00000000..931e8056 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/filefront.png differ diff --git a/ambersim/models/allegro_hand/assets/fileleft.png b/ambersim/models/allegro_hand/assets/fileleft.png new file mode 100644 index 00000000..da6427bf Binary files /dev/null and b/ambersim/models/allegro_hand/assets/fileleft.png differ diff --git a/ambersim/models/allegro_hand/assets/fileright.png b/ambersim/models/allegro_hand/assets/fileright.png new file mode 100644 index 00000000..6c2cb391 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/fileright.png differ diff --git a/ambersim/models/allegro_hand/assets/fileup.png b/ambersim/models/allegro_hand/assets/fileup.png new file mode 100644 index 00000000..4962dfd0 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/fileup.png differ diff --git a/ambersim/models/allegro_hand/assets/grayback.png b/ambersim/models/allegro_hand/assets/grayback.png new file mode 100644 index 00000000..03ae4bf7 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/grayback.png differ diff --git a/ambersim/models/allegro_hand/assets/graydown.png b/ambersim/models/allegro_hand/assets/graydown.png new file mode 100644 index 00000000..b8b7b314 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/graydown.png differ diff --git a/ambersim/models/allegro_hand/assets/grayfront.png b/ambersim/models/allegro_hand/assets/grayfront.png new file mode 100644 index 00000000..7c45fff0 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/grayfront.png differ diff --git a/ambersim/models/allegro_hand/assets/grayleft.png b/ambersim/models/allegro_hand/assets/grayleft.png new file mode 100644 index 00000000..2625e739 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/grayleft.png differ diff --git a/ambersim/models/allegro_hand/assets/grayright.png b/ambersim/models/allegro_hand/assets/grayright.png new file mode 100644 index 00000000..4516b601 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/grayright.png differ diff --git a/ambersim/models/allegro_hand/assets/grayup.png b/ambersim/models/allegro_hand/assets/grayup.png new file mode 100644 index 00000000..0b07c030 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/grayup.png differ diff --git a/ambersim/models/allegro_hand/assets/link_0.0.stl b/ambersim/models/allegro_hand/assets/link_0.0.stl new file mode 100644 index 00000000..6db5cdb1 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_0.0.stl differ diff --git a/ambersim/models/allegro_hand/assets/link_1.0.stl b/ambersim/models/allegro_hand/assets/link_1.0.stl new file mode 100644 index 00000000..84e43514 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_1.0.stl differ diff --git a/ambersim/models/allegro_hand/assets/link_12.0_left.stl b/ambersim/models/allegro_hand/assets/link_12.0_left.stl new file mode 100644 index 00000000..aeaa37a1 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_12.0_left.stl differ diff --git a/ambersim/models/allegro_hand/assets/link_12.0_right.stl b/ambersim/models/allegro_hand/assets/link_12.0_right.stl new file mode 100644 index 00000000..7eecece5 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_12.0_right.stl differ diff --git a/ambersim/models/allegro_hand/assets/link_13.0.stl b/ambersim/models/allegro_hand/assets/link_13.0.stl new file mode 100644 index 00000000..3ca4ec36 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_13.0.stl differ diff --git a/ambersim/models/allegro_hand/assets/link_14.0.stl b/ambersim/models/allegro_hand/assets/link_14.0.stl new file mode 100644 index 00000000..73d6f695 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_14.0.stl differ diff --git a/ambersim/models/allegro_hand/assets/link_15.0.stl b/ambersim/models/allegro_hand/assets/link_15.0.stl new file mode 100644 index 00000000..e16bff45 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_15.0.stl differ diff --git a/ambersim/models/allegro_hand/assets/link_15.0_tip.stl b/ambersim/models/allegro_hand/assets/link_15.0_tip.stl new file mode 100644 index 00000000..9b1d8119 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_15.0_tip.stl differ diff --git a/ambersim/models/allegro_hand/assets/link_2.0.stl b/ambersim/models/allegro_hand/assets/link_2.0.stl new file mode 100644 index 00000000..20a911a4 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_2.0.stl differ diff --git a/ambersim/models/allegro_hand/assets/link_3.0.stl b/ambersim/models/allegro_hand/assets/link_3.0.stl new file mode 100644 index 00000000..a1fffee0 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_3.0.stl differ diff --git a/ambersim/models/allegro_hand/assets/link_3.0_tip.stl b/ambersim/models/allegro_hand/assets/link_3.0_tip.stl new file mode 100644 index 00000000..b7d9be8a Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_3.0_tip.stl differ diff --git a/ambersim/models/allegro_hand/assets/link_4.0.stl b/ambersim/models/allegro_hand/assets/link_4.0.stl new file mode 100644 index 00000000..ae75f704 Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_4.0.stl differ diff --git a/ambersim/models/allegro_hand/cube.xml b/ambersim/models/allegro_hand/cube.xml new file mode 100644 index 00000000..c857243d --- /dev/null +++ b/ambersim/models/allegro_hand/cube.xml @@ -0,0 +1,29 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/ambersim/models/allegro_hand/left_hand.xml b/ambersim/models/allegro_hand/left_hand.xml new file mode 100644 index 00000000..223d5f6b --- /dev/null +++ b/ambersim/models/allegro_hand/left_hand.xml @@ -0,0 +1,259 @@ + + + + diff --git a/ambersim/models/allegro_hand/right_hand.xml b/ambersim/models/allegro_hand/right_hand.xml new file mode 100644 index 00000000..cc6f4dd0 --- /dev/null +++ b/ambersim/models/allegro_hand/right_hand.xml @@ -0,0 +1,260 @@ + + + + diff --git a/ambersim/models/allegro_hand/scene_left.xml b/ambersim/models/allegro_hand/scene_left.xml new file mode 100644 index 00000000..c1ad48f9 --- /dev/null +++ b/ambersim/models/allegro_hand/scene_left.xml @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/ambersim/models/allegro_hand/scene_right.xml b/ambersim/models/allegro_hand/scene_right.xml new file mode 100644 index 00000000..b1f79cf6 --- /dev/null +++ b/ambersim/models/allegro_hand/scene_right.xml @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/ambersim/sim/simulation.py b/ambersim/sim/simulation.py new file mode 100644 index 00000000..d3da7dae --- /dev/null +++ b/ambersim/sim/simulation.py @@ -0,0 +1,214 @@ +import jax +import jax.numpy as jnp +from jax import jit, lax +from mujoco import mjx + +from ambersim.control.predictive_control import ( + VanillaPredictiveSamplingController, + VanillaPredictiveSamplingControllerParams, +) + +"""This file contains simulation utils for various controllers.""" + + +def simulate_predictive_sampling_controller( + model: mjx.Model, + controller: VanillaPredictiveSamplingController, + x0: jax.Array, + num_steps: int, + N: int, + seed: int = 0, + physics_steps_per_control_step: int = 1, +) -> None: + """Simulates a closed-loop system. + + Args: + model: The "real" model. Properties like the timestep are set in here. + controller: The controller. + x0: The initial state. + num_steps: The number of steps to simulate for. + seed: The random seed. + physics_steps_per_control_step: The number of physics steps to take per control step. + + Returns: + data: The internal data of the model after the simulation. + xs: The state trajectory. + """ + # initial setup + print("Initial setup...") + dt = model.opt.timestep + + data = mjx.make_data(model) + data = data.replace(qpos=x0[: model.nq], qvel=x0[model.nq :]) # setting the initial state. + data = mjx.forward(model, data) # setting other internal states like acceleration without integrating + + xs_hist = jnp.zeros((num_steps + 1, model.nq + model.nv)) + xs_hist = xs_hist.at[0, :].set(x0) + + us_hist = jnp.zeros((num_steps, model.nu)) + key = jax.random.PRNGKey(seed) + + jit_compute = jit( + lambda key, x_meas, us_guess: controller.compute_with_us_star( + VanillaPredictiveSamplingControllerParams(key=key, x=x_meas, guess=us_guess) + ) + ) + jit_step = jit(lambda data, u: mjx.step(model, data.replace(ctrl=u))) + + # us_guess = jnp.zeros((N, model.nu)) + us_guess = jnp.array([0.0, 0.2, 0.2, 0.2, 0.0, 0.2, 0.2, 0.2, 0.0, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1]) + import time + + times = [] + for i in range(num_steps // physics_steps_per_control_step): + start = time.time() + # computing the control input + x_meas = jnp.concatenate((data.qpos, data.qvel)) + if i == 0: + print("Compiling the compute function...") + u, us_guess = jit_compute(key, x_meas, us_guess) + if i == 0: + print("Compiled the compute function!") + end = time.time() + times.append(end - start) + + # simulating the system forward + for j in range(physics_steps_per_control_step): + print(f"t: {(i * physics_steps_per_control_step + j) * dt:.2f} | u: {u}") + if i == 0 and j == 0: + print("Compiling the step function...") + data = jit_step(data, u) # <-- segfaults here! non-jitted version doesn't segfault + if i == 0 and j == 0: + print("Compiled the step function!") + us_hist = us_hist.at[i, :].set(u) + xs_hist = xs_hist.at[i + 1, :].set(jnp.concatenate((data.qpos, data.qvel))) + + key = jax.random.split(key)[0] + + return data, xs_hist, us_hist, times # [DEBUG] return iteration times + + +if __name__ == "__main__": + # TODO(ahl): move this into a proper script + import time + + import mujoco + import mujoco.viewer + import numpy as np + from jax.experimental import mesh_utils + from jax.experimental.shard_map import shard_map + from jax.sharding import Mesh, NamedSharding + from jax.sharding import PartitionSpec as P + from mujoco.mjx._src.types import DisableBit + + from ambersim.trajopt.cost import StaticGoalQuadraticCost + from ambersim.trajopt.shooting import VanillaPredictiveSampler + from ambersim.utils.io_utils import load_mj_model_from_file, mj_to_mjx_model_and_data + + # sharding logic + devices = np.array(jax.devices()) + mesh = Mesh(devices, ("i",)) + + # loading model + defining the predictive controller + # mj_model = load_mj_model_from_file("models/barrett_hand/bh280.xml") + mj_model = load_mj_model_from_file("models/allegro_hand/right_hand.xml") + mj_model.opt.timestep = 0.001 # dt, "framerate of reality" + + ctrl_model, _ = mj_to_mjx_model_and_data(mj_model) + ctrl_model = ctrl_model.replace( + opt=ctrl_model.opt.replace( + timestep=0.015, # dt for each step in the controller's internal model + iterations=1, # number of Newton steps to take during solve + ls_iterations=1, # number of line search iterations along step direction + integrator=0, # Euler semi-implicit integration + solver=2, # Newton solver + # disableflags=DisableBit.CONTACT, # disable contact for this test + ) + ) + cost_function = StaticGoalQuadraticCost( + Q=jnp.diag(jnp.array([10.0] * ctrl_model.nq + [0.01] * ctrl_model.nv)), + Qf=jnp.diag(jnp.array([10.0] * ctrl_model.nq + [0.01] * ctrl_model.nv)), + R=0.001 * jnp.eye(ctrl_model.nu), + xg=jnp.concatenate( + (jnp.array([0.0, 0.5, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 1.0, 0.8, 0.5, 0.5]), jnp.zeros(16)) + ), + ) + nsamples = 100 + stdev = 0.1 + ps = VanillaPredictiveSampler(model=ctrl_model, cost_function=cost_function, nsamples=nsamples, stdev=stdev) + + N = 10 + controller = VanillaPredictiveSamplingController(trajectory_optimizer=ps, model=ctrl_model) + # jit_compute = jit( + # shard_map( + # lambda key, x_meas, us_guess: controller.compute_with_us_star( + # VanillaPredictiveSamplingControllerParams(key=key, x=x_meas, guess=us_guess) + # ), + # mesh=mesh, + # in_specs=(P(), P(), P('i', None)), + # out_specs=(P(), P('i', None)), + # check_rep=False, + # ) + # ) # [DEBUG] + jit_compute = jit( + lambda key, x_meas, us_guess: controller.compute_with_us_star( + VanillaPredictiveSamplingControllerParams(key=key, x=x_meas, guess=us_guess) + ) + ) + print("Controller created! Simulating...") + + # simulating forward + key = jax.random.PRNGKey(1234) + # x0 = jnp.array([0.14021026, 0.04142465, 0.99314054, -0.00491294, -0.00487147, 0.81353587, -0.00491293, -0.00487143, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + x0 = 2.0 * (jax.random.uniform(key, shape=(ctrl_model.nq + ctrl_model.nv,)) - 0.5) + # T = 3.0 + # num_steps = int(T / model.opt.timestep) + # data, xs, us, times = simulate_predictive_sampling_controller( + # model, controller, x0, num_steps, N, seed=1337, physics_steps_per_control_step=2 + # ) # just testing warm starting the sim + # print(np.mean(times[1:])) + + # post-hoc visualization + mj_data = mujoco.MjData(mj_model) + mj_data.qpos[:] = x0[: mj_model.nq] + mj_data.qvel[:] = x0[mj_model.nq :] + dt = mj_model.opt.timestep + us_guess = jnp.zeros((N, mj_model.nu)) + # us_guess = jax.device_put(jnp.zeros((N, mj_model.nu)), NamedSharding(mesh, P('i', None))) # [DEBUG] + compiled = False + num_phys_steps_per_control_step = 15 + with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer: + while viewer.is_running(): + while True: + if not compiled: + print("compiling jit_compute...") + compiled = True + start = time.time() + x_meas = jnp.concatenate((mj_data.qpos, mj_data.qvel)) + u, _us_guess = jit_compute(key, x_meas, us_guess) + us_guess = np.concatenate((_us_guess[1:, :], np.zeros((1, mj_model.nu)))) + mj_data.ctrl[:] = u + print(time.time() - start) + for _ in range(num_phys_steps_per_control_step): + start = time.time() + mujoco.mj_step(mj_model, mj_data) + viewer.sync() + elapsed = time.time() - start + if elapsed < mj_model.opt.timestep: + time.sleep(mj_model.opt.timestep) + key = jax.random.split(key)[0] + + # start = time.time() + # ui = us[i, :] + # mj_data.ctrl[:] = ui + # if i <= int(T / dt): + # mujoco.mj_step(mj_model, mj_data) + # # print(f"t: {i * dt:.2f} | qpos: {mj_data.qpos}") + # # print(f"t: {i * dt:.2f} | cost: {mj_data.qpos @ mj_data.qpos + mj_data.qvel @ mj_data.qvel:.2f}") + # viewer.sync() + # i += 1 + # elapsed = time.time() - start + # if elapsed < mj_model.opt.timestep: + # time.sleep(mj_model.opt.timestep) + + breakpoint() diff --git a/ambersim/trajopt/base.py b/ambersim/trajopt/base.py index 2e50975e..f0d17c41 100644 --- a/ambersim/trajopt/base.py +++ b/ambersim/trajopt/base.py @@ -61,7 +61,7 @@ class TrajectoryOptimizer: raising a NotImplementedError. """ - def optimize(self, params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array, jax.Array]: + def optimize(self, to_params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array, jax.Array]: """Optimizes a trajectory. The shapes of the outputs include (?) because we may choose to return non-zero-order-hold parameterizations of @@ -69,7 +69,7 @@ def optimize(self, params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Ar control inputs over the trajectory as is done in the gradient-based methods of MJPC). Args: - params: The parameters of the trajectory optimizer. + to_params: The parameters of the trajectory optimizer. Returns: xs_star (shape=(N + 1, nq + nv) or (?)): The optimized trajectory. @@ -102,22 +102,22 @@ class CostFunction: (4) histories of higher-order derivatives can be useful for updating their current estimates, e.g., BFGS. """ - def cost(self, xs: jax.Array, us: jax.Array, params: CostFunctionParams) -> Tuple[jax.Array, CostFunctionParams]: + def cost(self, xs: jax.Array, us: jax.Array, cf_params: CostFunctionParams) -> Tuple[jax.Array, CostFunctionParams]: """Computes the cost of a trajectory. Args: xs (shape=(N + 1, nq + nv)): The state trajectory. us (shape=(N, nu)): The controls over the trajectory. - params: The parameters of the cost function. + cf_params: The parameters of the cost function. Returns: val (shape=(,)): The cost of the trajectory. - new_params: The updated parameters of the cost function. + new_cf_params: The updated parameters of the cost function. """ raise NotImplementedError def grad( - self, xs: jax.Array, us: jax.Array, params: CostFunctionParams + self, xs: jax.Array, us: jax.Array, cf_params: CostFunctionParams ) -> Tuple[jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]: """Computes the gradient of the cost of a trajectory. @@ -127,19 +127,19 @@ def grad( Args: xs (shape=(N + 1, nq + nv)): The state trajectory. us (shape=(N, nu)): The controls over the trajectory. - params: The parameters of the cost function. + cf_params: The parameters of the cost function. Returns: gcost_xs (shape=(N + 1, nq + nv): The gradient of the cost wrt xs. gcost_us (shape=(N, nu)): The gradient of the cost wrt us. - gcost_params: The gradient of the cost wrt params. - new_params: The updated parameters of the cost function. + gcost_params: The gradient of the cost wrt cf_params. + new_cf_params: The updated parameters of the cost function. """ - _fn = lambda xs, us, params: self.cost(xs, us, params)[0] # only differentiate wrt the cost val - return grad(_fn, argnums=(0, 1, 2))(xs, us, params) + (params,) + _fn = lambda xs, us, cf_params: self.cost(xs, us, cf_params)[0] # only differentiate wrt the cost val + return grad(_fn, argnums=(0, 1, 2))(xs, us, cf_params) + (cf_params,) def hess( - self, xs: jax.Array, us: jax.Array, params: CostFunctionParams + self, xs: jax.Array, us: jax.Array, cf_params: CostFunctionParams ) -> Tuple[ jax.Array, jax.Array, CostFunctionParams, jax.Array, CostFunctionParams, CostFunctionParams, CostFunctionParams ]: @@ -153,20 +153,20 @@ def hess( Args: xs (shape=(N + 1, nq + nv)): The state trajectory. us (shape=(N, nu)): The controls over the trajectory. - params: The parameters of the cost function. + cf_params: The parameters of the cost function. Returns: Hcost_xsxs (shape=(N + 1, nq + nv, N + 1, nq + nv)): The Hessian of the cost wrt xs. Hcost_xsus (shape=(N + 1, nq + nv, N, nu)): The Hessian of the cost wrt xs and us. - Hcost_xsparams: The Hessian of the cost wrt xs and params. + Hcost_xsparams: The Hessian of the cost wrt xs and cf_params. Hcost_usus (shape=(N, nu, N, nu)): The Hessian of the cost wrt us. - Hcost_usparams: The Hessian of the cost wrt us and params. - Hcost_paramsall: The Hessian of the cost wrt params and everything else. - new_params: The updated parameters of the cost function. + Hcost_usparams: The Hessian of the cost wrt us and cf_params. + Hcost_paramsall: The Hessian of the cost wrt cf_params and everything else. + new_cf_params: The updated parameters of the cost function. """ - _fn = lambda xs, us, params: self.cost(xs, us, params)[0] # only differentiate wrt the cost val - hessians = hessian(_fn, argnums=(0, 1, 2))(xs, us, params) + _fn = lambda xs, us, cf_params: self.cost(xs, us, cf_params)[0] # only differentiate wrt the cost val + hessians = hessian(_fn, argnums=(0, 1, 2))(xs, us, cf_params) Hcost_xsxs, Hcost_xsus, Hcost_xsparams = hessians[0] _, Hcost_usus, Hcost_usparams = hessians[1] Hcost_paramsall = hessians[2] - return Hcost_xsxs, Hcost_xsus, Hcost_xsparams, Hcost_usus, Hcost_usparams, Hcost_paramsall, params + return Hcost_xsxs, Hcost_xsus, Hcost_xsparams, Hcost_usus, Hcost_usparams, Hcost_paramsall, cf_params diff --git a/ambersim/trajopt/cost.py b/ambersim/trajopt/cost.py index fa560128..c50fb842 100644 --- a/ambersim/trajopt/cost.py +++ b/ambersim/trajopt/cost.py @@ -10,6 +10,7 @@ """A collection of common cost functions.""" +@struct.dataclass class StaticGoalQuadraticCost(CostFunction): """A quadratic cost function that penalizes the distance to a static goal. @@ -19,19 +20,10 @@ class StaticGoalQuadraticCost(CostFunction): dense. """ - def __init__(self, Q: jax.Array, Qf: jax.Array, R: jax.Array, xg: jax.Array) -> None: - """Initializes a quadratic cost function. - - Args: - Q (shape=(nx, nx)): The state cost matrix. - Qf (shape=(nx, nx)): The final state cost matrix. - R (shape=(nu, nu)): The control cost matrix. - xg (shape=(nq,)): The goal state. - """ - self.Q = Q - self.Qf = Qf - self.R = R - self.xg = xg + Q: jax.Array # shape=(nx, nx) state cost matrix + Qf: jax.Array # shape=(nx, nx) final state cost matrix + R: jax.Array # shape=(nu, nu) control cost matrix + xg: jax.Array # shape=(nq + nv,) goal state @staticmethod def batch_quadform(bs: jax.Array, A: jax.Array) -> jax.Array: diff --git a/ambersim/trajopt/shooting.py b/ambersim/trajopt/shooting.py index 414917a1..7fe492e4 100644 --- a/ambersim/trajopt/shooting.py +++ b/ambersim/trajopt/shooting.py @@ -24,8 +24,8 @@ def shoot(m: mjx.Model, x0: jax.Array, us: jax.Array) -> jax.Array: Args: m: The model. - x0: The initial state. - us: The control inputs. + x0 (shape=(nx,)): The initial state. + us (shape=(N, nu)): The control inputs. Returns: xs (shape=(N + 1, nq + nv)): The state trajectory. @@ -48,6 +48,44 @@ def scan_fn(d, u): return xs +def shoot_pd(m: mjx.Model, x0: jax.Array, qgs: jax.Array, kp: float, kd: float) -> jax.Array: + """Utility function that shoots a model forward given a sequence of goal states + PD gains. + + Args: + m: The model. + x0 (shape=(nx,)): The initial state. + qgs (shape=(N, nq)): The goal trajectory. + kp: The proportional gain. + kd: The derivative gain. + + Returns: + xs (shape=(N + 1, nq + nv)): The state trajectory. + us (shape=(N, nu)): The control inputs. + """ + # initializing the data + d = mjx.make_data(m) + d = d.replace(qpos=x0[: m.nq], qvel=x0[m.nq :]) # setting the initial state. + d = mjx.forward(m, d) # setting other internal states like acceleration without integrating + + def scan_fn(d, qg): + """Integrates the model forward one step given the control input u.""" + # applying PD controller + u = -kp * (d.qpos - qg) - kd * d.qvel + + d = d.replace(ctrl=u) + d = step(m, d) + x = jnp.concatenate((d.qpos, d.qvel)) # (nq + nv,) + xu = jnp.concatenate((x, u)) + return d, xu + + # scan over the control inputs to get the trajectory. + _, _xus = lax.scan(scan_fn, d, qgs, length=qgs.shape[0]) + _xs = _xus[:, : m.nq + m.nv] + xs = jnp.concatenate((x0[None, :], _xs), axis=0) # (N + 1, nq + nv) + us = _xus[:, m.nq + m.nv :] + return xs, us + + # ################ # # SHOOTING METHODS # # ################ # @@ -61,27 +99,29 @@ class ShootingParams(TrajectoryOptimizerParams): # inputs into the algorithm x0: jax.Array # shape=(nq + nv,) or (?) - us_guess: jax.Array # shape=(N, nu) or (?) + guess: jax.Array # shape=(N, n?) or (?) @property def N(self) -> int: """The number of time steps. - By default, we assume us_guess represents the ZOH control inputs. However, in the case that it is actually an + By default, we assume guess represents the ZOH control inputs. However, in the case that it is actually an alternate parameterization, we may need to compute N some other way, which requires overwriting this method. + + In the case of PD controllers, for instance, the guess could actually be a trajectory of target positions. """ - return self.us_guess.shape[0] + return self.guess.shape[0] @struct.dataclass class ShootingAlgorithm(TrajectoryOptimizer): """A trajectory optimization algorithm based on shooting methods.""" - def optimize(self, params: ShootingParams) -> Tuple[jax.Array, jax.Array]: + def optimize(self, to_params: ShootingParams) -> Tuple[jax.Array, jax.Array]: """Optimizes a trajectory using a shooting method. Args: - params: The parameters of the trajectory optimizer. + to_params: The parameters of the trajectory optimizer. Returns: xs_star (shape=(N + 1, nq) or (?)): The optimized trajectory. @@ -90,18 +130,46 @@ def optimize(self, params: ShootingParams) -> Tuple[jax.Array, jax.Array]: raise NotImplementedError -# predictive sampling API +# ########################### # +# PREDICTIVE SAMPLING METHODS # +# ########################### # + +# generic API @struct.dataclass -class VanillaPredictiveSamplerParams(ShootingParams): +class PredictiveSamplerParams(ShootingParams): """Parameters for generic predictive sampling methods.""" key: jax.Array # random key for sampling @struct.dataclass -class VanillaPredictiveSampler(ShootingAlgorithm): +class PredictiveSampler(ShootingAlgorithm): + """A generic predictive sampler object.""" + + def optimize(self, to_params: PredictiveSamplerParams) -> Tuple[jax.Array, jax.Array]: + """Optimizes a trajectory using a vanilla predictive sampler. + + Args: + to_params: The parameters of the trajectory optimizer. + + Returns: + xs (shape=(N + 1, nq + nv) or (?)): The optimized trajectory. + us (shape=(N, nu) or (?)): The optimized controls. + """ + + +# vanilla + + +@struct.dataclass +class VanillaPredictiveSamplerParams(PredictiveSamplerParams): + """Parameters for the vanilla predictive sampling method.""" + + +@struct.dataclass +class VanillaPredictiveSampler(PredictiveSampler): """A vanilla predictive sampler object. The following choices are made: @@ -116,25 +184,25 @@ class VanillaPredictiveSampler(ShootingAlgorithm): nsamples: int = struct.field(pytree_node=False) stdev: float = struct.field(pytree_node=False) # noise scale, parameters theta_new ~ N(theta, (stdev ** 2) * I) - def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, jax.Array]: + def optimize(self, to_params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, jax.Array]: """Optimizes a trajectory using a vanilla predictive sampler. Args: - params: The parameters of the trajectory optimizer. + to_params: The parameters of the trajectory optimizer. Returns: - xs (shape=(N + 1, nq + nv)): The optimized trajectory. - us (shape=(N, nu)): The optimized controls. + xs_star (shape=(N + 1, nq + nv)): The optimized trajectory. + us_star (shape=(N, nu)): The optimized controls. """ # unpack the params m = self.model nsamples = self.nsamples stdev = self.stdev - x0 = params.x0 - us_guess = params.us_guess - N = params.N - key = params.key + x0 = to_params.x0 + us_guess = to_params.guess + N = to_params.N + key = to_params.key # sample over the control inputs - the first sample is the guess, since it's possible that it's the best one noise = jnp.concatenate( @@ -155,3 +223,74 @@ def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, j xs_star = lax.dynamic_slice(xs_samples, (best_idx, 0, 0), (1, N + 1, m.nq + m.nv))[0] # (N + 1, nq + nv) us_star = lax.dynamic_slice(us_samples, (best_idx, 0, 0), (1, N, m.nu))[0] # (N, nu) return xs_star, us_star + + +# PD predictive sampling + + +@struct.dataclass +class PDPredictiveSamplerParams(PredictiveSamplerParams): + """Parameters for the PD controller predictive sampler.""" + + kp: float # proportional gain + kd: float # derivative gain + + +@struct.dataclass +class PDPredictiveSampler(PredictiveSampler): + """A PD predictive sampler object. + + The following choices are made: + (1) the model parameters are fixed, and are therefore a field of this dataclass; + (2) the control sequence is parameterized as a ZOH sequence instead of a spline; + (3) the guess supplied is actually a trajectory of positions and the control signal is computed via PD control; + (4) the cost function is quadratic in the states and controls. + """ + + model: mjx.Model + cost_function: CostFunction + nsamples: int = struct.field(pytree_node=False) + stdev: float = struct.field(pytree_node=False) # noise scale, parameters theta_new ~ N(theta, (stdev ** 2) * I) + + def optimize(self, to_params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, jax.Array]: + """Optimizes a trajectory using a vanilla predictive sampler. + + Args: + to_params: The parameters of the trajectory optimizer. + + Returns: + xs_star (shape=(N + 1, nq + nv)): The optimized trajectory. + us_star (shape=(N, nu)): The optimized controls. + """ + # unpack the params + m = self.model + nsamples = self.nsamples + stdev = self.stdev + + x0 = to_params.x0 + qs_guess = to_params.guess + N = to_params.N + key = to_params.key + + kp = to_params.kp + kd = to_params.kd + + # sample random position trajectories + noise = jnp.concatenate( + (jnp.zeros((1, N, m.nq)), jax.random.normal(key, shape=(nsamples - 1, N, m.nq)) * stdev), axis=0 + ) + _qgs_samples = qs_guess + noise + + # clamping the samples to their position limits + limits = m.jnt_range + clip_fn = partial(jnp.clip, a_min=limits[:, 0], a_max=limits[:, 1]) # clipping function with limits already set + qgs_samples = vmap(vmap(clip_fn))(_qgs_samples) # apply limits only to the last dim, need a nested vmap + + # predict many samples, evaluate them, and return the best trajectory tuple + # vmap over the input data and the control trajectories + xs_samples, us_samples = vmap(shoot_pd, in_axes=(None, None, 0, None, None))(m, x0, qgs_samples, kp, kd) + costs, _ = vmap(self.cost_function.cost, in_axes=(0, 0, None))(xs_samples, us_samples, None) # (nsamples,) + best_idx = jnp.argmin(costs) + xs_star = lax.dynamic_slice(xs_samples, (best_idx, 0, 0), (1, N + 1, m.nq + m.nv))[0] # (N + 1, nq + nv) + us_star = lax.dynamic_slice(us_samples, (best_idx, 0, 0), (1, N, m.nu))[0] # (N, nu) + return xs_star, us_star diff --git a/examples/control/algr_palmup.py b/examples/control/algr_palmup.py new file mode 100644 index 00000000..89d6a47c --- /dev/null +++ b/examples/control/algr_palmup.py @@ -0,0 +1,147 @@ +import time + +import jax +import jax.numpy as jnp +import mujoco +import mujoco.viewer +import numpy as np +from jax import jit + +# from ambersim.control.predictive_control import PDPredictiveSamplingController, PDPredictiveSamplingControllerParams +from ambersim.control.predictive_control import ( + VanillaPredictiveSamplingController, + VanillaPredictiveSamplingControllerParams, +) +from ambersim.trajopt.cost import StaticGoalQuadraticCost + +# from ambersim.trajopt.shooting import PDPredictiveSampler +from ambersim.trajopt.shooting import VanillaPredictiveSampler +from ambersim.utils.io_utils import load_mj_model_from_file, mj_to_mjx_model_and_data + +"""Important info: +* the cube's states come after the hand's states +* the allegro hand is floating but fixed in space + +TODOs: +* implement a check for the cube being stuck +* add some reasonable cost terms for the palm up task (see the mjpc code + talk to vince) +* add support for the PD controller, but we have to more clever about underactuation +""" + +# ##################### # +# CUBE TARGET POSITIONS # +# ##################### # + +# position +POS = jnp.array([0.05, 0.0, 0.05]) # move the cube over the fingers +Z_FAIL = -0.05 # floor at -0.1 + +# rotations +IDENTITY = jnp.array([1.0, 0.0, 0.0, 0.0]) +X90 = jnp.array([0.7071068, 0.7071068, 0.0, 0.0]) +X180 = jnp.array([0.0, 1.0, 0.0, 0.0]) +X270 = jnp.array([-0.7071068, 0.7071068, 0.0, 0.0]) +Y90 = jnp.array([0.7071068, 0.0, 0.7071068, 0.0]) +Y180 = jnp.array([0.0, 0.0, 1.0, 0.0]) +Y270 = jnp.array([-0.7071068, 0.0, 0.7071068, 0.0]) +Z90 = jnp.array([0.7071068, 0.0, 0.0, 0.7071068]) +Z180 = jnp.array([0.0, 0.0, 0.0, 1.0]) +Z270 = jnp.array([-0.7071068, 0.0, 0.0, 0.7071068]) + +# ########## # +# PARAMETERS # +# ########## # + +dt = 0.001 # rate of reality +num_phys_steps_per_control_step = 15 # dt * npspcs = rate of control + +w_pos = 10.0 +w_vel = 0.001 +w_ctrl = 0.001 +q_goal_algr = jnp.zeros(16) +q_goal_cube = jnp.concatenate((POS, X90)) +q_goal = jnp.concatenate((q_goal_algr, q_goal_cube)) + +nsamples = 100 # number of samples to draw for predictive sampling +stdev = 0.3 # standard deviation of control parameters to draw +N = 5 # number of time steps to predict + +key = jax.random.PRNGKey(1234) + +# ########## # +# SIMULATION # +# ########## # + +# loading model +mj_model = load_mj_model_from_file("models/allegro_hand/scene_right.xml") +mj_model.opt.timestep = dt + +# defining the predictive controller +ctrl_model, _ = mj_to_mjx_model_and_data(mj_model) +ctrl_model = ctrl_model.replace( + opt=ctrl_model.opt.replace( + timestep=num_phys_steps_per_control_step * dt, # dt for each step in the controller's internal model + iterations=1, # number of Newton steps to take during solve + ls_iterations=1, # number of line search iterations along step direction + integrator=0, # Euler semi-implicit integration + solver=2, # Newton solver + ) +) +cost_function = StaticGoalQuadraticCost( + Q=jnp.diag(jnp.array([w_pos] * ctrl_model.nq + [w_vel] * ctrl_model.nv)), + Qf=jnp.diag(jnp.array([w_pos] * ctrl_model.nq + [w_vel] * ctrl_model.nv)), + R=w_ctrl * jnp.eye(ctrl_model.nu), + xg=jnp.concatenate((q_goal, jnp.zeros(ctrl_model.nv))), +) +ps = VanillaPredictiveSampler(model=ctrl_model, cost_function=cost_function, nsamples=nsamples, stdev=stdev) +controller = VanillaPredictiveSamplingController(trajectory_optimizer=ps, model=ctrl_model) +jit_compute = jit( + lambda key, x_meas, us_guess: controller.compute_with_us_star( + VanillaPredictiveSamplingControllerParams(key=key, x=x_meas, guess=us_guess) + ) +) +print("Controller created! Simulating...") + +# simulating forward +q0_algr = jnp.zeros(16) +q0_cube = jnp.concatenate((POS, IDENTITY)) +v0 = jnp.zeros(mj_model.nv) +x0 = jnp.concatenate((q0_algr, q0_cube, v0)) + +mj_data = mujoco.MjData(mj_model) +mj_data.qpos[:] = x0[: mj_model.nq] +mj_data.qvel[:] = x0[mj_model.nq :] + +us_guess = jnp.zeros((N, mj_model.nu)) + +print("Compiling jit_compute...") +jit_compute(key, x0, us_guess) +print("Compiled! Beginning simulation...") + +with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer: + while viewer.is_running(): + while True: + start = time.time() + x_meas = jnp.concatenate((mj_data.qpos, mj_data.qvel)) + u, _us_guess = jit_compute(key, x_meas, us_guess) + us_guess = np.concatenate((_us_guess[1:, :], np.zeros((1, mj_model.nu)))) + mj_data.ctrl[:] = u + print(f"Controller delay: {time.time() - start}") + for _ in range(num_phys_steps_per_control_step): + start = time.time() + mujoco.mj_step(mj_model, mj_data) + viewer.sync() + elapsed = time.time() - start + if elapsed < mj_model.opt.timestep: + time.sleep(mj_model.opt.timestep - elapsed) + + # check whether to reset the cube + if mj_data.qpos[-5] <= Z_FAIL: + print("*** Cube fell! Resetting... *** ") + mj_data.qpos[:] = x0[: mj_model.nq] + mj_data.qvel[:] = x0[mj_model.nq :] + us_guess = jnp.zeros((N, mj_model.nu)) + viewer.sync() + time.sleep(1.0) + + key = jax.random.split(key)[0] diff --git a/examples/control/pd_predictive_sampling.py b/examples/control/pd_predictive_sampling.py new file mode 100644 index 00000000..c4b18b99 --- /dev/null +++ b/examples/control/pd_predictive_sampling.py @@ -0,0 +1,106 @@ +import time + +import jax +import jax.numpy as jnp +import mujoco +import mujoco.viewer +import numpy as np +from jax import jit + +from ambersim.control.predictive_control import PDPredictiveSamplingController, PDPredictiveSamplingControllerParams +from ambersim.trajopt.cost import StaticGoalQuadraticCost +from ambersim.trajopt.shooting import PDPredictiveSampler +from ambersim.utils.io_utils import load_mj_model_from_file, mj_to_mjx_model_and_data + +""" +By far the most sensitive parameters for this controller parameterization are the rollout horizon N and the proportional gain. +""" + +# ########## # +# PARAMETERS # +# ########## # + +dt = 0.001 # rate of reality +num_phys_steps_per_control_step = 15 # dt * npspcs = rate of control + +w_pos = 10.0 +w_vel = 0.01 +w_ctrl = 0.001 +q_goal = jnp.array([0.0, 0.5, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 1.0, 0.8, 0.5, 0.5]) + +nsamples = 100 # number of samples to draw for predictive sampling +stdev = 0.1 # standard deviation of control parameters to draw +N = 8 # number of time steps to predict + +key = jax.random.PRNGKey(1234) + +# ########## # +# SIMULATION # +# ########## # + +# loading model +mj_model = load_mj_model_from_file("models/allegro_hand/right_hand.xml") +mj_model.opt.timestep = dt + +# defining the predictive controller +ctrl_model, _ = mj_to_mjx_model_and_data(mj_model) +ctrl_model = ctrl_model.replace( + opt=ctrl_model.opt.replace( + timestep=num_phys_steps_per_control_step * dt, # dt for each step in the controller's internal model + iterations=1, # number of Newton steps to take during solve + ls_iterations=1, # number of line search iterations along step direction + integrator=0, # Euler semi-implicit integration + solver=2, # Newton solver + ) +) +cost_function = StaticGoalQuadraticCost( + Q=jnp.diag(jnp.array([w_pos] * ctrl_model.nq + [w_vel] * ctrl_model.nv)), + Qf=jnp.diag(jnp.array([w_pos] * ctrl_model.nq + [w_vel] * ctrl_model.nv)), + R=w_ctrl * jnp.eye(ctrl_model.nu), + xg=jnp.concatenate((q_goal, jnp.zeros(ctrl_model.nv))), +) +ps = PDPredictiveSampler(model=ctrl_model, cost_function=cost_function, nsamples=nsamples, stdev=stdev) + +kp = 5.0 +kd = 0.1 +controller = PDPredictiveSamplingController(trajectory_optimizer=ps, model=ctrl_model) +jit_compute = jit( + lambda key, x_meas, qgs_guess: controller.compute_with_qs_star( + PDPredictiveSamplingControllerParams(key=key, x=x_meas, guess=qgs_guess, kp=kp, kd=kd) + ) +) +print("Controller created! Simulating...") + +# simulating forward +x0 = 2.0 * (jax.random.uniform(key, shape=(ctrl_model.nq + ctrl_model.nv,)) - 0.5) # random initial state + +mj_data = mujoco.MjData(mj_model) +mj_data.qpos[:] = x0[: mj_model.nq] +mj_data.qvel[:] = x0[mj_model.nq :] + +qgs_guess = jnp.tile(x0[: mj_model.nq], (N, 1)) + +print("Compiling jit_compute...") +jit_compute(key, x0, qgs_guess) +print("Compiled! Beginning simulation...") + +with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer: + while viewer.is_running(): + while True: + start = time.time() + x_meas = jnp.concatenate((mj_data.qpos, mj_data.qvel)) + qg, _qgs_guess = jit_compute(key, x_meas, qgs_guess) + qg.block_until_ready() + qgs_guess = _qgs_guess[1:, :] + print(f"Controller delay: {time.time() - start}") + for _ in range(num_phys_steps_per_control_step): + start = time.time() + u = -kp * (mj_data.qpos - qg) - kd * mj_data.qvel + u = np.clip(u, mj_model.actuator_ctrlrange[:, 0], mj_model.actuator_ctrlrange[:, 1]) + mj_data.ctrl[:] = u + mujoco.mj_step(mj_model, mj_data) + viewer.sync() + elapsed = time.time() - start + if elapsed < mj_model.opt.timestep: + time.sleep(mj_model.opt.timestep - elapsed) + key = jax.random.split(key)[0] diff --git a/examples/control/vanilla_predictive_sampling.py b/examples/control/vanilla_predictive_sampling.py new file mode 100644 index 00000000..b1e8dc31 --- /dev/null +++ b/examples/control/vanilla_predictive_sampling.py @@ -0,0 +1,99 @@ +import time + +import jax +import jax.numpy as jnp +import mujoco +import mujoco.viewer +import numpy as np +from jax import jit + +from ambersim.control.predictive_control import ( + VanillaPredictiveSamplingController, + VanillaPredictiveSamplingControllerParams, +) +from ambersim.trajopt.cost import StaticGoalQuadraticCost +from ambersim.trajopt.shooting import VanillaPredictiveSampler +from ambersim.utils.io_utils import load_mj_model_from_file, mj_to_mjx_model_and_data + +# ########## # +# PARAMETERS # +# ########## # + +dt = 0.001 # rate of reality +num_phys_steps_per_control_step = 15 # dt * npspcs = rate of control + +w_pos = 10.0 +w_vel = 0.01 +w_ctrl = 0.001 +q_goal = jnp.array([0.0, 0.5, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 1.0, 0.8, 0.5, 0.5]) + +nsamples = 100 # number of samples to draw for predictive sampling +stdev = 0.1 # standard deviation of control parameters to draw +N = 10 # number of time steps to predict + +key = jax.random.PRNGKey(1234) + +# ########## # +# SIMULATION # +# ########## # + +# loading model +mj_model = load_mj_model_from_file("models/allegro_hand/right_hand.xml") +mj_model.opt.timestep = dt + +# defining the predictive controller +ctrl_model, _ = mj_to_mjx_model_and_data(mj_model) +ctrl_model = ctrl_model.replace( + opt=ctrl_model.opt.replace( + timestep=num_phys_steps_per_control_step * dt, # dt for each step in the controller's internal model + iterations=1, # number of Newton steps to take during solve + ls_iterations=1, # number of line search iterations along step direction + integrator=0, # Euler semi-implicit integration + solver=2, # Newton solver + ) +) +cost_function = StaticGoalQuadraticCost( + Q=jnp.diag(jnp.array([w_pos] * ctrl_model.nq + [w_vel] * ctrl_model.nv)), + Qf=jnp.diag(jnp.array([w_pos] * ctrl_model.nq + [w_vel] * ctrl_model.nv)), + R=w_ctrl * jnp.eye(ctrl_model.nu), + xg=jnp.concatenate((q_goal, jnp.zeros(ctrl_model.nv))), +) +ps = VanillaPredictiveSampler(model=ctrl_model, cost_function=cost_function, nsamples=nsamples, stdev=stdev) +controller = VanillaPredictiveSamplingController(trajectory_optimizer=ps, model=ctrl_model) +jit_compute = jit( + lambda key, x_meas, us_guess: controller.compute_with_us_star( + VanillaPredictiveSamplingControllerParams(key=key, x=x_meas, guess=us_guess) + ) +) +print("Controller created! Simulating...") + +# simulating forward +x0 = 2.0 * (jax.random.uniform(key, shape=(ctrl_model.nq + ctrl_model.nv,)) - 0.5) # random initial state + +mj_data = mujoco.MjData(mj_model) +mj_data.qpos[:] = x0[: mj_model.nq] +mj_data.qvel[:] = x0[mj_model.nq :] + +us_guess = jnp.zeros((N, mj_model.nu)) + +print("Compiling jit_compute...") +jit_compute(key, x0, us_guess) +print("Compiled! Beginning simulation...") + +with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer: + while viewer.is_running(): + while True: + start = time.time() + x_meas = jnp.concatenate((mj_data.qpos, mj_data.qvel)) + u, _us_guess = jit_compute(key, x_meas, us_guess) + us_guess = np.concatenate((_us_guess[1:, :], np.zeros((1, mj_model.nu)))) + mj_data.ctrl[:] = u + print(f"Controller delay: {time.time() - start}") + for _ in range(num_phys_steps_per_control_step): + start = time.time() + mujoco.mj_step(mj_model, mj_data) + viewer.sync() + elapsed = time.time() - start + if elapsed < mj_model.opt.timestep: + time.sleep(mj_model.opt.timestep - elapsed) + key = jax.random.split(key)[0] diff --git a/examples/convex_decomposition.py b/examples/io/convex_decomposition.py similarity index 100% rename from examples/convex_decomposition.py rename to examples/io/convex_decomposition.py diff --git a/examples/load_from_file.py b/examples/io/load_from_file.py similarity index 100% rename from examples/load_from_file.py rename to examples/io/load_from_file.py diff --git a/examples/interactive_simulation.py b/examples/sim/interactive_simulation.py similarity index 100% rename from examples/interactive_simulation.py rename to examples/sim/interactive_simulation.py diff --git a/pyproject.toml b/pyproject.toml index 64bd86e4..79e618d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ install-mujoco-from-src = "ambersim._scripts._install_shim:entrypoint" include = ["ambersim*"] [tool.setuptools.package-data] -"*" = ["*.obj", "*.urdf", "*.xml", "install.sh", "install_mj_source.sh"] +"*" = ["*.png", "*.obj", "*.stl", "*.urdf", "*.xml", "install.sh", "install_mj_source.sh"] [tool.black] line-length = 120 diff --git a/tests/trajopt/test_predictive_sampler.py b/tests/trajopt/test_predictive_sampler.py index 7111fc05..f66f4b9a 100644 --- a/tests/trajopt/test_predictive_sampler.py +++ b/tests/trajopt/test_predictive_sampler.py @@ -8,17 +8,28 @@ from ambersim.trajopt.base import CostFunctionParams from ambersim.trajopt.cost import StaticGoalQuadraticCost -from ambersim.trajopt.shooting import VanillaPredictiveSampler, VanillaPredictiveSamplerParams, shoot +from ambersim.trajopt.shooting import ( + PDPredictiveSampler, + PDPredictiveSamplerParams, + VanillaPredictiveSampler, + VanillaPredictiveSamplerParams, + shoot, + shoot_pd, +) from ambersim.utils.io_utils import load_mjx_model_and_data_from_file os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" # fixes OOM error +# ########################## # +# VANILLA PREDICTIVE SAMPLER # +# ########################## # + @pytest.fixture def vps_data(): """Makes data required for testing vanilla predictive sampling.""" # initializing the predictive sampler - model, _ = load_mjx_model_and_data_from_file("models/barrett_hand/bh280.xml", force_float=False) + model, _ = load_mjx_model_and_data_from_file("models/allegro_hand/right_hand_motor.xml", force_float=False) model = model.replace( opt=model.opt.replace( timestep=0.002, # dt @@ -26,7 +37,6 @@ def vps_data(): ls_iterations=4, # number of line search iterations along step direction integrator=0, # Euler semi-implicit integration solver=2, # Newton solver - disableflags=DisableBit.CONTACT, # disable contact for this test ) ) cost_function = StaticGoalQuadraticCost( @@ -50,7 +60,7 @@ def test_smoke_VPS(vps_data): x0 = jnp.zeros(model.nq + model.nv) num_steps = 10 us_guess = jnp.zeros((num_steps, model.nu)) - params = VanillaPredictiveSamplerParams(key=key, x0=x0, us_guess=us_guess) + params = VanillaPredictiveSamplerParams(key=key, x0=x0, guess=us_guess) # sampling the best sequence of qs, vs, and us optimize_fn = jit(ps.optimize) @@ -72,7 +82,7 @@ def test_VPS_cost_decrease(vps_data): us_guess = jax.random.normal(key=subkey, shape=(batch_size, num_steps, model.nu)) keys = jax.random.split(key, num=batch_size) - params = VanillaPredictiveSamplerParams(key=keys, x0=x0, us_guess=us_guess) + params = VanillaPredictiveSamplerParams(key=keys, x0=x0, guess=us_guess) # sampling with the vanilla predictive sampler xs_stars, us_stars = vmap(ps.optimize)(params) @@ -85,3 +95,50 @@ def test_VPS_cost_decrease(vps_data): xs_guess = vmap(shoot, in_axes=(None, 0, 0))(model, x0, us_guess) costs_guess = vmap_cost(xs_guess, us_guess) assert jnp.all(costs_star <= costs_guess) + + +# ##################### # +# PD PREDICTIVE SAMPLER # +# ##################### # + + +@pytest.fixture +def pdps_data(): + """Makes data required for testing PD predictive sampling.""" + # initializing the predictive sampler + model, _ = load_mjx_model_and_data_from_file("models/allegro_hand/right_hand_position.xml", force_float=False) + model = model.replace( + opt=model.opt.replace( + timestep=0.002, # dt + iterations=1, # number of Newton steps to take during solve + ls_iterations=4, # number of line search iterations along step direction + integrator=0, # Euler semi-implicit integration + solver=2, # Newton solver + ) + ) + cost_function = StaticGoalQuadraticCost( + Q=jnp.eye(model.nq + model.nv), + Qf=10.0 * jnp.eye(model.nq + model.nv), + R=0.01 * jnp.eye(model.nu), + xg=jnp.zeros(model.nq + model.nv), + ) + nsamples = 100 + stdev = 0.01 + ps = PDPredictiveSampler(model=model, cost_function=cost_function, nsamples=nsamples, stdev=stdev) + return ps, model, cost_function + + +def test_smoke_PDPS(pdps_data): + """Simple smoke test to make sure we can run inputs through the PD predictive sampler + jit.""" + ps, model, _ = pdps_data + + # sampler parameters + key = jax.random.PRNGKey(0) # random seed for the predictive sampler + x0 = jnp.zeros(model.nq + model.nv) + num_steps = 10 + qs_guess = jnp.tile(x0[: model.nq], (num_steps, 1)) + params = PDPredictiveSamplerParams(key=key, x0=x0, guess=qs_guess, kp=1.0, kd=0.1) + + # sampling the best sequence of qs, vs, and us + optimize_fn = jit(ps.optimize) + assert optimize_fn(params)