Source code for promptbind.visualize_prompt_components
import argparse
from os import makedirs
from os.path import join
import numpy as np
import matplotlib.pyplot as plt
[docs]
def main(opt) :
# PromptBind Model Configuration
best_epoch_dict = {8:12, 16:12, 32:22, 48:40}
# Create Directory for Saving Prompt Component Visualization Results
save_path = f"prompt_comp_visualization/prompt_{opt.prompt_num}"
makedirs(save_path, exist_ok=True)
# Load Prompt Components
prompt_comp_root = f"results/prompt_{opt.prompt_num}/prompt_components/epoch-{best_epoch_dict[opt.prompt_num]}"
pocket_prompt_node_components = np.load(join(prompt_comp_root, "pocket_prompt_node_components.npy"))
pocket_prompt_coord_components = np.load(join(prompt_comp_root, "pocket_prompt_coord_components.npy"))
complex_prompt_node_components = np.load(join(prompt_comp_root, "complex_prompt_node_components.npy"))
complex_prompt_coord_components = np.load(join(prompt_comp_root, "complex_prompt_coord_components.npy"))
# Start Prompt Component Visualization
print(f"Prompt Component Visualization [PromptBind-{opt.prompt_num}]")
# Save Pocket Node Prompt Component Visualization Result
print(f"[Pocket Node Prompt Comp.] < (Mean) : {pocket_prompt_node_components.mean()} || (Std) : {pocket_prompt_node_components.std()} >")
plt.imshow(pocket_prompt_node_components, interpolation=None)
plt.colorbar()
plt.savefig(join(save_path, "prompt_node_components(pocket).png"), bbox_inches="tight")
plt.clf()
# Save Pocket Coord. Prompt Component Visualization Result
print(f"[Pocket Coord. Prompt Comp.] < (Mean) : {pocket_prompt_coord_components.mean()} || (Std) : {pocket_prompt_coord_components.std()} >")
plt.imshow(pocket_prompt_coord_components, interpolation=None)
plt.colorbar()
plt.savefig(join(save_path, "prompt_coord_components(pocket).png"), bbox_inches="tight")
plt.clf()
# Save Complex Node Prompt Component Visualization Result
print(f"[Complex Node Prompt Comp.] < (Mean) : {complex_prompt_node_components.mean()} || (Std) : {complex_prompt_node_components.std()} >")
plt.imshow(complex_prompt_node_components, interpolation=None)
plt.colorbar()
plt.savefig(join(save_path, "prompt_node_components(complex).png"), bbox_inches="tight")
plt.clf()
# Save Complex Coord. Prompt Component Visualization Result
print(f"[Complex Coord. Prompt Comp.] < (Mean) : {complex_prompt_coord_components.mean()} || (Std) : {complex_prompt_coord_components.std()} >")
plt.imshow(complex_prompt_coord_components, interpolation=None)
plt.colorbar()
plt.savefig(join(save_path, "prompt_coord_components(complex).png"), bbox_inches="tight")
plt.clf()
if __name__ == "__main__" :
# Parse Arguments
parser = argparse.ArgumentParser()
parser.add_argument("--prompt-num", type=int, choices=[8,16,32,48])
opt = parser.parse_args()
# Save Prompt Component Visualization Result
main(opt)