From f272ffaca2bce267640fcf44a8f9db24f7e9e1e3 Mon Sep 17 00:00:00 2001 From: Dustin Swan Date: Thu, 26 Mar 2026 18:32:40 -0600 Subject: [PATCH] We're checking types!!!! --- src/ast.ts | 27 ++--- src/cg/01-stdlib.cg | 10 +- src/compiler.ts | 12 +- src/parser.ts | 8 +- src/runtime-js.ts | 12 +- src/typechecker.ts | 259 ++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 304 insertions(+), 24 deletions(-) create mode 100644 src/typechecker.ts diff --git a/src/ast.ts b/src/ast.ts index 11e601f..5e1d0f2 100644 --- a/src/ast.ts +++ b/src/ast.ts @@ -142,7 +142,7 @@ export type RecordUpdate = { export type Definition = { kind: 'definition' name: string - body: AST + body?: AST line?: number column?: number start?: number @@ -327,6 +327,7 @@ export function prettyPrint(ast: AST, indent = 0): string { const ann = ast.annotation ? ` : ${prettyPrintType(ast.annotation.type)}` : ''; + if (!ast.body) return `${ast.name}${ann};`; return `${ast.name}${ann} = ${prettyPrint(ast.body, indent)};`; default: @@ -398,18 +399,18 @@ export function prettyPrintType(type: TypeAST): string { } } -function prettyPrintTypeDefinition(td: TypeDefinition): string { - const params = td.params.length > 0 ? ' ' + td.params.join(' ') : ''; - const ctors = td.constructors.map(c => { - const args = c.args.map(a => - a.kind === 'type-function' || a.kind === 'type-apply' - ? `(${prettyPrintType(a)})` - : prettyPrintType(a) - ).join(' '); - return args ? `${c.name} ${args}` : c.name; - }).join(' | '); - return `${td.name}${params} = ${ctors};`; -} +// function prettyPrintTypeDefinition(td: TypeDefinition): string { +// const params = td.params.length > 0 ? ' ' + td.params.join(' ') : ''; +// const ctors = td.constructors.map(c => { +// const args = c.args.map(a => +// a.kind === 'type-function' || a.kind === 'type-apply' +// ? `(${prettyPrintType(a)})` +// : prettyPrintType(a) +// ).join(' '); +// return args ? `${c.name} ${args}` : c.name; +// }).join(' | '); +// return `${td.name}${params} = ${ctors};`; +// } function needsQuotes(key: string): boolean { return key === '_' || !/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(key); diff --git a/src/cg/01-stdlib.cg b/src/cg/01-stdlib.cg index a23a6d1..90bb3ba 100644 --- a/src/cg/01-stdlib.cg +++ b/src/cg/01-stdlib.cg @@ -1,5 +1,13 @@ -Maybe a = None | Some a; +# builtins +# TODO: once we get typeclasses, make these the actual types +cat : a \ a \ a; +add : a \ a \ a; +sub : a \ a \ a; +mul : a \ a \ a; +div : a \ a \ a; +eq : a \ a \ Bool; +Maybe a = None | Some a; # nth : Int \ List a \ Maybe a # in host at the moment, until we get typeclasses or something and this can work on strings too diff --git a/src/compiler.ts b/src/compiler.ts index 36b1d61..0e62b89 100644 --- a/src/compiler.ts +++ b/src/compiler.ts @@ -1,5 +1,6 @@ import type { AST, Pattern, Definition } from './ast'; import { store } from './runtime-js'; +import { typecheck } from './typechecker'; let matchCounter = 0; @@ -221,11 +222,14 @@ function compilePattern(pattern: Pattern, expr: string): { condition: string, bi } export function compileAndRun(defs: Definition[]) { + typecheck(defs); + const compiledDefs: string[] = []; - const topLevel = new Set(defs.map(d => d.name)); + const topLevel = new Set(defs.filter(d => d.body).map(d => d.name)); for (const def of defs) { + if (!def.body) continue; // type declaration only definitions.set(def.name, def); const free = freeVars(def.body); const deps = new Set([...free].filter(v => topLevel.has(v))); @@ -244,6 +248,7 @@ export function compileAndRun(defs: Definition[]) { } for (const def of defs) { + if (!def.body) continue; const ctx: CompileCtx = { useStore: false, topLevel, bound: new Set() }; const compiled = `const ${sanitizeName(def.name)} = ${compile(def.body, ctx)};`; compiledDefs.push(compiled); @@ -257,8 +262,9 @@ export function compileAndRun(defs: Definition[]) { } } + const defsWithBody = defs.filter(d => d.body); const lastName = defs[defs.length - 1].name; - const defNames = defs.map(d => sanitizeName(d.name)).join(', '); + const defNames = defsWithBody.map(d => sanitizeName(d.name)).join(', '); const code = `${compiledDefs.join('\n')} return { ${defNames}, __result: ${sanitizeName(lastName)} };`; @@ -396,7 +402,7 @@ export function recompile(name: string, newAst: AST) { collectDependents(name); for (const defName of toRecompile) { - const ast = definitions.get(defName)!.body; + const ast = definitions.get(defName)!.body!; const compiled = compile(ast); const fn = new Function('store', `return ${compiled}`); diff --git a/src/parser.ts b/src/parser.ts index 61a88ef..4891672 100644 --- a/src/parser.ts +++ b/src/parser.ts @@ -1,5 +1,5 @@ import type { Token } from './lexer' -import type { AST, MatchCase, Pattern, Definition, TypeAST, TypeDefinition, TypeConstructor } from './ast' +import type { AST, MatchCase, Pattern, Definition, TypeAST, TypeDefinition, TypeConstructor, Annotation } from './ast' import { ParseError } from './error' export class Parser { @@ -174,6 +174,12 @@ export class Parser { if (this.current().kind === 'colon') { this.advance(); annotation = { constraints: [], type: this.parseType() }; + + // Declaration only + if (this.current().kind === 'semicolon') { + this.advance(); + return { kind: 'definition', name, annotation, ...this.getPos(nameToken) }; + } } this.expect('equals'); diff --git a/src/runtime-js.ts b/src/runtime-js.ts index dac4249..a235b24 100644 --- a/src/runtime-js.ts +++ b/src/runtime-js.ts @@ -1,8 +1,8 @@ import { tokenize } from './lexer' import { Parser } from './parser' import { compile, recompile, definitions, freeVars, dependencies, dependents, astRegistry } from './compiler' -import { prettyPrint, prettyPrintType } from './ast' -import type { AST } from './ast' +import { prettyPrint } from './ast' +import type { AST, Definition } from './ast' import { measure } from './ui'; const STORAGE_KEY = 'cg-definitions'; @@ -213,7 +213,7 @@ export const _rt = { if (defs.length > 0) { const def = defs[0]; - recompile(def.name, def.body); + recompile(def.name, def.body!); const source = prettyPrint({ kind: 'definition', name: def.name, body: def.body }); appendChangeLog(def.name, source); saveDefinitions(); @@ -230,7 +230,7 @@ export const _rt = { const tokens = tokenize(wrapped); const parser = new Parser(tokens, wrapped); const { definitions: defs } = parser.parse(); - const ast = defs[0].body; + const ast = defs[0].body!; // validate free vars const free = freeVars(ast); @@ -243,7 +243,7 @@ export const _rt = { return { _tag: 'Err', _0: `Unknown: ${unknown.join(', ')}` }; } - const compiled = compile(defs[0].body); + const compiled = compile(defs[0].body!); const fn = new Function('_rt', 'store', `return ${compiled}`); const result = fn(_rt, store); if (result === undefined) { @@ -277,7 +277,7 @@ export function loadDefinitions() { const parser = new Parser(tokens, source as string); const { definitions: defs } = parser.parse(); if (defs.length > 0) { - recompile(defs[0].name, defs[0].body); + recompile(defs[0].name, defs[0].body!); } } } catch (e) { diff --git a/src/typechecker.ts b/src/typechecker.ts new file mode 100644 index 0000000..5f02852 --- /dev/null +++ b/src/typechecker.ts @@ -0,0 +1,259 @@ +import type { AST, TypeAST, Pattern, Definition } from './ast' +import { prettyPrintType } from './ast' + +// Map type var names to their types +type Subst = Map + +// Map var names to their types +type TypeEnv = Map + +// Replace type vars with resolved types +function applySubst(type: TypeAST, subst: Subst): TypeAST { + switch (type.kind) { + case 'type-var': + const resolved = subst.get(type.name); + return resolved ? applySubst(resolved, subst) : type; + case 'type-function': + return { kind: 'type-function', param: applySubst(type.param, subst), result: applySubst(type.result, subst) }; + case 'type-apply': + return { kind: 'type-apply', constructor: applySubst(type.constructor, subst), args: type.args.map(a => applySubst(a, subst)) }; + case 'type-record': + return { kind: 'type-record', fields: type.fields.map(f => ({ name: f.name, type: applySubst(f.type, subst) } )) }; + case 'type-name': + return type; + } + +} + +function unify(t1: TypeAST, t2: TypeAST, subst: Subst): string | null { + const a = applySubst(t1, subst); + const b = applySubst(t2, subst); + + // Same type name + if (a.kind === 'type-name' && b.kind === 'type-name' && a.name === b.name) return null; + + // Type var binds to anything + if (a.kind === 'type-var') { subst.set(a.name, b); return null; } + if (b.kind === 'type-var') { subst.set(b.name, a); return null; } + + // Functions: unify param & result + if (a.kind === 'type-function' && b.kind === 'type-function') { + const err = unify(a.param, b.param, subst); + if (err) return err; + return unify(a.result, b.result, subst); + } + + // Type application: unify constructor and args + if (a.kind === 'type-apply' && b.kind === 'type-apply') { + const err = unify(a.constructor, b.constructor, subst); + if (err) return err; + if (a.args.length !== b.args.length) return `Type argument mismatch`; + for (let i = 0; i < a.args.length; i++) { + const err = unify(a.args[i], b.args[i], subst); + if (err) return err; + } + return null; + } + + // Records: unify matching fields + if (a.kind === 'type-record' && b.kind === 'type-record') { + for (const af of a.fields) { + const bf = b.fields.find(f => f.name == af.name); + if (bf) { + const err = unify(af.type, bf.type, subst); + if (err) return err; + } + } + return null; + } + + return `Cannot unify ${prettyPrintType(a)} with ${prettyPrintType(b)}`; +} + +function infer(expr: AST, env: TypeEnv, subst: Subst): TypeAST | null { + switch (expr.kind) { + case 'literal': + if (expr.value.kind === 'int') return { kind: 'type-name', name: 'Int' }; + if (expr.value.kind === 'float') return { kind: 'type-name', name: 'Float' }; + if (expr.value.kind === 'string') return { kind: 'type-name', name: 'String' }; + return null; + + case 'variable': { + const t = env.get(expr.name); + return t ? applySubst(t, subst) : null; + } + + case 'constructor': { + const t = env.get(expr.name); + return t ? applySubst(t, subst) : null; + } + + case 'record': { + const fields: { name: string, type: TypeAST }[] = []; + for (const entry of expr.entries) { + if (entry.kind === 'spread') continue; + const t = infer(entry.value, env, subst); + if (!t) return null; + fields.push({ name: entry.key, type: t }); + } + return { kind: 'type-record', fields }; + } + + case 'record-access': { + const recType = infer(expr.record, env, subst); + if (!recType) return null; + const resolved = applySubst(recType, subst); + if (resolved.kind === 'type-record') { + const field = resolved.fields.find(f => f.name === expr.field); + return field ? field.type : null; + } + return null; + } + + case 'apply': { + const funcType = infer(expr.func, env, subst); + if (!funcType) return null; + + let current = applySubst(funcType, subst); + for (const arg of expr.args) { + if (current.kind !== 'type-function') return null; + const err = check(arg, current.param, env, subst); + if (err) warn(err, arg); + current = applySubst(current.result, subst); + } + return current; + } + + case 'let': { + const valType = infer(expr.value, env, subst); + const newEnv = new Map(env); + if (valType) newEnv.set(expr.name, valType); + return infer(expr.body, newEnv, subst); + } + + case 'lambda': + return null; + + case 'match': { + const scrutType = infer(expr.expr, env, subst); + if (expr.cases.length === 0) return null; + const firstEnv = new Map(env); + if (scrutType) bindPattern(expr.cases[0].pattern, scrutType, firstEnv, subst); + return infer(expr.cases[0].result, firstEnv, subst); + } + + default: + return null; + } +} + +function check(expr: AST, expected: TypeAST, env: TypeEnv, subst: Subst): string | null { + const exp = applySubst(expected, subst); + + // Lambda against function type + if (expr.kind === 'lambda' && exp.kind === 'type-function') { + const newEnv = new Map(env); + newEnv.set(expr.params[0], exp.param); + + if (expr.params.length > 1) { + const innerLambda: AST = { + kind: 'lambda', + params: expr.params.slice(1), + body: expr.body, + }; + + return check(innerLambda, exp.result, newEnv, subst); + } + + return check(expr.body, exp.result, newEnv, subst); + } + + // Match: check each case result against expected + if (expr.kind === 'match') { + const scrutType = infer(expr.expr, env, subst); + for (const c of expr.cases) { + const caseEnv = new Map(env); + if (scrutType) bindPattern(c.pattern, scrutType, caseEnv, subst); + const err = check(c.result, expected, caseEnv, subst); + if (err) warn(err, c.result); + } + return null; + } + + // Let + if (expr.kind === 'let') { + const valType = infer(expr.value, env, subst); + const newEnv = new Map(env); + if (valType) newEnv.set(expr.name, valType); + return check(expr.body, expected, newEnv, subst); + } + + // Fallback: infer and unify + const inferred = infer(expr, env, subst); + if (!inferred) return null; // Can't infer, skip silently + return unify(inferred, expected, subst); +} + +function warn(msg: string, expr: AST) { + const loc = expr.line ? ` (line ${expr.line})` : ''; + console.warn(`TypeError${loc}: ${msg}`); +} + +function bindPattern(pattern: Pattern, type: TypeAST, env: TypeEnv, subst: Subst): void { + const t = applySubst(type, subst); + switch (pattern.kind) { + case 'var': + env.set(pattern.name, t); + break; + case 'constructor': + // TODO: look up ctor arg types + break; + case 'list': + case 'list-spread': + // TODO: bind element types + break; + case 'record': + if (t.kind === 'type-record') { + for (const [key, pat] of Object.entries(pattern.fields)) { + const field = t.fields.find(f => f.name === key); + if (field) bindPattern(pat, field.type, env, subst); + } + } + break; + } +} + +// const int: TypeAST = { kind: 'type-name', name: 'Int' }; +// const float: TypeAST = { kind: 'type-name', name: 'Float' }; +// const str: TypeAST = { kind: 'type-name', name: 'String' }; +// const bool: TypeAST = { kind: 'type-name', name: 'Bool' }; +// const tvar = (name: string): TypeAST => ({ kind: 'type-var', name }); +// const tfn = (param: TypeAST, result: TypeAST): TypeAST => ({ kind: 'type-function', param, result }); + +export function typecheck(defs: Definition[]) { + const env: TypeEnv = new Map(); + + // seed env with builtin types + // env.set('cat', tfn(str, tfn(str, str))); + // env.set('add', tfn(int, tfn(int, int))); + // env.set('sub', tfn(int, tfn(int, int))); + // env.set('mul', tfn(int, tfn(int, int))); + // env.set('div', tfn(int, tfn(int, int))); + // env.set('eq', tfn(tvar('a'), tfn(tvar('a'), bool))); + + // Register all annotated defs in env first so they can ref eachother + for (const def of defs) { + if (def.annotation) { + env.set(def.name, def.annotation.type); + } + } + + // Check each annotated def + for (const def of defs) { + if (def.annotation && def.body) { + const subst: Subst = new Map(); + const err = check(def.body, def.annotation.type, env, subst); + if (err) warn(err, def.body); + } + } +}