diff --git a/.github/workflows/code_checks.yml b/.github/workflows/code_checks.yml index 3a8a3d65..17aa5ce5 100644 --- a/.github/workflows/code_checks.yml +++ b/.github/workflows/code_checks.yml @@ -5,6 +5,7 @@ on: branches: [main] pull_request: branches: [main] + types: [ready_for_review] permissions: contents: read diff --git a/.gitignore b/.gitignore index d5b956c4..ffe707c6 100644 --- a/.gitignore +++ b/.gitignore @@ -164,4 +164,7 @@ cython_debug/ # mujoco MUJOCO_LOG.TXT -mujoco \ No newline at end of file +mujoco + +# media +*.png \ No newline at end of file diff --git a/README.md b/README.md index 621354e7..bbce0098 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,43 @@ Major versioning decisions: * `python=3.11.5`. `torch`, `jax`, and `mujoco` all support it and there are major reported speed improvements over `python` 3.10. * `cuda==11.8`. Both `torch` and `jax` support `cuda12`; however, they annoyingly support different minor versions which makes them [incompatible in the same environment](https://github.com/google/jax/issues/18032). Once this is resolved, we will upgrade to `cuda-12.2` or later. It seems most likely that `torch` will support `cuda-12.3` once they do upgrade, since that is the most recent release. +### Code Profiling +The majority of the profiling we do will be on JAX code. Here's a generic template for profiling code using `tensorboard` (you need the test dependencies): +``` +# create the function you want to profile - we recommend jitting it, since this typically changes the profiling results +def fn_to_profile(): + ... + +jit_fn = jit(fn_to_profile) + +# choose a path to store the profiling results +with jax.profiler.trace("/dir/to/profiling/results"): + jit_fn(inputs) +``` +To view the profiling results, run +``` +tensorboard --logdir=/dir/to/profiling/results --port +``` +where `--port` should be some open port like `8008`. In the top right dropdown menu which should say "Inactive," scroll down and select "Profile." Select the run you'd like to analyze and under tools, the most useful tab will usually be "trace_viewer." + +Sometimes, we want to expose certain subroutines to the profiler. We can do so with the following: +``` +# in one file +def fn(): + # stuff that we don't want to profile + fn1() + + # stuff we do want to specifically profile + with jax.named_scope("name_of_your_choice"): + fn2() + +# in another file containing the jitted function to profile +jit_fn = jit(fn) +with jax.profiler.trace("/dir/to/profiling/results"): + jit_fn() +``` +Now, the traced results will specifically show the time spent in `fn2` under the name you chose. Note that you can also use `jax.profiler.TraceAnnotation` or `jax.profiler.annotate_function()` instead, [as recommended](https://jax.readthedocs.io/en/latest/profiling.html#adding-custom-trace-events). + ### Tooling We use various tools to ensure code quality. diff --git a/benchmarks/trajopt/bm_predictive_sampling.py b/benchmarks/trajopt/bm_predictive_sampling.py new file mode 100644 index 00000000..5d6833d4 --- /dev/null +++ b/benchmarks/trajopt/bm_predictive_sampling.py @@ -0,0 +1,99 @@ +import timeit + +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +from jax import jit +from mujoco import mjx +from mujoco.mjx._src.types import DisableBit + +from ambersim.trajopt.cost import CostFunction, StaticGoalQuadraticCost +from ambersim.trajopt.shooting import VanillaPredictiveSampler, VanillaPredictiveSamplerParams +from ambersim.utils.io_utils import load_mjx_model_and_data_from_file + + +def make_ps(model: mjx.Model, cost_function: CostFunction, nsamples: int) -> VanillaPredictiveSampler: + """Makes a predictive sampler for this quick and dirty timing script.""" + stdev = 0.01 + ps = VanillaPredictiveSampler(model=model, cost_function=cost_function, nsamples=nsamples, stdev=stdev) + return ps + + +if __name__ == "__main__": + # initializing the model + model, _ = load_mjx_model_and_data_from_file("models/barrett_hand/bh280.xml", force_float=False) + model = model.replace( + opt=model.opt.replace( + timestep=0.002, # dt + iterations=1, # number of Newton steps to take during solve + ls_iterations=4, # number of line search iterations along step direction + integrator=0, # Euler semi-implicit integration + solver=2, # Newton solver + disableflags=DisableBit.CONTACT, # [IMPORTANT] disable contact for this example + ) + ) + + # initializing the cost function + cost_function = StaticGoalQuadraticCost( + Q=jnp.eye(model.nq + model.nv), + Qf=10.0 * jnp.eye(model.nq + model.nv), + R=0.01 * jnp.eye(model.nu), + # qg=jnp.zeros(model.nq).at[6].set(1.0), # if force_float=True + qg=jnp.zeros(model.nq), + vg=jnp.zeros(model.nv), + ) + + # sampler parameters we pass in independent of the number of samples + key = jax.random.PRNGKey(0) # random seed for the predictive sampler + q0 = jnp.zeros(model.nq).at[6].set(1.0) + v0 = jnp.zeros(model.nv) + num_steps = 10 + us_guess = jnp.zeros((num_steps, model.nu)) + params = VanillaPredictiveSamplerParams(key=key, q0=q0, v0=v0, us_guess=us_guess) + + nsamples_list = [1, 10, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 2000, 3000, 4000, 5000, 10000] + runtimes = [] + throughputs = [] + for nsamples in nsamples_list: + print(f"Running with nsamples={nsamples}...") + ps = make_ps(model, cost_function, nsamples) + optimize_fn = jit(ps.optimize) + + # # [DEBUG] profiling with tensorboard + # qs_star, vs_star, us_star = optimize_fn(params) # JIT compiling + # with jax.profiler.trace("/home/albert/tensorboard"): + # qs_star, vs_star, us_star = optimize_fn(params) # after JIT + + def _time_fn(fn=optimize_fn) -> None: + """Function to time runtime.""" + qs_star, vs_star, us_star = fn(params) + qs_star.block_until_ready() + vs_star.block_until_ready() + us_star.block_until_ready() + + compile_time = timeit.timeit(_time_fn, number=1) + print(f" Compile time: {compile_time}") + + num_timing_iters = 100 + time = timeit.timeit(_time_fn, number=num_timing_iters) + print(f" Avg. runtime: {time / num_timing_iters}") # returns TOTAL time, so compute the average ourselves + + runtimes.append(time / num_timing_iters) + throughputs.append(nsamples / (time / num_timing_iters)) + + plt.scatter(np.array(nsamples_list), np.array(runtimes)) + plt.xlabel("number of samples") + plt.ylabel("runtime (s)") + plt.title("Predictive Sampling: Number of Samples vs. Runtime") + plt.xlim([-100, max(nsamples_list) + 100]) + plt.ylim([0, max(runtimes) + 0.01]) + plt.show() + + plt.scatter(np.array(nsamples_list), np.array(throughputs)) + plt.xlabel("number of samples") + plt.ylabel("samples per second (s)") + plt.title("Predictive Sampling: Sampling Throughput vs. Number of Samples") + plt.xlim([-100, max(nsamples_list) + 100]) + plt.ylim([0, max(throughputs) + 10000]) + plt.show() diff --git a/environment.yml b/environment.yml index 3857960e..77ce9b5f 100644 --- a/environment.yml +++ b/environment.yml @@ -1,5 +1,5 @@ channels: - - nvidia/label/cuda-11.8.0 + - nvidia/label/cuda-12.3.0 - conda-forge dependencies: - cuda diff --git a/pyproject.toml b/pyproject.toml index 64bd86e4..8ab06d84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,15 +18,15 @@ dependencies = [ "coacd>=1.0.0", "dm_control>=1.0.0", "flax>=0.7.5", - "jax[cuda11_local]>=0.4.1", + "jax[cuda12_local]>=0.4.1", "jaxlib>=0.4.1", "matplotlib>=3.5.2", "mujoco>=3.0.0", "mujoco-mjx>=3.0.0", "numpy>=1.23.1", "scipy>=1.10.0", - "torch>=1.13.1", - "tensorboard>=2.15.1", + # "torch>=1.13.1", + "tensorboard>=2.13.0", # [Dec. 4, 2023] https://github.com/tensorflow/tensorflow/issues/62075#issuecomment-1808652131 ] [project.optional-dependencies] @@ -43,10 +43,12 @@ dev = [ # Test-specific packages for verification test = [ - "cvxpy>=1.4.1", - "drake>=1.21.0", + # "cvxpy>=1.4.1", + # "drake>=1.21.0", "libigl>=2.4.0", "pin>=2.6.20", + "tensorflow>=2.13.0", # [Dec. 4, 2023] https://github.com/tensorflow/tensorflow/issues/62075#issuecomment-1808652131 + "tensorboard-plugin-profile>=2.13.0", ] # All packages