diff --git a/ambersim/control/base.py b/ambersim/control/base.py
new file mode 100644
index 00000000..1a2c7a4a
--- /dev/null
+++ b/ambersim/control/base.py
@@ -0,0 +1,32 @@
+import jax
+from flax import struct
+
+
+@struct.dataclass
+class ControllerParams:
+ """The parameters for generic controllers.
+
+ This is left completely empty for maximum flexibility in the API. Some examples:
+ - "Regular" inputs into feedback controllers (e.g., the state) belong here.
+ - Non-Markovian controllers can pass histories in this params object.
+ - Parameters of the controller that you may randomize/optimize go here.
+ """
+
+
+@struct.dataclass
+class Controller:
+ """The API for a generic controller.
+
+ See the notes in TrajectoryOptimizer on the generality of this class - much of the same applies.
+ """
+
+ def compute(self, ctrl_params: ControllerParams) -> jax.Array:
+ """Computes a control input.
+
+ Args:
+ ctrl_params: ControllerParams
+
+ Returns:
+ u (shape=(nu,)): The control input.
+ """
+ raise NotImplementedError
diff --git a/ambersim/control/predictive_control.py b/ambersim/control/predictive_control.py
new file mode 100644
index 00000000..4946d6d9
--- /dev/null
+++ b/ambersim/control/predictive_control.py
@@ -0,0 +1,149 @@
+from typing import Tuple
+
+import jax
+import jax.numpy as jnp
+from flax import struct
+from mujoco import mjx
+
+from ambersim.control.base import Controller, ControllerParams
+from ambersim.trajopt.base import TrajectoryOptimizer
+from ambersim.trajopt.shooting import PDPredictiveSamplerParams, PredictiveSampler, VanillaPredictiveSamplerParams
+
+# ########### #
+# GENERIC API #
+# ########### #
+
+
+@struct.dataclass
+class PredictiveControllerParams(ControllerParams):
+ """The generic API for predictive controller params."""
+
+
+@struct.dataclass
+class PredictiveController(Controller):
+ """The generic API for a predictive controller."""
+
+ trajectory_optimizer: TrajectoryOptimizer
+ model: mjx.Model
+
+ def compute(self, ctrl_params: PredictiveControllerParams) -> jax.Array:
+ """Computes a control input using forward prediction."""
+ raise NotImplementedError
+
+
+# ################### #
+# PREDICTIVE SAMPLING #
+# ################### #
+
+
+@struct.dataclass
+class VanillaPredictiveSamplingControllerParams(PredictiveControllerParams):
+ """Vanilla predictive sampling controller params."""
+
+ key: jax.Array # random key for sampling
+ x: jax.Array # shape=(nq + nv,) current state
+ guess: jax.Array # shape=(N, nu) current guess
+
+
+@struct.dataclass
+class VanillaPredictiveSamplingController(PredictiveController):
+ """Vanilla predictive sampling controller."""
+
+ def __post_init__(self) -> None:
+ """Post-initialization check."""
+ assert isinstance(
+ self.trajectory_optimizer, PredictiveSampler
+ ), "trajectory_optimizer must be a PredictiveSampler!"
+
+ def compute(self, ctrl_params: VanillaPredictiveSamplingControllerParams) -> jax.Array:
+ """Computes a control input using forward prediction.
+
+ Args:
+ ctrl_params: Inputs into the controller.
+
+ Returns:
+ u (shape=(nu,)): The control input.
+ """
+ return self.compute_with_us_star(ctrl_params)[0]
+
+ def compute_with_us_star(
+ self, ctrl_params: VanillaPredictiveSamplingControllerParams
+ ) -> Tuple[jax.Array, jax.Array]:
+ """Computes a control input using forward prediction + the optimal sequence of guesses.
+
+ This is needed in practice because the current optimal sequence is used to warm start the sampling distribution
+ for the next call of the controller.
+
+ Args:
+ ctrl_params: Inputs into the controller.
+
+ Returns:
+ u (shape=(nu,)): The control input.
+ us_star (shape=(N, nu)): The optimal control sequence.
+ """
+ to_params = VanillaPredictiveSamplerParams(
+ key=ctrl_params.key,
+ x0=ctrl_params.x,
+ guess=ctrl_params.guess,
+ )
+ xs_star, us_star = self.trajectory_optimizer.optimize(to_params)
+ u = us_star[0, :]
+ return u, us_star
+
+
+@struct.dataclass
+class PDPredictiveSamplingControllerParams(PredictiveControllerParams):
+ """PD predictive sampling controller params."""
+
+ key: jax.Array # random key for sampling
+ x: jax.Array # shape=(nq + nv,) current state
+ guess: jax.Array # shape=(N, nq) current guess
+ kp: float # proportional gain
+ kd: float # derivative gain
+
+
+@struct.dataclass
+class PDPredictiveSamplingController(PredictiveController):
+ """PD predictive sampling controller."""
+
+ def __post_init__(self) -> None:
+ """Post-initialization check."""
+ assert isinstance(
+ self.trajectory_optimizer, PredictiveSampler
+ ), "trajectory_optimizer must be a PredictiveSampler!"
+
+ def compute(self, ctrl_params: PDPredictiveSamplingControllerParams) -> jax.Array:
+ """Computes a control input using forward prediction.
+
+ Args:
+ ctrl_params: Inputs into the controller.
+
+ Returns:
+ u (shape=(nu,)): The control input.
+ """
+ return self.compute_with_qs_star(ctrl_params)[0]
+
+ def compute_with_qs_star(self, ctrl_params: PDPredictiveSamplingControllerParams) -> Tuple[jax.Array, jax.Array]:
+ """Computes a control input using forward prediction + the optimal sequence of guesses.
+
+ This is needed in practice because the current optimal sequence is used to warm start the sampling distribution
+ for the next call of the controller.
+
+ Args:
+ ctrl_params: Inputs into the controller.
+
+ Returns:
+ u (shape=(nu,)): The control input.
+ us_star (shape=(N, nu)): The optimal control sequence.
+ """
+ to_params = PDPredictiveSamplerParams(
+ key=ctrl_params.key,
+ x0=ctrl_params.x,
+ guess=ctrl_params.guess,
+ kp=ctrl_params.kp,
+ kd=ctrl_params.kd,
+ )
+ xs_star, us_star = self.trajectory_optimizer.optimize(to_params)
+ # u = us_star[0, :]
+ qs_star = xs_star[:, : self.model.nq] # the 0th index is the current state, so return the 1st index
+ return qs_star[1, :], qs_star
diff --git a/ambersim/models/allegro_hand/assets/base_link.stl b/ambersim/models/allegro_hand/assets/base_link.stl
new file mode 100644
index 00000000..5901473f
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/base_link.stl differ
diff --git a/ambersim/models/allegro_hand/assets/base_link_left.stl b/ambersim/models/allegro_hand/assets/base_link_left.stl
new file mode 100644
index 00000000..365d7e7f
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/base_link_left.stl differ
diff --git a/ambersim/models/allegro_hand/assets/fileback.png b/ambersim/models/allegro_hand/assets/fileback.png
new file mode 100644
index 00000000..e03322fc
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/fileback.png differ
diff --git a/ambersim/models/allegro_hand/assets/filedown.png b/ambersim/models/allegro_hand/assets/filedown.png
new file mode 100644
index 00000000..ee48edf6
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/filedown.png differ
diff --git a/ambersim/models/allegro_hand/assets/filefront.png b/ambersim/models/allegro_hand/assets/filefront.png
new file mode 100644
index 00000000..931e8056
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/filefront.png differ
diff --git a/ambersim/models/allegro_hand/assets/fileleft.png b/ambersim/models/allegro_hand/assets/fileleft.png
new file mode 100644
index 00000000..da6427bf
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/fileleft.png differ
diff --git a/ambersim/models/allegro_hand/assets/fileright.png b/ambersim/models/allegro_hand/assets/fileright.png
new file mode 100644
index 00000000..6c2cb391
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/fileright.png differ
diff --git a/ambersim/models/allegro_hand/assets/fileup.png b/ambersim/models/allegro_hand/assets/fileup.png
new file mode 100644
index 00000000..4962dfd0
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/fileup.png differ
diff --git a/ambersim/models/allegro_hand/assets/grayback.png b/ambersim/models/allegro_hand/assets/grayback.png
new file mode 100644
index 00000000..03ae4bf7
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/grayback.png differ
diff --git a/ambersim/models/allegro_hand/assets/graydown.png b/ambersim/models/allegro_hand/assets/graydown.png
new file mode 100644
index 00000000..b8b7b314
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/graydown.png differ
diff --git a/ambersim/models/allegro_hand/assets/grayfront.png b/ambersim/models/allegro_hand/assets/grayfront.png
new file mode 100644
index 00000000..7c45fff0
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/grayfront.png differ
diff --git a/ambersim/models/allegro_hand/assets/grayleft.png b/ambersim/models/allegro_hand/assets/grayleft.png
new file mode 100644
index 00000000..2625e739
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/grayleft.png differ
diff --git a/ambersim/models/allegro_hand/assets/grayright.png b/ambersim/models/allegro_hand/assets/grayright.png
new file mode 100644
index 00000000..4516b601
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/grayright.png differ
diff --git a/ambersim/models/allegro_hand/assets/grayup.png b/ambersim/models/allegro_hand/assets/grayup.png
new file mode 100644
index 00000000..0b07c030
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/grayup.png differ
diff --git a/ambersim/models/allegro_hand/assets/link_0.0.stl b/ambersim/models/allegro_hand/assets/link_0.0.stl
new file mode 100644
index 00000000..6db5cdb1
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_0.0.stl differ
diff --git a/ambersim/models/allegro_hand/assets/link_1.0.stl b/ambersim/models/allegro_hand/assets/link_1.0.stl
new file mode 100644
index 00000000..84e43514
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_1.0.stl differ
diff --git a/ambersim/models/allegro_hand/assets/link_12.0_left.stl b/ambersim/models/allegro_hand/assets/link_12.0_left.stl
new file mode 100644
index 00000000..aeaa37a1
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_12.0_left.stl differ
diff --git a/ambersim/models/allegro_hand/assets/link_12.0_right.stl b/ambersim/models/allegro_hand/assets/link_12.0_right.stl
new file mode 100644
index 00000000..7eecece5
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_12.0_right.stl differ
diff --git a/ambersim/models/allegro_hand/assets/link_13.0.stl b/ambersim/models/allegro_hand/assets/link_13.0.stl
new file mode 100644
index 00000000..3ca4ec36
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_13.0.stl differ
diff --git a/ambersim/models/allegro_hand/assets/link_14.0.stl b/ambersim/models/allegro_hand/assets/link_14.0.stl
new file mode 100644
index 00000000..73d6f695
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_14.0.stl differ
diff --git a/ambersim/models/allegro_hand/assets/link_15.0.stl b/ambersim/models/allegro_hand/assets/link_15.0.stl
new file mode 100644
index 00000000..e16bff45
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_15.0.stl differ
diff --git a/ambersim/models/allegro_hand/assets/link_15.0_tip.stl b/ambersim/models/allegro_hand/assets/link_15.0_tip.stl
new file mode 100644
index 00000000..9b1d8119
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_15.0_tip.stl differ
diff --git a/ambersim/models/allegro_hand/assets/link_2.0.stl b/ambersim/models/allegro_hand/assets/link_2.0.stl
new file mode 100644
index 00000000..20a911a4
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_2.0.stl differ
diff --git a/ambersim/models/allegro_hand/assets/link_3.0.stl b/ambersim/models/allegro_hand/assets/link_3.0.stl
new file mode 100644
index 00000000..a1fffee0
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_3.0.stl differ
diff --git a/ambersim/models/allegro_hand/assets/link_3.0_tip.stl b/ambersim/models/allegro_hand/assets/link_3.0_tip.stl
new file mode 100644
index 00000000..b7d9be8a
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_3.0_tip.stl differ
diff --git a/ambersim/models/allegro_hand/assets/link_4.0.stl b/ambersim/models/allegro_hand/assets/link_4.0.stl
new file mode 100644
index 00000000..ae75f704
Binary files /dev/null and b/ambersim/models/allegro_hand/assets/link_4.0.stl differ
diff --git a/ambersim/models/allegro_hand/cube.xml b/ambersim/models/allegro_hand/cube.xml
new file mode 100644
index 00000000..c857243d
--- /dev/null
+++ b/ambersim/models/allegro_hand/cube.xml
@@ -0,0 +1,29 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/ambersim/models/allegro_hand/left_hand.xml b/ambersim/models/allegro_hand/left_hand.xml
new file mode 100644
index 00000000..223d5f6b
--- /dev/null
+++ b/ambersim/models/allegro_hand/left_hand.xml
@@ -0,0 +1,259 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/ambersim/models/allegro_hand/right_hand.xml b/ambersim/models/allegro_hand/right_hand.xml
new file mode 100644
index 00000000..cc6f4dd0
--- /dev/null
+++ b/ambersim/models/allegro_hand/right_hand.xml
@@ -0,0 +1,260 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/ambersim/models/allegro_hand/scene_left.xml b/ambersim/models/allegro_hand/scene_left.xml
new file mode 100644
index 00000000..c1ad48f9
--- /dev/null
+++ b/ambersim/models/allegro_hand/scene_left.xml
@@ -0,0 +1,25 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/ambersim/models/allegro_hand/scene_right.xml b/ambersim/models/allegro_hand/scene_right.xml
new file mode 100644
index 00000000..b1f79cf6
--- /dev/null
+++ b/ambersim/models/allegro_hand/scene_right.xml
@@ -0,0 +1,25 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/ambersim/sim/simulation.py b/ambersim/sim/simulation.py
new file mode 100644
index 00000000..d3da7dae
--- /dev/null
+++ b/ambersim/sim/simulation.py
@@ -0,0 +1,214 @@
+import jax
+import jax.numpy as jnp
+from jax import jit, lax
+from mujoco import mjx
+
+from ambersim.control.predictive_control import (
+ VanillaPredictiveSamplingController,
+ VanillaPredictiveSamplingControllerParams,
+)
+
+"""This file contains simulation utils for various controllers."""
+
+
+def simulate_predictive_sampling_controller(
+ model: mjx.Model,
+ controller: VanillaPredictiveSamplingController,
+ x0: jax.Array,
+ num_steps: int,
+ N: int,
+ seed: int = 0,
+ physics_steps_per_control_step: int = 1,
+) -> None:
+ """Simulates a closed-loop system.
+
+ Args:
+ model: The "real" model. Properties like the timestep are set in here.
+ controller: The controller.
+ x0: The initial state.
+ num_steps: The number of steps to simulate for.
+ seed: The random seed.
+ physics_steps_per_control_step: The number of physics steps to take per control step.
+
+ Returns:
+ data: The internal data of the model after the simulation.
+ xs: The state trajectory.
+ """
+ # initial setup
+ print("Initial setup...")
+ dt = model.opt.timestep
+
+ data = mjx.make_data(model)
+ data = data.replace(qpos=x0[: model.nq], qvel=x0[model.nq :]) # setting the initial state.
+ data = mjx.forward(model, data) # setting other internal states like acceleration without integrating
+
+ xs_hist = jnp.zeros((num_steps + 1, model.nq + model.nv))
+ xs_hist = xs_hist.at[0, :].set(x0)
+
+ us_hist = jnp.zeros((num_steps, model.nu))
+ key = jax.random.PRNGKey(seed)
+
+ jit_compute = jit(
+ lambda key, x_meas, us_guess: controller.compute_with_us_star(
+ VanillaPredictiveSamplingControllerParams(key=key, x=x_meas, guess=us_guess)
+ )
+ )
+ jit_step = jit(lambda data, u: mjx.step(model, data.replace(ctrl=u)))
+
+ # us_guess = jnp.zeros((N, model.nu))
+ us_guess = jnp.array([0.0, 0.2, 0.2, 0.2, 0.0, 0.2, 0.2, 0.2, 0.0, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1])
+ import time
+
+ times = []
+ for i in range(num_steps // physics_steps_per_control_step):
+ start = time.time()
+ # computing the control input
+ x_meas = jnp.concatenate((data.qpos, data.qvel))
+ if i == 0:
+ print("Compiling the compute function...")
+ u, us_guess = jit_compute(key, x_meas, us_guess)
+ if i == 0:
+ print("Compiled the compute function!")
+ end = time.time()
+ times.append(end - start)
+
+ # simulating the system forward
+ for j in range(physics_steps_per_control_step):
+ print(f"t: {(i * physics_steps_per_control_step + j) * dt:.2f} | u: {u}")
+ if i == 0 and j == 0:
+ print("Compiling the step function...")
+ data = jit_step(data, u) # <-- segfaults here! non-jitted version doesn't segfault
+ if i == 0 and j == 0:
+ print("Compiled the step function!")
+ us_hist = us_hist.at[i, :].set(u)
+ xs_hist = xs_hist.at[i + 1, :].set(jnp.concatenate((data.qpos, data.qvel)))
+
+ key = jax.random.split(key)[0]
+
+ return data, xs_hist, us_hist, times # [DEBUG] return iteration times
+
+
+if __name__ == "__main__":
+ # TODO(ahl): move this into a proper script
+ import time
+
+ import mujoco
+ import mujoco.viewer
+ import numpy as np
+ from jax.experimental import mesh_utils
+ from jax.experimental.shard_map import shard_map
+ from jax.sharding import Mesh, NamedSharding
+ from jax.sharding import PartitionSpec as P
+ from mujoco.mjx._src.types import DisableBit
+
+ from ambersim.trajopt.cost import StaticGoalQuadraticCost
+ from ambersim.trajopt.shooting import VanillaPredictiveSampler
+ from ambersim.utils.io_utils import load_mj_model_from_file, mj_to_mjx_model_and_data
+
+ # sharding logic
+ devices = np.array(jax.devices())
+ mesh = Mesh(devices, ("i",))
+
+ # loading model + defining the predictive controller
+ # mj_model = load_mj_model_from_file("models/barrett_hand/bh280.xml")
+ mj_model = load_mj_model_from_file("models/allegro_hand/right_hand.xml")
+ mj_model.opt.timestep = 0.001 # dt, "framerate of reality"
+
+ ctrl_model, _ = mj_to_mjx_model_and_data(mj_model)
+ ctrl_model = ctrl_model.replace(
+ opt=ctrl_model.opt.replace(
+ timestep=0.015, # dt for each step in the controller's internal model
+ iterations=1, # number of Newton steps to take during solve
+ ls_iterations=1, # number of line search iterations along step direction
+ integrator=0, # Euler semi-implicit integration
+ solver=2, # Newton solver
+ # disableflags=DisableBit.CONTACT, # disable contact for this test
+ )
+ )
+ cost_function = StaticGoalQuadraticCost(
+ Q=jnp.diag(jnp.array([10.0] * ctrl_model.nq + [0.01] * ctrl_model.nv)),
+ Qf=jnp.diag(jnp.array([10.0] * ctrl_model.nq + [0.01] * ctrl_model.nv)),
+ R=0.001 * jnp.eye(ctrl_model.nu),
+ xg=jnp.concatenate(
+ (jnp.array([0.0, 0.5, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 1.0, 0.8, 0.5, 0.5]), jnp.zeros(16))
+ ),
+ )
+ nsamples = 100
+ stdev = 0.1
+ ps = VanillaPredictiveSampler(model=ctrl_model, cost_function=cost_function, nsamples=nsamples, stdev=stdev)
+
+ N = 10
+ controller = VanillaPredictiveSamplingController(trajectory_optimizer=ps, model=ctrl_model)
+ # jit_compute = jit(
+ # shard_map(
+ # lambda key, x_meas, us_guess: controller.compute_with_us_star(
+ # VanillaPredictiveSamplingControllerParams(key=key, x=x_meas, guess=us_guess)
+ # ),
+ # mesh=mesh,
+ # in_specs=(P(), P(), P('i', None)),
+ # out_specs=(P(), P('i', None)),
+ # check_rep=False,
+ # )
+ # ) # [DEBUG]
+ jit_compute = jit(
+ lambda key, x_meas, us_guess: controller.compute_with_us_star(
+ VanillaPredictiveSamplingControllerParams(key=key, x=x_meas, guess=us_guess)
+ )
+ )
+ print("Controller created! Simulating...")
+
+ # simulating forward
+ key = jax.random.PRNGKey(1234)
+ # x0 = jnp.array([0.14021026, 0.04142465, 0.99314054, -0.00491294, -0.00487147, 0.81353587, -0.00491293, -0.00487143, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
+ x0 = 2.0 * (jax.random.uniform(key, shape=(ctrl_model.nq + ctrl_model.nv,)) - 0.5)
+ # T = 3.0
+ # num_steps = int(T / model.opt.timestep)
+ # data, xs, us, times = simulate_predictive_sampling_controller(
+ # model, controller, x0, num_steps, N, seed=1337, physics_steps_per_control_step=2
+ # ) # just testing warm starting the sim
+ # print(np.mean(times[1:]))
+
+ # post-hoc visualization
+ mj_data = mujoco.MjData(mj_model)
+ mj_data.qpos[:] = x0[: mj_model.nq]
+ mj_data.qvel[:] = x0[mj_model.nq :]
+ dt = mj_model.opt.timestep
+ us_guess = jnp.zeros((N, mj_model.nu))
+ # us_guess = jax.device_put(jnp.zeros((N, mj_model.nu)), NamedSharding(mesh, P('i', None))) # [DEBUG]
+ compiled = False
+ num_phys_steps_per_control_step = 15
+ with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
+ while viewer.is_running():
+ while True:
+ if not compiled:
+ print("compiling jit_compute...")
+ compiled = True
+ start = time.time()
+ x_meas = jnp.concatenate((mj_data.qpos, mj_data.qvel))
+ u, _us_guess = jit_compute(key, x_meas, us_guess)
+ us_guess = np.concatenate((_us_guess[1:, :], np.zeros((1, mj_model.nu))))
+ mj_data.ctrl[:] = u
+ print(time.time() - start)
+ for _ in range(num_phys_steps_per_control_step):
+ start = time.time()
+ mujoco.mj_step(mj_model, mj_data)
+ viewer.sync()
+ elapsed = time.time() - start
+ if elapsed < mj_model.opt.timestep:
+ time.sleep(mj_model.opt.timestep)
+ key = jax.random.split(key)[0]
+
+ # start = time.time()
+ # ui = us[i, :]
+ # mj_data.ctrl[:] = ui
+ # if i <= int(T / dt):
+ # mujoco.mj_step(mj_model, mj_data)
+ # # print(f"t: {i * dt:.2f} | qpos: {mj_data.qpos}")
+ # # print(f"t: {i * dt:.2f} | cost: {mj_data.qpos @ mj_data.qpos + mj_data.qvel @ mj_data.qvel:.2f}")
+ # viewer.sync()
+ # i += 1
+ # elapsed = time.time() - start
+ # if elapsed < mj_model.opt.timestep:
+ # time.sleep(mj_model.opt.timestep)
+
+ breakpoint()
diff --git a/ambersim/trajopt/base.py b/ambersim/trajopt/base.py
index 2e50975e..f0d17c41 100644
--- a/ambersim/trajopt/base.py
+++ b/ambersim/trajopt/base.py
@@ -61,7 +61,7 @@ class TrajectoryOptimizer:
raising a NotImplementedError.
"""
- def optimize(self, params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array, jax.Array]:
+ def optimize(self, to_params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Array, jax.Array]:
"""Optimizes a trajectory.
The shapes of the outputs include (?) because we may choose to return non-zero-order-hold parameterizations of
@@ -69,7 +69,7 @@ def optimize(self, params: TrajectoryOptimizerParams) -> Tuple[jax.Array, jax.Ar
control inputs over the trajectory as is done in the gradient-based methods of MJPC).
Args:
- params: The parameters of the trajectory optimizer.
+ to_params: The parameters of the trajectory optimizer.
Returns:
xs_star (shape=(N + 1, nq + nv) or (?)): The optimized trajectory.
@@ -102,22 +102,22 @@ class CostFunction:
(4) histories of higher-order derivatives can be useful for updating their current estimates, e.g., BFGS.
"""
- def cost(self, xs: jax.Array, us: jax.Array, params: CostFunctionParams) -> Tuple[jax.Array, CostFunctionParams]:
+ def cost(self, xs: jax.Array, us: jax.Array, cf_params: CostFunctionParams) -> Tuple[jax.Array, CostFunctionParams]:
"""Computes the cost of a trajectory.
Args:
xs (shape=(N + 1, nq + nv)): The state trajectory.
us (shape=(N, nu)): The controls over the trajectory.
- params: The parameters of the cost function.
+ cf_params: The parameters of the cost function.
Returns:
val (shape=(,)): The cost of the trajectory.
- new_params: The updated parameters of the cost function.
+ new_cf_params: The updated parameters of the cost function.
"""
raise NotImplementedError
def grad(
- self, xs: jax.Array, us: jax.Array, params: CostFunctionParams
+ self, xs: jax.Array, us: jax.Array, cf_params: CostFunctionParams
) -> Tuple[jax.Array, jax.Array, CostFunctionParams, CostFunctionParams]:
"""Computes the gradient of the cost of a trajectory.
@@ -127,19 +127,19 @@ def grad(
Args:
xs (shape=(N + 1, nq + nv)): The state trajectory.
us (shape=(N, nu)): The controls over the trajectory.
- params: The parameters of the cost function.
+ cf_params: The parameters of the cost function.
Returns:
gcost_xs (shape=(N + 1, nq + nv): The gradient of the cost wrt xs.
gcost_us (shape=(N, nu)): The gradient of the cost wrt us.
- gcost_params: The gradient of the cost wrt params.
- new_params: The updated parameters of the cost function.
+ gcost_params: The gradient of the cost wrt cf_params.
+ new_cf_params: The updated parameters of the cost function.
"""
- _fn = lambda xs, us, params: self.cost(xs, us, params)[0] # only differentiate wrt the cost val
- return grad(_fn, argnums=(0, 1, 2))(xs, us, params) + (params,)
+ _fn = lambda xs, us, cf_params: self.cost(xs, us, cf_params)[0] # only differentiate wrt the cost val
+ return grad(_fn, argnums=(0, 1, 2))(xs, us, cf_params) + (cf_params,)
def hess(
- self, xs: jax.Array, us: jax.Array, params: CostFunctionParams
+ self, xs: jax.Array, us: jax.Array, cf_params: CostFunctionParams
) -> Tuple[
jax.Array, jax.Array, CostFunctionParams, jax.Array, CostFunctionParams, CostFunctionParams, CostFunctionParams
]:
@@ -153,20 +153,20 @@ def hess(
Args:
xs (shape=(N + 1, nq + nv)): The state trajectory.
us (shape=(N, nu)): The controls over the trajectory.
- params: The parameters of the cost function.
+ cf_params: The parameters of the cost function.
Returns:
Hcost_xsxs (shape=(N + 1, nq + nv, N + 1, nq + nv)): The Hessian of the cost wrt xs.
Hcost_xsus (shape=(N + 1, nq + nv, N, nu)): The Hessian of the cost wrt xs and us.
- Hcost_xsparams: The Hessian of the cost wrt xs and params.
+ Hcost_xsparams: The Hessian of the cost wrt xs and cf_params.
Hcost_usus (shape=(N, nu, N, nu)): The Hessian of the cost wrt us.
- Hcost_usparams: The Hessian of the cost wrt us and params.
- Hcost_paramsall: The Hessian of the cost wrt params and everything else.
- new_params: The updated parameters of the cost function.
+ Hcost_usparams: The Hessian of the cost wrt us and cf_params.
+ Hcost_paramsall: The Hessian of the cost wrt cf_params and everything else.
+ new_cf_params: The updated parameters of the cost function.
"""
- _fn = lambda xs, us, params: self.cost(xs, us, params)[0] # only differentiate wrt the cost val
- hessians = hessian(_fn, argnums=(0, 1, 2))(xs, us, params)
+ _fn = lambda xs, us, cf_params: self.cost(xs, us, cf_params)[0] # only differentiate wrt the cost val
+ hessians = hessian(_fn, argnums=(0, 1, 2))(xs, us, cf_params)
Hcost_xsxs, Hcost_xsus, Hcost_xsparams = hessians[0]
_, Hcost_usus, Hcost_usparams = hessians[1]
Hcost_paramsall = hessians[2]
- return Hcost_xsxs, Hcost_xsus, Hcost_xsparams, Hcost_usus, Hcost_usparams, Hcost_paramsall, params
+ return Hcost_xsxs, Hcost_xsus, Hcost_xsparams, Hcost_usus, Hcost_usparams, Hcost_paramsall, cf_params
diff --git a/ambersim/trajopt/cost.py b/ambersim/trajopt/cost.py
index fa560128..c50fb842 100644
--- a/ambersim/trajopt/cost.py
+++ b/ambersim/trajopt/cost.py
@@ -10,6 +10,7 @@
"""A collection of common cost functions."""
+@struct.dataclass
class StaticGoalQuadraticCost(CostFunction):
"""A quadratic cost function that penalizes the distance to a static goal.
@@ -19,19 +20,10 @@ class StaticGoalQuadraticCost(CostFunction):
dense.
"""
- def __init__(self, Q: jax.Array, Qf: jax.Array, R: jax.Array, xg: jax.Array) -> None:
- """Initializes a quadratic cost function.
-
- Args:
- Q (shape=(nx, nx)): The state cost matrix.
- Qf (shape=(nx, nx)): The final state cost matrix.
- R (shape=(nu, nu)): The control cost matrix.
- xg (shape=(nq,)): The goal state.
- """
- self.Q = Q
- self.Qf = Qf
- self.R = R
- self.xg = xg
+ Q: jax.Array # shape=(nx, nx) state cost matrix
+ Qf: jax.Array # shape=(nx, nx) final state cost matrix
+ R: jax.Array # shape=(nu, nu) control cost matrix
+ xg: jax.Array # shape=(nq + nv,) goal state
@staticmethod
def batch_quadform(bs: jax.Array, A: jax.Array) -> jax.Array:
diff --git a/ambersim/trajopt/shooting.py b/ambersim/trajopt/shooting.py
index 414917a1..7fe492e4 100644
--- a/ambersim/trajopt/shooting.py
+++ b/ambersim/trajopt/shooting.py
@@ -24,8 +24,8 @@ def shoot(m: mjx.Model, x0: jax.Array, us: jax.Array) -> jax.Array:
Args:
m: The model.
- x0: The initial state.
- us: The control inputs.
+ x0 (shape=(nx,)): The initial state.
+ us (shape=(N, nu)): The control inputs.
Returns:
xs (shape=(N + 1, nq + nv)): The state trajectory.
@@ -48,6 +48,44 @@ def scan_fn(d, u):
return xs
+def shoot_pd(m: mjx.Model, x0: jax.Array, qgs: jax.Array, kp: float, kd: float) -> jax.Array:
+ """Utility function that shoots a model forward given a sequence of goal states + PD gains.
+
+ Args:
+ m: The model.
+ x0 (shape=(nx,)): The initial state.
+ qgs (shape=(N, nq)): The goal trajectory.
+ kp: The proportional gain.
+ kd: The derivative gain.
+
+ Returns:
+ xs (shape=(N + 1, nq + nv)): The state trajectory.
+ us (shape=(N, nu)): The control inputs.
+ """
+ # initializing the data
+ d = mjx.make_data(m)
+ d = d.replace(qpos=x0[: m.nq], qvel=x0[m.nq :]) # setting the initial state.
+ d = mjx.forward(m, d) # setting other internal states like acceleration without integrating
+
+ def scan_fn(d, qg):
+ """Integrates the model forward one step given the control input u."""
+ # applying PD controller
+ u = -kp * (d.qpos - qg) - kd * d.qvel
+
+ d = d.replace(ctrl=u)
+ d = step(m, d)
+ x = jnp.concatenate((d.qpos, d.qvel)) # (nq + nv,)
+ xu = jnp.concatenate((x, u))
+ return d, xu
+
+ # scan over the control inputs to get the trajectory.
+ _, _xus = lax.scan(scan_fn, d, qgs, length=qgs.shape[0])
+ _xs = _xus[:, : m.nq + m.nv]
+ xs = jnp.concatenate((x0[None, :], _xs), axis=0) # (N + 1, nq + nv)
+ us = _xus[:, m.nq + m.nv :]
+ return xs, us
+
+
# ################ #
# SHOOTING METHODS #
# ################ #
@@ -61,27 +99,29 @@ class ShootingParams(TrajectoryOptimizerParams):
# inputs into the algorithm
x0: jax.Array # shape=(nq + nv,) or (?)
- us_guess: jax.Array # shape=(N, nu) or (?)
+ guess: jax.Array # shape=(N, n?) or (?)
@property
def N(self) -> int:
"""The number of time steps.
- By default, we assume us_guess represents the ZOH control inputs. However, in the case that it is actually an
+ By default, we assume guess represents the ZOH control inputs. However, in the case that it is actually an
alternate parameterization, we may need to compute N some other way, which requires overwriting this method.
+
+ In the case of PD controllers, for instance, the guess could actually be a trajectory of target positions.
"""
- return self.us_guess.shape[0]
+ return self.guess.shape[0]
@struct.dataclass
class ShootingAlgorithm(TrajectoryOptimizer):
"""A trajectory optimization algorithm based on shooting methods."""
- def optimize(self, params: ShootingParams) -> Tuple[jax.Array, jax.Array]:
+ def optimize(self, to_params: ShootingParams) -> Tuple[jax.Array, jax.Array]:
"""Optimizes a trajectory using a shooting method.
Args:
- params: The parameters of the trajectory optimizer.
+ to_params: The parameters of the trajectory optimizer.
Returns:
xs_star (shape=(N + 1, nq) or (?)): The optimized trajectory.
@@ -90,18 +130,46 @@ def optimize(self, params: ShootingParams) -> Tuple[jax.Array, jax.Array]:
raise NotImplementedError
-# predictive sampling API
+# ########################### #
+# PREDICTIVE SAMPLING METHODS #
+# ########################### #
+
+# generic API
@struct.dataclass
-class VanillaPredictiveSamplerParams(ShootingParams):
+class PredictiveSamplerParams(ShootingParams):
"""Parameters for generic predictive sampling methods."""
key: jax.Array # random key for sampling
@struct.dataclass
-class VanillaPredictiveSampler(ShootingAlgorithm):
+class PredictiveSampler(ShootingAlgorithm):
+ """A generic predictive sampler object."""
+
+ def optimize(self, to_params: PredictiveSamplerParams) -> Tuple[jax.Array, jax.Array]:
+ """Optimizes a trajectory using a vanilla predictive sampler.
+
+ Args:
+ to_params: The parameters of the trajectory optimizer.
+
+ Returns:
+ xs (shape=(N + 1, nq + nv) or (?)): The optimized trajectory.
+ us (shape=(N, nu) or (?)): The optimized controls.
+ """
+
+
+# vanilla
+
+
+@struct.dataclass
+class VanillaPredictiveSamplerParams(PredictiveSamplerParams):
+ """Parameters for the vanilla predictive sampling method."""
+
+
+@struct.dataclass
+class VanillaPredictiveSampler(PredictiveSampler):
"""A vanilla predictive sampler object.
The following choices are made:
@@ -116,25 +184,25 @@ class VanillaPredictiveSampler(ShootingAlgorithm):
nsamples: int = struct.field(pytree_node=False)
stdev: float = struct.field(pytree_node=False) # noise scale, parameters theta_new ~ N(theta, (stdev ** 2) * I)
- def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, jax.Array]:
+ def optimize(self, to_params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, jax.Array]:
"""Optimizes a trajectory using a vanilla predictive sampler.
Args:
- params: The parameters of the trajectory optimizer.
+ to_params: The parameters of the trajectory optimizer.
Returns:
- xs (shape=(N + 1, nq + nv)): The optimized trajectory.
- us (shape=(N, nu)): The optimized controls.
+ xs_star (shape=(N + 1, nq + nv)): The optimized trajectory.
+ us_star (shape=(N, nu)): The optimized controls.
"""
# unpack the params
m = self.model
nsamples = self.nsamples
stdev = self.stdev
- x0 = params.x0
- us_guess = params.us_guess
- N = params.N
- key = params.key
+ x0 = to_params.x0
+ us_guess = to_params.guess
+ N = to_params.N
+ key = to_params.key
# sample over the control inputs - the first sample is the guess, since it's possible that it's the best one
noise = jnp.concatenate(
@@ -155,3 +223,74 @@ def optimize(self, params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, j
xs_star = lax.dynamic_slice(xs_samples, (best_idx, 0, 0), (1, N + 1, m.nq + m.nv))[0] # (N + 1, nq + nv)
us_star = lax.dynamic_slice(us_samples, (best_idx, 0, 0), (1, N, m.nu))[0] # (N, nu)
return xs_star, us_star
+
+
+# PD predictive sampling
+
+
+@struct.dataclass
+class PDPredictiveSamplerParams(PredictiveSamplerParams):
+ """Parameters for the PD controller predictive sampler."""
+
+ kp: float # proportional gain
+ kd: float # derivative gain
+
+
+@struct.dataclass
+class PDPredictiveSampler(PredictiveSampler):
+ """A PD predictive sampler object.
+
+ The following choices are made:
+ (1) the model parameters are fixed, and are therefore a field of this dataclass;
+ (2) the control sequence is parameterized as a ZOH sequence instead of a spline;
+ (3) the guess supplied is actually a trajectory of positions and the control signal is computed via PD control;
+ (4) the cost function is quadratic in the states and controls.
+ """
+
+ model: mjx.Model
+ cost_function: CostFunction
+ nsamples: int = struct.field(pytree_node=False)
+ stdev: float = struct.field(pytree_node=False) # noise scale, parameters theta_new ~ N(theta, (stdev ** 2) * I)
+
+ def optimize(self, to_params: VanillaPredictiveSamplerParams) -> Tuple[jax.Array, jax.Array]:
+ """Optimizes a trajectory using a vanilla predictive sampler.
+
+ Args:
+ to_params: The parameters of the trajectory optimizer.
+
+ Returns:
+ xs_star (shape=(N + 1, nq + nv)): The optimized trajectory.
+ us_star (shape=(N, nu)): The optimized controls.
+ """
+ # unpack the params
+ m = self.model
+ nsamples = self.nsamples
+ stdev = self.stdev
+
+ x0 = to_params.x0
+ qs_guess = to_params.guess
+ N = to_params.N
+ key = to_params.key
+
+ kp = to_params.kp
+ kd = to_params.kd
+
+ # sample random position trajectories
+ noise = jnp.concatenate(
+ (jnp.zeros((1, N, m.nq)), jax.random.normal(key, shape=(nsamples - 1, N, m.nq)) * stdev), axis=0
+ )
+ _qgs_samples = qs_guess + noise
+
+ # clamping the samples to their position limits
+ limits = m.jnt_range
+ clip_fn = partial(jnp.clip, a_min=limits[:, 0], a_max=limits[:, 1]) # clipping function with limits already set
+ qgs_samples = vmap(vmap(clip_fn))(_qgs_samples) # apply limits only to the last dim, need a nested vmap
+
+ # predict many samples, evaluate them, and return the best trajectory tuple
+ # vmap over the input data and the control trajectories
+ xs_samples, us_samples = vmap(shoot_pd, in_axes=(None, None, 0, None, None))(m, x0, qgs_samples, kp, kd)
+ costs, _ = vmap(self.cost_function.cost, in_axes=(0, 0, None))(xs_samples, us_samples, None) # (nsamples,)
+ best_idx = jnp.argmin(costs)
+ xs_star = lax.dynamic_slice(xs_samples, (best_idx, 0, 0), (1, N + 1, m.nq + m.nv))[0] # (N + 1, nq + nv)
+ us_star = lax.dynamic_slice(us_samples, (best_idx, 0, 0), (1, N, m.nu))[0] # (N, nu)
+ return xs_star, us_star
diff --git a/examples/control/algr_palmup.py b/examples/control/algr_palmup.py
new file mode 100644
index 00000000..89d6a47c
--- /dev/null
+++ b/examples/control/algr_palmup.py
@@ -0,0 +1,147 @@
+import time
+
+import jax
+import jax.numpy as jnp
+import mujoco
+import mujoco.viewer
+import numpy as np
+from jax import jit
+
+# from ambersim.control.predictive_control import PDPredictiveSamplingController, PDPredictiveSamplingControllerParams
+from ambersim.control.predictive_control import (
+ VanillaPredictiveSamplingController,
+ VanillaPredictiveSamplingControllerParams,
+)
+from ambersim.trajopt.cost import StaticGoalQuadraticCost
+
+# from ambersim.trajopt.shooting import PDPredictiveSampler
+from ambersim.trajopt.shooting import VanillaPredictiveSampler
+from ambersim.utils.io_utils import load_mj_model_from_file, mj_to_mjx_model_and_data
+
+"""Important info:
+* the cube's states come after the hand's states
+* the allegro hand is floating but fixed in space
+
+TODOs:
+* implement a check for the cube being stuck
+* add some reasonable cost terms for the palm up task (see the mjpc code + talk to vince)
+* add support for the PD controller, but we have to more clever about underactuation
+"""
+
+# ##################### #
+# CUBE TARGET POSITIONS #
+# ##################### #
+
+# position
+POS = jnp.array([0.05, 0.0, 0.05]) # move the cube over the fingers
+Z_FAIL = -0.05 # floor at -0.1
+
+# rotations
+IDENTITY = jnp.array([1.0, 0.0, 0.0, 0.0])
+X90 = jnp.array([0.7071068, 0.7071068, 0.0, 0.0])
+X180 = jnp.array([0.0, 1.0, 0.0, 0.0])
+X270 = jnp.array([-0.7071068, 0.7071068, 0.0, 0.0])
+Y90 = jnp.array([0.7071068, 0.0, 0.7071068, 0.0])
+Y180 = jnp.array([0.0, 0.0, 1.0, 0.0])
+Y270 = jnp.array([-0.7071068, 0.0, 0.7071068, 0.0])
+Z90 = jnp.array([0.7071068, 0.0, 0.0, 0.7071068])
+Z180 = jnp.array([0.0, 0.0, 0.0, 1.0])
+Z270 = jnp.array([-0.7071068, 0.0, 0.0, 0.7071068])
+
+# ########## #
+# PARAMETERS #
+# ########## #
+
+dt = 0.001 # rate of reality
+num_phys_steps_per_control_step = 15 # dt * npspcs = rate of control
+
+w_pos = 10.0
+w_vel = 0.001
+w_ctrl = 0.001
+q_goal_algr = jnp.zeros(16)
+q_goal_cube = jnp.concatenate((POS, X90))
+q_goal = jnp.concatenate((q_goal_algr, q_goal_cube))
+
+nsamples = 100 # number of samples to draw for predictive sampling
+stdev = 0.3 # standard deviation of control parameters to draw
+N = 5 # number of time steps to predict
+
+key = jax.random.PRNGKey(1234)
+
+# ########## #
+# SIMULATION #
+# ########## #
+
+# loading model
+mj_model = load_mj_model_from_file("models/allegro_hand/scene_right.xml")
+mj_model.opt.timestep = dt
+
+# defining the predictive controller
+ctrl_model, _ = mj_to_mjx_model_and_data(mj_model)
+ctrl_model = ctrl_model.replace(
+ opt=ctrl_model.opt.replace(
+ timestep=num_phys_steps_per_control_step * dt, # dt for each step in the controller's internal model
+ iterations=1, # number of Newton steps to take during solve
+ ls_iterations=1, # number of line search iterations along step direction
+ integrator=0, # Euler semi-implicit integration
+ solver=2, # Newton solver
+ )
+)
+cost_function = StaticGoalQuadraticCost(
+ Q=jnp.diag(jnp.array([w_pos] * ctrl_model.nq + [w_vel] * ctrl_model.nv)),
+ Qf=jnp.diag(jnp.array([w_pos] * ctrl_model.nq + [w_vel] * ctrl_model.nv)),
+ R=w_ctrl * jnp.eye(ctrl_model.nu),
+ xg=jnp.concatenate((q_goal, jnp.zeros(ctrl_model.nv))),
+)
+ps = VanillaPredictiveSampler(model=ctrl_model, cost_function=cost_function, nsamples=nsamples, stdev=stdev)
+controller = VanillaPredictiveSamplingController(trajectory_optimizer=ps, model=ctrl_model)
+jit_compute = jit(
+ lambda key, x_meas, us_guess: controller.compute_with_us_star(
+ VanillaPredictiveSamplingControllerParams(key=key, x=x_meas, guess=us_guess)
+ )
+)
+print("Controller created! Simulating...")
+
+# simulating forward
+q0_algr = jnp.zeros(16)
+q0_cube = jnp.concatenate((POS, IDENTITY))
+v0 = jnp.zeros(mj_model.nv)
+x0 = jnp.concatenate((q0_algr, q0_cube, v0))
+
+mj_data = mujoco.MjData(mj_model)
+mj_data.qpos[:] = x0[: mj_model.nq]
+mj_data.qvel[:] = x0[mj_model.nq :]
+
+us_guess = jnp.zeros((N, mj_model.nu))
+
+print("Compiling jit_compute...")
+jit_compute(key, x0, us_guess)
+print("Compiled! Beginning simulation...")
+
+with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
+ while viewer.is_running():
+ while True:
+ start = time.time()
+ x_meas = jnp.concatenate((mj_data.qpos, mj_data.qvel))
+ u, _us_guess = jit_compute(key, x_meas, us_guess)
+ us_guess = np.concatenate((_us_guess[1:, :], np.zeros((1, mj_model.nu))))
+ mj_data.ctrl[:] = u
+ print(f"Controller delay: {time.time() - start}")
+ for _ in range(num_phys_steps_per_control_step):
+ start = time.time()
+ mujoco.mj_step(mj_model, mj_data)
+ viewer.sync()
+ elapsed = time.time() - start
+ if elapsed < mj_model.opt.timestep:
+ time.sleep(mj_model.opt.timestep - elapsed)
+
+ # check whether to reset the cube
+ if mj_data.qpos[-5] <= Z_FAIL:
+ print("*** Cube fell! Resetting... *** ")
+ mj_data.qpos[:] = x0[: mj_model.nq]
+ mj_data.qvel[:] = x0[mj_model.nq :]
+ us_guess = jnp.zeros((N, mj_model.nu))
+ viewer.sync()
+ time.sleep(1.0)
+
+ key = jax.random.split(key)[0]
diff --git a/examples/control/pd_predictive_sampling.py b/examples/control/pd_predictive_sampling.py
new file mode 100644
index 00000000..c4b18b99
--- /dev/null
+++ b/examples/control/pd_predictive_sampling.py
@@ -0,0 +1,106 @@
+import time
+
+import jax
+import jax.numpy as jnp
+import mujoco
+import mujoco.viewer
+import numpy as np
+from jax import jit
+
+from ambersim.control.predictive_control import PDPredictiveSamplingController, PDPredictiveSamplingControllerParams
+from ambersim.trajopt.cost import StaticGoalQuadraticCost
+from ambersim.trajopt.shooting import PDPredictiveSampler
+from ambersim.utils.io_utils import load_mj_model_from_file, mj_to_mjx_model_and_data
+
+"""
+By far the most sensitive parameters for this controller parameterization are the rollout horizon N and the proportional gain.
+"""
+
+# ########## #
+# PARAMETERS #
+# ########## #
+
+dt = 0.001 # rate of reality
+num_phys_steps_per_control_step = 15 # dt * npspcs = rate of control
+
+w_pos = 10.0
+w_vel = 0.01
+w_ctrl = 0.001
+q_goal = jnp.array([0.0, 0.5, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 1.0, 0.8, 0.5, 0.5])
+
+nsamples = 100 # number of samples to draw for predictive sampling
+stdev = 0.1 # standard deviation of control parameters to draw
+N = 8 # number of time steps to predict
+
+key = jax.random.PRNGKey(1234)
+
+# ########## #
+# SIMULATION #
+# ########## #
+
+# loading model
+mj_model = load_mj_model_from_file("models/allegro_hand/right_hand.xml")
+mj_model.opt.timestep = dt
+
+# defining the predictive controller
+ctrl_model, _ = mj_to_mjx_model_and_data(mj_model)
+ctrl_model = ctrl_model.replace(
+ opt=ctrl_model.opt.replace(
+ timestep=num_phys_steps_per_control_step * dt, # dt for each step in the controller's internal model
+ iterations=1, # number of Newton steps to take during solve
+ ls_iterations=1, # number of line search iterations along step direction
+ integrator=0, # Euler semi-implicit integration
+ solver=2, # Newton solver
+ )
+)
+cost_function = StaticGoalQuadraticCost(
+ Q=jnp.diag(jnp.array([w_pos] * ctrl_model.nq + [w_vel] * ctrl_model.nv)),
+ Qf=jnp.diag(jnp.array([w_pos] * ctrl_model.nq + [w_vel] * ctrl_model.nv)),
+ R=w_ctrl * jnp.eye(ctrl_model.nu),
+ xg=jnp.concatenate((q_goal, jnp.zeros(ctrl_model.nv))),
+)
+ps = PDPredictiveSampler(model=ctrl_model, cost_function=cost_function, nsamples=nsamples, stdev=stdev)
+
+kp = 5.0
+kd = 0.1
+controller = PDPredictiveSamplingController(trajectory_optimizer=ps, model=ctrl_model)
+jit_compute = jit(
+ lambda key, x_meas, qgs_guess: controller.compute_with_qs_star(
+ PDPredictiveSamplingControllerParams(key=key, x=x_meas, guess=qgs_guess, kp=kp, kd=kd)
+ )
+)
+print("Controller created! Simulating...")
+
+# simulating forward
+x0 = 2.0 * (jax.random.uniform(key, shape=(ctrl_model.nq + ctrl_model.nv,)) - 0.5) # random initial state
+
+mj_data = mujoco.MjData(mj_model)
+mj_data.qpos[:] = x0[: mj_model.nq]
+mj_data.qvel[:] = x0[mj_model.nq :]
+
+qgs_guess = jnp.tile(x0[: mj_model.nq], (N, 1))
+
+print("Compiling jit_compute...")
+jit_compute(key, x0, qgs_guess)
+print("Compiled! Beginning simulation...")
+
+with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
+ while viewer.is_running():
+ while True:
+ start = time.time()
+ x_meas = jnp.concatenate((mj_data.qpos, mj_data.qvel))
+ qg, _qgs_guess = jit_compute(key, x_meas, qgs_guess)
+ qg.block_until_ready()
+ qgs_guess = _qgs_guess[1:, :]
+ print(f"Controller delay: {time.time() - start}")
+ for _ in range(num_phys_steps_per_control_step):
+ start = time.time()
+ u = -kp * (mj_data.qpos - qg) - kd * mj_data.qvel
+ u = np.clip(u, mj_model.actuator_ctrlrange[:, 0], mj_model.actuator_ctrlrange[:, 1])
+ mj_data.ctrl[:] = u
+ mujoco.mj_step(mj_model, mj_data)
+ viewer.sync()
+ elapsed = time.time() - start
+ if elapsed < mj_model.opt.timestep:
+ time.sleep(mj_model.opt.timestep - elapsed)
+ key = jax.random.split(key)[0]
diff --git a/examples/control/vanilla_predictive_sampling.py b/examples/control/vanilla_predictive_sampling.py
new file mode 100644
index 00000000..b1e8dc31
--- /dev/null
+++ b/examples/control/vanilla_predictive_sampling.py
@@ -0,0 +1,99 @@
+import time
+
+import jax
+import jax.numpy as jnp
+import mujoco
+import mujoco.viewer
+import numpy as np
+from jax import jit
+
+from ambersim.control.predictive_control import (
+ VanillaPredictiveSamplingController,
+ VanillaPredictiveSamplingControllerParams,
+)
+from ambersim.trajopt.cost import StaticGoalQuadraticCost
+from ambersim.trajopt.shooting import VanillaPredictiveSampler
+from ambersim.utils.io_utils import load_mj_model_from_file, mj_to_mjx_model_and_data
+
+# ########## #
+# PARAMETERS #
+# ########## #
+
+dt = 0.001 # rate of reality
+num_phys_steps_per_control_step = 15 # dt * npspcs = rate of control
+
+w_pos = 10.0
+w_vel = 0.01
+w_ctrl = 0.001
+q_goal = jnp.array([0.0, 0.5, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 0.0, 0.5, 0.5, 0.5, 1.0, 0.8, 0.5, 0.5])
+
+nsamples = 100 # number of samples to draw for predictive sampling
+stdev = 0.1 # standard deviation of control parameters to draw
+N = 10 # number of time steps to predict
+
+key = jax.random.PRNGKey(1234)
+
+# ########## #
+# SIMULATION #
+# ########## #
+
+# loading model
+mj_model = load_mj_model_from_file("models/allegro_hand/right_hand.xml")
+mj_model.opt.timestep = dt
+
+# defining the predictive controller
+ctrl_model, _ = mj_to_mjx_model_and_data(mj_model)
+ctrl_model = ctrl_model.replace(
+ opt=ctrl_model.opt.replace(
+ timestep=num_phys_steps_per_control_step * dt, # dt for each step in the controller's internal model
+ iterations=1, # number of Newton steps to take during solve
+ ls_iterations=1, # number of line search iterations along step direction
+ integrator=0, # Euler semi-implicit integration
+ solver=2, # Newton solver
+ )
+)
+cost_function = StaticGoalQuadraticCost(
+ Q=jnp.diag(jnp.array([w_pos] * ctrl_model.nq + [w_vel] * ctrl_model.nv)),
+ Qf=jnp.diag(jnp.array([w_pos] * ctrl_model.nq + [w_vel] * ctrl_model.nv)),
+ R=w_ctrl * jnp.eye(ctrl_model.nu),
+ xg=jnp.concatenate((q_goal, jnp.zeros(ctrl_model.nv))),
+)
+ps = VanillaPredictiveSampler(model=ctrl_model, cost_function=cost_function, nsamples=nsamples, stdev=stdev)
+controller = VanillaPredictiveSamplingController(trajectory_optimizer=ps, model=ctrl_model)
+jit_compute = jit(
+ lambda key, x_meas, us_guess: controller.compute_with_us_star(
+ VanillaPredictiveSamplingControllerParams(key=key, x=x_meas, guess=us_guess)
+ )
+)
+print("Controller created! Simulating...")
+
+# simulating forward
+x0 = 2.0 * (jax.random.uniform(key, shape=(ctrl_model.nq + ctrl_model.nv,)) - 0.5) # random initial state
+
+mj_data = mujoco.MjData(mj_model)
+mj_data.qpos[:] = x0[: mj_model.nq]
+mj_data.qvel[:] = x0[mj_model.nq :]
+
+us_guess = jnp.zeros((N, mj_model.nu))
+
+print("Compiling jit_compute...")
+jit_compute(key, x0, us_guess)
+print("Compiled! Beginning simulation...")
+
+with mujoco.viewer.launch_passive(mj_model, mj_data) as viewer:
+ while viewer.is_running():
+ while True:
+ start = time.time()
+ x_meas = jnp.concatenate((mj_data.qpos, mj_data.qvel))
+ u, _us_guess = jit_compute(key, x_meas, us_guess)
+ us_guess = np.concatenate((_us_guess[1:, :], np.zeros((1, mj_model.nu))))
+ mj_data.ctrl[:] = u
+ print(f"Controller delay: {time.time() - start}")
+ for _ in range(num_phys_steps_per_control_step):
+ start = time.time()
+ mujoco.mj_step(mj_model, mj_data)
+ viewer.sync()
+ elapsed = time.time() - start
+ if elapsed < mj_model.opt.timestep:
+ time.sleep(mj_model.opt.timestep - elapsed)
+ key = jax.random.split(key)[0]
diff --git a/examples/convex_decomposition.py b/examples/io/convex_decomposition.py
similarity index 100%
rename from examples/convex_decomposition.py
rename to examples/io/convex_decomposition.py
diff --git a/examples/load_from_file.py b/examples/io/load_from_file.py
similarity index 100%
rename from examples/load_from_file.py
rename to examples/io/load_from_file.py
diff --git a/examples/interactive_simulation.py b/examples/sim/interactive_simulation.py
similarity index 100%
rename from examples/interactive_simulation.py
rename to examples/sim/interactive_simulation.py
diff --git a/pyproject.toml b/pyproject.toml
index 64bd86e4..79e618d4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -59,7 +59,7 @@ install-mujoco-from-src = "ambersim._scripts._install_shim:entrypoint"
include = ["ambersim*"]
[tool.setuptools.package-data]
-"*" = ["*.obj", "*.urdf", "*.xml", "install.sh", "install_mj_source.sh"]
+"*" = ["*.png", "*.obj", "*.stl", "*.urdf", "*.xml", "install.sh", "install_mj_source.sh"]
[tool.black]
line-length = 120
diff --git a/tests/trajopt/test_predictive_sampler.py b/tests/trajopt/test_predictive_sampler.py
index 7111fc05..f66f4b9a 100644
--- a/tests/trajopt/test_predictive_sampler.py
+++ b/tests/trajopt/test_predictive_sampler.py
@@ -8,17 +8,28 @@
from ambersim.trajopt.base import CostFunctionParams
from ambersim.trajopt.cost import StaticGoalQuadraticCost
-from ambersim.trajopt.shooting import VanillaPredictiveSampler, VanillaPredictiveSamplerParams, shoot
+from ambersim.trajopt.shooting import (
+ PDPredictiveSampler,
+ PDPredictiveSamplerParams,
+ VanillaPredictiveSampler,
+ VanillaPredictiveSamplerParams,
+ shoot,
+ shoot_pd,
+)
from ambersim.utils.io_utils import load_mjx_model_and_data_from_file
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" # fixes OOM error
+# ########################## #
+# VANILLA PREDICTIVE SAMPLER #
+# ########################## #
+
@pytest.fixture
def vps_data():
"""Makes data required for testing vanilla predictive sampling."""
# initializing the predictive sampler
- model, _ = load_mjx_model_and_data_from_file("models/barrett_hand/bh280.xml", force_float=False)
+ model, _ = load_mjx_model_and_data_from_file("models/allegro_hand/right_hand_motor.xml", force_float=False)
model = model.replace(
opt=model.opt.replace(
timestep=0.002, # dt
@@ -26,7 +37,6 @@ def vps_data():
ls_iterations=4, # number of line search iterations along step direction
integrator=0, # Euler semi-implicit integration
solver=2, # Newton solver
- disableflags=DisableBit.CONTACT, # disable contact for this test
)
)
cost_function = StaticGoalQuadraticCost(
@@ -50,7 +60,7 @@ def test_smoke_VPS(vps_data):
x0 = jnp.zeros(model.nq + model.nv)
num_steps = 10
us_guess = jnp.zeros((num_steps, model.nu))
- params = VanillaPredictiveSamplerParams(key=key, x0=x0, us_guess=us_guess)
+ params = VanillaPredictiveSamplerParams(key=key, x0=x0, guess=us_guess)
# sampling the best sequence of qs, vs, and us
optimize_fn = jit(ps.optimize)
@@ -72,7 +82,7 @@ def test_VPS_cost_decrease(vps_data):
us_guess = jax.random.normal(key=subkey, shape=(batch_size, num_steps, model.nu))
keys = jax.random.split(key, num=batch_size)
- params = VanillaPredictiveSamplerParams(key=keys, x0=x0, us_guess=us_guess)
+ params = VanillaPredictiveSamplerParams(key=keys, x0=x0, guess=us_guess)
# sampling with the vanilla predictive sampler
xs_stars, us_stars = vmap(ps.optimize)(params)
@@ -85,3 +95,50 @@ def test_VPS_cost_decrease(vps_data):
xs_guess = vmap(shoot, in_axes=(None, 0, 0))(model, x0, us_guess)
costs_guess = vmap_cost(xs_guess, us_guess)
assert jnp.all(costs_star <= costs_guess)
+
+
+# ##################### #
+# PD PREDICTIVE SAMPLER #
+# ##################### #
+
+
+@pytest.fixture
+def pdps_data():
+ """Makes data required for testing PD predictive sampling."""
+ # initializing the predictive sampler
+ model, _ = load_mjx_model_and_data_from_file("models/allegro_hand/right_hand_position.xml", force_float=False)
+ model = model.replace(
+ opt=model.opt.replace(
+ timestep=0.002, # dt
+ iterations=1, # number of Newton steps to take during solve
+ ls_iterations=4, # number of line search iterations along step direction
+ integrator=0, # Euler semi-implicit integration
+ solver=2, # Newton solver
+ )
+ )
+ cost_function = StaticGoalQuadraticCost(
+ Q=jnp.eye(model.nq + model.nv),
+ Qf=10.0 * jnp.eye(model.nq + model.nv),
+ R=0.01 * jnp.eye(model.nu),
+ xg=jnp.zeros(model.nq + model.nv),
+ )
+ nsamples = 100
+ stdev = 0.01
+ ps = PDPredictiveSampler(model=model, cost_function=cost_function, nsamples=nsamples, stdev=stdev)
+ return ps, model, cost_function
+
+
+def test_smoke_PDPS(pdps_data):
+ """Simple smoke test to make sure we can run inputs through the PD predictive sampler + jit."""
+ ps, model, _ = pdps_data
+
+ # sampler parameters
+ key = jax.random.PRNGKey(0) # random seed for the predictive sampler
+ x0 = jnp.zeros(model.nq + model.nv)
+ num_steps = 10
+ qs_guess = jnp.tile(x0[: model.nq], (num_steps, 1))
+ params = PDPredictiveSamplerParams(key=key, x0=x0, guess=qs_guess, kp=1.0, kd=0.1)
+
+ # sampling the best sequence of qs, vs, and us
+ optimize_fn = jit(ps.optimize)
+ assert optimize_fn(params)