-
Notifications
You must be signed in to change notification settings - Fork 258
Fridah/static fp4 export #858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: asma/refactor-scale-sweep
Are you sure you want to change the base?
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the
✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
9f69993 to
df4e6a9
Compare
| "algorithm": "max", | ||
| } | ||
|
|
||
| NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Activation is quantized, can we make it more evident?
| NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = { | |
| NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG = { |
| weight_scaling_factor_2 = global_amax / (6.0 * 448.0) | ||
| per_block_scale = per_block_amax / (6.0 * weight_scaling_factor_2.to(per_block_amax.device)) | ||
| per_block_scale[per_block_scale == 0] = 1.0 | ||
|
|
||
| # Reshape per_block_scale to match weight's block structure: (rows, num_blocks_per_row) | ||
| num_blocks_per_row = weight.shape[-1] // block_size | ||
| expected_shape = (*weight.shape[:-1], num_blocks_per_row) | ||
| per_block_scale = per_block_scale.view(expected_shape) | ||
|
|
||
| return per_block_scale.to(torch.float8_e4m3fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
This is the way I think about FP8 quantization of scale. I think it is more intuitive as following
| weight_scaling_factor_2 = global_amax / (6.0 * 448.0) | |
| per_block_scale = per_block_amax / (6.0 * weight_scaling_factor_2.to(per_block_amax.device)) | |
| per_block_scale[per_block_scale == 0] = 1.0 | |
| # Reshape per_block_scale to match weight's block structure: (rows, num_blocks_per_row) | |
| num_blocks_per_row = weight.shape[-1] // block_size | |
| expected_shape = (*weight.shape[:-1], num_blocks_per_row) | |
| per_block_scale = per_block_scale.view(expected_shape) | |
| return per_block_scale.to(torch.float8_e4m3fn) | |
| per_block_scale_max = global_amax.float() / 6.0 # importance do the scale in float | |
| per_block_scale = per_block_amax.float() / 6.0 | |
| per_block_scale[per_block_scale == 0] = 1.0 | |
| # Reshape per_block_scale to match weight's block structure: (rows, num_blocks_per_row) | |
| num_blocks_per_row = weight.shape[-1] // block_size | |
| expected_shape = (*weight.shape[:-1], num_blocks_per_row) | |
| per_block_scale = per_block_scale.view(expected_shape) | |
| per_block_scale_fp8 = (per_block_scale * 448.0 / per_block_scale_max).to(torch.float8_e4m3fn) | |
| return per_block_scale_fp8 |
| assert ( | ||
| hasattr(weight_quantizer, "_global_amax") and weight_quantizer._global_amax is not None | ||
| ) | ||
| global_amax = weight_quantizer._global_amax.float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use this
| global_amax = weight_quantizer._global_amax.float() | |
| global_amax = weight_quantizer.global_amax.float() |
| assert ( | ||
| hasattr(weight_quantizer, "_global_amax") and weight_quantizer._global_amax is not None | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| assert ( | |
| hasattr(weight_quantizer, "_global_amax") and weight_quantizer._global_amax is not None | |
| ) | |
| assert weight_quantizer.global_amax is not None |
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
What does this PR do?
Type of change: ?
Overview: ?
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Additional Information