{
"cells": [
{
"cell_type": "markdown",
"id": "802074a3",
"metadata": {},
"source": [
"# Searching 20-Newsgroups \n",
"\n",
"In this notebook, we will show you how to prepare the [20-Newsgroups](https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html) dataset for Thematic Search. \n",
"\n",
"This dataset consists of ~18,000 posts from [Usenet](https://en.wikipedia.org/wiki/Usenet) which are sorted into newsgroups that have a hierarchical structure. For example, two newsgroups are called `rec.sport.hockey` and `rec.sport.baseball`, which are both groups under the namespace `rec.sport`, itself under the namespace `rec`. We can use this hierarchical structure to build the topic tree for our thematic search.\n",
"\n",
"To begin, we will fetch a dataset from HuggingFace that contains 20-Newsgroups with precomputed embeddings and UMAP reduced vectors:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1d550261",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" post | \n",
" newsgroup | \n",
" embedding | \n",
" map | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" \\n\\nI am sure some bashers of Pens fans are pr... | \n",
" rec.sport.hockey | \n",
" [-0.04380008950829506, 0.08495834469795227, -0... | \n",
" [-0.13199903070926666, 10.1972017288208] | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" post newsgroup \\\n",
"0 \\n\\nI am sure some bashers of Pens fans are pr... rec.sport.hockey \n",
"\n",
" embedding \\\n",
"0 [-0.04380008950829506, 0.08495834469795227, -0... \n",
"\n",
" map \n",
"0 [-0.13199903070926666, 10.1972017288208] "
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import thematic_search as ts\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"newsgroups_df = pd.read_parquet(\"hf://datasets/lmcinnes/20newsgroups_embedded/data/train-00000-of-00001.parquet\")\n",
"newsgroups_df.head(1)"
]
},
{
"cell_type": "markdown",
"id": "5455ec64",
"metadata": {},
"source": [
"## Building the Topic Tree\n",
"\n",
"The first thing we need to do is build the cluster tree that we will pass to `thematic_search.TopicDatabase`. The cluster tree should be a dictionary with entries `{ vertex:[child_1, child_2,...,child_n]}`, where each vertex is a tuple `(layer, cluster)`. \n",
"\n",
"First, we will build our a tree by assigning each vertex a parent according to its name structure - `rec.sport.hockey`'s parent will be `rec.sport`. Using this we can form a dictionary with keys given by vertices and values given by lists of children. \n",
"\n",
"Then we can convert this dictionary into the required form using `thematic_search.utils.convert_string_tree`. This takes a tree of strings and converts it to a tree of tuples. It returns the tree, `cluster_tree` and a dictionary `cluster_labels` that maps clusters `(l,c)` to their string names.\n",
"\n",
"The utility `thematic_search.utils.print_tree` can be used to print the tree and check that it is correct."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "45901982",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"root\n",
"--alt\n",
"----alt.atheism\n",
"--comp\n",
"----comp.graphics\n",
"----comp.os\n",
"------comp.os.ms-windows\n",
"--------comp.os.ms-windows.misc\n",
"----comp.sys\n",
"------comp.sys.ibm\n",
"--------comp.sys.ibm.pc\n",
"----------comp.sys.ibm.pc.hardware\n",
"------comp.sys.mac\n",
"--------comp.sys.mac.hardware\n",
"----comp.windows\n",
"------comp.windows.x\n",
"--misc\n",
"----misc.forsale\n",
"--rec\n",
"----rec.autos\n",
"----rec.motorcycles\n",
"----rec.sport\n",
"------rec.sport.baseball\n",
"------rec.sport.hockey\n",
"--sci\n",
"----sci.crypt\n",
"----sci.electronics\n",
"----sci.med\n",
"----sci.space\n",
"--soc\n",
"----soc.religion\n",
"------soc.religion.christian\n",
"--talk\n",
"----talk.politics\n",
"------talk.politics.guns\n",
"------talk.politics.mideast\n",
"------talk.politics.misc\n",
"----talk.religion\n",
"------talk.religion.misc\n"
]
}
],
"source": [
"tags = np.unique(newsgroups_df['newsgroup'].to_numpy())\n",
"\n",
"from collections import defaultdict\n",
"def build_tree(paths):\n",
" tree = defaultdict(set)\n",
" tree[\"root\"]\n",
" for p in paths:\n",
" parts = p.split(\".\")\n",
" for i in range(len(parts)):\n",
" node = \".\".join(parts[:i+1])\n",
" parent = \"root\" if i == 0 else \".\".join(parts[:i])\n",
" tree[parent].add(node)\n",
" tree[node]\n",
" return {k: sorted(v) for k, v in tree.items()}\n",
"\n",
"tree = build_tree(tags)\n",
"\n",
"cluster_tree, cluster_labels = ts.utils.convert_tree(tree) \n",
"ts.utils.print_tree(cluster_tree, cluster_labels=cluster_labels)"
]
},
{
"cell_type": "markdown",
"id": "ea2d1263",
"metadata": {},
"source": [
"## Building the Topic Metadata\n",
"\n",
"Next, we need to make a pandas dataframe with metadata about the topics. This dataframe's **index** must match the **keys of your cluster tree**. In our case, our cluster tree's keys are the newsgroups' names. Having these equal allows the TopicDatabase to correctly match rows of the topic dataframe with vertices of the cluster tree.\n",
"\n",
"\n",
"For the 20-newsgroups dataset, each topic has a string name that we also want to store, and we may as well also include the layer and cluster number in the metadata:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5b3faa59",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" name | \n",
" layer | \n",
" cluster | \n",
"
\n",
" \n",
" | index | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | root | \n",
" root | \n",
" 5 | \n",
" 0 | \n",
"
\n",
" \n",
" | alt | \n",
" alt | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" name layer cluster\n",
"index \n",
"root root 5 0\n",
"alt alt 1 0"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tag_to_tuple = {v:k for k,v in cluster_labels.items()}\n",
"\n",
"data = []\n",
"for tag in tree.keys():\n",
" layer, cluster = tag_to_tuple[tag]\n",
" data.append({\n",
" 'index':tag,\n",
" 'name':tag,\n",
" 'layer':layer,\n",
" 'cluster':cluster,\n",
" })\n",
"\n",
"topic_df = pd.DataFrame(data)\n",
"topic_df = topic_df.set_index('index')\n",
"topic_df.head(2)"
]
},
{
"cell_type": "markdown",
"id": "d6e4cfe8",
"metadata": {},
"source": [
"## Building the Topic Inclusion Matrices\n",
"\n",
"The final bit of information we need to construct a TopicDatabase are the matrices encoding the inclusion strengths of each sample in each topic. For 20-newsgroups, this will be a binary matrix. \n",
"\n",
"We need a list of matrices, one for each layer of the cluster hierarchy. For 20-newsgroups, the upper topics (such as `alt`) don't have their own posts, but we will any post in a topic's children to also be in the topic. As this is a common situation, there is a utility `thematic_search.utils.cluster_layers_from_leaf_matrix` that only requires us to build the matrix for the layer-0 nodes.\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b974e8c7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(18170, 20)\n",
"(18170, 11)\n",
"(18170, 5)\n",
"(18170, 1)\n",
"(18170, 1)\n"
]
}
],
"source": [
"leaves = [(l,c) for l,c in cluster_tree.keys() if l == 0]\n",
"\n",
"n_samples = len(newsgroups_df)\n",
"n_leaves = len(leaves)\n",
"\n",
"leaf_matrix = np.zeros((n_samples, n_leaves))\n",
"\n",
"for l,c in leaves:\n",
" tag_name = cluster_labels[(l, c)]\n",
" leaf_matrix[:, c] = newsgroups_df['newsgroup'].apply(\n",
" lambda x: x ==tag_name\n",
" )\n",
"\n",
"cluster_matrices = ts.utils.cluster_layers_from_leaf_matrix(\n",
" cluster_tree, leaf_matrix\n",
")\n",
"\n",
"for matrix in cluster_matrices:\n",
" print(matrix.shape)\n"
]
},
{
"cell_type": "markdown",
"id": "3ee96fca",
"metadata": {},
"source": [
"## Initializing the Database\n",
"\n",
"Now we have everything we need to intitialize the database for thematic search. We're going to load the sentence-transformers model that was used to embed the dataset, so we have it available for semantic nearest-neighbour search. Everything else was constructed in the previous steps!"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2f59c48b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "aa15aaeb36cb4a4faf8e6a0e11d24535",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading weights: 0%| | 0/199 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[1mMPNetModel LOAD REPORT\u001b[0m from: sentence-transformers/all-mpnet-base-v2\n",
"Key | Status | | \n",
"------------------------+------------+--+-\n",
"embeddings.position_ids | UNEXPECTED | | \n",
"\n",
"\u001b[3mNotes:\n",
"- UNEXPECTED\u001b[3m\t:can be ignored when loading from different task/architecture; not ok if you expect identical arch.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sentence_transformers import SentenceTransformer\n",
"\n",
"model = SentenceTransformer(\"sentence-transformers/all-mpnet-base-v2\")\n",
"\n",
"soft_cluster_tree = ts.SoftClusterTree(\n",
" cluster_matrices,\n",
" cluster_tree,\n",
")\n",
"topic_db = ts.TopicDatabase(\n",
" soft_cluster_tree = soft_cluster_tree,\n",
" embedding_vectors = np.stack(newsgroups_df['embedding'].values),\n",
" reduced_vectors = np.stack(newsgroups_df['map'].values),\n",
" sample_df = newsgroups_df[['post', 'newsgroup']],\n",
" topic_df = topic_df,\n",
" cluster_labels = cluster_labels,\n",
" embedding_model = model,\n",
")\n",
"topic_db"
]
},
{
"cell_type": "markdown",
"id": "795aafa4",
"metadata": {},
"source": [
"## Example Queries\n",
"\n",
"Finally, all that's left to do is query the dataset!\n",
"\n",
"First, we can query by topic name; we will query for \"documents inside the topic named `rec.sport`\"."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e7282911",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" post | \n",
" newsgroup | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" \\n\\nI am sure some bashers of Pens fans are pr... | \n",
" rec.sport.hockey | \n",
"
\n",
" \n",
" | 7 | \n",
" \\n[stuff deleted]\\n\\nOk, here's the solution t... | \n",
" rec.sport.hockey | \n",
"
\n",
" \n",
" | 8 | \n",
" \\n\\n\\nYeah, it's the second one. And I believ... | \n",
" rec.sport.hockey | \n",
"
\n",
" \n",
" | 24 | \n",
" I don't know the exact coverage in the states.... | \n",
" rec.sport.hockey | \n",
"
\n",
" \n",
" | 33 | \n",
" \\nBe patient. He has a sore shoulder from cras... | \n",
" rec.sport.baseball | \n",
"
\n",
" \n",
" | ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" | 18132 | \n",
" Can someone send me ticket ordering informatio... | \n",
" rec.sport.baseball | \n",
"
\n",
" \n",
" | 18135 | \n",
" \\n\\n\\n\\nSaku isn't that small any longer I gue... | \n",
" rec.sport.hockey | \n",
"
\n",
" \n",
" | 18152 | \n",
" \\n \\n Well, I'm a Wings fan and I think the F... | \n",
" rec.sport.hockey | \n",
"
\n",
" \n",
" | 18154 | \n",
" \\n\\n\\n\\n\\n Anaheim. | \n",
" rec.sport.baseball | \n",
"
\n",
" \n",
" | 18161 | \n",
" \\nAnd won't they have to change their name to ... | \n",
" rec.sport.hockey | \n",
"
\n",
" \n",
"
\n",
"
1927 rows × 2 columns
\n",
"
"
],
"text/plain": [
" post newsgroup\n",
"0 \\n\\nI am sure some bashers of Pens fans are pr... rec.sport.hockey\n",
"7 \\n[stuff deleted]\\n\\nOk, here's the solution t... rec.sport.hockey\n",
"8 \\n\\n\\nYeah, it's the second one. And I believ... rec.sport.hockey\n",
"24 I don't know the exact coverage in the states.... rec.sport.hockey\n",
"33 \\nBe patient. He has a sore shoulder from cras... rec.sport.baseball\n",
"... ... ...\n",
"18132 Can someone send me ticket ordering informatio... rec.sport.baseball\n",
"18135 \\n\\n\\n\\nSaku isn't that small any longer I gue... rec.sport.hockey\n",
"18152 \\n \\n Well, I'm a Wings fan and I think the F... rec.sport.hockey\n",
"18154 \\n\\n\\n\\n\\n Anaheim. rec.sport.baseball\n",
"18161 \\nAnd won't they have to change their name to ... rec.sport.hockey\n",
"\n",
"[1927 rows x 2 columns]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"topic_db.q.topic_name('rec.sport').samples().metadata()"
]
},
{
"cell_type": "markdown",
"id": "412de198",
"metadata": {},
"source": [
"Next, let's query by semantic search. We can query for \"information about the theme of the documents semantically close to the string 'Recent advancements in space exploration'\"."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ede42d29",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" name | \n",
" layer | \n",
" cluster | \n",
"
\n",
" \n",
" | index | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | 14 | \n",
" sci.space | \n",
" 0 | \n",
" 14 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" name layer cluster\n",
"index \n",
"14 sci.space 0 14"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"topic_db.q.neighbours(\"Recent advancements in space exploration\").neighbours().theme().metadata()"
]
},
{
"cell_type": "markdown",
"id": "ef1033f2",
"metadata": {},
"source": [
"Let's do one more example of searching for a theme. We will pick out a handful of documents from `rec.sport.hockey` and `rec.sport.baseball`, and ask what their theme is:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "b3dfa607",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" post newsgroup\n",
"0 \\n\\nI am sure some bashers of Pens fans are pr... rec.sport.hockey\n",
"7 \\n[stuff deleted]\\n\\nOk, here's the solution t... rec.sport.hockey\n",
"8 \\n\\n\\nYeah, it's the second one. And I believ... rec.sport.hockey\n",
"33 \\nBe patient. He has a sore shoulder from cras... rec.sport.baseball\n",
"18132 Can someone send me ticket ordering informatio... rec.sport.baseball\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" name | \n",
" layer | \n",
" cluster | \n",
"
\n",
" \n",
" | index | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" | 26 | \n",
" rec.sport | \n",
" 1 | \n",
" 6 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" name layer cluster\n",
"index \n",
"26 rec.sport 1 6"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(topic_db.q.samples([0,7,8,33,18132]).metadata())\n",
"\n",
"topic_db.q.samples([0,7,8,33,18132]).theme().metadata()"
]
},
{
"cell_type": "markdown",
"id": "633a47f6",
"metadata": {},
"source": [
"As you probably expected, the theme is `rec.sport`; this is the least upper bound that contains the documents in our query."
]
},
{
"cell_type": "markdown",
"id": "e840fbf6",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}