From fc7e0799c819007c30aa4556bdbadae3a3a4e1f9 Mon Sep 17 00:00:00 2001 From: zzk Date: Mon, 16 Mar 2026 18:18:31 +0800 Subject: [PATCH 1/6] Removed auth for local backtest --- tqsdk/api.py | 32 +++++++++++++++------------- tqsdk/auth.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 15 deletions(-) diff --git a/tqsdk/api.py b/tqsdk/api.py index 1785beb1..51021eb9 100644 --- a/tqsdk/api.py +++ b/tqsdk/api.py @@ -224,7 +224,8 @@ def __init__(self, account: Optional[Union[TqMultiAccount, UnionTradeable]] = No user_name, pwd = auth[:comma_index], auth[comma_index + 1:] self._auth = TqAuth(user_name, pwd) else: - self._auth = None + from tqsdk.auth import TqAuthDummy + self._auth = TqAuthDummy() self._account = TqSim() if account is None else account self._backtest = backtest self._stock = False if isinstance(self._backtest, TqReplay) else _stock @@ -3405,21 +3406,22 @@ def _setup_connection(self): py_version=platform.python_version(), py_arch=platform.architecture()[0], cmd=sys.argv, mem_total=mem.total, mem_free=mem.free) if self._auth is None: - raise Exception("请输入 auth (快期账户)参数,快期账户是使用 tqsdk 的前提,如果没有请点击注册,注册地址:https://account.shinnytech.com/。") - else: - self._auth.init(mode="bt" if isinstance(self._backtest, TqBacktest) else "real") - self._auth.login() # tqwebhelper 有可能会设置 self._auth + from tqsdk.auth import TqAuthDummy + self._auth = TqAuthDummy() + + self._auth.init(mode="bt" if isinstance(self._backtest, TqBacktest) else "real") + self._auth.login() # tqwebhelper 有可能会设置 self._auth - # tqsdk 内部捕获异常如果需要打印日志,则需要自定义异常 - # 对于第三方代码产生的异常需要逐个捕获,可以参考 connect.py TqConnect._run 函数中对于各类异常的捕获 - # 这里只是打印账户过期日期来提醒用户,不关心是否成功,也不记录日志,所以直接 pass - # 单独捕获 self._auth.expire_datetime 是为了语义清晰,表明异常的来源 - try: - self._auth.expire_datetime - except Exception: - pass - if self._auth._expire_days_left is not None and self._auth._product_type is not None and self._auth._expire_days_left < 30: - self._print(f"TqSdk {self._auth._product_type} 版剩余 {self._auth._expire_days_left} 天到期,如需续费或升级请访问 https://account.shinnytech.com/ 或联系相关工作人员。") + # tqsdk 内部捕获异常如果需要打印日志,则需要自定义异常 + # 对于第三方代码产生的异常需要逐个捕获,可以参考 connect.py TqConnect._run 函数中对于各类异常的捕获 + # 这里只是打印账户过期日期来提醒用户,不关心是否成功,也不记录日志,所以直接 pass + # 单独捕获 self._auth.expire_datetime 是为了语义清晰,表明异常的来源 + try: + self._auth.expire_datetime + except Exception: + pass + if self._auth._expire_days_left is not None and self._auth._product_type is not None and self._auth._expire_days_left < 30: + self._print(f"TqSdk {self._auth._product_type} 版剩余 {self._auth._expire_days_left} 天到期,如需续费或升级请访问 https://account.shinnytech.com/ 或联系相关工作人员。") # 在快期账户登录之后,对于账户的基本信息校验及更新 for acc in self._account._account_list: diff --git a/tqsdk/auth.py b/tqsdk/auth.py index 0c04099b..e3b2de62 100644 --- a/tqsdk/auth.py +++ b/tqsdk/auth.py @@ -212,3 +212,62 @@ def _has_td_grants(self, symbol): if symbol.split('.', 1)[0] in (FUTURE_EXCHANGES + KQ_EXCHANGES) and self._has_feature("futr"): return True raise Exception(f"您的账户不支持交易 {symbol},需要购买后才能使用。升级网址:https://www.shinnytech.com/tqsdk-buy/") + + +class TqAuthDummy(object): + """无认证桩类,授予所有权限,不做任何网络调用。""" + + def __init__(self): + self._user_name = "local_user" + self._password = "" + self._access_token = "" + self._refresh_token = "" + self._auth_id = "" + self._mode = "real" + self._grants = {"features": [], "accounts": []} + self._expire_datetime = None + self._expire_days_left = None + self._product_type = None + self._logger = ShinnyLoggerAdapter( + logging.getLogger("TqApi.TqAuth"), + headers=self._base_headers, + grants=self._grants, + ) + + @property + def _base_headers(self): + return { + "User-Agent": "tqsdk-python %s" % __version__, + "Accept": "application/json", + } + + @property + def expire_datetime(self): + return datetime.datetime(2099, 12, 31, 23, 59, 59, tzinfo=_cst_tz) + + def init(self, mode="real"): + self._mode = mode + + def login(self): + pass + + def _has_feature(self, feature): + return True + + def _has_account(self, account): + return True + + def _has_md_grants(self, symbol): + return True + + def _has_td_grants(self, symbol): + return True + + def _add_account(self, account_id): + return True + + def _get_td_url(self, broker_id, account_id): + raise Exception("无认证模式不支持 OTG 实盘交易,请提供 TqAuth 参数。") + + def _get_md_url(self, stock, backtest): + raise Exception("无认证模式无法自动发现行情服务器,请通过 url 参数或 TQ_MD_URL 环境变量指定。") From 00edea43669ba73478e7cb0163aa0c52912953bd Mon Sep 17 00:00:00 2001 From: snowbro3 Date: Thu, 19 Mar 2026 10:06:02 +0800 Subject: [PATCH 2/6] set main ins to usual --- tqsdk/backtest/backtest.py | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/tqsdk/backtest/backtest.py b/tqsdk/backtest/backtest.py index 8007be5d..8be268a8 100644 --- a/tqsdk/backtest/backtest.py +++ b/tqsdk/backtest/backtest.py @@ -9,7 +9,7 @@ from datetime import date, datetime from typing import Union, Any, List, Dict -from tqsdk.backtest.utils import TqBacktestContinuous, TqBacktestDividend +from tqsdk.backtest.utils import TqBacktestDividend from tqsdk.channel import TqChan from tqsdk.datetime import _get_trading_day_start_time, _get_trading_day_end_time, _get_trading_day_from_timestamp, \ _timestamp_nano_to_str, _convert_user_input_to_nano @@ -87,16 +87,12 @@ def __init__(self, start_dt: Union[date, datetime], end_dt: Union[date, datetime async def _run(self, api, sim_send_chan, sim_recv_chan, md_send_chan, md_recv_chan): """回测task""" self._api = api - # 下载历史主连合约信息 - start_trading_day = _get_trading_day_from_timestamp(self._start_dt) # 回测开始交易日 - end_trading_day = _get_trading_day_from_timestamp(self._end_dt) # 回测结束交易日 - self._continuous_table = TqBacktestContinuous(start_dt=start_trading_day, - end_dt=end_trading_day, - headers=self._api._base_headers) + start_trading_day = _get_trading_day_from_timestamp(self._start_dt) + end_trading_day = _get_trading_day_from_timestamp(self._end_dt) self._stock_dividend = TqBacktestDividend(start_dt=start_trading_day, end_dt=end_trading_day, headers=self._api._base_headers) - self._logger = api._logger.getChild("TqBacktest") # 调试信息输出 + self._logger = api._logger.getChild("TqBacktest") self._sim_send_chan = sim_send_chan self._sim_recv_chan = sim_recv_chan self._md_send_chan = md_send_chan @@ -229,8 +225,6 @@ def _update_valid_quotes(self, quotes): invalid_keys.union({'open', 'close', 'settlement', 'lowest', 'lower_limit', 'upper_limit', 'pre_open_interest', 'pre_settlement', 'pre_close', 'expired'}) for symbol, quote in quotes.items(): [quote.pop(k, None) for k in invalid_keys] - if symbol.startswith("KQ.m"): - quote.pop("underlying_symbol", None) if quote.get('expire_datetime'): # 先删除所有的 quote 的 expired 字段,只在有 expire_datetime 字段时才会添加 expired 字段 quote['expired'] = quote.get('expire_datetime') * 1000000000 <= self._trading_day_start @@ -278,11 +272,6 @@ async def _send_snapshot(self): "option_class": quote.get("option_class", ""), "product_id": quote.get("product_id", ""), } - # 修改历史主连合约信息 - cont_quotes = self._continuous_table._get_history_cont_quotes(self._trading_day) - for k, v in cont_quotes.items(): - quotes.setdefault(k, {}) # 实际上,初始行情截面中只有下市合约,没有主连 - quotes[k].update(v) self._diffs.append({ "quotes": quotes, "ins_list": "", @@ -298,16 +287,13 @@ async def _send_diff(self): # 发送数据集中添加 backtest 字段,开始时间、结束时间、当前时间,表示当前行情推进是由 backtest 推进 self._diffs.append({"_tqsdk_backtest": self._get_backtest_time()}) - # 切换交易日,将历史的主连合约信息添加的 diffs + # 切换交易日 if self._current_dt > self._trading_day_end: # 使用交易日结束时间,每个交易日切换只需要计算一次交易日结束时间 # 相比发送 diffs 前每次都用 _current_dt 计算当前交易日,计算次数更少 self._trading_day = _get_trading_day_from_timestamp(self._current_dt) self._trading_day_start = _get_trading_day_start_time(self._trading_day) self._trading_day_end = _get_trading_day_end_time(self._trading_day) - self._diffs.append({ - "quotes": self._continuous_table._get_history_cont_quotes(self._trading_day) - }) self._diffs.append({ "quotes": self._stock_dividend._get_dividend(self._data.get('quotes'), self._trading_day) }) From 9de893e8de2f676d23b72d35c97e1fef5ea43f1e Mon Sep 17 00:00:00 2001 From: snowbro3 Date: Thu, 19 Mar 2026 16:00:16 +0800 Subject: [PATCH 3/6] dd --- tqsdk/datetime.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tqsdk/datetime.py b/tqsdk/datetime.py index 5234baef..811ce295 100644 --- a/tqsdk/datetime.py +++ b/tqsdk/datetime.py @@ -125,6 +125,9 @@ def _get_period_timestamp(real_date_timestamp, period_str): def _is_in_trading_time(quote, current_datetime, local_time_record): """ 判断是否在可交易时间段内,需在quote已收到行情后调用本函数""" # 只在需要用到可交易时间段时(即本函数中)才调用_get_trading_timestamp() + time_part = current_datetime.split(' ')[1] if ' ' in current_datetime else '' + if time_part in ('18:00:00.000000', '17:59:59.999999'): + return True trading_timestamp = _get_trading_timestamp(quote, current_datetime) now_ns_timestamp = _get_trade_timestamp(current_datetime, local_time_record) # 当前预估交易所纳秒时间戳 # 判断当前交易所时间(估计值)是否在交易时间段内 From 7a0d7aa6f6d85d084de1e1439616742c323acf45 Mon Sep 17 00:00:00 2001 From: snowbro3 Date: Thu, 19 Mar 2026 18:01:34 +0800 Subject: [PATCH 4/6] dd --- tqsdk/backtest.py | 4 +++- tqsdk/backtest/backtest.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tqsdk/backtest.py b/tqsdk/backtest.py index f8ec5b88..ecaa6822 100644 --- a/tqsdk/backtest.py +++ b/tqsdk/backtest.py @@ -445,7 +445,9 @@ async def _ensure_quote(self, ins): quote["_listener"].add(update_chan) while math.isnan(quote.get("price_tick")): await update_chan.recv() - if ins not in self._quotes or self._quotes[ins]["min_duration"] > 60000000000: + # if ins not in self._quotes or self._quotes[ins]["min_duration"] > 60000000000: + # await self._ensure_serial(ins, 60000000000) + if ins not in self._quotes: await self._ensure_serial(ins, 60000000000) async def _fetch_serial(self, key): diff --git a/tqsdk/backtest/backtest.py b/tqsdk/backtest/backtest.py index 8be268a8..5d42f1ea 100644 --- a/tqsdk/backtest/backtest.py +++ b/tqsdk/backtest/backtest.py @@ -453,7 +453,9 @@ async def _ensure_symbols(self, symbols): await update_chan.recv() async def _ensure_quote(self, symbol): - if symbol not in self._quotes or self._quotes[symbol]["min_duration"] > 60000000000: + # if symbol not in self._quotes or self._quotes[symbol]["min_duration"] > 60000000000: + # await self._ensure_serial(symbol, 60000000000) + if symbol not in self._quotes: await self._ensure_serial(symbol, 60000000000) async def _fetch_serial(self, key): From 8aecc8a8d0370d2cdb97f286a0b05a48de4d1fc1 Mon Sep 17 00:00:00 2001 From: zzk Date: Fri, 20 Mar 2026 15:53:19 +0800 Subject: [PATCH 5/6] Performance optimization --- tqsdk/api.py | 22 +++++++-------- tqsdk/backtest.py | 16 +++++------ tqsdk/backtest/backtest.py | 16 +++++------ tqsdk/diff.py | 10 +++---- tqsdk/entity.py | 57 +++++++++++++++++++++++++++++++------- tqsdk/objs.py | 46 ++++++++++++++++++++++++------ tqsdk/trading_status.py | 4 +-- 7 files changed, 119 insertions(+), 52 deletions(-) diff --git a/tqsdk/api.py b/tqsdk/api.py index 51021eb9..a1bcd8f0 100644 --- a/tqsdk/api.py +++ b/tqsdk/api.py @@ -2057,7 +2057,7 @@ def _is_obj_changing(self, obj: Any, diffs: List[Dict[str, Any]], key: List[str] if id(obj) in self._serials: paths = [] for root in self._serials[id(obj)]["root"]: - paths.append(root["_path"]) + paths.append(root._path) elif len(obj) == 0: return False else: # 处理传入的为一个 copy 出的 DataFrame (与原 DataFrame 数据相同的另一个object) @@ -2088,7 +2088,7 @@ def _is_obj_changing(self, obj: Any, diffs: List[Dict[str, Any]], key: List[str] paths.append(["ticks", obj["symbol"], "data", str(int(obj["id"]))]) else: - paths = [obj["_path"]] + paths = [obj._path] except (KeyError, IndexError): return False for diff in diffs: @@ -3629,12 +3629,12 @@ def _init_serial(self, root_list, width, default, adj_type): temp_df = pd.DataFrame() temp_df._mgr = bm serial["df"] = TqDataFrame(self, temp_df, copy=False) - serial["df"]["symbol"] = root_list[0]["_path"][1] + serial["df"]["symbol"] = root_list[0]._path[1] for i in range(1, len(root_list)): - serial["df"]["symbol" + str(i)] = root_list[i]["_path"][1] + serial["df"]["symbol" + str(i)] = root_list[i]._path[1] - serial["df"]["duration"] = 0 if root_list[0]["_path"][0] == "ticks" else int( - root_list[0]["_path"][-1]) // 1000000000 + serial["df"]["duration"] = 0 if root_list[0]._path[0] == "ticks" else int( + root_list[0]._path[-1]) // 1000000000 return serial def _update_serial_single(self, serial): @@ -3842,8 +3842,8 @@ def _process_serial_extra_array(self, serial): serial["all_attr"] = set(serial["df"].columns.values) if serial["update_row"] == serial["width"]: return - symbol = serial["root"][0]["_path"][1] # 主合约的symbol,标志绘图的主合约 - duration = 0 if serial["root"][0]["_path"][0] == "ticks" else int(serial["root"][0]["_path"][-1]) + symbol = serial["root"][0]._path[1] # 主合约的symbol,标志绘图的主合约 + duration = 0 if serial["root"][0]._path[0] == "ticks" else int(serial["root"][0]._path[-1]) cols = list(serial["extra_array"].keys()) # 归并数据序列 while len(cols) != 0: @@ -4033,8 +4033,8 @@ def _gen_security_prototype(self): @staticmethod def _deep_copy_dict(source, dest): - for key, value in source.__dict__.items(): - if isinstance(value, Entity): + for key, value in source._data.items(): + if hasattr(value, '_data'): dest[key] = {} TqApi._deep_copy_dict(value, dest[key]) else: @@ -4216,7 +4216,7 @@ def draw_report(self, report_datas): def _send_chart_data(self, base_kserial_frame, serial_id, serial_data): s = self._serials[id(base_kserial_frame)] - p = s["root"][0]["_path"] + p = s["root"][0]._path symbol = p[-2] dur_nano = int(p[-1]) pack = { diff --git a/tqsdk/backtest.py b/tqsdk/backtest.py index ecaa6822..09891978 100644 --- a/tqsdk/backtest.py +++ b/tqsdk/backtest.py @@ -232,7 +232,7 @@ def _update_valid_quotes(self, quotes): async def _send_snapshot(self): """发送初始合约信息""" async with TqChan(self._api, last_only=True) as update_chan: # 等待与行情服务器连接成功 - self._data["_listener"].add(update_chan) + self._data._listener.add(update_chan) while self._data.get("mdhis_more_data", True): await update_chan.recv() # 发送初始行情(合约信息截面)时 @@ -431,7 +431,7 @@ async def _ensure_query(self, pack): if query_pack.items() <= self._data.get("symbols", {}).get(pack["query_id"], {}).items(): return async with TqChan(self._api, last_only=True) as update_chan: - self._data["_listener"].add(update_chan) + self._data._listener.add(update_chan) while not query_pack.items() <= self._data.get("symbols", {}).get(pack["query_id"], {}).items(): await update_chan.recv() @@ -442,7 +442,7 @@ async def _ensure_quote(self, ins): query_pack = _query_for_quote(ins) await self._md_send_chan.send(query_pack) async with TqChan(self._api, last_only=True) as update_chan: - quote["_listener"].add(update_chan) + quote._listener.add(update_chan) while math.isnan(quote.get("price_tick")): await update_chan.recv() # if ins not in self._quotes or self._quotes[ins]["min_duration"] > 60000000000: @@ -483,9 +483,9 @@ async def _gen_serial(self, ins, dur): serials = [_get_obj(self._data, ["klines", s, str(dur)]) for s in symbol_list] async with TqChan(self._api, last_only=True) as update_chan: for serial in serials: - serial["_listener"].add(update_chan) - chart_a["_listener"].add(update_chan) - chart_b["_listener"].add(update_chan) + serial._listener.add(update_chan) + chart_a._listener.add(update_chan) + chart_b._listener.add(update_chan) await self._md_send_chan.send(chart_info.copy()) try: async for _ in update_chan: @@ -501,10 +501,10 @@ async def _gen_serial(self, ins, dur): if last_id == -1: continue # 数据序列还没收到 if self._data.get("mdhis_more_data", True): - self._data["_listener"].add(update_chan) + self._data._listener.add(update_chan) continue else: - self._data["_listener"].discard(update_chan) + self._data._listener.discard(update_chan) if current_id is None: current_id = max(left_id, 0) # 发送下一段 chart 8964 根 kline diff --git a/tqsdk/backtest/backtest.py b/tqsdk/backtest/backtest.py index 5d42f1ea..43f5a10a 100644 --- a/tqsdk/backtest/backtest.py +++ b/tqsdk/backtest/backtest.py @@ -233,7 +233,7 @@ def _update_valid_quotes(self, quotes): async def _send_snapshot(self): """发送初始合约信息""" async with TqChan(self._api, last_only=True) as update_chan: # 等待与行情服务器连接成功 - self._data["_listener"].add(update_chan) + self._data._listener.add(update_chan) while self._data.get("mdhis_more_data", True): await update_chan.recv() # 发送初始行情(合约信息截面)时 @@ -434,7 +434,7 @@ async def _ensure_query(self, pack): if query_pack.items() <= self._data.get("symbols", {}).get(pack["query_id"], {}).items(): return async with TqChan(self._api, last_only=True) as update_chan: - self._data["_listener"].add(update_chan) + self._data._listener.add(update_chan) while not query_pack.items() <= self._data.get("symbols", {}).get(pack["query_id"], {}).items(): await update_chan.recv() @@ -448,7 +448,7 @@ async def _ensure_symbols(self, symbols): await self._md_send_chan.send(query_pack) async with TqChan(self._api, last_only=True) as update_chan: for q in quotes: - q["_listener"].add(update_chan) + q._listener.add(update_chan) while any([math.isnan(q.get("price_tick")) for q in quotes]): await update_chan.recv() @@ -491,9 +491,9 @@ async def _gen_serial(self, ins, dur): serials = [_get_obj(self._data, ["klines", s, str(dur)]) for s in symbol_list] async with TqChan(self._api, last_only=True) as update_chan: for serial in serials: - serial["_listener"].add(update_chan) - chart_a["_listener"].add(update_chan) - chart_b["_listener"].add(update_chan) + serial._listener.add(update_chan) + chart_a._listener.add(update_chan) + chart_b._listener.add(update_chan) await self._md_send_chan.send(chart_info.copy()) try: async for _ in update_chan: @@ -520,10 +520,10 @@ async def _gen_serial(self, ins, dur): yield self._current_dt, diff, None, "OPEN" return if self._data.get("mdhis_more_data", True): - self._data["_listener"].add(update_chan) + self._data._listener.add(update_chan) continue else: - self._data["_listener"].discard(update_chan) + self._data._listener.discard(update_chan) left_id = chart.get("left_id", -1) right_id = chart.get("right_id", -1) if current_id is None: diff --git a/tqsdk/diff.py b/tqsdk/diff.py index 5dd89adc..28172009 100644 --- a/tqsdk/diff.py +++ b/tqsdk/diff.py @@ -29,7 +29,7 @@ def _merge_diff(result, diff, prototype, persist, reduce_diff=False, notify_upda else: if notify_update_diff: dv = result.pop(key, None) - _notify_update(dv, True, _gen_diff_obj(None, result["_path"] + [key])) + _notify_update(dv, True, _gen_diff_obj(None, result._path + [key])) else: dv = result.pop(key, None) _notify_update(dv, True, True) @@ -65,7 +65,7 @@ def _merge_diff(result, diff, prototype, persist, reduce_diff=False, notify_upda # 这里发的数据目前是不需要 copy (浅拷贝会有坑,深拷贝的话性能不知道有多大影响) # 因为这里现在会用到发送这个 diff 的只有 quote 对象,只有 sim 会收到使用,sim 收到之后是不会修改这个 diff # 所以这里就约定接收方不能改 diff 中的值 - diff_obj = _gen_diff_obj(diff, result["_path"]) + diff_obj = _gen_diff_obj(diff, result._path) _notify_update(result, False, diff_obj) @@ -79,7 +79,7 @@ def _gen_diff_obj(diff, path): def _notify_update(target, recursive, content): """同步通知业务数据更新""" - if isinstance(target, dict) or isinstance(target, Entity): + if type(target) is dict or hasattr(target, '_data'): for q in getattr(target, "_listener", {}): q.send_nowait(content) if recursive: @@ -96,7 +96,7 @@ def _get_obj(root, path, default=None): dv = Entity() else: dv = copy.copy(default) - dv._instance_entity(d["_path"] + [path[i]]) + dv._instance_entity(d._path + [path[i]]) d[path[i]] = dv d = d[path[i]] return d @@ -106,7 +106,7 @@ def _register_update_chan(objs, chan): if not isinstance(objs, list): objs = [objs] for o in objs: - o["_listener"].add(chan) + o._listener.add(chan) return chan diff --git a/tqsdk/entity.py b/tqsdk/entity.py index 317470ea..eec8f18f 100644 --- a/tqsdk/entity.py +++ b/tqsdk/entity.py @@ -8,31 +8,68 @@ class Entity(MutableMapping): + def __new__(cls, *args, **kwargs): + instance = super().__new__(cls) + object.__setattr__(instance, '_data', {}) + return instance + def _instance_entity(self, path): - self._path = path - self._listener = weakref.WeakSet() + object.__setattr__(self, '_path', path) + object.__setattr__(self, '_listener', weakref.WeakSet()) + + def __setattr__(self, key, value): + if key.startswith('_'): + object.__setattr__(self, key, value) + else: + self._data[key] = value + + def __getattr__(self, key): + try: + return self._data[key] + except KeyError: + raise AttributeError(key) + + def __delattr__(self, key): + if key.startswith('_'): + object.__delattr__(self, key) + else: + try: + del self._data[key] + except KeyError: + raise AttributeError(key) def __setitem__(self, key, value): - return self.__dict__.__setitem__(key, value) + self._data[key] = value def __delitem__(self, key): - return self.__dict__.__delitem__(key) + del self._data[key] def __getitem__(self, key): - return self.__dict__.__getitem__(key) + return self._data[key] def __iter__(self): - return iter({k: v for k, v in self.__dict__.items() if not k.startswith("_")}) + return iter(self._data) def __len__(self): - return len({k: v for k, v in self.__dict__.items() if not k.startswith("_")}) + return len(self._data) def __str__(self): - return str({k: v for k, v in self.__dict__.items() if not k.startswith("_")}) + return str(self._data) def __repr__(self): - return '{}, D({})'.format(super(Entity, self).__repr__(), - {k: v for k, v in self.__dict__.items() if not k.startswith("_")}) + return '{}, D({})'.format(super(Entity, self).__repr__(), self._data) + + def __contains__(self, key): + return key in self._data + + def __copy__(self): + new = type(self).__new__(type(self)) + # Copy private attrs from __dict__ (excluding _data which is handled separately) + for k, v in self.__dict__.items(): + if k != '_data': + object.__setattr__(new, k, v) + object.__setattr__(new, '_data', self._data.copy()) + return new def copy(self): return copy.copy(self) diff --git a/tqsdk/objs.py b/tqsdk/objs.py index a73d93fa..f32f113b 100644 --- a/tqsdk/objs.py +++ b/tqsdk/objs.py @@ -450,8 +450,16 @@ def orders(self): :return: dict, 其中每个元素的key为委托单ID, value为 :py:class:`~tqsdk.objs.Order` """ tdict = _get_obj(self._api._data, ["trade", self._path[1], "orders"]) - fts = {order_id: order for order_id, order in tdict.items() if (not order_id.startswith( - "_")) and order.instrument_id == self.instrument_id and order.exchange_id == self.exchange_id and order.status == "ALIVE"} + inst_id = self._data.get('instrument_id', '') + exch_id = self._data.get('exchange_id', '') + fts = {} + for order_id, order in tdict._data.items(): + try: + od = order._data + except AttributeError: + continue + if od.get('status') == "ALIVE" and od.get('instrument_id') == inst_id and od.get('exchange_id') == exch_id: + fts[order_id] = order return fts @@ -509,8 +517,15 @@ def trade_records(self): :return: dict, 其中每个元素的key为成交ID, value为 :py:class:`~tqsdk.objs.Trade` """ tdict = _get_obj(self._api._data, ["trade", self._path[1], "trades"]) - fts = {trade_id: trade for trade_id, trade in tdict.items() if - (not trade_id.startswith("_")) and trade.order_id == self.order_id} + target_order_id = self._data.get('order_id', '') + fts = {} + for trade_id, trade in tdict._data.items(): + try: + od = trade._data + except AttributeError: + continue + if od.get('order_id') == target_order_id: + fts[trade_id] = trade return fts @@ -837,8 +852,16 @@ def __init__(self, api): @property def orders(self): tdict = _get_obj(self._api._data, ["trade", self._path[1], "orders"]) - fts = {order_id: order for order_id, order in tdict.items() if (not order_id.startswith( - "_")) and order.instrument_id == self.instrument_id and order.exchange_id == self.exchange_id and order.status == "ALIVE"} + inst_id = self._data.get('instrument_id', '') + exch_id = self._data.get('exchange_id', '') + fts = {} + for order_id, order in tdict._data.items(): + try: + od = order._data + except AttributeError: + continue + if od.get('status') == "ALIVE" and od.get('instrument_id') == inst_id and od.get('exchange_id') == exch_id: + fts[order_id] = order return fts @@ -884,8 +907,15 @@ def trade_records(self): :return: dict, 其中每个元素的key为成交ID, value为 :py:class:`~tqsdk.objs.Trade` """ tdict = _get_obj(self._api._data, ["trade", self._path[1], "trades"]) - fts = {trade_id: trade for trade_id, trade in tdict.items() if - (not trade_id.startswith("_")) and trade.order_id == self.order_id} + target_order_id = self._data.get('order_id', '') + fts = {} + for trade_id, trade in tdict._data.items(): + try: + od = trade._data + except AttributeError: + continue + if od.get('order_id') == target_order_id: + fts[trade_id] = trade return fts diff --git a/tqsdk/trading_status.py b/tqsdk/trading_status.py index c5b02cef..45f4da2c 100644 --- a/tqsdk/trading_status.py +++ b/tqsdk/trading_status.py @@ -87,14 +87,14 @@ def _normalize_trade_status(self, diffs): async def _query_symbol_info(self, symbols): """查询缺少合约信息的quotes""" for symbol in symbols: - self._quotes_unready[symbol]["_listener"].add(self._quote_chan) + self._quotes_unready[symbol]._listener.add(self._quote_chan) for query_pack in _query_for_quote(list(symbols), self._api._pre20_ins_info.keys()): await self._md_send_chan.send(query_pack) async def _symbol_info_watcher(self): async for _ in self._quote_chan: for symbol in await self._unready_to_ready(): - self._quotes_ready[symbol]["_listener"].discard(self._quote_chan) + self._quotes_ready[symbol]._listener.discard(self._quote_chan) async def _unready_to_ready(self): ready_delta = {symbol for symbol, quote in self._quotes_unready.items() if not math.isnan(quote.price_tick)} From 4d4aea8ca1b4479c43a1b4c0ec235133e5748729 Mon Sep 17 00:00:00 2001 From: snowboy3 Date: Tue, 31 Mar 2026 20:40:40 +0800 Subject: [PATCH 6/6] revert backtest.py --- tqsdk/backtest/backtest.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/tqsdk/backtest/backtest.py b/tqsdk/backtest/backtest.py index 43f5a10a..89cecd31 100644 --- a/tqsdk/backtest/backtest.py +++ b/tqsdk/backtest/backtest.py @@ -9,7 +9,7 @@ from datetime import date, datetime from typing import Union, Any, List, Dict -from tqsdk.backtest.utils import TqBacktestDividend +from tqsdk.backtest.utils import TqBacktestContinuous, TqBacktestDividend from tqsdk.channel import TqChan from tqsdk.datetime import _get_trading_day_start_time, _get_trading_day_end_time, _get_trading_day_from_timestamp, \ _timestamp_nano_to_str, _convert_user_input_to_nano @@ -87,12 +87,16 @@ def __init__(self, start_dt: Union[date, datetime], end_dt: Union[date, datetime async def _run(self, api, sim_send_chan, sim_recv_chan, md_send_chan, md_recv_chan): """回测task""" self._api = api - start_trading_day = _get_trading_day_from_timestamp(self._start_dt) - end_trading_day = _get_trading_day_from_timestamp(self._end_dt) + # 下载历史主连合约信息 + start_trading_day = _get_trading_day_from_timestamp(self._start_dt) # 回测开始交易日 + end_trading_day = _get_trading_day_from_timestamp(self._end_dt) # 回测结束交易日 + self._continuous_table = TqBacktestContinuous(start_dt=start_trading_day, + end_dt=end_trading_day, + headers=self._api._base_headers) self._stock_dividend = TqBacktestDividend(start_dt=start_trading_day, end_dt=end_trading_day, headers=self._api._base_headers) - self._logger = api._logger.getChild("TqBacktest") + self._logger = api._logger.getChild("TqBacktest") # 调试信息输出 self._sim_send_chan = sim_send_chan self._sim_recv_chan = sim_recv_chan self._md_send_chan = md_send_chan @@ -225,6 +229,8 @@ def _update_valid_quotes(self, quotes): invalid_keys.union({'open', 'close', 'settlement', 'lowest', 'lower_limit', 'upper_limit', 'pre_open_interest', 'pre_settlement', 'pre_close', 'expired'}) for symbol, quote in quotes.items(): [quote.pop(k, None) for k in invalid_keys] + if symbol.startswith("KQ.m"): + quote.pop("underlying_symbol", None) if quote.get('expire_datetime'): # 先删除所有的 quote 的 expired 字段,只在有 expire_datetime 字段时才会添加 expired 字段 quote['expired'] = quote.get('expire_datetime') * 1000000000 <= self._trading_day_start @@ -272,6 +278,11 @@ async def _send_snapshot(self): "option_class": quote.get("option_class", ""), "product_id": quote.get("product_id", ""), } + # 修改历史主连合约信息 + cont_quotes = self._continuous_table._get_history_cont_quotes(self._trading_day) + for k, v in cont_quotes.items(): + quotes.setdefault(k, {}) # 实际上,初始行情截面中只有下市合约,没有主连 + quotes[k].update(v) self._diffs.append({ "quotes": quotes, "ins_list": "", @@ -287,13 +298,16 @@ async def _send_diff(self): # 发送数据集中添加 backtest 字段,开始时间、结束时间、当前时间,表示当前行情推进是由 backtest 推进 self._diffs.append({"_tqsdk_backtest": self._get_backtest_time()}) - # 切换交易日 + # 切换交易日,将历史的主连合约信息添加的 diffs if self._current_dt > self._trading_day_end: # 使用交易日结束时间,每个交易日切换只需要计算一次交易日结束时间 # 相比发送 diffs 前每次都用 _current_dt 计算当前交易日,计算次数更少 self._trading_day = _get_trading_day_from_timestamp(self._current_dt) self._trading_day_start = _get_trading_day_start_time(self._trading_day) self._trading_day_end = _get_trading_day_end_time(self._trading_day) + self._diffs.append({ + "quotes": self._continuous_table._get_history_cont_quotes(self._trading_day) + }) self._diffs.append({ "quotes": self._stock_dividend._get_dividend(self._data.get('quotes'), self._trading_day) })