Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions apps/typegpu-docs/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@ src/content/docs/api
# tests
tests/artifacts
!tests/artifacts/README.md

# generated transformed files
*.tsnotover.ts
*.tsnotover.tsx
6 changes: 4 additions & 2 deletions apps/typegpu-docs/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
"version": "0.0.1",
"private": true,
"scripts": {
"dev": "astro dev",
"build": "astro check && astro build",
"transform-overloads": "find . -type f -name '*.tsnotover.ts' -delete && node scripts/transform-overloads.ts",
"dev": "pnpm run transform-overloads && astro dev",
"build": "pnpm run transform-overloads && astro check && astro build",
"test:types": "astro check",
"preview": "astro preview",
"astro": "astro"
Expand Down Expand Up @@ -78,6 +79,7 @@
"@webgpu/types": "catalog:types",
"astro-vtbot": "^2.1.10",
"autoprefixer": "^10.4.21",
"magic-string": "^0.30.21",
"tailwindcss": "^4.1.11",
"tailwindcss-motion": "^1.1.1",
"vite-imagetools": "catalog:frontend",
Expand Down
288 changes: 288 additions & 0 deletions apps/typegpu-docs/scripts/transform-overloads.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
import ts from 'typescript';
import MagicString from 'magic-string';
import { readdir } from 'fs/promises';
import { basename, dirname, extname, join, relative } from 'path';
import { fileURLToPath } from 'url';
import { writeFile } from 'fs/promises';

const __dirname = dirname(fileURLToPath(import.meta.url));
const projectRoot = join(__dirname, '..');
const examplesDir = join(projectRoot, 'src', 'examples');

const operatorToMethod: Record<string, string> = {
[ts.SyntaxKind.PlusToken]: 'add',
[ts.SyntaxKind.PlusEqualsToken]: 'add',
[ts.SyntaxKind.MinusToken]: 'sub',
[ts.SyntaxKind.MinusEqualsToken]: 'sub',
[ts.SyntaxKind.AsteriskToken]: 'mul',
[ts.SyntaxKind.AsteriskEqualsToken]: 'mul',
[ts.SyntaxKind.SlashToken]: 'div',
[ts.SyntaxKind.SlashEqualsToken]: 'div',
[ts.SyntaxKind.AsteriskAsteriskToken]: 'pow',
[ts.SyntaxKind.AsteriskAsteriskEqualsToken]: 'pow',
};

const assignmentOperators = [
ts.SyntaxKind.PlusEqualsToken,
ts.SyntaxKind.MinusEqualsToken,
ts.SyntaxKind.AsteriskEqualsToken,
ts.SyntaxKind.SlashEqualsToken,
];

async function findTypeScriptFiles(dir: string): Promise<string[]> {
const files: string[] = [];

async function walk(currentDir: string): Promise<void> {
const entries = await readdir(currentDir, { withFileTypes: true });

for (const entry of entries) {
const fullPath = join(currentDir, entry.name);

if (entry.isDirectory()) {
await walk(fullPath);
} else if (entry.isFile()) {
const ext = extname(entry.name);
if (
(ext === '.ts' || ext === '.tsx') &&
!entry.name.endsWith('.d.ts') &&
!entry.name.endsWith('.d.tsx') &&
!entry.name.endsWith('.tsnotover.ts') &&
!entry.name.endsWith('.tsnotover.tsx')
) {
files.push(fullPath);
}
}
}
}

await walk(dir);
return files;
}

type Pattern =
| 'left.op(right)' // e.g. vec + 2 => vec.add(2)
| 'right.op(left)' // e.g. 2 * vec => vec.mul(2)
| 'std.op(left, right)'; // e.g. 2 / vec => std.div(2, vec)

function getOverloadPattern(
checker: ts.TypeChecker,
node: ts.BinaryExpression,
): Pattern | undefined {
const methodName = operatorToMethod[node.operatorToken.kind];
if (!methodName) {
// Not overlaoded
return undefined;
}

// Get the types of both operands
const leftType = checker.getTypeAtLocation(node.left);
const rightType = checker.getTypeAtLocation(node.right);

if (
!checker.__tsover__couldHaveOverloadedOperators(
node.left,
node.operatorToken.kind,
node.right,
leftType,
rightType,
)
) {
// Not overlaoded
return undefined;
}

// For non-commutative operators, use the standard library function
if (methodName === 'div' || methodName === 'pow') {
return 'std.op(left, right)';
}

// Since other supported operators are commutative, prefer left method, fall back to right
const leftHasMethod = leftType.getProperty(methodName) !== undefined;

return leftHasMethod ? 'left.op(right)' : 'right.op(left)';
}

function createProgram(allFiles: string[]): ts.Program {
const configPath = join(projectRoot, 'tsconfig.json');
const configText = ts.sys.readFile(configPath);

if (!configText) {
throw new Error(`Could not read tsconfig.json at ${configPath}`);
}

const { config } = ts.parseConfigFileTextToJson(configPath, configText);
const parsedConfig = ts.parseJsonConfigFileContent(
config,
ts.sys,
projectRoot,
);

const compilerOptions: ts.CompilerOptions = {
...parsedConfig.options,
noEmit: true,
};

const host = ts.createCompilerHost(compilerOptions, true);

return ts.createProgram(allFiles, compilerOptions, host);
}

function isStdDeclared(sourceFile: ts.SourceFile): boolean {
for (const stmt of sourceFile.statements) {
if (!ts.isImportDeclaration(stmt)) {
continue;
}
const moduleSpecifier = stmt.moduleSpecifier;
if (!ts.isStringLiteral(moduleSpecifier)) {
continue;
}
const namedBindings = stmt.importClause?.namedBindings;
// Case 1: import { std } from 'typegpu'
if (
moduleSpecifier.text === 'typegpu' &&
namedBindings &&
ts.isNamedImports(namedBindings) &&
namedBindings.elements.some((el) => el.name.text === 'std')
) {
return true;
}
// Case 2: import * as std from 'typegpu/std'
if (
moduleSpecifier.text === 'typegpu/std' &&
namedBindings &&
ts.isNamespaceImport(namedBindings) &&
namedBindings.name.text === 'std'
) {
return true;
}
}
return false;
}

function transformFile(
sourceFilePath: string,
program: ts.Program,
): { code: string; hasChanges: boolean } {
const checker = program.getTypeChecker();
const sourceFile = program.getSourceFile(sourceFilePath);

if (!sourceFile) {
throw new Error(`Could not get source file for ${sourceFilePath}`);
}

const sourceText = sourceFile.text;
const magic = new MagicString(sourceText);
let hasChanges = false;
let needsStd = false;

function visit(node: ts.Node): void {
// Visit all children of the node first, then process the node itself
ts.forEachChild(node, visit);

if (!ts.isBinaryExpression(node)) {
return;
}

const pattern = getOverloadPattern(checker, node);
const methodName = operatorToMethod[node.operatorToken.kind];

if (!pattern || !methodName) {
return;
}

hasChanges = true;

const start = node.getStart();
const end = node.getEnd();
const leftStart = node.left.getStart();
const leftEnd = node.left.getEnd();
const rightStart = node.right.getStart();
const rightEnd = node.right.getEnd();

const leftText = magic.slice(leftStart, leftEnd);
const rightText = magic.slice(rightStart, rightEnd);

let replacement = '';
if (pattern === 'std.op(left, right)') {
needsStd = true;
replacement = `std.${methodName}(${leftText}, ${rightText})`;
} else if (pattern === 'left.op(right)') {
replacement = `${leftText}.${methodName}(${rightText})`;
} else if (pattern === 'right.op(left)') {
replacement = `${rightText}.${methodName}(${leftText})`;
} else {
throw new Error(`Unsupported pattern: ${pattern}`);
}

if (assignmentOperators.includes(node.operatorToken.kind)) {
// E.g. transforms a += b into a = a.add(b)
magic.overwrite(start, end, `${leftText} = ${replacement}`);
} else {
magic.overwrite(start, end, replacement);
}
}

visit(sourceFile);

if (needsStd && !isStdDeclared(sourceFile)) {
magic.prepend("import { std } from 'typegpu';\n");
}

return { code: magic.toString(), hasChanges };
}

async function main() {
console.log('Finding TypeScript files in examples directory...');

const allFiles = await findTypeScriptFiles(examplesDir);

console.log(`Found ${allFiles.length} files to process`);
console.log('Creating TypeScript program...');

const program = createProgram(allFiles);

console.log('Transforming files...');

let transformedCount = 0;
let errorCount = 0;

for (const filePath of allFiles) {
try {
const { code, hasChanges } = transformFile(filePath, program);

const ext = filePath.endsWith('.tsx') ? '.tsx' : '.ts';
const baseName = basename(filePath, ext);
const dir = dirname(filePath);
const outputPath = join(dir, `${baseName}.tsnotover${ext}`);

if (hasChanges) {
transformedCount++;
await writeFile(outputPath, code, 'utf-8');
console.log(`Transformed: ${relative(projectRoot, filePath)}`);
}
} catch (error) {
errorCount++;
const errorMessage = error instanceof Error
? error.message
: String(error);
console.error(
`Error processing ${relative(projectRoot, filePath)}: ${errorMessage}`,
);
throw new Error(
`Failed to transform ${
relative(projectRoot, filePath)
}: ${errorMessage}`,
{ cause: error },
);
}
}

console.log(
`\nDone! Transformed ${transformedCount} files, ${errorCount} errors.`,
);
}

main().catch((error) => {
console.error('Fatal error:', error);
process.exit(1);
});
2 changes: 1 addition & 1 deletion apps/typegpu-docs/src/components/CodeEditor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ const createCodeEditorComponent = (
>
<Editor
defaultLanguage={language}
value={file.content}
value={file.tsnotoverContent ?? file.content}
path={path}
beforeMount={beforeMount}
onMount={onMount}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ export const openInStackBlitz = (
const tsFiles: Record<string, string> = {};

for (const file of example.tsFiles) {
tsFiles[`src/${file.path}`] = file.content;
tsFiles[`src/${file.path}`] = file.tsnotoverContent ?? file.content;
}
for (const file of common) {
tsFiles[`src/common/${file.path}`] = file.content;
tsFiles[`src/common/${file.path}`] = file.tsnotoverContent ?? file.content;
}

for (const key of Object.keys(tsFiles)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ const material = new THREE.MeshBasicNodeMaterial();

material.colorNode = t3.toTSL(() => {
'use gpu';
const coords = t3.uv().$.mul(2);
const coords = t3.uv().$ * 2;
const pattern = perlin3d.sample(d.vec3f(coords, t3.time.$ * 0.2));
return d.vec4f(std.tanh(pattern * 5), 0.2, 0.4, 1);
});
Expand Down Expand Up @@ -244,12 +244,28 @@ Loop({ start: ptrStart, end: ptrEnd, type: 'uint', condition: '<' }, ({ i }) =>
```

TypeGPU:
```ts
```ts twoslash
'use tsover';
import * as t3 from '@typegpu/three';
import * as THREE from 'three/webgpu';
import { d, std } from 'typegpu';
type TSLStorageAccessor<T extends d.AnyWgslData> = t3.TSLAccessor<
T,
THREE.StorageBufferNode
>;
declare const springListBuffer: TSLStorageAccessor<d.WgslArray<d.U32>>;
declare const springForceBuffer: TSLStorageAccessor<d.WgslArray<d.Vec3f>>;
declare const springVertexIdBuffer: TSLStorageAccessor<d.WgslArray<d.Vec2u>>
declare const ptrStart: number;
declare const ptrEnd: number;
declare const idx: number;
declare let force: d.v3f;
// ---cut---
for (let i = ptrStart; i < ptrEnd; i++) {
const springId = springListBuffer.$[i];
const springForce = springForceBuffer.$[springId];
const springVertexIds = springVertexIdBuffer.$[springId];
const factor = std.select(-1, 1, springVertexIds.x === idx);
force = force.add(springForce.mul(d.f32(factor)));
force += springForce * factor;
}
```
Loading
Loading