@@ -99,7 +99,7 @@ def get_num_tokens_from_messages(
9999 except Exception as e :
100100 tokenizer = TokenizerManage .get_tokenizer ()
101101 return sum ([len (tokenizer .encode (get_buffer_string ([m ]))) for m in messages ])
102- return self .usage_metadata .get ('input_tokens' , 0 )
102+ return self .usage_metadata .get ('input_tokens' , self . usage_metadata . get ( 'prompt_tokens' , 0 ) )
103103
104104 def get_num_tokens (self , text : str ) -> int :
105105 if self .usage_metadata is None or self .usage_metadata == {}:
@@ -108,7 +108,8 @@ def get_num_tokens(self, text: str) -> int:
108108 except Exception as e :
109109 tokenizer = TokenizerManage .get_tokenizer ()
110110 return len (tokenizer .encode (text ))
111- return self .get_last_generation_info ().get ('output_tokens' , 0 )
111+ return self .get_last_generation_info ().get ('output_tokens' ,
112+ self .get_last_generation_info ().get ('completion_tokens' , 0 ))
112113
113114 def _stream (self , * args : Any , ** kwargs : Any ) -> Iterator [ChatGenerationChunk ]:
114115 kwargs ['stream_usage' ] = True
@@ -133,7 +134,7 @@ def _convert_chunk_to_generation_chunk(
133134 )
134135
135136 usage_metadata : Optional [UsageMetadata ] = (
136- _create_usage_metadata (token_usage ) if token_usage else None
137+ _create_usage_metadata (token_usage ) if token_usage and token_usage . get ( "prompt_tokens" ) else None
137138 )
138139 if len (choices ) == 0 :
139140 # logprobs is implicitly None
0 commit comments