amino_acid_map = {
    'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C',
    'GLN': 'Q', 'GLU': 'E', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
    'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P',
    'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
}
from collections import Counter
from ast import literal_eval
from matplotlib.gridspec import GridSpec
import warnings
warnings.filterwarnings("ignore", message="Attempting to set identical low and high xlims")
def convert_to_single_letter(aa_list):
    if type(aa_list) == str:
        aa_list = literal_eval(aa_list)
    return [amino_acid_map[aa] for aa in aa_list]
def create_sequence_visualizations(df, max_letters_per_row=20):
    for idx, row in df.iterrows():
        bsss = row['bsss']
        AAss = row['AAss']
        single_letter_sequence = convert_to_single_letter(AAss)
        
        freq_counter = Counter(single_letter_sequence)
        total_aa = len(single_letter_sequence)
        frequencies = {aa: freq / total_aa for aa, freq in freq_counter.items()}
        
        cmap = plt.get_cmap('viridis')
        norm = plt.Normalize(0, max(frequencies.values()) if frequencies else 1)
        
        n_rows = (len(single_letter_sequence) + max_letters_per_row - 1) // max_letters_per_row
        fig = plt.figure(figsize=(max_letters_per_row * 0.6, n_rows * 1.2 + 0.5))
        
        gs = GridSpec(n_rows + 1, 1, height_ratios=[1] * n_rows + [0.1], hspace=0.3)
        
        for row_idx in range(n_rows):
            start_idx = row_idx * max_letters_per_row
            end_idx = min((row_idx + 1) * max_letters_per_row, len(single_letter_sequence))
            ax = fig.add_subplot(gs[row_idx, 0])
            ax.set_xlim(0, max_letters_per_row)
            ax.set_ylim(0, 1)
            ax.axis('off')
            
            for i, aa in enumerate(single_letter_sequence[start_idx:end_idx]):
                freq = frequencies[aa]
                color = cmap(norm(freq))
                ax.text(i + 0.5, 0.5, aa, ha='center', va='center', fontsize=24, color=color, fontweight='bold')
        
        cbar_ax = fig.add_subplot(gs[-1, 0])
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, cax=cbar_ax, orientation='horizontal')
        cbar.set_label('Frequency', fontsize=12)
        cbar.ax.tick_params(labelsize=12)
        
        plt.suptitle(f"Center residue {bsss}", fontsize=14)
        plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
        plt.show()
            
create_sequence_visualizations(df_plot)