From 7f5a341fb2b6466405ce32eca17f9012ff664ae0 Mon Sep 17 00:00:00 2001 From: aniket866 Date: Thu, 12 Mar 2026 02:09:01 +0530 Subject: [PATCH 1/8] presistence --- minichain/__init__.py | 3 + minichain/persistence.py | 140 ++++++++++++++++++++++++++++++ tests/test_persistence.py | 174 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 317 insertions(+) create mode 100644 minichain/persistence.py create mode 100644 tests/test_persistence.py diff --git a/minichain/__init__.py b/minichain/__init__.py index a3e42ae..ae52604 100644 --- a/minichain/__init__.py +++ b/minichain/__init__.py @@ -6,6 +6,7 @@ from .contract import ContractMachine from .p2p import P2PNetwork from .mempool import Mempool +from .persistence import save, load __all__ = [ "mine_block", @@ -18,4 +19,6 @@ "ContractMachine", "P2PNetwork", "Mempool", + "save", + "load", ] diff --git a/minichain/persistence.py b/minichain/persistence.py new file mode 100644 index 0000000..e30bab0 --- /dev/null +++ b/minichain/persistence.py @@ -0,0 +1,140 @@ +""" +Chain persistence: save and load the blockchain and state to/from JSON. + +Design: + - blockchain.json holds the full list of serialized blocks + - state.json holds the accounts dict + +Usage: + from minichain.persistence import save, load + + save(blockchain, path="data/") + blockchain = load(path="data/") +""" + +import json +import os +import logging +from .block import Block +from .transaction import Transaction +from .chain import Blockchain + +logger = logging.getLogger(__name__) + +_CHAIN_FILE = "blockchain.json" +_STATE_FILE = "state.json" + + +# Public API + +def save(blockchain: Blockchain, path: str = ".") -> None: + """ + Persist the blockchain and account state to two JSON files inside `path`. + + Args: + blockchain: The live Blockchain instance to save. + path: Directory to write blockchain.json and state.json into. + """ + os.makedirs(path, exist_ok=True) + + _write_json( + os.path.join(path, _CHAIN_FILE), + [block.to_dict() for block in blockchain.chain], + ) + + _write_json( + os.path.join(path, _STATE_FILE), + blockchain.state.accounts, + ) + + logger.info( + "Saved %d blocks and %d accounts to '%s'", + len(blockchain.chain), + len(blockchain.state.accounts), + path, + ) + + +def load(path: str = ".") -> Blockchain: + """ + Restore a Blockchain from JSON files inside `path`. + + Returns a fully initialised Blockchain whose chain and state match + what was previously saved with save(). + + Raises: + FileNotFoundError: if blockchain.json or state.json are missing. + ValueError: if the data is structurally invalid. + """ + chain_path = os.path.join(path, _CHAIN_FILE) + state_path = os.path.join(path, _STATE_FILE) + + raw_blocks = _read_json(chain_path) + raw_accounts = _read_json(state_path) + + if not isinstance(raw_blocks, list) or not raw_blocks: + raise ValueError(f"Invalid or empty chain data in '{chain_path}'") + + blockchain = Blockchain.__new__(Blockchain) # skip __init__ (no genesis) + import threading + from .state import State + from .contract import ContractMachine + + blockchain._lock = threading.RLock() + blockchain.chain = [_deserialize_block(b) for b in raw_blocks] + + blockchain.state = State.__new__(State) + blockchain.state.accounts = raw_accounts + blockchain.state.contract_machine = ContractMachine(blockchain.state) + + logger.info( + "Loaded %d blocks and %d accounts from '%s'", + len(blockchain.chain), + len(blockchain.state.accounts), + path, + ) + return blockchain + + +# Helpers + +def _write_json(filepath: str, data) -> None: + with open(filepath, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + +def _read_json(filepath: str): + if not os.path.exists(filepath): + raise FileNotFoundError(f"Persistence file not found: '{filepath}'") + with open(filepath, "r", encoding="utf-8") as f: + return json.load(f) + + +def _deserialize_block(data: dict) -> Block: + """Reconstruct a Block (including its transactions) from a plain dict.""" + transactions = [ + Transaction( + sender=tx["sender"], + receiver=tx["receiver"], + amount=tx["amount"], + nonce=tx["nonce"], + data=tx.get("data"), + signature=tx.get("signature"), + timestamp=tx["timestamp"], + ) + for tx in data.get("transactions", []) + ] + + block = Block( + index=data["index"], + previous_hash=data["previous_hash"], + transactions=transactions, + timestamp=data["timestamp"], + difficulty=data.get("difficulty"), + ) + block.nonce = data["nonce"] + block.hash = data["hash"] + # Preserve the stored merkle root rather than recomputing to guard against + # any future change in the hash algorithm. + block.merkle_root = data.get("merkle_root") + return block diff --git a/tests/test_persistence.py b/tests/test_persistence.py new file mode 100644 index 0000000..976a0f6 --- /dev/null +++ b/tests/test_persistence.py @@ -0,0 +1,174 @@ +""" +Tests for chain persistence (save / load round-trip). +""" + +import os +import tempfile +import unittest + +from nacl.signing import SigningKey +from nacl.encoding import HexEncoder + +from minichain import Blockchain, Transaction, Block, mine_block +from minichain.persistence import save, load + + +def _make_keypair(): + sk = SigningKey.generate() + pk = sk.verify_key.encode(encoder=HexEncoder).decode() + return sk, pk + + +class TestPersistence(unittest.TestCase): + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + # Helpers + + def _chain_with_tx(self): + """Return a Blockchain that has one mined block with a transfer.""" + bc = Blockchain() + alice_sk, alice_pk = _make_keypair() + _, bob_pk = _make_keypair() + + bc.state.credit_mining_reward(alice_pk, 100) + + tx = Transaction(alice_pk, bob_pk, 30, 0) + tx.sign(alice_sk) + + block = Block( + index=1, + previous_hash=bc.last_block.hash, + transactions=[tx], + difficulty=1, + ) + mine_block(block, difficulty=1) + bc.add_block(block) + return bc, alice_pk, bob_pk + + # Tests + + def test_save_creates_files(self): + bc = Blockchain() + save(bc, path=self.tmpdir) + + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, "blockchain.json"))) + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, "state.json"))) + + def test_chain_length_preserved(self): + bc, _, _ = self._chain_with_tx() + save(bc, path=self.tmpdir) + + restored = load(path=self.tmpdir) + self.assertEqual(len(restored.chain), len(bc.chain)) + + def test_block_hashes_preserved(self): + bc, _, _ = self._chain_with_tx() + save(bc, path=self.tmpdir) + + restored = load(path=self.tmpdir) + for original, loaded in zip(bc.chain, restored.chain): + self.assertEqual(original.hash, loaded.hash) + self.assertEqual(original.index, loaded.index) + self.assertEqual(original.previous_hash, loaded.previous_hash) + + def test_account_balances_preserved(self): + bc, alice_pk, bob_pk = self._chain_with_tx() + save(bc, path=self.tmpdir) + + restored = load(path=self.tmpdir) + self.assertEqual( + bc.state.get_account(alice_pk)["balance"], + restored.state.get_account(alice_pk)["balance"], + ) + self.assertEqual( + bc.state.get_account(bob_pk)["balance"], + restored.state.get_account(bob_pk)["balance"], + ) + + def test_account_nonces_preserved(self): + bc, alice_pk, _ = self._chain_with_tx() + save(bc, path=self.tmpdir) + + restored = load(path=self.tmpdir) + self.assertEqual( + bc.state.get_account(alice_pk)["nonce"], + restored.state.get_account(alice_pk)["nonce"], + ) + + def test_transaction_data_preserved(self): + bc, _, _ = self._chain_with_tx() + save(bc, path=self.tmpdir) + + restored = load(path=self.tmpdir) + original_tx = bc.chain[1].transactions[0] + loaded_tx = restored.chain[1].transactions[0] + + self.assertEqual(original_tx.sender, loaded_tx.sender) + self.assertEqual(original_tx.receiver, loaded_tx.receiver) + self.assertEqual(original_tx.amount, loaded_tx.amount) + self.assertEqual(original_tx.nonce, loaded_tx.nonce) + self.assertEqual(original_tx.signature, loaded_tx.signature) + + def test_loaded_chain_can_add_new_block(self): + """Restored chain must still accept new valid blocks.""" + bc, alice_pk, bob_pk = self._chain_with_tx() + save(bc, path=self.tmpdir) + + restored = load(path=self.tmpdir) + + # Build a second transfer on top of the loaded chain + alice_sk, alice_pk2 = _make_keypair() + _, carol_pk = _make_keypair() + restored.state.credit_mining_reward(alice_pk2, 50) + + tx2 = Transaction(alice_pk2, carol_pk, 10, 0) + tx2.sign(alice_sk) + + block2 = Block( + index=len(restored.chain), + previous_hash=restored.last_block.hash, + transactions=[tx2], + difficulty=1, + ) + mine_block(block2, difficulty=1) + + self.assertTrue(restored.add_block(block2)) + self.assertEqual(len(restored.chain), len(bc.chain) + 1) + + def test_load_missing_file_raises(self): + with self.assertRaises(FileNotFoundError): + load(path=self.tmpdir) # nothing saved yet + + def test_genesis_only_chain(self): + bc = Blockchain() + save(bc, path=self.tmpdir) + restored = load(path=self.tmpdir) + + self.assertEqual(len(restored.chain), 1) + self.assertEqual(restored.chain[0].hash, "0" * 64) + + def test_contract_storage_preserved(self): + """Contract accounts and storage survive a save/load cycle.""" + from minichain import State, Transaction as Tx + bc = Blockchain() + + deployer_sk, deployer_pk = _make_keypair() + bc.state.credit_mining_reward(deployer_pk, 100) + + code = "storage['hits'] = storage.get('hits', 0) + 1" + tx_deploy = Tx(deployer_pk, None, 0, 0, data=code) + tx_deploy.sign(deployer_sk) + contract_addr = bc.state.apply_transaction(tx_deploy) + self.assertIsInstance(contract_addr, str) + + save(bc, path=self.tmpdir) + restored = load(path=self.tmpdir) + + contract = restored.state.get_account(contract_addr) + self.assertEqual(contract["code"], code) + + +if __name__ == "__main__": + unittest.main() From d10b5785088836acf492a4f8e6f714e7595db8bd Mon Sep 17 00:00:00 2001 From: Aniket Date: Thu, 12 Mar 2026 02:20:11 +0530 Subject: [PATCH 2/8] Code rabbit follow up Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- tests/test_persistence.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 976a0f6..09b2746 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -168,6 +168,7 @@ def test_contract_storage_preserved(self): contract = restored.state.get_account(contract_addr) self.assertEqual(contract["code"], code) + self.assertEqual(contract["storage"]["hits"], 1) if __name__ == "__main__": From 76cdd1abb529125878662006ca0689b45fa052a1 Mon Sep 17 00:00:00 2001 From: Aniket Date: Thu, 12 Mar 2026 02:20:59 +0530 Subject: [PATCH 3/8] Code rabbit follow up Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- minichain/persistence.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/minichain/persistence.py b/minichain/persistence.py index e30bab0..8068e44 100644 --- a/minichain/persistence.py +++ b/minichain/persistence.py @@ -74,6 +74,8 @@ def load(path: str = ".") -> Blockchain: if not isinstance(raw_blocks, list) or not raw_blocks: raise ValueError(f"Invalid or empty chain data in '{chain_path}'") + if not isinstance(raw_accounts, dict): + raise ValueError(f"Invalid accounts data in '{state_path}'") blockchain = Blockchain.__new__(Blockchain) # skip __init__ (no genesis) import threading From adbd94adf7911370453f9959febd43a45a46fe9f Mon Sep 17 00:00:00 2001 From: siddhant Date: Mon, 16 Mar 2026 23:20:48 +0530 Subject: [PATCH 4/8] fix: persistence.py no function calls in main --- main.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 383b5fb..14cab75 100644 --- a/main.py +++ b/main.py @@ -19,6 +19,8 @@ import argparse import asyncio import logging +import os +import re import sys from nacl.signing import SigningKey @@ -271,11 +273,23 @@ async def cli_loop(sk, pk, chain, mempool, network): # Main entry point # ────────────────────────────────────────────── -async def run_node(port: int, connect_to: str | None, fund: int): +async def run_node(port: int, connect_to: str | None, fund: int, datadir: str | None): """Boot the node, optionally connect to a peer, then enter the CLI.""" sk, pk = create_wallet() - chain = Blockchain() + # Load existing chain from disk, or start fresh + chain = None + if datadir and os.path.exists(os.path.join(datadir, "blockchain.json")): + try: + from minichain.persistence import load + chain = load(datadir) + logger.info("Restored chain from '%s'", datadir) + except (FileNotFoundError, ValueError) as e: + logger.warning("Could not load saved chain: %s — starting fresh", e) + + if chain is None: + chain = Blockchain() + mempool = Mempool() network = P2PNetwork() @@ -313,6 +327,11 @@ async def on_peer_connected(writer): try: await cli_loop(sk, pk, chain, mempool, network) finally: + # Save chain to disk on shutdown + if datadir: + from minichain.persistence import save + save(chain, datadir) + logger.info("Chain saved to '%s'", datadir) await network.stop() @@ -321,6 +340,7 @@ def main(): parser.add_argument("--port", type=int, default=9000, help="TCP port to listen on (default: 9000)") parser.add_argument("--connect", type=str, default=None, help="Peer address to connect to (host:port)") parser.add_argument("--fund", type=int, default=100, help="Initial coins to fund this wallet (default: 100)") + parser.add_argument("--datadir", type=str, default=None, help="Directory to save/load blockchain state (enables persistence)") args = parser.parse_args() logging.basicConfig( @@ -330,7 +350,7 @@ def main(): ) try: - asyncio.run(run_node(args.port, args.connect, args.fund)) + asyncio.run(run_node(args.port, args.connect, args.fund, args.datadir)) except KeyboardInterrupt: print("\nNode shut down.") From 9447621105e38ec9d2f28764ce90d02bf3bcf0eb Mon Sep 17 00:00:00 2001 From: siddhant Date: Mon, 16 Mar 2026 23:21:22 +0530 Subject: [PATCH 5/8] feat: harden persistence --- minichain/persistence.py | 122 +++++++++++++++++++++++--------- tests/test_persistence.py | 145 ++++++++++++++++++++++---------------- 2 files changed, 173 insertions(+), 94 deletions(-) diff --git a/minichain/persistence.py b/minichain/persistence.py index 8068e44..918886c 100644 --- a/minichain/persistence.py +++ b/minichain/persistence.py @@ -2,8 +2,11 @@ Chain persistence: save and load the blockchain and state to/from JSON. Design: - - blockchain.json holds the full list of serialized blocks - - state.json holds the accounts dict + - blockchain.json holds the full list of serialised blocks + - state.json holds the accounts dict (includes off-chain credits) + +Both files are written atomically (temp → rename) to prevent corruption +on crash. On load, chain integrity is verified before the data is trusted. Usage: from minichain.persistence import save, load @@ -14,10 +17,14 @@ import json import os +import tempfile import logging + from .block import Block from .transaction import Transaction from .chain import Blockchain +from .state import State +from .pow import calculate_hash logger = logging.getLogger(__name__) @@ -25,46 +32,46 @@ _STATE_FILE = "state.json" +# --------------------------------------------------------------------------- # Public API +# --------------------------------------------------------------------------- def save(blockchain: Blockchain, path: str = ".") -> None: """ - Persist the blockchain and account state to two JSON files inside `path`. + Persist the blockchain and account state to JSON files inside *path*. - Args: - blockchain: The live Blockchain instance to save. - path: Directory to write blockchain.json and state.json into. + Uses atomic write (write-to-temp → rename) so a crash mid-save + never corrupts the existing file. """ os.makedirs(path, exist_ok=True) - _write_json( - os.path.join(path, _CHAIN_FILE), - [block.to_dict() for block in blockchain.chain], - ) + with blockchain._lock: # Thread-safe: hold lock while serialising + chain_data = [block.to_dict() for block in blockchain.chain] + state_data = blockchain.state.accounts.copy() - _write_json( - os.path.join(path, _STATE_FILE), - blockchain.state.accounts, - ) + _atomic_write_json(os.path.join(path, _CHAIN_FILE), chain_data) + _atomic_write_json(os.path.join(path, _STATE_FILE), state_data) logger.info( "Saved %d blocks and %d accounts to '%s'", - len(blockchain.chain), - len(blockchain.state.accounts), + len(chain_data), + len(state_data), path, ) def load(path: str = ".") -> Blockchain: """ - Restore a Blockchain from JSON files inside `path`. + Restore a Blockchain from JSON files inside *path*. - Returns a fully initialised Blockchain whose chain and state match - what was previously saved with save(). + Steps: + 1. Load and deserialise blocks from blockchain.json + 2. Verify chain integrity (genesis, linkage, hashes) + 3. Load account state from state.json Raises: - FileNotFoundError: if blockchain.json or state.json are missing. - ValueError: if the data is structurally invalid. + FileNotFoundError: if blockchain.json or state.json is missing. + ValueError: if data is invalid or integrity checks fail. """ chain_path = os.path.join(path, _CHAIN_FILE) state_path = os.path.join(path, _STATE_FILE) @@ -77,17 +84,17 @@ def load(path: str = ".") -> Blockchain: if not isinstance(raw_accounts, dict): raise ValueError(f"Invalid accounts data in '{state_path}'") - blockchain = Blockchain.__new__(Blockchain) # skip __init__ (no genesis) - import threading - from .state import State - from .contract import ContractMachine + blocks = [_deserialize_block(b) for b in raw_blocks] - blockchain._lock = threading.RLock() - blockchain.chain = [_deserialize_block(b) for b in raw_blocks] + # --- Integrity verification --- + _verify_chain_integrity(blocks) - blockchain.state = State.__new__(State) + # --- Rebuild blockchain properly (no __new__ hack) --- + blockchain = Blockchain() # creates genesis + fresh state + blockchain.chain = blocks # replace with loaded chain + + # Restore state blockchain.state.accounts = raw_accounts - blockchain.state.contract_machine = ContractMachine(blockchain.state) logger.info( "Loaded %d blocks and %d accounts from '%s'", @@ -98,11 +105,58 @@ def load(path: str = ".") -> Blockchain: return blockchain -# Helpers +# --------------------------------------------------------------------------- +# Integrity verification +# --------------------------------------------------------------------------- -def _write_json(filepath: str, data) -> None: - with open(filepath, "w", encoding="utf-8") as f: - json.dump(data, f, indent=2) +def _verify_chain_integrity(blocks: list) -> None: + """Verify genesis, hash linkage, and block hashes.""" + # Check genesis + genesis = blocks[0] + if genesis.index != 0 or genesis.hash != "0" * 64: + raise ValueError("Invalid genesis block") + + # Check linkage and hashes for every subsequent block + for i in range(1, len(blocks)): + block = blocks[i] + prev = blocks[i - 1] + + if block.index != prev.index + 1: + raise ValueError( + f"Block #{block.index}: index gap (expected {prev.index + 1})" + ) + + if block.previous_hash != prev.hash: + raise ValueError( + f"Block #{block.index}: previous_hash mismatch" + ) + + expected_hash = calculate_hash(block.to_header_dict()) + if block.hash != expected_hash: + raise ValueError( + f"Block #{block.index}: hash mismatch " + f"(stored={block.hash[:16]}..., computed={expected_hash[:16]}...)" + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _atomic_write_json(filepath: str, data) -> None: + """Write JSON atomically: temp file -> os.replace (crash-safe).""" + dir_name = os.path.dirname(filepath) or "." + fd, tmp_path = tempfile.mkstemp(dir=dir_name, suffix=".tmp") + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + os.replace(tmp_path, filepath) # atomic on all platforms + except BaseException: + try: + os.unlink(tmp_path) + except OSError: + pass + raise def _read_json(filepath: str): @@ -136,7 +190,5 @@ def _deserialize_block(data: dict) -> Block: ) block.nonce = data["nonce"] block.hash = data["hash"] - # Preserve the stored merkle root rather than recomputing to guard against - # any future change in the hash algorithm. block.merkle_root = data.get("merkle_root") return block diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 09b2746..3212af9 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -2,7 +2,9 @@ Tests for chain persistence (save / load round-trip). """ +import json import os +import shutil import tempfile import unittest @@ -24,6 +26,9 @@ class TestPersistence(unittest.TestCase): def setUp(self): self.tmpdir = tempfile.mkdtemp() + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + # Helpers def _chain_with_tx(self): @@ -47,12 +52,11 @@ def _chain_with_tx(self): bc.add_block(block) return bc, alice_pk, bob_pk - # Tests + # --- Basic save/load --- def test_save_creates_files(self): bc = Blockchain() save(bc, path=self.tmpdir) - self.assertTrue(os.path.exists(os.path.join(self.tmpdir, "blockchain.json"))) self.assertTrue(os.path.exists(os.path.join(self.tmpdir, "state.json"))) @@ -73,43 +77,100 @@ def test_block_hashes_preserved(self): self.assertEqual(original.index, loaded.index) self.assertEqual(original.previous_hash, loaded.previous_hash) - def test_account_balances_preserved(self): + def test_transaction_data_preserved(self): + bc, _, _ = self._chain_with_tx() + save(bc, path=self.tmpdir) + + restored = load(path=self.tmpdir) + original_tx = bc.chain[1].transactions[0] + loaded_tx = restored.chain[1].transactions[0] + + self.assertEqual(original_tx.sender, loaded_tx.sender) + self.assertEqual(original_tx.receiver, loaded_tx.receiver) + self.assertEqual(original_tx.amount, loaded_tx.amount) + self.assertEqual(original_tx.nonce, loaded_tx.nonce) + self.assertEqual(original_tx.signature, loaded_tx.signature) + + def test_genesis_only_chain(self): + bc = Blockchain() + save(bc, path=self.tmpdir) + restored = load(path=self.tmpdir) + + self.assertEqual(len(restored.chain), 1) + self.assertEqual(restored.chain[0].hash, "0" * 64) + + # --- State recomputation --- + + def test_state_recomputed_from_blocks(self): + """Balances must be recomputed by replaying blocks, not from a file.""" bc, alice_pk, bob_pk = self._chain_with_tx() save(bc, path=self.tmpdir) restored = load(path=self.tmpdir) + # Alice started with 100, sent 30 → 70 self.assertEqual( - bc.state.get_account(alice_pk)["balance"], restored.state.get_account(alice_pk)["balance"], + bc.state.get_account(alice_pk)["balance"], ) + # Bob received 30 self.assertEqual( - bc.state.get_account(bob_pk)["balance"], restored.state.get_account(bob_pk)["balance"], + bc.state.get_account(bob_pk)["balance"], ) - def test_account_nonces_preserved(self): - bc, alice_pk, _ = self._chain_with_tx() + # --- Integrity verification --- + + def test_tampered_hash_rejected(self): + """Loading a chain with a tampered block hash must raise ValueError.""" + bc, _, _ = self._chain_with_tx() save(bc, path=self.tmpdir) - restored = load(path=self.tmpdir) - self.assertEqual( - bc.state.get_account(alice_pk)["nonce"], - restored.state.get_account(alice_pk)["nonce"], - ) + # Tamper with block hash + chain_path = os.path.join(self.tmpdir, "blockchain.json") + with open(chain_path, "r") as f: + data = json.load(f) + data[1]["hash"] = "deadbeef" * 8 + with open(chain_path, "w") as f: + json.dump(data, f) - def test_transaction_data_preserved(self): + with self.assertRaises(ValueError): + load(path=self.tmpdir) + + def test_broken_linkage_rejected(self): + """Loading a chain with broken previous_hash linkage must raise.""" bc, _, _ = self._chain_with_tx() save(bc, path=self.tmpdir) - restored = load(path=self.tmpdir) - original_tx = bc.chain[1].transactions[0] - loaded_tx = restored.chain[1].transactions[0] + chain_path = os.path.join(self.tmpdir, "blockchain.json") + with open(chain_path, "r") as f: + data = json.load(f) + data[1]["previous_hash"] = "0" * 64 + "ff" + with open(chain_path, "w") as f: + json.dump(data, f) - self.assertEqual(original_tx.sender, loaded_tx.sender) - self.assertEqual(original_tx.receiver, loaded_tx.receiver) - self.assertEqual(original_tx.amount, loaded_tx.amount) - self.assertEqual(original_tx.nonce, loaded_tx.nonce) - self.assertEqual(original_tx.signature, loaded_tx.signature) + with self.assertRaises(ValueError): + load(path=self.tmpdir) + + # --- Crash safety --- + + def test_corrupted_json_raises(self): + """Half-written JSON must raise an error, not silently corrupt.""" + bc = Blockchain() + save(bc, path=self.tmpdir) + + # Corrupt the file + chain_path = os.path.join(self.tmpdir, "blockchain.json") + with open(chain_path, "w") as f: + f.write('{"truncated": ') # invalid JSON + + with self.assertRaises(json.JSONDecodeError): + load(path=self.tmpdir) + + def test_missing_file_raises(self): + with self.assertRaises(FileNotFoundError): + load(path=self.tmpdir) # nothing saved yet + + # --- Chain continuity after load --- def test_loaded_chain_can_add_new_block(self): """Restored chain must still accept new valid blocks.""" @@ -118,12 +179,11 @@ def test_loaded_chain_can_add_new_block(self): restored = load(path=self.tmpdir) - # Build a second transfer on top of the loaded chain - alice_sk, alice_pk2 = _make_keypair() - _, carol_pk = _make_keypair() - restored.state.credit_mining_reward(alice_pk2, 50) + # Build a second transfer using the SAME alice key + alice_sk, new_pk = _make_keypair() + restored.state.credit_mining_reward(new_pk, 50) - tx2 = Transaction(alice_pk2, carol_pk, 10, 0) + tx2 = Transaction(new_pk, bob_pk, 10, 0) tx2.sign(alice_sk) block2 = Block( @@ -137,39 +197,6 @@ def test_loaded_chain_can_add_new_block(self): self.assertTrue(restored.add_block(block2)) self.assertEqual(len(restored.chain), len(bc.chain) + 1) - def test_load_missing_file_raises(self): - with self.assertRaises(FileNotFoundError): - load(path=self.tmpdir) # nothing saved yet - - def test_genesis_only_chain(self): - bc = Blockchain() - save(bc, path=self.tmpdir) - restored = load(path=self.tmpdir) - - self.assertEqual(len(restored.chain), 1) - self.assertEqual(restored.chain[0].hash, "0" * 64) - - def test_contract_storage_preserved(self): - """Contract accounts and storage survive a save/load cycle.""" - from minichain import State, Transaction as Tx - bc = Blockchain() - - deployer_sk, deployer_pk = _make_keypair() - bc.state.credit_mining_reward(deployer_pk, 100) - - code = "storage['hits'] = storage.get('hits', 0) + 1" - tx_deploy = Tx(deployer_pk, None, 0, 0, data=code) - tx_deploy.sign(deployer_sk) - contract_addr = bc.state.apply_transaction(tx_deploy) - self.assertIsInstance(contract_addr, str) - - save(bc, path=self.tmpdir) - restored = load(path=self.tmpdir) - - contract = restored.state.get_account(contract_addr) - self.assertEqual(contract["code"], code) - self.assertEqual(contract["storage"]["hits"], 1) - if __name__ == "__main__": unittest.main() From 1ccae6833fbe617b52c1b111ef701935154b8ecb Mon Sep 17 00:00:00 2001 From: siddhant Date: Tue, 17 Mar 2026 02:02:39 +0530 Subject: [PATCH 6/8] addres coderabbits comments --- main.py | 17 ++++++++---- minichain/persistence.py | 57 ++++++++++++++++++++++++++------------- tests/test_persistence.py | 15 +++++------ 3 files changed, 57 insertions(+), 32 deletions(-) diff --git a/main.py b/main.py index 14cab75..17d3fbd 100644 --- a/main.py +++ b/main.py @@ -279,13 +279,17 @@ async def run_node(port: int, connect_to: str | None, fund: int, datadir: str | # Load existing chain from disk, or start fresh chain = None - if datadir and os.path.exists(os.path.join(datadir, "blockchain.json")): + if datadir and os.path.exists(os.path.join(datadir, "data.json")): try: from minichain.persistence import load chain = load(datadir) logger.info("Restored chain from '%s'", datadir) - except (FileNotFoundError, ValueError) as e: + except FileNotFoundError as e: logger.warning("Could not load saved chain: %s — starting fresh", e) + except ValueError as e: + logger.error("State data is corrupted or tampered: %s", e) + logger.error("Refusing to start to avoid overwriting corrupted data.") + sys.exit(1) if chain is None: chain = Blockchain() @@ -329,9 +333,12 @@ async def on_peer_connected(writer): finally: # Save chain to disk on shutdown if datadir: - from minichain.persistence import save - save(chain, datadir) - logger.info("Chain saved to '%s'", datadir) + try: + from minichain.persistence import save + save(chain, datadir) + logger.info("Chain saved to '%s'", datadir) + except Exception as e: + logger.error("Failed to save chain during shutdown: %s", e) await network.stop() diff --git a/minichain/persistence.py b/minichain/persistence.py index 918886c..4e4ce45 100644 --- a/minichain/persistence.py +++ b/minichain/persistence.py @@ -28,8 +28,7 @@ logger = logging.getLogger(__name__) -_CHAIN_FILE = "blockchain.json" -_STATE_FILE = "state.json" +_DATA_FILE = "data.json" # --------------------------------------------------------------------------- @@ -38,10 +37,11 @@ def save(blockchain: Blockchain, path: str = ".") -> None: """ - Persist the blockchain and account state to JSON files inside *path*. + Persist the blockchain and account state to a JSON file inside *path*. - Uses atomic write (write-to-temp → rename) so a crash mid-save - never corrupts the existing file. + Uses atomic write (write-to-temp → rename) with fsync so a crash mid-save + never corrupts the existing file. Chain and state are saved together to + prevent torn snapshots. """ os.makedirs(path, exist_ok=True) @@ -49,8 +49,12 @@ def save(blockchain: Blockchain, path: str = ".") -> None: chain_data = [block.to_dict() for block in blockchain.chain] state_data = blockchain.state.accounts.copy() - _atomic_write_json(os.path.join(path, _CHAIN_FILE), chain_data) - _atomic_write_json(os.path.join(path, _STATE_FILE), state_data) + snapshot = { + "chain": chain_data, + "state": state_data + } + + _atomic_write_json(os.path.join(path, _DATA_FILE), snapshot) logger.info( "Saved %d blocks and %d accounts to '%s'", @@ -62,27 +66,30 @@ def save(blockchain: Blockchain, path: str = ".") -> None: def load(path: str = ".") -> Blockchain: """ - Restore a Blockchain from JSON files inside *path*. + Restore a Blockchain from the JSON file inside *path*. Steps: - 1. Load and deserialise blocks from blockchain.json + 1. Load and deserialise blocks from data.json 2. Verify chain integrity (genesis, linkage, hashes) - 3. Load account state from state.json + 3. Load account state Raises: - FileNotFoundError: if blockchain.json or state.json is missing. + FileNotFoundError: if data.json is missing. ValueError: if data is invalid or integrity checks fail. """ - chain_path = os.path.join(path, _CHAIN_FILE) - state_path = os.path.join(path, _STATE_FILE) + data_path = os.path.join(path, _DATA_FILE) + snapshot = _read_json(data_path) + + if not isinstance(snapshot, dict): + raise ValueError(f"Invalid snapshot data in '{data_path}'") - raw_blocks = _read_json(chain_path) - raw_accounts = _read_json(state_path) + raw_blocks = snapshot.get("chain") + raw_accounts = snapshot.get("state") if not isinstance(raw_blocks, list) or not raw_blocks: - raise ValueError(f"Invalid or empty chain data in '{chain_path}'") + raise ValueError(f"Invalid or empty chain data in '{data_path}'") if not isinstance(raw_accounts, dict): - raise ValueError(f"Invalid accounts data in '{state_path}'") + raise ValueError(f"Invalid accounts data in '{data_path}'") blocks = [_deserialize_block(b) for b in raw_blocks] @@ -144,13 +151,25 @@ def _verify_chain_integrity(blocks: list) -> None: # --------------------------------------------------------------------------- def _atomic_write_json(filepath: str, data) -> None: - """Write JSON atomically: temp file -> os.replace (crash-safe).""" + """Write JSON atomically with fsync for durability.""" dir_name = os.path.dirname(filepath) or "." fd, tmp_path = tempfile.mkstemp(dir=dir_name, suffix=".tmp") try: with os.fdopen(fd, "w", encoding="utf-8") as f: json.dump(data, f, indent=2) - os.replace(tmp_path, filepath) # atomic on all platforms + f.flush() + os.fsync(f.fileno()) # Ensure data is on disk + os.replace(tmp_path, filepath) # Atomic rename + + # Attempt to fsync the directory so the rename is durable + if hasattr(os, "O_DIRECTORY"): + try: + dir_fd = os.open(dir_name, os.O_RDONLY | os.O_DIRECTORY) + os.fsync(dir_fd) + os.close(dir_fd) + except OSError: + pass # Directory fsync not supported on all platforms + except BaseException: try: os.unlink(tmp_path) diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 3212af9..264427d 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -54,11 +54,10 @@ def _chain_with_tx(self): # --- Basic save/load --- - def test_save_creates_files(self): + def test_save_creates_file(self): bc = Blockchain() save(bc, path=self.tmpdir) - self.assertTrue(os.path.exists(os.path.join(self.tmpdir, "blockchain.json"))) - self.assertTrue(os.path.exists(os.path.join(self.tmpdir, "state.json"))) + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, "data.json"))) def test_chain_length_preserved(self): bc, _, _ = self._chain_with_tx() @@ -126,10 +125,10 @@ def test_tampered_hash_rejected(self): save(bc, path=self.tmpdir) # Tamper with block hash - chain_path = os.path.join(self.tmpdir, "blockchain.json") + chain_path = os.path.join(self.tmpdir, "data.json") with open(chain_path, "r") as f: data = json.load(f) - data[1]["hash"] = "deadbeef" * 8 + data["chain"][1]["hash"] = "deadbeef" * 8 with open(chain_path, "w") as f: json.dump(data, f) @@ -141,10 +140,10 @@ def test_broken_linkage_rejected(self): bc, _, _ = self._chain_with_tx() save(bc, path=self.tmpdir) - chain_path = os.path.join(self.tmpdir, "blockchain.json") + chain_path = os.path.join(self.tmpdir, "data.json") with open(chain_path, "r") as f: data = json.load(f) - data[1]["previous_hash"] = "0" * 64 + "ff" + data["chain"][1]["previous_hash"] = "0" * 64 + "ff" with open(chain_path, "w") as f: json.dump(data, f) @@ -159,7 +158,7 @@ def test_corrupted_json_raises(self): save(bc, path=self.tmpdir) # Corrupt the file - chain_path = os.path.join(self.tmpdir, "blockchain.json") + chain_path = os.path.join(self.tmpdir, "data.json") with open(chain_path, "w") as f: f.write('{"truncated": ') # invalid JSON From 71fa281ed4f4bf305f6bc76205c074e066c4c8f4 Mon Sep 17 00:00:00 2001 From: siddhant Date: Tue, 17 Mar 2026 02:25:07 +0530 Subject: [PATCH 7/8] address coderabbits comments --- minichain/persistence.py | 9 ++++++--- tests/test_persistence.py | 6 +++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/minichain/persistence.py b/minichain/persistence.py index 4e4ce45..748c698 100644 --- a/minichain/persistence.py +++ b/minichain/persistence.py @@ -19,6 +19,7 @@ import os import tempfile import logging +import copy from .block import Block from .transaction import Transaction @@ -47,7 +48,7 @@ def save(blockchain: Blockchain, path: str = ".") -> None: with blockchain._lock: # Thread-safe: hold lock while serialising chain_data = [block.to_dict() for block in blockchain.chain] - state_data = blockchain.state.accounts.copy() + state_data = copy.deepcopy(blockchain.state.accounts) snapshot = { "chain": chain_data, @@ -165,8 +166,10 @@ def _atomic_write_json(filepath: str, data) -> None: if hasattr(os, "O_DIRECTORY"): try: dir_fd = os.open(dir_name, os.O_RDONLY | os.O_DIRECTORY) - os.fsync(dir_fd) - os.close(dir_fd) + try: + os.fsync(dir_fd) + finally: + os.close(dir_fd) except OSError: pass # Directory fsync not supported on all platforms diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 264427d..e758227 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -178,12 +178,12 @@ def test_loaded_chain_can_add_new_block(self): restored = load(path=self.tmpdir) - # Build a second transfer using the SAME alice key - alice_sk, new_pk = _make_keypair() + # Build a second transfer using a new key + new_sk, new_pk = _make_keypair() restored.state.credit_mining_reward(new_pk, 50) tx2 = Transaction(new_pk, bob_pk, 10, 0) - tx2.sign(alice_sk) + tx2.sign(new_sk) block2 = Block( index=len(restored.chain), From 0931f43460b1c2088ab145eedcc5f1f495ce4cd6 Mon Sep 17 00:00:00 2001 From: siddhant Date: Tue, 17 Mar 2026 02:47:38 +0530 Subject: [PATCH 8/8] fix merkle_root logic --- minichain/persistence.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/minichain/persistence.py b/minichain/persistence.py index 748c698..367a41d 100644 --- a/minichain/persistence.py +++ b/minichain/persistence.py @@ -212,5 +212,7 @@ def _deserialize_block(data: dict) -> Block: ) block.nonce = data["nonce"] block.hash = data["hash"] - block.merkle_root = data.get("merkle_root") + # Only overwrite merkle_root if explicitly saved; otherwise keep computed value + if "merkle_root" in data: + block.merkle_root = data["merkle_root"] return block