Skip to content

Conversation

@vthumbe1503
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vthumbe1503 and others added 2 commits January 5, 2026 18:11
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
vthumbe1503 and others added 4 commits January 6, 2026 12:34
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as ready for review January 7, 2026 17:22
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 7, 2026

Greptile Summary

  • Implements CPU optimizations for FP8 operations across TransformerEngine by caching function call results, replacing redundant property accesses with cached values, and optimizing Python object creation patterns
  • Refactors extension initialization system to use a global flag instead of individual guards, improving thread safety and atomicity of the initialization process
  • Adds convenience properties (shape, is_cuda) to Float8Tensor class and dtype caching to quantized tensor classes for improved attribute access performance

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/util.cpp Contains critical logical error - changed && to `
transformer_engine/pytorch/module/linear.py Introduced subtle bug in FP8 condition check on line 484 where cached bias_requires_grad is used even when bias is None

Confidence score: 2/5

  • This PR contains multiple critical bugs that will cause functional issues in production, including a logical error that breaks FP8 scaling functionality and incorrect bias gradient handling
  • Score reflects the presence of serious implementation errors that outweigh the performance optimizations, particularly the complete reversal of intended logic in the util.cpp file
  • Pay close attention to transformer_engine/pytorch/csrc/util.cpp and transformer_engine/pytorch/module/linear.py for the critical bugs identified

Sequence Diagram

sequenceDiagram
    participant User
    participant Linear
    participant _Linear
    participant general_gemm
    participant cublas_gemm
    participant Float8Quantizer
    participant TensorWrapper
    participant CUDA

    User->>Linear: "forward(input)"
    Linear->>Linear: "prepare_forward(input)"
    Linear->>Linear: "_get_quantizers()"
    Linear->>Float8Quantizer: "quantize(input)"
    Float8Quantizer->>TensorWrapper: "set_rowwise_data()"
    Float8Quantizer-->>Linear: "quantized_input"
    
    Linear->>Linear: "get_weight_workspace()"
    Linear->>Float8Quantizer: "quantize(weight)"
    Float8Quantizer->>TensorWrapper: "set_rowwise_data()"
    Float8Quantizer-->>Linear: "quantized_weight"
    
    Linear->>_Linear: "apply(weight, input, bias, args)"
    _Linear->>general_gemm: "general_gemm(weight, input, quantizer)"
    general_gemm->>cublas_gemm: "cublas_gemm(A, B, D, quantizer)"
    
    cublas_gemm->>cublas_gemm: "CanonicalizeGemmInput()"
    cublas_gemm->>CUDA: "nvte_compute_scale_from_amax()"
    cublas_gemm->>CUDA: "cublasLtMatmul()"
    cublas_gemm->>CUDA: "update_tensor_scale_inv()"
    cublas_gemm-->>general_gemm: "gemm_result"
    
    general_gemm-->>_Linear: "output"
    _Linear-->>Linear: "output"
    Linear-->>User: "output"
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. transformer_engine/pytorch/csrc/util.cpp, line 18-20 (link)

    logic: Critical logical error: || should be &&. This condition will always betruesince a value cannot simultaneously be both scaling modes, causing the function to always return nullopt for valid inputs.

  2. transformer_engine/pytorch/quantized_tensor.py, line 373-393 (link)

    style: commented-out code for requires_grad caching optimization - consider removing dead code entirely. Is this code planned to be implemented later or should it be removed?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  3. transformer_engine/pytorch/module/linear.py, line 484 (link)

    logic: Logical error: this condition should use OR (||) not AND (&&). The original logic was checking if ANY tensor requires gradients for FP8 handling, but this now only activates when ALL three require gradients, including bias which may be None.

    Should the FP8 condition check if any tensor requires gradients (OR logic) rather than all tensors (AND logic)?

10 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements CPU-side performance optimizations for FP8 operations by caching frequently accessed attributes and reducing redundant function calls. The optimizations target expensive PyObject attribute lookups on custom tensor types and repeated C++ function calls.

Key Changes:

  • Caches requires_grad, dtype, shape, and is_cuda attribute accesses to avoid expensive PyObject lookups on custom tensors
  • Reorders attribute checks in get_tensor_device() to prioritize internal quantized tensor attributes
  • Makes num_devices static in nvte_is_non_tn_fp8_gemm_supported() to cache device count
  • Stores GEMM support check results in local variables to avoid redundant function calls

Critical Issues Found:

  • Variable redeclaration error in cublaslt_gemm.cu (line 224) will prevent compilation
  • Logic bug in linear.py (line 484) changes FP8 state management from OR logic to AND logic, breaking functionality when bias is None or doesn't require grad

Confidence Score: 0/5

  • This PR cannot be merged due to compilation error and critical logic bug
  • Two critical issues prevent merging: (1) C++ compilation will fail due to variable redeclaration at line 224 of cublaslt_gemm.cu, and (2) logic bug at line 484 of linear.py breaks FP8 state management by requiring all three tensors to have requires_grad=True instead of any one of them
  • Pay close attention to transformer_engine/common/gemm/cublaslt_gemm.cu (compilation error) and transformer_engine/pytorch/module/linear.py (logic bug)

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/gemm/cublaslt_gemm.cu 1/5 Caches function call result to reduce overhead, but contains variable redeclaration error that will cause compilation failure
transformer_engine/common/transformer_engine.cpp 5/5 Makes num_devices static to avoid redundant calls to cuda::num_devices() - valid optimization
transformer_engine/pytorch/module/linear.py 0/5 Caches requires_grad checks for performance, but contains critical logic bug at line 484 that changes FP8 state management behavior

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Linear as Linear Module
    participant Quantizer as Quantizer/QuantizedTensor
    participant GEMM as GEMM Operations
    participant CPP as C++ Extensions

    Note over Linear,CPP: Performance Optimization Flow
    
    User->>Linear: forward(input, weight, bias)
    
    Note over Linear: Cache requires_grad checks
    Linear->>Linear: inp_requires_grad = inp.requires_grad<br/>weight_requires_grad = weight.requires_grad<br/>bias_requires_grad = bias.requires_grad
    
    Linear->>Quantizer: Check if quantized tensor
    alt QuantizedTensor
        Note over Quantizer: Use cached dtype property
        Quantizer->>Quantizer: return self._dtype
        Note over Quantizer: Use cached shape/is_cuda
        Quantizer->>Quantizer: return self._data.shape
    else Regular Tensor
        Quantizer->>Linear: Standard attribute access
    end
    
    Linear->>CPP: get_tensor_device(tensor)
    Note over CPP: Reordered attribute checks
    CPP->>CPP: Check _rowwise_data first<br/>Check _columnwise_data<br/>Check device last
    CPP-->>Linear: device_index
    
    Linear->>GEMM: Configure GEMM parameters
    Note over GEMM: Cache nvte_is_non_tn_fp8_gemm_supported
    GEMM->>CPP: nvte_is_non_tn_fp8_gemm_supported()
    Note over CPP: Static num_devices cached
    CPP-->>GEMM: support_flag
    GEMM->>GEMM: Store in local variable
    
    GEMM->>GEMM: Execute optimized GEMM
    GEMM-->>Linear: output
    
    Note over Linear: FP8 State Management
    alt FP8 enabled and requires_grad check
        Linear->>Linear: Update FP8 tensors<br/>based on cached flags
    end
    
    Linear-->>User: output
Loading

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 7, 2026

Additional Comments (2)

transformer_engine/common/gemm/cublaslt_gemm.cu
variable redeclared in same scope - already declared at line 132

    // int is_nvte_non_tn_fp8_gemm_supported already declared at line 132

transformer_engine/pytorch/module/linear.py
logic change from original requires_grad(inp, weight, bias) which returns True if ANY tensor requires grad. New code requires ALL THREE to be True, breaking FP8 state management when bias is None or doesn't require grad

            if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant