During my work on Paper To Project, I was looking into Categories of AI Research Ideas. There I manually looked at papers and created categories, and then I used OpenAI’s o4-mini to create a dataset with 800 examples. This is a bit slow, and costs lots of LLM tokens (even though I saved a lot by using token/traffic sharing program). I thought, maybe I already have enough data to train a text classifier? Of course, the task is a bit challenging, so a simple regression won’t work, so I decided to fine-tune a BERT-like model.
Selecting a model
Original BERT (and DistilBERT) has lots of limitations, like a small context size and a slow training. I heard lots of good things about ModernBERT, and this was my first choice to fine-tune. However, I saw another model got some attention on X, so I decided to try it out. However, it didn’t work locally on Mac due to using xformers, so I still started with ModernBERT first to test my training script without spending GPU time first.
Finding a GPU
I didn’t want to spend a lot on training, so I spent some time searching for a GPU. I started with google colab, but had too many problems with environment setup and various errors. I tried Sagemaker Studio Lab, but couldn’t even launch an instance, getting some useless error. Finally, I saw this post and registered for https://lightning.ai/ and it has been awesome. I created my environment with uv using a free cpu instance, and switching to GPU saved my env and everything magically worked out of the box!
Creating a dataset
I had a big json file with my labels like
{
"HqLHY4TzGj": {
"Most influential source": "Ren et al., \u201cFaster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks\u201d (2015)",
"Possible to create from the source above": "yes \u2013 by observing that Faster R-CNN\u2019s per-proposal regression to the full ground truth is difficult when a proposal only partially overlaps an object, and that its NMS discards valuable complementary proposals, one can instead regress each proposal to its intersection with the truth and then take the union of those intersections.",
"Idea Source": "Re-examining the two canonical stages of modern detectors\u2014box regression (e.g. in Faster R-CNN) and non-maximum suppression\u2014led to the insight that regressing only visible parts and merging all candidate parts could improve localization.",
"Idea Scenario": "While implementing Faster R-CNN and experimenting with IoU-aware losses and Soft-NMS variants, the authors noticed that high-IoU proposals were still regressed poorly if they only covered part of an object, and that lower-scoring proposals often contained complementary object regions. In a lab meeting they discussed splitting regression targets by intersection and replacing winner-takes-all with a union operator over grouped proposals, leading directly to UoI.",
"Idea Generation Category": "Direct Enhancement"
},
}
and paper text extracted from pdfs in markdown files with names like HqLHY4TzGj.md
using PyMuPDF4LLM. I used hf trainer, so I need to create a dataset in a correct format. First, I created a polars DataFrame to do some EDA:
rows = []
for file in data_dir.glob("*.md"):
fname = file.stem
if fname in clean_ideas:
with open(file, "r", encoding="utf-8") as f:
content = f.read()
label = clean_ideas[fname]["Idea Generation Category"]
if label in ["Framework Unification", "Empirical Re-evaluation", "Theoretical Advancement", "Benchmark Advancement"]:
label = "Other"
prompt = content.replace("Published as a conference paper at ICLR 2025", "")
prompt = " ".join(prompt.split()[:5000])
prompt += "\nIdea Generation Category: "
rows.append({"text": prompt, "labels": label, "id": fname})
df = pl.DataFrame(rows)
df.write_parquet("iclr_2025_ideas.parquet")
and then I created a dataset with 800 examples:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset, Dataset
import pandas as pd
from datasets import DatasetDict
from transformers import Trainer, TrainingArguments
import numpy as np
from sklearn.metrics import f1_score
df = pd.read_parquet("iclr_2025_ideas.parquet")
dataset = Dataset.from_pandas(df)
dataset = dataset.class_encode_column("labels")
dataset = dataset.train_test_split(test_size=0.1)
dataset.push_to_hub("eamag/iclr2025ideas")
Training
First we need to tokenize our dataset and load the model.
model_id = "chandar-lab/NeoBERT"
# model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
def tokenize(batch):
return tokenizer(
batch["text"], padding="max_length", truncation=True, return_tensors="pt"
)
tokenized_dataset = dataset.map(tokenize, batched=True, remove_columns=["text"])
# Prepare model labels - useful for inference
labels = tokenized_dataset["train"].features["labels"].names
num_labels = len(labels)
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
model = AutoModelForSequenceClassification.from_pretrained(
model_id,
num_labels=num_labels,
label2id=label2id,
id2label=id2label,
trust_remote_code=True,
)
Then we can define training arguments, some of them are quite interesting:
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
score = f1_score(
labels, predictions, labels=labels, pos_label=1, average="weighted"
)
return {"f1": float(score) if score == 1 else score}
model_name = model_id.split('/')[-1]
training_args = TrainingArguments(
output_dir=f"{model_name}-multiclass-classifier-ICLR-dir",
auto_find_batch_size=True,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=5e-4,
num_train_epochs=20,
# bfloat16 training
bf16=True,
bf16_full_eval=True,
# improved optimizer
optim="adamw_torch_fused",
# logging
logging_strategy="steps",
logging_steps=100,
eval_strategy="epoch",
save_strategy="epoch",
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model="f1",
report_to="tensorboard",
# try torch.compile (speeds up ModernBERT training but crase)
# torch_compile=True,
# torch_compile_backend="inductor",
# torch_compile_mode="default",
# overfitting
# weight_decay=0.01,
label_smoothing_factor=0.1,
# hub for pushing to huggingface, optional
# hub_model_id="eamag/NeoBERT-multiclass-classifier-ICLR",
# hub_strategy="end",
)
You can start training already, but I was struggling with overfitting due to a small dataset and a difficult task (still not solved!). So I also tried freezing some layers:
# freeze all layers except classifier
for param in model.base_model.parameters():
param.requires_grad = False
# unfreeze a last transformer layer in NeoBERT
for name, param in model.named_parameters():
if "layer_norm" in name or "27" in name:
param.requires_grad = True
print(name, param.requires_grad)
Finally, we can train the model.
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["test"],
compute_metrics=compute_metrics,
)
trainer.train()
# if you interrupted training, you can resume from checkpoint
# trainer.train(resume_from_checkpoint=True)
trainer.save_model(f'{model_name}--multiclass-classifier-ICLR')
That’s it! You can now use the model for inference:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
model_id = "chandar-lab/NeoBERT"
tokenizer = AutoTokenizer.from_pretrained(model_id)
def tokenize(batch):
return tokenizer(
batch["text"], padding="max_length", truncation=True, return_tensors="pt"
)
inputs = tokenizer(text, padding="max_length", truncation=True, return_tensors="pt")
model_id = "eamag/NeoBERT-multiclass-classifier-ICLR"
model = AutoModelForSequenceClassification.from_pretrained(
model_id,
trust_remote_code=True,
)
with torch.no_grad():
outputs = model.forward(**inputs)
logits = outputs.logits
predicted_class_id = logits.argmax().item()
print(model.config.id2label[predicted_class_id])
Results
Unfortunately, I couldn’t solve an overfitting problem. I know my dataset is dirty and it’s a difficult problem, so next steps to fix it can be to extract text from pdf in some other way, add more context for the embedding, and play around with a training setup (I tried things like weight decay, label smoothing, learning rate schedules, parameter freezing, classifying raw BERT embeddings and some other). If you know other things to try - write me on telegram or X! You can find also the final model on HuggingFace.