Skip to content

feat: support Risk-Aligned Cache under Classifier-Free Guidance.#1273

Open
yiming-l21 wants to merge 1 commit intojd-opensource:mainfrom
yiming-l21:main
Open

feat: support Risk-Aligned Cache under Classifier-Free Guidance.#1273
yiming-l21 wants to merge 1 commit intojd-opensource:mainfrom
yiming-l21:main

Conversation

@yiming-l21
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the RACFGCache policy, a joint decision-making caching mechanism for DiT models that incorporates accumulated risk and propagation-aware reweighting. The changes include the addition of hardcoded calibration tables, updates to global flags and configuration structures, and the integration of runtime context handling within the caching framework. Feedback identifies several style guide violations, specifically the need to use fixed-width integers (int64_t) instead of int and the requirement to use CHECK instead of TORCH_CHECK for assertions. Additionally, a recommendation was made to avoid changing default values in base configuration structs to prevent unintended side effects on other caching policies.

inline float NaN() { return std::numeric_limits<float>::quiet_NaN(); }

// Build a square rho table with shape [steps, steps].
torch::Tensor make_square_table(int steps, const std::vector<float>& values) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The style guide mandates the use of fixed-width integers (e.g., int32_t, int64_t) instead of int. For consistency with other step-related variables in the codebase, int64_t is preferred.

Suggested change
torch::Tensor make_square_table(int steps, const std::vector<float>& values) {
torch::Tensor make_square_table(int64_t steps, const std::vector<float>& values) {
References
  1. Use fixed-width integers (int32_t, int64_t) instead of plain int, unless the API you are calling explicitly requires int. (link)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +35 to +40
TORCH_CHECK(steps > 0, "steps must be positive, got ", steps);
TORCH_CHECK(values.size() == static_cast<size_t>(steps * steps),
"rho table size mismatch, expected ",
steps * steps,
", got ",
values.size());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The style guide requires using CHECK from glog for assertions, not TORCH_CHECK.

  CHECK(steps > 0) << "steps must be positive, got " << steps;
  CHECK(values.size() == static_cast<size_t>(steps * steps))
      << "rho table size mismatch, expected " << steps * steps
      << ", got " << values.size();
References
  1. Use CHECK (glog) instead of TORCH_CHECK for assertions. (link)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

struct RhoTableSpec {
std::string model_name;
float cfg_scale = 0.0f;
int infer_steps = 0;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The style guide requires using fixed-width integers like int64_t instead of int. Other parts of the codebase use int64_t for step counts.

Suggested change
int infer_steps = 0;
int64_t infer_steps = 0;
References
  1. Use fixed-width integers (int32_t, int64_t) instead of plain int, unless the API you are calling explicitly requires int. (link)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

struct DiTBaseCacheOptions {
// the number of warmup steps.
int warmup_steps = 0;
int warmup_steps = 5;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Changing the default warmup_steps in the base DiTBaseCacheOptions from 0 to 5 will affect all caching policies, not just the new RACFGCache. This could introduce unintended behavior changes for existing policies. It would be safer to revert this change and set the warmup steps specifically for RACFGCache when parsing its flags.

Suggested change
int warmup_steps = 5;
int warmup_steps = 0;

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant