Fetching ML models under Bazel

Fetching ML models under Bazel

·

3 min read

This is a repeat of my earlier article https://blog.aspect.dev/lazy-tool-fetching-under-bazel - this time instead of lazy-fetching a NodeJS tool used by vercel/pkg, we're going to fetch a Python package used for Machine Learning tasks, NLTK Data https://www.nltk.org/data.html

Like the previous article, we can read the install documentation provided by the tool we want to run as a Bazel action. When we do, we find they suggest a non-hermetic "installation" path that creates a /usr/local/share/nltk_data folder on your machine. Here's the error printed when the data isn't found:

Resource punkt not found.
Please use the NLTK Downloader to obtain the resource:
>>> import nltk
>>> nltk.download('punkt')
For more information see: https://www.nltk.org/data.html

We don't want to follow this guidance. It also forces us to add tags = ["requires-network"] on a target, which is a smell of something wrong. The installed state is unmanaged by the build system, and won't exist on the CI machine where your tests run. You could prepare the CI machines in the same way of course, but this isn't what we want under Bazel. It makes the build non-reproducible since it depends on machine state, and not portable to other executors.

Instead, we'll use the Bazel Downloader to prepare the package. Like before, the key observation is that these "lazy fetching" patterns always have a cache or install folder layout the tool expects to read. So long as we can construct a cache folder that meets all the assumptions of the tool, we can just stitch our result folder into the tool's runtime using the affordance provided, in this case the NLTK_DATA environment variable.

Fetching NLTK data

For the sake of example, let's say our NLTK usage assumes that the Punkt tokenizer is "installed". We can see that it's distributed in their GitHub repo: https://github.com/nltk/nltk_data/tree/gh-pages/packages/tokenizers

Our first step is to fetch the data. We want to use the Bazel Downloader, for reasons described in my earlier post: https://blog.aspect.dev/configuring-bazels-downloader. That can use the repository_ctx.download* functions, but this case is simple, so we can just use the http_archive helper from our WORKSPACE or MODULE.bazel file:

    # Bazel download from https://www.nltk.org/nltk_data/ rather than follow their lame instructions
    http_archive(
        name = "nltk_data_punkt",
        build_file_content = """exports_files(["punkt"], visibility = ["//visibility:public"])""",
        sha256 = "51c3078994aeaf650bfc8e028be4fb42b4a0d177d41c012b6a983979653660ec",
        # note: 'gh-pages' branch replaced by a commit hash for determinism
        urls = ["https://raw.githubusercontent.com/nltk/nltk_data/1d3c34b4cfd6059986bf4bc604e5929335ab92ff/packages/tokenizers/punkt.zip"],
    )

Now, we need to prepare a folder that mimics the cache folder structure nltk's downloader creates. It's sufficient to do a single copy_to_directory action. To make it look pretty in a Bazel context, we'll wrap it with a trivial Macro that adds some documentation. Let's put this content in nltk.bzl:

"Helpers for https://www.nltk.org/"

load("@aspect_bazel_lib//lib:copy_to_directory.bzl", "copy_to_directory")

def nltk_data(name, corpora):
    """Assemble a folder following instructions at https://www.nltk.org/data.html#manual-installation
    Args:
        name: name of resulting target
        corpora: list of packages from https://github.com/nltk/nltk_data/tree/gh-pages/packages
            e.g. ["tokenizers/punkt", "sentiment/vader_lexicon"]
            Note that these need to be fetched by Bazel, see /tools/bazel/fetch.bzl to add more.
    """

    copy_to_directory(
        name = name,
        # Prohibit this data being used in production,
        # which would pose a vulnerability issue and could cause outages.
        testonly = True,
        # convention for external repos is using the last segment of the name, e.g.
        # tokenizers/punkt is in @nltk_data_punkt
        srcs = ["@nltk_data_{0}//:{0}".format(each.split("/")[-1]) for each in corpora],
        include_external_repositories = ["nltk_data_*"],
        replace_prefixes = {each.split("/")[-1]: each for each in corpora},
    )

Using in a BUILD target

Now we'll edit our Python target to be able to resolve the data. This is likely a py_binary or py_test target. First, load the nltk.bzl file and call the macro, for example:

nltk_data(
    name = "nltk_data",
    corpora = ["tokenizers/punkt"],
)

This :nltk_data target looks like an nltk-installed data folder.

So in our Python target, we just need to set an environment variable pointing to it:

py_test(
    ...
    data = [
        ":nltk_data",
        "//path/to/tests/words",
    ],
    env = {"NLTK_DATA": "$(rootpath :nltk_data)"},
)

Now the data is downloaded by Bazel and then hermetically provided to the test action at runtime with no network access.