import { basename, dirname, extname, join } from "node:path"; import type b from "@babel/core"; import hash from "@emotion/hash"; import { isPlainObject } from "lodash"; import invariant from "tiny-invariant"; import { type NodePath, type types as t } from "@babel/core"; import { type SourceLocation, type StyleMapEntry, macroNames } from "./shared"; import { type ResolveTailwindOptions, getClassName } from "./index"; export type ClassNameCollector = (path: string, entries: StyleMapEntry[]) => void; type BabelTypes = typeof b.types; type Type = "css" | "js"; export function babelTailwind( options: ResolveTailwindOptions, onCollect: ClassNameCollector | undefined ) { const { styleMap, clsx, getClassName: getClass = getClassName, jsxAttributeAction = "delete", jsxAttributeName = "css", vite: bustCache, } = options; type BabelPluginUtils = ReturnType; function getUtils(path: NodePath, state: b.PluginPass, t: BabelTypes) { let cx: t.Identifier; let tslibImport: t.Identifier; let styleImport: t.Identifier; const cssMap = new Map(); const jsMap = new Map(); function getStyleImport() { styleImport ??= path.scope.generateUidIdentifier("styles"); return t.cloneNode(styleImport); } return { getClass(type: Type, value: string) { return type === "css" ? getClass(value) : "tw_" + hash(value); }, sliceText: (node: t.Node): SourceLocation => ({ filename: state.filename!, start: node.loc!.start, end: node.loc!.end, text: state.file.code .split("\n") .slice(node.loc!.start.line - 1, node.loc!.end.line) .join("\n"), }), recordIfAbsent(type: Type, entry: StyleMapEntry) { const map = type === "css" ? cssMap : jsMap; if (!map.has(entry.key)) { map.set(entry.key, entry); } }, replaceWithImport({ type, path, className, }: { type: Type; path: NodePath; className: string; }) { if (type === "css") { path.replaceWith(t.stringLiteral(className)); } else { const styleImportId = getStyleImport(); path.replaceWith( t.memberExpression(styleImportId, t.stringLiteral(className), true) ); } }, getCx: () => { if (cx == null) { cx = path.scope.generateUidIdentifier("cx"); path.node.body.unshift(getClsxImport(t, cx, clsx)); } return t.cloneNode(cx); }, getTSlibImport: () => { if (tslibImport == null) { tslibImport = path.scope.generateUidIdentifier("tslib"); path.node.body.unshift( t.importDeclaration( [t.importNamespaceSpecifier(tslibImport)], t.stringLiteral("tslib") ) ); } return t.cloneNode(tslibImport); }, finish(node: t.Program) { const { filename } = state; if (!cssMap.size && !jsMap.size) return; invariant(filename, "babel: missing state.filename"); if (cssMap.size) { const cssName = basename(filename, extname(filename)) + ".css"; const path = join(dirname(filename), cssName); const value = Array.from(cssMap.values()); const importee = `tailwind:./${cssName}` + getSuffix(bustCache, value); node.body.unshift(t.importDeclaration([], t.stringLiteral(importee))); styleMap.set(path, value); onCollect?.(path, value); } if (jsMap.size) { const jsName = basename(filename, extname(filename)) + ".tailwindStyle.js"; const path = join(dirname(filename), jsName); const value = Array.from(jsMap.values()); const importee = `tailwind:./${jsName}` + getSuffix(bustCache, value); node.body.unshift( t.importDeclaration( [t.importNamespaceSpecifier(getStyleImport())], t.stringLiteral(importee) ) ); styleMap.set(path, value); onCollect?.(path, value); } }, }; } return definePlugin(({ types: t }) => ({ Program: { enter(path, state) { const _ = getUtils(path, state, t); Object.assign(state, _); for (const { callee, imported, prefix } of getMacros(t, path, macroNames).map( macro => mapMacro(t, macro) )) { const type = imported === "tw" ? "css" : imported === "tws" ? "js" : undefined; if (!type) continue; if (isNodePath(callee, t.isTaggedTemplateExpression)) { const { node } = callee; const { quasi } = node; invariant( !quasi.expressions.length, `Macro call should not contain expressions` ); const value = quasi.quasis[0].value.cooked; if (value) { const list = trimPrefix(value, prefix ? prefix + ":" : undefined); const className = _.getClass(type, list.join(" ")); _.recordIfAbsent(type, { key: className, classNames: list, location: _.sliceText(node), }); _.replaceWithImport({ type, path: callee, className: addIf(className, list.includes("group") && " group"), }); } } else if (isNodePath(callee, t.isCallExpression)) { const { node } = callee; if (!t.isIdentifier(node.callee)) continue; const list = callee.get("arguments").flatMap(evaluateArgs); const className = getClass(list.join(" ")); _.recordIfAbsent(type, { key: className, classNames: list, location: _.sliceText(node), }); _.replaceWithImport({ type, path: callee, className: addIf(className, list.includes("group") && " group"), }); } } }, exit({ node }, _) { _.finish(node); }, }, JSXAttribute(path, _) { const { name } = path.node; if (name.name !== jsxAttributeName) return; const valuePath = path.get("value"); if (!valuePath.node) return; const copy = jsxAttributeAction === "delete" ? undefined : t.cloneNode(valuePath.node, true); const parent = path.parent as t.JSXOpeningElement; const classNameAttribute = parent.attributes.find( (attr): attr is t.JSXAttribute => t.isJSXAttribute(attr) && attr.name.name === "className" ); matchPath(valuePath, go => ({ StringLiteral(path) { const { node } = path; const { value } = node; const trimmed = trim(value); if (trimmed.length) { const className = getClass(trimmed.join(" ")); _.recordIfAbsent("css", { key: className, classNames: trimmed, location: _.sliceText(node), }); path.replaceWith(t.stringLiteral(className)); } }, ArrayExpression(path) { for (const element of path.get("elements")) { go(element); } }, ObjectExpression(path) { const trimmed = evaluateArgs(path); const className = getClass(trimmed.join(" ")); _.recordIfAbsent("css", { key: className, classNames: trimmed, location: _.sliceText(path.node), }); path.replaceWith(t.stringLiteral(className)); }, JSXExpressionContainer(path) { go(path.get("expression")); }, ConditionalExpression(path) { go(path.get("consequent")); go(path.get("alternate")); }, LogicalExpression(path) { go(path.get("right")); }, CallExpression(path) { for (const arg of path.get("arguments")) { go(arg); } }, })); let valuePathNode = extractJSXContainer(valuePath.node); if ( t.isArrayExpression(valuePathNode) && valuePathNode.elements.every(node => t.isStringLiteral(node)) ) { valuePathNode = t.stringLiteral( // eslint-disable-next-line @typescript-eslint/no-unnecessary-type-assertion (valuePathNode.elements as t.StringLiteral[]).map(node => node.value).join(" ") ); } if (classNameAttribute) { const attrValue = classNameAttribute.value!; const wrap = (originalValue: b.types.Expression) => t.callExpression(_.getCx(), [originalValue, valuePathNode]); // If both are string literals, we can merge them directly here if (t.isStringLiteral(attrValue) && t.isStringLiteral(valuePathNode)) { attrValue.value += (attrValue.value.at(-1) === " " ? "" : " ") + valuePathNode.value; } else { const internalAttrValue = extractJSXContainer(attrValue); if ( t.isArrowFunctionExpression(internalAttrValue) && !t.isBlockStatement(internalAttrValue.body) ) { internalAttrValue.body = wrap(internalAttrValue.body); } else { classNameAttribute.value = t.jsxExpressionContainer(wrap(internalAttrValue)); } } } else { const wrap = (originalValue: b.types.Expression) => t.callExpression(_.getCx(), [valuePathNode, originalValue]); const rest = parent.attributes.filter(attr => t.isJSXSpreadAttribute(attr)); let arg; if (rest.length === 1 && (arg = rest[0].argument) && t.isIdentifier(arg)) { // props from argument and not modified anywhere const scope = path.scope.getBinding(arg.name); let index: number; const node = scope?.path.node; if ( scope && !scope.constantViolations.length && t.isFunctionDeclaration(scope.path.parent) && (index = (scope.path.parent.params as t.Node[]).indexOf(node!)) !== -1 && (t.isIdentifier(node) || t.isObjectPattern(node)) ) { const clsVar = path.scope.generateUidIdentifier("className"); if (t.isIdentifier(node)) { scope.path.parent.params[index] = t.objectPattern([ t.objectProperty(t.identifier("className"), clsVar), t.restElement(node), ]); } else { node.properties.unshift( t.objectProperty(t.identifier("className"), clsVar) ); } parent.attributes.push( t.jsxAttribute( t.jsxIdentifier("className"), t.jsxExpressionContainer(wrap(clsVar)) ) ); } else { const tslibImport = _.getTSlibImport(); rest[0].argument = t.callExpression( t.memberExpression(tslibImport, t.identifier("__rest")), [arg, t.arrayExpression([t.stringLiteral("className")])] ); parent.attributes.push( t.jsxAttribute( t.jsxIdentifier("className"), t.jsxExpressionContainer( wrap(t.memberExpression(arg, t.identifier("className"))) ) ) ); } } else { const containerValue = t.isStringLiteral(valuePathNode) ? valuePathNode : t.callExpression(_.getCx(), [valuePathNode]); parent.attributes.push( t.jsxAttribute( t.jsxIdentifier("className"), t.jsxExpressionContainer(containerValue) ) ); } } if (jsxAttributeAction === "delete") { path.remove(); } else { path.node.value = copy!; if (Array.isArray(jsxAttributeAction) && jsxAttributeAction[0] === "rename") { path.node.name.name = jsxAttributeAction[1]; } } }, })); } function getClsxImport(t: BabelTypes, cx: t.Identifier, clsx: string) { switch (clsx) { case "emotion": return t.importDeclaration( [t.importSpecifier(cx, t.identifier("cx"))], t.stringLiteral("@emotion/css") ); case "clsx": return t.importDeclaration([t.importDefaultSpecifier(cx)], t.stringLiteral("clsx")); case "classnames": return t.importDeclaration( [t.importDefaultSpecifier(cx)], t.stringLiteral("classnames") ); default: throw new Error("Unknown clsx library"); } } function evaluateArgs(path: NodePath) { const { confident, value } = path.evaluate(); invariant(confident, "Argument cannot be statically evaluated"); if (typeof value === "string") { return trim(value); } if (isPlainObject(value)) { return flatMapEntries(value, (classes, modifier) => { if (modifier === "data" && isPlainObject(classes)) { return flatMapEntries(classes as Record, (cls, key) => typeof cls === "string" ? trimPrefix(cls, `${modifier}-[${key}]:`) : flatMapEntries(cls as Record, (cls, attrValue) => trimPrefix(cls, `${modifier}-[${key}=${attrValue}]:`) ) ); } invariant( typeof classes === "string", `Value for "${modifier}" should be a string` ); return trimPrefix(classes, modifier + ":"); }); } throw new Error("Invalid argument type"); } function getName(t: BabelTypes, exp: t.Node) { if (t.isIdentifier(exp)) { return exp.name; } else if (t.isStringLiteral(exp)) { return exp.value; } } function getMacros( t: BabelTypes, programPath: NodePath, importSources: string[] ) { const importDecs = programPath .get("body") .filter(x => isNodePath(x, t.isImportDeclaration)) .filter(x => importSources.includes(x.node.source.value)); const macros = importDecs .flatMap(x => x.get("specifiers")) .map(x => { const local = x.get("local"); if (isNodePath(x, t.isImportNamespaceSpecifier)) { return local.scope .getOwnBinding(local.node.name)! .referencePaths.map(p => p.parentPath) .filter(p => isNodePath(p, t.isMemberExpression)) .map(p => ({ local: p, imported: getName(t, p.node.property)!, })) .filter(p => p.imported); } else if (t.isImportSpecifier(x.node)) { const imported = x.node.imported; return local.scope.getOwnBinding(local.node.name)!.referencePaths.map(p => ({ local: p as NodePath, imported: getName(t, imported)!, })); } }) .filter(Boolean) .flat(1); for (const x of importDecs) { x.remove(); } return macros; } function mapMacro(t: BabelTypes, macro: ReturnType[number]) { let callee = macro.local.parentPath; const prefix: string[] = []; while (isNodePath(callee, t.isMemberExpression)) { invariant(t.isIdentifier(callee.node.property), "Invalid member expression"); prefix.unshift( callee.node.property.name.replace(/([a-z])([A-Z])/g, "$1-$2").toLowerCase() ); callee = callee.parentPath; } return { callee, imported: macro.imported, prefix: prefix.length ? prefix.join(":") : undefined, }; } const definePlugin = (fn: (runtime: typeof b) => b.Visitor) => (runtime: typeof b) => { const plugin: b.PluginObj = { visitor: fn(runtime), }; return plugin as b.PluginObj; }; const extractJSXContainer = (attr: NonNullable): t.Expression => attr.type === "JSXExpressionContainer" ? (attr.expression as t.Expression) : attr; function matchPath( nodePath: NodePath, fns: (dig: (nodePath: NodePath) => void) => b.Visitor ) { if (!nodePath.node) return; const fn = fns(path => matchPath(path, fns))[nodePath.node.type] as any; fn?.(nodePath); } function addIf(text: string, suffix: string | false) { return suffix ? text + suffix : text; } const isNodePath = ( nodePath: NodePath | null, predicate: (node: t.Node) => node is T ): nodePath is NodePath => Boolean(nodePath?.node && predicate(nodePath.node)); function getSuffix(add: boolean | undefined, entries: StyleMapEntry[]) { if (!add) return ""; const cacheKey = hash(entries.map(x => x.classNames).join(",")); return `?${cacheKey}`; } const trim = (value: string) => value.replace(/\s+/g, " ").trim().split(" ").filter(Boolean); const trimPrefix = (cls: string, prefix = "") => trim(cls).map(value => prefix + value); const flatMapEntries = ( map: Record, fn: (value: V, key: K) => R[] ): R[] => Object.entries(map).flatMap(([key, value]) => fn(value as V, key as K));