Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 68 additions & 7 deletions medcat-plugins/embedding-linker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,21 @@ pip install medcat-embedding-linker

## Quick Start

### Replacing current linker with a static embedding linker

```python
from medcat.cat import CAT
from medcat.config import Config
from medcat.components.types import CoreComponentType

from medcat_embedding_linker import EmbeddingLinking
from medcat_embedding_linker.embedding_linker import Linker as StaticEmbeddingLinker
from medcat_embedding_linker.config import EmbeddingLinking

# Load your MedCAT model
cat = CAT.load_model_pack("path/to/model_pack")

# Configure the embedding linker
cat.config.components.linking = EmbeddingLinking()
cat.config.components.linking.embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
cat.config.components.linking.comp_name = StaticEmbeddingLinker.name

# Recreate the pipeline to register the new linker
cat._recreate_pipe()
Expand All @@ -54,11 +56,52 @@ linker.create_embeddings()
entities = cat.get_entities("Patient presents with chest pain and dyspnea.")
```

### Replacing current linker with a trainable embedding linker AND training

```python
from medcat.cat import CAT
from medcat_embedding_linker.trainable_embedding_linker import TrainableEmbeddingLinker
from medcat_embedding_linker.config import EmbeddingLinking

# Load your MedCAT model
cat = CAT.load_model_pack("path/to/model_pack")

# Configure the embedding linker
cat.config.components.linking = EmbeddingLinking()
cat.config.components.linking.comp_name = TrainableEmbeddingLinker.name

# Recreate the pipeline to register the new linker
cat._recreate_pipe()

# Generate embeddings for your concept database
linker = self.get_component(CoreComponentType.linking)
# create
linker.create_embeddings()

# load required data into MedCATTrainerExport format
train_projects, test_projects = your_dataset_loading_method()

# Training loop - four is probably a nice stopping point
num_epochs = 4

# the first epoch is done out of the loop incase new concepts / names are detected
cat.trainer.train_supervised_raw(train_projects, test_size=0, nepochs=1)
# refreshing the structure here is required for new cuis/names that have been detected
# so the efficient lookup lists need to be recreated
linker.refresh_structure()
linker.create_embeddings()
get_stats(cat=cat, data=test_projects, use_project_filters=False)
for i in range(num_epochs - 1):
cat.trainer.train_supervised_raw(train_projects, test_size=0, nepochs=1)
linker.create_embeddings()
get_stats(cat=cat, data=test_projects, use_project_filters=False)
```

## How It Works

### Component Registration

The embedding linker automatically registers itself as `embedding_linker` when `EmbeddingLinking` config is detected. It implements MedCAT's `AbstractEntityProvidingComponent` interface and is lazily loaded when the pipeline is created.
The embedding linker automatically requires the name of the trainable or static component when `EmbeddingLinking` config is detected. It implements MedCAT's `AbstractEntityProvidingComponent` interface and is lazily loaded when the pipeline is created.

### Embedding Generation

Expand Down Expand Up @@ -87,8 +130,7 @@ For each detected entity:

## Configuration

### Key Parameters

### Key Parameters - Static and Trainable
```python
config.components.linking = EmbeddingLinking(
# Model settings
Expand Down Expand Up @@ -119,6 +161,21 @@ config.components.linking = EmbeddingLinking(
gpu_device="cuda:0" # or None for auto-detect
)
```
### Key Parameters - Trainable ONLY

```python
config.components.linking = EmbeddingLinking(
# Training settings
train_on_names: bool = True
training_batch_size: int = 32
embed_per_n_batches: int = 0

# Model settings
use_mention_attention: bool = True
use_projection_layer: bool = True
top_n_layers_to_unfreeze: int = 0
)
```

### Embedding Models

Expand All @@ -127,6 +184,7 @@ Any HuggingFace model compatible with sentence transformers will work. Popular o
- `sentence-transformers/all-MiniLM-L6-v2` (default, fast and lightweight)
- `sentence-transformers/all-mpnet-base-v2` (higher quality)
- `UFNLP/gatortron-medium` (biomedical domain)
- `abhinand/MedEmbed-small-v0.1` (often the best performing)
- `microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext`

## Advanced Usage
Expand Down Expand Up @@ -171,11 +229,14 @@ cat.config.components.linking.filters.cuis_exclude = {"C0000001"}
- **GPU recommended**: 10-50x faster inference with CUDA
- **Batch sizes**: Increase if you have GPU memory available
- **Model selection**: Smaller models (e.g., MiniLM) are faster but may be less accurate than larger domain-specific models
- **Unfreezing layers**: The more layers you unfreeze of a model - the better the predictive power of the model _should_ increase. This will come at the cost of increased computation.
- **Using a projection layer**: This will have no (or a slightly negative) impact on static embeddings. On trainable embeddings this will result in a large performance increase (i.e. 50-75% increase in recall or more). This is always trainable, as that is the point of it. The computational cost is minimal.
- **Mention Attention**: This will generate embeddings for the tokens of interest based on the sourounding context - not the entire context of detected entity. This should always result in a performance increase, at zero computational cost. The only case where this might not be true is if the entire detected context is all of a detected entity, at which case performance will be exactly equal to not using mention attention.
- **embed_per_n_batches**: This is how many training batches have been completed before re-embedding all names / cuis. Setting this to 0 means that re-embedding will never occur, and must be done manually. Re-embedding more often can result in a slight performance increase. However this is a long process and should probably be avoided / tested. It's recomended to set this to 0, and re-embed manually every epoch.

## Limitations

- Does not support `prefer_frequent_concepts` or `prefer_primary_name` from the default linker (logs warnings if set)
- Training mode is not applicable (logs warning if enabled)
- Requires pre-computed embeddings before inference

## Citation
Expand Down
2 changes: 1 addition & 1 deletion medcat-plugins/embedding-linker/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ classifiers = [
# For an analysis of this field vs pip's requirements files see:
# https://packaging.python.org/discussions/install-requires-vs-requirements/
dependencies = [
"medcat[spacy]>=2.5",
"medcat[spacy]>=2.7",
"transformers>=4.41.0,<5.0", # avoid major bump
"torch>=2.4.0,<3.0",
"tqdm",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,42 @@
from typing import Optional, Any

from medcat.config import Linking


class EmbeddingLinking(Linking):
"""The config exclusively used for the embedding linker"""

comp_name: str = "embedding_linker"
"""Changing compoenent name"""
filter_before_disamb: bool = False
"""Training on names or CUIs. If True all names of all CUIs will be used to train.
If false only CUIs preffered (or longest names will be used to train). Training on
names is more expensive computationally (and RAM/VRAM), but can lead to better
performance."""
train_on_names: bool = True
"""Filtering CUIs before disambiguation"""
train: bool = False
"""The embedding linker never needs to be trained in its
current implementation."""
training_batch_size: int = 32
"""The size of the batch to be used for training."""
embed_per_n_batches: int = 0
"""How many batches to train on before re-embedding the all names in the context
model. This is used to control how often the context model is updated during
training."""
use_similarity_threshold: bool = True
"""Do we have a similarity threshold we care about?"""
negative_sampling_k: int = 10
"""How many negative samples to generate for each positive sample during
training."""
negative_sampling_candidate_pool_size: int = 4096
"""When generating negative samples, sample top_n candidates to consider when
sampling. Higher numbers will make training slower but can provide varied negative
samples."""
negative_sampling_temperature: float = 0.1
"""Temperature to use when generating negative samples in training. Lower
temperatures will make the sampling more focused on the highest scoring candidates,
while higher temperatures will make it more random. Must be > 0."""
use_mention_attention: bool = True
"""Improves performance and fun to say. Mention attention can help the model focus
on the most relevant parts of the context when making linking decisions. Will only
pool on the tokens that contain the entity mention, with no context."""
long_similarity_threshold: float = 0.0
"""Used in the inference step to choose the best CUI given the
link candidates. Testing shows a threshold of 0.7 increases precision
Expand All @@ -26,11 +51,16 @@ class EmbeddingLinking(Linking):
embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
"""Name of the embedding model. It must be downloadable from
huggingface linked from an appropriate file directory"""
use_projection_layer: bool = True
"""Projection-layer default for trainable embedding linker."""
top_n_layers_to_unfreeze: int = 0
"""LM unfreezing default for trainable embedding linker.
-1 unfreezes all LM layers, 0 freezes all LM layers,
n unfreezes the top n layers."""
max_token_length: int = 64
"""Max number of tokens to be embedded from a name.
If the max token length is changed then the linker will need to be created
with a new config.
"""
with a new config."""
embedding_batch_size: int = 4096
"""How many pieces names can be embedded at once, useful when
embedding name2info names, cui2info names"""
Expand All @@ -44,5 +74,3 @@ class EmbeddingLinking(Linking):
use_ner_link_candidates: bool = True
"""Link candidates are provided by some NER steps. This will flag if
you want to trust them or not."""
use_similarity_threshold: bool = True
"""Do we have a similarity threshold we care about?"""
Loading
Loading