Similarity ratio minus human features
Python
from typing import List
import numpy as np
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
from sklearn.metrics import roc_curve, roc_auc_score
import torch
from prompt_playground.actionclip import (
images_features_df,
get_images_features,
get_all_text_features,
get_text_features,
VARIATION_NAMES,
)
from prompt_playground.tensor_utils import normalize_features
Python
def get_similarities(
variation: str,
texts_positive: List[str],
texts_negative: List[str],
):
assert len(texts_positive) > 0
assert len(texts_negative) > 0
assert variation in VARIATION_NAMES
images_features = normalize_features(get_images_features(variation))
positive_texts_features = get_all_text_features(texts_positive, variation, True)
negative_texts_features = get_all_text_features(texts_negative, variation, False)
human_text_features = get_all_text_features(["human"], variation, False).squeeze(
dim=0
)
negative_texts_features -= human_text_features
texts_features = torch.vstack(
(positive_texts_features, normalize_features(negative_texts_features))
)
similarities = images_features @ texts_features.T
return similarities
def infer(TEXTS: List[str], TEXT_CLASSIFICATIONS: List[bool], VARIATION: str):
texts_positive = [t for i, t in enumerate(TEXTS) if TEXT_CLASSIFICATIONS[i]]
texts_negative = [t for i, t in enumerate(TEXTS) if not TEXT_CLASSIFICATIONS[i]]
similarities = get_similarities(
VARIATION,
texts_positive,
texts_negative,
)
alarms_series = images_features_df(VARIATION)["Alarm"]
df = (
pd.DataFrame(
[
[clip, alarms_series[clip], text, y_predict, clip_similarity]
for clip, clip_similarities in zip(alarms_series.index, similarities)
for text, y_predict, clip_similarity in zip(
texts_positive + texts_negative,
[True] * len(texts_positive) + [False] * len(texts_negative),
clip_similarities,
)
],
columns=["clip", "y_true", "text", "y_predict", "similarity"],
)
.rename(columns={"classification": "y_predict"})
.sort_values(["clip", "y_true", "y_predict", "text"])
.reset_index(drop=True)
)
unique_y_true = df["y_true"].unique()
fig = make_subplots(
rows=1,
cols=4,
column_widths=[0.3, 0.2, 0.3, 0.2],
shared_yaxes=True,
y_title="similarity",
subplot_titles=[
f"y_true={y_true}" for y_true in unique_y_true for _ in range(2)
],
)
# group by
# 1. y_true (facet)
# 2. y_predict / text class (color)
for i, y_true in enumerate(unique_y_true):
facet_df = df[df["y_true"] == y_true]
for y_predict, class_color in zip(
sorted(facet_df["y_predict"].unique()), ["CornflowerBlue", "Tomato"]
):
facet_color_df = facet_df[facet_df["y_predict"] == y_predict]
violin_side = "positive" if y_predict else "negative"
fig.add_scatter(
x=facet_color_df["text"],
y=facet_color_df["similarity"],
marker=dict(color=class_color, size=3),
hovertext=facet_color_df["clip"],
mode="markers",
name=f"y_predict={str(y_predict)}",
legendgroup=f"y_true={str(y_true)}",
legendgrouptitle=dict(text=f"y_true={str(y_true)}"),
row=1,
col=i * 2 + 1,
)
fig.update_layout(**{f"xaxis{i*2+1}": dict(title="text")})
fig.add_violin(
x=np.repeat(str(y_true), len(facet_color_df)),
y=facet_color_df["similarity"],
box=dict(visible=True),
scalegroup=str(y_true),
scalemode="count",
width=1,
meanline=dict(visible=True),
side=violin_side,
marker=dict(color=class_color),
showlegend=False,
row=1,
col=i * 2 + 2,
)
fig.update_layout(height=900, violingap=0, violinmode="overlay")
fig.show()
groupby_classification = df.groupby(["clip", "y_predict"])["similarity"]
weighted_similarity = groupby_classification.sum() / groupby_classification.count()
ratio = weighted_similarity.groupby(level="clip").aggregate(
lambda s: s.loc[:, True] / s.loc[:, False]
)
ratio_df = ratio.to_frame("ratio")
ratio_df["y_true"] = alarms_series.loc[ratio_df.index]
fig = px.scatter(
ratio_df.sort_values(["y_true", "ratio"]),
y="ratio",
color="y_true",
render_mode="line",
marginal_y="violin",
height=900,
)
fig.show()
fpr, tpr, thresholds = roc_curve(ratio_df["y_true"], ratio_df["ratio"])
auc_score = roc_auc_score(ratio_df["y_true"], ratio_df["ratio"])
roc_df = pd.DataFrame(
{
"False Positive Rate": fpr,
"True Positive Rate": tpr,
},
columns=pd.Index(["False Positive Rate", "True Positive Rate"], name="Rate"),
index=pd.Index(thresholds, name="Thresholds"),
)
fig = px.line(
roc_df,
x="False Positive Rate",
y="True Positive Rate",
title=f"{VARIATION} - AUC: {auc_score:.5f}",
color_discrete_sequence=["orange"],
range_x=[0, 1],
range_y=[0, 1],
width=600,
height=450,
).add_shape(type="line", line=dict(dash="dash"), x0=0, x1=1, y0=0, y1=1)
fig.show()
TEXTS_TRUE = ["human"]
TEXTS_FALSE = ["birds flying", "bag", "rabbits", "insects", "animals", "barbed wire"]
# assert len(TEXTS_TRUE) == len(TEXTS_FALSE)
TEXTS = TEXTS_TRUE + TEXTS_FALSE
TEXT_CLASSIFICATIONS = [True] * len(TEXTS_TRUE) + [False] * len(TEXTS_FALSE)
infer(TEXTS, TEXT_CLASSIFICATIONS, VARIATION_NAMES[1])