Test metadata for untestable ops and fix test_eval.py#114
Test metadata for untestable ops and fix test_eval.py#114jiannanWang wants to merge 7 commits intomainfrom
Conversation
|
New commit: Remove redundant UNTESTABLE_OPERATORS. TODO: Once this PR is merged, remove UNTESTABLE_OPERATORS from dataset filter since they are testable now. |
PaliC
left a comment
There was a problem hiding this comment.
For this I'd actually make the distinction between random and tensor creation ops. I'd then route tensor creation ops to test/equals_metadata as it seems p similar to what you see here. Though I'd ensure its comprehensive as we'll add the other creation ops later.
For benoulli the testing code should be in one of the files at the bottom of #112 and we can likely just use the same testing methodology. If this pr gets merged, I still would not say we check for correctness for bernoulli yet.
To me equals_metadata seems correct! However, I'd personally not merge it until we do a branch cut for the alpha version we are releasing on 9/6 just out of an abundance of caution.
| _allclose(a.stride(), b.stride(), atol=0.0, rtol=0.0) | ||
| _allclose(a.dtype, b.dtype, atol=0.0, rtol=0.0) | ||
| _allclose(a.device, b.device, atol=0.0, rtol=0.0) | ||
| return True |
There was a problem hiding this comment.
I'd check the type string as well as per the reference
There was a problem hiding this comment.
Sure! I have added _allclose(a.is_sparse, b.is_sparse, atol=0.0, rtol=0.0).
The type string assertion checks for dtype, device, and is_sparse. The first two are checked already, so I only add is_sparse.
There was a problem hiding this comment.
wait ... let's just use the functions / machinery from pytorch directly. I feel like that's a bit more future proof / feeds into our desire to make these generated kernels mergable into pytorch
| return False | ||
|
|
||
|
|
||
| def equal_metadata(a, b): |
There was a problem hiding this comment.
One thing I'm not super clear on is that OpInfo this is indeed the way they test tensor creation ops, that's how we figured out this might be the right testing strategy. So why not just use OpInfo again here?
There was a problem hiding this comment.
there is a reference here to pytorch's testing strategy https://github.com/pytorch/pytorch/blob/332fa5b388521c05a19217649745c6edfdc2836d/test/test_tensor_creation_ops.py
BackendBench/eval.py
Outdated
|
|
||
| from BackendBench.utils import serialize_args, uses_cuda_stream, compute_errors | ||
| from BackendBench.scripts.pytorch_operators import extract_operator_name | ||
| from BackendBench.scripts.dataset_filters import UNTESTABLE_OPERATORS |
There was a problem hiding this comment.
The nam UNTESTABLE is no longer right, would be explicit and call it tensor creation ops
There was a problem hiding this comment.
changed to TENSOR_CREATION_OPERATORS
| assert counter == 20 | ||
| assert time_per_run > 0 | ||
|
|
||
| def test_gpu_bench(self): |
There was a problem hiding this comment.
was this giving a problem or do you jus think it's a useless test?
There was a problem hiding this comment.
There's no gpu_bench function in eval.py and we are using triton.testing.do_bench for gpu performance. This actually causes an import error and is fixed in this pr.
msaroufim
left a comment
There was a problem hiding this comment.
please check feedback before merge
Added
equal_metadataandtest_metadatato enable metadata correctness checks for previously untestable operators. Now, operators likeempty_like,new_empty,new_empty_strided, andbernoulliare tested for metadata.Running
uv run python BackendBench/scripts/main.py --suite opinfo --backend aten --ops "empty_like,new_empty,new_empty_strided,bernoulli"and I got:Before:
After:
In the meantime, I fixed a bug in
test_eval.pywhere it tried to importgpu_benchfromeval.py(which does not exist). This was causing all tests in that file to be skipped.I also noticed that
test_dataoverwrites previous entries when the multiple tests have the same arguments, leading to assertion failures intest_eval_correctness_multiple_testsandtest_eval_correctness_metadata. I commented out the affected assertion for now to let the tests pass, but this may need to be solved in future prs.