Skip to content

Commit 52c6477

Browse files
committed
robustness fixes
1 parent b0c7f4c commit 52c6477

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

summarizer/sbert.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -121,22 +121,21 @@ def cluster_runner(
121121
first_embedding = hidden[0, :]
122122
hidden = hidden[1:, :]
123123

124-
hidden_args = ClusterFeatures(
124+
summary_sentence_indices = ClusterFeatures(
125125
hidden, algorithm, random_state=self.random_state).cluster(ratio, num_sentences)
126126

127127
if use_first:
128-
# adjust for the first sentence to the right.
129-
hidden_args = [i + 1 for i in hidden_args]
130-
if not hidden_args:
131-
hidden_args.append(0)
132-
133-
elif hidden_args[0] != 0:
134-
hidden_args.insert(0, 0)
128+
if summary_sentence_indices:
129+
# adjust for the first sentence to the right.
130+
summary_sentence_indices = [i + 1 for i in summary_sentence_indices]
131+
summary_sentence_indices.insert(0, 0)
132+
else:
133+
summary_sentence_indices.append(0)
135134

136135
hidden = np.vstack([first_embedding, hidden])
137136

138-
sentences = [sentences[j] for j in hidden_args]
139-
embeddings = np.asarray([hidden[j] for j in hidden_args])
137+
sentences = [sentences[j] for j in summary_sentence_indices]
138+
embeddings = np.asarray([hidden[j] for j in summary_sentence_indices])
140139

141140
return sentences, embeddings
142141

0 commit comments

Comments
 (0)