diff --git a/CHANGELOG.md b/CHANGELOG.md index 65d44f4..b1a56dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added +- `TocountError` class +- `TocountValidationError` class ### Changed - `setup.py` updated ## [0.5] - 2026-01-02 diff --git a/tests/test_rule_based.py b/tests/test_rule_based.py index 93fa2f7..70f9d82 100644 --- a/tests/test_rule_based.py +++ b/tests/test_rule_based.py @@ -1,5 +1,6 @@ import pytest from tocount import estimate_text_tokens, TextEstimator +from tocount import TocountValidationError from tocount.params import INVALID_TEXT_MESSAGE, INVALID_TEXT_ESTIMATOR_MESSAGE @@ -135,12 +136,12 @@ def test_openai_gpt_4_text_with_non_english(): def test_raises_error_for_invalid_text(): invalid_text = 12345 - with pytest.raises(ValueError, match=INVALID_TEXT_MESSAGE): + with pytest.raises(TocountValidationError, match=INVALID_TEXT_MESSAGE): estimate_text_tokens(invalid_text) def test_raises_error_for_invalid_estimator(): valid_text = "sample prompt" invalid_estimator = "not a valid estimator" - with pytest.raises(ValueError, match=INVALID_TEXT_ESTIMATOR_MESSAGE): + with pytest.raises(TocountValidationError, match=INVALID_TEXT_ESTIMATOR_MESSAGE): estimate_text_tokens(valid_text, invalid_estimator) diff --git a/tocount/__init__.py b/tocount/__init__.py index 50ef7d4..70382cd 100644 --- a/tocount/__init__.py +++ b/tocount/__init__.py @@ -2,7 +2,8 @@ """Tocount modules.""" from .params import TOCOUNT_VERSION, TextEstimator +from .errors import TocountError, TocountValidationError from .functions import estimate_text_tokens __version__ = TOCOUNT_VERSION -__all__ = ["TextEstimator", "estimate_text_tokens"] +__all__ = ["TextEstimator", "estimate_text_tokens", "TocountError", "TocountValidationError"] diff --git a/tocount/errors.py b/tocount/errors.py new file mode 100644 index 0000000..d6029f2 --- /dev/null +++ b/tocount/errors.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +"""Tocount errors.""" + +class TocountError(Exception): + """Base exception for all Tocount errors.""" + + pass + +class TocountValidationError(TocountError, ValueError): + """Base class for validation errors in Tocount.""" + + pass diff --git a/tocount/functions.py b/tocount/functions.py index 3f68487..bb69dcf 100644 --- a/tocount/functions.py +++ b/tocount/functions.py @@ -4,6 +4,7 @@ from .params import TextEstimator, _TextEstimatorRuleBased, _TextEstimatorTikTokenR50K from .params import _TextEstimatorTikTokenCL100K, _TextEstimatorTikTokenO200K from .params import _TextEstimatorDeepseekR1, _TextEstimatorQwenQwQ, _TextEstimatorLlama_3_1 +from .errors import TocountValidationError from .rule_based.functions import universal_tokens_estimator, openai_tokens_estimator_gpt_3_5, openai_tokens_estimator_gpt_4 from .tiktoken_r50k.functions import linear_tokens_estimator_all as r50k_linear_all from .tiktoken_r50k.functions import linear_tokens_estimator_english as r50k_linear_english @@ -47,7 +48,7 @@ def estimate_text_tokens(text: str, estimator: TextEstimator = TextEstimator.DEF :return: tokens number """ if not isinstance(text, str): - raise ValueError(INVALID_TEXT_MESSAGE) + raise TocountValidationError(INVALID_TEXT_MESSAGE) if not isinstance(estimator, (TextEstimator, _TextEstimatorRuleBased, _TextEstimatorTikTokenR50K, _TextEstimatorTikTokenCL100K, _TextEstimatorTikTokenO200K, _TextEstimatorDeepseekR1, _TextEstimatorQwenQwQ, _TextEstimatorLlama_3_1)): - raise ValueError(INVALID_TEXT_ESTIMATOR_MESSAGE) + raise TocountValidationError(INVALID_TEXT_ESTIMATOR_MESSAGE) return text_estimator_map[estimator](text)