Estrazione migliorata delle relazioni ottimizzando Llama3–8B con un set di dati sintetici creato utilizzando Llama3–70B
L'estrazione delle relazioni (RE) è il compito di estrarre le relazioni dal testo non strutturato per identificare le connessioni tra varie entità denominate. Viene eseguito insieme al riconoscimento dell'entità denominata (NER) ed è un passaggio essenziale in una pipeline di elaborazione del linguaggio naturale. Con l'avvento dei Giant Language Fashions (LLM), gli approcci tradizionali supervisionati che comportano l'etichettatura degli intervalli di entità e la classificazione delle relazioni (se presenti) tra di essi vengono migliorati o completamente sostituiti da approcci basati su LLM (1).
Llama3 è la versione principale più recente nel dominio di GenerativeAI (2). Il modello base è disponibile in due dimensioni, 8B e 70B, con un modello 400B previsto a breve. Questi modelli sono disponibili sulla piattaforma HuggingFace; Vedere (3) per dettagli. La variante 70B alimenta il nuovo sito di chat di Meta Meta.ai e mostra prestazioni paragonabili a ChatGPT. Il modello 8B è tra i più performanti della sua categoria. L'architettura di Llama3 è simile a quella di Llama2, con l'aumento delle prestazioni dovuto principalmente all'aggiornamento dei dati. Il modello viene fornito con un tokenizzatore aggiornato e una finestra di contesto estesa. È etichettato come open supply, sebbene venga rilasciata solo una piccola percentuale dei dati. Nel complesso è un modello eccellente e non vedo l'ora di provarlo.
Llama3–70B può produrre risultati sorprendenti, ma a causa delle sue dimensioni è poco pratico, proibitivamente costoso e difficile da usare sui sistemi locali. Pertanto, per sfruttare le sue capacità, abbiamo chiesto a Llama3–70B di insegnare al più piccolo Llama3–8B il compito di estrarre le relazioni dal testo non strutturato.
Nello specifico, con l'aiuto di Llama3–70B, costruiamo un set di dati supervisionato per l'ottimizzazione mirata all'estrazione delle relazioni. Utilizziamo quindi questo set di dati per ottimizzare Llama3–8B per migliorare le sue capacità di estrazione delle relazioni.
Per riprodurre il codice nel file Taccuino di Google Colab associato a questo weblog, avrai bisogno di:
- Credenziali HuggingFace (per salvare il modello messo a punto, opzionale) e accesso a Llama3, ottenibile seguendo le istruzioni di una delle schede dei modelli;
- Un libero GroqCloud account (puoi accedere con un account Google) e una chiave API corrispondente.
Per questo progetto ho utilizzato un Google Colab Professional dotato di GPU A100 e impostazione Excessive-RAM.
Iniziamo installando tutte le librerie richieste:
!pip set up -q groq
!pip set up -U speed up bitsandbytes datasets consider
!pip set up -U peft transformers trl
Mi ha fatto molto piacere notare che l'intera configurazione ha funzionato dall'inizio senza problemi di dipendenze o necessità di installazione transformers
dalla fonte, nonostante la novità del modello.
Dobbiamo anche consentire l'accesso a Goggle Colab all'unità e ai file e impostare la listing di lavoro:
# For Google Colab settings
from google.colab import userdata, drive# It will immediate for authorization
drive.mount('/content material/drive')
# Set the working listing
%cd '/content material/drive/MyDrive/postedBlogs/llama3RE'
Per coloro che desiderano caricare il modello su HuggingFace Hub, dobbiamo caricare le credenziali dell'Hub. Nel mio caso, questi sono archiviati nei segreti di Google Colab, a cui è possibile accedere tramite il pulsante chiave a sinistra. Questo passaggio è facoltativo.
# For Hugging Face Hub setting
from huggingface_hub import login# Add the HuggingFace token (ought to have WRITE entry) from Colab secrets and techniques
HF = userdata.get('HF')
# That is wanted to add the mannequin to HuggingFace
login(token=HF,add_to_git_credential=True)
Ho anche aggiunto alcune variabili di percorso per semplificare l'accesso ai file:
# Create a path variable for the info folder
data_path = '/content material/drive/MyDrive/postedBlogs/llama3RE/datas/'# Full fine-tuning dataset
sft_dataset_file = f'{data_path}sft_train_data.json'
# Information collected from the the mini-test
mini_data_path = f'{data_path}mini_data.json'
# Take a look at information containing all three outputs
all_tests_data = f'{data_path}all_tests.json'
# The adjusted coaching dataset
train_data_path = f'{data_path}sft_train_data.json'
# Create a path variable for the SFT mannequin to be saved domestically
sft_model_path = '/content material/drive/MyDrive/llama3RE/Llama3_RE/'
Ora che il nostro spazio di lavoro è configurato, possiamo passare al primo passaggio, ovvero creare un set di dati sintetico per l'attività di estrazione delle relazioni.
Sono disponibili diversi set di dati per l'estrazione delle relazioni, il più noto è il CoNLL04 set di dati. Inoltre, ci sono set di dati eccellenti come web_nlgdisponibile su HuggingFace e SciREX sviluppato da AllenAI. Tuttavia, la maggior parte di questi set di dati prevede licenze restrittive.
Ispirato al formato del web_nlg
set di dati costruiremo il nostro set di dati. Questo approccio sarà particolarmente utile se intendiamo mettere a punto un modello addestrato sul nostro set di dati. Per iniziare, abbiamo bisogno di una raccolta di frasi brevi per il nostro compito di estrazione delle relazioni. Possiamo compilare questo corpus in vari modi.
Raccogli una raccolta di frasi
Noi useremo databricks-dolly-15k, un set di dati open supply generato dai dipendenti di Databricks nel 2023. Questo set di dati è progettato per la messa a punto supervisionata e embrace quattro funzionalità: istruzione, contesto, risposta e categoria. Dopo aver analizzato le otto categorie, ho deciso di mantenere la prima frase del contesto della information_extraction
categoria. I passaggi di analisi dei dati sono descritti di seguito:
from datasets import load_dataset# Load the dataset
dataset = load_dataset("databricks/databricks-dolly-15k")
# Select the specified class from the dataset
ie_category = (e for e in dataset("practice") if e("class")=="information_extraction")
# Retain solely the context from every occasion
ie_context = (e("context") for e in ie_category)
# Break up the textual content into sentences (on the interval) and maintain the primary sentence
reduced_context = (textual content.cut up('.')(0) + '.' for textual content in ie_context)
# Retain sequences of specified lengths solely (use character size)
sampler = (e for e in reduced_context if 30 < len(e) < 170)
Il processo di selezione produce un set di dati comprendente 1.041 frasi. Dato che si tratta di un mini-progetto, non ho selezionato manualmente le frasi e, di conseguenza, alcuni esempi potrebbero non essere ideali per il nostro compito. In un progetto destinato alla produzione, selezionerei attentamente solo le frasi più applicable. Tuttavia, per gli scopi di questo progetto, questo set di dati sarà sufficiente.
Formattare i dati
Dobbiamo prima creare un messaggio di sistema che definirà il immediate di enter e istruirà il modello su come generare le risposte:
system_message = """You might be an skilled annontator.
Extract all entities and the relations between them from the next textual content.
Write the reply as a triple entity1|relationship|entitity2.
Don't add the rest.
Instance Textual content: Alice is from France.
Reply: Alice|is from|France.
"""
Poiché si tratta di una fase sperimentale, mantengo le esigenze del modello al minimo. Ho testato diversi altri immediate, inclusi alcuni che richiedevano output in formato CoNLL in cui le entità vengono categorizzate, e il modello ha funzionato abbastanza bene. Tuttavia, per semplicità, per ora ci limiteremo alle nozioni di base.
Dobbiamo anche convertire i dati in un formato conversazionale:
messages = ((
{"position": "system","content material": f"{system_message}"},
{"position": "person", "content material": e}) for e in sampler)
Il consumer e l'API Groq
Llama3 è stato rilasciato solo pochi giorni fa e la disponibilità delle opzioni API è ancora limitata. Sebbene sia disponibile un'interfaccia di chat per Llama3–70B, questo progetto richiede un'API in grado di elaborare le mie 1.000 frasi con un paio di righe di codice. L'ho trovato eccellente Video Youtube che spiega come utilizzare gratuitamente l'API GroqCloud. Per maggiori dettagli fare riferimento al video.
Solo un promemoria: dovrai accedere e recuperare una chiave API gratuita dal GroqCloud sito internet. La mia chiave API è già salvata nei segreti di Google Colab. Iniziamo inizializzando il consumer Groq:
import os
from groq import Groqgclient = Groq(
api_key=userdata.get("GROQ"),
)
Successivamente dobbiamo definire un paio di funzioni di supporto che ci consentiranno di interagire con il file Meta.ai interfaccia di chat in modo efficace (questi sono adattati da Video Youtube):
import time
from tqdm import tqdmdef process_data(immediate):
"""Ship one request and retrieve mannequin's era."""
chat_completion = gclient.chat.completions.create(
messages=immediate, # enter immediate to ship to the mannequin
mannequin="llama3-70b-8192", # in response to GroqCloud labeling
temperature=0.5, # controls variety
max_tokens=128, # max quantity tokens to generate
top_p=1, # proportion of probability weighted choices to think about
cease=None, # string that indicators to cease producing
stream=False, # if set partial messages are despatched
)
return chat_completion.selections(0).message.content material
def send_messages(messages):
"""Course of messages in batches with a pause between batches."""
batch_size = 10
solutions = ()
for i in tqdm(vary(0, len(messages), batch_size)): # batches of measurement 10
batch = messages(i:i+10) # get the subsequent batch of messages
for message in batch:
output = process_data(message)
solutions.append(output)
if i + 10 < len(messages): # verify if there are batches left
time.sleep(10) # anticipate 10 seconds
return solutions
La prima funzione process_data()
funge da wrapper per la funzione di completamento della chat del consumer Groq. La seconda funzione send_messages()
, elabora i dati in piccoli batch. Se segui il collegamento Impostazioni nella pagina del parco giochi Groq, troverai un collegamento a Limiti che descrive in dettaglio le condizioni alle quali possiamo utilizzare l'API gratuita, inclusi i limiti al numero di richieste e token generati. Per evitare di superare questi limiti, ho aggiunto un ritardo di 10 secondi dopo ogni batch di 10 messaggi, anche se nel mio caso non period strettamente necessario. Potresti voler sperimentare queste impostazioni.
Ciò che resta ora è generare i nostri dati di estrazione delle relazioni e integrarli con il set di dati iniziale:
# Information era with Llama3-70B
solutions = send_messages(messages)# Mix enter information with the generated dataset
combined_dataset = ({'textual content': person, 'gold_re': output} for person, output in zip(sampler, solutions))
Prima di procedere con la messa a punto del modello, è importante valutare le sue prestazioni su diversi campioni per determinare se la messa a punto è effettivamente necessaria.
Creazione di un set di dati di check
Selezioneremo 20 campioni dal set di dati che abbiamo appena costruito e li metteremo da parte per i check. Il resto del set di dati verrà utilizzato per la messa a punto.
import random
random.seed(17)# Choose 20 random entries
mini_data = random.pattern(combined_dataset, 20)
# Construct conversational format
parsed_mini_data = (({'position': 'system', 'content material': system_message},
{'position': 'person', 'content material': e('textual content')}) for e in mini_data)
# Create the coaching set
train_data = (merchandise for merchandise in combined_dataset if merchandise not in mini_data)
Utilizzeremo l'API GroqCloud e le utilità sopra particular, specificando mannequin=llama3-8b-8192
mentre il resto della funzione rimane invariato. In questo caso, possiamo elaborare direttamente il nostro piccolo set di dati senza preoccuparci di superare i limiti API.
Ecco un output di esempio che fornisce l'originale textual content
denotava la generazione Llama3-70B gold_re
e la generazione Llama3-8B etichettata test_re
.
shocks'
Per il set di dati completo del check, fare riferimento a Taccuino di Google Colab.
Proprio da questo esempio diventa chiaro che Llama3–8B potrebbe beneficiare di alcuni miglioramenti nelle sue capacità di estrazione delle relazioni. Lavoriamo per migliorarlo.
Utilizzeremo un arsenale completo di tecniche per assisterci, tra cui QLoRA e Flash Consideration. Non approfondirò qui i dettagli della scelta degli iperparametri, ma se sei interessato a esplorare ulteriormente, dai un'occhiata a questi ottimi riferimenti (4) E (5).
La GPU A100 supporta Flash Consideration e bfloat16 e possiede circa 40 GB di memoria, sufficienti per le nostre esigenze di messa a punto.
Preparazione del set di dati SFT
Iniziamo analizzando il set di dati in un formato conversazionale, incluso un messaggio di sistema, testo di enter e la risposta desiderata, che ricaviamo dalla generazione Llama3–70B. Lo salviamo quindi come set di dati HuggingFace:
def create_conversation(pattern):
return {
"messages": (
{"position": "system","content material": system_message},
{"position": "person", "content material": pattern("textual content")},
{"position": "assistant", "content material": pattern("gold_re")}
)
}from datasets import load_dataset, Dataset
train_dataset = Dataset.from_list(train_data)
# Remodel to conversational format
train_dataset = train_dataset.map(create_conversation,
remove_columns=train_dataset.options,
batched=False)
Scegli il Modello
model_id = "meta-llama/Meta-Llama-3-8B"
Carica il tokenizzatore
from transformers import AutoTokenizer# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id,
use_fast=True,
trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'
# Set a most size
tokenizer.model_max_length = 512
Scegli Parametri di quantizzazione
from transformers import BitsAndBytesConfigbnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
Carica il modello
from transformers import AutoModelForCausalLM
from peft import prepare_model_for_kbit_training
from trl import setup_chat_formatdevice_map = {"": torch.cuda.current_device()} if torch.cuda.is_available() else None
mannequin = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device_map,
attn_implementation="flash_attention_2",
quantization_config=bnb_config
)
mannequin, tokenizer = setup_chat_format(mannequin, tokenizer)
mannequin = prepare_model_for_kbit_training(mannequin)
Configurazione LoRA
from peft import LoraConfig# Based on Sebastian Raschka findings
peft_config = LoraConfig(
lora_alpha=128, #32
lora_dropout=0.05,
r=256, #16
bias="none",
target_modules=("q_proj", "o_proj", "gate_proj", "up_proj",
"down_proj", "k_proj", "v_proj"),
task_type="CAUSAL_LM",
)
I migliori risultati si ottengono quando si prendono di mira tutti gli strati lineari. Se i vincoli di memoria sono un problema, può essere utile optare per valori più commonplace come alpha=32 e rating=16, poiché queste impostazioni comportano un numero significativamente inferiore di parametri.
Argomenti di formazione
from transformers import TrainingArguments# Tailored from Phil Schmid blogpost
args = TrainingArguments(
output_dir=sft_model_path, # listing to avoid wasting the mannequin and repository id
num_train_epochs=2, # variety of coaching epochs
per_device_train_batch_size=4, # batch measurement per gadget throughout coaching
gradient_accumulation_steps=2, # variety of steps earlier than performing a backward/replace move
gradient_checkpointing=True, # use gradient checkpointing to avoid wasting reminiscence, use in distributed coaching
optim="adamw_8bit", # select paged_adamw_8bit if not sufficient reminiscence
logging_steps=10, # log each 10 steps
save_strategy="epoch", # save checkpoint each epoch
learning_rate=2e-4, # studying fee, primarily based on QLoRA paper
bf16=True, # use bfloat16 precision
tf32=True, # use tf32 precision
max_grad_norm=0.3, # max gradient norm primarily based on QLoRA paper
warmup_ratio=0.03, # warmup ratio primarily based on QLoRA paper
lr_scheduler_type="fixed", # use fixed studying fee scheduler
push_to_hub=True, # push mannequin to Hugging Face hub
hub_model_id="llama3-8b-sft-qlora-re",
report_to="tensorboard", # report metrics to tensorboard
)
Se scegli di salvare il modello localmente, puoi omettere gli ultimi tre parametri. Potrebbe anche essere necessario regolare il per_device_batch_size
E gradient_accumulation_steps
per evitare errori di memoria esaurita (OOM).
Inizializza il coach e addestra il modello
from trl import SFTTrainercoach = SFTTrainer(
mannequin=mannequin,
args=args,
train_dataset=sft_dataset,
peft_config=peft_config,
max_seq_length=512,
tokenizer=tokenizer,
packing=False, # True if the dataset is massive
dataset_kwargs={
"add_special_tokens": False, # the template provides the particular tokens
"append_concat_token": False, # no want so as to add extra separator token
}
)
coach.practice()
coach.save_model()
La formazione, incluso il salvataggio del modello, ha richiesto circa 10 minuti.
Cancellamo la memoria per prepararci ai check di inferenza. Se utilizzi una GPU con meno memoria e riscontri errori CUDA Out of Reminiscence (OOM), potrebbe essere necessario riavviare il runtime.
import torch
import gc
del mannequin
del tokenizer
gc.accumulate()
torch.cuda.empty_cache()
In questo passaggio finale caricheremo il modello base in mezza precisione insieme all'adattatore Peft. Per questo check ho scelto di non unire il modello con l'adattatore.
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline
import torch# HF mannequin
peft_model_id = "solanaO/llama3-8b-sft-qlora-re"
# Load Mannequin with PEFT adapter
mannequin = AutoPeftModelForCausalLM.from_pretrained(
peft_model_id,
device_map="auto",
torch_dtype=torch.float16,
offload_buffers=True
)
Successivamente, carichiamo il tokenizzatore:
okenizer = AutoTokenizer.from_pretrained(peft_model_id)tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
E costruiamo la pipeline di generazione del testo:
pipe = pipeline("text-generation", mannequin=mannequin, tokenizer=tokenizer)
Carichiamo il dataset di check, composto dai 20 campioni che abbiamo messo da parte in precedenza, e formattiamo i dati in stile conversazionale. Tuttavia, questa volta omettiamo il messaggio dell'assistente e lo formattiamo come set di dati Hugging Face:
def create_input_prompt(pattern):
return {
"messages": (
{"position": "system","content material": system_message},
{"position": "person", "content material": pattern("textual content")},
)
}from datasets import Dataset
test_dataset = Dataset.from_list(mini_data)
# Remodel to conversational format
test_dataset = test_dataset.map(create_input_prompt,
remove_columns=test_dataset.options,
batched=False)
Un check campione
Generiamo l'output di estrazione della relazione utilizzando SFT Llama3–8B e confrontiamolo con i due output precedenti su una singola istanza:
Generate the enter immediate
immediate = pipe.tokenizer.apply_chat_template(test_dataset(2)("messages")(:2),
tokenize=False,
add_generation_prompt=True)
# Generate the output
outputs = pipe(immediate,
max_new_tokens=128,
do_sample=False,
temperature=0.1,
top_k=50,
top_p=0.1,
)
# Show the outcomes
print(f"Query: {test_dataset(2)('messages')(1)('content material')}n")
print(f"Gold-RE: {test_sampler(2)('gold_re')}n")
print(f"LLama3-8B-RE: {test_sampler(2)('test_re')}n")
print(f"SFT-Llama3-8B-RE: {outputs(0)('generated_text')(len(immediate):).strip()}")
Otteniamo quanto segue:
Query: Lengthy earlier than any information of electrical energy existed, individuals have been conscious of shocks from electrical fish.Gold-RE: individuals|have been conscious of|shocks
shocks|from|electrical fish
electrical fish|had|electrical energy
LLama3-8B-RE: electrical fish|have been conscious of|shocks
SFT-Llama3-8B-RE: individuals|have been conscious of|shocks
shocks|from|electrical fish
In questo esempio, osserviamo miglioramenti significativi nelle capacità di estrazione delle relazioni di Llama3–8B attraverso la messa a punto. Nonostante il set di dati di messa a punto non sia né molto pulito né particolarmente ampio, i risultati sono impressionanti.
Per i risultati completi sul set di dati di 20 campioni, fare riferimento a Taccuino di Google Colab. Tieni presente che il check di inferenza richiede più tempo perché carichiamo il modello con mezza precisione.
In conclusione, utilizzando Llama3–70B e un set di dati disponibile, abbiamo creato con successo un set di dati sintetico che è stato poi utilizzato per mettere a punto Llama3–8B per un compito specifico. Questo processo non solo ci ha fatto familiarizzare con Llama3, ma ci ha anche permesso di applicare le semplici tecniche di Hugging Face. Abbiamo osservato che lavorare con Llama3 somiglia molto all'esperienza con Llama2, con notevoli miglioramenti che riguardano una migliore qualità dell'output e un tokenizzatore più efficace.
Per coloro che sono interessati a spingersi oltre i confini, si consideri la possibilità di sfidare il modello con compiti più complessi come la categorizzazione di entità e relazioni e l'utilizzo di queste classificazioni per costruire un grafico della conoscenza.
- Somin Wadhwa, Silvio Amir, Byron C. Wallace, Revisiting Relation Extraction within the period of Giant Language Fashions, arXiv.2305.05003 (2023).
- Meta, presentazione di Meta Llama 3: il LLM più capace disponibile fino advert oggi, 18 aprile 2024 (collegamento).
- Philipp Schmid, Omar Sanseviero, Pedro Cuenca, Youndes Belkada, Leandro von Werra, Benvenuto Llama 3: il nuovo LLM aperto di Met, 18 aprile 2024.
- Sebastiano Raschka, Suggerimenti pratici per ottimizzare gli LLM utilizzando LoRA (adattamento di basso rango)Davanti all'IA, 19 novembre 2023.
- Filippo Schmid, Come perfezionare gli LLM nel 2024 con Hugging Face, 22 gennaio 2024.
databricks-dolly-15K sulla piattaforma Hugging Face (CC BY-SA 3.0)