import type { AST, TypeAST, Pattern, Definition, TypeDefinition, ClassDefinition, InstanceDeclaration } from './ast' import { prettyPrintType } from './ast' // Map type var names to their types type Subst = Map // Map var names to their types type TypeEnv = Map // Replace type vars with resolved types function applySubst(type: TypeAST, subst: Subst): TypeAST { switch (type.kind) { case 'type-var': const resolved = subst.get(type.name); return resolved ? applySubst(resolved, subst) : type; case 'type-function': return { kind: 'type-function', param: applySubst(type.param, subst), result: applySubst(type.result, subst) }; case 'type-apply': return { kind: 'type-apply', constructor: applySubst(type.constructor, subst), args: type.args.map(a => applySubst(a, subst)) }; case 'type-record': return { kind: 'type-record', fields: type.fields.map(f => ({ name: f.name, type: applySubst(f.type, subst) } )) }; case 'type-name': return type; } } function unify(t1: TypeAST, t2: TypeAST, subst: Subst): string | null { const a = applySubst(t1, subst); const b = applySubst(t2, subst); // Same type name if (a.kind === 'type-name' && b.kind === 'type-name' && a.name === b.name) return null; // Same type var if (a.kind === 'type-var' && b.kind === 'type-var' && a.name === b.name) return null; // Type var binds to anything if (a.kind === 'type-var') { if (occursIn(a.name, b, subst)) return `Infinite type: ${a.name} occurs in ${prettyPrintType(b)}`; subst.set(a.name, b); return null; } if (b.kind === 'type-var') { if (occursIn(b.name, a, subst)) return `Infinite type: ${b.name} occurs in ${prettyPrintType(a)}`; subst.set(b.name, a); return null; } // Functions: unify param & result if (a.kind === 'type-function' && b.kind === 'type-function') { const err = unify(a.param, b.param, subst); if (err) return err; return unify(a.result, b.result, subst); } // Type application: unify constructor and args if (a.kind === 'type-apply' && b.kind === 'type-apply') { const err = unify(a.constructor, b.constructor, subst); if (err) return err; if (a.args.length !== b.args.length) return `Type argument mismatch`; for (let i = 0; i < a.args.length; i++) { const err = unify(a.args[i], b.args[i], subst); if (err) return err; } return null; } // Records: unify matching fields if (a.kind === 'type-record' && b.kind === 'type-record') { for (const af of a.fields) { const bf = b.fields.find(f => f.name == af.name); if (bf) { const err = unify(af.type, bf.type, subst); if (err) return err; } } return null; } return `Cannot unify ${prettyPrintType(a)} with ${prettyPrintType(b)}`; } function infer(expr: AST, env: TypeEnv, subst: Subst): TypeAST | null { switch (expr.kind) { case 'literal': if (expr.value.kind === 'int') return { kind: 'type-name', name: 'Int' }; if (expr.value.kind === 'float') return { kind: 'type-name', name: 'Float' }; if (expr.value.kind === 'string') return { kind: 'type-name', name: 'String' }; return null; case 'variable': { const t = env.get(expr.name); if (!t) return null; const { type: fresh } = freshen(t, subst); // If this variable has a typeclass constraint, record it const constraint = moduleConstraints.get(expr.name); if (constraint && fresh.kind === 'type-function' && fresh.param.kind === 'type-var') { inferredConstraints.push({ varName: fresh.param.name, className: constraint.className }); } return fresh; } case 'constructor': { const t = env.get(expr.name); if (!t) return null; return freshen(t, subst).type; } case 'list': { if (expr.elements.length === 0) { return { kind: 'type-apply', constructor: { kind: 'type-name', name: 'List' }, args: [{ kind: 'type-var', name: '_empty' }] }; } const first = infer(expr.elements[0], env, subst); if (!first) return null; return { kind: 'type-apply', constructor: { kind: 'type-name', name: 'List' }, args: [first] }; } case 'record': { const fields: { name: string, type: TypeAST }[] = []; for (const entry of expr.entries) { if (entry.kind === 'spread') continue; const t = infer(entry.value, env, subst); if (!t) return null; fields.push({ name: entry.key, type: t }); } return { kind: 'type-record', fields }; } case 'record-access': { const recType = infer(expr.record, env, subst); if (!recType) return null; const resolved = applySubst(recType, subst); if (resolved.kind === 'type-record') { const field = resolved.fields.find(f => f.name === expr.field); return field ? field.type : null; } return null; } case 'record-update': { const recType = infer(expr.record, env, subst); if (!recType) return null; const resolved = applySubst(recType, subst); if (resolved.kind !== 'type-record') return null; for (const [key, value] of Object.entries(expr.updates)) { const field = resolved.fields.find(f => f.name === key); if (field) { const err = check(value, field.type, env, subst); if (err) warn(err, value); } } return resolved; } case 'apply': { let funcType: TypeAST | null = null; let varMap: Map | null = null; if (expr.func.kind === 'variable') { const t = env.get(expr.func.name); if (t) { const result = freshen(t, subst); funcType = result.type; varMap = result.varMap; } } else { funcType = infer(expr.func, env, subst); } if (!funcType) return null; let current = applySubst(funcType, subst); for (const arg of expr.args) { if (current.kind !== 'type-function') return null; const err = check(arg, current.param, env, subst); if (err) warn(err, arg); current = applySubst(current.result, subst); } // Check typeclass constraints if (expr.func.kind === 'variable') { const constraint = moduleConstraints.get(expr.func.name); if (constraint && varMap) { const freshName = varMap.get(constraint.param) || constraint.param; const resolved = applySubst({ kind: 'type-var', name: freshName }, subst); if (resolved.kind === 'type-name') { const instances = moduleInstances.get(constraint.className); if (!instances || !instances.has(resolved.name)) { warn(`No instance ${constraint.className} ${resolved.name}`, expr); } } else if (resolved.kind === 'type-apply' && resolved.constructor.kind === 'type-name') { const instances = moduleInstances.get(constraint.className); if (!instances || !instances.has(resolved.constructor.name)) { warn(`No instance ${constraint.className} ${resolved.constructor.name}`, expr); } } else if (resolved.kind === 'type-var') { inferredConstraints.push({ varName: resolved.name, className: constraint.className }); } } } return current; } case 'let': { const valType = infer(expr.value, env, subst); const newEnv = new Map(env); if (valType) newEnv.set(expr.name, valType); return infer(expr.body, newEnv, subst); } case 'lambda': return null; case 'match': { const scrutType = infer(expr.expr, env, subst); if (expr.cases.length === 0) return null; const firstEnv = new Map(env); if (scrutType) { bindPattern(expr.cases[0].pattern, scrutType, firstEnv, subst); checkExhaustiveness(scrutType, expr.cases, expr.expr, subst); } return infer(expr.cases[0].result, firstEnv, subst); } default: return null; } } function check(expr: AST, expected: TypeAST, env: TypeEnv, subst: Subst): string | null { const exp = applySubst(expected, subst); // Lambda against function type if (expr.kind === 'lambda' && exp.kind === 'type-function') { const newEnv = new Map(env); newEnv.set(expr.params[0], exp.param); if (expr.params.length > 1) { const innerLambda: AST = { kind: 'lambda', params: expr.params.slice(1), body: expr.body, }; return check(innerLambda, exp.result, newEnv, subst); } return check(expr.body, exp.result, newEnv, subst); } // Match: check each case result against expected if (expr.kind === 'match') { const scrutType = infer(expr.expr, env, subst); for (const c of expr.cases) { const caseEnv = new Map(env); if (scrutType) { const patErr = bindPattern(c.pattern, scrutType, caseEnv, subst); if (patErr) warn(patErr, c.result); } const err = check(c.result, expected, caseEnv, subst); if (err) warn(err, c.result); } if (scrutType) { checkExhaustiveness(scrutType, expr.cases, expr.expr, subst); } return null; } // Let if (expr.kind === 'let') { const valType = infer(expr.value, env, subst); const newEnv = new Map(env); if (valType) newEnv.set(expr.name, valType); return check(expr.body, expected, newEnv, subst); } // List literal against List type if (expr.kind === 'list' && exp.kind === 'type-apply' && exp.constructor.kind === 'type-name' && exp.constructor.name === 'List') { const elemType = exp.args[0]; for (const elem of expr.elements) { if (elem.kind === 'list-spread') { const err = check(elem.spread, exp, env, subst); if (err) return err; continue; } const err = check(elem, elemType, env, subst); if (err) return err; } return null; } // Fallback: infer and unify const inferred = infer(expr, env, subst); if (!inferred) return null; // Can't infer, skip silently return unify(inferred, expected, subst); } function warn(msg: string, expr: AST) { const loc = expr.line ? ` (line ${expr.line})` : ''; console.warn(`TypeError${loc}: ${msg}`); } function bindPattern(pattern: Pattern, type: TypeAST, env: TypeEnv, subst: Subst): string | null { const t = applySubst(type, subst); switch (pattern.kind) { case 'var': env.set(pattern.name, t); return null; case 'wildcard': case 'literal': return null; case 'constructor': // look up ctor types const ctorType = env.get(pattern.name); if (!ctorType) return null; const fresh = freshen(ctorType, subst).type; if (pattern.args.length === 0) { unify(fresh, t, subst); return null; } let cur = fresh; for (const arg of pattern.args) { if (cur.kind !== 'type-function') return null; const err = bindPattern(arg, cur.param, env, subst); if (err) return err; cur = applySubst(cur.result, subst); } unify(cur, t, subst); return null; case 'list': case 'list-spread': if (t.kind === 'type-apply' && t.constructor.kind === 'type-name' && t.constructor.name === 'List' && t.args.length === 1) { const elemType = t.args[0]; if (pattern.kind === 'list') { for (const elem of pattern.elements) { bindPattern(elem, elemType, env, subst); } } else { for (const elem of pattern.head) { bindPattern(elem, elemType, env, subst); } env.set(pattern.spread, t); } return null; } if (t.kind === 'type-var') return null; return `Connot match ${prettyPrintType(t)} against list pattern`; case 'record': if (t.kind === 'type-record') { for (const [key, pat] of Object.entries(pattern.fields)) { const field = t.fields.find(f => f.name === key); if (field) bindPattern(pat, field.type, env, subst); } } return null; default: return null; } } let moduleConstraints = new Map(); let moduleInstances = new Map>(); let inferredConstraints: { varName: string, className: string }[] = []; let typeConstructors = new Map(); export function typecheck(defs: Definition[], typeDefs: TypeDefinition[] = [], classDefs: ClassDefinition[] = [], instances: InstanceDeclaration[] = []) { const env: TypeEnv = new Map(); // Register instances as a lookup: className -> Set of type names const instanceMap = new Map>(); for (const inst of instances) { if (!instanceMap.has(inst.className)) instanceMap.set(inst.className, new Set()); instanceMap.get(inst.className)!.add(inst.typeName); } moduleInstances = instanceMap; // Register class methods with constraints in env for (const cls of classDefs) { for (const method of cls.methods) { env.set(method.name, method.type); moduleConstraints.set(method.name, { param: cls.param, className: cls.name }); } } // Register constructors for (const td of typeDefs) { typeConstructors.set(td.name, td.constructors.map(c => c.name)); const resultType: TypeAST = td.params.length === 0 ? { kind: 'type-name', name: td.name } : { kind: 'type-apply', constructor: { kind: 'type-name', name: td.name }, args: td.params.map(p => ({ kind: 'type-var', name: p })) }; for (const ctor of td.constructors) { if (ctor.args.length === 0) { env.set(ctor.name, resultType); } else { let fnType: TypeAST = resultType; for (let i = ctor.args.length - 1; i >= 0; i--) { fnType = { kind: 'type-function', param: ctor.args[i], result: fnType }; } env.set(ctor.name, fnType); } } } // Register all annotated defs in env first so they can ref eachother for (const def of defs) { if (def.annotation) { env.set(def.name, def.annotation.type); } } // Check each annotated def for (const def of defs) { if (def.body) { const subst: Subst = new Map(); inferredConstraints = []; // reset if (def.annotation) { const err = check(def.body, def.annotation.type, env, subst); if (err) warn(err, def.body); } else { const t = infer(def.body, env, subst); if (t) env.set(def.name, t); } // Propagate inferred constraints to this definition if (inferredConstraints.length > 0 && def.annotation) { // Collect annotation var names const annoVars = new Set(); function collectVars(t: TypeAST) { if (t.kind === 'type-var') annoVars.add(t.name); if (t.kind === 'type-function') { collectVars(t.param); collectVars(t.result); } if (t.kind === 'type-apply') { collectVars(t.constructor); t.args.forEach(collectVars); } if (t.kind === 'type-record') t.fields.forEach(f => collectVars(f.type)); } collectVars(def.annotation.type); for (const ic of inferredConstraints) { // Walk subst chain to find which annotation var this connects to let found: string | null = null; for (const av of annoVars) { const resolved = applySubst({ kind: 'type-var', name: av }, subst); const icResolved = applySubst({ kind: 'type-var', name: ic.varName }, subst); // If they resolve to the same thing, this annotation var is the one if (unify(resolved, icResolved, new Map(subst)) === null) { found = av; break; } } if (found) { moduleConstraints.set(def.name, { param: found, className: ic.className }); } } } } } } function occursIn(name: string, type: TypeAST, subst: Subst): boolean { const t = applySubst(type, subst); switch (t.kind) { case 'type-var': return t.name === name; case 'type-name': return false; case 'type-function': return occursIn(name, t.param, subst) || occursIn(name, t.result, subst); case 'type-apply': return occursIn(name, t.constructor, subst) || t.args.some(a => occursIn(name, a, subst)); case 'type-record': return t.fields.some(f => occursIn(name, f.type, subst)); } } let freshCounter = 0; function freshen(type: TypeAST, subst: Subst): { type: TypeAST, varMap: Map } { const vars = new Map(); function go(t: TypeAST): TypeAST { const resolved = applySubst(t, subst); switch (resolved.kind) { case 'type-var': if (!vars.has(resolved.name)) { vars.set(resolved.name, { kind: 'type-var', name: `_t${freshCounter++}` }); } return vars.get(resolved.name)!; case 'type-name': return resolved; case 'type-function': return { kind: 'type-function', param: go(resolved.param), result: go(resolved.result) }; case 'type-apply': return { kind: 'type-apply', constructor: go(resolved.constructor), args: resolved.args.map(go) }; case 'type-record': return { kind: 'type-record', fields: resolved.fields.map(f => ({ name: f.name, type: go(f.type) })) }; } } const result = go(type); const varMap = new Map(); for (const [old, fresh] of vars) { if (fresh.kind === 'type-var') varMap.set(old, fresh.name); } return { type: result, varMap }; } function checkExhaustiveness(scrutType: TypeAST, cases: { pattern: Pattern }[], expr: AST, subst: Subst) { const resolved = applySubst(scrutType, subst); // Find type name let typeName: string | null = null; if (resolved.kind === 'type-name') typeName = resolved.name; if (resolved.kind === 'type-apply' && resolved.constructor.kind === 'type-name') typeName = resolved.constructor.name; if (!typeName) return; const allCtors = typeConstructors.get(typeName); if (!allCtors) return; // Any catch-alls? const hasCatchAll = cases.some(c => c.pattern.kind === 'var' || c.pattern.kind === 'wildcard'); if (hasCatchAll) return; // Collect constructor names const coveredCtors = new Set(); for (const c of cases) { if (c.pattern.kind === 'constructor') { coveredCtors.add(c.pattern.name); } } const missing = allCtors.filter(name => !coveredCtors.has(name)); if (missing.length > 0) { warn(`Non-exhaustive match, missing: ${missing.join(', ')}`, expr); } }