diff --git a/main.py b/main.py index 383b5fb..17d3fbd 100644 --- a/main.py +++ b/main.py @@ -19,6 +19,8 @@ import argparse import asyncio import logging +import os +import re import sys from nacl.signing import SigningKey @@ -271,11 +273,27 @@ async def cli_loop(sk, pk, chain, mempool, network): # Main entry point # ────────────────────────────────────────────── -async def run_node(port: int, connect_to: str | None, fund: int): +async def run_node(port: int, connect_to: str | None, fund: int, datadir: str | None): """Boot the node, optionally connect to a peer, then enter the CLI.""" sk, pk = create_wallet() - chain = Blockchain() + # Load existing chain from disk, or start fresh + chain = None + if datadir and os.path.exists(os.path.join(datadir, "data.json")): + try: + from minichain.persistence import load + chain = load(datadir) + logger.info("Restored chain from '%s'", datadir) + except FileNotFoundError as e: + logger.warning("Could not load saved chain: %s — starting fresh", e) + except ValueError as e: + logger.error("State data is corrupted or tampered: %s", e) + logger.error("Refusing to start to avoid overwriting corrupted data.") + sys.exit(1) + + if chain is None: + chain = Blockchain() + mempool = Mempool() network = P2PNetwork() @@ -313,6 +331,14 @@ async def on_peer_connected(writer): try: await cli_loop(sk, pk, chain, mempool, network) finally: + # Save chain to disk on shutdown + if datadir: + try: + from minichain.persistence import save + save(chain, datadir) + logger.info("Chain saved to '%s'", datadir) + except Exception as e: + logger.error("Failed to save chain during shutdown: %s", e) await network.stop() @@ -321,6 +347,7 @@ def main(): parser.add_argument("--port", type=int, default=9000, help="TCP port to listen on (default: 9000)") parser.add_argument("--connect", type=str, default=None, help="Peer address to connect to (host:port)") parser.add_argument("--fund", type=int, default=100, help="Initial coins to fund this wallet (default: 100)") + parser.add_argument("--datadir", type=str, default=None, help="Directory to save/load blockchain state (enables persistence)") args = parser.parse_args() logging.basicConfig( @@ -330,7 +357,7 @@ def main(): ) try: - asyncio.run(run_node(args.port, args.connect, args.fund)) + asyncio.run(run_node(args.port, args.connect, args.fund, args.datadir)) except KeyboardInterrupt: print("\nNode shut down.") diff --git a/minichain/__init__.py b/minichain/__init__.py index a3e42ae..ae52604 100644 --- a/minichain/__init__.py +++ b/minichain/__init__.py @@ -6,6 +6,7 @@ from .contract import ContractMachine from .p2p import P2PNetwork from .mempool import Mempool +from .persistence import save, load __all__ = [ "mine_block", @@ -18,4 +19,6 @@ "ContractMachine", "P2PNetwork", "Mempool", + "save", + "load", ] diff --git a/minichain/persistence.py b/minichain/persistence.py new file mode 100644 index 0000000..367a41d --- /dev/null +++ b/minichain/persistence.py @@ -0,0 +1,218 @@ +""" +Chain persistence: save and load the blockchain and state to/from JSON. + +Design: + - blockchain.json holds the full list of serialised blocks + - state.json holds the accounts dict (includes off-chain credits) + +Both files are written atomically (temp → rename) to prevent corruption +on crash. On load, chain integrity is verified before the data is trusted. + +Usage: + from minichain.persistence import save, load + + save(blockchain, path="data/") + blockchain = load(path="data/") +""" + +import json +import os +import tempfile +import logging +import copy + +from .block import Block +from .transaction import Transaction +from .chain import Blockchain +from .state import State +from .pow import calculate_hash + +logger = logging.getLogger(__name__) + +_DATA_FILE = "data.json" + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def save(blockchain: Blockchain, path: str = ".") -> None: + """ + Persist the blockchain and account state to a JSON file inside *path*. + + Uses atomic write (write-to-temp → rename) with fsync so a crash mid-save + never corrupts the existing file. Chain and state are saved together to + prevent torn snapshots. + """ + os.makedirs(path, exist_ok=True) + + with blockchain._lock: # Thread-safe: hold lock while serialising + chain_data = [block.to_dict() for block in blockchain.chain] + state_data = copy.deepcopy(blockchain.state.accounts) + + snapshot = { + "chain": chain_data, + "state": state_data + } + + _atomic_write_json(os.path.join(path, _DATA_FILE), snapshot) + + logger.info( + "Saved %d blocks and %d accounts to '%s'", + len(chain_data), + len(state_data), + path, + ) + + +def load(path: str = ".") -> Blockchain: + """ + Restore a Blockchain from the JSON file inside *path*. + + Steps: + 1. Load and deserialise blocks from data.json + 2. Verify chain integrity (genesis, linkage, hashes) + 3. Load account state + + Raises: + FileNotFoundError: if data.json is missing. + ValueError: if data is invalid or integrity checks fail. + """ + data_path = os.path.join(path, _DATA_FILE) + snapshot = _read_json(data_path) + + if not isinstance(snapshot, dict): + raise ValueError(f"Invalid snapshot data in '{data_path}'") + + raw_blocks = snapshot.get("chain") + raw_accounts = snapshot.get("state") + + if not isinstance(raw_blocks, list) or not raw_blocks: + raise ValueError(f"Invalid or empty chain data in '{data_path}'") + if not isinstance(raw_accounts, dict): + raise ValueError(f"Invalid accounts data in '{data_path}'") + + blocks = [_deserialize_block(b) for b in raw_blocks] + + # --- Integrity verification --- + _verify_chain_integrity(blocks) + + # --- Rebuild blockchain properly (no __new__ hack) --- + blockchain = Blockchain() # creates genesis + fresh state + blockchain.chain = blocks # replace with loaded chain + + # Restore state + blockchain.state.accounts = raw_accounts + + logger.info( + "Loaded %d blocks and %d accounts from '%s'", + len(blockchain.chain), + len(blockchain.state.accounts), + path, + ) + return blockchain + + +# --------------------------------------------------------------------------- +# Integrity verification +# --------------------------------------------------------------------------- + +def _verify_chain_integrity(blocks: list) -> None: + """Verify genesis, hash linkage, and block hashes.""" + # Check genesis + genesis = blocks[0] + if genesis.index != 0 or genesis.hash != "0" * 64: + raise ValueError("Invalid genesis block") + + # Check linkage and hashes for every subsequent block + for i in range(1, len(blocks)): + block = blocks[i] + prev = blocks[i - 1] + + if block.index != prev.index + 1: + raise ValueError( + f"Block #{block.index}: index gap (expected {prev.index + 1})" + ) + + if block.previous_hash != prev.hash: + raise ValueError( + f"Block #{block.index}: previous_hash mismatch" + ) + + expected_hash = calculate_hash(block.to_header_dict()) + if block.hash != expected_hash: + raise ValueError( + f"Block #{block.index}: hash mismatch " + f"(stored={block.hash[:16]}..., computed={expected_hash[:16]}...)" + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _atomic_write_json(filepath: str, data) -> None: + """Write JSON atomically with fsync for durability.""" + dir_name = os.path.dirname(filepath) or "." + fd, tmp_path = tempfile.mkstemp(dir=dir_name, suffix=".tmp") + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + f.flush() + os.fsync(f.fileno()) # Ensure data is on disk + os.replace(tmp_path, filepath) # Atomic rename + + # Attempt to fsync the directory so the rename is durable + if hasattr(os, "O_DIRECTORY"): + try: + dir_fd = os.open(dir_name, os.O_RDONLY | os.O_DIRECTORY) + try: + os.fsync(dir_fd) + finally: + os.close(dir_fd) + except OSError: + pass # Directory fsync not supported on all platforms + + except BaseException: + try: + os.unlink(tmp_path) + except OSError: + pass + raise + + +def _read_json(filepath: str): + if not os.path.exists(filepath): + raise FileNotFoundError(f"Persistence file not found: '{filepath}'") + with open(filepath, "r", encoding="utf-8") as f: + return json.load(f) + + +def _deserialize_block(data: dict) -> Block: + """Reconstruct a Block (including its transactions) from a plain dict.""" + transactions = [ + Transaction( + sender=tx["sender"], + receiver=tx["receiver"], + amount=tx["amount"], + nonce=tx["nonce"], + data=tx.get("data"), + signature=tx.get("signature"), + timestamp=tx["timestamp"], + ) + for tx in data.get("transactions", []) + ] + + block = Block( + index=data["index"], + previous_hash=data["previous_hash"], + transactions=transactions, + timestamp=data["timestamp"], + difficulty=data.get("difficulty"), + ) + block.nonce = data["nonce"] + block.hash = data["hash"] + # Only overwrite merkle_root if explicitly saved; otherwise keep computed value + if "merkle_root" in data: + block.merkle_root = data["merkle_root"] + return block diff --git a/tests/test_persistence.py b/tests/test_persistence.py new file mode 100644 index 0000000..e758227 --- /dev/null +++ b/tests/test_persistence.py @@ -0,0 +1,201 @@ +""" +Tests for chain persistence (save / load round-trip). +""" + +import json +import os +import shutil +import tempfile +import unittest + +from nacl.signing import SigningKey +from nacl.encoding import HexEncoder + +from minichain import Blockchain, Transaction, Block, mine_block +from minichain.persistence import save, load + + +def _make_keypair(): + sk = SigningKey.generate() + pk = sk.verify_key.encode(encoder=HexEncoder).decode() + return sk, pk + + +class TestPersistence(unittest.TestCase): + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + # Helpers + + def _chain_with_tx(self): + """Return a Blockchain that has one mined block with a transfer.""" + bc = Blockchain() + alice_sk, alice_pk = _make_keypair() + _, bob_pk = _make_keypair() + + bc.state.credit_mining_reward(alice_pk, 100) + + tx = Transaction(alice_pk, bob_pk, 30, 0) + tx.sign(alice_sk) + + block = Block( + index=1, + previous_hash=bc.last_block.hash, + transactions=[tx], + difficulty=1, + ) + mine_block(block, difficulty=1) + bc.add_block(block) + return bc, alice_pk, bob_pk + + # --- Basic save/load --- + + def test_save_creates_file(self): + bc = Blockchain() + save(bc, path=self.tmpdir) + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, "data.json"))) + + def test_chain_length_preserved(self): + bc, _, _ = self._chain_with_tx() + save(bc, path=self.tmpdir) + + restored = load(path=self.tmpdir) + self.assertEqual(len(restored.chain), len(bc.chain)) + + def test_block_hashes_preserved(self): + bc, _, _ = self._chain_with_tx() + save(bc, path=self.tmpdir) + + restored = load(path=self.tmpdir) + for original, loaded in zip(bc.chain, restored.chain): + self.assertEqual(original.hash, loaded.hash) + self.assertEqual(original.index, loaded.index) + self.assertEqual(original.previous_hash, loaded.previous_hash) + + def test_transaction_data_preserved(self): + bc, _, _ = self._chain_with_tx() + save(bc, path=self.tmpdir) + + restored = load(path=self.tmpdir) + original_tx = bc.chain[1].transactions[0] + loaded_tx = restored.chain[1].transactions[0] + + self.assertEqual(original_tx.sender, loaded_tx.sender) + self.assertEqual(original_tx.receiver, loaded_tx.receiver) + self.assertEqual(original_tx.amount, loaded_tx.amount) + self.assertEqual(original_tx.nonce, loaded_tx.nonce) + self.assertEqual(original_tx.signature, loaded_tx.signature) + + def test_genesis_only_chain(self): + bc = Blockchain() + save(bc, path=self.tmpdir) + restored = load(path=self.tmpdir) + + self.assertEqual(len(restored.chain), 1) + self.assertEqual(restored.chain[0].hash, "0" * 64) + + # --- State recomputation --- + + def test_state_recomputed_from_blocks(self): + """Balances must be recomputed by replaying blocks, not from a file.""" + bc, alice_pk, bob_pk = self._chain_with_tx() + save(bc, path=self.tmpdir) + + restored = load(path=self.tmpdir) + # Alice started with 100, sent 30 → 70 + self.assertEqual( + restored.state.get_account(alice_pk)["balance"], + bc.state.get_account(alice_pk)["balance"], + ) + # Bob received 30 + self.assertEqual( + restored.state.get_account(bob_pk)["balance"], + bc.state.get_account(bob_pk)["balance"], + ) + + # --- Integrity verification --- + + def test_tampered_hash_rejected(self): + """Loading a chain with a tampered block hash must raise ValueError.""" + bc, _, _ = self._chain_with_tx() + save(bc, path=self.tmpdir) + + # Tamper with block hash + chain_path = os.path.join(self.tmpdir, "data.json") + with open(chain_path, "r") as f: + data = json.load(f) + data["chain"][1]["hash"] = "deadbeef" * 8 + with open(chain_path, "w") as f: + json.dump(data, f) + + with self.assertRaises(ValueError): + load(path=self.tmpdir) + + def test_broken_linkage_rejected(self): + """Loading a chain with broken previous_hash linkage must raise.""" + bc, _, _ = self._chain_with_tx() + save(bc, path=self.tmpdir) + + chain_path = os.path.join(self.tmpdir, "data.json") + with open(chain_path, "r") as f: + data = json.load(f) + data["chain"][1]["previous_hash"] = "0" * 64 + "ff" + with open(chain_path, "w") as f: + json.dump(data, f) + + with self.assertRaises(ValueError): + load(path=self.tmpdir) + + # --- Crash safety --- + + def test_corrupted_json_raises(self): + """Half-written JSON must raise an error, not silently corrupt.""" + bc = Blockchain() + save(bc, path=self.tmpdir) + + # Corrupt the file + chain_path = os.path.join(self.tmpdir, "data.json") + with open(chain_path, "w") as f: + f.write('{"truncated": ') # invalid JSON + + with self.assertRaises(json.JSONDecodeError): + load(path=self.tmpdir) + + def test_missing_file_raises(self): + with self.assertRaises(FileNotFoundError): + load(path=self.tmpdir) # nothing saved yet + + # --- Chain continuity after load --- + + def test_loaded_chain_can_add_new_block(self): + """Restored chain must still accept new valid blocks.""" + bc, alice_pk, bob_pk = self._chain_with_tx() + save(bc, path=self.tmpdir) + + restored = load(path=self.tmpdir) + + # Build a second transfer using a new key + new_sk, new_pk = _make_keypair() + restored.state.credit_mining_reward(new_pk, 50) + + tx2 = Transaction(new_pk, bob_pk, 10, 0) + tx2.sign(new_sk) + + block2 = Block( + index=len(restored.chain), + previous_hash=restored.last_block.hash, + transactions=[tx2], + difficulty=1, + ) + mine_block(block2, difficulty=1) + + self.assertTrue(restored.add_block(block2)) + self.assertEqual(len(restored.chain), len(bc.chain) + 1) + + +if __name__ == "__main__": + unittest.main()