import { basename, dirname, extname, join } from "node:path"; import type b from "@babel/core"; import { type Node, type NodePath, type types as t } from "@babel/core"; import hash from "@emotion/hash"; import { memoize } from "lodash-es"; import invariant from "tiny-invariant"; import { type ResolveTailwindOptions, getClassName } from "../index"; import { type SourceLocation, type StyleMapEntry, utilsName } from "../shared"; import { handleMacro } from "./macro"; import { evaluateArgs, trim } from "./utils"; export type ClassNameCollector = (path: string, entries: StyleMapEntry[]) => void; type BabelTypes = typeof b.types; type Type = "css" | "js"; type Scope = ReturnType; export type BabelPluginUtils = ReturnType; interface Import { source: string; specifiers: { local: string; imported: string; }[]; } function getUtils({ path, state, t, options, onCollect, }: { path: NodePath; state: b.PluginPass; t: BabelTypes; options: ResolveTailwindOptions; onCollect: ClassNameCollector | undefined; }) { const { styleMap, clsx, getClassName: getClass = getClassName, vite: bustCache, cssModules, } = options; let cx: t.Identifier; const cssMap = new Map(); const jsMap = new Map(); const imports: Import[] = path.node.body .filter(node => t.isImportDeclaration(node)) .map(i => ({ source: i.source.value, specifiers: i.specifiers .filter(x => t.isImportSpecifier(x)) .filter(x => x.importKind === "value") .map(x => ({ local: x.local.name, imported: t.isStringLiteral(x.imported) ? x.imported.value : x.imported.name, })), })); let existingCx: string | undefined; switch (clsx) { case "emotion": existingCx = imports .find(i => i.source === "@emotion/css") ?.specifiers.find(s => s.imported === "cx")?.local; break; case "clsx": existingCx = imports .find(i => i.source === "clsx") ?.specifiers.find(s => s.imported === "clsx")?.local; break; case "classnames": existingCx = imports.find(i => i.source === "classnames")?.specifiers[0]?.local; break; } function reuseImport(scope: Scope) { if ( existingCx && scope.getBinding(existingCx) === path.scope.getBinding(existingCx) ) { return t.identifier(existingCx); } } function cacheNode(fn: () => N) { let cache: N | undefined; return Object.assign( (): N => { cache ??= fn(); return t.cloneNode(cache); }, { getCache() { return cache; }, } ); } const getStyleImport = cacheNode(() => path.scope.generateUidIdentifier("styles")); const getCssModuleImport = cacheNode(() => path.scope.generateUidIdentifier("cssModule") ); const getUtilsImport = memoize(() => { const importDecl = t.importDeclaration([], t.stringLiteral(utilsName)); path.node.body.unshift(importDecl); return importDecl; }); return { program: path, existingCx, getClass(type: Type, value: string) { return type === "css" ? getClass(value) : "tw_" + hash(value); }, sliceText: (node: t.Node): SourceLocation => ({ filename: state.filename!, start: node.loc!.start, end: node.loc!.end, text: state.file.code .split("\n") .slice(node.loc!.start.line - 1, node.loc!.end.line) .join("\n"), }), recordIfAbsent(type: Type, entry: StyleMapEntry) { const map = type === "css" ? cssMap : jsMap; if (!map.has(entry.key)) { map.set(entry.key, entry); } }, replaceWithImport({ type, path, className, }: { type: Type; path: NodePath; className: string; }) { if (type === "css") { path.replaceWith(t.stringLiteral(className)); } else { const styleImportId = getStyleImport(); path.replaceWith( t.memberExpression(styleImportId, t.stringLiteral(className), true) ); } }, getCx: (localScope: Scope) => { if (cx == null) { const reuse = reuseImport(localScope); if (reuse) return reuse; cx = path.scope.generateUidIdentifier("cx"); // If you unshift, react-refresh/babel will insert _s(Component) right above // the component declaration, which is invalid. path.node.body.push(getClsxImport(t, cx, clsx)); } return t.cloneNode(cx); }, getTSlibImport: cacheNode(() => { const tslibImport = path.scope.generateUidIdentifier("tslib"); path.node.body.push( t.importDeclaration( [t.importNamespaceSpecifier(tslibImport)], t.stringLiteral("tslib") ) ); return tslibImport; }), getClsCompose: cacheNode(() => { const clsComposeImport = path.scope.generateUidIdentifier("composeClassName"); getUtilsImport().specifiers.push( t.importSpecifier(clsComposeImport, t.identifier("composeClassName")) ); return clsComposeImport; }), getClassedImport: cacheNode(() => { const classedImport = path.scope.generateUidIdentifier("classed"); getUtilsImport().specifiers.push( t.importSpecifier(classedImport, t.identifier("classed")) ); return classedImport; }), getCssModuleImport, getClassNameValue: (className: string) => { const validId = t.isValidIdentifier(className); return cssModules ? t.memberExpression( getCssModuleImport(), validId ? t.identifier(className) : t.stringLiteral(className), !validId ) : t.stringLiteral(className); }, finish(node: t.Program) { const { filename } = state; if (!cssMap.size && !jsMap.size) return; invariant(filename, "babel: missing state.filename"); const cssModuleImport = getCssModuleImport.getCache(); if (cssMap.size) { const cssName = basename(filename, extname(filename)) + (cssModuleImport ? ".module" : "") + ".css"; const path = join(dirname(filename), cssName); const value = Array.from(cssMap.values()); if (cssModuleImport) { const importee = `tailwind:./${cssName}` + getSuffix(bustCache, value); node.body.unshift( t.importDeclaration( [t.importDefaultSpecifier(cssModuleImport)], t.stringLiteral(importee) ) ); } else { const importee = `tailwind:./${cssName}` + getSuffix(bustCache, value); node.body.unshift(t.importDeclaration([], t.stringLiteral(importee))); } styleMap.set(path, value); onCollect?.(path, value); } if (jsMap.size) { const jsName = basename(filename, extname(filename)) + ".tailwindStyle.js"; const path = join(dirname(filename), jsName); const value = Array.from(jsMap.values()); const importee = `tailwind:./${jsName}` + getSuffix(bustCache, value); node.body.unshift( t.importDeclaration( [t.importNamespaceSpecifier(getStyleImport())], t.stringLiteral(importee) ) ); styleMap.set(path, value); onCollect?.(path, value); } }, }; } export function babelTailwind( options: ResolveTailwindOptions, onCollect: ClassNameCollector | undefined ) { const { getClassName: getClass = getClassName, jsxAttributeAction = "delete", jsxAttributeName = "css", composeRenderProps, } = options; return definePlugin(({ types: t }) => ({ Program: { enter(path, state) { const _ = getUtils({ path, state, t, options, onCollect }); Object.assign(state, _); handleMacro({ t, path, _ }); }, exit({ node }, _) { _.finish(node); }, }, JSXAttribute(path, _) { const { name } = path.node; if (name.name !== jsxAttributeName) return; const valuePath = path.get("value"); if (!valuePath.node) return; const copy = jsxAttributeAction === "delete" ? undefined : t.cloneNode(valuePath.node, true); const parent = path.parent as t.JSXOpeningElement; const classNameAttribute = parent.attributes.find( (attr): attr is t.JSXAttribute => t.isJSXAttribute(attr) && attr.name.name === "className" ); matchPath(valuePath, go => ({ StringLiteral(path) { const { node } = path; const { value } = node; const trimmed = trim(value); if (trimmed.length) { const className = getClass(trimmed.join(" ")); _.recordIfAbsent("css", { key: className, classNames: trimmed, location: _.sliceText(node), }); path.replaceWith(_.getClassNameValue(className)); } }, ArrayExpression(path) { for (const element of path.get("elements")) { go(element); } }, ObjectExpression(path) { const trimmed = evaluateArgs(path); const className = getClass(trimmed.join(" ")); _.recordIfAbsent("css", { key: className, classNames: trimmed, location: _.sliceText(path.node), }); path.replaceWith(_.getClassNameValue(className)); }, JSXExpressionContainer(path) { go(path.get("expression")); }, ConditionalExpression(path) { go(path.get("consequent")); go(path.get("alternate")); }, LogicalExpression(path) { go(path.get("right")); }, CallExpression(path) { for (const arg of path.get("arguments")) { go(arg); } }, })); const { identifier: id, jsxExpressionContainer: jsxBox, jsxIdentifier: jsxId, callExpression: call, } = t; let valuePathNode = extractJSXContainer(valuePath.node); if ( t.isArrayExpression(valuePathNode) && valuePathNode.elements.every(node => t.isStringLiteral(node)) ) { valuePathNode = t.stringLiteral( // eslint-disable-next-line @typescript-eslint/no-unnecessary-type-assertion (valuePathNode.elements as t.StringLiteral[]).map(node => node.value).join(" ") ); } const wrap = (existing: b.types.Expression) => composeRenderProps ? call(_.getClsCompose(), [valuePathNode, existing]) : call(_.getCx(path.scope), [valuePathNode, existing]); // There is an existing className attribute if (classNameAttribute) { const attrValue = classNameAttribute.value!; // If both are string literals, we can merge them directly here if (t.isStringLiteral(attrValue) && t.isStringLiteral(valuePathNode)) { attrValue.value += (attrValue.value.at(-1) === " " ? "" : " ") + valuePathNode.value; } else { const internal = extractJSXContainer(attrValue); if ( t.isArrowFunctionExpression(internal) && !t.isBlockStatement(internal.body) ) { // className={({ isEntering }) => isEntering ? "enter" : "exit"} // className: ({ isEntering }) => _cx("${clsName}", isEntering ? "enter" : "exit") internal.body = wrap(internal.body); } else if ( // if the existing className is already wrapped with cx, we unwrap it // to avoid double calling: cx(cx()) t.isCallExpression(internal) && t.isIdentifier(internal.callee) && _.existingCx && _.program.scope .getBinding(_.existingCx) ?.referencePaths.map(p => p.node) .includes(internal.callee) ) { classNameAttribute.value = jsxBox( call(_.getCx(path.scope), [ valuePathNode, ...(internal.arguments as (b.types.Expression | b.types.SpreadElement)[]), ]) ); } else { classNameAttribute.value = jsxBox(wrap(internal)); } } } else { const rest = parent.attributes.filter(attr => t.isJSXSpreadAttribute(attr)); let arg; // if there is only one JSX spread attribute and it's an identifier // ... {...props} /> if (rest.length === 1 && (arg = rest[0].argument) && t.isIdentifier(arg)) { // props from argument and not modified anywhere, get the declaration of this argument const scope = path.scope.getBinding(arg.name); let index: number; // node is an identifier or object pattern in `params` // (props) => ... or ({ ...props }) => ... const node = scope?.path.node; if ( scope && !scope.constantViolations.length && (t.isFunctionDeclaration(scope.path.parent) || t.isArrowFunctionExpression(scope.path.parent)) && (index = (scope.path.parent.params as t.Node[]).indexOf(node!)) !== -1 && (t.isIdentifier(node) || t.isObjectPattern(node)) ) { const clsVar = path.scope.generateUidIdentifier("className"); if (t.isIdentifier(node)) { // (props) => ... // ↪ ({ className, ...props }) => ... scope.path.parent.params[index] = t.objectPattern([ t.objectProperty(id("className"), clsVar), t.restElement(node), ]); } else { // ({ ...props }) => ... // ↪ ({ className, ...props }) => ... node.properties.unshift(t.objectProperty(id("className"), clsVar)); } parent.attributes.push( t.jsxAttribute(jsxId("className"), jsxBox(wrap(clsVar))) ); } else { const tslibImport = _.getTSlibImport(); rest[0].argument = call(t.memberExpression(tslibImport, id("__rest")), [ arg, t.arrayExpression([t.stringLiteral("className")]), ]); parent.attributes.push( t.jsxAttribute( jsxId("className"), jsxBox(wrap(t.memberExpression(arg, id("className")))) ) ); } } else { // Fallback const containerValue = t.isStringLiteral(valuePathNode) ? valuePathNode : call(_.getCx(path.scope), [valuePathNode]); parent.attributes.push( t.jsxAttribute( t.jsxIdentifier("className"), t.jsxExpressionContainer(containerValue) ) ); } } if (jsxAttributeAction === "delete") { path.remove(); } else { path.node.value = copy!; if (Array.isArray(jsxAttributeAction) && jsxAttributeAction[0] === "rename") { // eslint-disable-next-line unicorn/consistent-destructuring path.node.name.name = jsxAttributeAction[1]; } } }, })); } function getClsxImport(t: BabelTypes, cx: t.Identifier, clsx: string) { switch (clsx) { case "emotion": return t.importDeclaration( [t.importSpecifier(cx, t.identifier("cx"))], t.stringLiteral("@emotion/css") ); case "clsx": return t.importDeclaration( [t.importSpecifier(cx, t.identifier("clsx"))], t.stringLiteral("clsx") ); case "classnames": return t.importDeclaration( [t.importDefaultSpecifier(cx)], t.stringLiteral("classnames") ); default: throw new Error("Unknown clsx library"); } } const definePlugin = (fn: (runtime: typeof b) => b.Visitor) => (runtime: typeof b) => { const plugin: b.PluginObj = { visitor: fn(runtime), }; return plugin as b.PluginObj; }; const extractJSXContainer = (attr: NonNullable): t.Expression => attr.type === "JSXExpressionContainer" ? (attr.expression as t.Expression) : attr; function matchPath( nodePath: NodePath, fns: (dig: (nodePath: NodePath) => void) => b.Visitor ) { if (!nodePath.node) return; const fn = fns(path => matchPath(path, fns))[nodePath.node.type] as any; fn?.(nodePath); } function getSuffix(add: boolean | undefined, entries: StyleMapEntry[]) { if (!add) return ""; const cacheKey = hash(entries.map(x => x.classNames).join(",")); return `?${cacheKey}`; }