cg/src/typechecker.ts

539 lines
21 KiB
TypeScript

import type { AST, TypeAST, Pattern, Definition, TypeDefinition, ClassDefinition, InstanceDeclaration } from './ast'
import { prettyPrintType } from './ast'
// Map type var names to their types
type Subst = Map<string, TypeAST>
// Map var names to their types
type TypeEnv = Map<string, TypeAST>
// Replace type vars with resolved types
function applySubst(type: TypeAST, subst: Subst): TypeAST {
switch (type.kind) {
case 'type-var':
const resolved = subst.get(type.name);
return resolved ? applySubst(resolved, subst) : type;
case 'type-function':
return { kind: 'type-function', param: applySubst(type.param, subst), result: applySubst(type.result, subst) };
case 'type-apply':
return { kind: 'type-apply', constructor: applySubst(type.constructor, subst), args: type.args.map(a => applySubst(a, subst)) };
case 'type-record':
return { kind: 'type-record', fields: type.fields.map(f => ({ name: f.name, type: applySubst(f.type, subst) } )) };
case 'type-name':
return type;
}
}
function unify(t1: TypeAST, t2: TypeAST, subst: Subst): string | null {
const a = applySubst(t1, subst);
const b = applySubst(t2, subst);
// Same type name
if (a.kind === 'type-name' && b.kind === 'type-name' && a.name === b.name) return null;
// Same type var
if (a.kind === 'type-var' && b.kind === 'type-var' && a.name === b.name) return null;
// Type var binds to anything
if (a.kind === 'type-var') {
if (occursIn(a.name, b, subst)) return `Infinite type: ${a.name} occurs in ${prettyPrintType(b)}`;
subst.set(a.name, b);
return null;
}
if (b.kind === 'type-var') {
if (occursIn(b.name, a, subst)) return `Infinite type: ${b.name} occurs in ${prettyPrintType(a)}`;
subst.set(b.name, a);
return null;
}
// Functions: unify param & result
if (a.kind === 'type-function' && b.kind === 'type-function') {
const err = unify(a.param, b.param, subst);
if (err) return err;
return unify(a.result, b.result, subst);
}
// Type application: unify constructor and args
if (a.kind === 'type-apply' && b.kind === 'type-apply') {
const err = unify(a.constructor, b.constructor, subst);
if (err) return err;
if (a.args.length !== b.args.length) return `Type argument mismatch`;
for (let i = 0; i < a.args.length; i++) {
const err = unify(a.args[i], b.args[i], subst);
if (err) return err;
}
return null;
}
// Records: unify matching fields
if (a.kind === 'type-record' && b.kind === 'type-record') {
for (const af of a.fields) {
const bf = b.fields.find(f => f.name == af.name);
if (bf) {
const err = unify(af.type, bf.type, subst);
if (err) return err;
}
}
return null;
}
return `Cannot unify ${prettyPrintType(a)} with ${prettyPrintType(b)}`;
}
function infer(expr: AST, env: TypeEnv, subst: Subst): TypeAST | null {
switch (expr.kind) {
case 'literal':
if (expr.value.kind === 'int') return { kind: 'type-name', name: 'Int' };
if (expr.value.kind === 'float') return { kind: 'type-name', name: 'Float' };
if (expr.value.kind === 'string') return { kind: 'type-name', name: 'String' };
return null;
case 'variable': {
const t = env.get(expr.name);
if (!t) return null;
const { type: fresh } = freshen(t, subst);
// If this variable has a typeclass constraint, record it
const constraint = moduleConstraints.get(expr.name);
if (constraint && fresh.kind === 'type-function' && fresh.param.kind === 'type-var') {
inferredConstraints.push({ varName: fresh.param.name, className: constraint.className });
}
return fresh;
}
case 'constructor': {
const t = env.get(expr.name);
if (!t) return null;
return freshen(t, subst).type;
}
case 'list': {
if (expr.elements.length === 0) {
return { kind: 'type-apply', constructor: { kind: 'type-name', name: 'List' }, args: [{ kind: 'type-var', name: '_empty' }] };
}
const first = infer(expr.elements[0], env, subst);
if (!first) return null;
return { kind: 'type-apply', constructor: { kind: 'type-name', name: 'List' }, args: [first] };
}
case 'record': {
const fields: { name: string, type: TypeAST }[] = [];
for (const entry of expr.entries) {
if (entry.kind === 'spread') continue;
const t = infer(entry.value, env, subst);
if (!t) return null;
fields.push({ name: entry.key, type: t });
}
return { kind: 'type-record', fields };
}
case 'record-access': {
const recType = infer(expr.record, env, subst);
if (!recType) return null;
const resolved = applySubst(recType, subst);
if (resolved.kind === 'type-record') {
const field = resolved.fields.find(f => f.name === expr.field);
return field ? field.type : null;
}
return null;
}
case 'record-update': {
const recType = infer(expr.record, env, subst);
if (!recType) return null;
const resolved = applySubst(recType, subst);
if (resolved.kind !== 'type-record') return null;
for (const [key, value] of Object.entries(expr.updates)) {
const field = resolved.fields.find(f => f.name === key);
if (field) {
const err = check(value, field.type, env, subst);
if (err) warn(err, value);
}
}
return resolved;
}
case 'apply': {
let funcType: TypeAST | null = null;
let varMap: Map<string, string> | null = null;
if (expr.func.kind === 'variable') {
const t = env.get(expr.func.name);
if (t) {
const result = freshen(t, subst);
funcType = result.type;
varMap = result.varMap;
}
} else {
funcType = infer(expr.func, env, subst);
}
if (!funcType) return null;
let current = applySubst(funcType, subst);
for (const arg of expr.args) {
if (current.kind !== 'type-function') return null;
const err = check(arg, current.param, env, subst);
if (err) warn(err, arg);
current = applySubst(current.result, subst);
}
// Check typeclass constraints
if (expr.func.kind === 'variable') {
const constraint = moduleConstraints.get(expr.func.name);
if (constraint && varMap) {
const freshName = varMap.get(constraint.param) || constraint.param;
const resolved = applySubst({ kind: 'type-var', name: freshName }, subst);
if (resolved.kind === 'type-name') {
const instances = moduleInstances.get(constraint.className);
if (!instances || !instances.has(resolved.name)) {
warn(`No instance ${constraint.className} ${resolved.name}`, expr);
}
} else if (resolved.kind === 'type-apply' && resolved.constructor.kind === 'type-name') {
const instances = moduleInstances.get(constraint.className);
if (!instances || !instances.has(resolved.constructor.name)) {
warn(`No instance ${constraint.className} ${resolved.constructor.name}`, expr);
}
} else if (resolved.kind === 'type-var') {
inferredConstraints.push({ varName: resolved.name, className: constraint.className });
}
}
}
return current;
}
case 'let': {
const valType = infer(expr.value, env, subst);
const newEnv = new Map(env);
if (valType) newEnv.set(expr.name, valType);
return infer(expr.body, newEnv, subst);
}
case 'lambda':
return null;
case 'match': {
const scrutType = infer(expr.expr, env, subst);
if (expr.cases.length === 0) return null;
const firstEnv = new Map(env);
if (scrutType) {
bindPattern(expr.cases[0].pattern, scrutType, firstEnv, subst);
checkExhaustiveness(scrutType, expr.cases, expr.expr, subst);
}
return infer(expr.cases[0].result, firstEnv, subst);
}
default:
return null;
}
}
function check(expr: AST, expected: TypeAST, env: TypeEnv, subst: Subst): string | null {
const exp = applySubst(expected, subst);
// Lambda against function type
if (expr.kind === 'lambda' && exp.kind === 'type-function') {
const newEnv = new Map(env);
newEnv.set(expr.params[0], exp.param);
if (expr.params.length > 1) {
const innerLambda: AST = {
kind: 'lambda',
params: expr.params.slice(1),
body: expr.body,
};
return check(innerLambda, exp.result, newEnv, subst);
}
return check(expr.body, exp.result, newEnv, subst);
}
// Match: check each case result against expected
if (expr.kind === 'match') {
const scrutType = infer(expr.expr, env, subst);
for (const c of expr.cases) {
const caseEnv = new Map(env);
if (scrutType) {
const patErr = bindPattern(c.pattern, scrutType, caseEnv, subst);
if (patErr) warn(patErr, c.result);
}
const err = check(c.result, expected, caseEnv, subst);
if (err) warn(err, c.result);
}
if (scrutType) {
checkExhaustiveness(scrutType, expr.cases, expr.expr, subst);
}
return null;
}
// Let
if (expr.kind === 'let') {
const valType = infer(expr.value, env, subst);
const newEnv = new Map(env);
if (valType) newEnv.set(expr.name, valType);
return check(expr.body, expected, newEnv, subst);
}
// List literal against List type
if (expr.kind === 'list' && exp.kind === 'type-apply' && exp.constructor.kind === 'type-name' && exp.constructor.name === 'List') {
const elemType = exp.args[0];
for (const elem of expr.elements) {
if (elem.kind === 'list-spread') {
const err = check(elem.spread, exp, env, subst);
if (err) return err;
continue;
}
const err = check(elem, elemType, env, subst);
if (err) return err;
}
return null;
}
// Fallback: infer and unify
const inferred = infer(expr, env, subst);
if (!inferred) return null; // Can't infer, skip silently
return unify(inferred, expected, subst);
}
function warn(msg: string, expr: AST) {
const loc = expr.line ? ` (line ${expr.line})` : '';
console.warn(`TypeError${loc}: ${msg}`);
}
function bindPattern(pattern: Pattern, type: TypeAST, env: TypeEnv, subst: Subst): string | null {
const t = applySubst(type, subst);
switch (pattern.kind) {
case 'var':
env.set(pattern.name, t);
return null;
case 'wildcard':
case 'literal':
return null;
case 'constructor':
// look up ctor types
const ctorType = env.get(pattern.name);
if (!ctorType) return null;
const fresh = freshen(ctorType, subst).type;
if (pattern.args.length === 0) {
unify(fresh, t, subst);
return null;
}
let cur = fresh;
for (const arg of pattern.args) {
if (cur.kind !== 'type-function') return null;
const err = bindPattern(arg, cur.param, env, subst);
if (err) return err;
cur = applySubst(cur.result, subst);
}
unify(cur, t, subst);
return null;
case 'list':
case 'list-spread':
if (t.kind === 'type-apply' && t.constructor.kind === 'type-name' && t.constructor.name === 'List' && t.args.length === 1) {
const elemType = t.args[0];
if (pattern.kind === 'list') {
for (const elem of pattern.elements) {
bindPattern(elem, elemType, env, subst);
}
} else {
for (const elem of pattern.head) {
bindPattern(elem, elemType, env, subst);
}
env.set(pattern.spread, t);
}
return null;
}
if (t.kind === 'type-var') return null;
return `Connot match ${prettyPrintType(t)} against list pattern`;
case 'record':
if (t.kind === 'type-record') {
for (const [key, pat] of Object.entries(pattern.fields)) {
const field = t.fields.find(f => f.name === key);
if (field) bindPattern(pat, field.type, env, subst);
}
}
return null;
default:
return null;
}
}
let moduleConstraints = new Map<string, { param: string, className: string }>();
let moduleInstances = new Map<string, Set<string>>();
let inferredConstraints: { varName: string, className: string }[] = [];
let typeConstructors = new Map<string, string[]>();
export function typecheck(defs: Definition[], typeDefs: TypeDefinition[] = [], classDefs: ClassDefinition[] = [], instances: InstanceDeclaration[] = []) {
const env: TypeEnv = new Map();
// Register instances as a lookup: className -> Set of type names
const instanceMap = new Map<string, Set<string>>();
for (const inst of instances) {
if (!instanceMap.has(inst.className)) instanceMap.set(inst.className, new Set());
instanceMap.get(inst.className)!.add(inst.typeName);
}
moduleInstances = instanceMap;
// Register class methods with constraints in env
for (const cls of classDefs) {
for (const method of cls.methods) {
env.set(method.name, method.type);
moduleConstraints.set(method.name, { param: cls.param, className: cls.name });
}
}
// Register constructors
for (const td of typeDefs) {
typeConstructors.set(td.name, td.constructors.map(c => c.name));
const resultType: TypeAST = td.params.length === 0
? { kind: 'type-name', name: td.name }
: { kind: 'type-apply', constructor: { kind: 'type-name', name: td.name }, args: td.params.map(p => ({ kind: 'type-var', name: p })) };
for (const ctor of td.constructors) {
if (ctor.args.length === 0) {
env.set(ctor.name, resultType);
} else {
let fnType: TypeAST = resultType;
for (let i = ctor.args.length - 1; i >= 0; i--) {
fnType = { kind: 'type-function', param: ctor.args[i], result: fnType };
}
env.set(ctor.name, fnType);
}
}
}
// Register all annotated defs in env first so they can ref eachother
for (const def of defs) {
if (def.annotation) {
env.set(def.name, def.annotation.type);
}
}
// Check each annotated def
for (const def of defs) {
if (def.body) {
const subst: Subst = new Map();
inferredConstraints = []; // reset
if (def.annotation) {
const err = check(def.body, def.annotation.type, env, subst);
if (err) warn(err, def.body);
} else {
const t = infer(def.body, env, subst);
if (t) env.set(def.name, t);
}
// Propagate inferred constraints to this definition
if (inferredConstraints.length > 0 && def.annotation) {
// Collect annotation var names
const annoVars = new Set<string>();
function collectVars(t: TypeAST) {
if (t.kind === 'type-var') annoVars.add(t.name);
if (t.kind === 'type-function') { collectVars(t.param); collectVars(t.result); }
if (t.kind === 'type-apply') { collectVars(t.constructor); t.args.forEach(collectVars); }
if (t.kind === 'type-record') t.fields.forEach(f => collectVars(f.type));
}
collectVars(def.annotation.type);
for (const ic of inferredConstraints) {
// Walk subst chain to find which annotation var this connects to
let found: string | null = null;
for (const av of annoVars) {
const resolved = applySubst({ kind: 'type-var', name: av }, subst);
const icResolved = applySubst({ kind: 'type-var', name: ic.varName }, subst);
// If they resolve to the same thing, this annotation var is the one
if (unify(resolved, icResolved, new Map(subst)) === null) {
found = av;
break;
}
}
if (found) {
moduleConstraints.set(def.name, { param: found, className: ic.className });
}
}
}
}
}
}
function occursIn(name: string, type: TypeAST, subst: Subst): boolean {
const t = applySubst(type, subst);
switch (t.kind) {
case 'type-var': return t.name === name;
case 'type-name': return false;
case 'type-function': return occursIn(name, t.param, subst) || occursIn(name, t.result, subst);
case 'type-apply': return occursIn(name, t.constructor, subst) || t.args.some(a => occursIn(name, a, subst));
case 'type-record': return t.fields.some(f => occursIn(name, f.type, subst));
}
}
let freshCounter = 0;
function freshen(type: TypeAST, subst: Subst): { type: TypeAST, varMap: Map<string, string> } {
const vars = new Map<string, TypeAST>();
function go(t: TypeAST): TypeAST {
const resolved = applySubst(t, subst);
switch (resolved.kind) {
case 'type-var':
if (!vars.has(resolved.name)) {
vars.set(resolved.name, { kind: 'type-var', name: `_t${freshCounter++}` });
}
return vars.get(resolved.name)!;
case 'type-name': return resolved;
case 'type-function': return { kind: 'type-function', param: go(resolved.param), result: go(resolved.result) };
case 'type-apply': return { kind: 'type-apply', constructor: go(resolved.constructor), args: resolved.args.map(go) };
case 'type-record': return { kind: 'type-record', fields: resolved.fields.map(f => ({ name: f.name, type: go(f.type) })) };
}
}
const result = go(type);
const varMap = new Map<string, string>();
for (const [old, fresh] of vars) {
if (fresh.kind === 'type-var') varMap.set(old, fresh.name);
}
return { type: result, varMap };
}
function checkExhaustiveness(scrutType: TypeAST, cases: { pattern: Pattern }[], expr: AST, subst: Subst) {
const resolved = applySubst(scrutType, subst);
// Find type name
let typeName: string | null = null;
if (resolved.kind === 'type-name') typeName = resolved.name;
if (resolved.kind === 'type-apply' && resolved.constructor.kind === 'type-name') typeName = resolved.constructor.name;
if (!typeName) return;
const allCtors = typeConstructors.get(typeName);
if (!allCtors) return;
// Any catch-alls?
const hasCatchAll = cases.some(c => c.pattern.kind === 'var' || c.pattern.kind === 'wildcard');
if (hasCatchAll) return;
// Collect constructor names
const coveredCtors = new Set<string>();
for (const c of cases) {
if (c.pattern.kind === 'constructor') {
coveredCtors.add(c.pattern.name);
}
}
const missing = allCtors.filter(name => !coveredCtors.has(name));
if (missing.length > 0) {
warn(`Non-exhaustive match, missing: ${missing.join(', ')}`, expr);
}
}