Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ bazel-*
**/.bazel_user_root/
.whl
python/.cache
python/pyfory/__pycache__/
__pycache__/
python/dist
python/build
python/pyfory.egg-info
Expand Down
26 changes: 22 additions & 4 deletions compiler/fory_compiler/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def resolve_imports(
imported_enums = []
imported_messages = []
imported_unions = []
source_packages: Dict[str, Optional[str]] = {str(file_path): schema.package}

for imp in schema.imports:
# Resolve import path using search paths
Expand All @@ -149,6 +150,7 @@ def resolve_imports(
imported_enums.extend(imported_schema.enums)
imported_messages.extend(imported_schema.messages)
imported_unions.extend(imported_schema.unions)
source_packages.update(imported_schema.source_packages)

# Create merged schema with imported types first (so they can be referenced)
merged_schema = Schema(
Expand All @@ -161,6 +163,7 @@ def resolve_imports(
options=schema.options,
source_file=schema.source_file,
source_format=schema.source_format,
source_packages=source_packages,
)

cache[file_path] = copy.deepcopy(merged_schema)
Expand Down Expand Up @@ -482,6 +485,7 @@ def compile_file(
package_override: Optional package name override
import_paths: List of import search paths
"""
file_path = file_path.resolve()
print(f"Compiling {file_path}...")

# Parse and resolve imports
Expand Down Expand Up @@ -543,11 +547,25 @@ def compile_file(

generator_class = GENERATORS[lang]
generator = generator_class(schema, options)
files = generator.generate()
try:
files = generator.generate()

if grpc:
service_files = generator.generate_services()
files.extend(service_files)
except ValueError as e:
print(f"Error: {e}", file=sys.stderr)
return False

if grpc:
service_files = generator.generate_services()
files.extend(service_files)
if lang == "rust":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guard only rejects overwriting existing non-generated files. It does not catch two inputs in the same invocation that normalize to the same Rust output path. For example, packages foo.bar and foo_bar both produce foo_bar.rs; the second file sees the generated marker, overwrites the first file, and the command exits successfully with only the second schema's types. Please track generated target paths during the compile invocation and fail when distinct source files map to the same Rust output file, with a two-file CLI regression test.

for f in files:
target = (lang_output / f.path).resolve()
if target.exists() and not is_generated_file(target):
print(
f"Error: refusing to overwrite non-generated Rust file: {target}",
file=sys.stderr,
)
return False

generator.write_files(files)

Expand Down
2 changes: 1 addition & 1 deletion compiler/fory_compiler/frontend/fbs/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class FbsUnion:
"""A FlatBuffers union declaration."""

name: str
types: List[str] = field(default_factory=list)
types: List[FbsTypeName] = field(default_factory=list)
attributes: Dict[str, object] = field(default_factory=dict)
line: int = 0
column: int = 0
Expand Down
11 changes: 9 additions & 2 deletions compiler/fory_compiler/frontend/fbs/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,20 @@ def parse_union(self) -> FbsUnion:
attributes = self.parse_metadata()
self.consume(TokenType.LBRACE, "Expected '{' after union name")

types: List[str] = []
types: List[FbsTypeName] = []
while not self.check(TokenType.RBRACE):
if self.check(TokenType.COMMA):
self.advance()
continue
type_start = self.current()
type_name = self.parse_qualified_ident()
types.append(type_name)
types.append(
FbsTypeName(
name=type_name,
line=type_start.line,
column=type_start.column,
)
)
if self.match(TokenType.COMMA):
continue
if self.check(TokenType.RBRACE):
Expand Down
11 changes: 6 additions & 5 deletions compiler/fory_compiler/frontend/fbs/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,19 +230,20 @@ def _translate_field_attributes(

def _translate_union(self, fbs_union: FbsUnion) -> Union:
fields: List[Field] = []
for index, type_name in enumerate(fbs_union.types, start=1):
for index, type_ref in enumerate(fbs_union.types, start=1):
type_name = type_ref.name
field_name = self._lower_name(type_name)
fields.append(
Field(
name=field_name,
field_type=NamedType(
type_name,
location=self._location(fbs_union.line, fbs_union.column),
location=self._location(type_ref.line, type_ref.column),
),
number=index,
line=fbs_union.line,
column=fbs_union.column,
location=self._location(fbs_union.line, fbs_union.column),
line=type_ref.line,
column=type_ref.column,
location=self._location(type_ref.line, type_ref.column),
)
)
return Union(
Expand Down
Loading
Loading