2025-06-13 01:08:36 -04:00

545 lines
17 KiB
TypeScript

import { basename, dirname, extname, join } from "node:path";
import type b from "@babel/core";
import { type Node, type NodePath, type types as t } from "@babel/core";
import hash from "@emotion/hash";
import { memoize } from "lodash-es";
import invariant from "tiny-invariant";
import { type ResolveTailwindOptions, getClassName } from "../index";
import { type SourceLocation, type StyleMapEntry, utilsName } from "../shared";
import { handleMacro } from "./macro";
import { evaluateArgs, trim } from "./utils";
export type ClassNameCollector = (path: string, entries: StyleMapEntry[]) => void;
type BabelTypes = typeof b.types;
type Type = "css" | "js";
type Scope = ReturnType<typeof NodePath.prototype.getScope>;
export type BabelPluginUtils = ReturnType<typeof getUtils>;
interface Import {
source: string;
specifiers: {
local: string;
imported: string;
}[];
}
function getUtils({
path,
state,
t,
options,
onCollect,
}: {
path: NodePath<t.Program>;
state: b.PluginPass;
t: BabelTypes;
options: ResolveTailwindOptions;
onCollect: ClassNameCollector | undefined;
}) {
const {
styleMap,
clsx,
getClassName: getClass = getClassName,
vite: bustCache,
cssModules,
} = options;
let cx: t.Identifier;
const cssMap = new Map<string, StyleMapEntry>();
const jsMap = new Map<string, StyleMapEntry>();
const imports: Import[] = path.node.body
.filter(node => t.isImportDeclaration(node))
.map(i => ({
source: i.source.value,
specifiers: i.specifiers
.filter(x => t.isImportSpecifier(x))
.filter(x => x.importKind === "value")
.map(x => ({
local: x.local.name,
imported: t.isStringLiteral(x.imported) ? x.imported.value : x.imported.name,
})),
}));
let existingCx: string | undefined;
switch (clsx) {
case "emotion":
existingCx = imports
.find(i => i.source === "@emotion/css")
?.specifiers.find(s => s.imported === "cx")?.local;
break;
case "clsx":
existingCx = imports
.find(i => i.source === "clsx")
?.specifiers.find(s => s.imported === "clsx")?.local;
break;
case "classnames":
existingCx = imports.find(i => i.source === "classnames")?.specifiers[0]?.local;
break;
}
function reuseImport(scope: Scope) {
if (
existingCx &&
scope.getBinding(existingCx) === path.scope.getBinding(existingCx)
) {
return t.identifier(existingCx);
}
}
function cacheNode<N extends Node>(fn: () => N) {
let cache: N | undefined;
return Object.assign(
(): N => {
cache ??= fn();
return t.cloneNode(cache);
},
{
getCache() {
return cache;
},
}
);
}
const getStyleImport = cacheNode(() => path.scope.generateUidIdentifier("styles"));
const getCssModuleImport = cacheNode(() =>
path.scope.generateUidIdentifier("cssModule")
);
const getUtilsImport = memoize(() => {
const importDecl = t.importDeclaration([], t.stringLiteral(utilsName));
path.node.body.unshift(importDecl);
return importDecl;
});
return {
program: path,
existingCx,
getClass(type: Type, value: string) {
return type === "css" ? getClass(value) : "tw_" + hash(value);
},
sliceText: (node: t.Node): SourceLocation => ({
filename: state.filename!,
start: node.loc!.start,
end: node.loc!.end,
text: state.file.code
.split("\n")
.slice(node.loc!.start.line - 1, node.loc!.end.line)
.join("\n"),
}),
recordIfAbsent(type: Type, entry: StyleMapEntry) {
const map = type === "css" ? cssMap : jsMap;
if (!map.has(entry.key)) {
map.set(entry.key, entry);
}
},
replaceWithImport({
type,
path,
className,
}: {
type: Type;
path: NodePath;
className: string;
}) {
if (type === "css") {
path.replaceWith(t.stringLiteral(className));
} else {
const styleImportId = getStyleImport();
path.replaceWith(
t.memberExpression(styleImportId, t.stringLiteral(className), true)
);
}
},
getCx: (localScope: Scope) => {
if (cx == null) {
const reuse = reuseImport(localScope);
if (reuse) return reuse;
cx = path.scope.generateUidIdentifier("cx");
// If you unshift, react-refresh/babel will insert _s(Component) right above
// the component declaration, which is invalid.
path.node.body.push(getClsxImport(t, cx, clsx));
}
return t.cloneNode(cx);
},
getTSlibImport: cacheNode(() => {
const tslibImport = path.scope.generateUidIdentifier("tslib");
path.node.body.push(
t.importDeclaration(
[t.importNamespaceSpecifier(tslibImport)],
t.stringLiteral("tslib")
)
);
return tslibImport;
}),
getClsCompose: cacheNode(() => {
const clsComposeImport = path.scope.generateUidIdentifier("composeClassName");
getUtilsImport().specifiers.push(
t.importSpecifier(clsComposeImport, t.identifier("composeClassName"))
);
return clsComposeImport;
}),
getClassedImport: cacheNode(() => {
const classedImport = path.scope.generateUidIdentifier("classed");
getUtilsImport().specifiers.push(
t.importSpecifier(classedImport, t.identifier("classed"))
);
return classedImport;
}),
getCssModuleImport,
getClassNameValue: (className: string) => {
const validId = t.isValidIdentifier(className);
return cssModules
? t.memberExpression(
getCssModuleImport(),
validId ? t.identifier(className) : t.stringLiteral(className),
!validId
)
: t.stringLiteral(className);
},
finish(node: t.Program) {
const { filename } = state;
if (!cssMap.size && !jsMap.size) return;
invariant(filename, "babel: missing state.filename");
const cssModuleImport = getCssModuleImport.getCache();
if (cssMap.size) {
const cssName =
basename(filename, extname(filename)) +
(cssModuleImport ? ".module" : "") +
".css";
const path = join(dirname(filename), cssName);
const value = Array.from(cssMap.values());
if (cssModuleImport) {
const importee = `tailwind:./${cssName}` + getSuffix(bustCache, value);
node.body.unshift(
t.importDeclaration(
[t.importDefaultSpecifier(cssModuleImport)],
t.stringLiteral(importee)
)
);
} else {
const importee = `tailwind:./${cssName}` + getSuffix(bustCache, value);
node.body.unshift(t.importDeclaration([], t.stringLiteral(importee)));
}
styleMap.set(path, value);
onCollect?.(path, value);
}
if (jsMap.size) {
const jsName = basename(filename, extname(filename)) + ".tailwindStyle.js";
const path = join(dirname(filename), jsName);
const value = Array.from(jsMap.values());
const importee = `tailwind:./${jsName}` + getSuffix(bustCache, value);
node.body.unshift(
t.importDeclaration(
[t.importNamespaceSpecifier(getStyleImport())],
t.stringLiteral(importee)
)
);
styleMap.set(path, value);
onCollect?.(path, value);
}
},
};
}
export function babelTailwind(
options: ResolveTailwindOptions,
onCollect: ClassNameCollector | undefined
) {
const {
getClassName: getClass = getClassName,
jsxAttributeAction = "delete",
jsxAttributeName = "css",
composeRenderProps,
} = options;
return definePlugin<BabelPluginUtils>(({ types: t }) => ({
Program: {
enter(path, state) {
const _ = getUtils({ path, state, t, options, onCollect });
Object.assign(state, _);
handleMacro({ t, path, _ });
},
exit({ node }, _) {
_.finish(node);
},
},
JSXAttribute(path, _) {
const { name } = path.node;
if (name.name !== jsxAttributeName) return;
const valuePath = path.get("value");
if (!valuePath.node) return;
const copy =
jsxAttributeAction === "delete" ? undefined : t.cloneNode(valuePath.node, true);
const parent = path.parent as t.JSXOpeningElement;
const classNameAttribute = parent.attributes.find(
(attr): attr is t.JSXAttribute =>
t.isJSXAttribute(attr) && attr.name.name === "className"
);
matchPath(valuePath, go => ({
StringLiteral(path) {
const { node } = path;
const { value } = node;
const trimmed = trim(value);
if (trimmed.length) {
const className = getClass(trimmed.join(" "));
_.recordIfAbsent("css", {
key: className,
classNames: trimmed,
location: _.sliceText(node),
});
path.replaceWith(_.getClassNameValue(className));
}
},
ArrayExpression(path) {
for (const element of path.get("elements")) {
go(element);
}
},
ObjectExpression(path) {
const trimmed = evaluateArgs(path);
const className = getClass(trimmed.join(" "));
_.recordIfAbsent("css", {
key: className,
classNames: trimmed,
location: _.sliceText(path.node),
});
path.replaceWith(_.getClassNameValue(className));
},
JSXExpressionContainer(path) {
go(path.get("expression"));
},
ConditionalExpression(path) {
go(path.get("consequent"));
go(path.get("alternate"));
},
LogicalExpression(path) {
go(path.get("right"));
},
CallExpression(path) {
for (const arg of path.get("arguments")) {
go(arg);
}
},
}));
const {
identifier: id,
jsxExpressionContainer: jsxBox,
jsxIdentifier: jsxId,
callExpression: call,
} = t;
let valuePathNode = extractJSXContainer(valuePath.node);
if (
t.isArrayExpression(valuePathNode) &&
valuePathNode.elements.every(node => t.isStringLiteral(node))
) {
valuePathNode = t.stringLiteral(
// eslint-disable-next-line @typescript-eslint/no-unnecessary-type-assertion
(valuePathNode.elements as t.StringLiteral[]).map(node => node.value).join(" ")
);
}
const wrap = (existing: b.types.Expression) =>
composeRenderProps
? call(_.getClsCompose(), [valuePathNode, existing])
: call(_.getCx(path.scope), [valuePathNode, existing]);
// There is an existing className attribute
if (classNameAttribute) {
const attrValue = classNameAttribute.value!;
// If both are string literals, we can merge them directly here
if (t.isStringLiteral(attrValue) && t.isStringLiteral(valuePathNode)) {
attrValue.value +=
(attrValue.value.at(-1) === " " ? "" : " ") + valuePathNode.value;
} else {
const internal = extractJSXContainer(attrValue);
if (
t.isArrowFunctionExpression(internal) &&
!t.isBlockStatement(internal.body)
) {
// className={({ isEntering }) => isEntering ? "enter" : "exit"}
// className: ({ isEntering }) => _cx("${clsName}", isEntering ? "enter" : "exit")
internal.body = wrap(internal.body);
} else if (
// if the existing className is already wrapped with cx, we unwrap it
// to avoid double calling: cx(cx())
t.isCallExpression(internal) &&
t.isIdentifier(internal.callee) &&
_.existingCx &&
_.program.scope
.getBinding(_.existingCx)
?.referencePaths.map(p => p.node)
.includes(internal.callee)
) {
classNameAttribute.value = jsxBox(
call(_.getCx(path.scope), [
valuePathNode,
...(internal.arguments as (b.types.Expression | b.types.SpreadElement)[]),
])
);
} else {
classNameAttribute.value = jsxBox(wrap(internal));
}
}
} else {
const rest = parent.attributes.filter(attr => t.isJSXSpreadAttribute(attr));
let arg;
// if there is only one JSX spread attribute and it's an identifier
// ... {...props} />
if (rest.length === 1 && (arg = rest[0].argument) && t.isIdentifier(arg)) {
// props from argument and not modified anywhere, get the declaration of this argument
const scope = path.scope.getBinding(arg.name);
let index: number;
// node is an identifier or object pattern in `params`
// (props) => ... or ({ ...props }) => ...
const node = scope?.path.node;
if (
scope &&
!scope.constantViolations.length &&
(t.isFunctionDeclaration(scope.path.parent) ||
t.isArrowFunctionExpression(scope.path.parent)) &&
(index = (scope.path.parent.params as t.Node[]).indexOf(node!)) !== -1 &&
(t.isIdentifier(node) || t.isObjectPattern(node))
) {
const clsVar = path.scope.generateUidIdentifier("className");
if (t.isIdentifier(node)) {
// (props) => ...
// ↪ ({ className, ...props }) => ...
scope.path.parent.params[index] = t.objectPattern([
t.objectProperty(id("className"), clsVar),
t.restElement(node),
]);
} else {
// ({ ...props }) => ...
// ↪ ({ className, ...props }) => ...
node.properties.unshift(t.objectProperty(id("className"), clsVar));
}
parent.attributes.push(
t.jsxAttribute(jsxId("className"), jsxBox(wrap(clsVar)))
);
} else {
const tslibImport = _.getTSlibImport();
rest[0].argument = call(t.memberExpression(tslibImport, id("__rest")), [
arg,
t.arrayExpression([t.stringLiteral("className")]),
]);
parent.attributes.push(
t.jsxAttribute(
jsxId("className"),
jsxBox(wrap(t.memberExpression(arg, id("className"))))
)
);
}
} else {
// Fallback
const containerValue = t.isStringLiteral(valuePathNode)
? valuePathNode
: call(_.getCx(path.scope), [valuePathNode]);
parent.attributes.push(
t.jsxAttribute(
t.jsxIdentifier("className"),
t.jsxExpressionContainer(containerValue)
)
);
}
}
if (jsxAttributeAction === "delete") {
path.remove();
} else {
path.node.value = copy!;
if (Array.isArray(jsxAttributeAction) && jsxAttributeAction[0] === "rename") {
// eslint-disable-next-line unicorn/consistent-destructuring
path.node.name.name = jsxAttributeAction[1];
}
}
},
}));
}
function getClsxImport(t: BabelTypes, cx: t.Identifier, clsx: string) {
switch (clsx) {
case "emotion":
return t.importDeclaration(
[t.importSpecifier(cx, t.identifier("cx"))],
t.stringLiteral("@emotion/css")
);
case "clsx":
return t.importDeclaration(
[t.importSpecifier(cx, t.identifier("clsx"))],
t.stringLiteral("clsx")
);
case "classnames":
return t.importDeclaration(
[t.importDefaultSpecifier(cx)],
t.stringLiteral("classnames")
);
default:
throw new Error("Unknown clsx library");
}
}
const definePlugin =
<T>(fn: (runtime: typeof b) => b.Visitor<b.PluginPass & T>) =>
(runtime: typeof b) => {
const plugin: b.PluginObj<b.PluginPass & T> = {
visitor: fn(runtime),
};
return plugin as b.PluginObj;
};
const extractJSXContainer = (attr: NonNullable<t.JSXAttribute["value"]>): t.Expression =>
attr.type === "JSXExpressionContainer" ? (attr.expression as t.Expression) : attr;
function matchPath(
nodePath: NodePath<t.Node | null | undefined>,
fns: (dig: (nodePath: NodePath<t.Node | null | undefined>) => void) => b.Visitor
) {
if (!nodePath.node) return;
const fn = fns(path => matchPath(path, fns))[nodePath.node.type] as any;
fn?.(nodePath);
}
function getSuffix(add: boolean | undefined, entries: StyleMapEntry[]) {
if (!add) return "";
const cacheKey = hash(entries.map(x => x.classNames).join(","));
return `?${cacheKey}`;
}