diff --git a/.claude/settings.local.json b/.claude/settings.local.json index a57520a..f5cefa4 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -29,7 +29,9 @@ "Bash(find:*)", "Skill(frontend-design)", "Bash(FLASK_DEBUG=1 python:*)", - "Bash(docker compose:*)" + "Bash(docker compose:*)", + "Bash(wc:*)", + "Bash(grep:*)" ], "deny": [], "ask": [] diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..5a22eae --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,61 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## What this is + +WriteBot turns digital text into realistic handwritten SVG/PDF using a trained TensorFlow RNN. There are two layers: + +- `handwriting_synthesis/` — the ML engine: model loading, RNN stroke generation, and layout → SVG. No Flask dependency; usable standalone. +- `webapp/` — a multi-user Flask app (auth, batch jobs, character overrides, admin, usage stats) wrapping the engine. + +## Commands + +Python 3.10 in CI (local `.venv` is 3.11). `pip install -r requirements.txt` pulls TensorFlow + CUDA and is large. + +Run the app: +- Dev: `python webapp/app.py` — Flask dev server on port 5000, **single-threaded by design** (see model note below). +- Full stack: `docker compose up` — gunicorn (`webapp.app:app`), a Celery worker, Celery beat, Redis, and nginx. The Docker entrypoint runs `python webapp/init_db.py --auto` first to create/seed the DB. + +Tests (model-free, fast — these are the only automated tests): +- All: `python tests/test_operations.py` or `pytest tests/test_operations.py` +- One: `pytest tests/test_operations.py::test_balanced_breaks_at_punctuation` +- They cover `handwriting_synthesis/hand/operations/chunking.py`. Anything that touches `Hand` needs the model checkpoint and is slow, so it is not in the suite. + +Lint (as CI runs it): `flake8 webapp --select=E9,F63,F7,F82` (hard gate on syntax/undefined names) and `pylint $(git ls-files '*.py')`. + +DB migrations (Flask-Migrate/Alembic, `migrations/`): `FLASK_APP=webapp.app:app flask db migrate -m "..."` then `flask db upgrade`. + +CLI batch generation: `python scripts/batch_generate.py` (see `test_batch.csv` for the expected column layout). + +## The model — read before changing any generation code + +- `Hand()` (`handwriting_synthesis/hand/Hand.py`) builds a `tf.compat.v1` graph and restores a checkpoint (`model/checkpoint/model-17900.*`) plus style-priming arrays (`model/style/style-N-*.npy`). Construction is expensive and the TF session is effectively a **process-global singleton**: route modules instantiate `hand = Hand()` at import time, and `webapp/tasks.py` caches a single lazy `_hand_instance`. This is why the dev server is single-threaded and the Celery worker runs `-P solo --concurrency 1`. Do not assume concurrent calls are safe. +- The model only understands a fixed, restricted ASCII alphabet (`handwriting_synthesis/drawing/operations.py: alphabet`). Input must be normalized to it via `webapp/utils/text_utils.py: normalize_text_for_model` before generation. Limits: `MAX_CHAR_LEN=120` per line, `MAX_STROKE_LEN=2400`. Retraining is out of scope — improvements happen at the chunking / stitching / layout level, not the weights. + +## Generation pipeline (the part that requires reading several files) + +1. `webapp/routes/{generation,batch,job}_routes.py` receive a request with 30+ parameters. +2. `webapp/utils/generation_utils.py`: `parse_generation_params()` normalizes them; `generate_handwriting_to_file()` dispatches. +3. Text is normalized to the model alphabet. In **non-chunked** mode it is also wrapped to the page width by `text_processor.py` (`TextProcessor`, via `text_utils.wrap_by_canvas`). +4. `Hand.write()` (one RNN sample per line) or `Hand.write_chunked()` (the default) runs. Chunked mode splits each line into small chunks, samples each, and stitches them into lines using **measured** stroke widths — better line filling and shorter RNN sequences. The stages map to `operations/`: `chunking.py` (text → chunks), `sampling.py` (RNN inference), `stroke_ops.py` (stitch / baseline / rotation). +5. `handwriting_synthesis/hand/_draw.py: _draw()` does all page layout: unit conversion (`PX_PER_MM = 96/25.4`), auto-sizing strokes to fit the content box, alignment, line-height, baseline/margin jitter, and SVG emission via `svgwrite`. + +Chunked vs non-chunked is the `use_chunked` flag (default true). Generation defaults (chunking strategy, words/chars per chunk, page, style/bias) live in `config.json` under `defaults`. + +## Character overrides ("character insert") + +Users upload custom SVG glyphs for specific characters that get injected into otherwise model-generated handwriting. Persisted as `CharacterOverride` / `CharacterOverrideCollection` (`webapp/models.py`); helpers in `handwriting_synthesis/hand/character_override_utils.py`; rendered in `_draw.py`'s override path. The current approach generates the line with placeholder spaces, then uses the model's attention `char_indices` to cut the strokes precisely and shift them to open a gap for the inserted glyph (`_render_strokes_with_overrides`, `override_positions`). Override characters are exempt from alphabet validation. This subsystem is intricate and changes often — trace the `char_indices` / `override_positions` flow end-to-end before modifying it. Uploaded SVG is untrusted input. + +## Web app layout + +- Entry: `webapp/app.py` (`webapp.app:app`); Flask extensions in `webapp/extensions.py`; SQLAlchemy models in `webapp/models.py` (User, CharacterOverride(Collection), BatchJob, PageSize/TemplatePreset, Usage/Activity). +- Routes are split by concern under `webapp/routes/`; reusable logic lives in `webapp/utils/` (`generation_utils`, `text_utils`, `page_utils`, `secure_urls`, `auth_utils`). +- Async/batch work goes through Celery (`webapp/celery_app.py`, `webapp/tasks.py`, Redis broker) and the `BatchJob` model. +- Runtime config comes from env (`DATABASE_URL`, `REDIS_URL`, `SECRET_KEY`, Sentry, mail) — see `.env.example`. + +## Gotchas + +- Page geometry is computed in pixels internally but user-facing values may be mm or px (`units`). `PX_PER_MM` and the paper-size table are defined in **both** `_draw.py` and `webapp/utils/page_utils.py` — keep them consistent. +- `legibility` (`high` | `normal` | `natural`) sets jitter/interpolation defaults in `_draw.py`; `high` disables all randomness, which is what you want for deterministic output or tests. +- The RNN/TF code (`handwriting_synthesis/{rnn,tf}/`) uses graph-mode `tf.compat.v1` and legacy Keras (`tf-keras`, `TF_USE_LEGACY_KERAS=1`); it is not idiomatic TF2. diff --git a/deploy/db-migrate.sh b/deploy/db-migrate.sh index e16cd0f..dbbdae1 100644 --- a/deploy/db-migrate.sh +++ b/deploy/db-migrate.sh @@ -52,7 +52,7 @@ fi case $COMMAND in upgrade) log_info "Applying pending migrations..." - docker exec ${CONTAINER} flask db upgrade + docker exec -e FLASK_APP=webapp.app:app ${CONTAINER} flask db upgrade log_info "Migrations applied successfully" ;; @@ -60,7 +60,7 @@ case $COMMAND in log_warn "This will revert the last migration!" read -p "Are you sure? (y/N) " confirm if [ "$confirm" = "y" ] || [ "$confirm" = "Y" ]; then - docker exec ${CONTAINER} flask db downgrade + docker exec -e FLASK_APP=webapp.app:app ${CONTAINER} flask db downgrade log_info "Migration reverted" else log_info "Cancelled" @@ -69,36 +69,52 @@ case $COMMAND in current) log_info "Current migration version:" - docker exec ${CONTAINER} flask db current + docker exec -e FLASK_APP=webapp.app:app ${CONTAINER} flask db current ;; history) log_info "Migration history:" - docker exec ${CONTAINER} flask db history + docker exec -e FLASK_APP=webapp.app:app ${CONTAINER} flask db history ;; migrate) MESSAGE="${2:-Auto-generated migration}" log_info "Generating new migration: ${MESSAGE}" - docker exec ${CONTAINER} flask db migrate -m "${MESSAGE}" + docker exec -e FLASK_APP=webapp.app:app ${CONTAINER} flask db migrate -m "${MESSAGE}" log_warn "Review the generated migration before applying!" log_info "Apply with: $0 upgrade" ;; heads) log_info "Current head revisions:" - docker exec ${CONTAINER} flask db heads + docker exec -e FLASK_APP=webapp.app:app ${CONTAINER} flask db heads ;; init) log_info "Initializing migrations directory..." - docker exec ${CONTAINER} flask db init + docker exec -e FLASK_APP=webapp.app:app ${CONTAINER} flask db init ;; stamp) REVISION="${2:-head}" log_info "Stamping database with revision: ${REVISION}" - docker exec ${CONTAINER} flask db stamp ${REVISION} + docker exec -e FLASK_APP=webapp.app:app ${CONTAINER} flask db stamp ${REVISION} + ;; + + check) + log_info "Checking Flask-Migrate installation..." + echo "" + echo "1. Checking if Flask-Migrate is installed:" + docker exec ${CONTAINER} pip show flask-migrate 2>/dev/null || log_error "Flask-Migrate NOT installed!" + echo "" + echo "2. Checking available Flask commands:" + docker exec -e FLASK_APP=webapp.app:app ${CONTAINER} flask --help 2>&1 | grep -E "(db|Commands)" || true + echo "" + echo "3. Checking FLASK_APP environment:" + docker exec ${CONTAINER} printenv | grep FLASK || log_warn "FLASK_APP not in container env" + echo "" + echo "4. Checking migrations directory:" + docker exec ${CONTAINER} ls -la /app/migrations 2>/dev/null || log_error "Migrations directory not found!" ;; *) @@ -113,6 +129,7 @@ case $COMMAND in echo " heads Show current head revisions" echo " init Initialize migrations (first time only)" echo " stamp Mark database at specific revision" + echo " check Verify Flask-Migrate installation and setup" echo "" echo "Options:" echo " --production Use production container" diff --git a/handwriting_synthesis/hand/Hand.py b/handwriting_synthesis/hand/Hand.py index 9f24134..b74f395 100644 --- a/handwriting_synthesis/hand/Hand.py +++ b/handwriting_synthesis/hand/Hand.py @@ -19,6 +19,7 @@ get_stroke_width, stitch_strokes, split_text_into_chunks, + balanced_line_breaks, sample_strokes, ) @@ -76,6 +77,7 @@ def write( empty_line_spacing=None, auto_size=True, manual_size_scale=1.0, + writing_size_mm=None, character_override_collection_id=None, margin_jitter_frac=None, margin_jitter_coherence=None, @@ -176,70 +178,65 @@ def _normalize_seq(value, desired_len, cast_fn=None, name='param'): stroke_colors = _normalize_seq(stroke_colors, num_lines, str, 'stroke_colors') stroke_widths = _normalize_seq(stroke_widths, num_lines, float, 'stroke_widths') - # Split lines with character overrides + # Handle character overrides using SPACE PLACEHOLDER approach + # Key insight: Generate full lines with SPACES where overrides go. + # The space creates a natural gap in the stroke sequence (pen lift). + # We then insert the override SVG into that gap - no stroke clipping needed! + # This preserves full RNN context for the surrounding text. if overrides_dict: - print(f"DEBUG: Processing text with overrides enabled") - from handwriting_synthesis.hand.character_override_utils import split_text_with_overrides + print(f"DEBUG: Processing text with SPACE-PLACEHOLDER override approach") - # Create expanded line data with override info - line_segments = [] - texts_to_generate = [] - segment_to_line_idx = [] + # Use SPACE as placeholder - creates natural gap in strokes + placeholder_char = ' ' - for line_idx, line in enumerate(lines): - print(f"DEBUG: Processing line {line_idx}: '{line}'") - chunks = split_text_with_overrides(line, overrides_dict) - print(f"DEBUG: Split into {len(chunks)} chunks: {chunks}") - line_segment_list = [] + # Track override positions: {line_idx: [(char_idx, original_char), ...]} + override_positions = {} + modified_lines = [] - for chunk_text, is_override in chunks: - print(f"DEBUG: Chunk: '{chunk_text}', is_override={is_override}") - if is_override: - line_segment_list.append({ - 'type': 'override', - 'text': chunk_text, - 'line_idx': line_idx - }) + for line_idx, line in enumerate(lines): + override_positions[line_idx] = [] + modified_line_chars = [] + + for char_idx, char in enumerate(line): + if char in overrides_dict: + # Track the position and original character + override_positions[line_idx].append((char_idx, char)) + # Replace with SPACE - creates natural gap for override insertion + modified_line_chars.append(placeholder_char) + print(f"DEBUG: Line {line_idx}, char {char_idx}: replacing '{char}' with SPACE placeholder") else: - if chunk_text.strip(): # Only generate non-empty chunks - gen_idx = len(texts_to_generate) - texts_to_generate.append(chunk_text) - segment_to_line_idx.append(line_idx) - line_segment_list.append({ - 'type': 'generated', - 'gen_idx': gen_idx, - 'text': chunk_text, - 'line_idx': line_idx - }) - else: - # Empty space, generate it - gen_idx = len(texts_to_generate) - texts_to_generate.append(chunk_text) - segment_to_line_idx.append(line_idx) - line_segment_list.append({ - 'type': 'generated', - 'gen_idx': gen_idx, - 'text': chunk_text, - 'line_idx': line_idx - }) + modified_line_chars.append(char) - line_segments.append(line_segment_list) + modified_lines.append(''.join(modified_line_chars)) - print(f"DEBUG: Texts to generate: {texts_to_generate}") + print(f"DEBUG: Original lines: {lines}") + print(f"DEBUG: Modified lines (with placeholders): {modified_lines}") + print(f"DEBUG: Override positions: {override_positions}") - # Generate strokes for non-override chunks - if texts_to_generate: - gen_biases = [biases[idx] if biases else None for idx in segment_to_line_idx] - gen_styles = [styles[idx] if styles else None for idx in segment_to_line_idx] - generated_strokes = self._sample(texts_to_generate, biases=gen_biases, styles=gen_styles) - else: - generated_strokes = [] + # Generate strokes for FULL lines with CHAR INDICES from attention + # This gives us precise knowledge of where each character was written! + generated_strokes, char_indices_list = self._sample( + modified_lines, biases=biases, styles=styles, return_char_indices=True + ) - # Map generated strokes back to segments - for line_segment_list in line_segments: - for segment in line_segment_list: - if segment['type'] == 'generated': - segment['strokes'] = generated_strokes[segment['gen_idx']] + print(f"DEBUG: Got char_indices for {len(char_indices_list)} lines") + for i, ci in enumerate(char_indices_list): + print(f"DEBUG: Line {i}: {len(ci)} char indices, range [{ci.min() if len(ci) > 0 else 'N/A'}, {ci.max() if len(ci) > 0 else 'N/A'}]") + + # Convert to line_segments format (single segment per line, like non-override) + line_segments = [] + for line_idx, (original_line, strokes, char_indices) in enumerate( + zip(lines, generated_strokes, char_indices_list) + ): + line_segments.append([{ + 'type': 'generated', + 'text': original_line, # Keep original text for reference + 'modified_text': modified_lines[line_idx], # Text that was actually generated + 'strokes': strokes, + 'char_indices': char_indices, # NEW: Character index per stroke from attention + 'line_idx': line_idx, + 'override_positions': override_positions[line_idx] # [(char_idx, char), ...] + }]) else: # No overrides, use normal generation print(f"DEBUG: No overrides, using normal generation") @@ -279,13 +276,14 @@ def _normalize_seq(value, desired_len, cast_fn=None, name='param'): empty_line_spacing=empty_line_spacing, auto_size=auto_size, manual_size_scale=manual_size_scale, + writing_size_mm=writing_size_mm, character_override_collection_id=character_override_collection_id, overrides_dict=overrides_dict, margin_jitter_frac=margin_jitter_frac, margin_jitter_coherence=margin_jitter_coherence, ) - def _sample(self, lines, biases=None, styles=None): + def _sample(self, lines, biases=None, styles=None, return_char_indices=False): """ Sample stroke sequences from the RNN. @@ -293,11 +291,239 @@ def _sample(self, lines, biases=None, styles=None): lines: List of text lines biases: Optional biases styles: Optional styles + return_char_indices: If True, also return character indices per stroke + (from the attention mechanism) Returns: - List of stroke sequences + If return_char_indices is False: + List of stroke sequences + If return_char_indices is True: + Tuple of (strokes_list, char_indices_list) """ - return sample_strokes(self.nn.session, self.nn, lines, biases, styles) + return sample_strokes( + self.nn.session, self.nn, lines, biases, styles, + return_char_indices=return_char_indices + ) + + @staticmethod + def _measure_model_xheight(chunk_strokes): + """Median x-height of sampled chunks in raw model units. + + Measured EXACTLY the way ``_draw`` measures it for sizing: deslant via + ``drawing.align``, then take the robust body band. Measuring post-align + (rather than on the raw slanted strokes) is what makes wrap-time + predictions match the rendered scale regardless of handwriting style. + Returns ``None`` when no chunk is measurable. + """ + from handwriting_synthesis.hand._draw import _estimate_xheight + xheights = [] + for stroke in chunk_strokes: + if stroke is None or len(stroke) < 8: + continue + coords = drawing.offsets_to_coords(stroke) + coords[:, :2] = drawing.align(coords[:, :2]) + xheights.append(_estimate_xheight(coords)) + return float(np.median(xheights)) if xheights else None + + @staticmethod + def _estimate_stitched_xheight(sampled_lines, chunk_spacing, rotate_chunks, group_size=4): + """Measure stitched-line statistics: (x-height, width inflation factor). + + ``_draw`` derives its render scale from the x-height of the final + STITCHED lines, which reads systematically taller than individual chunks + (baseline joins and residual drift widen the percentile band). Predicting + wrap widths from per-chunk measurements therefore left every rendered + line ~15-20% short of the right margin. Stitching chunks in groups of + ``group_size`` (about one final line's worth) with the same stitch + parameters reproduces the statistics ``_draw`` will actually see. + + The width factor is how much wider a stitched line measures than the sum + of its chunk widths plus spacing (adaptive gaps and rotation correction + widen it, by a style-dependent amount). The wrap budget must be deflated + by it, otherwise the widest lines overrun the page and the global width + clamp trades the writing size away to fit them. + + ``sampled_lines`` is the per-input-line list of ``None`` or + ``(chunks, strokes)`` built by ``write_chunked``; groups never span + input lines. Returns ``(None, 1.0)`` when nothing is measurable. + """ + from handwriting_synthesis.hand._draw import _estimate_xheight + bands = [] + width_factors = [] + for entry in sampled_lines: + if not entry: + continue + strokes = [s for s in entry[1] if s is not None and len(s) >= 8] + for i in range(0, len(strokes), group_size): + group = strokes[i:i + group_size] + predicted_w = (sum(get_stroke_width(s) for s in group) + + chunk_spacing * (len(group) - 1)) + stitched = group[0] + for nxt in group[1:]: + stitched = stitch_strokes(stitched, nxt, chunk_spacing, + rotate_to_match=rotate_chunks) + coords = drawing.offsets_to_coords(stitched) + coords[:, :2] = drawing.align(coords[:, :2]) + bands.append(_estimate_xheight(coords)) + actual_w = float(coords[:, 0].max() - coords[:, 0].min()) + if len(group) > 1 and predicted_w > 1e-6: + width_factors.append(actual_w / predicted_w) + if not bands: + return None, 1.0 + factor = float(np.median(width_factors)) if width_factors else 1.0 + return float(np.median(bands)), min(max(factor, 1.0), 1.3) + + @staticmethod + def _content_box_px(page_size, units, margins, orientation): + """Resolve the page's content box (width_px, height_px) inside margins.""" + from handwriting_synthesis.hand._draw import _resolve_page_size, _normalize_margins + width_px, height_px, _ = _resolve_page_size(page_size, units, 1, 60.0) + if orientation == 'landscape': + width_px, height_px = height_px, width_px + m_top, m_right, m_bottom, m_left = _normalize_margins(margins, units) + return (max(1.0, width_px - (m_left + m_right)), + max(1.0, height_px - (m_top + m_bottom))) + + # Average rendered character advance as a fraction of the x-height. Used only + # to pre-estimate line capacity for chunk granularity, before any sampling. + _CHAR_ADVANCE_PER_XHEIGHT = 0.55 + + def _adaptive_chunk_chars( + self, input_lines, target_chars_per_chunk, page_size, units, margins, + orientation, x_stretch, writing_size_mm, auto_size, + ): + """Scale chunk granularity to the writing size, before any sampling. + + At larger writing sizes a line holds fewer characters, so fixed ~25-char + chunks quantize badly: a line that holds 60 chars wraps after 2 chunks and + leaves up to a third of the width empty. This pre-estimates the page-fill + x-height from text statistics alone (the per-style stroke scale cancels + out of the solver when widths are expressed per character) and targets + ~3.5 chunks per line, clamped to [12, target_chars_per_chunk]. + """ + try: + from handwriting_synthesis.hand._draw import ( + solve_fill_xheight_px, PX_PER_MM, + NATURAL_WRITING_MIN_FILL_MM, NATURAL_WRITING_MAX_FILL_MM, + ) + if not auto_size: + return target_chars_per_chunk + n_chars = sum(len(ln.strip()) for ln in input_lines if ln.strip()) + n_blank = sum(1 for ln in input_lines if not ln.strip()) + if n_chars == 0: + return target_chars_per_chunk + content_w, content_h = self._content_box_px(page_size, units, margins, orientation) + xs = float(x_stretch) if x_stretch and float(x_stretch) > 0 else 1.0 + if writing_size_mm is not None: + h_px = float(writing_size_mm) * PX_PER_MM + else: + # Express widths per character: raw width ~= chars * advance * xh, + # so passing model_xheight=1 cancels the style-dependent scale. + h_px = solve_fill_xheight_px( + n_chars * self._CHAR_ADVANCE_PER_XHEIGHT, 1.0, n_blank, + content_w, content_h, xs) + if not h_px: + return target_chars_per_chunk + h_px = min(max(h_px, NATURAL_WRITING_MIN_FILL_MM * PX_PER_MM), + NATURAL_WRITING_MAX_FILL_MM * PX_PER_MM) + # ~4.5 chunks per line: finer quanta let the line breaker land close + # to the margin (with ~3 chunks/line the achievable widths step by + # ~30%, which is what made wrapping look conservative). + chars_per_line = content_w / (self._CHAR_ADVANCE_PER_XHEIGHT * h_px * xs) + return int(min(target_chars_per_chunk, max(10, chars_per_line / 4.5))) + except Exception as exc: # granularity heuristic must never block generation + print(f"Warning: adaptive chunk sizing failed, using default: {exc}") + return target_chars_per_chunk + + def _auto_fill_writing_size( + self, chunk_strokes, n_blank_lines, page_size, units, margins, + orientation, x_stretch, model_xheight=None, total_raw_width=None, + ): + """Pick a writing size (mm) so the text fills the page vertically. + + A short letter at the base natural size only covers the top of the page; + a person writing the same letter by hand would simply write larger, and + someone with too much to say writes smaller to keep it on one page. This + solves for the x-height at which the wrapped text spans the content + height, clamped to [NATURAL_WRITING_MIN_FILL_MM, NATURAL_WRITING_MAX_FILL_MM]. + Returns mm, or ``None`` to keep the default. + """ + try: + from handwriting_synthesis.hand._draw import ( + solve_fill_xheight_px, PX_PER_MM, + NATURAL_WRITING_MIN_FILL_MM, NATURAL_WRITING_MAX_FILL_MM, + ) + if model_xheight is None: + model_xheight = self._measure_model_xheight(chunk_strokes) + if not model_xheight: + return None + content_w, content_h = self._content_box_px(page_size, units, margins, orientation) + if total_raw_width is None: + total_raw_width = sum( + get_stroke_width(s) for s in chunk_strokes + if s is not None and len(s) > 0 + ) + xs = float(x_stretch) if x_stretch else 1.0 + if xs <= 0: + xs = 1.0 + h_px = solve_fill_xheight_px( + total_raw_width, model_xheight, n_blank_lines, content_w, content_h, xs) + if not h_px: + return None + # Short text -> grow toward the max; long text -> shrink toward the + # min so it fits one page with full-width lines (the wrap width and + # the rendered size must come from the SAME x-height). + return min(max(h_px / PX_PER_MM, NATURAL_WRITING_MIN_FILL_MM), + NATURAL_WRITING_MAX_FILL_MM) + except Exception as exc: # never block generation on a sizing heuristic + print(f"Warning: page-fill sizing failed, using default size: {exc}") + return None + + def _size_aware_max_line_width( + self, chunk_strokes, max_line_width, page_size, units, margins, + orientation, writing_size_mm, x_stretch, auto_size, model_xheight=None, + ): + """Cap the wrap width so a full line still renders at the natural size. + + The chunked wrapper measures widths in raw model units, but ``_draw`` + renders in page pixels. If lines are allowed to grow to ``max_line_width`` + raw units, a full line ends up wider than the page at the natural + x-height, so ``_draw`` shrinks ALL text to make it fit -- which is what + makes the handwriting come out small. Here we cap the wrap width to the + raw width that exactly fills the page's content box at the target + x-height, so letters keep their natural size and a line simply holds + fewer words (wrapping to the next line instead of shrinking). + + Returns the (possibly reduced) wrap width in raw units. Only applies when + auto-sizing; returns ``max_line_width`` unchanged otherwise. + """ + if not auto_size or not chunk_strokes: + return max_line_width + try: + from handwriting_synthesis.hand._draw import PX_PER_MM, NATURAL_WRITING_SIZE_MM + + content_width_px, _ = self._content_box_px(page_size, units, margins, orientation) + if model_xheight is None: + model_xheight = self._measure_model_xheight(chunk_strokes) + if not model_xheight: + return max_line_width + + target_mm = NATURAL_WRITING_SIZE_MM if writing_size_mm is None else float(writing_size_mm) + target_xheight_px = max(1.0, target_mm * PX_PER_MM) + xs = float(x_stretch) if x_stretch else 1.0 + if xs <= 0: + xs = 1.0 + + # Raw width whose rendered width == content width at the target size: + # rendered_width = raw_width * s_render * x_stretch, and + # s_render = target_xheight_px / model_xheight (height-driven sizing) + # => raw_width that fills the page = content_width * model_xheight + # / (target_xheight_px * x_stretch). No empirical fudge factor needed. + fit_raw = content_width_px * model_xheight / (target_xheight_px * xs) + return max(1.0, min(float(max_line_width), fit_raw)) + except Exception as exc: # never block generation on a sizing heuristic + print(f"Warning: size-aware wrap width failed, using max_line_width: {exc}") + return max_line_width def write_chunked( self, @@ -330,6 +556,7 @@ def write_chunked( empty_line_spacing=None, auto_size=True, manual_size_scale=1.0, + writing_size_mm=None, character_override_collection_id=None, margin_jitter_frac=None, margin_jitter_coherence=None, @@ -383,9 +610,17 @@ def write_chunked( all_lines = [] all_line_texts = [] - # If we have overrides, we need to handle text splitting differently + # The writing size actually rendered. The no-override path may grow this + # (page-fill auto sizing) when the user did not request an explicit size. + effective_writing_size_mm = writing_size_mm + + # If we have overrides, use SPACE-PLACEHOLDER approach + # KEY FIX: Chunk the ORIGINAL text first, THEN replace override chars in each chunk. + # This preserves the position mapping between chunks and the original text. if overrides_dict: - from handwriting_synthesis.hand.character_override_utils import split_text_with_overrides + from handwriting_synthesis.hand.character_override_utils import estimate_override_width, get_random_override + + print(f"DEBUG write_chunked: Using SPACE-PLACEHOLDER approach for overrides") # Track segments for each line (will be used later for line_segments) all_line_segment_data = [] @@ -398,51 +633,52 @@ def write_chunked( all_line_segment_data.append([]) continue - # Split line into override and non-override chunks - text_chunks = split_text_with_overrides(input_line, overrides_dict) - - # Process each chunk - line_segments_data = [] - texts_to_generate = [] - chunk_metadata = [] - - for chunk_text, is_override in text_chunks: - if is_override: - # Mark as override - will be handled during drawing - line_segments_data.append({ - 'type': 'override', - 'text': chunk_text, - 'is_override': True - }) - else: - # Non-override text - chunk it and prepare for generation - sub_chunks = split_text_into_chunks( - chunk_text, - words_per_chunk=words_per_chunk, - target_chars_per_chunk=target_chars_per_chunk, - min_words=min_words_per_chunk, - max_words=max_words_per_chunk, - adaptive_chunking=adaptive_chunking, - adaptive_strategy=adaptive_strategy - ) - - for sub_chunk in sub_chunks: - gen_idx = len(texts_to_generate) - texts_to_generate.append(sub_chunk) - chunk_metadata.append({ - 'gen_idx': gen_idx, - 'text': sub_chunk - }) - line_segments_data.append({ - 'type': 'generated', - 'text': sub_chunk, - 'gen_idx': gen_idx, - 'is_override': False - }) - - # Validate characters in texts to generate + # STEP 1: Chunk the ORIGINAL text first (before any modification) + # This preserves word boundaries and spacing correctly + original_chunks = split_text_into_chunks( + input_line, + words_per_chunk=words_per_chunk, + target_chars_per_chunk=target_chars_per_chunk, + min_words=min_words_per_chunk, + max_words=max_words_per_chunk, + adaptive_chunking=adaptive_chunking, + adaptive_strategy=adaptive_strategy + ) + + if not original_chunks: + all_lines.append(np.empty((0, 3))) + all_line_texts.append('') + all_line_segment_data.append([]) + continue + + print(f"DEBUG: Original line: '{input_line}'") + print(f"DEBUG: Original chunks: {original_chunks}") + + # STEP 2: For each chunk, identify overrides and create modified version + modified_chunks = [] # Chunks with override chars replaced by spaces + chunk_override_info = [] # Override positions for each chunk + + for chunk_idx, original_chunk in enumerate(original_chunks): + chunk_overrides = [] # [(local_idx, char), ...] + modified_chars = [] + + for char_idx, char in enumerate(original_chunk): + if char in overrides_dict: + chunk_overrides.append((char_idx, char)) + modified_chars.append(' ') # Space placeholder + print(f"DEBUG: Chunk {chunk_idx}: Replacing '{char}' at local position {char_idx} with SPACE") + else: + modified_chars.append(char) + + modified_chunk = ''.join(modified_chars) + modified_chunks.append(modified_chunk) + chunk_override_info.append(chunk_overrides) + + print(f"DEBUG: Chunk {chunk_idx}: original='{original_chunk}' modified='{modified_chunk}' overrides={chunk_overrides}") + + # STEP 3: Validate modified chunks (should only contain valid alphabet chars) valid_char_set = set(drawing.alphabet) - for chunk_num, chunk in enumerate(texts_to_generate): + for chunk_num, chunk in enumerate(modified_chunks): for char in chunk: if char not in valid_char_set: raise ValueError( @@ -450,120 +686,108 @@ def write_chunked( f"Valid character set is {valid_char_set}" ) - # Generate strokes for non-override chunks only - if texts_to_generate: - chunk_strokes = self._sample( - texts_to_generate, - biases=[biases] * len(texts_to_generate) if biases is not None else None, - styles=[styles] * len(texts_to_generate) if styles is not None else None - ) + # STEP 4: Generate strokes for modified chunks WITH char_indices + # NOTE: char_indices offset is detected automatically from min(char_indices) in each chunk + chunk_strokes, chunk_char_indices = self._sample( + modified_chunks, + biases=[biases] * len(modified_chunks) if biases is not None else None, + styles=[styles] * len(modified_chunks) if styles is not None else None, + return_char_indices=True # Get char indices from attention + ) + + print(f"DEBUG: Generated {len(modified_chunks)} chunks with char_indices") - # Map generated strokes back to segments - for segment in line_segments_data: - if segment['type'] == 'generated': - segment['strokes'] = chunk_strokes[segment['gen_idx']] - else: - chunk_strokes = [] + # Wrap to the page at the natural size: cap line width so a full + # line renders at the target x-height instead of being shrunk. + effective_max_line_width = self._size_aware_max_line_width( + chunk_strokes, max_line_width, page_size, units, margins, + orientation, writing_size_mm, x_stretch, auto_size, + ) - # Now stitch the generated chunks together, handling overrides + # STEP 6: Build segment data with override info for each chunk + # Stitch chunks into lines based on actual widths current_line_stroke = np.empty((0, 3)) current_line_text = [] current_line_width = 0.0 current_line_segment_list = [] - for seg_idx, segment in enumerate(line_segments_data): - if segment['type'] == 'override': - # Estimate override width for layout - from handwriting_synthesis.hand.character_override_utils import get_random_override, estimate_override_width - override_data = get_random_override(overrides_dict, segment['text']) + for chunk_idx, (original_chunk, modified_chunk, chunk_stroke, char_indices, chunk_overrides) in enumerate( + zip(original_chunks, modified_chunks, chunk_strokes, chunk_char_indices, chunk_override_info) + ): + has_overrides = len(chunk_overrides) > 0 + + # CRITICAL FIX: Detect the actual char_indices offset from the data itself + # char_indices from the model start at min(char_indices), not 0 + # This accounts for style priming and any other offsets automatically + actual_offset = int(char_indices.min()) if len(char_indices) > 0 else 0 + + # Adjust override positions using the detected offset + adjusted_overrides = [(local_idx + actual_offset, char) for local_idx, char in chunk_overrides] + + print(f"DEBUG: Processing chunk {chunk_idx} '{modified_chunk}': has_overrides={has_overrides}") + print(f"DEBUG: Original positions: {chunk_overrides}") + print(f"DEBUG: Detected char_indices offset: {actual_offset}") + print(f"DEBUG: Adjusted positions: {adjusted_overrides}") + if has_overrides: + print(f"DEBUG: char_indices range: [{char_indices.min()}, {char_indices.max()}], unique values: {np.unique(char_indices)[:20]}...") + + # Calculate chunk width (including estimated override widths) + chunk_width = get_stroke_width(chunk_stroke) + + # For width calculation, estimate how much extra space overrides need + extra_override_width = 0.0 + for local_idx, override_char in chunk_overrides: + override_data = get_random_override(overrides_dict, override_char) if override_data: - # Estimate width (using typical line height of 60px) - override_width = estimate_override_width(override_data, target_height=60, x_stretch=1.0) - else: - override_width = 20 # fallback width - - # FIXED: Check for adjacent spaces and apply appropriate spacing - # This matches the logic in _draw.py for consistent line breaking - has_space_before = False - if seg_idx > 0: - prev_seg = line_segments_data[seg_idx - 1] - if prev_seg.get('type') == 'generated': - prev_text = prev_seg.get('text', '') - has_space_before = prev_text.strip() == '' or prev_text.endswith(' ') - - has_space_after = False - if seg_idx < len(line_segments_data) - 1: - next_seg = line_segments_data[seg_idx + 1] - if next_seg.get('type') == 'generated': - next_text = next_seg.get('text', '') - has_space_after = next_text.strip() == '' or next_text.startswith(' ') - - # When there's a space adjacent, use space-width spacing - # When there's no space, use minimal character spacing - space_width = override_width * 0.35 - spacing_before = space_width if has_space_before else override_width * 0.15 - spacing_after = space_width if has_space_after else override_width * 0.15 - override_width_with_spacing = spacing_before + override_width + spacing_after - - potential_width = current_line_width - if current_line_width > 0: - potential_width += override_width_with_spacing - else: - potential_width = override_width_with_spacing + override_w = estimate_override_width(override_data, target_height=60, x_stretch=1.0) + extra_override_width += override_w + (override_w * 0.3) - if potential_width <= max_line_width or current_line_width == 0: - # Fits on current line - current_line_text.append(segment['text']) - current_line_segment_list.append(segment) - current_line_width = potential_width - else: - # Start new line - if len(current_line_stroke) > 0 or len(current_line_text) > 0: - all_lines.append(current_line_stroke) - all_line_texts.append(''.join(current_line_text)) - all_line_segment_data.append(current_line_segment_list) - - current_line_stroke = np.empty((0, 3)) - current_line_text = [segment['text']] - current_line_segment_list = [segment] - current_line_width = override_width_with_spacing - else: - # Generated chunk - chunk_stroke = segment['strokes'] - chunk_width = get_stroke_width(chunk_stroke) + effective_chunk_width = chunk_width + extra_override_width - # Check if chunk fits on current line - potential_width = current_line_width + # Check if chunk fits on current line + potential_width = current_line_width + if current_line_width > 0: + potential_width += chunk_spacing + effective_chunk_width + else: + potential_width = effective_chunk_width + + # Build segment data + # NOTE: 'text' is the MODIFIED chunk (what was generated) + # override_positions are ADJUSTED for style offset (to match char_indices) + segment = { + 'type': 'generated', + 'text': modified_chunk, # Text that was generated (with spaces) + 'original_text': original_chunk, # Original text (with override chars) + 'strokes': chunk_stroke, + 'char_indices': char_indices, # Attention-based character indices + 'override_positions': adjusted_overrides, # [(adjusted_idx, char), ...] - ADJUSTED for style offset + } + + if potential_width <= effective_max_line_width or current_line_width == 0: + # Chunk fits on current line if current_line_width > 0: - potential_width += chunk_spacing + chunk_width - else: - potential_width = chunk_width - - if potential_width <= max_line_width or current_line_width == 0: - # Chunk fits on current line - if current_line_width > 0: - current_line_stroke = stitch_strokes( - current_line_stroke, - chunk_stroke, - chunk_spacing, - rotate_to_match=rotate_chunks - ) - else: - current_line_stroke = chunk_stroke - current_line_text.append(segment['text']) - current_line_segment_list.append(segment) - current_line_width = potential_width + current_line_stroke = stitch_strokes( + current_line_stroke, + chunk_stroke, + chunk_spacing, + rotate_to_match=rotate_chunks + ) else: - # Start new line (width exceeded) - if len(current_line_stroke) > 0 or len(current_line_text) > 0: - all_lines.append(current_line_stroke) - all_line_texts.append(''.join(current_line_text)) - all_line_segment_data.append(current_line_segment_list) - current_line_stroke = chunk_stroke - current_line_text = [segment['text']] - current_line_segment_list = [segment] - current_line_width = chunk_width + current_line_text.append(original_chunk) + current_line_segment_list.append(segment) + current_line_width = potential_width + else: + # Start new line (width exceeded) + if len(current_line_stroke) > 0 or len(current_line_text) > 0: + all_lines.append(current_line_stroke) + all_line_texts.append(''.join(current_line_text)) + all_line_segment_data.append(current_line_segment_list) + + current_line_stroke = chunk_stroke + current_line_text = [original_chunk] + current_line_segment_list = [segment] + current_line_width = effective_chunk_width # Add last line from this input line if len(current_line_stroke) > 0 or len(current_line_text) > 0: @@ -574,32 +798,36 @@ def write_chunked( # No overrides - use original logic all_line_segment_data = None + # PHASE 1: chunk + sample every input line up front, so the writing + # size can be chosen from the WHOLE text before any line wrapping. + # Chunk granularity follows the estimated writing size: larger writing + # means fewer characters per line, which needs smaller chunks to wrap + # without leaving big quantization gaps at the right margin. + eff_target_chars = self._adaptive_chunk_chars( + input_lines, target_chars_per_chunk, page_size, units, margins, + orientation, x_stretch, writing_size_mm, auto_size, + ) + valid_char_set = set(drawing.alphabet) + sampled_lines = [] # per input line: None (blank) or (chunks, strokes) for input_line in input_lines: - # Handle blank lines if not input_line.strip(): - all_lines.append(np.empty((0, 3))) - all_line_texts.append('') + sampled_lines.append(None) continue # Split line into chunks with adaptive sizing chunks = split_text_into_chunks( input_line, words_per_chunk=words_per_chunk, - target_chars_per_chunk=target_chars_per_chunk, + target_chars_per_chunk=eff_target_chars, min_words=min_words_per_chunk, max_words=max_words_per_chunk, adaptive_chunking=adaptive_chunking, adaptive_strategy=adaptive_strategy ) - if not chunks: - all_lines.append(np.empty((0, 3))) - all_line_texts.append('') + sampled_lines.append(None) continue - # Expand valid character set with overrides - valid_char_set = set(drawing.alphabet) - # Validate characters for chunk_num, chunk in enumerate(chunks): for char in chunk: @@ -609,55 +837,76 @@ def write_chunked( f"Valid character set is {valid_char_set}" ) - # Generate strokes for all chunks chunk_strokes = self._sample( chunks, biases=[biases] * len(chunks) if biases is not None else None, styles=[styles] * len(chunks) if styles is not None else None ) + sampled_lines.append((chunks, chunk_strokes)) + + # PHASE 2: choose the writing size, then the wrap width that fills the + # page at that size. When the user did not pick a size, grow it (up to + # a natural cap) so short texts fill the page vertically instead of + # leaving the bottom half empty -- the way a real one-page letter is + # simply written larger. Long texts stay at the base natural size. + all_strokes_flat = [s for entry in sampled_lines if entry for s in entry[1]] + n_blank_lines = sum(1 for entry in sampled_lines if entry is None) + # Measure x-height and width inflation on line-sized STITCHED groups + # -- the statistics _draw actually renders with -- so wrap-time + # predictions match the rendered output. + stitched_xheight, stitch_width_factor = self._estimate_stitched_xheight( + sampled_lines, chunk_spacing, rotate_chunks) + if auto_size and writing_size_mm is None and all_strokes_flat: + # Stitched lines come out wider than the sum of their chunks, so + # the text effectively occupies stitch_width_factor more width + # when the solver estimates how many lines it will wrap into. + total_w = sum(get_stroke_width(s) for s in all_strokes_flat + if s is not None and len(s) > 0) + fitted_mm = self._auto_fill_writing_size( + all_strokes_flat, n_blank_lines, page_size, units, margins, + orientation, x_stretch, model_xheight=stitched_xheight, + total_raw_width=total_w * stitch_width_factor, + ) + if fitted_mm: + effective_writing_size_mm = fitted_mm + + effective_max_line_width = self._size_aware_max_line_width( + all_strokes_flat, max_line_width, page_size, units, margins, + orientation, effective_writing_size_mm, x_stretch, auto_size, + model_xheight=stitched_xheight, + ) + # Deflate the budget by the stitch widening: the DP below compares it + # against SUMS of chunk widths, but the stitched line will measure + # stitch_width_factor wider -- without this the widest lines overrun + # the page and the width clamp shrinks the writing size to fit them. + effective_max_line_width /= stitch_width_factor + # Allow a squeeze past the wrap limit: a writer fits one more word by + # tightening slightly rather than leaving a ragged gap. _draw condenses + # such lines by the same few percent per line (line_scale_x). + from handwriting_synthesis.hand._draw import LINE_SQUEEZE_TOLERANCE + squeeze_limit = effective_max_line_width * LINE_SQUEEZE_TOLERANCE + + # PHASE 3: break chunks into balanced lines (raggedness spread evenly + # rather than greedy fill), then stitch each line's chunks together. + for entry in sampled_lines: + if entry is None: + all_lines.append(np.empty((0, 3))) + all_line_texts.append('') + continue + chunks, chunk_strokes = entry - # Stitch chunks into lines based on actual widths - current_line_stroke = np.empty((0, 3)) - current_line_text = [] - current_line_width = 0.0 - - for chunk_text, chunk_stroke in zip(chunks, chunk_strokes): - chunk_width = get_stroke_width(chunk_stroke) - - # Check if chunk fits on current line - potential_width = current_line_width - if current_line_width > 0: - potential_width += chunk_spacing + chunk_width - else: - potential_width = chunk_width - - if potential_width <= max_line_width or current_line_width == 0: - # Chunk fits on current line - if current_line_width > 0: - current_line_stroke = stitch_strokes( - current_line_stroke, - chunk_stroke, - chunk_spacing, - rotate_to_match=rotate_chunks - ) - current_line_text.append(chunk_text) - else: - current_line_stroke = chunk_stroke - current_line_text.append(chunk_text) - current_line_width = potential_width - else: - # Start new line (width exceeded) - all_lines.append(current_line_stroke) - all_line_texts.append(' '.join(current_line_text)) - - current_line_stroke = chunk_stroke - current_line_text = [chunk_text] - current_line_width = chunk_width + widths = [get_stroke_width(s) for s in chunk_strokes] + breaks = balanced_line_breaks( + widths, chunk_spacing, effective_max_line_width, squeeze_limit) - # Add last line from this input line - if len(current_line_stroke) > 0 or len(current_line_text) > 0: - all_lines.append(current_line_stroke) - all_line_texts.append(' '.join(current_line_text)) + for start, end in breaks: + line_stroke = chunk_strokes[start] + for nxt in chunk_strokes[start + 1:end]: + line_stroke = stitch_strokes( + line_stroke, nxt, chunk_spacing, + rotate_to_match=rotate_chunks) + all_lines.append(line_stroke) + all_line_texts.append(' '.join(chunks[start:end])) # Use the collected lines lines = all_lines @@ -725,6 +974,7 @@ def _normalize_seq(value, desired_len, cast_fn=None, name='param'): empty_line_spacing=empty_line_spacing, auto_size=auto_size, manual_size_scale=manual_size_scale, + writing_size_mm=effective_writing_size_mm, character_override_collection_id=character_override_collection_id, overrides_dict=overrides_dict, margin_jitter_frac=margin_jitter_frac, diff --git a/handwriting_synthesis/hand/_draw.py b/handwriting_synthesis/hand/_draw.py index 28b91c0..1c200f2 100644 --- a/handwriting_synthesis/hand/_draw.py +++ b/handwriting_synthesis/hand/_draw.py @@ -16,6 +16,147 @@ 'Legal': (215.9, 355.6), } +# --- Natural handwriting sizing ------------------------------------------------- +# Auto-sizing targets a physical x-height (the height of lowercase letters such as +# a / e / o) rather than fitting the worst-case stroke extent. ~4.5 mm matches +# normal ballpoint handwriting. Override per call with writing_size_mm. +NATURAL_WRITING_SIZE_MM = 4.5 +# Auto line advance as a multiple of the rendered x-height. ~2.1x leaves room for +# ascenders/descenders without large gaps (natural single spacing). +LINE_SPACING_PER_XHEIGHT = 2.1 +# Width clamp fits every line up to this multiple of the median line width; lines +# wider than that are treated as outliers (e.g. an unwrapped long token) and are +# condensed per-line at render time instead of shrinking the whole document. +WIDTH_OUTLIER_FACTOR = 2.0 + +# Bounds for the auto page-fill writing size. Short texts grow (like a real +# one-page letter written larger) up to the max; beyond it, blank space at the +# bottom looks more natural than huge letters. Long texts shrink below the base +# natural size down to the min so they fit one page with full-width lines -- +# crucially the WRAP width shrinks with the render size, otherwise _draw's +# fallback shrink leaves every line short of the right margin. +NATURAL_WRITING_MIN_FILL_MM = 2.5 +NATURAL_WRITING_MAX_FILL_MM = 7.0 + +# Fraction of the content height the page-fill solver aims to use. Below 1.0 to +# absorb what the closed-form estimate ignores: the first-line offset, integer +# line rounding, and inter-chunk stitch gaps widening lines slightly. +PAGE_FILL_FRACTION = 0.92 + +# How far past the wrap budget a single line may go before it is condensed +# horizontally (line_scale_x) at render time. A writer squeezes the last word in +# rather than leaving a gap; an 8% horizontal tightening is visually invisible. +# Used by the wrapper (line-break limit) and by the global width clamp, which +# tolerates this much overhang on the widest line instead of shrinking ALL text. +LINE_SQUEEZE_TOLERANCE = 1.08 + + +def solve_fill_xheight_px( + total_raw_width, + model_xheight, + n_blank_lines, + content_width_px, + content_height_px, + x_stretch=1.0, + spacing_per_xheight=LINE_SPACING_PER_XHEIGHT, + fill_frac=PAGE_FILL_FRACTION, +): + """Solve for the x-height (px) at which wrapped text fills the page height. + + At rendered x-height ``h`` the text scales by ``h / model_xheight``, so it + wraps into roughly ``n(h) = total_raw_width * h * x_stretch / (model_xheight + * content_width)`` lines, each advancing ``spacing_per_xheight * h``; blank + (paragraph-break) lines add the same advance without consuming text. Setting + the resulting height to ``fill_frac * content_height`` gives a quadratic in + ``h``:: + + a*h^2 + b*h - fill_frac*content_height = 0, + a = spacing_per_xheight * total_raw_width * x_stretch + / (model_xheight * content_width) + b = spacing_per_xheight * n_blank_lines + + Returns the positive root, or ``None`` if the inputs are degenerate. The + caller is expected to clamp the result to a sensible size range. + """ + if total_raw_width <= 0 or model_xheight <= 0 or content_width_px <= 0 or content_height_px <= 0: + return None + a = spacing_per_xheight * total_raw_width * max(x_stretch, 1e-6) / (model_xheight * content_width_px) + b = spacing_per_xheight * max(0, n_blank_lines) + c = -fill_frac * content_height_px + disc = b * b - 4.0 * a * c + if disc <= 0 or a <= 0: + return None + return (-b + math.sqrt(disc)) / (2.0 * a) + + +def _estimate_xheight(ls): + """Estimate the x-height (lowercase body height) of an aligned stroke array. + + ``ls`` is in the normalised layout space produced in the first pass (y in + ``[0, raw_h]`` with the baseline near the bottom). Letter bodies form a dense + central band while ascenders (l, h, k) and descenders (g, y, p) are a small + fraction of the points. Taking the 10th..90th percentile span of the y values + yields a body-height estimate that stays stable regardless of which letters + happen to appear -- unlike the raw maximum extent, which a single tall stroke + inflates and which is the reason the previous logic shrank text inconsistently. + + Returns the band height (a positive float); falls back to the full extent for + very short stroke arrays. + """ + if ls.shape[0] < 8: + return max(1e-6, float(ls[:, 1].max()) if ls.shape[0] else 1e-6) + ys = ls[:, 1] + band = float(np.percentile(ys, 90.0) - np.percentile(ys, 10.0)) + if band <= 1e-6: + band = max(1e-6, float(ys.max())) + return band + + +def _extract_svg_coordinates(d_string): + """ + Extract all coordinate points from an SVG path 'd' attribute. + + Handles M, L, C, Q, A commands (absolute and relative) to properly + calculate bounding boxes for characters with curves (like '!' dot). + + Args: + d_string: The 'd' attribute value from an SVG path element. + + Returns: + List of (x, y) tuples representing all coordinate points. + """ + coords = [] + + # M/L: x y (move/line commands) + for match in re.finditer(r'[MLml]\s*([-\d.]+)[,\s]+([-\d.]+)', d_string): + coords.append((float(match.group(1)), float(match.group(2)))) + + # C (cubic bezier): x1 y1, x2 y2, x y - capture all 3 points for bounding box + for match in re.finditer(r'[Cc]\s*([-\d.]+)[,\s]+([-\d.]+)[,\s]+([-\d.]+)[,\s]+([-\d.]+)[,\s]+([-\d.]+)[,\s]+([-\d.]+)', d_string): + coords.append((float(match.group(1)), float(match.group(2)))) # control point 1 + coords.append((float(match.group(3)), float(match.group(4)))) # control point 2 + coords.append((float(match.group(5)), float(match.group(6)))) # endpoint + + # Q (quadratic bezier): x1 y1, x y - capture both points + for match in re.finditer(r'[Qq]\s*([-\d.]+)[,\s]+([-\d.]+)[,\s]+([-\d.]+)[,\s]+([-\d.]+)', d_string): + coords.append((float(match.group(1)), float(match.group(2)))) # control point + coords.append((float(match.group(3)), float(match.group(4)))) # endpoint + + # S (smooth cubic): x2 y2, x y - capture both points + for match in re.finditer(r'[Ss]\s*([-\d.]+)[,\s]+([-\d.]+)[,\s]+([-\d.]+)[,\s]+([-\d.]+)', d_string): + coords.append((float(match.group(1)), float(match.group(2)))) + coords.append((float(match.group(3)), float(match.group(4)))) + + # T (smooth quadratic): x y + for match in re.finditer(r'[Tt]\s*([-\d.]+)[,\s]+([-\d.]+)', d_string): + coords.append((float(match.group(1)), float(match.group(2)))) + + # A (arc): rx ry angle large-arc sweep x y - capture endpoint + for match in re.finditer(r'[Aa]\s*[-\d.]+[,\s]+[-\d.]+[,\s]+[-\d.]+[,\s]+[01][,\s]+[01][,\s]+([-\d.]+)[,\s]+([-\d.]+)', d_string): + coords.append((float(match.group(1)), float(match.group(2)))) + + return coords + def _to_px(value, units): """ @@ -95,6 +236,392 @@ def _resolve_page_size(page_size, units, num_lines, default_line_height_px): return width_px, height_px, svg_size +def _compute_inter_segment_spacing(prev_segment, current_segment, reference_height): + """ + Compute spacing to add before current_segment based on the previous segment. + + Args: + prev_segment: The previous segment dict (or None if first segment) + current_segment: The current segment dict + reference_height: Height to use for computing proportional spacing + + Returns: + Spacing amount in pixels + """ + if prev_segment is None: + return 0.0 + + current_type = current_segment.get('type') + prev_type = prev_segment.get('type') + + if current_type == 'generated' and prev_type == 'generated': + # Generated-to-generated: add spacing based on text boundaries + prev_text = prev_segment.get('text', '') + current_text = current_segment.get('text', '') + has_space = prev_text.endswith(' ') or current_text.startswith(' ') + return reference_height * 0.35 if has_space else reference_height * 0.1 + + # Override spacing is handled separately in override rendering + return 0.0 + + +def _render_strokes_with_overrides( + dwg, ls, original_text, override_positions, overrides_dict, + cursor_x, line_offset_y, s_global, x_stretch, line_scale_x, + color, width, target_h, char_indices=None +): + """ + Render generated strokes with override SVGs inserted at precise positions. + + MODEL-LEVEL CHAR INDEX APPROACH WITH GAP CREATION: + The text was generated with SPACES where override characters should be. + We use the model's attention-based char_indices to know EXACTLY which + strokes correspond to each character. Since spaces create minimal horizontal + movement, we SHIFT subsequent strokes to CREATE ROOM for the override. + + This ensures: + 1. Full RNN context for surrounding text (space is a valid character) + 2. PRECISE cuts based on model's internal knowledge + 3. PROPER SPACING by shifting strokes to make room for overrides + 4. Clean override insertion at natural positions + + Args: + dwg: SVG drawing object + ls: Stroke coordinates array (already scaled) + original_text: Original text of the line (with override chars) + override_positions: List of (char_idx, char) tuples for override positions + overrides_dict: Dictionary of override character data + cursor_x: Starting X position + line_offset_y: Y position for this line + s_global: Global scale factor + x_stretch: Horizontal stretch factor + line_scale_x: Line-specific horizontal scale (for overflow prevention) + color: Stroke color + width: Stroke width + target_h: Target height for scaling overrides + char_indices: Array of character indices per stroke (from model attention). + If provided, uses precise cutting; otherwise falls back to estimation. + + Returns: + Final cursor_x position after rendering + """ + from handwriting_synthesis.hand.character_override_utils import get_random_override + + if ls.shape[0] == 0: + return cursor_x + + # Calculate dimensions + stroke_min_x = ls[:, 0].min() + stroke_max_x = ls[:, 0].max() + total_stroke_width = stroke_max_x - stroke_min_x + stroke_height = ls[:, 1].max() + num_chars = len(original_text) if original_text else 1 + + # Sort override positions by character index (process left to right) + sorted_overrides = sorted(override_positions, key=lambda x: x[0]) + + # Determine if we can use precise char_indices + use_precise_indices = ( + char_indices is not None and + len(char_indices) == ls.shape[0] + ) + + # Calculate average character width for sizing overrides + # Exclude override positions from calculation + if use_precise_indices: + non_override_chars = set(range(num_chars)) - set(ci for ci, _ in sorted_overrides) + char_widths = [] + for ci in non_override_chars: + matching = np.where(char_indices == ci)[0] + if len(matching) > 1: + w = ls[matching[-1], 0] - ls[matching[0], 0] + if w > 0: + char_widths.append(w) + avg_char_width = np.mean(char_widths) if char_widths else total_stroke_width / max(1, num_chars) + else: + avg_char_width = total_stroke_width / max(1, num_chars) + + print(f"DEBUG render_with_overrides: avg_char_width={avg_char_width:.2f}") + + # STEP 1: Calculate override widths and insertion points + override_info = [] # List of override details + + for char_idx, override_char in sorted_overrides: + # Get override data and calculate its rendered width + override_data = get_random_override(overrides_dict, override_char) + if not override_data: + print(f"Warning: No override data for '{override_char}'") + continue + + # Parse override SVG to get dimensions + try: + svg_root = ET.fromstring(override_data['svg_data']) + all_x_coords = [] + all_y_coords = [] + + for elem in svg_root.iter(): + tag_name = elem.tag.split('}')[-1] if '}' in elem.tag else elem.tag + if tag_name == 'path': + d = elem.get('d', '') + # Use comprehensive SVG parsing to capture bezier curves (e.g., for '!' dot) + coords = _extract_svg_coordinates(d) + for x, y in coords: + all_x_coords.append(x) + all_y_coords.append(y) + + if not all_x_coords or not all_y_coords: + print(f"Warning: No coordinates found for override '{override_char}'") + continue + + char_min_x = min(all_x_coords) + char_max_x = max(all_x_coords) + char_min_y = min(all_y_coords) + char_max_y = max(all_y_coords) + + char_width = char_max_x - char_min_x + char_height = char_max_y - char_min_y + + # Calculate scale to match stroke height + if char_height > 0: + scale = stroke_height / char_height + else: + scale = 1.0 + + scale_x = scale * x_stretch * line_scale_x + scale_y = scale + rendered_width = char_width * scale_x + + # Find insertion point and EXPANDED stroke range using char_indices + # We expand the range to include transition strokes (buffer zone) + stroke_range = None + exclusion_range = None # Expanded range for excluding transition strokes + + if use_precise_indices: + print(f"DEBUG: Looking for char_idx={char_idx} in char_indices") + print(f"DEBUG: char_indices range: [{char_indices.min()}, {char_indices.max()}]") + + # IMPROVED APPROACH: Find characters with SUFFICIENT strokes (not just immediate neighbors) + # Spaces may have very few strokes, so we search outward until we find substantial characters + min_strokes_threshold = 3 # Require at least this many strokes to be reliable + + # Search backwards for previous substantial character + prev_strokes = np.array([], dtype=int) + for search_idx in range(char_idx - 1, int(char_indices.min()) - 1, -1): + candidate_strokes = np.where(char_indices == search_idx)[0] + if len(candidate_strokes) >= min_strokes_threshold: + prev_strokes = candidate_strokes + print(f"DEBUG: Found prev char at idx {search_idx} with {len(candidate_strokes)} strokes") + break + + # Search forwards for next substantial character + next_strokes = np.array([], dtype=int) + for search_idx in range(char_idx + 1, int(char_indices.max()) + 1): + candidate_strokes = np.where(char_indices == search_idx)[0] + if len(candidate_strokes) >= min_strokes_threshold: + next_strokes = candidate_strokes + print(f"DEBUG: Found next char at idx {search_idx} with {len(candidate_strokes)} strokes") + break + + if len(prev_strokes) > 0 and len(next_strokes) > 0: + # Get the X position at the END of previous character + prev_end_x = ls[prev_strokes[-1], 0] + # Get the X position at the START of next character + next_start_x = ls[next_strokes[0], 0] + # Insert closer to the start of the next character (leave room for any space) + # Weight towards next_start_x since we want override right before the number/letter + insertion_x = prev_end_x + (next_start_x - prev_end_x) * 0.3 + stroke_range = (prev_strokes[-1], next_strokes[0]) + print(f"DEBUG: Using BETWEEN approach: prev ends at {prev_end_x:.2f}, next starts at {next_start_x:.2f}") + print(f"DEBUG: Insertion X position: {insertion_x:.2f} (30% into gap)") + elif len(prev_strokes) > 0: + # Only have previous character - insert after it + prev_end_x = ls[prev_strokes[-1], 0] + insertion_x = prev_end_x + avg_char_width * 0.3 + stroke_range = (prev_strokes[-1], prev_strokes[-1]) + print(f"DEBUG: Using AFTER-PREV approach: inserting after {prev_end_x:.2f}") + elif len(next_strokes) > 0: + # Only have next character - insert before it + next_start_x = ls[next_strokes[0], 0] + insertion_x = next_start_x - avg_char_width * 0.3 + stroke_range = (next_strokes[0], next_strokes[0]) + print(f"DEBUG: Using BEFORE-NEXT approach: inserting before {next_start_x:.2f}") + else: + # Fallback to position estimate + print(f"DEBUG: No adjacent chars found. Falling back to position estimate.") + insertion_x = stroke_min_x + ((char_idx - char_indices.min()) * avg_char_width) + stroke_range = None + + exclusion_range = None + else: + insertion_x = stroke_min_x + (char_idx * avg_char_width) + stroke_range = None + exclusion_range = None + + override_info.append({ + 'char_idx': char_idx, + 'override_char': override_char, + 'insertion_x': insertion_x, + 'override_width': rendered_width, + 'stroke_range': stroke_range, + 'exclusion_range': exclusion_range, # Expanded range for transition strokes + 'override_data': override_data, + 'char_min_x': char_min_x, + 'char_min_y': char_min_y, + 'scale_x': scale_x, + 'scale_y': scale_y, + }) + + print(f"DEBUG: Override '{override_char}' at char_idx={char_idx}: insertion_x={insertion_x:.2f}, width={rendered_width:.2f}") + + except Exception as e: + print(f"Error processing override '{override_char}': {e}") + continue + + # STEP 2: Build shifted stroke coordinates + # We need to shift strokes AFTER each override to make room + ls_shifted = ls.copy() + + # Calculate cumulative shift needed at each stroke position + cumulative_shift = np.zeros(ls.shape[0]) + + # Build set of all stroke indices to exclude (using expanded exclusion ranges) + excluded_stroke_indices = set() + + for info in override_info: + char_idx = info['char_idx'] + override_width = info['override_width'] + stroke_range = info.get('stroke_range') + + # Add small spacing around override (like natural character spacing) + spacing = avg_char_width * 0.1 # Reduced from 0.15 + + # Calculate the existing gap width (space placeholder takes some natural width) + insertion_x = info['insertion_x'] + + # Get the existing space width from the stroke range + if stroke_range is not None: + prev_stroke_idx, next_stroke_idx = stroke_range + # The existing gap is from end of prev char to start of next char + existing_gap = ls[next_stroke_idx, 0] - ls[prev_stroke_idx, 0] + else: + existing_gap = avg_char_width * 0.5 # Fallback estimate + + # Only shift by the ADDITIONAL space needed beyond what's already there + # We want: existing_gap -> override_width + small_spacing + extra_needed = (override_width + spacing) - existing_gap + total_shift = max(0, extra_needed) + + print(f"DEBUG: existing_gap={existing_gap:.2f}, override_width={override_width:.2f}, extra_needed={extra_needed:.2f}") + + # Store for SVG positioning + info['existing_gap'] = existing_gap + + # ALWAYS use X-position based shifting - this is more reliable than stroke exclusion + # The char_indices boundaries are fuzzy and excluding strokes cuts into adjacent chars + mask = ls[:, 0] > insertion_x + cumulative_shift[mask] += total_shift + print(f"DEBUG: X-position shift at {insertion_x:.2f}, shifting {np.sum(mask)} strokes by {total_shift:.2f}") + + # Apply shifts to X coordinates + ls_shifted[:, 0] += cumulative_shift + + # Recalculate total width after shifting + total_shifted_width = ls_shifted[:, 0].max() - ls_shifted[:, 0].min() + + # STEP 3: Render strokes (excluding override positions AND transition strokes) + ls_render = ls_shifted.copy() + shifted_min_x = ls_shifted[:, 0].min() + ls_render[:, 0] += cursor_x - shifted_min_x + ls_render[:, 1] += line_offset_y + + prev_eos = 1.0 + commands = [] + + # RENDER ALL STROKES - no exclusion! + # We use X-position shifting to create gaps, so all strokes are valid + for x, y, eos in zip(*ls_render.T): + commands.append('{}{},{}'.format('M' if prev_eos == 1.0 else 'L', x, y)) + prev_eos = eos + + if commands: + p = ' '.join(commands) + path = svgwrite.path.Path(p) + path = path.stroke(color=color, width=width, linecap='round', linejoin='round', miterlimit=2).fill('none') + dwg.add(path) + + # STEP 4: Insert override SVGs at calculated positions (accounting for shifts) + running_shift = 0.0 + for info in override_info: + char_idx = info['char_idx'] + override_char = info['override_char'] + override_data = info['override_data'] + override_width = info['override_width'] + stroke_range = info['stroke_range'] + existing_gap = info.get('existing_gap', avg_char_width * 0.5) + + # Small spacing before override (consistent with shift calculation) + spacing = avg_char_width * 0.05 # Small gap before override + + # Calculate position accounting for previous shifts + if use_precise_indices and stroke_range is not None: + prev_stroke_idx, next_stroke_idx = stroke_range + # Position after the previous character ends (in shifted coordinates) + prev_end_x_shifted = ls_shifted[prev_stroke_idx, 0] + base_x = prev_end_x_shifted - shifted_min_x + cursor_x + else: + base_x = info['insertion_x'] - stroke_min_x + cursor_x + running_shift + + # Place override with small spacing after previous character + override_start_x = base_x + spacing + + # Position override SVG + pos_x = override_start_x - (info['char_min_x'] * info['scale_x']) + pos_y = line_offset_y - (info['char_min_y'] * info['scale_y']) + + print(f"DEBUG: Rendering override '{override_char}' at pos_x={pos_x:.2f}") + + # Create group with transform + g = dwg.g(transform=f"translate({pos_x},{pos_y}) scale({info['scale_x']},{info['scale_y']})") + + # Add paths from override SVG + try: + svg_root = ET.fromstring(override_data['svg_data']) + for elem in svg_root.iter(): + tag_name = elem.tag.split('}')[-1] if '}' in elem.tag else elem.tag + if tag_name == 'path': + d = elem.get('d') + if not d: + continue + + orig_stroke = elem.get('stroke', 'none') + path = dwg.path(d=d) + + if orig_stroke and orig_stroke.lower() not in ('none', 'transparent'): + avg_scale = (info['scale_x'] + info['scale_y']) / 2.0 + adjusted_stroke_width = width / avg_scale if avg_scale > 0 else width + path = path.stroke( + color=color, + width=adjusted_stroke_width, + linecap='round', + linejoin='round' + ).fill('none') + else: + path = path.fill(color) + + g.add(path) + + dwg.add(g) + except Exception as e: + print(f"Error rendering override '{override_char}': {e}") + + # Track cumulative shift for fallback mode + running_shift += override_width + spacing * 2 + + # Return final X position + final_x = cursor_x + total_shifted_width + return final_x + + def _draw( line_segments, # Changed from 'strokes' to 'line_segments' lines, @@ -115,6 +642,7 @@ def _draw( empty_line_spacing=None, auto_size=True, manual_size_scale=1.0, + writing_size_mm=None, # Target x-height in mm for natural sizing (None -> NATURAL_WRITING_SIZE_MM) character_override_collection_id=None, overrides_dict=None, # New parameter margin_jitter_frac=None, # Bi-directional left margin jitter (fraction of content width) @@ -174,10 +702,12 @@ def _draw( content_width_px = max(1.0, width_px - (m_left + m_right)) content_height_px = max(1.0, height_px - (m_top + m_bottom)) - line_height_px = _to_px(line_height, units) if line_height is not None else default_line_height_px - # Ensure all lines fit vertically - max_line_height_px = content_height_px / max(1, len(line_segments) + 0) - line_height_px = min(line_height_px, max_line_height_px) + # Requested line advance (px). When None we derive a natural one from the text + # size below. The previous content_height/num_lines cram is gone: keeping the + # output on one page is handled once, after sizing, by scaling size + spacing + # together (see the sizing block) so spacing always tracks the letter size. + line_height_given = line_height is not None + line_height_px = _to_px(line_height, units) if line_height_given else default_line_height_px # Empty line spacing: if not specified, use regular line_height_px empty_line_spacing_px = _to_px(empty_line_spacing, units) if empty_line_spacing is not None else line_height_px @@ -216,9 +746,14 @@ def _draw( if margin_jitter_coherence is None: margin_jitter_coherence = {'high': 0.0, 'normal': 0.4}.get(legibility, 0.3) - # First pass: preprocess each line and compute per-line max allowed scale + # First pass: preprocess each line, measuring a robust per-segment x-height + # (drives the natural text size) and per-line widths (drives the width clamp). + # target_h is only a stable REFERENCE for override scaling -- it cancels out of + # the override width math, so its exact value does not affect the output. preprocessed_lines = [] - scale_limits = [] + raw_heights = [] # full stroke extents, used for override size matching + xheights = [] # robust body heights, used to pick the natural text size + line_raw_widths = [] # summed generated raw width per line, for the width clamp target_h = 0.95 * line_height_px for line_idx, segment_list in enumerate(line_segments): @@ -229,6 +764,7 @@ def _draw( preprocessed_segments = [] color = stroke_colors[line_idx] width = stroke_widths[line_idx] + line_gen_raw_w = 0.0 # accumulated generated stroke width for this line for segment in segment_list: if segment['type'] == 'override': @@ -259,13 +795,29 @@ def _draw( offsets_cp = offsets.copy() offsets_cp[:, :2] *= float(global_scale) ls = drawing.offsets_to_coords(offsets_cp) - if denoise: - ls = drawing.denoise(ls) - if interpolate_factor > 1: - try: - ls = drawing.interpolate(ls, factor=interpolate_factor) - except Exception: - pass + + # Get char_indices and override_positions for this segment + segment_char_indices = segment.get('char_indices', None) + segment_override_positions = segment.get('override_positions', []) + has_overrides = bool(segment_override_positions) + + # IMPORTANT: Skip denoise/interpolate for segments with overrides + # This preserves the 1:1 correspondence between strokes and char_indices + # which is critical for precise model-based cutting + if has_overrides and segment_char_indices is not None: + print(f"DEBUG preprocess: Skipping denoise/interpolate for override segment to preserve char_indices alignment") + # Don't denoise or interpolate - keep exact correspondence + else: + if denoise: + ls = drawing.denoise(ls) + if interpolate_factor > 1: + try: + ls = drawing.interpolate(ls, factor=interpolate_factor) + except Exception: + pass + # Clear char_indices since they no longer align after denoise/interpolate + segment_char_indices = None + if ls.shape[0] == 0: preprocessed_segments.append({'type': 'empty'}) continue @@ -276,63 +828,103 @@ def _draw( ls[:, :2] -= min_xy raw_w = max(1e-6, ls[:, 0].max()) raw_h = max(1e-6, ls[:, 1].max()) - s_w = content_width_px / raw_w - s_h = target_h / raw_h - scale_limits.append(min(s_w, s_h)) + raw_heights.append(raw_h) # full extent (override matching) + xheights.append((_estimate_xheight(ls), raw_w)) # body height (sizing) + line_gen_raw_w += raw_w + preprocessed_segments.append({ 'type': 'generated', 'strokes': ls, + 'raw_h': raw_h, # Store for adjacent override sizing + 'raw_w': raw_w, # cached so the width pass need not re-measure 'color': color, 'width': width, - 'text': segment.get('text', '') # Add original text for spacing checks + 'text': segment.get('text', ''), # Add original text for spacing checks + 'override_positions': segment_override_positions, # Preserve override positions + 'char_indices': segment_char_indices # Character indices (preserved for override segments) }) + if line_gen_raw_w > 0: + line_raw_widths.append(line_gen_raw_w) preprocessed_lines.append(preprocessed_segments if preprocessed_segments else [{'empty': True}]) - # Determine global scale: automatic or manual + # ---- Choose the natural handwriting size and line spacing ----------------- + # + # Size is driven by a robust x-height target that is the SAME for every line, + # so a single tall or wide line no longer shrinks the whole document. Width is + # respected via a percentile clamp (most lines fit the page; the few widest are + # condensed slightly per line at render time). Vertically we keep one page by + # scaling the text AND the spacing down together when the lines would not fit. + x_stretch = float(x_stretch) if x_stretch is not None else 1.0 + if x_stretch <= 0: + x_stretch = 1.0 + + writing_mm = NATURAL_WRITING_SIZE_MM if writing_size_mm is None else float(writing_size_mm) + target_xheight_px = max(1.0, writing_mm * PX_PER_MM) + + # Typical x-height from the LONG segments only: on short lines (a signature, + # a paragraph's last few words) ascenders/descenders are a large fraction of + # the points, which inflates the percentile band and would make all text + # render smaller and narrower than the wrap predicted. + if xheights: + max_seg_w = max(w for _, w in xheights) + long_bands = [h for h, w in xheights if w >= 0.5 * max_seg_w] + typical_xheight = float(np.median(long_bands if long_bands + else [h for h, _ in xheights])) + else: + typical_xheight = target_xheight_px + size_scale = target_xheight_px / max(1e-6, typical_xheight) + if auto_size: - s_global = min(scale_limits) if scale_limits else 1.0 + s_global = size_scale + # Width clamp: fit every NORMAL line within the page, ignoring gross + # outliers (a single unwrapped long line is condensed per-line at render + # time via line_scale_x instead of shrinking every line -- which is what + # used to make the text tiny). The clamp tolerates LINE_SQUEEZE_TOLERANCE + # of overhang on the widest line: that line is condensed individually, + # so one well-packed line doesn't scale the whole document down. + if line_raw_widths: + median_w = float(np.median(line_raw_widths)) + normal_widths = [w for w in line_raw_widths if w <= WIDTH_OUTLIER_FACTOR * median_w] + width_ref = max(normal_widths) if normal_widths else median_w + if width_ref > 1e-6: + s_global = min(s_global, LINE_SQUEEZE_TOLERANCE * content_width_px + / (width_ref * x_stretch)) else: - s_global = float(manual_size_scale) - - # BUGFIX: For small pages where auto_size significantly reduces text scale, - # adjust line height to be proportional to the actual rendered text size. - # This prevents huge line spacing when text is scaled down to fit narrow pages. - if auto_size and scale_limits: - # Calculate what the text height would have been without width constraint - # scale_limits contains min(s_w, s_h) for each line, where s_h = target_h / raw_h - # If s_global is much smaller than what s_h alone would give, text is width-constrained - # In that case, effective line height should scale down proportionally - - # Recalculate scale limits considering only height (not width) - height_only_scales = [] - for preprocessed_segments in preprocessed_lines: - for segment in preprocessed_segments: - if segment.get('type') == 'generated' and 'strokes' in segment: - ls = segment['strokes'] - raw_h = max(1e-6, ls[:, 1].max()) - s_h = target_h / raw_h - height_only_scales.append(s_h) - break - - if height_only_scales: - # The ideal scale based on height alone - ideal_height_scale = min(height_only_scales) - # If actual scale is significantly smaller (width-constrained), reduce line height - if s_global < ideal_height_scale * 0.95: # Allow 5% tolerance - scale_ratio = s_global / ideal_height_scale - # Adjust line height proportionally, but keep some minimum spacing - adjusted_line_height = line_height_px * scale_ratio - # Ensure minimum spacing of at least 20% of original to prevent overlapping - line_height_px = max(adjusted_line_height, line_height_px * 0.2) - # Also adjust empty line spacing if it was based on line_height_px - if empty_line_spacing is None: - empty_line_spacing_px = line_height_px + # manual_size_scale is now a multiple of the natural size (1.0 == natural). + s_global = float(manual_size_scale) * size_scale + + # Rendered x-height after the width clamp, used to derive natural line spacing. + rendered_xheight = typical_xheight * s_global + + # Line advance: honour an explicit line_height, otherwise derive one from the + # rendered x-height so spacing always tracks the letter size. + line_advance_px = line_height_px if line_height_given else LINE_SPACING_PER_XHEIGHT * rendered_xheight + + # Keep everything on one page (auto-size only): if the lines would not fit the + # content height, scale the size and the spacing down by the same factor. + if auto_size: + n_rows = max(1, len(preprocessed_lines)) + needed_height = line_advance_px * (n_rows + 1.0) # first-line offset + descender slack + if needed_height > content_height_px: + vfit = content_height_px / needed_height + s_global *= vfit + line_advance_px *= vfit + rendered_xheight *= vfit + + line_height_px = max(1.0, line_advance_px) + if empty_line_spacing is None: + empty_line_spacing_px = line_height_px + + # Override sizing reference: overrides are sized to neighbouring generated text + # via raw_h * s_global; target_h cancels out of the override width math, so its + # exact value does not matter as long as it is used consistently. + avg_raw_h = sum(raw_heights) / len(raw_heights) if raw_heights else 1.0 + effective_target_h = avg_raw_h * s_global # Second pass: render with uniform scale across lines for consistent letter size cursor_y = m_top + (3.0 * line_height_px / 4.0) rng = np.random.RandomState(42) - x_stretch = float(x_stretch) if x_stretch is not None else 1.0 # Pre-generate bi-directional margin jitter for all lines (Gaussian + coherence smoothing) num_lines = len(preprocessed_lines) @@ -369,9 +961,37 @@ def _draw( ls_temp[:, :2] *= s_global if x_stretch != 1.0: ls_temp[:, 0] *= x_stretch - total_line_width += ls_temp[:, 0].max() + segment_height = ls_temp[:, 1].max() + segment_width = ls_temp[:, 0].max() + + # Add inter-segment spacing + prev_seg = preprocessed_segments[seg_idx - 1] if seg_idx > 0 else None + spacing = _compute_inter_segment_spacing(prev_seg, segment, segment_height) + total_line_width += spacing + segment_width + + # SPACE PLACEHOLDER APPROACH: No width adjustment needed + # The strokes already have natural gaps where spaces are, and we just fill them. + # The total width is the stroke width as-is. + elif segment.get('type') == 'override': - override_width = segment['estimated_width'] + # Scale estimated width using ADJACENT segment heights (same as rendering) + adjacent_raw_heights = [] + if seg_idx > 0: + prev_seg = preprocessed_segments[seg_idx - 1] + if prev_seg.get('type') == 'generated' and 'raw_h' in prev_seg: + adjacent_raw_heights.append(prev_seg['raw_h']) + if seg_idx < len(preprocessed_segments) - 1: + next_seg = preprocessed_segments[seg_idx + 1] + if next_seg.get('type') == 'generated' and 'raw_h' in next_seg: + adjacent_raw_heights.append(next_seg['raw_h']) + + if adjacent_raw_heights: + local_raw_h = sum(adjacent_raw_heights) / len(adjacent_raw_heights) + local_effective_target_h = local_raw_h * s_global + else: + local_effective_target_h = effective_target_h + + override_width = segment['estimated_width'] * (local_effective_target_h / target_h) # Check if there's a space before this override character has_space_before = False @@ -435,33 +1055,89 @@ def _draw( cursor_x = line_offset_x for seg_idx, segment in enumerate(preprocessed_segments): if segment.get('type') == 'generated': - ls = segment['strokes'].copy() - ls[:, :2] *= s_global - if x_stretch != 1.0: - ls[:, 0] *= x_stretch + # Check if this segment uses the placeholder-based override approach + override_positions = segment.get('override_positions', []) + + if override_positions and overrides_dict: + # MODEL-LEVEL CHAR INDEX APPROACH: Use char_indices from attention for precise cutting + char_indices = segment.get('char_indices', None) + print(f"DEBUG: Using MODEL-LEVEL CHAR INDEX rendering for segment with {len(override_positions)} overrides") + if char_indices is not None: + print(f"DEBUG: Have char_indices: {len(char_indices)} values") + else: + print(f"DEBUG: No char_indices, will fall back to width estimation") - # Apply line-specific horizontal scaling to prevent overflow - if line_scale_x != 1.0: - ls[:, 0] *= line_scale_x + ls = segment['strokes'].copy() + ls[:, :2] *= s_global + if x_stretch != 1.0: + ls[:, 0] *= x_stretch + if line_scale_x != 1.0: + ls[:, 0] *= line_scale_x + + segment_height = ls[:, 1].max() + + cursor_x = _render_strokes_with_overrides( + dwg=dwg, + ls=ls, + original_text=segment.get('text', ''), + override_positions=override_positions, + overrides_dict=overrides_dict, + cursor_x=cursor_x, + line_offset_y=line_offset_y, + s_global=s_global, + x_stretch=x_stretch, + line_scale_x=line_scale_x, + color=segment['color'], + width=segment['width'], + target_h=segment_height, + char_indices=char_indices # NEW: Pass char_indices for precise cutting + ) + else: + # STANDARD PATH: No overrides in this segment, render normally + ls = segment['strokes'].copy() + raw_h_before_scale = ls[:, 1].max() + + # NOTE: With the space-placeholder approach, we no longer need aggressive + # clipping for segments adjacent to overrides. Text is generated as a + # continuous sequence with spaces where overrides go, and char_indices + # from attention give us precise cutting positions. + + ls[:, :2] *= s_global + if x_stretch != 1.0: + ls[:, 0] *= x_stretch + + # Apply line-specific horizontal scaling to prevent overflow + if line_scale_x != 1.0: + ls[:, 0] *= line_scale_x - # Track segment width before translating - segment_width = ls[:, 0].max() + # Track segment width before translating + segment_width = ls[:, 0].max() if ls.shape[0] > 0 else 0 + segment_height = ls[:, 1].max() if ls.shape[0] > 0 else 0 - ls[:, 0] += cursor_x - ls[:, 1] += line_offset_y + # Add inter-segment spacing + prev_seg = preprocessed_segments[seg_idx - 1] if seg_idx > 0 else None + spacing = _compute_inter_segment_spacing(prev_seg, segment, segment_height) + cursor_x += spacing - prev_eos = 1.0 - commands = [] - for x, y, eos in zip(*ls.T): - commands.append('{}{},{}'.format('M' if prev_eos == 1.0 else 'L', x, y)) - prev_eos = eos - p = ' '.join(commands) - path = svgwrite.path.Path(p) - path = path.stroke(color=segment['color'], width=segment['width'], linecap='round', linejoin='round', miterlimit=2).fill('none') - dwg.add(path) + # DEBUG: Log generated segment dimensions + print(f"DEBUG generated: text='{segment.get('text', '')[:20]}', raw_h={raw_h_before_scale:.2f}, final_h={segment_height:.2f}") - # Advance cursor by segment width - cursor_x += segment_width + if ls.shape[0] > 0: + ls[:, 0] += cursor_x + ls[:, 1] += line_offset_y + + prev_eos = 1.0 + commands = [] + for x, y, eos in zip(*ls.T): + commands.append('{}{},{}'.format('M' if prev_eos == 1.0 else 'L', x, y)) + prev_eos = eos + p = ' '.join(commands) + path = svgwrite.path.Path(p) + path = path.stroke(color=segment['color'], width=segment['width'], linecap='round', linejoin='round', miterlimit=2).fill('none') + dwg.add(path) + + # Advance cursor by segment width + cursor_x += segment_width elif segment.get('type') == 'override': override_data = segment['override_data'] @@ -476,10 +1152,11 @@ def _draw( tag_name = elem.tag.split('}')[-1] if '}' in elem.tag else elem.tag if tag_name == 'path': d = elem.get('d', '') - coords = re.findall(r'[ML]\s*([-\d.]+)\s+([-\d.]+)', d) + # Use comprehensive SVG parsing to capture bezier curves (e.g., for '!' dot) + coords = _extract_svg_coordinates(d) for x, y in coords: - all_x_coords.append(float(x)) - all_y_coords.append(float(y)) + all_x_coords.append(x) + all_y_coords.append(y) if not all_x_coords or not all_y_coords: print(f"Warning: No coordinates found for override '{segment.get('char', '?')}'") @@ -494,12 +1171,27 @@ def _draw( char_width = char_max_x - char_min_x char_height = char_max_y - char_min_y - # Calculate scale to match generated text height - # Generated text: normalized to start at y=0, height=raw_h, then scaled by s_global - # Final height = raw_h * s_global ≈ target_h - # SVG character should have same final height: char_height * scale = target_h + # Calculate scale to match ADJACENT generated text height + # Use raw_h from neighboring segments for better local matching + adjacent_raw_heights = [] + if seg_idx > 0: + prev_seg = preprocessed_segments[seg_idx - 1] + if prev_seg.get('type') == 'generated' and 'raw_h' in prev_seg: + adjacent_raw_heights.append(prev_seg['raw_h']) + if seg_idx < len(preprocessed_segments) - 1: + next_seg = preprocessed_segments[seg_idx + 1] + if next_seg.get('type') == 'generated' and 'raw_h' in next_seg: + adjacent_raw_heights.append(next_seg['raw_h']) + + # Use adjacent average if available, otherwise fall back to global + if adjacent_raw_heights: + local_raw_h = sum(adjacent_raw_heights) / len(adjacent_raw_heights) + local_effective_target_h = local_raw_h * s_global + else: + local_effective_target_h = effective_target_h + if char_height > 0: - scale = target_h / char_height + scale = local_effective_target_h / char_height else: scale = 1.0 @@ -514,6 +1206,9 @@ def _draw( rendered_width = char_width * scale_x rendered_height = char_height * scale_y + # DEBUG: Log override dimensions + print(f"DEBUG override: char='{segment.get('char', '?')}', char_h={char_height:.2f}, scale={scale:.4f}, final_h={rendered_height:.2f}, local_target_h={local_effective_target_h:.2f}, adjacent_raw_h={adjacent_raw_heights}") + # Check if there's a space before this override character has_space_before = False if seg_idx > 0: @@ -553,19 +1248,19 @@ def _draw( continue orig_stroke = elem.get('stroke', 'none') - orig_stroke_width = elem.get('stroke-width', '3') path = dwg.path(d=d) if orig_stroke and orig_stroke.lower() not in ('none', 'transparent'): - try: - stroke_width = min(float(orig_stroke_width), 4.0) - except: - stroke_width = 2.0 + # Use line-level stroke width for consistency with generated text + # Compensate for transform scaling to maintain visual thickness + line_stroke_width = segment['width'] + avg_scale = (scale_x + scale_y) / 2.0 + adjusted_stroke_width = line_stroke_width / avg_scale if avg_scale > 0 else line_stroke_width path = path.stroke( color=segment['color'], - width=stroke_width, + width=adjusted_stroke_width, linecap='round', linejoin='round' ).fill('none') diff --git a/handwriting_synthesis/hand/character_override_utils.py b/handwriting_synthesis/hand/character_override_utils.py index d73b062..605c29d 100644 --- a/handwriting_synthesis/hand/character_override_utils.py +++ b/handwriting_synthesis/hand/character_override_utils.py @@ -269,10 +269,18 @@ def estimate_override_width(override_data, target_height, x_stretch=1.0): tag_name = elem.tag.split('}')[-1] if '}' in elem.tag else elem.tag if tag_name == 'path': d = elem.get('d', '') - coords = re.findall(r'[ML]\s*([-\d.]+)\s+([-\d.]+)', d) - for x, y in coords: - all_x_coords.append(float(x)) - all_y_coords.append(float(y)) + # Extract M/L coordinates + for match in re.finditer(r'[MLml]\s*([-\d.]+)[,\s]+([-\d.]+)', d): + all_x_coords.append(float(match.group(1))) + all_y_coords.append(float(match.group(2))) + # Extract C (cubic bezier) control and end points for bounding box + for match in re.finditer(r'[Cc]\s*([-\d.]+)[,\s]+([-\d.]+)[,\s]+([-\d.]+)[,\s]+([-\d.]+)[,\s]+([-\d.]+)[,\s]+([-\d.]+)', d): + all_x_coords.extend([float(match.group(1)), float(match.group(3)), float(match.group(5))]) + all_y_coords.extend([float(match.group(2)), float(match.group(4)), float(match.group(6))]) + # Extract Q (quadratic bezier) points + for match in re.finditer(r'[Qq]\s*([-\d.]+)[,\s]+([-\d.]+)[,\s]+([-\d.]+)[,\s]+([-\d.]+)', d): + all_x_coords.extend([float(match.group(1)), float(match.group(3))]) + all_y_coords.extend([float(match.group(2)), float(match.group(4))]) if all_x_coords and all_y_coords: char_width = max(all_x_coords) - min(all_x_coords) diff --git a/handwriting_synthesis/hand/operations/__init__.py b/handwriting_synthesis/hand/operations/__init__.py index d5a826d..c693231 100644 --- a/handwriting_synthesis/hand/operations/__init__.py +++ b/handwriting_synthesis/hand/operations/__init__.py @@ -14,7 +14,7 @@ calculate_adaptive_spacing, stitch_strokes, ) -from .chunking import split_text_into_chunks +from .chunking import split_text_into_chunks, balanced_line_breaks from .sampling import sample_strokes __all__ = [ @@ -26,5 +26,6 @@ 'calculate_adaptive_spacing', 'stitch_strokes', 'split_text_into_chunks', + 'balanced_line_breaks', 'sample_strokes', ] diff --git a/handwriting_synthesis/hand/operations/chunking.py b/handwriting_synthesis/hand/operations/chunking.py index 175ce6e..5c668aa 100644 --- a/handwriting_synthesis/hand/operations/chunking.py +++ b/handwriting_synthesis/hand/operations/chunking.py @@ -1,6 +1,126 @@ """Text chunking logic for improved handwriting generation.""" -from typing import List +from typing import List, Optional, Tuple + + +def balanced_line_breaks( + widths: List[float], + spacing: float, + target: float, + limit: float, +) -> List[Tuple[int, int]]: + """Choose line breaks over measured chunk widths, minimising raggedness. + + Greedy filling makes line lengths erratic (one line packed past the budget, + the next stopping at 60%), which reads as a jagged right margin. This is the + classic dynamic-programming line-breaking approach applied to chunk widths: + every line except the last pays for its deviation from the target, so slack + is spread evenly across lines instead of accumulating in one. The penalty is + asymmetric: undershoot pays the full quadratic, overshoot (up to ``limit``) + only a quarter -- a slightly over-full line is condensed a few percent at + render time, which looks like natural cramming, whereas an under-full line + leaves a visible gap at the margin. + + Args: + widths: Measured raw width of each chunk, in order. + spacing: Horizontal gap added between chunks on a line. + target: Ideal line width (the wrap budget). + limit: Hard maximum line width (target plus any squeeze allowance). A + single chunk wider than the limit still gets a line of its own. + + Returns: + List of (start, end) index pairs, one per line, covering all chunks. + """ + n = len(widths) + if n == 0: + return [] + inf = float('inf') + best = [0.0] + [inf] * n + back = [0] * (n + 1) + for j in range(1, n + 1): + w = 0.0 + for i in range(j - 1, -1, -1): + w = widths[i] + (spacing + w if w > 0 else 0.0) + if w > limit and i < j - 1: + break # adding earlier chunks only widens the line further + if j == n: + penalty = 0.0 # the final line may be any length + elif w <= target: + penalty = (target - w) ** 2 + else: + penalty = 0.25 * (w - target) ** 2 # mild: overshoot is condensed + if best[i] + penalty < best[j]: + best[j] = best[i] + penalty + back[j] = i + lines = [] + j = n + while j > 0: + i = back[j] + lines.append((i, j)) + j = i + return lines[::-1] + + +# Tokens that mark the end of a sentence -- strong, high-priority break points. +_SENTENCE_ENDERS = ('.', '!', '?') +# Softer break points; a comma/semicolon is a natural place to end a chunk. +_PUNCTUATION_BREAKS = (',', ';', ':') + + +def _hard_split_long_word(word: str, max_chars: int) -> List[str]: + """Break a single token longer than ``max_chars`` into ``<= max_chars`` pieces. + + Normal words are returned unchanged. This only triggers for pathological + tokens (URLs, long identifiers, base64 blobs) that would otherwise create an + over-long RNN sequence or run off the edge of the page, since such a token + contains no spaces for the wrapper to break on. + + Args: + word: The token to (possibly) split. + max_chars: Maximum characters allowed per piece. + + Returns: + List of one or more sub-tokens, each at most ``max_chars`` long. + """ + if max_chars <= 0 or len(word) <= max_chars: + return [word] + return [word[i:i + max_chars] for i in range(0, len(word), max_chars)] + + +def _find_break_point( + words: List[str], + start: int, + search_lo: int, + search_hi: int, + use_sentence: bool, + use_punctuation: bool, +) -> Optional[int]: + """Find a punctuation-based chunk boundary within a word range. + + Sentence terminators take priority and break as early as possible (so a + sentence becomes its own chunk). Failing that, the *latest* soft punctuation + break in range is used, which fills the chunk as much as possible while still + ending on a natural pause. + + Args: + words: Full list of words. + start: Index of the first word in the current chunk. + search_lo: First word index to consider as a break point. + search_hi: One past the last word index to consider. + use_sentence: Whether to break on sentence terminators (. ! ?). + use_punctuation: Whether to break on soft punctuation (, ; :). + + Returns: + A word count for the chunk if a break was found, otherwise ``None``. + """ + punctuation_break = None + for j in range(search_lo, search_hi): + word = words[j] + if use_sentence and word.endswith(_SENTENCE_ENDERS): + return j - start + 1 + if use_punctuation and word.endswith(_PUNCTUATION_BREAKS): + punctuation_break = j - start + 1 # keep the latest one in range + return punctuation_break def split_text_into_chunks( @@ -19,19 +139,22 @@ def split_text_into_chunks( - 'word_length': Adjusts based on average word length (original behavior) - 'sentence': Respects sentence boundaries (periods, !, ?) - 'punctuation': Prefers to break at punctuation marks (commas, semicolons) - - 'balanced': Combines word length + punctuation awareness + - 'balanced': Combines word length + sentence + punctuation awareness - 'off': Fixed chunk sizes (no adaptation) This method creates more natural chunks by: 1. Using more words if they're short (better context for the model) 2. Using fewer words if they're long (avoid exceeding limits) 3. Respecting sentence and punctuation boundaries when enabled - 4. Ensuring reasonable min/max bounds + 4. Keeping chunk length near ``target_chars_per_chunk`` for even line filling + 5. Ensuring reasonable min/max bounds Args: text: Input text to split. words_per_chunk: Target number of words per chunk (used as baseline). - target_chars_per_chunk: Target character count per chunk (default: 25). + target_chars_per_chunk: Soft upper bound on characters per chunk. Chunks + are trimmed back toward this length (never below ``min_words``) so the + generated pieces stay a consistent size. min_words: Minimum words per chunk. max_words: Maximum words per chunk. adaptive_chunking: Enable adaptive chunking. @@ -44,12 +167,36 @@ def split_text_into_chunks( leading_space = len(text) - len(text.lstrip()) trailing_space = len(text) - len(text.rstrip()) - words = text.split() - if not words: + raw_words = text.split() + if not raw_words: # If only whitespace, return it as-is return [text] if text else [] - # Non-adaptive mode: fixed chunk sizes + # Character budgets. The soft cap keeps chunks near the requested target; the + # hard cap only breaks pathological space-less tokens so they cannot blow past + # the model's sequence limit. A normal long word (e.g. "internationalization") + # stays intact because it is shorter than the hard cap. + soft_char_cap = max(1, int(target_chars_per_chunk)) + hard_word_cap = max(soft_char_cap * 2, 40) + + # Pre-split any token that is, on its own, longer than the hard cap. For normal + # text this is a no-op, so word-based logic below is unchanged. + words: List[str] = [] + for w in raw_words: + words.extend(_hard_split_long_word(w, hard_word_cap)) + + def _chunk_char_len(start: int, count: int) -> int: + """Character length of ``count`` words joined with single spaces.""" + return len(' '.join(words[start:start + count])) + + def _fit_to_char_budget(start: int, count: int) -> int: + """Shrink ``count`` so the chunk fits the soft char cap (keeps >= min_words).""" + lower_bound = min(min_words, len(words) - start) + while count > lower_bound and _chunk_char_len(start, count) > soft_char_cap: + count -= 1 + return max(1, count) + + # Non-adaptive mode: fixed chunk sizes (still honours the hard word cap above). if not adaptive_chunking or adaptive_strategy == 'off': chunks = [] for i in range(0, len(words), words_per_chunk): @@ -66,83 +213,64 @@ def split_text_into_chunks( chunks = [] i = 0 - # Sentence boundary markers - sentence_enders = {'.', '!', '?'} - punctuation_breaks = {',', ';', ':', '--'} + use_sentence = adaptive_strategy in ('sentence', 'balanced') + use_punctuation = adaptive_strategy in ('punctuation', 'balanced') + use_word_length = adaptive_strategy in ('word_length', 'balanced') + # A sentence terminator is a strong break for ANY punctuation-aware strategy + # (sentence / punctuation / balanced); soft commas/semicolons only break when + # punctuation awareness is on. NOTE: 'balanced' must consider both -- the old + # if/elif on overlapping sets made the punctuation branch unreachable for it + # (and 'balanced' is the default strategy). + break_on_sentence = use_sentence or use_punctuation while i < len(words): - # Start with the target words per chunk - chunk_word_count = words_per_chunk + remaining = len(words) - i - # Look ahead to see the average word length - lookahead_end = min(i + words_per_chunk * 2, len(words)) - lookahead_words = words[i:lookahead_end] - - # Word length adaptation (used in word_length and balanced strategies) - if adaptive_strategy in ('word_length', 'balanced') and lookahead_words: - avg_word_length = sum(len(w) for w in lookahead_words) / len(lookahead_words) + # 1. Baseline chunk size, optionally adapted to average word length so that + # short words pack more per chunk and long words pack fewer. + chunk_word_count = words_per_chunk + if use_word_length: + lookahead_words = words[i:min(i + words_per_chunk * 2, len(words))] + if lookahead_words: + avg_word_length = sum(len(w) for w in lookahead_words) / len(lookahead_words) + if avg_word_length < 4: # short words + chunk_word_count = min(max_words, int(words_per_chunk * 1.5)) + elif avg_word_length > 7: # long words + chunk_word_count = max(min_words, int(words_per_chunk * 0.75)) + chunk_word_count = max(min_words, min(max_words, chunk_word_count)) + chunk_word_count = min(chunk_word_count, remaining) - # Adjust chunk size based on word length - if avg_word_length < 4: # Short words (a, an, the, is, of, etc.) - # Use more words to provide better context - chunk_word_count = min(max_words, int(words_per_chunk * 1.5)) - elif avg_word_length > 7: # Long words - # Use fewer words to avoid too long chunks - chunk_word_count = max(min_words, int(words_per_chunk * 0.75)) + # 2. Character budget: the most words that still fit the soft cap. This + # bounds everything below so chunks stay near target_chars_per_chunk. + budget_max = _fit_to_char_budget(i, min(max_words, remaining)) - # Ensure we stay within bounds - chunk_word_count = max(min_words, min(max_words, chunk_word_count)) + # 3. Prefer a natural break (sentence/punctuation) *within* the budget + # window, so the break lands on real punctuation that also fits the + # target -- rather than trimming a good break back to mid-phrase. + if break_on_sentence: + search_lo = i + min_words + search_hi = i + min(budget_max, remaining) + break_point = _find_break_point( + words, i, search_lo, search_hi, break_on_sentence, use_punctuation + ) + if break_point: + chunk_word_count = break_point + else: + # No natural break in range: keep the baseline size, capped by budget. + chunk_word_count = min(chunk_word_count, budget_max) + else: + # word_length / off: no punctuation awareness, just honour the budget. + chunk_word_count = min(chunk_word_count, budget_max) - # Don't exceed remaining words - chunk_word_count = min(chunk_word_count, len(words) - i) - - # Sentence-aware chunking (sentence and balanced strategies) - if adaptive_strategy in ('sentence', 'balanced'): - # Look for sentence boundaries within our chunk range - search_end = min(i + max_words, len(words)) - for j in range(i + min_words, search_end): - word = words[j] - # Check if word ends with sentence terminator - if any(word.endswith(char) for char in sentence_enders): - # Found sentence end, use this as chunk boundary - chunk_word_count = j - i + 1 - break - - # Punctuation-aware chunking (punctuation and balanced strategies) - elif adaptive_strategy in ('punctuation', 'balanced'): - # Look for punctuation breaks within our chunk range - search_start = i + min_words - search_end = min(i + chunk_word_count + 2, len(words)) - best_break = None - - for j in range(search_start, search_end): - word = words[j] - # Check for sentence enders first (higher priority) - if any(word.endswith(char) for char in sentence_enders): - best_break = j - i + 1 - break - # Check for punctuation breaks (lower priority) - elif any(word.endswith(char) for char in punctuation_breaks): - best_break = j - i + 1 - - if best_break: - chunk_word_count = best_break - - # Final bounds check - chunk_word_count = max(min_words, min(max_words, chunk_word_count)) - chunk_word_count = min(chunk_word_count, len(words) - i) + # Final bounds: never below the word floor, never past the remaining words, + # and always at least one word so the loop is guaranteed to make progress. + chunk_word_count = min(max(min_words, chunk_word_count), remaining) + chunk_word_count = max(1, chunk_word_count) # Create the chunk chunk_words = words[i:i + chunk_word_count] chunk_text = ' '.join(chunk_words) - # If chunk is too long (> 50 chars), split it - if len(chunk_text) > 50 and len(chunk_words) > min_words: - # Use fewer words - chunk_word_count = max(min_words, len(chunk_words) // 2) - chunk_words = words[i:i + chunk_word_count] - chunk_text = ' '.join(chunk_words) - # Add leading space to first chunk if i == 0 and leading_space > 0: chunk_text = ' ' * leading_space + chunk_text diff --git a/handwriting_synthesis/hand/operations/sampling.py b/handwriting_synthesis/hand/operations/sampling.py index aa4f724..95e6d08 100644 --- a/handwriting_synthesis/hand/operations/sampling.py +++ b/handwriting_synthesis/hand/operations/sampling.py @@ -12,7 +12,8 @@ def sample_strokes( rnn_model, lines: List[str], biases: Optional[List[float]] = None, - styles: Optional[List[int]] = None + styles: Optional[List[int]] = None, + return_char_indices: bool = False ) -> List[np.ndarray]: """ Sample stroke sequences from the RNN model. @@ -28,10 +29,16 @@ def sample_strokes( consistency of the handwriting. Higher bias -> more legible, less random. styles: Optional list of style IDs (one per line). + return_char_indices: If True, also return the character indices per stroke + (from the attention mechanism's phi weights). Returns: - List of stroke sequences (numpy arrays of shape [T, 3]). - Each stroke point is (x, y, eos). + If return_char_indices is False: + List of stroke sequences (numpy arrays of shape [T, 3]). + Each stroke point is (x, y, eos). + If return_char_indices is True: + Tuple of (strokes_list, char_indices_list) where char_indices_list + contains the character index the model was attending to at each stroke. """ num_samples = len(lines) max_tsteps = 40 * max([len(i) for i in lines]) @@ -62,18 +69,37 @@ def sample_strokes( chars[i, :len(encoded)] = encoded chars_len[i] = len(encoded) - [samples] = rnn_session.run( - [rnn_model.sampled_sequence], - feed_dict={ - rnn_model.prime: styles is not None, - rnn_model.x_prime: x_prime, - rnn_model.x_prime_len: x_prime_len, - rnn_model.num_samples: num_samples, - rnn_model.sample_tsteps: max_tsteps, - rnn_model.c: chars, - rnn_model.c_len: chars_len, - rnn_model.bias: biases - } - ) - samples = [sample[~np.all(sample == 0.0, axis=1)] for sample in samples] - return samples + feed_dict = { + rnn_model.prime: styles is not None, + rnn_model.x_prime: x_prime, + rnn_model.x_prime_len: x_prime_len, + rnn_model.num_samples: num_samples, + rnn_model.sample_tsteps: max_tsteps, + rnn_model.c: chars, + rnn_model.c_len: chars_len, + rnn_model.bias: biases + } + + if return_char_indices: + # Fetch both stroke samples and character indices from attention + samples, char_indices = rnn_session.run( + [rnn_model.sampled_sequence, rnn_model.sampled_char_indices], + feed_dict=feed_dict + ) + # Remove zero-padded strokes (and corresponding char indices) + strokes_list = [] + char_indices_list = [] + for sample, ci in zip(samples, char_indices): + # Find non-zero strokes + valid_mask = ~np.all(sample == 0.0, axis=1) + strokes_list.append(sample[valid_mask]) + char_indices_list.append(ci[valid_mask]) + return strokes_list, char_indices_list + else: + # Original behavior: only fetch stroke samples + [samples] = rnn_session.run( + [rnn_model.sampled_sequence], + feed_dict=feed_dict + ) + samples = [sample[~np.all(sample == 0.0, axis=1)] for sample in samples] + return samples diff --git a/handwriting_synthesis/rnn/RNN.py b/handwriting_synthesis/rnn/RNN.py index 1263bfb..5278aa2 100644 --- a/handwriting_synthesis/rnn/RNN.py +++ b/handwriting_synthesis/rnn/RNN.py @@ -58,6 +58,7 @@ def __init__( self.initial_state = None self.final_state = None self.sampled_sequence = None + self.sampled_char_indices = None self.lstm_size = lstm_size self.output_mixture_components = output_mixture_components self.output_units = self.output_mixture_components * 6 + 1 @@ -149,20 +150,26 @@ def sample(self, cell): cell: The RNN cell to use for sampling. Returns: - Sampled sequence tensor. + Tuple of (sampled_sequence, char_indices) where: + - sampled_sequence: The stroke outputs + - char_indices: Character index per timestep from attention (argmax of phi) """ initial_state = cell.zero_state(self.num_samples, dtype=tf.float32) initial_input = tf.concat([ tf.zeros([self.num_samples, 2]), tf.ones([self.num_samples, 1]), ], axis=1) - return rnn_free_run( + states, outputs, final_state = rnn_free_run( cell=cell, sequence_length=self.sample_tsteps, initial_state=initial_state, initial_input=initial_input, scope='rnn' - )[1] + ) + # Extract char_indices from phi: states.phi has shape [batch, timesteps, char_len] + # argmax gives us which character the model is attending to at each timestep + char_indices = tf.argmax(states.phi, axis=2) # [batch, timesteps] + return outputs, char_indices def primed_sample(self, cell): """ @@ -172,7 +179,9 @@ def primed_sample(self, cell): cell: The RNN cell to use for sampling. Returns: - Sampled sequence tensor. + Tuple of (sampled_sequence, char_indices) where: + - sampled_sequence: The stroke outputs + - char_indices: Character index per timestep from attention (argmax of phi) """ initial_state = cell.zero_state(self.num_samples, dtype=tf.float32) primed_state = tfcompat.nn.dynamic_rnn( @@ -183,12 +192,15 @@ def primed_sample(self, cell): initial_state=initial_state, scope='rnn' )[1] - return rnn_free_run( + states, outputs, final_state = rnn_free_run( cell=cell, sequence_length=self.sample_tsteps, initial_state=primed_state, scope='rnn' - )[1] + ) + # Extract char_indices from phi: states.phi has shape [batch, timesteps, char_len] + char_indices = tf.argmax(states.phi, axis=2) # [batch, timesteps] + return outputs, char_indices def calculate_loss(self): """ @@ -236,9 +248,18 @@ def calculate_loss(self): pis, mus, sigmas, rhos, es = self.parse_parameters(params) sequence_loss, self.loss = self.nll(self.y, self.x_len, pis, mus, sigmas, rhos, es) + # Sample returns (outputs, char_indices) - use tf.cond on each + primed_outputs, primed_char_indices = self.primed_sample(cell) + unprimed_outputs, unprimed_char_indices = self.sample(cell) + self.sampled_sequence = tf.cond( self.prime, - lambda: self.primed_sample(cell), - lambda: self.sample(cell) + lambda: primed_outputs, + lambda: unprimed_outputs + ) + self.sampled_char_indices = tf.cond( + self.prime, + lambda: primed_char_indices, + lambda: unprimed_char_indices ) return self.loss diff --git a/tests/test_operations.py b/tests/test_operations.py new file mode 100644 index 0000000..18cb841 --- /dev/null +++ b/tests/test_operations.py @@ -0,0 +1,162 @@ +"""Tests for the text-chunking operation used to split text before generation. + +These tests are model-free -- they exercise the pure wrapping/sizing logic and +never load the RNN -- so they run fast and anywhere. Run with: + + pytest tests/test_operations.py # if pytest is installed + python tests/test_operations.py # standalone fallback runner +""" + +import os +import sys + +# Make the project importable when run directly (python tests/test_operations.py). +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) + +from handwriting_synthesis.hand.operations.chunking import ( + split_text_into_chunks, + balanced_line_breaks, +) + + +# Tokens longer than this are hard-split by the chunker (see chunking.py). +def _hard_cap(target_chars): + return max(int(target_chars) * 2, 40) + + +PUNCTUATED = "I went home, then I slept, and later, after dinner, I read a long book." + + +def _max_token_len(chunks): + return max((len(tok) for c in chunks for tok in c.split()), default=0) + + +def test_balanced_breaks_at_punctuation(): + """'balanced' (the default) must use punctuation breaks, not only sentences. + + Regression guard for the unreachable elif branch that previously made the + punctuation logic dead code whenever the strategy was 'balanced'. + """ + chunks = split_text_into_chunks( + PUNCTUATED, words_per_chunk=3, target_chars_per_chunk=25, + adaptive_strategy='balanced', + ) + # At least one boundary lands right after a comma -> punctuation awareness ran. + assert any(c.rstrip().endswith(',') for c in chunks), chunks + + +def test_target_chars_is_honoured(): + """Chunks should stay near the character target (down to the min-word floor).""" + target = 20 + min_words = 2 + chunks = split_text_into_chunks( + PUNCTUATED, words_per_chunk=4, target_chars_per_chunk=target, + min_words=min_words, max_words=8, adaptive_strategy='balanced', + ) + for c in chunks: + # A chunk may exceed the soft cap only if it is already at the min-word floor. + assert len(c) <= target or len(c.split()) <= min_words, (c, len(c)) + + +def test_long_word_is_hard_split(): + """A space-less token longer than the hard cap must be broken up.""" + target = 25 + url = "see https://example.com/a/very/long/path/that/keeps/going/and/going/forever/" + chunks = split_text_into_chunks(url, words_per_chunk=3, target_chars_per_chunk=target) + assert _max_token_len(chunks) <= _hard_cap(target), chunks + # Reassembling the tokens must preserve the original characters (no loss). + assert "".join("".join(c.split()) for c in chunks) == url.replace(" ", "") + + +def test_normal_long_word_is_not_split(): + """A legitimately long word (shorter than the hard cap) stays intact.""" + word = "internationalization" # 20 chars, under the 50 hard cap + chunks = split_text_into_chunks("the " + word + " process", target_chars_per_chunk=25) + assert any(word in c for c in chunks), chunks + + +def test_off_strategy_is_fixed_size(): + chunks = split_text_into_chunks( + "one two three four five six seven", words_per_chunk=3, adaptive_strategy='off', + ) + assert chunks == ["one two three", "four five six", "seven"], chunks + + +def test_sentence_strategy_respects_budget(): + chunks = split_text_into_chunks( + PUNCTUATED, words_per_chunk=3, target_chars_per_chunk=25, adaptive_strategy='sentence', + ) + assert all(len(c) <= 25 or len(c.split()) <= 2 for c in chunks), chunks + + +def test_whitespace_and_empty_inputs(): + assert split_text_into_chunks("") == [] + assert split_text_into_chunks(" ") == [" "] + lead_trail = split_text_into_chunks(" hello world there ", words_per_chunk=2) + assert lead_trail[0].startswith(" "), lead_trail + assert lead_trail[-1].endswith(" "), lead_trail + + +def _line_widths(widths, spacing, breaks): + out = [] + for i, j in breaks: + w = sum(widths[i:j]) + spacing * (j - i - 1) + out.append(w) + return out + + +def test_balanced_breaks_cover_all_chunks_in_order(): + widths = [90.0, 110.0, 100.0, 95.0, 105.0, 80.0, 120.0] + breaks = balanced_line_breaks(widths, 8.0, target=250.0, limit=260.0) + flat = [k for i, j in breaks for k in range(i, j)] + assert flat == list(range(len(widths))), breaks + # No line exceeds the limit (none of these single chunks is oversized) + assert all(w <= 260.0 for w in _line_widths(widths, 8.0, breaks)), breaks + + +def test_balanced_breaks_spread_slack(): + """DP must not leave one line nearly empty when even splits exist. + + Greedy on these widths gives lines of 240 and 60; balanced breaking + should split 150/150 (both near-ish target, far better balance). + """ + widths = [120.0, 120.0, 30.0, 30.0] + breaks = balanced_line_breaks(widths, 0.0, target=160.0, limit=240.0) + line_w = _line_widths(widths, 0.0, breaks) + assert len(line_w) >= 2 + # the non-final lines must be closer to target than greedy's worst case + assert min(line_w[:-1]) >= 120.0, line_w + + +def test_balanced_breaks_oversized_chunk_gets_own_line(): + widths = [50.0, 500.0, 50.0] + breaks = balanced_line_breaks(widths, 5.0, target=200.0, limit=210.0) + assert (1, 2) in breaks, breaks # the huge chunk stands alone + + +def test_balanced_breaks_empty_and_single(): + assert balanced_line_breaks([], 5.0, 100.0, 105.0) == [] + assert balanced_line_breaks([42.0], 5.0, 100.0, 105.0) == [(0, 1)] + + +def test_progress_guaranteed_with_degenerate_min_words(): + """min_words=0 must not cause an infinite loop.""" + chunks = split_text_into_chunks( + "a b c d e", words_per_chunk=2, min_words=0, target_chars_per_chunk=5, + ) + assert "".join("".join(c.split()) for c in chunks) == "abcde", chunks + + +if __name__ == '__main__': + tests = [v for k, v in sorted(globals().items()) + if k.startswith('test_') and callable(v)] + failures = 0 + for fn in tests: + try: + fn() + print(f"PASS {fn.__name__}") + except Exception as exc: # noqa: BLE001 - report and continue + failures += 1 + print(f"FAIL {fn.__name__}: {type(exc).__name__}: {exc}") + print(f"\n{len(tests) - failures}/{len(tests)} passed") + sys.exit(1 if failures else 0) diff --git a/tests/test_sizing.py b/tests/test_sizing.py new file mode 100644 index 0000000..352f473 --- /dev/null +++ b/tests/test_sizing.py @@ -0,0 +1,201 @@ +"""Dimensional tests for the natural-handwriting sizing in `_draw`. + +Model-free: synthetic stroke arrays with a known body height are fed through +`_draw`, then the rendered SVG is parsed and measured in page pixels. These pin +down the *behaviour* of the sizing/spacing logic (consistent x-height, spacing +proportional to size, width does not shrink everything, shrink-to-fit-one-page, +manual scale as a multiple of natural) without judging visual "naturalness". + +Run: `python tests/test_sizing.py` or `pytest tests/test_sizing.py`. +""" + +import os +import re +import sys +import tempfile +import xml.etree.ElementTree as ET + +import numpy as np + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) + +from handwriting_synthesis import drawing +from handwriting_synthesis.hand import _draw as draw_mod +from handwriting_synthesis.hand._draw import ( + _draw, PX_PER_MM, NATURAL_WRITING_SIZE_MM, + solve_fill_xheight_px, LINE_SPACING_PER_XHEIGHT, +) + +_COORD = re.compile(r'[ML]\s*([-\d.]+)[\s,]+([-\d.]+)') + + +def _segment(text, width_units=180.0, xheight=20.0, ascender_frac=0.0, n=240, line_idx=0): + """A synthetic generated segment. + + Most points sit in the body band [0, xheight]; a fraction are pushed up to + ~2x xheight to emulate ascenders, so the 10..90 percentile band ~= xheight. + """ + xs = np.linspace(0.0, width_units, n) + ys = np.abs(np.sin(np.linspace(0.0, 9 * np.pi, n))) * xheight # body in [0, xheight] + if ascender_frac > 0: + k = max(1, int(n * ascender_frac)) + ys[np.linspace(0, n - 1, k).astype(int)] = 2.0 * xheight # occasional tall strokes + coords = np.stack([xs, ys, np.zeros(n)], axis=1) + coords[-1, 2] = 1.0 + return {'type': 'generated', 'text': text, 'strokes': drawing.coords_to_offsets(coords), + 'line_idx': line_idx} + + +def _render(line_segments, page=(210.0, 297.0), margins=20.0, **kw): + fd, path = tempfile.mkstemp(suffix='.svg') + os.close(fd) + opts = dict(page_size=list(page), units='mm', margins=margins, align='left', + legibility='high', denoise=False, auto_size=True, background='white') + opts.update(kw) + _draw(line_segments, [s[0].get('text', '') for s in line_segments], path, **opts) + return path + + +def _line_y_bands(path): + """Return, per top-level path, (min_y, max_y, p10, p90) of its point y-coords.""" + root = ET.parse(path).getroot() + bands = [] + for child in root: + if child.tag.split('}')[-1] != 'path': + continue + ys = [float(y) for _, y in _COORD.findall(child.get('d', ''))] + if ys: + a = np.array(ys) + bands.append((a.min(), a.max(), float(np.percentile(a, 10)), float(np.percentile(a, 90)))) + return bands + + +def _rendered_xheight_mm(path): + """Median rendered 10..90 body band across lines, in mm.""" + bands = _line_y_bands(path) + assert bands, "no generated paths in output" + spans_px = [p90 - p10 for (_, _, p10, p90) in bands] + return float(np.median(spans_px)) / PX_PER_MM + + +def _baseline_spacing_mm(path): + """Median gap between consecutive lines' bottoms (a baseline proxy), in mm.""" + bands = sorted(_line_y_bands(path), key=lambda b: b[1]) + bottoms = [b[1] for b in bands] + if len(bottoms) < 2: + return None + gaps = np.diff(bottoms) + return float(np.median(gaps)) / PX_PER_MM + + +def test_default_xheight_is_natural(): + """Default auto-size renders a body x-height ~= NATURAL_WRITING_SIZE_MM.""" + segs = [[_segment("hello world", line_idx=i)] for i in range(4)] + xh = _rendered_xheight_mm(_render(segs)) + assert abs(xh - NATURAL_WRITING_SIZE_MM) <= 0.9, xh + + +def test_writing_size_mm_controls_size(): + """Rendered x-height tracks the writing_size_mm knob (and bigger = bigger).""" + segs = [[_segment("hello world", line_idx=i)] for i in range(4)] + small = _rendered_xheight_mm(_render(segs, writing_size_mm=3.0)) + big = _rendered_xheight_mm(_render(segs, writing_size_mm=6.0)) + assert abs(small - 3.0) <= 0.8, small + assert abs(big - 6.0) <= 1.0, big + assert big > small + 1.5 + + +def test_xheight_consistent_despite_ascenders(): + """A line full of ascenders must not shrink the document's x-height. + + This is the core fix: previously the tallest line set a global min scale that + shrank everything; now sizing uses a robust body height per line. + """ + no_asc = [[_segment("aaa eee ooo", ascender_frac=0.0, line_idx=i)] for i in range(4)] + with_asc = [[_segment("aaa eee ooo", ascender_frac=0.0, line_idx=0)], + [_segment("llll kkkk hhhh", ascender_frac=0.30, line_idx=1)], + [_segment("aaa eee ooo", ascender_frac=0.0, line_idx=2)], + [_segment("aaa eee ooo", ascender_frac=0.0, line_idx=3)]] + xh_plain = _rendered_xheight_mm(_render(no_asc)) + xh_mixed = _rendered_xheight_mm(_render(with_asc)) + assert abs(xh_plain - xh_mixed) <= 0.7, (xh_plain, xh_mixed) + + +def test_one_wide_line_does_not_shrink_others(): + """A single very long line must not shrink the size of the normal lines.""" + normal = [[_segment("hello", width_units=120.0, line_idx=i)] for i in range(5)] + with_wide = [[_segment("hello", width_units=120.0, line_idx=0)], + [_segment("x" * 50, width_units=2000.0, line_idx=1)]] # one huge outlier + with_wide += [[_segment("hello", width_units=120.0, line_idx=i)] for i in range(2, 5)] + xh_normal = _rendered_xheight_mm(_render(normal)) + xh_wide = _rendered_xheight_mm(_render(with_wide)) + # The outlier is condensed per-line, not allowed to shrink everyone. + assert xh_wide >= xh_normal - 0.6, (xh_normal, xh_wide) + + +def test_spacing_tracks_size_when_auto(): + """With auto line height, baseline spacing scales with the writing size.""" + segs = [[_segment("hello world", line_idx=i)] for i in range(5)] + sp_small = _baseline_spacing_mm(_render(segs, writing_size_mm=3.0)) + sp_big = _baseline_spacing_mm(_render(segs, writing_size_mm=6.0)) + assert sp_small and sp_big and sp_big > sp_small + 2.0, (sp_small, sp_big) + + +def test_shrinks_to_fit_one_page(): + """Many lines on a short page shrink (size + spacing) to stay on one page.""" + many = [[_segment("hello world", line_idx=i)] for i in range(60)] + path = _render(many, page=(210.0, 297.0), margins=20.0, writing_size_mm=5.0) + bands = _line_y_bands(path) + max_y_px = max(b[1] for b in bands) + page_h_px = 297.0 * PX_PER_MM + assert max_y_px <= page_h_px + 2.0, (max_y_px, page_h_px) # stays on page + assert _rendered_xheight_mm(path) < 5.0 # was scaled down + + +def test_fill_solver_fills_target_height(): + """The solved x-height plugs back into the height model at the fill target.""" + W, mxh, content_w, content_h = 5000.0, 20.0, 600.0, 900.0 + h = solve_fill_xheight_px(W, mxh, 0, content_w, content_h, fill_frac=0.92) + assert h and h > 0 + n_lines = W * h / (mxh * content_w) + height = n_lines * LINE_SPACING_PER_XHEIGHT * h + assert abs(height - 0.92 * content_h) < 1e-6, (height, 0.92 * content_h) + + +def test_fill_solver_monotonic(): + """More text or more blank lines -> smaller solved size; both reduce h.""" + args = dict(model_xheight=20.0, content_width_px=600.0, content_height_px=900.0) + h_short = solve_fill_xheight_px(2000.0, n_blank_lines=0, **args) + h_long = solve_fill_xheight_px(20000.0, n_blank_lines=0, **args) + h_blanks = solve_fill_xheight_px(2000.0, n_blank_lines=5, **args) + assert h_long < h_short, (h_long, h_short) + assert h_blanks < h_short, (h_blanks, h_short) + + +def test_fill_solver_degenerate_inputs(): + assert solve_fill_xheight_px(0.0, 20.0, 0, 600.0, 900.0) is None + assert solve_fill_xheight_px(100.0, 0.0, 0, 600.0, 900.0) is None + assert solve_fill_xheight_px(100.0, 20.0, 0, 0.0, 900.0) is None + + +def test_manual_scale_is_multiple_of_natural(): + """auto_size=False: manual_size_scale=2 renders ~2x the natural size.""" + segs = [[_segment("hello world", line_idx=i)] for i in range(3)] + natural = _rendered_xheight_mm(_render(segs, auto_size=True)) + doubled = _rendered_xheight_mm(_render(segs, auto_size=False, manual_size_scale=2.0)) + assert doubled > natural * 1.6, (natural, doubled) + + +if __name__ == '__main__': + tests = [v for k, v in sorted(globals().items()) + if k.startswith('test_') and callable(v)] + failures = 0 + for fn in tests: + try: + fn() + print(f"PASS {fn.__name__}") + except Exception as exc: # noqa: BLE001 + failures += 1 + print(f"FAIL {fn.__name__}: {type(exc).__name__}: {exc}") + print(f"\n{len(tests) - failures}/{len(tests)} passed") + sys.exit(1 if failures else 0) diff --git a/webapp/init_db.py b/webapp/init_db.py index c101c7f..f85a70a 100644 --- a/webapp/init_db.py +++ b/webapp/init_db.py @@ -6,6 +6,7 @@ """ import os import sys +from datetime import datetime from getpass import getpass import warnings @@ -41,30 +42,153 @@ def get_password_input(prompt="Password: "): return input(prompt).strip() -def init_database(): +def _placeholder_for(column): + """Return a safe non-null backfill value for a newly-added NOT NULL column.""" + from sqlalchemy import Integer, Numeric, Float, Boolean, DateTime, Date + col_type = column.type + if isinstance(col_type, (Integer, Numeric, Float)): + return 0 + if isinstance(col_type, Boolean): + return False + if isinstance(col_type, (DateTime, Date)): + return datetime.utcnow() + return '' # strings/text and anything else + + +def _reconcile_missing_columns(): + """Add columns present in the models but missing from existing tables. + + ``db.create_all()`` creates new tables but never ALTERs existing ones, so a DB + created against older models is left missing newly-added columns -- which is + exactly how ``users.email`` went missing and made every page 500. For each + existing table we add any missing column (NOT NULL columns are backfilled so + the ALTER succeeds on populated tables; unique columns get a unique index when + the current values allow it). Column drops / renames / type changes are NOT + handled here -- those need a real Alembic migration. """ - Initialize the database tables and run migrations. + from sqlalchemy import inspect, text + + inspector = inspect(db.engine) + existing_tables = set(inspector.get_table_names()) + added = [] + + for table in db.metadata.sorted_tables: + if table.name not in existing_tables: + continue # brand-new table: db.create_all() already created it + db_cols = {c['name'] for c in inspector.get_columns(table.name)} + for column in table.columns: + if column.name in db_cols: + continue + col_type = column.type.compile(dialect=db.engine.dialect) + with db.engine.begin() as conn: + conn.execute(text( + f'ALTER TABLE "{table.name}" ADD COLUMN "{column.name}" {col_type}')) + if not column.nullable: + conn.execute( + text(f'UPDATE "{table.name}" SET "{column.name}" = :val ' + f'WHERE "{column.name}" IS NULL'), + {"val": _placeholder_for(column)}) + if column.unique: + dupes = conn.execute(text( + f'SELECT COUNT(*) - COUNT(DISTINCT "{column.name}") ' + f'FROM "{table.name}"')).scalar() + if not dupes: + conn.execute(text( + f'CREATE UNIQUE INDEX IF NOT EXISTS ' + f'"ix_{table.name}_{column.name}" ' + f'ON "{table.name}" ("{column.name}")')) + else: + print(f" [WARN] added {table.name}.{column.name} but left it " + f"non-unique: existing rows have blank/duplicate values; " + f"set them and add a unique index manually.") + added.append(f"{table.name}.{column.name}") + + if added: + print(f" Added missing columns: {', '.join(added)}") + else: + print(" Schema already matches models (no missing columns).") + return added + - Attempts to run Alembic migrations first. If that fails (e.g., first run), - falls back to SQLAlchemy's `db.create_all()`. +def init_database(): + """ + Bring the database schema up to date from any starting state. + + Handles all three cases the app can encounter: + * Fresh DB, or a legacy DB created by db.create_all() with no Alembic stamp: + build the schema directly from the models (creating missing tables AND + adding columns missing from existing tables), then stamp Alembic head so + future `flask db upgrade` works. + * Alembic-managed DB: apply any pending migrations with `upgrade head`. + + The previous version ran migrations first and fell back to db.create_all() on + error, which could not ALTER existing tables and silently left the schema out + of date (the users.email outage). """ with app.app_context(): - print("Running database migrations...") - from alembic.config import Config - from alembic import command - - # Get the alembic config - alembic_cfg = Config(os.path.join(PROJECT_ROOT, "alembic.ini")) - - try: - # Run all pending migrations - command.upgrade(alembic_cfg, "head") - print("Database migrations completed successfully!") - except Exception as e: - print(f"Error running migrations: {e}") - print("\nFalling back to db.create_all()...") - db.create_all() - print("Database tables created successfully!") + # Use Flask-Migrate's helpers (not a hand-built alembic Config): they use + # the Migrate extension's configured migrations/ directory. The old code + # pointed Config at webapp/alembic.ini -> webapp/alembic/env.py, which does + # not exist, so every `upgrade` failed and silently fell back to + # create_all() -- the reason the schema drifted (users.email outage). + from flask_migrate import upgrade as fm_upgrade, stamp as fm_stamp + from alembic.runtime.migration import MigrationContext + + with db.engine.connect() as conn: + current_rev = MigrationContext.configure(conn).get_current_revision() + + if current_rev is None: + print("No Alembic revision found - syncing schema directly from models...") + db.create_all() # create any missing tables + _reconcile_missing_columns() # add columns missing from existing tables + fm_stamp(revision="head") # mark as current so future upgrades work + print("Schema synced from models and stamped to Alembic head.") + else: + print(f"Alembic revision {current_rev} - applying any pending migrations...") + try: + fm_upgrade() # to head + print("Database is at Alembic head.") + except Exception as e: + print(f"Error applying migrations: {e}") + print("Falling back to model-based schema sync...") + db.create_all() + _reconcile_missing_columns() + + +# Standard page sizes the UI expects (names must match the engine's PAPER_SIZES_MM +# and the frontend's predefined-size list so they resolve correctly). +DEFAULT_PAGE_SIZES = [ + ('A4', 210.0, 297.0), + ('A5', 148.0, 210.0), + ('Letter', 215.9, 279.4), + ('Legal', 215.9, 355.6), +] + + +def seed_default_page_sizes(): + """Seed the standard system page sizes if they are missing. + + Without these the page-size dropdown in the UI is empty, which forces every + generation onto the A4 fallback and hides the size options. Idempotent: only + inserts names that are not already present, so it is safe to run on every init. + """ + from models import PageSizePreset + with app.app_context(): + existing = {row[0] for row in db.session.query(PageSizePreset.name).all()} + created = [] + for name, width, height in DEFAULT_PAGE_SIZES: + if name in existing: + continue + db.session.add(PageSizePreset( + name=name, width=width, height=height, unit='mm', + is_active=True, is_default=True, created_by=None, + )) + created.append(name) + if created: + db.session.commit() + print(f"Seeded default page sizes: {', '.join(created)}") + else: + print("Default page sizes already present.") def create_admin_user(): @@ -202,8 +326,11 @@ def main(): # Initialize database init_database() + # Seed system defaults the UI depends on (page-size dropdown). + seed_default_page_sizes() + if args.auto: - # Automatic mode - just run migrations and exit + # Automatic mode - schema + system defaults, then exit print("Database initialization completed (auto mode)") return diff --git a/webapp/instance/writebot.db.pre-email-fix.20260531_092148.bak b/webapp/instance/writebot.db.pre-email-fix.20260531_092148.bak new file mode 100755 index 0000000..83d2922 Binary files /dev/null and b/webapp/instance/writebot.db.pre-email-fix.20260531_092148.bak differ diff --git a/webapp/routes/admin_routes.py b/webapp/routes/admin_routes.py index 4663033..a897353 100644 --- a/webapp/routes/admin_routes.py +++ b/webapp/routes/admin_routes.py @@ -1,7 +1,10 @@ """ Admin routes for user management and statistics. """ -from flask import Blueprint, render_template, redirect, url_for, flash, request, jsonify +import os +import re +import glob +from flask import Blueprint, render_template, redirect, url_for, flash, request, jsonify, current_app from flask_login import login_required, current_user from webapp.models import db, User, UserActivity, UsageStatistics, PageSizePreset, TemplatePreset from webapp.utils.auth_utils import admin_required, log_activity, get_user_statistics, get_user_activities, get_all_user_statistics @@ -760,3 +763,205 @@ def delete_template(template_id): log_activity('admin_action', f'Deleted template preset: {name} (ID: {template_id})') flash(f'Template preset "{name}" deleted successfully.', 'success') return redirect(url_for('admin.templates')) + + +# ============================================================================ +# Error Logs Management +# ============================================================================ + +def strip_ansi_codes(text): + """Remove ANSI escape codes from text.""" + ansi_pattern = re.compile(r'\x1b\[[0-9;]*m') + return ansi_pattern.sub('', text) + + +def parse_log_line(line): + """Parse a log line and extract level and content.""" + clean_line = strip_ansi_codes(line) + + # Determine log level based on content + level = 'info' + if 'ERROR' in clean_line.upper() or '500' in clean_line: + level = 'error' + elif 'WARNING' in clean_line.upper() or '400' in clean_line or '404' in clean_line: + level = 'warning' + elif 'DEBUG' in clean_line.upper(): + level = 'debug' + + return { + 'raw': line, + 'clean': clean_line, + 'level': level + } + + +@admin_bp.route('/logs') +@login_required +@admin_required +def error_logs(): + """ + View application error logs. + + Displays log files with filtering and search capabilities. + """ + # Get logs directory + logs_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs') + + # Get list of available log files + log_files = [] + if os.path.exists(logs_dir): + for filename in sorted(os.listdir(logs_dir), reverse=True): + if filename.endswith('.txt') or filename.endswith('.log'): + filepath = os.path.join(logs_dir, filename) + stat = os.stat(filepath) + log_files.append({ + 'name': filename, + 'size': stat.st_size, + 'modified': datetime.fromtimestamp(stat.st_mtime), + 'size_human': f"{stat.st_size / 1024:.1f} KB" if stat.st_size < 1024 * 1024 else f"{stat.st_size / (1024 * 1024):.1f} MB" + }) + + # Get selected log file (default to most recent) + selected_file = request.args.get('file', '') + if not selected_file and log_files: + selected_file = log_files[0]['name'] + + # Get filter parameters + level_filter = request.args.get('level', 'all') + search_query = request.args.get('search', '').strip() + lines_limit = request.args.get('limit', 500, type=int) + + # Read log file contents + log_entries = [] + total_lines = 0 + error_count = 0 + warning_count = 0 + + if selected_file: + filepath = os.path.join(logs_dir, selected_file) + # Security check: ensure the file is within logs_dir + real_logs_dir = os.path.realpath(logs_dir) + real_filepath = os.path.realpath(filepath) + if not real_filepath.startswith(real_logs_dir): + flash('Invalid log file path.', 'error') + return redirect(url_for('admin.error_logs')) + + if os.path.exists(filepath): + try: + with open(filepath, 'r', encoding='utf-8', errors='replace') as f: + lines = f.readlines() + + total_lines = len(lines) + + # Process lines in reverse order (newest first) + for line in reversed(lines): + if not line.strip(): + continue + + entry = parse_log_line(line) + + # Update counts + if entry['level'] == 'error': + error_count += 1 + elif entry['level'] == 'warning': + warning_count += 1 + + # Apply filters + if level_filter != 'all' and entry['level'] != level_filter: + continue + + if search_query and search_query.lower() not in entry['clean'].lower(): + continue + + log_entries.append(entry) + + # Limit entries + if len(log_entries) >= lines_limit: + break + + except Exception as e: + current_app.logger.exception(f'Error reading log file: {selected_file}') + flash(f'Error reading log file.', 'error') + + log_activity('admin_action', f'Viewed error logs: {selected_file}') + + return render_template('admin/logs.html', + active_nav='logs', + log_files=log_files, + selected_file=selected_file, + log_entries=log_entries, + total_lines=total_lines, + error_count=error_count, + warning_count=warning_count, + level_filter=level_filter, + search_query=search_query, + lines_limit=lines_limit) + + +@admin_bp.route('/logs/download/') +@login_required +@admin_required +def download_log(filename): + """ + Download a log file. + + Args: + filename: Name of the log file to download. + """ + from flask import send_file + + logs_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs') + filepath = os.path.join(logs_dir, filename) + + # Security check: ensure the file is within logs_dir + real_logs_dir = os.path.realpath(logs_dir) + real_filepath = os.path.realpath(filepath) + if not real_filepath.startswith(real_logs_dir): + flash('Invalid log file path.', 'error') + return redirect(url_for('admin.error_logs')) + + if not os.path.exists(filepath): + flash('Log file not found.', 'error') + return redirect(url_for('admin.error_logs')) + + log_activity('admin_action', f'Downloaded log file: {filename}') + + return send_file(filepath, as_attachment=True, download_name=filename) + + +@admin_bp.route('/logs/clear/', methods=['POST']) +@login_required +@admin_required +def clear_log(filename): + """ + Clear (truncate) a log file. + + Args: + filename: Name of the log file to clear. + """ + logs_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs') + filepath = os.path.join(logs_dir, filename) + + # Security check: ensure the file is within logs_dir + real_logs_dir = os.path.realpath(logs_dir) + real_filepath = os.path.realpath(filepath) + if not real_filepath.startswith(real_logs_dir): + flash('Invalid log file path.', 'error') + return redirect(url_for('admin.error_logs')) + + if not os.path.exists(filepath): + flash('Log file not found.', 'error') + return redirect(url_for('admin.error_logs')) + + try: + # Truncate the file + with open(filepath, 'w') as f: + f.write(f'# Log cleared by {current_user.username} at {datetime.now().isoformat()}\n') + + log_activity('admin_action', f'Cleared log file: {filename}') + flash(f'Log file "{filename}" has been cleared.', 'success') + except Exception as e: + current_app.logger.exception(f'Error clearing log file: {filename}') + flash('Error clearing log file.', 'error') + + return redirect(url_for('admin.error_logs', file=filename)) diff --git a/webapp/routes/batch_routes.py b/webapp/routes/batch_routes.py index 3489ebd..f9929c8 100644 --- a/webapp/routes/batch_routes.py +++ b/webapp/routes/batch_routes.py @@ -9,7 +9,7 @@ import json import uuid from typing import List, Tuple, Dict, Any -from flask import Blueprint, jsonify, request, send_file, Response, stream_with_context +from flask import Blueprint, jsonify, request, send_file, Response, stream_with_context, current_app from flask_login import login_required from werkzeug.utils import secure_filename @@ -687,11 +687,10 @@ def batch_stream(): else: return jsonify({"error": "File must be CSV or XLSX format"}), 400 - print(f"DEBUG: File parsed successfully. Rows: {len(df)}, Columns: {list(df.columns)}") + current_app.logger.debug(f"File parsed successfully. Rows: {len(df)}, Columns: {list(df.columns)}") except Exception as e: - error_msg = f"Failed to read file: {e}" - print(f"ERROR: {error_msg}") - return jsonify({"error": error_msg}), 400 + current_app.logger.exception('Failed to read uploaded batch file') + return jsonify({"error": "Failed to read file. Please ensure it's a valid CSV or XLSX format."}), 400 # Get defaults from form, but filter out None/empty values defaults = {k: v for k, v in request.form.to_dict(flat=True).items() if v} diff --git a/webapp/routes/character_override_routes.py b/webapp/routes/character_override_routes.py index 3a811ab..e0a8c61 100644 --- a/webapp/routes/character_override_routes.py +++ b/webapp/routes/character_override_routes.py @@ -1,7 +1,7 @@ """ Admin routes for managing character override collections. """ -from flask import Blueprint, render_template, redirect, url_for, flash, request, jsonify +from flask import Blueprint, render_template, redirect, url_for, flash, request, jsonify, current_app from flask_login import login_required, current_user from webapp.models import db, CharacterOverrideCollection, CharacterOverride from webapp.utils.auth_utils import admin_required, log_activity @@ -454,8 +454,8 @@ def save_drawn_character(collection_id): return jsonify({'success': True, 'message': f'Character "{character}" saved successfully.'}), 200 except Exception as e: - print(f"Error saving drawn character: {e}") - return jsonify({'error': str(e)}), 500 + current_app.logger.exception('Error saving drawn character') + return jsonify({'error': 'Failed to save character'}), 500 @character_override_bp.route('/character//delete', methods=['POST']) diff --git a/webapp/routes/generation_routes.py b/webapp/routes/generation_routes.py index fa581b2..86500b9 100644 --- a/webapp/routes/generation_routes.py +++ b/webapp/routes/generation_routes.py @@ -6,7 +6,7 @@ import tempfile import time from typing import Any, Dict -from flask import Blueprint, jsonify, request, Response +from flask import Blueprint, jsonify, request, Response, current_app from flask_login import login_required # Ensure project root is in sys.path @@ -107,8 +107,13 @@ def api_v1_generate(): meta['generation_time_seconds'] = round(processing_time, 3) return jsonify({"svg": svg_text, "meta": meta}) - except Exception as e: + except ValueError as e: + # ValueError is typically a validation error (invalid params), safe to show + current_app.logger.warning(f'Generation validation error: {e}') return jsonify({"error": str(e)}), 400 + except Exception as e: + current_app.logger.exception('Generation error') + return jsonify({"error": "Failed to generate handwriting. Please check your parameters."}), 400 @generation_bp.route("/api/v1/generate/svg", methods=["POST"]) @@ -146,8 +151,12 @@ def api_v1_generate_svg(): log_activity('generate', f'Generated {lines_count} lines (SVG only)') return Response(svg_text, mimetype="image/svg+xml") - except Exception as e: + except ValueError as e: + current_app.logger.warning(f'Generation validation error: {e}') return jsonify({"error": str(e)}), 400 + except Exception as e: + current_app.logger.exception('Generation error (SVG)') + return jsonify({"error": "Failed to generate handwriting. Please check your parameters."}), 400 @generation_bp.route("/api/generate", methods=["POST"]) @@ -184,5 +193,9 @@ def generate_svg(): log_activity('generate', f'Generated {lines_count} lines (legacy)') return Response(svg_text, mimetype="image/svg+xml") - except Exception as e: + except ValueError as e: + current_app.logger.warning(f'Generation validation error (legacy): {e}') return jsonify({"error": str(e)}), 400 + except Exception as e: + current_app.logger.exception('Generation error (legacy)') + return jsonify({"error": "Failed to generate handwriting. Please check your parameters."}), 400 diff --git a/webapp/routes/job_routes.py b/webapp/routes/job_routes.py index 99a0e34..3f6eccf 100644 --- a/webapp/routes/job_routes.py +++ b/webapp/routes/job_routes.py @@ -115,8 +115,10 @@ def list_jobs(): } }) except Exception as e: - current_app.logger.error(f'Error loading jobs: {str(e)}') - return jsonify({'error': 'Failed to load jobs', 'details': str(e)}), 500 + # Log full exception details server-side for debugging + current_app.logger.exception('Error loading jobs') + # Return generic error message to user without internal details + return jsonify({'error': 'Failed to load jobs'}), 500 @jobs_bp.route('/api/jobs', methods=['POST']) @@ -401,5 +403,7 @@ def job_stats(): return jsonify(stats) except Exception as e: - current_app.logger.error(f'Error loading job stats: {str(e)}') + # Log full exception details server-side (including stack trace) + current_app.logger.exception('Error loading job stats') + # Return generic error with default stats to keep UI functional return jsonify({'error': 'Failed to load stats', 'pending': 0, 'queued': 0, 'processing': 0, 'completed': 0, 'failed': 0, 'cancelled': 0, 'total': 0}), 500 diff --git a/webapp/routes/presets_routes.py b/webapp/routes/presets_routes.py index acf228e..fcf0acd 100644 --- a/webapp/routes/presets_routes.py +++ b/webapp/routes/presets_routes.py @@ -1,7 +1,7 @@ """ API endpoints for page size and template presets. """ -from flask import Blueprint, jsonify, request +from flask import Blueprint, jsonify, request, current_app from flask_login import login_required, current_user from webapp.models import PageSizePreset, TemplatePreset, db from webapp.utils.auth_utils import admin_required, log_activity @@ -29,7 +29,8 @@ def list_page_sizes(): 'page_sizes': [ps.to_dict() for ps in page_sizes] }) except Exception as e: - return jsonify({'page_sizes': [], 'error': str(e)}), 500 + current_app.logger.exception('Error loading page sizes') + return jsonify({'page_sizes': [], 'error': 'Failed to load page sizes'}), 500 @presets_bp.route('/api/templates', methods=['GET']) @@ -50,7 +51,8 @@ def list_templates(): 'templates': [t.to_dict() for t in templates] }) except Exception as e: - return jsonify({'templates': [], 'error': str(e)}), 500 + current_app.logger.exception('Error loading templates') + return jsonify({'templates': [], 'error': 'Failed to load templates'}), 500 @presets_bp.route('/api/templates/', methods=['GET']) @@ -72,7 +74,8 @@ def get_template(template_id): 'template': template.to_dict() }) except Exception as e: - return jsonify({'error': str(e)}), 404 + current_app.logger.exception(f'Error loading template {template_id}') + return jsonify({'error': 'Template not found or error loading template'}), 404 @presets_bp.route('/api/templates/', methods=['PATCH']) @@ -123,7 +126,8 @@ def update_template_status(template_id): except Exception as e: db.session.rollback() - return jsonify({'error': str(e)}), 500 + current_app.logger.exception(f'Error updating template {template_id}') + return jsonify({'error': 'Failed to update template'}), 500 @presets_bp.route('/api/templates', methods=['POST']) @@ -223,7 +227,10 @@ def create_template_from_form(): }), 201 except ValueError as e: - return jsonify({'error': f'Invalid value: {str(e)}'}), 400 + # ValueError is a controlled validation error, safe to show message + current_app.logger.warning(f'Invalid value when creating template: {e}') + return jsonify({'error': 'Invalid value provided'}), 400 except Exception as e: db.session.rollback() - return jsonify({'error': str(e)}), 500 + current_app.logger.exception('Error creating template') + return jsonify({'error': 'Failed to create template'}), 500 diff --git a/webapp/routes/style_routes.py b/webapp/routes/style_routes.py index 9c8a9fb..cb5c31b 100644 --- a/webapp/routes/style_routes.py +++ b/webapp/routes/style_routes.py @@ -4,7 +4,7 @@ import sys import re from typing import List, Dict, Any -from flask import Blueprint, jsonify, send_file, Response +from flask import Blueprint, jsonify, send_from_directory, Response, current_app from flask_login import login_required import numpy as np @@ -120,7 +120,8 @@ def list_styles(): return jsonify({"styles": styles}) except Exception as e: - return jsonify({"styles": [], "error": str(e)}), 200 + current_app.logger.exception('Error loading styles') + return jsonify({"styles": [], "error": "Failed to load styles"}), 200 @style_bp.route("/api/style-preview/", methods=["GET"]) @@ -136,26 +137,29 @@ def get_style_preview(style_id: int): SVG file content with 'image/svg+xml' mimetype, or a placeholder/error SVG if not found. """ try: - # Validate style_id is a positive integer (Flask already validates it's an int) + # Validate style_id range (Flask already validates it's an int via route) if style_id < 0 or style_id > 999999: return Response(_placeholder_svg(style_id), mimetype='image/svg+xml') - # Construct safe filename - only digits allowed in style_id due to route - safe_filename = f"style-{style_id}.svg" + # Construct safe filename - style_id is guaranteed to be an integer by Flask route + # Using string formatting with validated integer prevents path traversal + safe_filename = f"style-{style_id:d}.svg" - # Build and normalize paths - base_path = os.path.normpath(os.path.abspath(STYLE_DIR)) - file_path = os.path.normpath(os.path.join(base_path, safe_filename)) + # Get the absolute base directory (constant, not user-controlled) + base_directory = os.path.abspath(STYLE_DIR) - # Verify path stays within base directory (defense in depth) - if not file_path.startswith(base_path + os.sep) and file_path != base_path: + # Check if the file exists before attempting to serve + file_path = os.path.join(base_directory, safe_filename) + if not os.path.isfile(file_path): return Response(_placeholder_svg(style_id), mimetype='image/svg+xml') - if os.path.isfile(file_path): - return send_file(file_path, mimetype='image/svg+xml') - - # If no preview exists, return a placeholder SVG - return Response(_placeholder_svg(style_id), mimetype='image/svg+xml') + # Use send_from_directory for secure file serving + # This is Flask's safe way to serve files from a directory + return send_from_directory( + base_directory, + safe_filename, + mimetype='image/svg+xml' + ) except Exception: # Return error placeholder diff --git a/webapp/static/js/modules/alpine-app.js b/webapp/static/js/modules/alpine-app.js index 0a3ff72..29f81cb 100644 --- a/webapp/static/js/modules/alpine-app.js +++ b/webapp/static/js/modules/alpine-app.js @@ -40,6 +40,7 @@ document.addEventListener('alpine:init', () => { globalScale: '', autoSize: true, manualSizeScale: '', + writingSizeMm: '', // target x-height in mm (natural handwriting size; blank = engine default ~4.5) // Custom size pageWidth: '', @@ -66,6 +67,7 @@ document.addEventListener('alpine:init', () => { // Chunked generation useChunked: true, + reflowText: true, // reflow soft-wrapped input to fill the width (keeps blank-line paragraph breaks) adaptiveChunking: true, adaptiveStrategy: 'balanced', wordsPerChunk: '', @@ -282,7 +284,9 @@ document.addEventListener('alpine:init', () => { empty_line_spacing: this.emptyLineSpacing ? Number(this.emptyLineSpacing) : undefined, auto_size: this.autoSize, manual_size_scale: (!this.autoSize && this.manualSizeScale) ? Number(this.manualSizeScale) : undefined, + writing_size_mm: this.writingSizeMm ? Number(this.writingSizeMm) : undefined, use_chunked: this.useChunked, + reflow: this.reflowText, adaptive_chunking: this.adaptiveChunking, adaptive_strategy: this.adaptiveStrategy || undefined, words_per_chunk: this.wordsPerChunk ? Number(this.wordsPerChunk) : undefined, @@ -651,6 +655,7 @@ document.addEventListener('alpine:init', () => { formData.append('global_scale', this.globalScale || ''); formData.append('auto_size', this.autoSize ? 'true' : 'false'); formData.append('manual_size_scale', this.manualSizeScale || ''); + formData.append('writing_size_mm', this.writingSizeMm || ''); formData.append('biases', this.biases || ''); formData.append('stroke_colors', this.strokeColors || ''); formData.append('stroke_widths', this.strokeWidths || ''); @@ -662,6 +667,7 @@ document.addEventListener('alpine:init', () => { formData.append('wrap_ratio', this.wrapRatio || ''); formData.append('wrap_utilization', this.wrapUtil || ''); formData.append('use_chunked', this.useChunked ? 'true' : 'false'); + formData.append('reflow', this.reflowText ? 'true' : 'false'); formData.append('adaptive_chunking', this.adaptiveChunking ? 'true' : 'false'); formData.append('adaptive_strategy', this.adaptiveStrategy || ''); formData.append('words_per_chunk', this.wordsPerChunk || ''); @@ -808,6 +814,7 @@ document.addEventListener('alpine:init', () => { global_scale: this.globalScale || null, auto_size: this.autoSize, manual_size_scale: this.manualSizeScale || null, + writing_size_mm: this.writingSizeMm || null, biases: this.biases || null, stroke_colors: this.strokeColors || null, stroke_widths: this.strokeWidths || null, @@ -819,6 +826,7 @@ document.addEventListener('alpine:init', () => { wrap_ratio: this.wrapRatio || null, wrap_utilization: this.wrapUtil || null, use_chunked: this.useChunked, + reflow: this.reflowText, adaptive_chunking: this.adaptiveChunking, adaptive_strategy: this.adaptiveStrategy || null, words_per_chunk: this.wordsPerChunk || null, diff --git a/webapp/templates/admin/base.html b/webapp/templates/admin/base.html index aa41079..0287fdf 100644 --- a/webapp/templates/admin/base.html +++ b/webapp/templates/admin/base.html @@ -66,6 +66,12 @@

{% block admin_title %}Admin Dashboard{% endblock %}

Templates +
  • + + Error Logs + +
  • diff --git a/webapp/templates/admin/character_overrides/view.html b/webapp/templates/admin/character_overrides/view.html index 264c119..b8c3631 100644 --- a/webapp/templates/admin/character_overrides/view.html +++ b/webapp/templates/admin/character_overrides/view.html @@ -75,9 +75,12 @@

    Upload Character Variants

    -
    - - +
    + + + + Preview only. Final width uses generation settings. +
    diff --git a/webapp/templates/admin/logs.html b/webapp/templates/admin/logs.html new file mode 100644 index 0000000..b34def7 --- /dev/null +++ b/webapp/templates/admin/logs.html @@ -0,0 +1,488 @@ +{% extends "admin/base.html" %} + +{% set active_nav = 'logs' %} + +{% block title %}Error Logs - Admin - WriteBot{% endblock %} + +{% block admin_title %}Error Logs{% endblock %} + +{% block admin_extra_css %} + +{% endblock %} + +{% block admin_content %} +
    +
    +
    +

    Application Logs

    +

    View and manage application error and debug logs

    +
    +
    + + {% if log_files %} + + + + +
    +
    + {{ total_lines }} + Total Lines +
    +
    + {{ error_count }} + Errors +
    +
    + {{ warning_count }} + Warnings +
    +
    + {{ log_entries | length }} + Showing +
    +
    + + +
    + + +
    + + +
    + +
    + + +
    + +
    + + +
    + +
    + + +
    +
    + + +
    +
    + {{ selected_file }} +
    + + Download + + +
    +
    + +
    + {% if log_entries %} + {% for entry in log_entries %} +
    {{ entry.clean }}
    + {% endfor %} + {% else %} +
    + {% if search_query or level_filter != 'all' %} + No log entries match your filter criteria. + {% else %} + No log entries found in this file. + {% endif %} +
    + {% endif %} +
    +
    + + {% else %} +
    +
    No log files found in the logs directory.
    +
    + {% endif %} +
    + + +
    +
    +
    Clear Log File?
    +
    + This will permanently delete all entries in {{ selected_file }}. + This action cannot be undone. +
    +
    + +
    + +
    +
    +
    +
    +{% endblock %} + +{% block admin_extra_js %} + +{% endblock %} diff --git a/webapp/templates/index.html b/webapp/templates/index.html index 63655d5..cacabfd 100644 --- a/webapp/templates/index.html +++ b/webapp/templates/index.html @@ -372,18 +372,27 @@

    Page Settings

    +
    + + +
    @@ -448,6 +457,22 @@

    Page Settings

    Generate text in small chunks for better quality and longer lines
    +
    +
    + + + Joins soft-wrapped lines so paragraphs fill the page width, keeping blank-line paragraph breaks. Uncheck to keep your exact line breaks. + +
    + + Fills the page width; uncheck to keep your exact line breaks + +