Auto-IOL / Error The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() ×
luigi12345's picture
(Commit from mac2)
cff8897
import gradio as gr
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
import os
from datetime import datetime
import numpy as np
import json
import math
import spaces
import logging
from functools import lru_cache
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
TEST_REPORT_URL = "https://www.eyenews.uk.com/media/9694/eyeam18-case-in-point-figure-2.png"
DESCRIPTION = """
# IOL Report Analyzer
Upload an IOL Master biometry report for AI-powered analysis and IOL calculations.
"""
DEFAULT_SETTINGS = {
"manufacturer": "Alcon",
"lens_model": "SN60WF",
"a_constant": 118.7,
"target_refraction": 0.0
}
IOL_CONSTANTS = {
"Alcon": {
"SN60WF": 118.7,
"SA60AT": 118.4,
"MA60MA": 118.9,
"SN6AT": 119.1
},
"Johnson & Johnson": {
"ZCB00": 119.3,
"PCB00": 119.3,
"ZA9003": 119.1,
"AR40e": 118.7
},
"Zeiss": {
"CT LUCIA 611P": 118.6,
"AT LISA tri": 118.3,
"AT TORBI": 118.2
},
"Bausch & Lomb": {
"enVista": 119.1,
"LI61AO": 118.0,
"EyeCee One": 118.9
}
}
FORMULAS_DOC = """
# IOL Power Calculation Formulas
## 1. SRK/T Formula
A complex vergence formula that considers:
- Corneal height
- Effective lens position (ELP)
- Retinal thickness
- Adjusted axial length
Key variables: AL, K readings, A-constant
## 2. Barrett Universal II Formula
A proprietary formula developed by Dr. Graham Barrett that considers:
- Adjusted axial length
- Lens factor
- Effective lens position
- Optional ACD measurement
Note: Implementation uses published approximations
## 3. Haigis Formula
ELP = a0 + (a1 × ACD) + (a2 × AL)
P = (n_vitreous * (AL - ELP)) - (n_aqueous * ELP) / (ELP * (AL - ELP))
Where:
- ELP = Effective lens position
- ACD = Anterior chamber depth
- AL = Axial length
## 4. Holladay 1 Formula
P = (1336 / (AL - ELP)) - (1336 / (337.5 / K))
Where:
- P = IOL power (D)
- AL = Axial length (mm)
- ELP = Effective lens position
- K = Mean keratometry (D)
## 5. EVO Formula
A proprietary formula developed by VSY Biotechnology that considers:
- Axial length
- Keratometry
- Anterior chamber depth
- Lens thickness
- Target refraction
Note: Implementation uses published approximations
## 6. Kane Formula
A proprietary formula developed by Dr. Graham Barrett that considers:
- Axial length
- Keratometry
- Anterior chamber depth
- Lens thickness
- Gender (optional)
Note: Implementation uses published approximations
"""
expected_measurements = {'axial_length': ('Axial Length', 'mm'), 'k1': ('K1', 'D'), 'k2': ('K2', 'D'), 'acd': ('ACD', 'mm'), 'lens_thickness': ('Lens Thickness', 'mm'), 'white_to_white': ('WTW', 'mm'), 'pupil_size': ('Pupil Size', 'mm'), 'astigmatism': ('Astigmatism', 'D'), 'axis': ('Axis', '°')}
class IOLCalculator:
@staticmethod
def srkt_formula(al, k1, k2, a_const, target_ref=0.0):
try:
k_avg = (k1 + k2) / 2; r = 337.5 / k_avg; h = r - math.sqrt(r**2 - (0.0725**2))
acd = 0.62467 * r - 6.8; rt = 0.65696 - 0.02029 * al; elp = h + acd - rt
v = 1336.3 / (337.5 / k_avg)
iol_power = (1.336 - 1.336) / (0.001 * (al - elp - 0.1)) + (1336.3 - v) / (v * (al - elp - 0.1))
return round(iol_power + (a_const - 118.4) - (target_ref * 1.458), 2)
except: return None
@staticmethod
def barrett_universal_2(al, k1, k2, acd=None, lcf=1.67, target_ref=0.0):
try:
k_avg = (k1 + k2) / 2; r = 337.5 / k_avg
acd = 3.0 + (0.1 * (al - 23.5)) + (0.05 * (k_avg - 43.5)) + (0.1 * (lcf - 1.67)) if acd is None else acd
lfa = lcf * (al / 23.5) * (1 + 0.02 * abs(al - 23.5)) * (0.98 if al > 26 else 1.02 if al < 22 else 1)
elp = (lfa * acd + 0.1 * k_avg - 3.4) * (0.97 if k_avg > 46 else 1.03 if k_avg < 42 else 1)
iol_power = ((1.336 * 1000 / (al - elp - (0.65696 - 0.02029 * al))) - ((1.3375 - 1) / (r / 1000)) - (target_ref * 1.458)) * (0.98 if al > 25 else 1.02 if al < 22 else 1)
return round(iol_power, 2)
except: return None
@staticmethod
def haigis_formula(al, acd, k_avg, a0=0.87, a1=0.2, a2=0.4, target_ref=0.0):
try:
elp = a0 + (a1 * acd) + (a2 * al)
return round((((1.336 * (al - elp)) - (1.336 * elp)) / (elp * (al - elp))) * 1000 - (target_ref * 1.458), 2)
except: return None
@staticmethod
def holladay1_formula(al, k_avg, a_const, target_ref=0.0):
try:
sf = a_const - 68.4; acd = 0.56 + (sf * 0.65) + (0.4 * math.log(k_avg))
r = 337.5 / k_avg; h = r - math.sqrt(r**2 - (0.0725**2)); elp = h + sf
return round(((1336 / (al - elp)) - (1336 / (337.5 / k_avg))) * 1000 - (target_ref * 1.458), 2)
except: return None
@staticmethod
def get_barrett_lcf(a_const): return (a_const - 115.8) / 1.2
@staticmethod
def calculate_retinal_thickness(al): return 0.65696 - 0.02029 * al
@staticmethod
def calculate_corneal_height(k_avg):
r = 337.5 / k_avg; return r - math.sqrt(r**2 - (0.0725**2))
@staticmethod
def evo_formula(al, k1, k2, acd, lt, target_ref=0.0):
try:
avg_k = (k1 + k2) / 2
k_adj = 0.97 if avg_k > 46 else 1.03 if avg_k < 42 else 1.0
al_adj = 0.97 if al > 26 else 1.03 if al < 22 else 1.0
acd_adj = 0.95 if al > 25 else 1.05 if al < 22 else 1.0
return round((al * 0.64 * al_adj) + (avg_k * 0.87 * k_adj) - (acd * 0.45 * acd_adj) - (target_ref * 1.458) + (lt * 0.20), 2)
except: return None
@staticmethod
def kane_formula(al, k1, k2, acd, lt, gender='neutral', target_ref=0.0):
try:
avg_k = (k1 + k2) / 2
gender_factor = {'male': 1.12, 'female': 1.06, 'neutral': 1.09}.get(gender.lower(), 1.09)
al_adj = 0.96 if al > 25 else 1.04 if al < 22 else 1.0
k_adj = 0.98 if avg_k > 46 else 1.02 if avg_k < 42 else 1.0
acd_adj = 0.95 if acd > 3.5 else 1.05 if acd < 2.8 else 1.0
return round((al * 0.67 * al_adj) + (avg_k * 0.84 * k_adj * gender_factor) - (acd * 0.42 * acd_adj) - (lt * 0.12) - (target_ref * 1.458), 2)
except: return None
def validate_measurements(eye_data):
warnings = []
if eye_data.get('axial_length') and not 20 <= eye_data['axial_length'] <= 30: warnings.append("Axial length outside normal range")
if eye_data.get('k1') and eye_data.get('k2'):
if not 39 <= eye_data['k1'] <= 48: warnings.append("K1 reading outside normal range")
if not 39 <= eye_data['k2'] <= 48: warnings.append("K2 reading outside normal range")
return warnings
def calculate_iol_powers(eye_data, settings):
if not all(eye_data.get(field) for field in ['axial_length', 'k1', 'k2']): return {"error": "Missing required measurements"}
calculator = IOLCalculator(); results = {"measurements": eye_data, "calculations": {}, "recommendations": {"warnings": []}}; calculations = {}
lcf = calculator.get_barrett_lcf(settings['a_constant']); k_avg = (eye_data['k1'] + eye_data['k2']) / 2
for formula in [('barrett', lambda: calculator.barrett_universal_2(eye_data['axial_length'], eye_data['k1'], eye_data['k2'], acd=eye_data.get('acd'), lcf=lcf, target_ref=settings['target_refraction'])),
('haigis', lambda: calculator.haigis_formula(eye_data['axial_length'], eye_data['acd'], k_avg, target_ref=settings['target_refraction']) if eye_data.get('acd') else None),
('holladay1', lambda: calculator.holladay1_formula(eye_data['axial_length'], k_avg, settings['a_constant'], settings['target_refraction'])),
('srkt', lambda: calculator.srkt_formula(eye_data['axial_length'], eye_data['k1'], eye_data['k2'], settings['a_constant'], settings['target_refraction'])),
('evo', lambda: calculator.evo_formula(eye_data['axial_length'], eye_data['k1'], eye_data['k2'], eye_data['acd'], eye_data['lens_thickness'], settings['target_refraction']) if all(eye_data.get(f) for f in ['axial_length', 'k1', 'k2', 'acd', 'lens_thickness']) else None),
('kane', lambda: calculator.kane_formula(eye_data['axial_length'], eye_data['k1'], eye_data['k2'], eye_data['acd'], eye_data['lens_thickness'], 'neutral', settings['target_refraction']) if all(eye_data.get(f) for f in ['axial_length', 'k1', 'k2', 'acd', 'lens_thickness']) else None)]:
result = formula[1]()
if result is not None: calculations[formula[0]] = result
results['calculations'] = calculations
valid_powers = [calculations['barrett'] * 1.2 if 'barrett' in calculations else None] + [p for k, p in calculations.items() if p is not None and k != 'barrett']
valid_powers = [p for p in valid_powers if p is not None]
if valid_powers: results['recommendations']['suggested_power'] = round(sum(valid_powers) / len(valid_powers), 2)
results['recommendations']['warnings'].extend(validate_measurements(eye_data))
return results
def ensure_temp_directory():
"""Create temp directory if it doesn't exist"""
os.makedirs("temp", exist_ok=True)
@spaces.GPU
def check_environment(): print(f"PyTorch version: {torch.__version__}")
@lru_cache(maxsize=1)
def get_model(model_id):
try:
if not torch.cuda.is_available():
print("Warning: CUDA not available, using CPU")
device_map = "cpu"
else:
device_map = "auto"
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map=device_map
)
return model.eval()
except Exception as e:
error_msg = f"Error loading model: {str(e)}"
print(error_msg)
raise RuntimeError(error_msg)
def get_processor(model_name):
"""Load the processor with error handling"""
try:
return AutoProcessor.from_pretrained(model_name)
except Exception as e:
raise RuntimeError(f"Failed to load processor {model_name}: {str(e)}")
def generate_extraction_prompt():
"""Generate the prompt for measurement extraction"""
return """Please analyze this IOL Master report and extract the measurements in JSON format with the following structure:
{
"right_eye": {
"axial_length": float,
"k1": float,
"k2": float,
"acd": float,
"lens_thickness": float,
"white_to_white": float,
"pupil_size": float
},
"left_eye": {
// same structure as right_eye
}
}"""
def format_eye_data(measurements, results):
"""Format eye measurements and calculations for report"""
if not measurements or not isinstance(measurements, dict) or not results:
return "No data available", "No calculations available", ""
meas_text = "| Parameter | Value | Unit |\n|-----------|--------|------|\n"
for key, (label, unit) in expected_measurements.items():
if key in measurements:
# Ensure the measurement is a number and not an array
value = measurements[key]
if isinstance(value, np.ndarray):
value = value.item() if value.size == 1 else float(value.mean())
meas_text += f"| {label} | {value} | {unit} |\n"
calc_text = "| Formula | Power |\n|----------|--------|\n"
if 'calculations' in results:
for formula, power in results['calculations'].items():
calc_text += f"| {formula.title()} | {power} D |\n"
warnings = ""
if 'recommendations' in results and results['recommendations'].get('warnings'):
warnings = "\n### ⚠️ Warnings\n" + "\n".join(f"- {w}" for w in results['recommendations']['warnings'])
return meas_text, calc_text, warnings
def cleanup_resources(image_path=None):
if image_path and os.path.exists(image_path):
try: os.remove(image_path)
except Exception as e: print(f"Failed to remove temporary file {image_path}: {str(e)}")
if torch.cuda.is_available():
try: torch.cuda.empty_cache()
except Exception as e: print(f"Failed to clear GPU cache: {str(e)}")
MAX_IMAGE_SIZE = 5 * 1024 * 1024 # 5MB
SUPPORTED_FORMATS = ['.jpg', '.jpeg', '.png']
MAX_GPU_MEMORY_THRESHOLD = 2 * 1024 * 1024 * 1024 # 2GB
def validate_inputs(image, manufacturer, lens_model, a_constant, target_refraction):
if image is None or (isinstance(image, np.ndarray) and image.size == 0):
raise ValueError("No image provided")
# For numpy arrays, we need to check the shape and content
if isinstance(image, np.ndarray):
if len(image.shape) < 2: # Must be at least 2D
raise ValueError("Invalid image format: Image must be 2D or 3D array")
if image.shape[0] == 0 or image.shape[1] == 0:
raise ValueError("Invalid image dimensions")
# Validate image format and size using proper file attributes
if hasattr(image, 'format'):
if image.format.lower() not in [fmt.strip('.') for fmt in SUPPORTED_FORMATS]:
raise ValueError(f"Unsupported image format. Supported formats: {', '.join(SUPPORTED_FORMATS)}")
# Check file size if available
if hasattr(image, 'size_bytes'): # Use size_bytes instead of size
if image.size_bytes > MAX_IMAGE_SIZE:
raise ValueErroError
The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
×
r(f"Image size exceeds maximum allowed size of {MAX_IMAGE_SIZE/1024/1024}MB")
# Validate numeric inputs
if not isinstance(target_refraction, (int, float)):
raise ValueError("Target refraction must be a number")
if not isinstance(a_constant, (int, float)):
raise ValueError("A-constant must be a number")
# Validate a_constant range
if not 110 <= a_constant <= 125:
raise ValueError("A-constant must be between 110 and 125")
if manufacturer not in IOL_CONSTANTS:
raise ValueError("Invalid manufacturer selected")
if not lens_model or lens_model not in IOL_CONSTANTS[manufacturer]:
raise ValueError("Invalid lens model selected")
def cleanup_model_resources():
"""Clean up model resources and GPU memory"""
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Clear the model cache
get_model.cache_clear()
get_processor.cache_clear()
except Exception as e:
logger.error(f"Error cleaning up model resources: {str(e)}")
def check_memory_availability():
"""Check if sufficient memory is available for processing"""
if torch.cuda.is_available():
free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)
if free_memory < MAX_GPU_MEMORY_THRESHOLD:
# Try to free up memory
cleanup_model_resources()
free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)
if free_memory < MAX_GPU_MEMORY_THRESHOLD:
raise RuntimeError("Insufficient GPU memory available. Please try again later.")
return True
@spaces.GPU
def run_analysis(image, manufacturer, lens_model, a_constant, target_refraction):
try:
if torch.cuda.is_available():
free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)
if free_memory < 2 * 1024 * 1024 * 1024: # Less than 2GB free
cleanup_model_resources()
if torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0) < 2 * 1024 * 1024 * 1024:
return "Error: Insufficient GPU memory"
if not isinstance(a_constant, (int, float)) or not isinstance(target_refraction, (int, float)):
return "Error: A-Constant and Target Refraction must be numbers"
if a_constant < 110 or a_constant > 122:
return "Error: A-Constant must be between 110 and 122"
if image is None or (isinstance(image, np.ndarray) and image.size == 0):
return "Please upload an image"
settings = {
"manufacturer": manufacturer,
"lens_model": lens_model,
"a_constant": a_constant,
"target_refraction": target_refraction
}
image_path = f"temp/image_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
os.makedirs("temp", exist_ok=True)
try:
if isinstance(image, np.ndarray):
if image.dtype != np.uint8:
image = (image * 255).astype(np.uint8)
pil_image = Image.fromarray(image)
elif isinstance(image, list):
image_array = np.array(image)
if image_array.dtype != np.uint8:
image_array = (image_array * 255).astype(np.uint8)
pil_image = Image.fromarray(image_array)
else:
return "Error: Invalid image format"
pil_image.save(image_path)
except Exception as e:
return f"Error handling image: {str(e)}"
try:
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype=torch.float16,
device_map="auto" if torch.cuda.is_available() else "cpu"
).eval()
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
except Exception as e:
return f"Error: Model initialization failed - {str(e)}"
messages = [{
"role": "user",
"content": [
{"type": "image", "image": image_path},
{"type": "text", "text": generate_extraction_prompt()}
]
}]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt").to(model.device)
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
temperature=0.1,
do_sample=False
)
ai_output = processor.batch_decode([generated_ids[0][len(inputs.input_ids[0]):]], skip_special_tokens=True)[0]
try:
measurements = json.loads(ai_output)
if not isinstance(measurements, dict):
return "Error: Invalid measurements format"
right_results = calculate_iol_powers(measurements.get('right_eye', {}), settings) if measurements.get('right_eye') else None
right_meas, right_calc, right_warn = format_eye_data(measurements.get('right_eye'), right_results)
left_results = calculate_iol_powers(measurements.get('left_eye', {}), settings) if measurements.get('left_eye') else None
left_meas, left_calc, left_warn = format_eye_data(measurements.get('left_eye'), left_results)
report = f"""# IOL Analysis Report
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M')}
## Right Eye Analysis
### 📊 Measurements
{right_meas}
### 💡 IOL Calculations
{right_calc}
{right_warn}
## Left Eye Analysis
### 📊 Measurements
{left_meas}
### 💡 IOL Calculations
{left_calc}
{left_warn}
-------------------
### Device Settings
| Parameter | Value |
|:----------|:-------|
| Manufacturer | {manufacturer} |
| Lens Model | {lens_model} |
| A-Constant | {a_constant} |
| Target Refraction | {target_refraction} D |
-------------------
*Report generated by [Auto-IOL AI Tool](https://luigi12345-auto-iol.hf.space)*"""
if os.path.exists(image_path):
try:
os.remove(image_path)
except:
pass
return report
except json.JSONDecodeError:
return "Error: Could not parse AI output"
except Exception as e:
return f"Error during calculation: {str(e)}"
except Exception as e:
return f"Unexpected error: {str(e)}"
finally:
if 'image_path' in locals() and os.path.exists(image_path):
try:
os.remove(image_path)
except:
pass
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="indigo")) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tabs() as tabs:
with gr.Tab("📊 Analysis", id="analysis"):
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(label="IOL Report Image", type="numpy", value=TEST_REPORT_URL, height=350, width="100%", sources=["upload", "webcam", "clipboard"])
with gr.Group():
with gr.Row():
manufacturer = gr.Dropdown(choices=list(IOL_CONSTANTS.keys()), value=DEFAULT_SETTINGS["manufacturer"], label="Manufacturer", info="Select IOL manufacturer", scale=1)
lens_model = gr.Dropdown(choices=list(IOL_CONSTANTS[DEFAULT_SETTINGS["manufacturer"]].keys()), value=DEFAULT_SETTINGS["lens_model"], label="Lens Model", info="Select specific lens model", scale=1)
with gr.Row():
a_constant = gr.Number(value=DEFAULT_SETTINGS["a_constant"], label="A-Constant", info="Auto-updated based on lens selection", interactive=False, scale=1)
target_refraction = gr.Number(value=DEFAULT_SETTINGS["target_refraction"], label="Target Refraction (D)", info="Desired postoperative refraction", minimum=-10, maximum=10, step=0.25, scale=1)
analyze_btn = gr.Button("Analyze Report", variant="primary", size="lg", icon="🔍")
with gr.Column(scale=2):
output_text = gr.Markdown(label="Analysis Results", show_label=True, value="Upload an IOL report image and click 'Analyze Report' to begin...")
with gr.Tab("📚 Documentation", id="docs"): gr.Markdown(FORMULAS_DOC)
analyze_btn.click(fn=run_analysis, inputs=[input_img, manufacturer, lens_model, a_constant, target_refraction], outputs=output_text, api_name="analyze", show_progress="full")
gr.Examples([[TEST_REPORT_URL, DEFAULT_SETTINGS["manufacturer"], DEFAULT_SETTINGS["lens_model"], DEFAULT_SETTINGS["a_constant"], DEFAULT_SETTINGS["target_refraction"]]], inputs=[input_img, manufacturer, lens_model, a_constant, target_refraction], outputs=output_text, fn=run_analysis, cache_examples=True, label="Example Report")
manufacturer.change(fn=lambda m: list(IOL_CONSTANTS[m].keys()), inputs=[manufacturer], outputs=[lens_model])
lens_model.change(fn=lambda m, l: IOL_CONSTANTS[m][l], inputs=[manufacturer, lens_model], outputs=[a_constant])
if __name__ == "__main__":
demo.queue(max_size=1, api_open=False)
demo.launch(debug=True, show_error=True, share=False, server_name="0.0.0.0", server_port=7860)