import { flattenDeep, padStart } from "es-toolkit/compat"; import type { Node } from "prosemirror-model"; import type { Transaction } from "prosemirror-state"; import { Plugin, PluginKey } from "prosemirror-state"; import { Decoration, DecorationSet } from "prosemirror-view"; import type refractorType from "refractor/core"; import { getLoaderForLanguage, getRefractorLangForLanguage } from "../lib/code"; import { isRemoteTransaction } from "../lib/multiplayer"; import { findBlockNodes } from "../queries/findChildren"; type ParsedNode = { text: string; classes: string[]; }; const cache: Record = {}; const languagesToImport = new Set(); const languagePromises: Record< string, Promise | undefined > = {}; let refractor: typeof refractorType | undefined; /** Lazily load refractor core. */ async function getRefractor() { refractor ??= (await import("refractor/core")).default; return refractor; } async function loadLanguage(language: string) { const r = await getRefractor(); if (!language || r.registered(language)) { return; } if (languagePromises[language]) { return languagePromises[language]; } const loader = getLoaderForLanguage(language); if (!loader) { return; } languagePromises[language] = loader() .then((syntax) => { r.register(syntax); return language; }) .catch((err) => { // It will retry loading the language on the next render // oxlint-disable-next-line no-console console.error( `Failed to load language ${language} for code highlighting`, err ); delete languagePromises[language]; // Remove failed promise from cache return undefined; }); return languagePromises[language]; } function getDecorations({ doc, name, lineNumbers, }: { /** The prosemirror document to operate on. */ doc: Node; /** The node name. */ name: string; /** Whether to include decorations representing line numbers */ lineNumbers?: boolean; }) { const decorations: Decoration[] = []; const blocks: { node: Node; pos: number }[] = findBlockNodes( doc, true ).filter((item) => item.node.type.name === name); function parseNodes( nodes: refractorType.RefractorNode[], classNames: string[] = [] ): { text: string; classes: string[]; }[] { return flattenDeep( nodes.map((node) => { if (node.type === "element") { const classes = [...classNames, ...(node.properties.className || [])]; return parseNodes(node.children, classes); } return { text: node.value, classes: classNames, }; }) ); } blocks.forEach((block) => { let startPos = block.pos + 1; const language = block.node.attrs.language; const lang = getRefractorLangForLanguage(language); const lineDecorations = []; if (!cache[block.pos] || !cache[block.pos].node.eq(block.node)) { if (lineNumbers && !block.node.attrs.wrap) { const lineCount = (block.node.textContent.match(/\n/g) || []).length + 1; const gutterWidth = String(lineCount).length; const lineCountText = new Array(lineCount) .fill(0) .map((_, i) => padStart(`${i + 1}`, gutterWidth, " ")) .join("\n"); lineDecorations.push( Decoration.node( block.pos, block.pos + block.node.nodeSize, { "data-line-numbers": `${lineCountText}`, style: `--line-number-gutter-width: ${gutterWidth};`, }, { key: `line-${lineCount}-gutter`, } ) ); } cache[block.pos] = { node: block.node, decorations: lineDecorations, }; if (!lang) { // do nothing } else if (refractor?.registered(lang)) { languagesToImport.delete(language); const nodes = refractor!.highlight(block.node.textContent, lang); const newDecorations = parseNodes(nodes) .map((node: ParsedNode) => { const from = startPos; const to = from + node.text.length; startPos = to; return { ...node, from, to, }; }) .filter((node) => node.classes && node.classes.length) .map((node) => Decoration.inline(node.from, node.to, { class: node.classes.join(" "), }) ) .concat(lineDecorations); cache[block.pos] = { node: block.node, decorations: newDecorations, }; } else { languagesToImport.add(language); } } cache[block.pos]?.decorations.forEach((decoration) => { decorations.push(decoration); }); }); Object.keys(cache) .filter((pos) => !blocks.find((block) => block.pos === Number(pos))) .forEach((pos) => { delete cache[Number(pos)]; }); return DecorationSet.create(doc, decorations); } export function CodeHighlighting({ name, lineNumbers, }: { /** The node name. */ name: string; /** Whether to include decorations representing line numbers */ lineNumbers?: boolean; }) { let highlighted = false; return new Plugin({ key: new PluginKey("codeHighlighting"), state: { init: (_, { doc }) => DecorationSet.create(doc, []), apply: (transaction: Transaction, decorationSet, oldState, state) => { const nodeName = state.selection.$head.parent.type.name; const previousNodeName = oldState.selection.$head.parent.type.name; const codeBlockChanged = transaction.docChanged && [nodeName, previousNodeName].includes(name); // @ts-expect-error accessing private field. const isPaste = transaction.meta?.paste; const langLoaded = transaction.getMeta("codeHighlighting")?.langLoaded; if ( !highlighted || codeBlockChanged || isPaste || langLoaded || isRemoteTransaction(transaction) ) { // Invalidate cached entries for blocks whose language just loaded // so getDecorations rebuilds them with syntax highlighting applied. if (Array.isArray(langLoaded)) { for (const key of Object.keys(cache)) { const pos = Number(key); if (langLoaded.includes(cache[pos]?.node.attrs.language)) { delete cache[pos]; } } } highlighted = true; return getDecorations({ doc: transaction.doc, name, lineNumbers }); } return decorationSet.map(transaction.mapping, transaction.doc); }, }, view: (view) => { if (!highlighted) { void getRefractor().then(() => { if (!view.isDestroyed) { view.dispatch( view.state.tr.setMeta("codeHighlighting", { langLoaded: true, }) ); } }); } return { update: () => { if (!languagesToImport.size) { return; } void Promise.all([...languagesToImport].map(loadLanguage)).then( (results) => { const loaded = results.filter((lang): lang is string => !!lang); if (loaded.length && !view.isDestroyed) { view.dispatch( view.state.tr.setMeta("codeHighlighting", { langLoaded: loaded, }) ); } } ); }, }; }, props: { decorations(state) { return this.getState(state); }, }, }); }