Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import argparse
import asyncio
import logging
import os
import re
import sys

from nacl.signing import SigningKey
Expand Down Expand Up @@ -271,11 +273,27 @@ async def cli_loop(sk, pk, chain, mempool, network):
# Main entry point
# ──────────────────────────────────────────────

async def run_node(port: int, connect_to: str | None, fund: int):
async def run_node(port: int, connect_to: str | None, fund: int, datadir: str | None):
"""Boot the node, optionally connect to a peer, then enter the CLI."""
sk, pk = create_wallet()

chain = Blockchain()
# Load existing chain from disk, or start fresh
chain = None
if datadir and os.path.exists(os.path.join(datadir, "data.json")):
try:
from minichain.persistence import load
chain = load(datadir)
logger.info("Restored chain from '%s'", datadir)
except FileNotFoundError as e:
logger.warning("Could not load saved chain: %s — starting fresh", e)
except ValueError as e:
logger.error("State data is corrupted or tampered: %s", e)
logger.error("Refusing to start to avoid overwriting corrupted data.")
sys.exit(1)

if chain is None:
chain = Blockchain()

mempool = Mempool()
network = P2PNetwork()

Expand Down Expand Up @@ -313,6 +331,14 @@ async def on_peer_connected(writer):
try:
await cli_loop(sk, pk, chain, mempool, network)
finally:
# Save chain to disk on shutdown
if datadir:
try:
from minichain.persistence import save
save(chain, datadir)
logger.info("Chain saved to '%s'", datadir)
except Exception as e:
logger.error("Failed to save chain during shutdown: %s", e)
await network.stop()


Expand All @@ -321,6 +347,7 @@ def main():
parser.add_argument("--port", type=int, default=9000, help="TCP port to listen on (default: 9000)")
parser.add_argument("--connect", type=str, default=None, help="Peer address to connect to (host:port)")
parser.add_argument("--fund", type=int, default=100, help="Initial coins to fund this wallet (default: 100)")
parser.add_argument("--datadir", type=str, default=None, help="Directory to save/load blockchain state (enables persistence)")
args = parser.parse_args()

logging.basicConfig(
Expand All @@ -330,7 +357,7 @@ def main():
)

try:
asyncio.run(run_node(args.port, args.connect, args.fund))
asyncio.run(run_node(args.port, args.connect, args.fund, args.datadir))
except KeyboardInterrupt:
print("\nNode shut down.")

Expand Down
3 changes: 3 additions & 0 deletions minichain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .contract import ContractMachine
from .p2p import P2PNetwork
from .mempool import Mempool
from .persistence import save, load

__all__ = [
"mine_block",
Expand All @@ -18,4 +19,6 @@
"ContractMachine",
"P2PNetwork",
"Mempool",
"save",
"load",
]
218 changes: 218 additions & 0 deletions minichain/persistence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
"""
Chain persistence: save and load the blockchain and state to/from JSON.
Design:
- blockchain.json holds the full list of serialised blocks
- state.json holds the accounts dict (includes off-chain credits)
Comment on lines +4 to +6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring describes two-file design but implementation uses single file.

The docstring mentions blockchain.json and state.json as separate files, but the implementation uses a single _DATA_FILE = "data.json" containing both chain and state. This inconsistency could confuse future maintainers.

📝 Proposed fix to update docstring
 Design:
-  - blockchain.json  holds the full list of serialised blocks
-  - state.json       holds the accounts dict (includes off-chain credits)
+  - data.json  holds both the serialised blocks and account state in a single
+               atomic snapshot to prevent torn writes on crash
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@minichain/persistence.py` around lines 4 - 6, The module docstring is out of
sync with the implementation: it describes separate files `blockchain.json` and
`state.json` but the code uses a single `_DATA_FILE = "data.json"`. Update the
top-level docstring in minichain/persistence.py to accurately describe the
single-file design (that `_DATA_FILE` stores both the serialized chain and the
accounts/state) and mention any format/keys used, referencing the `_DATA_FILE`
constant and the persistence class or functions (e.g., Persistence or load/save
methods) so maintainers can find where the file is read/written.

Both files are written atomically (temp → rename) to prevent corruption
on crash. On load, chain integrity is verified before the data is trusted.
Usage:
from minichain.persistence import save, load
save(blockchain, path="data/")
blockchain = load(path="data/")
"""

import json
import os
import tempfile
import logging
import copy

from .block import Block
from .transaction import Transaction
from .chain import Blockchain
from .state import State
from .pow import calculate_hash

logger = logging.getLogger(__name__)

_DATA_FILE = "data.json"


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

def save(blockchain: Blockchain, path: str = ".") -> None:
"""
Persist the blockchain and account state to a JSON file inside *path*.
Uses atomic write (write-to-temp → rename) with fsync so a crash mid-save
never corrupts the existing file. Chain and state are saved together to
prevent torn snapshots.
"""
os.makedirs(path, exist_ok=True)

with blockchain._lock: # Thread-safe: hold lock while serialising
chain_data = [block.to_dict() for block in blockchain.chain]
state_data = copy.deepcopy(blockchain.state.accounts)

snapshot = {
"chain": chain_data,
"state": state_data
}

_atomic_write_json(os.path.join(path, _DATA_FILE), snapshot)

logger.info(
"Saved %d blocks and %d accounts to '%s'",
len(chain_data),
len(state_data),
path,
)


def load(path: str = ".") -> Blockchain:
"""
Restore a Blockchain from the JSON file inside *path*.
Steps:
1. Load and deserialise blocks from data.json
2. Verify chain integrity (genesis, linkage, hashes)
3. Load account state
Raises:
FileNotFoundError: if data.json is missing.
ValueError: if data is invalid or integrity checks fail.
"""
data_path = os.path.join(path, _DATA_FILE)
snapshot = _read_json(data_path)

if not isinstance(snapshot, dict):
raise ValueError(f"Invalid snapshot data in '{data_path}'")

raw_blocks = snapshot.get("chain")
raw_accounts = snapshot.get("state")

if not isinstance(raw_blocks, list) or not raw_blocks:
raise ValueError(f"Invalid or empty chain data in '{data_path}'")
if not isinstance(raw_accounts, dict):
raise ValueError(f"Invalid accounts data in '{data_path}'")

blocks = [_deserialize_block(b) for b in raw_blocks]

# --- Integrity verification ---
_verify_chain_integrity(blocks)

# --- Rebuild blockchain properly (no __new__ hack) ---
blockchain = Blockchain() # creates genesis + fresh state
blockchain.chain = blocks # replace with loaded chain

# Restore state
blockchain.state.accounts = raw_accounts

logger.info(
"Loaded %d blocks and %d accounts from '%s'",
len(blockchain.chain),
len(blockchain.state.accounts),
path,
)
return blockchain


# ---------------------------------------------------------------------------
# Integrity verification
# ---------------------------------------------------------------------------

def _verify_chain_integrity(blocks: list) -> None:
"""Verify genesis, hash linkage, and block hashes."""
# Check genesis
genesis = blocks[0]
if genesis.index != 0 or genesis.hash != "0" * 64:
raise ValueError("Invalid genesis block")

# Check linkage and hashes for every subsequent block
for i in range(1, len(blocks)):
block = blocks[i]
prev = blocks[i - 1]

if block.index != prev.index + 1:
raise ValueError(
f"Block #{block.index}: index gap (expected {prev.index + 1})"
)

if block.previous_hash != prev.hash:
raise ValueError(
f"Block #{block.index}: previous_hash mismatch"
)

expected_hash = calculate_hash(block.to_header_dict())
if block.hash != expected_hash:
raise ValueError(
f"Block #{block.index}: hash mismatch "
f"(stored={block.hash[:16]}..., computed={expected_hash[:16]}...)"
)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _atomic_write_json(filepath: str, data) -> None:
"""Write JSON atomically with fsync for durability."""
dir_name = os.path.dirname(filepath) or "."
fd, tmp_path = tempfile.mkstemp(dir=dir_name, suffix=".tmp")
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
f.flush()
os.fsync(f.fileno()) # Ensure data is on disk
os.replace(tmp_path, filepath) # Atomic rename

# Attempt to fsync the directory so the rename is durable
if hasattr(os, "O_DIRECTORY"):
try:
dir_fd = os.open(dir_name, os.O_RDONLY | os.O_DIRECTORY)
try:
os.fsync(dir_fd)
finally:
os.close(dir_fd)
except OSError:
pass # Directory fsync not supported on all platforms

except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise


def _read_json(filepath: str):
if not os.path.exists(filepath):
raise FileNotFoundError(f"Persistence file not found: '{filepath}'")
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)


def _deserialize_block(data: dict) -> Block:
"""Reconstruct a Block (including its transactions) from a plain dict."""
transactions = [
Transaction(
sender=tx["sender"],
receiver=tx["receiver"],
amount=tx["amount"],
nonce=tx["nonce"],
data=tx.get("data"),
signature=tx.get("signature"),
timestamp=tx["timestamp"],
)
for tx in data.get("transactions", [])
]

block = Block(
index=data["index"],
previous_hash=data["previous_hash"],
transactions=transactions,
timestamp=data["timestamp"],
difficulty=data.get("difficulty"),
)
block.nonce = data["nonce"]
block.hash = data["hash"]
# Only overwrite merkle_root if explicitly saved; otherwise keep computed value
if "merkle_root" in data:
block.merkle_root = data["merkle_root"]
return block
Loading