Legal practitioners and organizations worldwide deal with extensive legal documents in multiple languages, which is time-consuming and prone to errors. You aim to design a Transformer-based multilingual summarization model that can generate concise, accurate summaries of legal documents in multiple languages (e.g., English, Spanish, French, Mandarin) while preserving key legal semantics.
Type of Data:
Size of Data: At least 500,000 legal documents covering various languages (English, French, Spanish, etc.).
Sources:
Preprocessing Steps:
Base Model: Use mBART (Multilingual Bidirectional and Auto-Regressive Transformer) or mT5 for fine-tuning. These models are pre-trained for multilingual tasks and handle both encoding and decoding.
Architecture Details:
Optimization Techniques:
Metrics:
Validation:
Week 1-2:
Week 3-4:
Week 5-6:
Week 7-8:
Week 9:
Week 10:
"""
Project: Multilingual Legal Document Summarization using Transformer (mBART)
Author: Jahaziel Titular
Description:
This script fine-tunes the mBART model for multilingual summarization of legal documents.
It uses Hugging Face's Transformers library, along with tokenizers and datasets for preprocessing, training, and evaluation.
"""
# ----------------------------
# Step 1: Import Required Libraries
# ----------------------------
import os
import numpy as np
import pandas as pd
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from datasets import load_dataset, DatasetDict
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import torch
from transformers import AdamW
from transformers import DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from nltk.translate.bleu_score import corpus_bleu
from rouge_score import rouge_scorer
import evaluate
# ----------------------------
# Step 2: Load and Preprocess Data
# ----------------------------
# Assuming we have a multilingual dataset of legal documents with text and summary fields
# Replace with your actual data source or dataset
def load_legal_data(file_path):
"""
Loads multilingual legal dataset from CSV.
Args:
file_path (str): Path to the dataset file (CSV format).
Returns:
pandas.DataFrame: DataFrame with 'text', 'summary', and 'language' columns.
"""
df = pd.read_csv(file_path) # Expect columns: 'text', 'summary', 'language'
print(f"Loaded dataset with {len(df)} rows.")
return df
# Example: Load dataset (Replace 'path_to_dataset.csv' with your dataset file)
file_path = 'path_to_dataset.csv' # CSV file with text, summary, and language columns
legal_data = load_legal_data(file_path)
# Filter and split dataset by language for stratification
train_data, val_data = train_test_split(legal_data, test_size=0.2, random_state=42, stratify=legal_data['language'])
# Convert data into Hugging Face Dataset format
train_dataset = DatasetDict.from_pandas(pd.DataFrame(train_data))
val_dataset = DatasetDict.from_pandas(pd.DataFrame(val_data))
# ----------------------------
# Step 3: Load mBART Tokenizer and Model
# ----------------------------
model_name = "facebook/mbart-large-50" # Pretrained multilingual mBART model
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name)
# Set target language token for summarization
# Example: Summarizing to English, Spanish, or French
target_language_token = {
"en": "en_XX",
"fr": "fr_XX",
"es": "es_XX",
}
# Preprocessing Function for Tokenization
def preprocess_function(examples, source_lang="en", target_lang="en"):
"""
Tokenizes and preprocesses data for the mBART model.
Args:
examples (dict): Dataset examples with 'text' and 'summary' fields.
source_lang (str): Source language code (e.g., 'en').
target_lang (str): Target language code (e.g., 'en').
Returns:
dict: Tokenized input and target sequences.
"""
inputs = [example["text"] for example in examples]
targets = [example["summary"] for example in examples]
model_inputs = tokenizer(
inputs, max_length=1024, truncation=True, padding="max_length"
)
labels = tokenizer(
targets, max_length=256, truncation=True, padding="max_length"
)
model_inputs["labels"] = labels["input_ids"]
model_inputs["decoder_input_ids"] = labels["input_ids"]
return model_inputs
# Tokenize datasets
train_dataset = train_dataset.map(
preprocess_function, batched=True, fn_kwargs={"source_lang": "en", "target_lang": "en"}
)
val_dataset = val_dataset.map(
preprocess_function, batched=True, fn_kwargs={"source_lang": "en", "target_lang": "en"}
)
# ----------------------------
# Step 4: Set Up Training Arguments and Trainer
# ----------------------------
# Define Training Arguments
training_args = Seq2SeqTrainingArguments(
output_dir="./results", # Save results to this directory
evaluation_strategy="epoch", # Evaluate at the end of every epoch
save_strategy="epoch", # Save the model every epoch
learning_rate=5e-5, # Initial learning rate
per_device_train_batch_size=8, # Batch size per device during training
per_device_eval_batch_size=8, # Batch size for evaluation
weight_decay=0.01, # Weight decay for regularization
save_total_limit=3, # Keep only last 3 model checkpoints
num_train_epochs=3, # Number of training epochs
predict_with_generate=True, # Use `generate()` for predictions
logging_dir="./logs", # Log directory
fp16=True, # Use mixed precision for faster training
)
# Data Collator for padding
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
# Define Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
# ----------------------------
# Step 5: Train the Model
# ----------------------------
print("Starting model training...")
trainer.train()
# Save the trained model
model.save_pretrained("./trained_model")
tokenizer.save_pretrained("./trained_model")
# ----------------------------
# Step 6: Evaluate the Model
# ----------------------------
# Generate predictions for the validation set
print("Evaluating the model...")
predictions = trainer.predict(val_dataset)
predicted_summaries = tokenizer.batch_decode(
predictions.predictions, skip_special_tokens=True
)
# Metrics: ROUGE and BLEU
rouge = evaluate.load("rouge")
rouge_results = rouge.compute(
predictions=predicted_summaries,
references=[example["summary"] for example in val_data],
)
print("ROUGE Scores:", rouge_results)
# Compute BLEU score
bleu_score = corpus_bleu(
[[example["summary"].split()] for example in val_data],
[summary.split() for summary in predicted_summaries],
)
print(f"BLEU Score: {bleu_score:.2f}")
# ----------------------------
# Step 7: Package and Deploy
# ----------------------------
# Create a simple REST API using FastAPI
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
class SummarizationRequest(BaseModel):
text: str
language: str
@app.post("/summarize")
def summarize(request: SummarizationRequest):
"""
Summarizes a given legal document in the specified language.
Args:
request (SummarizationRequest): Request object containing text and language.
Returns:
dict: Summarized text.
"""
inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding="longest")
summary_ids = model.generate(inputs["input_ids"], max_length=256, num_beams=4)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return {"summary": summary}
# Run the API: `uvicorn script_name:app --reload`
"""
Enhanced Streamlit App: Multilingual Legal Document Summarization
Author: [Your Name]
Description:
This Streamlit app summarizes legal documents in multiple languages using a fine-tuned mBART model.
Features include:
- File upload support (.txt and .pdf)
- Length control for summary
- Multiple model selection
- Custom theming
"""
# ----------------------------
# Step 1: Import Required Libraries
# ----------------------------
import streamlit as st
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import PyPDF2 # For PDF parsing
import os
# ----------------------------
# Step 2: Load Available Models and Tokenizers
# ----------------------------
@st.cache_resource # Cache the model and tokenizer to avoid reloading on every run
def load_model_and_tokenizer(model_name):
"""
Loads the fine-tuned mBART model and tokenizer.
Args:
model_name (str): Path to the model or Hugging Face model name.
Returns:
model: The fine-tuned mBART model.
tokenizer: The tokenizer for the mBART model.
"""
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name)
return model, tokenizer
# Pre-defined models (you can add more models if needed)
model_options = {
"Fine-tuned mBART": "./trained_model", # Path to your fine-tuned mBART model
"Pretrained mBART-50": "facebook/mbart-large-50" # Hugging Face pre-trained model
}
# Streamlit app allows the user to select the model
selected_model = st.sidebar.selectbox("Select a Model:", list(model_options.keys()))
model, tokenizer = load_model_and_tokenizer(model_options[selected_model])
# ----------------------------
# Step 3: Streamlit App UI
# ----------------------------
# App Title and Description
st.title("📜 Multilingual Legal Document Summarization")
st.markdown("""
This tool summarizes legal documents in multiple languages using a Transformer-based model (mBART).
Features include:
- **File Upload Support:** Summarize `.txt` or `.pdf` legal documents.
- **Language Selection:** Generate summaries in English, Spanish, French, etc.
- **Length Control:** Adjust the length of the summary.
""")
# ----------------------------
# Step 4: Input Section
# ----------------------------
st.header("Upload or Paste Your Document")
# Option to upload a file
uploaded_file = st.file_uploader("Upload a `.txt` or `.pdf` file", type=["txt", "pdf"])
# Option to paste the text directly
document_text = st.text_area("Or Paste Your Document Here:", height=300)
# Parse uploaded file if provided
if uploaded_file:
if uploaded_file.type == "application/pdf":
# Read PDF file using PyPDF2
pdf_reader = PyPDF2.PdfReader(uploaded_file)
document_text = ""
for page in pdf_reader.pages:
document_text += page.extract_text()
elif uploaded_file.type == "text/plain":
# Read text file
document_text = uploaded_file.read().decode("utf-8")
# Warn if both file and text are empty
if not document_text.strip():
st.warning("Please upload a file or paste a document to summarize.")
# Language selection
language_map = {
"English": "en_XX",
"Spanish": "es_XX",
"French": "fr_XX",
"German": "de_DE",
"Mandarin (Simplified)": "zh_CN",
}
selected_language = st.selectbox("Select Target Language for Summarization:", list(language_map.keys()))
# Length control slider
summary_length = st.slider(
"Select Summary Length (in number of words):",
min_value=50,
max_value=300,
value=150,
step=10
)
# ----------------------------
# Step 5: Summarization Logic
# ----------------------------
if st.button("Summarize"):
if document_text.strip():
# Set the target language token
tokenizer.src_lang = "en_XX" # Assuming input document is in English
target_lang_token = language_map[selected_language]
# Preprocess the input
inputs = tokenizer(document_text, return_tensors="pt", max_length=1024, truncation=True)
# Generate summary
with st.spinner("Generating summary..."):
summary_ids = model.generate(
inputs["input_ids"],
max_length=summary_length,
num_beams=4,
early_stopping=True,
decoder_start_token_id=tokenizer.convert_tokens_to_ids(target_lang_token),
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# Display the summary
st.subheader("Summarized Document:")
st.write(summary)
else:
st.warning("Please upload or paste a document to summarize.")
# ----------------------------
# Step 6: Footer
# ----------------------------
st.markdown("---")
st.markdown("""
*Created by [Your Name]. Built with 🤗 Transformers, Streamlit, and PyPDF2.*
""")
# Add a "dark mode" toggle in the sidebar
st.sidebar.markdown("---")
st.sidebar.write("### Customize Theme:")
if st.sidebar.checkbox("Enable Dark Mode"):
st.markdown("""
<style>
body {
background-color: #333;
color: white;
}
</style>
""", unsafe_allow_html=True)