feat: support Risk-Aligned Cache under Classifier-Free Guidance.#1273
feat: support Risk-Aligned Cache under Classifier-Free Guidance.#1273yiming-l21 wants to merge 1 commit intojd-opensource:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements the RACFGCache policy, a joint decision-making caching mechanism for DiT models that incorporates accumulated risk and propagation-aware reweighting. The changes include the addition of hardcoded calibration tables, updates to global flags and configuration structures, and the integration of runtime context handling within the caching framework. Feedback identifies several style guide violations, specifically the need to use fixed-width integers (int64_t) instead of int and the requirement to use CHECK instead of TORCH_CHECK for assertions. Additionally, a recommendation was made to avoid changing default values in base configuration structs to prevent unintended side effects on other caching policies.
| inline float NaN() { return std::numeric_limits<float>::quiet_NaN(); } | ||
|
|
||
| // Build a square rho table with shape [steps, steps]. | ||
| torch::Tensor make_square_table(int steps, const std::vector<float>& values) { |
There was a problem hiding this comment.
The style guide mandates the use of fixed-width integers (e.g., int32_t, int64_t) instead of int. For consistency with other step-related variables in the codebase, int64_t is preferred.
| torch::Tensor make_square_table(int steps, const std::vector<float>& values) { | |
| torch::Tensor make_square_table(int64_t steps, const std::vector<float>& values) { |
References
- Use fixed-width integers (
int32_t,int64_t) instead of plainint, unless the API you are calling explicitly requiresint. (link)
| TORCH_CHECK(steps > 0, "steps must be positive, got ", steps); | ||
| TORCH_CHECK(values.size() == static_cast<size_t>(steps * steps), | ||
| "rho table size mismatch, expected ", | ||
| steps * steps, | ||
| ", got ", | ||
| values.size()); |
There was a problem hiding this comment.
The style guide requires using CHECK from glog for assertions, not TORCH_CHECK.
CHECK(steps > 0) << "steps must be positive, got " << steps;
CHECK(values.size() == static_cast<size_t>(steps * steps))
<< "rho table size mismatch, expected " << steps * steps
<< ", got " << values.size();References
- Use
CHECK(glog) instead ofTORCH_CHECKfor assertions. (link)
| struct RhoTableSpec { | ||
| std::string model_name; | ||
| float cfg_scale = 0.0f; | ||
| int infer_steps = 0; |
There was a problem hiding this comment.
The style guide requires using fixed-width integers like int64_t instead of int. Other parts of the codebase use int64_t for step counts.
| int infer_steps = 0; | |
| int64_t infer_steps = 0; |
References
- Use fixed-width integers (
int32_t,int64_t) instead of plainint, unless the API you are calling explicitly requiresint. (link)
| struct DiTBaseCacheOptions { | ||
| // the number of warmup steps. | ||
| int warmup_steps = 0; | ||
| int warmup_steps = 5; |
There was a problem hiding this comment.
Changing the default warmup_steps in the base DiTBaseCacheOptions from 0 to 5 will affect all caching policies, not just the new RACFGCache. This could introduce unintended behavior changes for existing policies. It would be safer to revert this change and set the warmup steps specifically for RACFGCache when parsing its flags.
| int warmup_steps = 5; | |
| int warmup_steps = 0; |
No description provided.