Deploying a spam filter with Snorkel, scikit-learn and ONNX Runtime
- 11 minutes read - 2192 wordsI created the .NET Twitter Bot in January 2020 and it has been running ever since, amassing a decent following with over 2,000 followers at the time of writing this post. The bot was initially setup to simply search for #dotnet related tweets every 30 minutes and retweet them. As you can imagine, with such a simple bot, it was sometimes retweeting tweets that were not related to .NET at all.
Initially, this was handled by programming the bot to skip tweets that contained certain keywords, such as “domain” and “registration”. This worked well for a while, but as the bot grew in popularity, certain accounts started to take advantage of it for spamming.
Clearly, a more sophisticated spam filter was required to ensure the bot was surfacing the right tweets.
Spam filters, like the one that runs on your email provider, are generally built with machine learning models as it is intractable to programmatically identify spam with a rules-based algorithm. Therefore, to create a spam filter, a dataset of spam and non-spam tweets was needed to train a model.
Collecting a training dataset
In a previous data analysis post I had used Vicinitas to download approximately 3,200 of .NET Bot’s tweets. This is a great tool to quickly grab a bunch of tweet data, and their website is very easy to use:
I still had access to this initial dataset, so this seemed like a good starting point for a training dataset. However, since this was from over a year ago, it seemed reasonable to collect a more recent dataset to ensure the bot was aligned with the current data distribution. So I downloaded another set of 3,200 tweets, giving me a training set of around 6,400 tweets from 2020 and 2021!
This data consists of the following columns:
- Tweet ID
- Tweet full text
- Created time
- Number of favourties
- Number of retweets
- Detected language
- List of URLs in the tweet
- Number of hashtags
- Number of mentions
- Media type
- Media links
This was a great starting point, but to train a spam filter, which is a supervised learning task, a labelled dataset is needed. However, I wasn’t too keen on labelling 6,400 tweets by hand!! Is there another way…?
Labelling with Snorkel
Snorkel is an open source Python library that can help quickly label training data using Weak Supervision. Cool! This sounded promising, and I had been keen on trying Snorkel out for a while, so this seemed like a great direction to go in.
Test Set
I must admit, I did do a small amount of labelling to create a hold out Test Set that would serve as a benchmark to test the final trained model. To create the Test Set, I randomly went through the tweet data and labelled approximately 100 spam tweets (label = 1) and 100 non-spam tweets (label = 0). All other tweets were given a label of -1, meaning I will “abstain” from deciding what they are and leave it up to Snorkel to determine the labels for these. Snorkel will end up labelling 6200 of the 6400 tweets (or about 97% of the data) - that’s a pretty good productivity gain!
Installing Snorkel
Snorkel is available as a Python package:
pip install snorkel
or via conda
:
conda install snorkel -c conda-forge
Then to get up and running with basic Snorkel functionality, you’ll need the following imports:
from snorkel.analysis import get_label_buckets
from snorkel.labeling import labeling_function, filter_unlabeled_dataframe, PandasLFApplier, LFAnalysis
from snorkel.labeling.model import MajorityLabelVoter, LabelModel
from snorkel.utils import probs_to_preds
The following settings are also applied to achieve quiet, reproducible runs:
# Turn off TensorFlow logging messages
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
# For reproducibility
os.environ["PYTHONHASHSEED"] = "0"
Loading data
Snorkel works nicely with Pandas, and since I didn’t have a huge amount of data, Pandas was a good choice for loading the data:
import pandas as pd
df = pd.read_excel('tweets.xlsx', usecols=['Label', 'Text', 'Language', 'Hashtags', 'URLs', 'Mentions', 'Media Type'])
df_test = df[(df['Label'] == 1) | (df['Label'] == 0)]
df_train = df[(df['Label'] == -1)]
Y_test = df_test['Label'].values
Here, the 6,400 tweet excel file is read, loading the required columns, and separating out the Train and Test sets. Note that the training set is comprised completely of unlabelled data!
Labelling Functions
One of the key ways to interact with Snorkel is through a Labelling Function (LF). These act as noisy heuristics to label the data with simple rules. Each rule should label something as either SPAM
, HAM
(not spam) or ABSTAIN
:
ABSTAIN = -1
HAM = 0
SPAM = 1
For example, I noticed many of the spam tweets had lots of hashtags, so I created the following LF:
@labeling_function()
def hashtags(x):
return SPAM if x['Hashtags'] > 8 else ABSTAIN
Here, a LF called hashtags
is defined to label tweets that have more than 8 hashtags as SPAM
, otherwise it’s not given a label (abstaining).
I also noticed that there were tweets that were heavily geared towards cryptocurrency, blockchain and so-forth, that also had nothing to do with .NET. So I created the following LF:
@labeling_function()
def crypto(x):
return SPAM if "crypto" in x['Text'].lower() else ABSTAIN
Here, a LF called crypto
is defined to label tweets that have “crypto” in the text.
All in all, I ended up defining 26 LFs. It is an iterative process to arrive at a final set of LFs; you’ll need to inspect your data, try new LFs, until you run out of ideas or your model achieves good performance. General Data-Centric AI methodologies work best here.
Label function analysis
Once the LFs are finalised, they can be applied to the data to generate a weakly labelled data set:
lfs = [hashtags, crypto, ...] # list of 26 labelling functions
applier = PandasLFApplier(lfs=lfs)
L_train = applier.apply(df=df_train)
L_test = applier.apply(df=df_test)
The weakly labelled training set L_train
can then be analysed:
LFAnalysis(L=L_train, lfs=lfs).lf_summary()
j | Polarity | Coverage | Overlaps | Conflicts | |
---|---|---|---|---|---|
hiring | 0 | [1] | 0.028185 | 0.027696 | 0.012708 |
hire | 1 | [1] | 0.006680 | 0.006680 | 0.001955 |
job | 2 | [1] | 0.074780 | 0.072988 | 0.027533 |
work | 3 | [1] | 0.084881 | 0.078853 | 0.054252 |
career | 4 | [1] | 0.017758 | 0.016129 | 0.003747 |
offer | 5 | [1] | 0.005702 | 0.005213 | 0.001629 |
candidate | 6 | [1] | 0.013848 | 0.013359 | 0.002770 |
remote | 7 | [1] | 0.018736 | 0.018736 | 0.005702 |
money | 8 | [1] | 0.021668 | 0.021342 | 0.003584 |
percent | 9 | [1] | 0.004562 | 0.004399 | 0.001955 |
exclaim | 10 | [1] | 0.161779 | 0.145976 | 0.105083 |
visit | 11 | [1] | 0.019062 | 0.013522 | 0.009775 |
subscribe | 12 | [1] | 0.009286 | 0.009286 | 0.008961 |
firewall | 13 | [1] | 0.000489 | 0.000489 | 0.000163 |
buy | 14 | [1] | 0.000815 | 0.000489 | 0.000489 |
free | 15 | [1] | 0.021180 | 0.020365 | 0.010590 |
today | 16 | [1] | 0.035191 | 0.033561 | 0.026393 |
come | 17 | [1] | 0.012382 | 0.011730 | 0.006517 |
blockchain | 18 | [1] | 0.041544 | 0.041382 | 0.000815 |
crypto | 19 | [1] | 0.006680 | 0.006517 | 0.000489 |
hashtags | 20 | [1] | 0.290160 | 0.155262 | 0.032095 |
sale | 21 | [1] | 0.005865 | 0.005702 | 0.002281 |
mentions | 22 | [1] | 0.023135 | 0.021342 | 0.018247 |
lang_und | 23 | [1] | 0.022972 | 0.020854 | 0.005702 |
minimal_hashtags | 24 | [0] | 0.447540 | 0.277615 | 0.177256 |
net | 25 | [0] | 0.164386 | 0.147931 | 0.074780 |
nuget | 26 | [0] | 0.010264 | 0.009938 | 0.003421 |
at_dotnet | 27 | [0] | 0.144184 | 0.063213 | 0.037634 |
This analysis shows how much coverage, overlap and conflicts each LF has. This helps to understand how much data is being labelled and the level of redundancy for each LF. As mentioned above, this information can be used to iterate on the LFs.
Digging deeper, it is possible to sample the training set to see what type of tweets are being labelled.
For example, a sample of some weakly labelled SPAM
tweets:
df_train.iloc[L_train[:, 1] == SPAM].sample(10, random_state=1)[['Text', 'Hashtags', 'Label']]
Text | Hashtags | Label |
---|---|---|
RT @unicorntalents : Hire the best C# programm… | 17 | -1 |
RT @HirectApp : .@Prostooservices is #Hiring o… | 18 | -1 |
RT @JobHookup4U : Senior C# Developer - New Yo… | 8 | -1 |
RT @SoftwaredevIn : Need a #Dotnet programmer … | 4 | -1 |
RT @Julieeeee06 : hiring kami!!!!\n\nDM mo sak… | 9 | -1 |
RT @JobHookup4U : Senior C# Developer - New Yo… | 8 | -1 |
RT @Katheri01147964 : #Hire #dotNET #Developer… | 8 | -1 |
RT @HindleyWesley : WHY HIRE A DEDICATED #ASPD… | 5 | -1 |
RT @JobHookup4U : Senior C# Developer - New Yo… | 8 | -1 |
RT @OfficialHaritha : Hire .net Programmer\nht… | 8 | -1 |
So job advertisement tweets are being labelling as Spam, that seems reasonable.
A similar analysis can also be done for the L_test
set.
Generating weak labels
Now that the LFs have been applied, this is where Snorkel really shines. Using it’s built in models, it will handle any label conflicts and produce weak probabilistic labels that can finally be used to train a spam classifier.
The MajorityLabelVoter can be used to combine labels based on a majority vote:
majority_model = MajorityLabelVoter()
majority_model.predict(L=L_train)
majority_acc = majority_model.score(L=L_test, Y=Y_test, tie_break_policy="random")[
"accuracy"
]
The LabelModel can be used to combine labels based on an algorithm in this paper:
label_model = LabelModel(cardinality=2, verbose=True)
label_model.fit(L_train=L_train, n_epochs=500, log_freq=100, seed=123)
label_model_acc = label_model.score(L=L_test, Y=Y_test, tie_break_policy="random")[
"accuracy"
]
Comparing the two approaches:
print(f"{'Majority Vote Accuracy:':<25} {majority_acc * 100:.1f}%")
print(f"{'Label Model Accuracy:':<25} {label_model_acc * 100:.1f}%")
Majority Vote Accuracy: 79.3%
Label Model Accuracy: 84.5%
In this case, the Label Model approach produces the best accuracy. Choosing this as the weak label method, the training probabilities can then be generated with:
probs_train = label_model.predict_proba(L=L_train)
and the training Pandas DataFrame
can be generated with:
df_train_filtered, probs_train_filtered = filter_unlabeled_dataframe(
X=df_train, y=probs_train, L=L_train
)
Training the spam classifier
Firstly, the data needs to be arranged to be suitable for a scikit-learn
model:
X_train = df_train_filtered['Text'].tolist()
X_test = df_test['Text'].tolist()
preds_train_filtered = probs_to_preds(probs=probs_train_filtered)
Here, only the tweet text is used, and the probabilities are converted to binary labels (preds).
Now the model pipeline can be created:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.pipeline import Pipeline
vectorizer = CountVectorizer(ngram_range=(1, 5))
sklearn_model = GradientBoostingClassifier(n_estimators=100, learning_rate=1.0, max_depth=1, random_state=0)
sklearn_pipeline = Pipeline([("vectorizer", vectorizer), ("model", sklearn_model)])
Here, the CountVectorizer
feature extraction class is used to convert text into something the model can interpret, and the GradientBoostingClassifier
is used as the classification model. These are then added to a Pipeline
so that feature extraction can be part of the exported model.
Now the model can finally be trained:
pipe.fit(X=X_train, y=preds_train_filtered)
And the final accuracy on the test set:
print(f"Test Accuracy: {pipe.score(X=X_test, y=Y_test) * 100:.1f}%")
Test Accuracy: 85.0%
The classifier has a 0.5% improvement over the Label Model!
Exporting the model
Now that the model has been trained, it is ready to be exported and deployed to the bot service. As the bot is not written in Python, it is not as simple as exporting a native scikit-learn model. However, ONNX and ONNXRuntime are our saviours here as they allow the model to be exported to ONNX format and loaded up in C# (or any other language for that matter!) with ONNXRuntime.
To export the model to ONNX:
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import StringTensorType
model_onnx = convert_sklearn(pipe, initial_types=[("input", StringTensorType([None, 1]))])
with open("../DotNetTwitterBot/spam_filter.onnx", "wb") as f:
f.write(model_onnx.SerializeToString())
Here, the convert_sklearn
module from skl2onnx
is being used to create an ONNX model, which is then saved to the bot’s service location.
The exported model can then be visualised using the open source tool Netron:
Using the model for inference
To load this model in C#, ONNXRuntime package needs to be installed:
dotnet add package Microsoft.ML.ONNXRuntime
Then a simple class can be created to perform Spam Classification:
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using System.Collections.Generic;
using System.Linq;
namespace DotNetTwitterBot
{
public class SpamFilter
{
private readonly InferenceSession _session;
private readonly string _inputName = "input";
private readonly string _outputName = "output_probability";
private readonly int _spamIndex = 1;
private readonly float _spamThreshold = 0.5f;
public SpamFilter(string modelPath)
{
_session = new InferenceSession(modelPath);
}
public IEnumerable<bool> Run(IEnumerable<string> textList)
{
var textArray = textList.ToArray();
var count = textArray.Length;
var input = new DenseTensor<string>(new[] { count, 1 });
for (var i = 0; i < count; i++)
{
input.SetValue(i, textArray[i]);
}
var inputs = new List<NamedOnnxValue>()
{
NamedOnnxValue.CreateFromTensor(_inputName, input)
};
var outputs = new List<string> { _outputName };
var results = _session.Run(inputs, outputs).ToArray()[0].Value as IList<DisposableNamedOnnxValue>;
return results.Select(x => (x.Value as IDictionary<long, float>)[_spamIndex] > _spamThreshold);
}
}
}
Here, the input text list is loaded into a DenseTensor
, passed to the InferenceSession
to run inference, predictions extracted from the inference output, and a threshold applied for what is considered spam (>0.5).
This can then be used in the existing AWS Lambda function to filter out spam tweets:
// Start of method omitted for brevity
var tweets = await SearchAsync.SearchTweets(param);
var spamFilter = new SpamFilter("spam_filter.onnx");
var isSpam = spamFilter.Run(tweets.Select(t => $"RT @{t.CreatedBy.ScreenName} : {t.Text}")).ToArray();
var tasks = tweets.Select(async (t, i) => {
if (!isSpam[i])
await t.PublishRetweetAsync();
});
await Task.WhenAll(tasks);
Now tweets are only retweeted by the bot if they are not considered spam, and hardcoded rules are no longer required!
This has been deployed to the AWS Lambda Function and the bot is now using this spam filter live, right now!
Summary
In this post, Snorkel was used to create a weakly labelled training dataset of Spam/Ham tweets using Labelling Functions and Snorkel’s Label Model. This allowed a scikit-learn classification model to be trained very quickly (this took me all of 1 day to complete!). It was then exported to ONNX format and loaded into a C# AWS Lambda Function using ONNXRuntime. In the future, I’d like to develop this into a proper MLOps pipeline and use a transformer model to improve accuracy - stay posted for blogs on these topics!